snowflake-ml-python 1.6.4__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_complete.py +107 -64
  3. snowflake/cortex/_finetune.py +273 -0
  4. snowflake/cortex/_sse_client.py +91 -28
  5. snowflake/cortex/_util.py +30 -1
  6. snowflake/ml/_internal/telemetry.py +4 -2
  7. snowflake/ml/_internal/type_utils.py +3 -3
  8. snowflake/ml/_internal/utils/import_utils.py +31 -0
  9. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
  10. snowflake/ml/data/__init__.py +5 -0
  11. snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
  12. snowflake/ml/data/data_connector.py +1 -1
  13. snowflake/ml/data/torch_utils.py +33 -14
  14. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
  15. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
  16. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
  18. snowflake/ml/feature_store/examples/example_helper.py +6 -3
  19. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
  20. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
  21. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
  22. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
  23. snowflake/ml/feature_store/feature_store.py +1 -2
  24. snowflake/ml/feature_store/feature_view.py +5 -1
  25. snowflake/ml/model/_client/model/model_version_impl.py +145 -11
  26. snowflake/ml/model/_client/ops/model_ops.py +56 -16
  27. snowflake/ml/model/_client/ops/service_ops.py +46 -30
  28. snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
  29. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  30. snowflake/ml/model/_client/sql/service.py +25 -1
  31. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  34. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
  36. snowflake/ml/model/_packager/model_env/model_env.py +12 -0
  37. snowflake/ml/model/_packager/model_handlers/_utils.py +6 -2
  38. snowflake/ml/model/_packager/model_handlers/catboost.py +4 -7
  39. snowflake/ml/model/_packager/model_handlers/custom.py +5 -1
  40. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
  41. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -7
  42. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
  43. snowflake/ml/model/_packager/model_handlers/sklearn.py +51 -7
  44. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +8 -66
  45. snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
  46. snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
  47. snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
  48. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
  49. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
  50. snowflake/ml/model/_packager/model_packager.py +0 -11
  51. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
  52. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
  53. snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +14 -26
  54. snowflake/ml/model/_signatures/core.py +63 -16
  55. snowflake/ml/model/_signatures/pandas_handler.py +87 -27
  56. snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
  57. snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
  58. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
  59. snowflake/ml/model/_signatures/utils.py +4 -0
  60. snowflake/ml/model/custom_model.py +47 -7
  61. snowflake/ml/model/model_signature.py +40 -9
  62. snowflake/ml/model/type_hints.py +9 -1
  63. snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
  64. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
  65. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
  66. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
  67. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
  68. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
  69. snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
  70. snowflake/ml/modeling/cluster/dbscan.py +5 -2
  71. snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
  72. snowflake/ml/modeling/cluster/k_means.py +14 -19
  73. snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
  74. snowflake/ml/modeling/cluster/optics.py +6 -6
  75. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
  76. snowflake/ml/modeling/compose/column_transformer.py +15 -5
  77. snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
  78. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  79. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  80. snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
  81. snowflake/ml/modeling/covariance/oas.py +1 -1
  82. snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
  83. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
  84. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
  85. snowflake/ml/modeling/decomposition/pca.py +28 -15
  86. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
  87. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
  88. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
  89. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
  90. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
  91. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
  92. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
  93. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
  94. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
  96. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
  99. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
  101. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
  102. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
  103. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
  104. snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
  105. snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
  106. snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
  107. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  108. snowflake/ml/modeling/linear_model/lars.py +0 -10
  109. snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
  110. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  111. snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
  112. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
  113. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
  114. snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
  115. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
  116. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  117. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  118. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
  119. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
  120. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  121. snowflake/ml/modeling/linear_model/perceptron.py +3 -3
  122. snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
  123. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
  124. snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
  125. snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
  126. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
  127. snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
  128. snowflake/ml/modeling/manifold/isomap.py +1 -1
  129. snowflake/ml/modeling/manifold/mds.py +3 -3
  130. snowflake/ml/modeling/manifold/tsne.py +10 -4
  131. snowflake/ml/modeling/metrics/classification.py +12 -16
  132. snowflake/ml/modeling/metrics/ranking.py +3 -3
  133. snowflake/ml/modeling/metrics/regression.py +3 -3
  134. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
  135. snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
  136. snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
  137. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
  138. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
  139. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
  140. snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
  141. snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
  142. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  143. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
  144. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  145. snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
  146. snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
  147. snowflake/ml/modeling/pipeline/pipeline.py +16 -14
  148. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
  149. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
  150. snowflake/ml/modeling/svm/linear_svc.py +25 -16
  151. snowflake/ml/modeling/svm/linear_svr.py +23 -17
  152. snowflake/ml/modeling/svm/nu_svc.py +5 -3
  153. snowflake/ml/modeling/svm/nu_svr.py +3 -1
  154. snowflake/ml/modeling/svm/svc.py +9 -5
  155. snowflake/ml/modeling/svm/svr.py +3 -1
  156. snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
  157. snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
  158. snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
  159. snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
  160. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +448 -0
  161. snowflake/ml/monitoring/_manager/model_monitor_manager.py +238 -0
  162. snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
  163. snowflake/ml/monitoring/model_monitor.py +37 -0
  164. snowflake/ml/registry/_manager/model_manager.py +15 -1
  165. snowflake/ml/registry/registry.py +32 -37
  166. snowflake/ml/version.py +1 -1
  167. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +104 -12
  168. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +172 -171
  169. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
  170. snowflake/ml/monitoring/_client/model_monitor.py +0 -126
  171. snowflake/ml/monitoring/_client/model_monitor_manager.py +0 -361
  172. snowflake/ml/monitoring/_client/monitor_sql_client.py +0 -1335
  173. snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
  174. /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
  175. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
  176. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,448 @@
