snowflake-ml-python 1.6.4__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_complete.py +107 -64
- snowflake/cortex/_finetune.py +273 -0
- snowflake/cortex/_sse_client.py +91 -28
- snowflake/cortex/_util.py +30 -1
- snowflake/ml/_internal/telemetry.py +4 -2
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/import_utils.py +31 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
- snowflake/ml/data/__init__.py +5 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/data/torch_utils.py +33 -14
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
- snowflake/ml/feature_store/examples/example_helper.py +6 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
- snowflake/ml/feature_store/feature_store.py +1 -2
- snowflake/ml/feature_store/feature_view.py +5 -1
- snowflake/ml/model/_client/model/model_version_impl.py +145 -11
- snowflake/ml/model/_client/ops/model_ops.py +56 -16
- snowflake/ml/model/_client/ops/service_ops.py +46 -30
- snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_client/sql/service.py +25 -1
- snowflake/ml/model/_model_composer/model_composer.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
- snowflake/ml/model/_packager/model_env/model_env.py +12 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +6 -2
- snowflake/ml/model/_packager/model_handlers/catboost.py +4 -7
- snowflake/ml/model/_packager/model_handlers/custom.py +5 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -7
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +51 -7
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +8 -66
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
- snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
- snowflake/ml/model/_packager/model_packager.py +0 -11
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
- snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +14 -26
- snowflake/ml/model/_signatures/core.py +63 -16
- snowflake/ml/model/_signatures/pandas_handler.py +87 -27
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
- snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/custom_model.py +47 -7
- snowflake/ml/model/model_signature.py +40 -9
- snowflake/ml/model/type_hints.py +9 -1
- snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
- snowflake/ml/modeling/cluster/dbscan.py +5 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
- snowflake/ml/modeling/cluster/k_means.py +14 -19
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
- snowflake/ml/modeling/cluster/optics.py +6 -6
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
- snowflake/ml/modeling/compose/column_transformer.py +15 -5
- snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
- snowflake/ml/modeling/decomposition/pca.py +28 -15
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
- snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
- snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +0 -10
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
- snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +3 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
- snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
- snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +3 -3
- snowflake/ml/modeling/manifold/tsne.py +10 -4
- snowflake/ml/modeling/metrics/classification.py +12 -16
- snowflake/ml/modeling/metrics/ranking.py +3 -3
- snowflake/ml/modeling/metrics/regression.py +3 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
- snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +16 -14
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
- snowflake/ml/modeling/svm/linear_svc.py +25 -16
- snowflake/ml/modeling/svm/linear_svr.py +23 -17
- snowflake/ml/modeling/svm/nu_svc.py +5 -3
- snowflake/ml/modeling/svm/nu_svr.py +3 -1
- snowflake/ml/modeling/svm/svc.py +9 -5
- snowflake/ml/modeling/svm/svr.py +3 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
- snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
- snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
- snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +448 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +238 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
- snowflake/ml/monitoring/model_monitor.py +37 -0
- snowflake/ml/registry/_manager/model_manager.py +15 -1
- snowflake/ml/registry/registry.py +32 -37
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +104 -12
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +172 -171
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/_client/model_monitor.py +0 -126
- snowflake/ml/monitoring/_client/model_monitor_manager.py +0 -361
- snowflake/ml/monitoring/_client/monitor_sql_client.py +0 -1335
- snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
- /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -1,1335 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import string
|
3
|
-
import textwrap
|
4
|
-
import typing
|
5
|
-
from collections import Counter
|
6
|
-
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, TypedDict
|
7
|
-
|
8
|
-
from importlib_resources import files
|
9
|
-
from typing_extensions import Required
|
10
|
-
|
11
|
-
from snowflake import snowpark
|
12
|
-
from snowflake.connector import errors
|
13
|
-
from snowflake.ml._internal.utils import (
|
14
|
-
db_utils,
|
15
|
-
formatting,
|
16
|
-
query_result_checker,
|
17
|
-
sql_identifier,
|
18
|
-
table_manager,
|
19
|
-
)
|
20
|
-
from snowflake.ml.model import type_hints
|
21
|
-
from snowflake.ml.model._client.sql import _base
|
22
|
-
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
23
|
-
from snowflake.ml.monitoring.entities import model_monitor_interval, output_score_type
|
24
|
-
from snowflake.ml.monitoring.entities.model_monitor_interval import (
|
25
|
-
ModelMonitorAggregationWindow,
|
26
|
-
ModelMonitorRefreshInterval,
|
27
|
-
)
|
28
|
-
from snowflake.snowpark import DataFrame, exceptions, session, types
|
29
|
-
from snowflake.snowpark._internal import type_utils
|
30
|
-
|
31
|
-
SNOWML_MONITORING_METADATA_TABLE_NAME = "_SYSTEM_MONITORING_METADATA"
|
32
|
-
_SNOWML_MONITORING_TABLE_NAME_PREFIX = "_SNOWML_OBS_MONITORING_"
|
33
|
-
_SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX = "_SNOWML_OBS_ACCURACY_"
|
34
|
-
|
35
|
-
MONITOR_NAME_COL_NAME = "MONITOR_NAME"
|
36
|
-
SOURCE_TABLE_NAME_COL_NAME = "SOURCE_TABLE_NAME"
|
37
|
-
FQ_MODEL_NAME_COL_NAME = "FULLY_QUALIFIED_MODEL_NAME"
|
38
|
-
VERSION_NAME_COL_NAME = "MODEL_VERSION_NAME"
|
39
|
-
FUNCTION_NAME_COL_NAME = "FUNCTION_NAME"
|
40
|
-
TASK_COL_NAME = "TASK"
|
41
|
-
MONITORING_ENABLED_COL_NAME = "IS_ENABLED"
|
42
|
-
TIMESTAMP_COL_NAME_COL_NAME = "TIMESTAMP_COLUMN_NAME"
|
43
|
-
PREDICTION_COL_NAMES_COL_NAME = "PREDICTION_COLUMN_NAMES"
|
44
|
-
LABEL_COL_NAMES_COL_NAME = "LABEL_COLUMN_NAMES"
|
45
|
-
ID_COL_NAMES_COL_NAME = "ID_COLUMN_NAMES"
|
46
|
-
|
47
|
-
_DASHBOARD_UDTFS_COMMON_LIST = ["record_count"]
|
48
|
-
_DASHBOARD_UDTFS_REGRESSION_LIST = ["rmse"]
|
49
|
-
|
50
|
-
|
51
|
-
def _initialize_monitoring_metadata_tables(
|
52
|
-
session: session.Session,
|
53
|
-
database_name: sql_identifier.SqlIdentifier,
|
54
|
-
schema_name: sql_identifier.SqlIdentifier,
|
55
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
56
|
-
) -> None:
|
57
|
-
"""Create tables necessary for Model Monitoring in provided schema.
|
58
|
-
|
59
|
-
Args:
|
60
|
-
session: Active Snowpark session.
|
61
|
-
database_name: The database in which to setup resources for Model Monitoring.
|
62
|
-
schema_name: The schema in which to setup resources for Model Monitoring.
|
63
|
-
statement_params: Optional statement params for queries.
|
64
|
-
"""
|
65
|
-
table_manager.create_single_table(
|
66
|
-
session,
|
67
|
-
database_name,
|
68
|
-
schema_name,
|
69
|
-
SNOWML_MONITORING_METADATA_TABLE_NAME,
|
70
|
-
[
|
71
|
-
(MONITOR_NAME_COL_NAME, "VARCHAR"),
|
72
|
-
(SOURCE_TABLE_NAME_COL_NAME, "VARCHAR"),
|
73
|
-
(FQ_MODEL_NAME_COL_NAME, "VARCHAR"),
|
74
|
-
(VERSION_NAME_COL_NAME, "VARCHAR"),
|
75
|
-
(FUNCTION_NAME_COL_NAME, "VARCHAR"),
|
76
|
-
(TASK_COL_NAME, "VARCHAR"),
|
77
|
-
(MONITORING_ENABLED_COL_NAME, "BOOLEAN"),
|
78
|
-
(TIMESTAMP_COL_NAME_COL_NAME, "VARCHAR"),
|
79
|
-
(PREDICTION_COL_NAMES_COL_NAME, "ARRAY"),
|
80
|
-
(LABEL_COL_NAMES_COL_NAME, "ARRAY"),
|
81
|
-
(ID_COL_NAMES_COL_NAME, "ARRAY"),
|
82
|
-
],
|
83
|
-
statement_params=statement_params,
|
84
|
-
)
|
85
|
-
|
86
|
-
|
87
|
-
def _create_baseline_table_name(model_name: str, version_name: str) -> str:
|
88
|
-
return f"_SNOWML_OBS_BASELINE_{model_name}_{version_name}"
|
89
|
-
|
90
|
-
|
91
|
-
def _infer_numeric_categoric_feature_column_names(
|
92
|
-
*,
|
93
|
-
source_table_schema: Mapping[str, types.DataType],
|
94
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
95
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
96
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
97
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
98
|
-
) -> Tuple[List[sql_identifier.SqlIdentifier], List[sql_identifier.SqlIdentifier]]:
|
99
|
-
cols_to_remove = {timestamp_column, *id_columns, *prediction_columns, *label_columns}
|
100
|
-
cols_to_consider = [
|
101
|
-
(col_name, source_table_schema[col_name]) for col_name in source_table_schema if col_name not in cols_to_remove
|
102
|
-
]
|
103
|
-
numeric_cols = [
|
104
|
-
sql_identifier.SqlIdentifier(column[0])
|
105
|
-
for column in cols_to_consider
|
106
|
-
if isinstance(column[1], types._NumericType)
|
107
|
-
]
|
108
|
-
categorical_cols = [
|
109
|
-
sql_identifier.SqlIdentifier(column[0])
|
110
|
-
for column in cols_to_consider
|
111
|
-
if isinstance(column[1], types.StringType) or isinstance(column[1], types.BooleanType)
|
112
|
-
]
|
113
|
-
return (numeric_cols, categorical_cols)
|
114
|
-
|
115
|
-
|
116
|
-
class _ModelMonitorParams(TypedDict):
|
117
|
-
"""Class to transfer model monitor parameters to the ModelMonitor class."""
|
118
|
-
|
119
|
-
monitor_name: Required[str]
|
120
|
-
fully_qualified_model_name: Required[str]
|
121
|
-
version_name: Required[str]
|
122
|
-
function_name: Required[str]
|
123
|
-
prediction_columns: Required[List[sql_identifier.SqlIdentifier]]
|
124
|
-
label_columns: Required[List[sql_identifier.SqlIdentifier]]
|
125
|
-
|
126
|
-
|
127
|
-
class _ModelMonitorSQLClient:
|
128
|
-
def __init__(
|
129
|
-
self,
|
130
|
-
session: session.Session,
|
131
|
-
*,
|
132
|
-
database_name: sql_identifier.SqlIdentifier,
|
133
|
-
schema_name: sql_identifier.SqlIdentifier,
|
134
|
-
) -> None:
|
135
|
-
"""Client to manage monitoring metadata persisted in SNOWML_OBSERVABILITY.METADATA schema.
|
136
|
-
|
137
|
-
Args:
|
138
|
-
session: Active snowpark session.
|
139
|
-
database_name: Name of the Database where monitoring resources are provisioned.
|
140
|
-
schema_name: Name of the Schema where monitoring resources are provisioned.
|
141
|
-
"""
|
142
|
-
self._sql_client = _base._BaseSQLClient(session, database_name=database_name, schema_name=schema_name)
|
143
|
-
self._database_name = database_name
|
144
|
-
self._schema_name = schema_name
|
145
|
-
|
146
|
-
@staticmethod
|
147
|
-
def initialize_monitoring_schema(
|
148
|
-
session: session.Session,
|
149
|
-
database_name: sql_identifier.SqlIdentifier,
|
150
|
-
schema_name: sql_identifier.SqlIdentifier,
|
151
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
152
|
-
) -> None:
|
153
|
-
"""Initialize tables for tracking metadata associated with model monitoring.
|
154
|
-
|
155
|
-
Args:
|
156
|
-
session: The Snowpark Session to connect with Snowflake.
|
157
|
-
database_name: The database in which to setup resources for Model Monitoring.
|
158
|
-
schema_name: The schema in which to setup resources for Model Monitoring.
|
159
|
-
statement_params: Optional set of statement_params to include with query.
|
160
|
-
"""
|
161
|
-
# Create metadata management tables
|
162
|
-
_initialize_monitoring_metadata_tables(session, database_name, schema_name, statement_params)
|
163
|
-
|
164
|
-
def _validate_is_initialized(self) -> bool:
|
165
|
-
"""Validates whether monitoring metadata has been initialized.
|
166
|
-
|
167
|
-
Returns:
|
168
|
-
boolean to indicate whether tables have been initialized.
|
169
|
-
"""
|
170
|
-
try:
|
171
|
-
return table_manager.validate_table_exist(
|
172
|
-
self._sql_client._session,
|
173
|
-
SNOWML_MONITORING_METADATA_TABLE_NAME,
|
174
|
-
f"{self._database_name}.{self._schema_name}",
|
175
|
-
)
|
176
|
-
except exceptions.SnowparkSQLException:
|
177
|
-
return False
|
178
|
-
|
179
|
-
def _validate_unique_columns(
|
180
|
-
self,
|
181
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
182
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
183
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
184
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
185
|
-
) -> None:
|
186
|
-
all_columns = [*id_columns, *prediction_columns, *label_columns, timestamp_column]
|
187
|
-
num_all_columns = len(all_columns)
|
188
|
-
num_unique_columns = len(set(all_columns))
|
189
|
-
if num_all_columns != num_unique_columns:
|
190
|
-
raise ValueError("Column names must be unique across id, timestamp, prediction, and label columns.")
|
191
|
-
|
192
|
-
def validate_existence_by_name(
|
193
|
-
self,
|
194
|
-
monitor_name: sql_identifier.SqlIdentifier,
|
195
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
196
|
-
) -> bool:
|
197
|
-
res = (
|
198
|
-
query_result_checker.SqlResultValidator(
|
199
|
-
self._sql_client._session,
|
200
|
-
f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME}
|
201
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
202
|
-
WHERE {MONITOR_NAME_COL_NAME} = '{monitor_name}'""",
|
203
|
-
statement_params=statement_params,
|
204
|
-
)
|
205
|
-
.has_column(FQ_MODEL_NAME_COL_NAME, allow_empty=True)
|
206
|
-
.has_column(VERSION_NAME_COL_NAME, allow_empty=True)
|
207
|
-
.validate()
|
208
|
-
)
|
209
|
-
return len(res) >= 1
|
210
|
-
|
211
|
-
def validate_existence(
|
212
|
-
self,
|
213
|
-
fully_qualified_model_name: str,
|
214
|
-
version_name: sql_identifier.SqlIdentifier,
|
215
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
216
|
-
) -> bool:
|
217
|
-
"""Validate existence of a ModelMonitor on a Model Version.
|
218
|
-
|
219
|
-
Args:
|
220
|
-
fully_qualified_model_name: Fully qualified name of model.
|
221
|
-
version_name: Name of model version.
|
222
|
-
statement_params: Optional set of statement_params to include with query.
|
223
|
-
|
224
|
-
Returns:
|
225
|
-
Boolean indicating whether monitor exists on model version.
|
226
|
-
"""
|
227
|
-
res = (
|
228
|
-
query_result_checker.SqlResultValidator(
|
229
|
-
self._sql_client._session,
|
230
|
-
f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME}
|
231
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
232
|
-
WHERE {FQ_MODEL_NAME_COL_NAME} = '{fully_qualified_model_name}'
|
233
|
-
AND {VERSION_NAME_COL_NAME} = '{version_name}'""",
|
234
|
-
statement_params=statement_params,
|
235
|
-
)
|
236
|
-
.has_column(FQ_MODEL_NAME_COL_NAME, allow_empty=True)
|
237
|
-
.has_column(VERSION_NAME_COL_NAME, allow_empty=True)
|
238
|
-
.validate()
|
239
|
-
)
|
240
|
-
return len(res) >= 1
|
241
|
-
|
242
|
-
def validate_monitor_warehouse(
|
243
|
-
self,
|
244
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
245
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
246
|
-
) -> None:
|
247
|
-
"""Validate warehouse provided for monitoring exists.
|
248
|
-
|
249
|
-
Args:
|
250
|
-
warehouse_name: Warehouse name
|
251
|
-
statement_params: Optional set of statement params to include in queries.
|
252
|
-
|
253
|
-
Raises:
|
254
|
-
ValueError: If warehouse does not exist.
|
255
|
-
"""
|
256
|
-
if not db_utils.db_object_exists(
|
257
|
-
session=self._sql_client._session,
|
258
|
-
object_type=db_utils.SnowflakeDbObjectType.WAREHOUSE,
|
259
|
-
object_name=warehouse_name,
|
260
|
-
statement_params=statement_params,
|
261
|
-
):
|
262
|
-
raise ValueError(f"Warehouse '{warehouse_name}' not found.")
|
263
|
-
|
264
|
-
def add_dashboard_udtfs(
|
265
|
-
self,
|
266
|
-
monitor_name: sql_identifier.SqlIdentifier,
|
267
|
-
model_name: sql_identifier.SqlIdentifier,
|
268
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
269
|
-
task: type_hints.Task,
|
270
|
-
score_type: output_score_type.OutputScoreType,
|
271
|
-
output_columns: List[sql_identifier.SqlIdentifier],
|
272
|
-
ground_truth_columns: List[sql_identifier.SqlIdentifier],
|
273
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
274
|
-
) -> None:
|
275
|
-
udtf_name_query_map = self._create_dashboard_udtf_queries(
|
276
|
-
monitor_name,
|
277
|
-
model_name,
|
278
|
-
model_version_name,
|
279
|
-
task,
|
280
|
-
score_type,
|
281
|
-
output_columns,
|
282
|
-
ground_truth_columns,
|
283
|
-
)
|
284
|
-
for udtf_query in udtf_name_query_map.values():
|
285
|
-
query_result_checker.SqlResultValidator(
|
286
|
-
self._sql_client._session,
|
287
|
-
f"""{udtf_query}""",
|
288
|
-
statement_params=statement_params,
|
289
|
-
).validate()
|
290
|
-
|
291
|
-
def get_monitoring_table_fully_qualified_name(
|
292
|
-
self,
|
293
|
-
model_name: sql_identifier.SqlIdentifier,
|
294
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
295
|
-
) -> str:
|
296
|
-
table_name = f"{_SNOWML_MONITORING_TABLE_NAME_PREFIX}_{model_name}_{model_version_name}"
|
297
|
-
return table_manager.get_fully_qualified_table_name(self._database_name, self._schema_name, table_name)
|
298
|
-
|
299
|
-
def get_accuracy_monitoring_table_fully_qualified_name(
|
300
|
-
self,
|
301
|
-
model_name: sql_identifier.SqlIdentifier,
|
302
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
303
|
-
) -> str:
|
304
|
-
table_name = f"{_SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX}_{model_name}_{model_version_name}"
|
305
|
-
return table_manager.get_fully_qualified_table_name(self._database_name, self._schema_name, table_name)
|
306
|
-
|
307
|
-
def _create_dashboard_udtf_queries(
|
308
|
-
self,
|
309
|
-
monitor_name: sql_identifier.SqlIdentifier,
|
310
|
-
model_name: sql_identifier.SqlIdentifier,
|
311
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
312
|
-
task: type_hints.Task,
|
313
|
-
score_type: output_score_type.OutputScoreType,
|
314
|
-
output_columns: List[sql_identifier.SqlIdentifier],
|
315
|
-
ground_truth_columns: List[sql_identifier.SqlIdentifier],
|
316
|
-
) -> Mapping[str, str]:
|
317
|
-
query_files = files("snowflake.ml.monitoring._client")
|
318
|
-
# TODO(apgupta): Expand list of queries based on model objective and score type.
|
319
|
-
queries_list = []
|
320
|
-
queries_list.extend(_DASHBOARD_UDTFS_COMMON_LIST)
|
321
|
-
if task == type_hints.Task.TABULAR_REGRESSION:
|
322
|
-
queries_list.extend(_DASHBOARD_UDTFS_REGRESSION_LIST)
|
323
|
-
var_map = {
|
324
|
-
"MODEL_MONITOR_NAME": monitor_name,
|
325
|
-
"MONITORING_TABLE": self.get_monitoring_table_fully_qualified_name(model_name, model_version_name),
|
326
|
-
"MONITORING_PRED_LABEL_JOINED_TABLE": self.get_accuracy_monitoring_table_fully_qualified_name(
|
327
|
-
model_name, model_version_name
|
328
|
-
),
|
329
|
-
"OUTPUT_COLUMN_NAME": output_columns[0],
|
330
|
-
"GROUND_TRUTH_COLUMN_NAME": ground_truth_columns[0],
|
331
|
-
}
|
332
|
-
|
333
|
-
udf_name_query_map = {}
|
334
|
-
for q in queries_list:
|
335
|
-
q_template = query_files.joinpath(f"queries/{q}.ssql").read_text()
|
336
|
-
q_actual = string.Template(q_template).substitute(var_map)
|
337
|
-
udf_name_query_map[q] = q_actual
|
338
|
-
return udf_name_query_map
|
339
|
-
|
340
|
-
def _validate_columns_exist_in_source_table(
|
341
|
-
self,
|
342
|
-
*,
|
343
|
-
table_schema: Mapping[str, types.DataType],
|
344
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
345
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
346
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
347
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
348
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
349
|
-
) -> None:
|
350
|
-
"""Ensures all columns exist in the source table.
|
351
|
-
|
352
|
-
Args:
|
353
|
-
table_schema: Dictionary of column names and types in the source table.
|
354
|
-
source_table_name: Name of the table with model data to monitor.
|
355
|
-
timestamp_column: Name of the timestamp column.
|
356
|
-
prediction_columns: List of prediction column names.
|
357
|
-
label_columns: List of label column names.
|
358
|
-
id_columns: List of id column names.
|
359
|
-
|
360
|
-
Raises:
|
361
|
-
ValueError: If any of the columns do not exist in the source table.
|
362
|
-
"""
|
363
|
-
|
364
|
-
if timestamp_column not in table_schema:
|
365
|
-
raise ValueError(f"Timestamp column {timestamp_column} does not exist in table {source_table_name}.")
|
366
|
-
|
367
|
-
if not all([column_name in table_schema for column_name in prediction_columns]):
|
368
|
-
raise ValueError(f"Prediction column(s): {prediction_columns} do not exist in table {source_table_name}.")
|
369
|
-
if not all([column_name in table_schema for column_name in label_columns]):
|
370
|
-
raise ValueError(f"Label column(s): {label_columns} do not exist in table {source_table_name}.")
|
371
|
-
if not all([column_name in table_schema for column_name in id_columns]):
|
372
|
-
raise ValueError(f"ID column(s): {id_columns} do not exist in table {source_table_name}.")
|
373
|
-
|
374
|
-
def _validate_timestamp_column_type(
|
375
|
-
self, table_schema: Mapping[str, types.DataType], timestamp_column: sql_identifier.SqlIdentifier
|
376
|
-
) -> None:
|
377
|
-
"""Ensures columns have the same type.
|
378
|
-
|
379
|
-
Args:
|
380
|
-
table_schema: Dictionary of column names and types in the source table.
|
381
|
-
timestamp_column: Name of the timestamp column.
|
382
|
-
|
383
|
-
Raises:
|
384
|
-
ValueError: If the timestamp column is not of type TimestampType.
|
385
|
-
"""
|
386
|
-
if not isinstance(table_schema[timestamp_column], types.TimestampType):
|
387
|
-
raise ValueError(
|
388
|
-
f"Timestamp column: {timestamp_column} must be TimestampType. "
|
389
|
-
f"Found: {table_schema[timestamp_column]}"
|
390
|
-
)
|
391
|
-
|
392
|
-
def _validate_id_columns_types(
|
393
|
-
self, table_schema: Mapping[str, types.DataType], id_columns: List[sql_identifier.SqlIdentifier]
|
394
|
-
) -> None:
|
395
|
-
"""Ensures id columns have the correct type.
|
396
|
-
|
397
|
-
Args:
|
398
|
-
table_schema: Dictionary of column names and types in the source table.
|
399
|
-
id_columns: List of id column names.
|
400
|
-
|
401
|
-
Raises:
|
402
|
-
ValueError: If the id column is not of type StringType.
|
403
|
-
"""
|
404
|
-
id_column_types = list({table_schema[column_name] for column_name in id_columns})
|
405
|
-
all_id_columns_string = all([isinstance(column_type, types.StringType) for column_type in id_column_types])
|
406
|
-
if not all_id_columns_string:
|
407
|
-
raise ValueError(f"Id columns must all be StringType. Found: {id_column_types}")
|
408
|
-
|
409
|
-
def _validate_prediction_columns_types(
|
410
|
-
self, table_schema: Mapping[str, types.DataType], prediction_columns: List[sql_identifier.SqlIdentifier]
|
411
|
-
) -> None:
|
412
|
-
"""Ensures prediction columns have the same type.
|
413
|
-
|
414
|
-
Args:
|
415
|
-
table_schema: Dictionary of column names and types in the source table.
|
416
|
-
prediction_columns: List of prediction column names.
|
417
|
-
|
418
|
-
Raises:
|
419
|
-
ValueError: If the prediction columns do not share the same type.
|
420
|
-
"""
|
421
|
-
|
422
|
-
prediction_column_types = {table_schema[column_name] for column_name in prediction_columns}
|
423
|
-
if len(prediction_column_types) > 1:
|
424
|
-
raise ValueError(f"Prediction column types must be the same. Found: {prediction_column_types}")
|
425
|
-
|
426
|
-
def _validate_label_columns_types(
|
427
|
-
self,
|
428
|
-
table_schema: Mapping[str, types.DataType],
|
429
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
430
|
-
) -> None:
|
431
|
-
"""Ensures label columns have the same type, and the correct type for the score type.
|
432
|
-
|
433
|
-
Args:
|
434
|
-
table_schema: Dictionary of column names and types in the source table.
|
435
|
-
label_columns: List of label column names.
|
436
|
-
|
437
|
-
Raises:
|
438
|
-
ValueError: If the label columns do not share the same type.
|
439
|
-
"""
|
440
|
-
label_column_types = {table_schema[column_name] for column_name in label_columns}
|
441
|
-
if len(label_column_types) > 1:
|
442
|
-
raise ValueError(f"Label column types must be the same. Found: {label_column_types}")
|
443
|
-
|
444
|
-
def _validate_column_types(
|
445
|
-
self,
|
446
|
-
*,
|
447
|
-
table_schema: Mapping[str, types.DataType],
|
448
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
449
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
450
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
451
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
452
|
-
) -> None:
|
453
|
-
"""Ensures columns have the expected type.
|
454
|
-
|
455
|
-
Args:
|
456
|
-
table_schema: Dictionary of column names and types in the source table.
|
457
|
-
timestamp_column: Name of the timestamp column.
|
458
|
-
id_columns: List of id column names.
|
459
|
-
prediction_columns: List of prediction column names.
|
460
|
-
label_columns: List of label column names.
|
461
|
-
"""
|
462
|
-
self._validate_timestamp_column_type(table_schema, timestamp_column)
|
463
|
-
self._validate_id_columns_types(table_schema, id_columns)
|
464
|
-
self._validate_prediction_columns_types(table_schema, prediction_columns)
|
465
|
-
self._validate_label_columns_types(table_schema, label_columns)
|
466
|
-
# TODO(SNOW-1646693): Validate label makes sense with model task
|
467
|
-
|
468
|
-
def _validate_source_table_features_shape(
|
469
|
-
self,
|
470
|
-
table_schema: Mapping[str, types.DataType],
|
471
|
-
special_columns: Set[sql_identifier.SqlIdentifier],
|
472
|
-
model_function: model_manifest_schema.ModelFunctionInfo,
|
473
|
-
) -> None:
|
474
|
-
table_schema_without_special_columns = {
|
475
|
-
k: v for k, v in table_schema.items() if sql_identifier.SqlIdentifier(k) not in special_columns
|
476
|
-
}
|
477
|
-
schema_column_types_to_count: typing.Counter[types.DataType] = Counter()
|
478
|
-
for column_type in table_schema_without_special_columns.values():
|
479
|
-
schema_column_types_to_count[column_type] += 1
|
480
|
-
|
481
|
-
inputs = model_function["signature"].inputs
|
482
|
-
function_input_types = [input.as_snowpark_type() for input in inputs]
|
483
|
-
function_input_types_to_count: typing.Counter[types.DataType] = Counter()
|
484
|
-
for function_input_type in function_input_types:
|
485
|
-
function_input_types_to_count[function_input_type] += 1
|
486
|
-
|
487
|
-
if function_input_types_to_count != schema_column_types_to_count:
|
488
|
-
raise ValueError(
|
489
|
-
"Model function input types do not match the source table input columns types. "
|
490
|
-
f"Model function expected: {inputs} but got {table_schema_without_special_columns}"
|
491
|
-
)
|
492
|
-
|
493
|
-
def get_model_monitor_by_name(
|
494
|
-
self,
|
495
|
-
monitor_name: sql_identifier.SqlIdentifier,
|
496
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
497
|
-
) -> _ModelMonitorParams:
|
498
|
-
"""Fetch metadata for a Model Monitor by name.
|
499
|
-
|
500
|
-
Args:
|
501
|
-
monitor_name: Name of ModelMonitor to fetch.
|
502
|
-
statement_params: Optional set of statement_params to include with query.
|
503
|
-
|
504
|
-
Returns:
|
505
|
-
_ModelMonitorParams dict with Name of monitor, fully qualified model name,
|
506
|
-
model version name, model function name, prediction_col, label_col.
|
507
|
-
|
508
|
-
Raises:
|
509
|
-
ValueError: If multiple ModelMonitors exist with the same name.
|
510
|
-
"""
|
511
|
-
try:
|
512
|
-
res = (
|
513
|
-
query_result_checker.SqlResultValidator(
|
514
|
-
self._sql_client._session,
|
515
|
-
f"""SELECT {FQ_MODEL_NAME_COL_NAME}, {VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME},
|
516
|
-
{PREDICTION_COL_NAMES_COL_NAME}, {LABEL_COL_NAMES_COL_NAME}
|
517
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
518
|
-
WHERE {MONITOR_NAME_COL_NAME} = '{monitor_name}'""",
|
519
|
-
statement_params=statement_params,
|
520
|
-
)
|
521
|
-
.has_column(FQ_MODEL_NAME_COL_NAME)
|
522
|
-
.has_column(VERSION_NAME_COL_NAME)
|
523
|
-
.has_column(FUNCTION_NAME_COL_NAME)
|
524
|
-
.has_column(PREDICTION_COL_NAMES_COL_NAME)
|
525
|
-
.has_column(LABEL_COL_NAMES_COL_NAME)
|
526
|
-
.validate()
|
527
|
-
)
|
528
|
-
except errors.DataError:
|
529
|
-
raise ValueError(f"Failed to find any monitor with name '{monitor_name}'")
|
530
|
-
|
531
|
-
if len(res) > 1:
|
532
|
-
raise ValueError(f"Invalid state. Multiple Monitors exist with name '{monitor_name}'")
|
533
|
-
|
534
|
-
return _ModelMonitorParams(
|
535
|
-
monitor_name=str(monitor_name),
|
536
|
-
fully_qualified_model_name=res[0][FQ_MODEL_NAME_COL_NAME],
|
537
|
-
version_name=res[0][VERSION_NAME_COL_NAME],
|
538
|
-
function_name=res[0][FUNCTION_NAME_COL_NAME],
|
539
|
-
prediction_columns=[
|
540
|
-
sql_identifier.SqlIdentifier(prediction_column)
|
541
|
-
for prediction_column in json.loads(res[0][PREDICTION_COL_NAMES_COL_NAME])
|
542
|
-
],
|
543
|
-
label_columns=[
|
544
|
-
sql_identifier.SqlIdentifier(label_column)
|
545
|
-
for label_column in json.loads(res[0][LABEL_COL_NAMES_COL_NAME])
|
546
|
-
],
|
547
|
-
)
|
548
|
-
|
549
|
-
def get_model_monitor_by_model_version(
|
550
|
-
self,
|
551
|
-
*,
|
552
|
-
model_db: sql_identifier.SqlIdentifier,
|
553
|
-
model_schema: sql_identifier.SqlIdentifier,
|
554
|
-
model_name: sql_identifier.SqlIdentifier,
|
555
|
-
version_name: sql_identifier.SqlIdentifier,
|
556
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
557
|
-
) -> _ModelMonitorParams:
|
558
|
-
"""Fetch metadata for a Model Monitor by model version.
|
559
|
-
|
560
|
-
Args:
|
561
|
-
model_db: Database of model.
|
562
|
-
model_schema: Schema of model.
|
563
|
-
model_name: Model name.
|
564
|
-
version_name: Model version name
|
565
|
-
statement_params: Optional set of statement_params to include with queries.
|
566
|
-
|
567
|
-
Returns:
|
568
|
-
_ModelMonitorParams dict with Name of monitor, fully qualified model name,
|
569
|
-
model version name, model function name, prediction_col, label_col.
|
570
|
-
|
571
|
-
Raises:
|
572
|
-
ValueError: If multiple ModelMonitors exist with the same name.
|
573
|
-
"""
|
574
|
-
res = (
|
575
|
-
query_result_checker.SqlResultValidator(
|
576
|
-
self._sql_client._session,
|
577
|
-
f"""SELECT {MONITOR_NAME_COL_NAME}, {FQ_MODEL_NAME_COL_NAME},
|
578
|
-
{VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME}
|
579
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
580
|
-
WHERE {FQ_MODEL_NAME_COL_NAME} = '{model_db}.{model_schema}.{model_name}'
|
581
|
-
AND {VERSION_NAME_COL_NAME} = '{version_name}'""",
|
582
|
-
statement_params=statement_params,
|
583
|
-
)
|
584
|
-
.has_column(MONITOR_NAME_COL_NAME)
|
585
|
-
.has_column(FQ_MODEL_NAME_COL_NAME)
|
586
|
-
.has_column(VERSION_NAME_COL_NAME)
|
587
|
-
.has_column(FUNCTION_NAME_COL_NAME)
|
588
|
-
.validate()
|
589
|
-
)
|
590
|
-
if len(res) > 1:
|
591
|
-
raise ValueError(
|
592
|
-
f"Invalid state. Multiple Monitors exist for model: '{model_name}' and version: '{version_name}'"
|
593
|
-
)
|
594
|
-
return _ModelMonitorParams(
|
595
|
-
monitor_name=res[0][MONITOR_NAME_COL_NAME],
|
596
|
-
fully_qualified_model_name=res[0][FQ_MODEL_NAME_COL_NAME],
|
597
|
-
version_name=res[0][VERSION_NAME_COL_NAME],
|
598
|
-
function_name=res[0][FUNCTION_NAME_COL_NAME],
|
599
|
-
prediction_columns=[
|
600
|
-
sql_identifier.SqlIdentifier(prediction_column)
|
601
|
-
for prediction_column in json.loads(res[0][PREDICTION_COL_NAMES_COL_NAME])
|
602
|
-
],
|
603
|
-
label_columns=[
|
604
|
-
sql_identifier.SqlIdentifier(label_column)
|
605
|
-
for label_column in json.loads(res[0][LABEL_COL_NAMES_COL_NAME])
|
606
|
-
],
|
607
|
-
)
|
608
|
-
|
609
|
-
def get_score_type(
|
610
|
-
self,
|
611
|
-
task: type_hints.Task,
|
612
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
613
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
614
|
-
) -> output_score_type.OutputScoreType:
|
615
|
-
"""Infer score type given model task and prediction table columns.
|
616
|
-
|
617
|
-
Args:
|
618
|
-
task: Model task
|
619
|
-
source_table_name: Source data table containing model outputs.
|
620
|
-
prediction_columns: columns in source data table corresponding to model outputs.
|
621
|
-
|
622
|
-
Returns:
|
623
|
-
OutputScoreType for model.
|
624
|
-
"""
|
625
|
-
table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
|
626
|
-
self._sql_client._session,
|
627
|
-
self._database_name,
|
628
|
-
self._schema_name,
|
629
|
-
source_table_name,
|
630
|
-
)
|
631
|
-
return output_score_type.OutputScoreType.deduce_score_type(table_schema, prediction_columns, task)
|
632
|
-
|
633
|
-
def validate_source_table(
|
634
|
-
self,
|
635
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
636
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
637
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
638
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
639
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
640
|
-
model_function: model_manifest_schema.ModelFunctionInfo,
|
641
|
-
) -> None:
|
642
|
-
# Validate source table exists
|
643
|
-
if not table_manager.validate_table_exist(
|
644
|
-
self._sql_client._session,
|
645
|
-
source_table_name,
|
646
|
-
f"{self._database_name}.{self._schema_name}",
|
647
|
-
):
|
648
|
-
raise ValueError(
|
649
|
-
f"Table {source_table_name} does not exist in schema {self._database_name}.{self._schema_name}."
|
650
|
-
)
|
651
|
-
table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
|
652
|
-
self._sql_client._session,
|
653
|
-
self._database_name,
|
654
|
-
self._schema_name,
|
655
|
-
source_table_name,
|
656
|
-
)
|
657
|
-
self._validate_columns_exist_in_source_table(
|
658
|
-
table_schema=table_schema,
|
659
|
-
source_table_name=source_table_name,
|
660
|
-
timestamp_column=timestamp_column,
|
661
|
-
prediction_columns=prediction_columns,
|
662
|
-
label_columns=label_columns,
|
663
|
-
id_columns=id_columns,
|
664
|
-
)
|
665
|
-
self._validate_column_types(
|
666
|
-
table_schema=table_schema,
|
667
|
-
timestamp_column=timestamp_column,
|
668
|
-
id_columns=id_columns,
|
669
|
-
prediction_columns=prediction_columns,
|
670
|
-
label_columns=label_columns,
|
671
|
-
)
|
672
|
-
self._validate_source_table_features_shape(
|
673
|
-
table_schema=table_schema,
|
674
|
-
special_columns={timestamp_column, *id_columns, *prediction_columns, *label_columns},
|
675
|
-
model_function=model_function,
|
676
|
-
)
|
677
|
-
|
678
|
-
def delete_monitor_metadata(
|
679
|
-
self,
|
680
|
-
name: str,
|
681
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
682
|
-
) -> None:
|
683
|
-
"""Delete the row in the metadata table corresponding to the given monitor name.
|
684
|
-
|
685
|
-
Args:
|
686
|
-
name: Name of the model monitor whose metadata should be deleted.
|
687
|
-
statement_params: Optional set of statement_params to include with query.
|
688
|
-
"""
|
689
|
-
self._sql_client._session.sql(
|
690
|
-
f"""DELETE FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
691
|
-
WHERE {MONITOR_NAME_COL_NAME} = '{name}'""",
|
692
|
-
).collect(statement_params=statement_params)
|
693
|
-
|
694
|
-
def delete_baseline_table(
|
695
|
-
self,
|
696
|
-
fully_qualified_model_name: str,
|
697
|
-
version_name: str,
|
698
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
699
|
-
) -> None:
|
700
|
-
"""Delete the baseline table corresponding to a particular model and version.
|
701
|
-
|
702
|
-
Args:
|
703
|
-
fully_qualified_model_name: Fully qualified name of the model.
|
704
|
-
version_name: Name of the model version.
|
705
|
-
statement_params: Optional set of statement_params to include with query.
|
706
|
-
"""
|
707
|
-
table_name = _create_baseline_table_name(fully_qualified_model_name, version_name)
|
708
|
-
self._sql_client._session.sql(
|
709
|
-
f"""DROP TABLE IF EXISTS {self._database_name}.{self._schema_name}.{table_name}"""
|
710
|
-
).collect(statement_params=statement_params)
|
711
|
-
|
712
|
-
def delete_dynamic_tables(
|
713
|
-
self,
|
714
|
-
fully_qualified_model_name: str,
|
715
|
-
version_name: str,
|
716
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
717
|
-
) -> None:
|
718
|
-
"""Delete the dynamic tables corresponding to a particular model and version.
|
719
|
-
|
720
|
-
Args:
|
721
|
-
fully_qualified_model_name: Fully qualified name of the model.
|
722
|
-
version_name: Name of the model version.
|
723
|
-
statement_params: Optional set of statement_params to include with query.
|
724
|
-
"""
|
725
|
-
_, _, model_name = sql_identifier.parse_fully_qualified_name(fully_qualified_model_name)
|
726
|
-
model_id = sql_identifier.SqlIdentifier(model_name)
|
727
|
-
version_id = sql_identifier.SqlIdentifier(version_name)
|
728
|
-
monitoring_table_name = self.get_monitoring_table_fully_qualified_name(model_id, version_id)
|
729
|
-
self._sql_client._session.sql(f"""DROP DYNAMIC TABLE IF EXISTS {monitoring_table_name}""").collect(
|
730
|
-
statement_params=statement_params
|
731
|
-
)
|
732
|
-
accuracy_table_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_id, version_id)
|
733
|
-
self._sql_client._session.sql(f"""DROP DYNAMIC TABLE IF EXISTS {accuracy_table_name}""").collect(
|
734
|
-
statement_params=statement_params
|
735
|
-
)
|
736
|
-
|
737
|
-
def create_monitor_on_model_version(
|
738
|
-
self,
|
739
|
-
monitor_name: sql_identifier.SqlIdentifier,
|
740
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
741
|
-
fully_qualified_model_name: str,
|
742
|
-
version_name: sql_identifier.SqlIdentifier,
|
743
|
-
function_name: str,
|
744
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
745
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
746
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
747
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
748
|
-
task: type_hints.Task,
|
749
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
750
|
-
) -> None:
|
751
|
-
"""
|
752
|
-
Creates a ModelMonitor on a Model Version from the Snowflake Model Registry. Creates public schema for metadata.
|
753
|
-
|
754
|
-
Args:
|
755
|
-
monitor_name: Name of monitor object to create.
|
756
|
-
source_table_name: Name of source data table to monitor.
|
757
|
-
fully_qualified_model_name: fully qualified name of model to monitor '<db>.<schema>.<model_name>'.
|
758
|
-
version_name: model version name to monitor.
|
759
|
-
function_name: function_name to monitor in model version.
|
760
|
-
timestamp_column: timestamp column name.
|
761
|
-
prediction_columns: list of prediction column names.
|
762
|
-
label_columns: list of label column names.
|
763
|
-
id_columns: list of id column names.
|
764
|
-
task: Task of the model, e.g. TABULAR_REGRESSION.
|
765
|
-
statement_params: Optional dict of statement_params to include with queries.
|
766
|
-
|
767
|
-
Raises:
|
768
|
-
ValueError: If model version is already monitored.
|
769
|
-
"""
|
770
|
-
# Validate monitor does not already exist on model version.
|
771
|
-
if self.validate_existence(fully_qualified_model_name, version_name, statement_params):
|
772
|
-
raise ValueError(f"Model {fully_qualified_model_name} Version {version_name} is already monitored!")
|
773
|
-
|
774
|
-
if self.validate_existence_by_name(monitor_name, statement_params):
|
775
|
-
raise ValueError(f"Model Monitor with name '{monitor_name}' already exists!")
|
776
|
-
|
777
|
-
prediction_columns_for_select = formatting.format_value_for_select(prediction_columns)
|
778
|
-
label_columns_for_select = formatting.format_value_for_select(label_columns)
|
779
|
-
id_columns_for_select = formatting.format_value_for_select(id_columns)
|
780
|
-
query_result_checker.SqlResultValidator(
|
781
|
-
self._sql_client._session,
|
782
|
-
textwrap.dedent(
|
783
|
-
f"""INSERT INTO {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}
|
784
|
-
({MONITOR_NAME_COL_NAME}, {SOURCE_TABLE_NAME_COL_NAME}, {FQ_MODEL_NAME_COL_NAME},
|
785
|
-
{VERSION_NAME_COL_NAME}, {FUNCTION_NAME_COL_NAME}, {TASK_COL_NAME},
|
786
|
-
{MONITORING_ENABLED_COL_NAME}, {TIMESTAMP_COL_NAME_COL_NAME},
|
787
|
-
{PREDICTION_COL_NAMES_COL_NAME}, {LABEL_COL_NAMES_COL_NAME},
|
788
|
-
{ID_COL_NAMES_COL_NAME})
|
789
|
-
SELECT '{monitor_name}', '{source_table_name}', '{fully_qualified_model_name}',
|
790
|
-
'{version_name}', '{function_name}', '{task.value}', TRUE, '{timestamp_column}',
|
791
|
-
{prediction_columns_for_select}, {label_columns_for_select}, {id_columns_for_select}"""
|
792
|
-
),
|
793
|
-
statement_params=statement_params,
|
794
|
-
).insertion_success(expected_num_rows=1).validate()
|
795
|
-
|
796
|
-
def initialize_baseline_table(
|
797
|
-
self,
|
798
|
-
model_name: sql_identifier.SqlIdentifier,
|
799
|
-
version_name: sql_identifier.SqlIdentifier,
|
800
|
-
source_table_name: str,
|
801
|
-
columns_to_drop: Optional[List[sql_identifier.SqlIdentifier]] = None,
|
802
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
803
|
-
) -> None:
|
804
|
-
"""
|
805
|
-
Initializes the baseline table for a Model Version. Creates schema for baseline data using the source table.
|
806
|
-
|
807
|
-
Args:
|
808
|
-
model_name: name of model to monitor.
|
809
|
-
version_name: model version name to monitor.
|
810
|
-
source_table_name: name of the user's table containing their model data.
|
811
|
-
columns_to_drop: special columns in the source table to be excluded from baseline tables.
|
812
|
-
statement_params: Optional dict of statement_params to include with queries.
|
813
|
-
"""
|
814
|
-
table_schema = table_manager.get_table_schema_types(
|
815
|
-
self._sql_client._session,
|
816
|
-
database=self._database_name,
|
817
|
-
schema=self._schema_name,
|
818
|
-
table_name=source_table_name,
|
819
|
-
)
|
820
|
-
|
821
|
-
if columns_to_drop is None:
|
822
|
-
columns_to_drop = []
|
823
|
-
|
824
|
-
table_manager.create_single_table(
|
825
|
-
self._sql_client._session,
|
826
|
-
self._database_name,
|
827
|
-
self._schema_name,
|
828
|
-
_create_baseline_table_name(model_name, version_name),
|
829
|
-
[
|
830
|
-
(k, type_utils.convert_sp_to_sf_type(v))
|
831
|
-
for k, v in table_schema.items()
|
832
|
-
if sql_identifier.SqlIdentifier(k) not in columns_to_drop
|
833
|
-
],
|
834
|
-
statement_params=statement_params,
|
835
|
-
)
|
836
|
-
|
837
|
-
def get_all_model_monitor_metadata(
|
838
|
-
self,
|
839
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
840
|
-
) -> List[snowpark.Row]:
|
841
|
-
"""Get the metadata for all model monitors in the given schema.
|
842
|
-
|
843
|
-
Args:
|
844
|
-
statement_params: Optional dict of statement_params to include with queries.
|
845
|
-
|
846
|
-
Returns:
|
847
|
-
List of snowpark.Row containing metadata for each model monitor.
|
848
|
-
"""
|
849
|
-
return query_result_checker.SqlResultValidator(
|
850
|
-
self._sql_client._session,
|
851
|
-
f"""SELECT *
|
852
|
-
FROM {self._database_name}.{self._schema_name}.{SNOWML_MONITORING_METADATA_TABLE_NAME}""",
|
853
|
-
statement_params=statement_params,
|
854
|
-
).validate()
|
855
|
-
|
856
|
-
def materialize_baseline_dataframe(
|
857
|
-
self,
|
858
|
-
baseline_df: DataFrame,
|
859
|
-
fully_qualified_model_name: str,
|
860
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
861
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
862
|
-
) -> None:
|
863
|
-
"""
|
864
|
-
Materialize baseline dataframe to a permanent snowflake table. This method
|
865
|
-
truncates (overwrite without dropping) any existing data in the baseline table.
|
866
|
-
|
867
|
-
Args:
|
868
|
-
baseline_df: dataframe containing baseline data that monitored data will be compared against.
|
869
|
-
fully_qualified_model_name: name of the model.
|
870
|
-
model_version_name: model version name to monitor.
|
871
|
-
statement_params: Optional dict of statement_params to include with queries.
|
872
|
-
|
873
|
-
Raises:
|
874
|
-
ValueError: If no baseline table was initialized.
|
875
|
-
"""
|
876
|
-
|
877
|
-
_, _, model_name = sql_identifier.parse_fully_qualified_name(fully_qualified_model_name)
|
878
|
-
baseline_table_name = _create_baseline_table_name(model_name, model_version_name)
|
879
|
-
|
880
|
-
baseline_table_exists = db_utils.db_object_exists(
|
881
|
-
self._sql_client._session,
|
882
|
-
db_utils.SnowflakeDbObjectType.TABLE,
|
883
|
-
sql_identifier.SqlIdentifier(baseline_table_name),
|
884
|
-
database_name=self._database_name,
|
885
|
-
schema_name=self._schema_name,
|
886
|
-
statement_params=statement_params,
|
887
|
-
)
|
888
|
-
if not baseline_table_exists:
|
889
|
-
raise ValueError(
|
890
|
-
f"Baseline table '{baseline_table_name}' does not exist for model: "
|
891
|
-
f"'{model_name}' and model_version: '{model_version_name}'"
|
892
|
-
)
|
893
|
-
|
894
|
-
fully_qualified_baseline_table_name = [self._database_name, self._schema_name, baseline_table_name]
|
895
|
-
|
896
|
-
try:
|
897
|
-
# Truncate overwrites by clearing the rows in the table, instead of dropping the table.
|
898
|
-
# This lets us keep the schema to validate the baseline_df against.
|
899
|
-
baseline_df.write.mode("truncate").save_as_table(
|
900
|
-
fully_qualified_baseline_table_name, statement_params=statement_params
|
901
|
-
)
|
902
|
-
except exceptions.SnowparkSQLException as e:
|
903
|
-
raise ValueError(
|
904
|
-
f"""Failed to save baseline dataframe.
|
905
|
-
Ensure that the baseline dataframe columns match those provided in your monitored table: {e}"""
|
906
|
-
)
|
907
|
-
|
908
|
-
def _alter_monitor_dynamic_tables(
|
909
|
-
self,
|
910
|
-
operation: str,
|
911
|
-
model_name: sql_identifier.SqlIdentifier,
|
912
|
-
version_name: sql_identifier.SqlIdentifier,
|
913
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
914
|
-
) -> None:
|
915
|
-
if operation not in {"SUSPEND", "RESUME"}:
|
916
|
-
raise ValueError(f"Operation {operation} not supported for altering Dynamic Tables")
|
917
|
-
fq_monitor_dt_name = self.get_monitoring_table_fully_qualified_name(model_name, version_name)
|
918
|
-
query_result_checker.SqlResultValidator(
|
919
|
-
self._sql_client._session,
|
920
|
-
f"""ALTER DYNAMIC TABLE {fq_monitor_dt_name} {operation}""",
|
921
|
-
statement_params=statement_params,
|
922
|
-
).has_column("status").has_dimensions(1, 1).validate()
|
923
|
-
|
924
|
-
fq_accuracy_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, version_name)
|
925
|
-
query_result_checker.SqlResultValidator(
|
926
|
-
self._sql_client._session,
|
927
|
-
f"""ALTER DYNAMIC TABLE {fq_accuracy_dt_name} {operation}""",
|
928
|
-
statement_params=statement_params,
|
929
|
-
).has_column("status").has_dimensions(1, 1).validate()
|
930
|
-
|
931
|
-
def suspend_monitor_dynamic_tables(
|
932
|
-
self,
|
933
|
-
model_name: sql_identifier.SqlIdentifier,
|
934
|
-
version_name: sql_identifier.SqlIdentifier,
|
935
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
936
|
-
) -> None:
|
937
|
-
self._alter_monitor_dynamic_tables(
|
938
|
-
operation="SUSPEND",
|
939
|
-
model_name=model_name,
|
940
|
-
version_name=version_name,
|
941
|
-
statement_params=statement_params,
|
942
|
-
)
|
943
|
-
|
944
|
-
def resume_monitor_dynamic_tables(
|
945
|
-
self,
|
946
|
-
model_name: sql_identifier.SqlIdentifier,
|
947
|
-
version_name: sql_identifier.SqlIdentifier,
|
948
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
949
|
-
) -> None:
|
950
|
-
self._alter_monitor_dynamic_tables(
|
951
|
-
operation="RESUME",
|
952
|
-
model_name=model_name,
|
953
|
-
version_name=version_name,
|
954
|
-
statement_params=statement_params,
|
955
|
-
)
|
956
|
-
|
957
|
-
def create_dynamic_tables_for_monitor(
|
958
|
-
self,
|
959
|
-
*,
|
960
|
-
model_name: sql_identifier.SqlIdentifier,
|
961
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
962
|
-
task: type_hints.Task,
|
963
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
964
|
-
refresh_interval: model_monitor_interval.ModelMonitorRefreshInterval,
|
965
|
-
aggregation_window: model_monitor_interval.ModelMonitorAggregationWindow,
|
966
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
967
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
968
|
-
id_columns: List[sql_identifier.SqlIdentifier],
|
969
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
970
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
971
|
-
score_type: output_score_type.OutputScoreType,
|
972
|
-
) -> None:
|
973
|
-
table_schema: Mapping[str, types.DataType] = table_manager.get_table_schema_types(
|
974
|
-
self._sql_client._session,
|
975
|
-
self._database_name,
|
976
|
-
self._schema_name,
|
977
|
-
source_table_name,
|
978
|
-
)
|
979
|
-
(numeric_features_names, categorical_feature_names) = _infer_numeric_categoric_feature_column_names(
|
980
|
-
source_table_schema=table_schema,
|
981
|
-
timestamp_column=timestamp_column,
|
982
|
-
id_columns=id_columns,
|
983
|
-
prediction_columns=prediction_columns,
|
984
|
-
label_columns=label_columns,
|
985
|
-
)
|
986
|
-
features_dynamic_table_query = self._monitoring_dynamic_table_query(
|
987
|
-
model_name=model_name,
|
988
|
-
model_version_name=model_version_name,
|
989
|
-
source_table_name=source_table_name,
|
990
|
-
refresh_interval=refresh_interval,
|
991
|
-
aggregate_window=aggregation_window,
|
992
|
-
warehouse_name=warehouse_name,
|
993
|
-
timestamp_column=timestamp_column,
|
994
|
-
numeric_features=numeric_features_names,
|
995
|
-
categoric_features=categorical_feature_names,
|
996
|
-
prediction_columns=prediction_columns,
|
997
|
-
label_columns=label_columns,
|
998
|
-
)
|
999
|
-
query_result_checker.SqlResultValidator(self._sql_client._session, features_dynamic_table_query).has_column(
|
1000
|
-
"status"
|
1001
|
-
).has_dimensions(1, 1).validate()
|
1002
|
-
|
1003
|
-
label_pred_join_table_query = self._monitoring_accuracy_table_query(
|
1004
|
-
model_name=model_name,
|
1005
|
-
model_version_name=model_version_name,
|
1006
|
-
task=task,
|
1007
|
-
source_table_name=source_table_name,
|
1008
|
-
refresh_interval=refresh_interval,
|
1009
|
-
aggregate_window=aggregation_window,
|
1010
|
-
warehouse_name=warehouse_name,
|
1011
|
-
timestamp_column=timestamp_column,
|
1012
|
-
prediction_columns=prediction_columns,
|
1013
|
-
label_columns=label_columns,
|
1014
|
-
score_type=score_type,
|
1015
|
-
)
|
1016
|
-
query_result_checker.SqlResultValidator(self._sql_client._session, label_pred_join_table_query).has_column(
|
1017
|
-
"status"
|
1018
|
-
).has_dimensions(1, 1).validate()
|
1019
|
-
|
1020
|
-
def _monitoring_dynamic_table_query(
|
1021
|
-
self,
|
1022
|
-
*,
|
1023
|
-
model_name: sql_identifier.SqlIdentifier,
|
1024
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
1025
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
1026
|
-
refresh_interval: ModelMonitorRefreshInterval,
|
1027
|
-
aggregate_window: ModelMonitorAggregationWindow,
|
1028
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
1029
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
1030
|
-
numeric_features: List[sql_identifier.SqlIdentifier],
|
1031
|
-
categoric_features: List[sql_identifier.SqlIdentifier],
|
1032
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
1033
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
1034
|
-
) -> str:
|
1035
|
-
"""
|
1036
|
-
Generates a dynamic table query for Observability - Monitoring.
|
1037
|
-
|
1038
|
-
Args:
|
1039
|
-
model_name: Model name to monitor.
|
1040
|
-
model_version_name: Model version name to monitor.
|
1041
|
-
source_table_name: Name of source data table to monitor.
|
1042
|
-
refresh_interval: Refresh interval in minutes.
|
1043
|
-
aggregate_window: Aggregate window minutes.
|
1044
|
-
warehouse_name: Warehouse name to use for dynamic table.
|
1045
|
-
timestamp_column: Timestamp column name.
|
1046
|
-
numeric_features: List of numeric features to capture.
|
1047
|
-
categoric_features: List of categoric features to capture.
|
1048
|
-
prediction_columns: List of columns that contain model inference outputs.
|
1049
|
-
label_columns: List of columns that contain ground truth values.
|
1050
|
-
|
1051
|
-
Raises:
|
1052
|
-
ValueError: If multiple output/ground truth columns are specified. MultiClass models are not yet supported.
|
1053
|
-
|
1054
|
-
Returns:
|
1055
|
-
Dynamic table query.
|
1056
|
-
"""
|
1057
|
-
# output and ground cols are list to keep interface extensible.
|
1058
|
-
# for prpr only one label and one output col will be supported
|
1059
|
-
if len(prediction_columns) != 1 or len(label_columns) != 1:
|
1060
|
-
raise ValueError("Multiple Output columns are not supported in monitoring")
|
1061
|
-
|
1062
|
-
monitoring_dt_name = self.get_monitoring_table_fully_qualified_name(model_name, model_version_name)
|
1063
|
-
|
1064
|
-
feature_cols_query_list = []
|
1065
|
-
for feature in numeric_features + prediction_columns + label_columns:
|
1066
|
-
feature_cols_query_list.append(
|
1067
|
-
"""
|
1068
|
-
OBJECT_CONSTRUCT(
|
1069
|
-
'sketch', APPROX_PERCENTILE_ACCUMULATE({col}),
|
1070
|
-
'count', count_if({col} is not null),
|
1071
|
-
'count_null', count_if({col} is null),
|
1072
|
-
'min', min({col}),
|
1073
|
-
'max', max({col}),
|
1074
|
-
'sum', sum({col})
|
1075
|
-
) AS {col}""".format(
|
1076
|
-
col=feature
|
1077
|
-
)
|
1078
|
-
)
|
1079
|
-
|
1080
|
-
for col in categoric_features:
|
1081
|
-
feature_cols_query_list.append(
|
1082
|
-
f"""
|
1083
|
-
{self._database_name}.{self._schema_name}.OBJECT_SUM(to_varchar({col})) AS {col}"""
|
1084
|
-
)
|
1085
|
-
feature_cols_query = ",".join(feature_cols_query_list)
|
1086
|
-
|
1087
|
-
return f"""
|
1088
|
-
CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name}
|
1089
|
-
TARGET_LAG = '{refresh_interval.minutes} minutes'
|
1090
|
-
WAREHOUSE = {warehouse_name}
|
1091
|
-
REFRESH_MODE = AUTO
|
1092
|
-
INITIALIZE = ON_CREATE
|
1093
|
-
AS
|
1094
|
-
SELECT
|
1095
|
-
TIME_SLICE({timestamp_column}, {aggregate_window.minutes}, 'MINUTE') timestamp,{feature_cols_query}
|
1096
|
-
FROM
|
1097
|
-
{source_table_name}
|
1098
|
-
GROUP BY
|
1099
|
-
1
|
1100
|
-
"""
|
1101
|
-
|
1102
|
-
def _monitoring_accuracy_table_query(
|
1103
|
-
self,
|
1104
|
-
*,
|
1105
|
-
model_name: sql_identifier.SqlIdentifier,
|
1106
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
1107
|
-
task: type_hints.Task,
|
1108
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
1109
|
-
refresh_interval: ModelMonitorRefreshInterval,
|
1110
|
-
aggregate_window: ModelMonitorAggregationWindow,
|
1111
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
1112
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
1113
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
1114
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
1115
|
-
score_type: output_score_type.OutputScoreType,
|
1116
|
-
) -> str:
|
1117
|
-
# output and ground cols are list to keep interface extensible.
|
1118
|
-
# for prpr only one label and one output col will be supported
|
1119
|
-
if len(prediction_columns) != 1 or len(label_columns) != 1:
|
1120
|
-
raise ValueError("Multiple Output columns are not supported in monitoring")
|
1121
|
-
if task == type_hints.Task.TABULAR_BINARY_CLASSIFICATION:
|
1122
|
-
return self._monitoring_classification_accuracy_table_query(
|
1123
|
-
model_name=model_name,
|
1124
|
-
model_version_name=model_version_name,
|
1125
|
-
source_table_name=source_table_name,
|
1126
|
-
refresh_interval=refresh_interval,
|
1127
|
-
aggregate_window=aggregate_window,
|
1128
|
-
warehouse_name=warehouse_name,
|
1129
|
-
timestamp_column=timestamp_column,
|
1130
|
-
prediction_columns=prediction_columns,
|
1131
|
-
label_columns=label_columns,
|
1132
|
-
score_type=score_type,
|
1133
|
-
)
|
1134
|
-
else:
|
1135
|
-
return self._monitoring_regression_accuracy_table_query(
|
1136
|
-
model_name=model_name,
|
1137
|
-
model_version_name=model_version_name,
|
1138
|
-
source_table_name=source_table_name,
|
1139
|
-
refresh_interval=refresh_interval,
|
1140
|
-
aggregate_window=aggregate_window,
|
1141
|
-
warehouse_name=warehouse_name,
|
1142
|
-
timestamp_column=timestamp_column,
|
1143
|
-
prediction_columns=prediction_columns,
|
1144
|
-
label_columns=label_columns,
|
1145
|
-
)
|
1146
|
-
|
1147
|
-
def _monitoring_regression_accuracy_table_query(
|
1148
|
-
self,
|
1149
|
-
*,
|
1150
|
-
model_name: sql_identifier.SqlIdentifier,
|
1151
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
1152
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
1153
|
-
refresh_interval: ModelMonitorRefreshInterval,
|
1154
|
-
aggregate_window: ModelMonitorAggregationWindow,
|
1155
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
1156
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
1157
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
1158
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
1159
|
-
) -> str:
|
1160
|
-
"""
|
1161
|
-
Generates a dynamic table query for Monitoring - regression model accuracy.
|
1162
|
-
|
1163
|
-
Args:
|
1164
|
-
model_name: Model name to monitor.
|
1165
|
-
model_version_name: Model version name to monitor.
|
1166
|
-
source_table_name: Name of source data table to monitor.
|
1167
|
-
refresh_interval: Refresh interval in minutes.
|
1168
|
-
aggregate_window: Aggregate window minutes.
|
1169
|
-
warehouse_name: Warehouse name to use for dynamic table.
|
1170
|
-
timestamp_column: Timestamp column name.
|
1171
|
-
prediction_columns: List of output columns.
|
1172
|
-
label_columns: List of ground truth columns.
|
1173
|
-
|
1174
|
-
Returns:
|
1175
|
-
Dynamic table query.
|
1176
|
-
|
1177
|
-
Raises:
|
1178
|
-
ValueError: If output columns are not same as ground truth columns.
|
1179
|
-
|
1180
|
-
"""
|
1181
|
-
|
1182
|
-
if len(prediction_columns) != len(label_columns):
|
1183
|
-
raise ValueError(f"Mismatch in output & ground truth columns: {prediction_columns} != {label_columns}")
|
1184
|
-
|
1185
|
-
monitoring_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, model_version_name)
|
1186
|
-
|
1187
|
-
output_cols_query_list = []
|
1188
|
-
|
1189
|
-
output_cols_query_list.append(
|
1190
|
-
f"""
|
1191
|
-
OBJECT_CONSTRUCT(
|
1192
|
-
'sum_difference_label_pred', sum({prediction_columns[0]} - {label_columns[0]}),
|
1193
|
-
'sum_log_difference_square_label_pred',
|
1194
|
-
sum(
|
1195
|
-
case
|
1196
|
-
when {prediction_columns[0]} > -1 and {label_columns[0]} > -1
|
1197
|
-
then pow(ln({prediction_columns[0]} + 1) - ln({label_columns[0]} + 1),2)
|
1198
|
-
else null
|
1199
|
-
END
|
1200
|
-
),
|
1201
|
-
'sum_difference_squares_label_pred',
|
1202
|
-
sum(
|
1203
|
-
pow(
|
1204
|
-
{prediction_columns[0]} - {label_columns[0]},
|
1205
|
-
2
|
1206
|
-
)
|
1207
|
-
),
|
1208
|
-
'sum_absolute_regression_labels', sum(abs({label_columns[0]})),
|
1209
|
-
'sum_absolute_percentage_error',
|
1210
|
-
sum(
|
1211
|
-
abs(
|
1212
|
-
div0null(
|
1213
|
-
({prediction_columns[0]} - {label_columns[0]}),
|
1214
|
-
{label_columns[0]}
|
1215
|
-
)
|
1216
|
-
)
|
1217
|
-
),
|
1218
|
-
'sum_absolute_difference_label_pred',
|
1219
|
-
sum(
|
1220
|
-
abs({prediction_columns[0]} - {label_columns[0]})
|
1221
|
-
),
|
1222
|
-
'sum_prediction', sum({prediction_columns[0]}),
|
1223
|
-
'sum_label', sum({label_columns[0]}),
|
1224
|
-
'count', count(*)
|
1225
|
-
) AS AGGREGATE_METRICS,
|
1226
|
-
APPROX_PERCENTILE_ACCUMULATE({prediction_columns[0]}) prediction_sketch,
|
1227
|
-
APPROX_PERCENTILE_ACCUMULATE({label_columns[0]}) label_sketch"""
|
1228
|
-
)
|
1229
|
-
output_cols_query = ", ".join(output_cols_query_list)
|
1230
|
-
|
1231
|
-
return f"""
|
1232
|
-
CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name}
|
1233
|
-
TARGET_LAG = '{refresh_interval.minutes} minutes'
|
1234
|
-
WAREHOUSE = {warehouse_name}
|
1235
|
-
REFRESH_MODE = AUTO
|
1236
|
-
INITIALIZE = ON_CREATE
|
1237
|
-
AS
|
1238
|
-
SELECT
|
1239
|
-
TIME_SLICE({timestamp_column}, {aggregate_window.minutes}, 'MINUTE') timestamp,
|
1240
|
-
'class_regression' label_class,{output_cols_query}
|
1241
|
-
FROM
|
1242
|
-
{source_table_name}
|
1243
|
-
GROUP BY
|
1244
|
-
1
|
1245
|
-
"""
|
1246
|
-
|
1247
|
-
def _monitoring_classification_accuracy_table_query(
|
1248
|
-
self,
|
1249
|
-
*,
|
1250
|
-
model_name: sql_identifier.SqlIdentifier,
|
1251
|
-
model_version_name: sql_identifier.SqlIdentifier,
|
1252
|
-
source_table_name: sql_identifier.SqlIdentifier,
|
1253
|
-
refresh_interval: ModelMonitorRefreshInterval,
|
1254
|
-
aggregate_window: ModelMonitorAggregationWindow,
|
1255
|
-
warehouse_name: sql_identifier.SqlIdentifier,
|
1256
|
-
timestamp_column: sql_identifier.SqlIdentifier,
|
1257
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
1258
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
1259
|
-
score_type: output_score_type.OutputScoreType,
|
1260
|
-
) -> str:
|
1261
|
-
monitoring_dt_name = self.get_accuracy_monitoring_table_fully_qualified_name(model_name, model_version_name)
|
1262
|
-
|
1263
|
-
# Initialize the select clause components
|
1264
|
-
select_clauses = []
|
1265
|
-
|
1266
|
-
select_clauses.append(
|
1267
|
-
f"""
|
1268
|
-
{prediction_columns[0]},
|
1269
|
-
{label_columns[0]},
|
1270
|
-
CASE
|
1271
|
-
WHEN {label_columns[0]} = 1 THEN 'class_positive'
|
1272
|
-
ELSE 'class_negative'
|
1273
|
-
END AS label_class"""
|
1274
|
-
)
|
1275
|
-
|
1276
|
-
# Join all the select clauses into a single string
|
1277
|
-
select_clause = f"{timestamp_column} AS timestamp," + ",".join(select_clauses)
|
1278
|
-
|
1279
|
-
# Create the final CTE query
|
1280
|
-
cte_query = f"""
|
1281
|
-
WITH filtered_data AS (
|
1282
|
-
SELECT
|
1283
|
-
{select_clause}
|
1284
|
-
FROM
|
1285
|
-
{source_table_name}
|
1286
|
-
)"""
|
1287
|
-
|
1288
|
-
# Initialize the select clause components
|
1289
|
-
select_clauses = []
|
1290
|
-
|
1291
|
-
score_type_agg_clause = ""
|
1292
|
-
if score_type == output_score_type.OutputScoreType.PROBITS:
|
1293
|
-
score_type_agg_clause = f"""
|
1294
|
-
'sum_log_loss',
|
1295
|
-
CASE
|
1296
|
-
WHEN label_class = 'class_positive' THEN sum(-ln({prediction_columns[0]}))
|
1297
|
-
ELSE sum(-ln(1 - {prediction_columns[0]}))
|
1298
|
-
END,"""
|
1299
|
-
else:
|
1300
|
-
score_type_agg_clause = f"""
|
1301
|
-
'tp', count_if({label_columns[0]} = 1 AND {prediction_columns[0]} = 1),
|
1302
|
-
'tn', count_if({label_columns[0]} = 0 AND {prediction_columns[0]} = 0),
|
1303
|
-
'fp', count_if({label_columns[0]} = 0 AND {prediction_columns[0]} = 1),
|
1304
|
-
'fn', count_if({label_columns[0]} = 1 AND {prediction_columns[0]} = 0),"""
|
1305
|
-
|
1306
|
-
select_clauses.append(
|
1307
|
-
f"""
|
1308
|
-
label_class,
|
1309
|
-
OBJECT_CONSTRUCT(
|
1310
|
-
'sum_prediction', sum({prediction_columns[0]}),
|
1311
|
-
'sum_label', sum({label_columns[0]}),{score_type_agg_clause}
|
1312
|
-
'count', count(*)
|
1313
|
-
) AS AGGREGATE_METRICS,
|
1314
|
-
APPROX_PERCENTILE_ACCUMULATE({prediction_columns[0]}) prediction_sketch,
|
1315
|
-
APPROX_PERCENTILE_ACCUMULATE({label_columns[0]}) label_sketch"""
|
1316
|
-
)
|
1317
|
-
|
1318
|
-
# Join all the select clauses into a single string
|
1319
|
-
select_clause = ",\n".join(select_clauses)
|
1320
|
-
|
1321
|
-
return f"""
|
1322
|
-
CREATE DYNAMIC TABLE IF NOT EXISTS {monitoring_dt_name}
|
1323
|
-
TARGET_LAG = '{refresh_interval.minutes} minutes'
|
1324
|
-
WAREHOUSE = {warehouse_name}
|
1325
|
-
REFRESH_MODE = AUTO
|
1326
|
-
INITIALIZE = ON_CREATE
|
1327
|
-
AS{cte_query}
|
1328
|
-
select
|
1329
|
-
time_slice(timestamp, {aggregate_window.minutes}, 'MINUTE') timestamp,{select_clause}
|
1330
|
-
FROM
|
1331
|
-
filtered_data
|
1332
|
-
group by
|
1333
|
-
1,
|
1334
|
-
2
|
1335
|
-
"""
|