1
+ import typing
2
+ from collections import Counter
3
+ from typing import Any, Dict, List, Mapping, Optional, Set
4
+
5
+ from snowflake import snowpark
6
+ from snowflake.ml._internal.utils import (
7
+ db_utils,
8
+ query_result_checker,
9
+ sql_identifier,
10
+ table_manager,
11
+ )
12
+ from snowflake.ml.model._client.sql import _base
13
+ from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
14
+ from snowflake.snowpark import session, types
15
+
16
+ SNOWML_MONITORING_METADATA_TABLE_NAME = "_SYSTEM_MONITORING_METADATA"
17
+
18
+ MODEL_JSON_COL_NAME = "model"
19
+ MODEL_JSON_MODEL_NAME_FIELD = "model_name"
20
+ MODEL_JSON_VERSION_NAME_FIELD = "version_name"
21
+
22
+ MONITOR_NAME_COL_NAME = "MONITOR_NAME"
23
+ SOURCE_TABLE_NAME_COL_NAME = "SOURCE_TABLE_NAME"
24
+ FQ_MODEL_NAME_COL_NAME = "FULLY_QUALIFIED_MODEL_NAME"
25
+ VERSION_NAME_COL_NAME = "MODEL_VERSION_NAME"
26
+ FUNCTION_NAME_COL_NAME = "FUNCTION_NAME"
27
+ TASK_COL_NAME = "TASK"
28
+ MONITORING_ENABLED_COL_NAME = "IS_ENABLED"
29
+ TIMESTAMP_COL_NAME_COL_NAME = "TIMESTAMP_COLUMN_NAME"
30
+ PREDICTION_COL_NAMES_COL_NAME = "PREDICTION_COLUMN_NAMES"
31
+ LABEL_COL_NAMES_COL_NAME = "LABEL_COLUMN_NAMES"
32
+ ID_COL_NAMES_COL_NAME = "ID_COLUMN_NAMES"
33
+
34
+
35
+ def _build_sql_list_from_columns(columns: List[sql_identifier.SqlIdentifier]) -> str:
36
+ sql_list = ", ".join([f"'{column}'" for column in columns])
37
+ return f"({sql_list})"
38
+
39
+
40
+ class ModelMonitorSQLClient:
41
+ def __init__(
42
+ self,
43
+ session: session.Session,
44
+ *,
45
+ database_name: sql_identifier.SqlIdentifier,
46
+ schema_name: sql_identifier.SqlIdentifier,
47
+ ) -> None:
48
+ """Client to manage monitoring metadata persisted in SNOWML_OBSERVABILITY.METADATA schema.
49
+
50
+ Args:
51
+ session: Active snowpark session.
52
+ database_name: Name of the Database where monitoring resources are provisioned.
53
+ schema_name: Name of the Schema where monitoring resources are provisioned.
54
+ """
55
+ self._sql_client = _base._BaseSQLClient(session, database_name=database_name, schema_name=schema_name)
56
+ self._database_name = database_name
57
+ self._schema_name = schema_name
58
+
59
+ def _infer_qualified_schema(
60
+ self, database_name: Optional[sql_identifier.SqlIdentifier], schema_name: Optional[sql_identifier.SqlIdentifier]
61
+ ) -> str:
62
+ return f"{database_name or self._database_name}.{schema_name or self._schema_name}"
63
+
64
+ def create_model_monitor(
65
+ self,
66
+ *,
67
+ monitor_database: Optional[sql_identifier.SqlIdentifier],
68
+ monitor_schema: Optional[sql_identifier.SqlIdentifier],
69
+ monitor_name: sql_identifier.SqlIdentifier,
70
+ source_database: Optional[sql_identifier.SqlIdentifier],
71
+ source_schema: Optional[sql_identifier.SqlIdentifier],
72
+ source: sql_identifier.SqlIdentifier,
73
+ model_database: Optional[sql_identifier.SqlIdentifier],
74
+ model_schema: Optional[sql_identifier.SqlIdentifier],
75
+ model_name: sql_identifier.SqlIdentifier,
76
+ version_name: sql_identifier.SqlIdentifier,
77
+ function_name: str,
78
+ warehouse_name: sql_identifier.SqlIdentifier,
79
+ timestamp_column: sql_identifier.SqlIdentifier,
80
+ id_columns: List[sql_identifier.SqlIdentifier],
81
+ prediction_score_columns: List[sql_identifier.SqlIdentifier],
82
+ prediction_class_columns: List[sql_identifier.SqlIdentifier],
83
+ actual_score_columns: List[sql_identifier.SqlIdentifier],
84
+ actual_class_columns: List[sql_identifier.SqlIdentifier],
85
+ refresh_interval: str,
86
+ aggregation_window: str,
87
+ baseline_database: Optional[sql_identifier.SqlIdentifier] = None,
88
+ baseline_schema: Optional[sql_identifier.SqlIdentifier] = None,
89
+ baseline: Optional[sql_identifier.SqlIdentifier] = None,
90
+ statement_params: Optional[Dict[str, Any]] = None,
91
+ ) -> None:
92
+ baseline_sql = ""
93
+ if baseline:
94
+ baseline_sql = f"BASELINE='{self._infer_qualified_schema(baseline_database, baseline_schema)}.{baseline}'"
95
+ query_result_checker.SqlResultValidator(
96
+ self._sql_client._session,
97
+ f"""
98
+ CREATE MODEL MONITOR {self._infer_qualified_schema(monitor_database, monitor_schema)}.{monitor_name}
99
+ WITH
100
+ MODEL='{self._infer_qualified_schema(model_database, model_schema)}.{model_name}'
101
+ VERSION='{version_name}'
102
+ FUNCTION='{function_name}'
103
+ WAREHOUSE='{warehouse_name}'
104
+ SOURCE='{self._infer_qualified_schema(source_database, source_schema)}.{source}'
105
+ ID_COLUMNS={_build_sql_list_from_columns(id_columns)}
106
+ PREDICTION_SCORE_COLUMNS={_build_sql_list_from_columns(prediction_score_columns)}
107
+ PREDICTION_CLASS_COLUMNS={_build_sql_list_from_columns(prediction_class_columns)}
108
+ ACTUAL_SCORE_COLUMNS={_build_sql_list_from_columns(actual_score_columns)}
109
+ ACTUAL_CLASS_COLUMNS={_build_sql_list_from_columns(actual_class_columns)}
110
+ TIMESTAMP_COLUMN='{timestamp_column}'
111
+ REFRESH_INTERVAL='{refresh_interval}'
112
+ AGGREGATION_WINDOW='{aggregation_window}'
113
+ {baseline_sql}""",
114
+ statement_params=statement_params,
115
+ ).has_column("status").has_dimensions(1, 1).validate()
116
+
117
+ def drop_model_monitor(
118
+ self,
119
+ *,
120
+ database_name: Optional[sql_identifier.SqlIdentifier] = None,
121
+ schema_name: Optional[sql_identifier.SqlIdentifier] = None,
122
+ monitor_name: sql_identifier.SqlIdentifier,
123
+ statement_params: Optional[Dict[str, Any]] = None,
124
+ ) -> None:
125
+ search_database_name = database_name or self._database_name
126
+ search_schema_name = schema_name or self._schema_name
127
+ query_result_checker.SqlResultValidator(
128
+ self._sql_client._session,
129
+ f"DROP MODEL MONITOR {search_database_name}.{search_schema_name}.{monitor_name}",
130
+ statement_params=statement_params,
131
+ ).validate()
132
+
133
+ def show_model_monitors(
134
+ self,
135
+ *,
136
+ statement_params: Optional[Dict[str, Any]] = None,
137
+ ) -> List[snowpark.Row]:
138
+ fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
139
+ return (
140
+ query_result_checker.SqlResultValidator(
141
+ self._sql_client._session,
142
+ f"SHOW MODEL MONITORS IN {fully_qualified_schema_name}",
143
+ statement_params=statement_params,
144
+ )
145
+ .has_column("name", allow_empty=True)
146
+ .validate()
147
+ )
148
+
149
+ def _validate_unique_columns(
150
+ self,
151
+ timestamp_column: sql_identifier.SqlIdentifier,
152
+ id_columns: List[sql_identifier.SqlIdentifier],
153
+ prediction_columns: List[sql_identifier.SqlIdentifier],
154
+ label_columns: List[sql_identifier.SqlIdentifier],
155
+ ) -> None:
156
+ all_columns = [*id_columns, *prediction_columns, *label_columns, timestamp_column]
157
+ num_all_columns = len(all_columns)
158
+ num_unique_columns = len(set(all_columns))
159
+ if num_all_columns != num_unique_columns:
160
+ raise ValueError("Column names must be unique across id, timestamp, prediction, and label columns.")
161
+
162
+ def validate_existence_by_name(
163
+ self,
164
+ *,
165
+ database_name: Optional[sql_identifier.SqlIdentifier] = None,
166
+ schema_name: Optional[sql_identifier.SqlIdentifier] = None,
167
+ monitor_name: sql_identifier.SqlIdentifier,
168
+ statement_params: Optional[Dict[str, Any]] = None,
169
+ ) -> bool:
170
+ search_database_name = database_name or self._database_name
171
+ search_schema_name = schema_name or self._schema_name
172
+ res = (
173
+ query_result_checker.SqlResultValidator(
174
+ self._sql_client._session,
175
+ f"SHOW MODEL MONITORS LIKE '{monitor_name.resolved()}' IN {search_database_name}.{search_schema_name}",
176
+ statement_params=statement_params,
177
+ )
178
+ .has_column("name", allow_empty=True)
179
+ .validate()
180
+ )
181
+ return len(res) == 1
182
+
183
+ def validate_monitor_warehouse(
184
+ self,
185
+ warehouse_name: sql_identifier.SqlIdentifier,
186
+ statement_params: Optional[Dict[str, Any]] = None,
187
+ ) -> None:
188
+ """Validate warehouse provided for monitoring exists.
189
+
190
+ Args:
191
+ warehouse_name: Warehouse name
192
+ statement_params: Optional set of statement params to include in queries.
193
+
194
+ Raises:
195
+ ValueError: If warehouse does not exist.
196
+ """
197
+ if not db_utils.db_object_exists(
198
+ session=self._sql_client._session,
199
+ object_type=db_utils.SnowflakeDbObjectType.WAREHOUSE,
200
+ object_name=warehouse_name,
201
+ statement_params=statement_params,
202
+ ):
203
+ raise ValueError(f"Warehouse '{warehouse_name}' not found.")
204
+
205
+ def _validate_columns_exist_in_source(
206
+ self,
207
+ *,
208
+ source_column_schema: Mapping[str, types.DataType],
209
+ timestamp_column: sql_identifier.SqlIdentifier,
210
+ prediction_score_columns: List[sql_identifier.SqlIdentifier],
211
+ prediction_class_columns: List[sql_identifier.SqlIdentifier],
212
+ actual_score_columns: List[sql_identifier.SqlIdentifier],
213
+ actual_class_columns: List[sql_identifier.SqlIdentifier],
214
+ id_columns: List[sql_identifier.SqlIdentifier],
215
+ ) -> None:
216
+ """Ensures all columns exist in the source table.
217
+
218
+ Args:
219
+ source_column_schema: Dictionary of column names and types in the source.
220
+ timestamp_column: Name of the timestamp column.
221
+ prediction_score_columns: List of prediction score column names.
222
+ prediction_class_columns: List of prediction class names.
223
+ actual_score_columns: List of actual score column names.
224
+ actual_class_columns: List of actual class column names.
225
+ id_columns: List of id column names.
226
+
227
+ Raises:
228
+ ValueError: If any of the columns do not exist in the source.
229
+ """
230
+
231
+ if timestamp_column not in source_column_schema:
232
+ raise ValueError(f"Timestamp column {timestamp_column} does not exist in source.")
233
+
234
+ if not all([column_name in source_column_schema for column_name in prediction_score_columns]):
235
+ raise ValueError(f"Prediction Score column(s): {prediction_score_columns} do not exist in source.")
236
+ if not all([column_name in source_column_schema for column_name in prediction_class_columns]):
237
+ raise ValueError(f"Prediction Class column(s): {prediction_class_columns} do not exist in source.")
238
+ if not all([column_name in source_column_schema for column_name in actual_score_columns]):
239
+ raise ValueError(f"Actual Score column(s): {actual_score_columns} do not exist in source.")
240
+
241
+ if not all([column_name in source_column_schema for column_name in actual_class_columns]):
242
+ raise ValueError(f"Actual Class column(s): {actual_class_columns} do not exist in source.")
243
+
244
+ if not all([column_name in source_column_schema for column_name in id_columns]):
245
+ raise ValueError(f"ID column(s): {id_columns} do not exist in source.")
246
+
247
+ def _validate_timestamp_column_type(
248
+ self, table_schema: Mapping[str, types.DataType], timestamp_column: sql_identifier.SqlIdentifier
249
+ ) -> None:
250
+ """Ensures columns have the same type.
251
+
252
+ Args:
253
+ table_schema: Dictionary of column names and types in the source table.
254
+ timestamp_column: Name of the timestamp column.
255
+
256
+ Raises:
257
+ ValueError: If the timestamp column is not of type TimestampType.
258
+ """
259
+ if not isinstance(table_schema[timestamp_column], types.TimestampType):
260
+ raise ValueError(
261
+ f"Timestamp column: {timestamp_column} must be TimestampType. "
262
+ f"Found: {table_schema[timestamp_column]}"
263
+ )
264
+
265
+ def _validate_id_columns_types(
266
+ self, table_schema: Mapping[str, types.DataType], id_columns: List[sql_identifier.SqlIdentifier]
267
+ ) -> None:
268
+ """Ensures id columns have the correct type.
269
+
270
+ Args:
271
+ table_schema: Dictionary of column names and types in the source table.
272
+ id_columns: List of id column names.
273
+
274
+ Raises:
275
+ ValueError: If the id column is not of type StringType.
276
+ """
277
+ id_column_types = list({table_schema[column_name] for column_name in id_columns})
278
+ all_id_columns_string = all([isinstance(column_type, types.StringType) for column_type in id_column_types])
279
+ if not all_id_columns_string:
280
+ raise ValueError(f"Id columns must all be StringType. Found: {id_column_types}")
281
+
282
+ def _validate_prediction_columns_types(
283
+ self, table_schema: Mapping[str, types.DataType], prediction_columns: List[sql_identifier.SqlIdentifier]
284
+ ) -> None:
285
+ """Ensures prediction columns have the same type.
286
+
287
+ Args:
288
+ table_schema: Dictionary of column names and types in the source table.
289
+ prediction_columns: List of prediction column names.
290
+
291
+ Raises:
292
+ ValueError: If the prediction columns do not share the same type.
293
+ """
294
+
295
+ prediction_column_types = {table_schema[column_name] for column_name in prediction_columns}
296
+ if len(prediction_column_types) > 1:
297
+ raise ValueError(f"Prediction column types must be the same. Found: {prediction_column_types}")
298
+
299
+ def _validate_label_columns_types(
300
+ self,
301
+ table_schema: Mapping[str, types.DataType],
302
+ label_columns: List[sql_identifier.SqlIdentifier],
303
+ ) -> None:
304
+ """Ensures label columns have the same type, and the correct type for the score type.
305
+
306
+ Args:
307
+ table_schema: Dictionary of column names and types in the source table.
308
+ label_columns: List of label column names.
309
+
310
+ Raises:
311
+ ValueError: If the label columns do not share the same type.
312
+ """
313
+ label_column_types = {table_schema[column_name] for column_name in label_columns}
314
+ if len(label_column_types) > 1:
315
+ raise ValueError(f"Label column types must be the same. Found: {label_column_types}")
316
+
317
+ def _validate_column_types(
318
+ self,
319
+ *,
320
+ table_schema: Mapping[str, types.DataType],
321
+ timestamp_column: sql_identifier.SqlIdentifier,
322
+ id_columns: List[sql_identifier.SqlIdentifier],
323
+ prediction_columns: List[sql_identifier.SqlIdentifier],
324
+ label_columns: List[sql_identifier.SqlIdentifier],
325
+ ) -> None:
326
+ """Ensures columns have the expected type.
327
+
328
+ Args:
329
+ table_schema: Dictionary of column names and types in the source table.
330
+ timestamp_column: Name of the timestamp column.
331
+ id_columns: List of id column names.
332
+ prediction_columns: List of prediction column names.
333
+ label_columns: List of label column names.
334
+ """
335
+ self._validate_timestamp_column_type(table_schema, timestamp_column)
336
+ self._validate_id_columns_types(table_schema, id_columns)
337
+ self._validate_prediction_columns_types(table_schema, prediction_columns)
338
+ self._validate_label_columns_types(table_schema, label_columns)
339
+ # TODO(SNOW-1646693): Validate label makes sense with model task
340
+
341
+ def _validate_source_table_features_shape(
342
+ self,
343
+ table_schema: Mapping[str, types.DataType],
344
+ special_columns: Set[sql_identifier.SqlIdentifier],
345
+ model_function: model_manifest_schema.ModelFunctionInfo,
346
+ ) -> None:
347
+ table_schema_without_special_columns = {
348
+ k: v for k, v in table_schema.items() if sql_identifier.SqlIdentifier(k) not in special_columns
349
+ }
350
+ schema_column_types_to_count: typing.Counter[types.DataType] = Counter()
351
+ for column_type in table_schema_without_special_columns.values():
352
+ schema_column_types_to_count[column_type] += 1
353
+
354
+ inputs = model_function["signature"].inputs
355
+ function_input_types = [input.as_snowpark_type() for input in inputs]
356
+ function_input_types_to_count: typing.Counter[types.DataType] = Counter()
357
+ for function_input_type in function_input_types:
358
+ function_input_types_to_count[function_input_type] += 1
359
+
360
+ if function_input_types_to_count != schema_column_types_to_count:
361
+ raise ValueError(
362
+ "Model function input types do not match the source table input columns types. "
363
+ f"Model function expected: {inputs} but got {table_schema_without_special_columns}"
364
+ )
365
+
366
+ def validate_source(
367
+ self,
368
+ *,
369
+ source_database: Optional[sql_identifier.SqlIdentifier],
370
+ source_schema: Optional[sql_identifier.SqlIdentifier],
371
+ source: sql_identifier.SqlIdentifier,
372
+ timestamp_column: sql_identifier.SqlIdentifier,
373
+ prediction_score_columns: List[sql_identifier.SqlIdentifier],
374
+ prediction_class_columns: List[sql_identifier.SqlIdentifier],
375
+ actual_score_columns: List[sql_identifier.SqlIdentifier],
376
+ actual_class_columns: List[sql_identifier.SqlIdentifier],
377
+ id_columns: List[sql_identifier.SqlIdentifier],
378
+ ) -> None:
379
+ source_database = source_database or self._database_name
380
+ source_schema = source_schema or self._schema_name
381
+ # Get Schema of the source. Implicitly validates that the source exists.
382
+ source_column_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
383
+ self._sql_client._session,
384
+ source_database,
385
+ source_schema,
386
+ source,
387
+ )
388
+ self._validate_columns_exist_in_source(
389
+ source_column_schema=source_column_schema,
390
+ timestamp_column=timestamp_column,
391
+ prediction_score_columns=prediction_score_columns,
392
+ prediction_class_columns=prediction_class_columns,
393
+ actual_score_columns=actual_score_columns,
394
+ actual_class_columns=actual_class_columns,
395
+ id_columns=id_columns,
396
+ )
397
+
398
+ def delete_monitor_metadata(
399
+ self,
400
+ name: str,
401
+ statement_params: Optional[Dict[str, Any]] = None,
402
+ ) -> None:
403
+ """Delete the row in the metadata table corresponding to the given monitor name.
404
+
405
+ Args:
406
+ name: Name of the model monitor whose metadata should be deleted.
407
+ statement_params: Optional set of statement_params to include with query.
408
+ """
409
+ self._sql_client._session.sql(
410
+ f"""DELETE FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
411
+ WHERE {MONITOR_NAME_COL_NAME} = '{name}'""",
412
+ ).collect(statement_params=statement_params)
413
+
414
+ def _alter_monitor(
415
+ self,
416
+ operation: str,
417
+ monitor_name: sql_identifier.SqlIdentifier,
418
+ statement_params: Optional[Dict[str, Any]] = None,
419
+ ) -> None:
420
+ if operation not in {"SUSPEND", "RESUME"}:
421
+ raise ValueError(f"Operation {operation} not supported for altering Dynamic Tables")
422
+ query_result_checker.SqlResultValidator(
423
+ self._sql_client._session,
424
+ f"""ALTER MODEL MONITOR {self._database_name}.{self._schema_name}.{monitor_name} {operation}""",
425
+ statement_params=statement_params,
426
+ ).has_column("status").has_dimensions(1, 1).validate()
427
+
428
+ def suspend_monitor(
429
+ self,
430
+ monitor_name: sql_identifier.SqlIdentifier,
431
+ statement_params: Optional[Dict[str, Any]] = None,
432
+ ) -> None:
433
+ self._alter_monitor(
434
+ operation="SUSPEND",
435
+ monitor_name=monitor_name,
436
+ statement_params=statement_params,
437
+ )
438
+
439
+ def resume_monitor(
440
+ self,
441
+ monitor_name: sql_identifier.SqlIdentifier,
442
+ statement_params: Optional[Dict[str, Any]] = None,
443
+ ) -> None:
444
+ self._alter_monitor(
445
+ operation="RESUME",
446
+ monitor_name=monitor_name,
447
+ statement_params=statement_params,
448
+ )
@@ -0,0 +1,238 @@
1
+ import json
2
+ from typing import Any, Dict, List, Optional
3
+
4
+ from snowflake import snowpark
5
+ from snowflake.ml._internal.utils import sql_identifier
6
+ from snowflake.ml.model import type_hints
7
+ from snowflake.ml.model._client.model import model_version_impl
8
+ from snowflake.ml.monitoring import model_monitor
9
+ from snowflake.ml.monitoring._client import model_monitor_sql_client
10
+ from snowflake.ml.monitoring.entities import model_monitor_config
11
+ from snowflake.snowpark import session
12
+
13
+
14
+ class ModelMonitorManager:
15
+ """Class to manage internal operations for Model Monitor workflows."""
16
+
17
+ def _validate_task_from_model_version(
18
+ self,
19
+ model_version: model_version_impl.ModelVersion,
20
+ ) -> type_hints.Task:
21
+ task = model_version.get_model_task()
22
+ if task == type_hints.Task.UNKNOWN:
23
+ raise ValueError("Registry model must be logged with task in order to be monitored.")
24
+ return task
25
+
26
+ def __init__(
27
+ self,
28
+ session: session.Session,
29
+ database_name: sql_identifier.SqlIdentifier,
30
+ schema_name: sql_identifier.SqlIdentifier,
31
+ *,
32
+ statement_params: Optional[Dict[str, Any]] = None,
33
+ ) -> None:
34
+ """
35
+ Opens a ModelMonitorManager for a given database and schema.
36
+ Optionally sets up the schema for Model Monitoring.
37
+
38
+ Args:
39
+ session: The Snowpark Session to connect with Snowflake.
40
+ database_name: The name of the database.
41
+ schema_name: The name of the schema.
42
+ statement_params: Optional set of statement params.
43
+ """
44
+ self._database_name = database_name
45
+ self._schema_name = schema_name
46
+ self.statement_params = statement_params
47
+
48
+ self._model_monitor_client = model_monitor_sql_client.ModelMonitorSQLClient(
49
+ session,
50
+ database_name=self._database_name,
51
+ schema_name=self._schema_name,
52
+ )
53
+
54
+ def _validate_model_function_from_model_version(
55
+ self, function: str, model_version: model_version_impl.ModelVersion
56
+ ) -> None:
57
+ functions = model_version.show_functions()
58
+ for f in functions:
59
+ if f["target_method"] == function:
60
+ return
61
+ existing_target_methods = {f["target_method"] for f in functions}
62
+ raise ValueError(
63
+ f"Function with name {function} does not exist in the given model version. "
64
+ f"Found: {existing_target_methods}."
65
+ )
66
+
67
+ def _build_column_list_from_input(self, columns: Optional[List[str]]) -> List[sql_identifier.SqlIdentifier]:
68
+ return [sql_identifier.SqlIdentifier(column_name) for column_name in columns] if columns else []
69
+
70
+ def add_monitor(
71
+ self,
72
+ name: str,
73
+ source_config: model_monitor_config.ModelMonitorSourceConfig,
74
+ model_monitor_config: model_monitor_config.ModelMonitorConfig,
75
+ ) -> model_monitor.ModelMonitor:
76
+ """Add a new Model Monitor.
77
+
78
+ Args:
79
+ name: Name of Model Monitor to create.
80
+ source_config: Configuration options for the source table used in ModelMonitor.
81
+ model_monitor_config: Configuration options of ModelMonitor.
82
+
83
+ Returns:
84
+ The newly added ModelMonitor object.
85
+ """
86
+ warehouse_name_id = sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name)
87
+ self._model_monitor_client.validate_monitor_warehouse(warehouse_name_id, statement_params=self.statement_params)
88
+ self._validate_model_function_from_model_version(
89
+ model_monitor_config.model_function_name, model_monitor_config.model_version
90
+ )
91
+ self._validate_task_from_model_version(model_monitor_config.model_version)
92
+ monitor_database_name_id, monitor_schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name(
93
+ name
94
+ )
95
+ source_database_name_id, source_schema_name_id, source_name_id = sql_identifier.parse_fully_qualified_name(
96
+ source_config.source
97
+ )
98
+ baseline_database_name_id, baseline_schema_name_id, baseline_name_id = (
99
+ sql_identifier.parse_fully_qualified_name(source_config.baseline)
100
+ if source_config.baseline
101
+ else (None, None, None)
102
+ )
103
+ model_database_name_id, model_schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(
104
+ model_monitor_config.model_version.fully_qualified_model_name
105
+ )
106
+
107
+ prediction_score_columns = self._build_column_list_from_input(source_config.prediction_score_columns)
108
+ prediction_class_columns = self._build_column_list_from_input(source_config.prediction_class_columns)
109
+ actual_score_columns = self._build_column_list_from_input(source_config.actual_score_columns)
110
+ actual_class_columns = self._build_column_list_from_input(source_config.actual_class_columns)
111
+
112
+ id_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in source_config.id_columns]
113
+ ts_column = sql_identifier.SqlIdentifier(source_config.timestamp_column)
114
+
115
+ # Validate source table
116
+ self._model_monitor_client.validate_source(
117
+ source_database=source_database_name_id,
118
+ source_schema=source_schema_name_id,
119
+ source=source_name_id,
120
+ timestamp_column=ts_column,
121
+ prediction_score_columns=prediction_score_columns,
122
+ prediction_class_columns=prediction_class_columns,
123
+ actual_score_columns=actual_score_columns,
124
+ actual_class_columns=actual_class_columns,
125
+ id_columns=id_columns,
126
+ )
127
+
128
+ self._model_monitor_client.create_model_monitor(
129
+ monitor_database=monitor_database_name_id,
130
+ monitor_schema=monitor_schema_name_id,
131
+ monitor_name=monitor_name_id,
132
+ source_database=source_database_name_id,
133
+ source_schema=source_schema_name_id,
134
+ source=source_name_id,
135
+ model_database=model_database_name_id,
136
+ model_schema=model_schema_name_id,
137
+ model_name=model_name_id,
138
+ version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
139
+ function_name=model_monitor_config.model_function_name,
140
+ warehouse_name=warehouse_name_id,
141
+ timestamp_column=ts_column,
142
+ id_columns=id_columns,
143
+ prediction_score_columns=prediction_score_columns,
144
+ prediction_class_columns=prediction_class_columns,
145
+ actual_score_columns=actual_score_columns,
146
+ actual_class_columns=actual_class_columns,
147
+ refresh_interval=model_monitor_config.refresh_interval,
148
+ aggregation_window=model_monitor_config.aggregation_window,
149
+ baseline_database=baseline_database_name_id,
150
+ baseline_schema=baseline_schema_name_id,
151
+ baseline=baseline_name_id,
152
+ statement_params=self.statement_params,
153
+ )
154
+ return model_monitor.ModelMonitor._ref(
155
+ model_monitor_client=self._model_monitor_client,
156
+ name=monitor_name_id,
157
+ )
158
+
159
+ def get_monitor_by_model_version(
160
+ self, model_version: model_version_impl.ModelVersion
161
+ ) -> model_monitor.ModelMonitor:
162
+ """Get a Model Monitor by Model Version.
163
+
164
+ Args:
165
+ model_version: ModelVersion to retrieve Model Monitor for.
166
+
167
+ Returns:
168
+ The fetched ModelMonitor.
169
+
170
+ Raises:
171
+ ValueError: If model monitor is not found.
172
+ """
173
+ rows = self._model_monitor_client.show_model_monitors(statement_params=self.statement_params)
174
+
175
+ def model_match_fn(model_details: Dict[str, str]) -> bool:
176
+ return (
177
+ model_details[model_monitor_sql_client.MODEL_JSON_MODEL_NAME_FIELD] == model_version.model_name
178
+ and model_details[model_monitor_sql_client.MODEL_JSON_VERSION_NAME_FIELD] == model_version.version_name
179
+ )
180
+
181
+ rows = [row for row in rows if model_match_fn(json.loads(row[model_monitor_sql_client.MODEL_JSON_COL_NAME]))]
182
+ if len(rows) == 0:
183
+ raise ValueError("Unable to find model monitor for the given model version.")
184
+ if len(rows) > 1:
185
+ raise ValueError("Found multiple model monitors for the given model version.")
186
+
187
+ return model_monitor.ModelMonitor._ref(
188
+ model_monitor_client=self._model_monitor_client,
189
+ name=sql_identifier.SqlIdentifier(rows[0]["name"]),
190
+ )
191
+
192
+ def get_monitor(self, name: str) -> model_monitor.ModelMonitor:
193
+ """Get a Model Monitor from the Registry
194
+
195
+ Args:
196
+ name: Name of Model Monitor to retrieve.
197
+
198
+ Raises:
199
+ ValueError: If model monitor is not found.
200
+
201
+ Returns:
202
+ The fetched ModelMonitor.
203
+ """
204
+ database_name_id, schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name(name)
205
+
206
+ if not self._model_monitor_client.validate_existence_by_name(
207
+ database_name=database_name_id,
208
+ schema_name=schema_name_id,
209
+ monitor_name=monitor_name_id,
210
+ statement_params=self.statement_params,
211
+ ):
212
+ raise ValueError(f"Unable to find model monitor '{name}'")
213
+ return model_monitor.ModelMonitor._ref(
214
+ model_monitor_client=self._model_monitor_client,
215
+ name=monitor_name_id,
216
+ )
217
+
218
+ def show_model_monitors(self) -> List[snowpark.Row]:
219
+ """Show all model monitors in the registry.
220
+
221
+ Returns:
222
+ List of snowpark.Row containing metadata for each model monitor.
223
+ """
224
+ return self._model_monitor_client.show_model_monitors(statement_params=self.statement_params)
225
+
226
+ def delete_monitor(self, name: str) -> None:
227
+ """Delete a Model Monitor from the Registry
228
+
229
+ Args:
230
+ name: Name of the Model Monitor to delete.
231
+ """
232
+ database_name_id, schema_name_id, monitor_name_id = sql_identifier.parse_fully_qualified_name(name)
233
+ self._model_monitor_client.drop_model_monitor(
234
+ database_name=database_name_id,
235
+ schema_name=schema_name_id,
236
+ monitor_name=monitor_name_id,
237
+ statement_params=self.statement_params,
238
+ )