snowflake-ml-python 1.6.4__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_complete.py +107 -64
- snowflake/cortex/_finetune.py +273 -0
- snowflake/cortex/_sse_client.py +91 -28
- snowflake/cortex/_util.py +30 -1
- snowflake/ml/_internal/telemetry.py +4 -2
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/import_utils.py +31 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
- snowflake/ml/data/__init__.py +5 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/data/torch_utils.py +33 -14
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
- snowflake/ml/feature_store/examples/example_helper.py +6 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
- snowflake/ml/feature_store/feature_store.py +1 -2
- snowflake/ml/feature_store/feature_view.py +5 -1
- snowflake/ml/model/_client/model/model_version_impl.py +145 -11
- snowflake/ml/model/_client/ops/model_ops.py +56 -16
- snowflake/ml/model/_client/ops/service_ops.py +46 -30
- snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_client/sql/service.py +25 -1
- snowflake/ml/model/_model_composer/model_composer.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
- snowflake/ml/model/_packager/model_env/model_env.py +12 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +6 -2
- snowflake/ml/model/_packager/model_handlers/catboost.py +4 -7
- snowflake/ml/model/_packager/model_handlers/custom.py +5 -1
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -7
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +51 -7
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +8 -66
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
- snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
- snowflake/ml/model/_packager/model_packager.py +0 -11
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
- snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +14 -26
- snowflake/ml/model/_signatures/core.py +63 -16
- snowflake/ml/model/_signatures/pandas_handler.py +87 -27
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
- snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/custom_model.py +47 -7
- snowflake/ml/model/model_signature.py +40 -9
- snowflake/ml/model/type_hints.py +9 -1
- snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
- snowflake/ml/modeling/cluster/dbscan.py +5 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
- snowflake/ml/modeling/cluster/k_means.py +14 -19
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
- snowflake/ml/modeling/cluster/optics.py +6 -6
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
- snowflake/ml/modeling/compose/column_transformer.py +15 -5
- snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
- snowflake/ml/modeling/decomposition/pca.py +28 -15
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
- snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
- snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +0 -10
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
- snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +3 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
- snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
- snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +3 -3
- snowflake/ml/modeling/manifold/tsne.py +10 -4
- snowflake/ml/modeling/metrics/classification.py +12 -16
- snowflake/ml/modeling/metrics/ranking.py +3 -3
- snowflake/ml/modeling/metrics/regression.py +3 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
- snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +16 -14
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
- snowflake/ml/modeling/svm/linear_svc.py +25 -16
- snowflake/ml/modeling/svm/linear_svr.py +23 -17
- snowflake/ml/modeling/svm/nu_svc.py +5 -3
- snowflake/ml/modeling/svm/nu_svr.py +3 -1
- snowflake/ml/modeling/svm/svc.py +9 -5
- snowflake/ml/modeling/svm/svr.py +3 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
- snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
- snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
- snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +448 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +238 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
- snowflake/ml/monitoring/model_monitor.py +37 -0
- snowflake/ml/registry/_manager/model_manager.py +15 -1
- snowflake/ml/registry/registry.py +32 -37
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +104 -12
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +172 -171
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/_client/model_monitor.py +0 -126
- snowflake/ml/monitoring/_client/model_monitor_manager.py +0 -361
- snowflake/ml/monitoring/_client/monitor_sql_client.py +0 -1335
- snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
- /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -1,126 +0,0 @@
|
|
1
|
-
from typing import List, Union
|
2
|
-
|
3
|
-
import pandas as pd
|
4
|
-
|
5
|
-
from snowflake import snowpark
|
6
|
-
from snowflake.ml._internal import telemetry
|
7
|
-
from snowflake.ml._internal.utils import sql_identifier
|
8
|
-
from snowflake.ml.monitoring._client import monitor_sql_client
|
9
|
-
|
10
|
-
|
11
|
-
class ModelMonitor:
|
12
|
-
"""Class to manage instrumentation of Model Monitoring and Observability"""
|
13
|
-
|
14
|
-
name: sql_identifier.SqlIdentifier
|
15
|
-
_model_monitor_client: monitor_sql_client._ModelMonitorSQLClient
|
16
|
-
_fully_qualified_model_name: str
|
17
|
-
_version_name: sql_identifier.SqlIdentifier
|
18
|
-
_function_name: sql_identifier.SqlIdentifier
|
19
|
-
_prediction_columns: List[sql_identifier.SqlIdentifier]
|
20
|
-
_label_columns: List[sql_identifier.SqlIdentifier]
|
21
|
-
|
22
|
-
def __init__(self) -> None:
|
23
|
-
raise RuntimeError("ModelMonitor's initializer is not meant to be used.")
|
24
|
-
|
25
|
-
@classmethod
|
26
|
-
def _ref(
|
27
|
-
cls,
|
28
|
-
model_monitor_client: monitor_sql_client._ModelMonitorSQLClient,
|
29
|
-
name: sql_identifier.SqlIdentifier,
|
30
|
-
*,
|
31
|
-
fully_qualified_model_name: str,
|
32
|
-
version_name: sql_identifier.SqlIdentifier,
|
33
|
-
function_name: sql_identifier.SqlIdentifier,
|
34
|
-
prediction_columns: List[sql_identifier.SqlIdentifier],
|
35
|
-
label_columns: List[sql_identifier.SqlIdentifier],
|
36
|
-
) -> "ModelMonitor":
|
37
|
-
self: "ModelMonitor" = object.__new__(cls)
|
38
|
-
self.name = name
|
39
|
-
self._model_monitor_client = model_monitor_client
|
40
|
-
self._fully_qualified_model_name = fully_qualified_model_name
|
41
|
-
self._version_name = version_name
|
42
|
-
self._function_name = function_name
|
43
|
-
self._prediction_columns = prediction_columns
|
44
|
-
self._label_columns = label_columns
|
45
|
-
return self
|
46
|
-
|
47
|
-
@telemetry.send_api_usage_telemetry(
|
48
|
-
project=telemetry.TelemetryProject.MLOPS.value,
|
49
|
-
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
50
|
-
)
|
51
|
-
def set_baseline(self, baseline_df: Union[pd.DataFrame, snowpark.DataFrame]) -> None:
|
52
|
-
"""
|
53
|
-
The baseline dataframe is compared with the monitored data once monitoring is enabled.
|
54
|
-
The columns of the dataframe should match the columns of the source table that the
|
55
|
-
ModelMonitor was configured with. Calling this method overwrites any existing baseline split data.
|
56
|
-
|
57
|
-
Args:
|
58
|
-
baseline_df: Snowpark dataframe containing baseline data.
|
59
|
-
|
60
|
-
Raises:
|
61
|
-
ValueError: baseline_df does not contain prediction or label columns
|
62
|
-
"""
|
63
|
-
statement_params = telemetry.get_statement_params(
|
64
|
-
project=telemetry.TelemetryProject.MLOPS.value,
|
65
|
-
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
66
|
-
)
|
67
|
-
|
68
|
-
if isinstance(baseline_df, pd.DataFrame):
|
69
|
-
baseline_df = self._model_monitor_client._sql_client._session.create_dataframe(baseline_df)
|
70
|
-
|
71
|
-
column_names_identifiers: List[sql_identifier.SqlIdentifier] = [
|
72
|
-
sql_identifier.SqlIdentifier(column_name) for column_name in baseline_df.columns
|
73
|
-
]
|
74
|
-
prediction_cols_not_found = any(
|
75
|
-
[prediction_col not in column_names_identifiers for prediction_col in self._prediction_columns]
|
76
|
-
)
|
77
|
-
label_cols_not_found = any(
|
78
|
-
[label_col.identifier() not in column_names_identifiers for label_col in self._label_columns]
|
79
|
-
)
|
80
|
-
|
81
|
-
if prediction_cols_not_found:
|
82
|
-
raise ValueError(
|
83
|
-
"Specified prediction columns were not found in the baseline dataframe. "
|
84
|
-
f"Columns provided were: {column_names_identifiers}. "
|
85
|
-
f"Configured prediction columns were: {self._prediction_columns}."
|
86
|
-
)
|
87
|
-
if label_cols_not_found:
|
88
|
-
raise ValueError(
|
89
|
-
"Specified label columns were not found in the baseline dataframe."
|
90
|
-
f"Columns provided in the baseline dataframe were: {column_names_identifiers}."
|
91
|
-
f"Configured label columns were: {self._label_columns}."
|
92
|
-
)
|
93
|
-
|
94
|
-
# Create the table by materializing the df
|
95
|
-
self._model_monitor_client.materialize_baseline_dataframe(
|
96
|
-
baseline_df,
|
97
|
-
self._fully_qualified_model_name,
|
98
|
-
self._version_name,
|
99
|
-
statement_params=statement_params,
|
100
|
-
)
|
101
|
-
|
102
|
-
def suspend(self) -> None:
|
103
|
-
"""Suspend pipeline for ModelMonitor"""
|
104
|
-
statement_params = telemetry.get_statement_params(
|
105
|
-
telemetry.TelemetryProject.MLOPS.value,
|
106
|
-
telemetry.TelemetrySubProject.MONITORING.value,
|
107
|
-
)
|
108
|
-
_, _, model_name = sql_identifier.parse_fully_qualified_name(self._fully_qualified_model_name)
|
109
|
-
self._model_monitor_client.suspend_monitor_dynamic_tables(
|
110
|
-
model_name=model_name,
|
111
|
-
version_name=self._version_name,
|
112
|
-
statement_params=statement_params,
|
113
|
-
)
|
114
|
-
|
115
|
-
def resume(self) -> None:
|
116
|
-
"""Resume pipeline for ModelMonitor"""
|
117
|
-
statement_params = telemetry.get_statement_params(
|
118
|
-
telemetry.TelemetryProject.MLOPS.value,
|
119
|
-
telemetry.TelemetrySubProject.MONITORING.value,
|
120
|
-
)
|
121
|
-
_, _, model_name = sql_identifier.parse_fully_qualified_name(self._fully_qualified_model_name)
|
122
|
-
self._model_monitor_client.resume_monitor_dynamic_tables(
|
123
|
-
model_name=model_name,
|
124
|
-
version_name=self._version_name,
|
125
|
-
statement_params=statement_params,
|
126
|
-
)
|
@@ -1,361 +0,0 @@
|
|
1
|
-
from typing import Any, Dict, List, Optional
|
2
|
-
|
3
|
-
from snowflake import snowpark
|
4
|
-
from snowflake.ml._internal import telemetry
|
5
|
-
from snowflake.ml._internal.utils import db_utils, sql_identifier
|
6
|
-
from snowflake.ml.model import type_hints
|
7
|
-
from snowflake.ml.model._client.model import model_version_impl
|
8
|
-
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
9
|
-
from snowflake.ml.monitoring._client import model_monitor, monitor_sql_client
|
10
|
-
from snowflake.ml.monitoring.entities import (
|
11
|
-
model_monitor_config,
|
12
|
-
model_monitor_interval,
|
13
|
-
)
|
14
|
-
from snowflake.snowpark import session
|
15
|
-
|
16
|
-
|
17
|
-
def _validate_name_constraints(model_version: model_version_impl.ModelVersion) -> None:
|
18
|
-
system_table_prefixes = [
|
19
|
-
monitor_sql_client._SNOWML_MONITORING_TABLE_NAME_PREFIX,
|
20
|
-
monitor_sql_client._SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX,
|
21
|
-
]
|
22
|
-
|
23
|
-
max_allowed_model_name_and_version_length = (
|
24
|
-
db_utils.MAX_IDENTIFIER_LENGTH - max(len(prefix) for prefix in system_table_prefixes) - 1
|
25
|
-
) # -1 includes '_' between model_name + model_version
|
26
|
-
if len(model_version.model_name) + len(model_version.version_name) > max_allowed_model_name_and_version_length:
|
27
|
-
error_msg = f"Model name and version name exceeds maximum length of {max_allowed_model_name_and_version_length}"
|
28
|
-
raise ValueError(error_msg)
|
29
|
-
|
30
|
-
|
31
|
-
class ModelMonitorManager:
|
32
|
-
"""Class to manage internal operations for Model Monitor workflows.""" # TODO: Move to Registry.
|
33
|
-
|
34
|
-
@staticmethod
|
35
|
-
def setup(session: session.Session, database_name: str, schema_name: str) -> None:
|
36
|
-
"""Static method to set up schema for Model Monitoring resources.
|
37
|
-
|
38
|
-
Args:
|
39
|
-
session: The Snowpark Session to connect with Snowflake.
|
40
|
-
database_name: The name of the database. If None, the current database of the session
|
41
|
-
will be used. Defaults to None.
|
42
|
-
schema_name: The name of the schema. If None, the current schema of the session
|
43
|
-
will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None.
|
44
|
-
"""
|
45
|
-
statement_params = telemetry.get_statement_params(
|
46
|
-
project=telemetry.TelemetryProject.MLOPS.value,
|
47
|
-
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
48
|
-
)
|
49
|
-
database_name_id = sql_identifier.SqlIdentifier(database_name)
|
50
|
-
schema_name_id = sql_identifier.SqlIdentifier(schema_name)
|
51
|
-
monitor_sql_client._ModelMonitorSQLClient.initialize_monitoring_schema(
|
52
|
-
session, database_name_id, schema_name_id, statement_params=statement_params
|
53
|
-
)
|
54
|
-
|
55
|
-
def _fetch_task_from_model_version(
|
56
|
-
self,
|
57
|
-
model_version: model_version_impl.ModelVersion,
|
58
|
-
) -> type_hints.Task:
|
59
|
-
task = model_version.get_model_task()
|
60
|
-
if task == type_hints.Task.UNKNOWN:
|
61
|
-
raise ValueError("Registry model must be logged with task in order to be monitored.")
|
62
|
-
return task
|
63
|
-
|
64
|
-
def __init__(
|
65
|
-
self,
|
66
|
-
session: session.Session,
|
67
|
-
database_name: sql_identifier.SqlIdentifier,
|
68
|
-
schema_name: sql_identifier.SqlIdentifier,
|
69
|
-
*,
|
70
|
-
create_if_not_exists: bool = False,
|
71
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
72
|
-
) -> None:
|
73
|
-
"""
|
74
|
-
Opens a ModelMonitorManager for a given database and schema.
|
75
|
-
Optionally sets up the schema for Model Monitoring.
|
76
|
-
|
77
|
-
Args:
|
78
|
-
session: The Snowpark Session to connect with Snowflake.
|
79
|
-
database_name: The name of the database.
|
80
|
-
schema_name: The name of the schema.
|
81
|
-
create_if_not_exists: Flag whether to initialize resources in the schema needed for Model Monitoring.
|
82
|
-
statement_params: Optional set of statement params.
|
83
|
-
|
84
|
-
Raises:
|
85
|
-
ValueError: When there is no specified or active database in the session.
|
86
|
-
"""
|
87
|
-
self._database_name = database_name
|
88
|
-
self._schema_name = schema_name
|
89
|
-
self.statement_params = statement_params
|
90
|
-
self._model_monitor_client = monitor_sql_client._ModelMonitorSQLClient(
|
91
|
-
session,
|
92
|
-
database_name=self._database_name,
|
93
|
-
schema_name=self._schema_name,
|
94
|
-
)
|
95
|
-
if create_if_not_exists:
|
96
|
-
monitor_sql_client._ModelMonitorSQLClient.initialize_monitoring_schema(
|
97
|
-
session, self._database_name, self._schema_name, self.statement_params
|
98
|
-
)
|
99
|
-
elif not self._model_monitor_client._validate_is_initialized():
|
100
|
-
raise ValueError(
|
101
|
-
"Monitoring has not been setup. Set create_if_not_exists or call ModelMonitorManager.setup"
|
102
|
-
)
|
103
|
-
|
104
|
-
def _get_and_validate_model_function_from_model_version(
|
105
|
-
self, function: str, model_version: model_version_impl.ModelVersion
|
106
|
-
) -> model_manifest_schema.ModelFunctionInfo:
|
107
|
-
functions = model_version.show_functions()
|
108
|
-
for f in functions:
|
109
|
-
if f["target_method"] == function:
|
110
|
-
return f
|
111
|
-
existing_target_methods = {f["target_method"] for f in functions}
|
112
|
-
raise ValueError(
|
113
|
-
f"Function with name {function} does not exist in the given model version. "
|
114
|
-
f"Found: {existing_target_methods}."
|
115
|
-
)
|
116
|
-
|
117
|
-
def _validate_monitor_config_or_raise(
|
118
|
-
self,
|
119
|
-
table_config: model_monitor_config.ModelMonitorTableConfig,
|
120
|
-
model_monitor_config: model_monitor_config.ModelMonitorConfig,
|
121
|
-
) -> None:
|
122
|
-
"""Validate provided config for model monitor.
|
123
|
-
|
124
|
-
Args:
|
125
|
-
table_config: Config for model monitor tables.
|
126
|
-
model_monitor_config: Config for ModelMonitor.
|
127
|
-
|
128
|
-
Raises:
|
129
|
-
ValueError: If warehouse provided does not exist.
|
130
|
-
"""
|
131
|
-
|
132
|
-
# Validate naming will not exceed 255 chars
|
133
|
-
_validate_name_constraints(model_monitor_config.model_version)
|
134
|
-
|
135
|
-
if len(table_config.prediction_columns) != len(table_config.label_columns):
|
136
|
-
raise ValueError("Prediction and Label column names must be of the same length.")
|
137
|
-
# output and ground cols are list to keep interface extensible.
|
138
|
-
# for prpr only one label and one output col will be supported
|
139
|
-
if len(table_config.prediction_columns) != 1 or len(table_config.label_columns) != 1:
|
140
|
-
raise ValueError("Multiple Output columns are not supported in monitoring")
|
141
|
-
|
142
|
-
# Validate warehouse exists.
|
143
|
-
warehouse_name_id = sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name)
|
144
|
-
self._model_monitor_client.validate_monitor_warehouse(warehouse_name_id, statement_params=self.statement_params)
|
145
|
-
|
146
|
-
# Validate refresh interval.
|
147
|
-
try:
|
148
|
-
num_units, time_units = model_monitor_config.refresh_interval.strip().split(" ")
|
149
|
-
int(num_units) # try to cast
|
150
|
-
if time_units.lower() not in {"seconds", "minutes", "hours", "days"}:
|
151
|
-
raise ValueError(
|
152
|
-
"""Invalid time unit in refresh interval. Provide '<num> <seconds | minutes | hours | days>'.
|
153
|
-
See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info."""
|
154
|
-
)
|
155
|
-
except Exception as e: # TODO: Link to DT page.
|
156
|
-
raise ValueError(
|
157
|
-
f"""Failed to parse refresh interval with exception {e}.
|
158
|
-
Provide '<num> <seconds | minutes | hours | days>'.
|
159
|
-
See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info."""
|
160
|
-
)
|
161
|
-
|
162
|
-
def add_monitor(
|
163
|
-
self,
|
164
|
-
name: str,
|
165
|
-
table_config: model_monitor_config.ModelMonitorTableConfig,
|
166
|
-
model_monitor_config: model_monitor_config.ModelMonitorConfig,
|
167
|
-
*,
|
168
|
-
add_dashboard_udtfs: bool = False,
|
169
|
-
) -> model_monitor.ModelMonitor:
|
170
|
-
"""Add a new Model Monitor.
|
171
|
-
|
172
|
-
Args:
|
173
|
-
name: Name of Model Monitor to create.
|
174
|
-
table_config: Configuration options for the source table used in ModelMonitor.
|
175
|
-
model_monitor_config: Configuration options of ModelMonitor.
|
176
|
-
add_dashboard_udtfs: Add UDTFs useful for creating a dashboard.
|
177
|
-
|
178
|
-
Returns:
|
179
|
-
The newly added ModelMonitor object.
|
180
|
-
"""
|
181
|
-
# Validates configuration or raise.
|
182
|
-
self._validate_monitor_config_or_raise(table_config, model_monitor_config)
|
183
|
-
model_function = self._get_and_validate_model_function_from_model_version(
|
184
|
-
model_monitor_config.model_function_name, model_monitor_config.model_version
|
185
|
-
)
|
186
|
-
monitor_refresh_interval = model_monitor_interval.ModelMonitorRefreshInterval(
|
187
|
-
model_monitor_config.refresh_interval
|
188
|
-
)
|
189
|
-
name_id = sql_identifier.SqlIdentifier(name)
|
190
|
-
source_table_name_id = sql_identifier.SqlIdentifier(table_config.source_table)
|
191
|
-
prediction_columns = [
|
192
|
-
sql_identifier.SqlIdentifier(column_name) for column_name in table_config.prediction_columns
|
193
|
-
]
|
194
|
-
label_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in table_config.label_columns]
|
195
|
-
id_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in table_config.id_columns]
|
196
|
-
ts_column = sql_identifier.SqlIdentifier(table_config.timestamp_column)
|
197
|
-
|
198
|
-
# Validate source table
|
199
|
-
self._model_monitor_client.validate_source_table(
|
200
|
-
source_table_name=source_table_name_id,
|
201
|
-
timestamp_column=ts_column,
|
202
|
-
prediction_columns=prediction_columns,
|
203
|
-
label_columns=label_columns,
|
204
|
-
id_columns=id_columns,
|
205
|
-
model_function=model_function,
|
206
|
-
)
|
207
|
-
|
208
|
-
task = self._fetch_task_from_model_version(model_version=model_monitor_config.model_version)
|
209
|
-
score_type = self._model_monitor_client.get_score_type(task, source_table_name_id, prediction_columns)
|
210
|
-
|
211
|
-
# Insert monitoring metadata for new model version.
|
212
|
-
self._model_monitor_client.create_monitor_on_model_version(
|
213
|
-
monitor_name=name_id,
|
214
|
-
source_table_name=source_table_name_id,
|
215
|
-
fully_qualified_model_name=model_monitor_config.model_version.fully_qualified_model_name,
|
216
|
-
version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
|
217
|
-
function_name=model_monitor_config.model_function_name,
|
218
|
-
timestamp_column=ts_column,
|
219
|
-
prediction_columns=prediction_columns,
|
220
|
-
label_columns=label_columns,
|
221
|
-
id_columns=id_columns,
|
222
|
-
task=task,
|
223
|
-
statement_params=self.statement_params,
|
224
|
-
)
|
225
|
-
|
226
|
-
# Create Dynamic tables for model monitor.
|
227
|
-
self._model_monitor_client.create_dynamic_tables_for_monitor(
|
228
|
-
model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
|
229
|
-
model_version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
|
230
|
-
task=task,
|
231
|
-
source_table_name=source_table_name_id,
|
232
|
-
refresh_interval=monitor_refresh_interval,
|
233
|
-
aggregation_window=model_monitor_config.aggregation_window,
|
234
|
-
warehouse_name=sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name),
|
235
|
-
timestamp_column=sql_identifier.SqlIdentifier(table_config.timestamp_column),
|
236
|
-
id_columns=id_columns,
|
237
|
-
prediction_columns=prediction_columns,
|
238
|
-
label_columns=label_columns,
|
239
|
-
score_type=score_type,
|
240
|
-
)
|
241
|
-
|
242
|
-
# Initialize baseline table.
|
243
|
-
self._model_monitor_client.initialize_baseline_table(
|
244
|
-
model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
|
245
|
-
version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
|
246
|
-
source_table_name=table_config.source_table,
|
247
|
-
columns_to_drop=[ts_column, *id_columns],
|
248
|
-
statement_params=self.statement_params,
|
249
|
-
)
|
250
|
-
|
251
|
-
# Add udtfs helpful for dashboard queries.
|
252
|
-
# TODO(apgupta) Make this true by default.
|
253
|
-
if add_dashboard_udtfs:
|
254
|
-
self._model_monitor_client.add_dashboard_udtfs(
|
255
|
-
name_id,
|
256
|
-
model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
|
257
|
-
model_version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
|
258
|
-
task=task,
|
259
|
-
score_type=score_type,
|
260
|
-
output_columns=prediction_columns,
|
261
|
-
ground_truth_columns=label_columns,
|
262
|
-
)
|
263
|
-
|
264
|
-
return model_monitor.ModelMonitor._ref(
|
265
|
-
model_monitor_client=self._model_monitor_client,
|
266
|
-
name=name_id,
|
267
|
-
fully_qualified_model_name=model_monitor_config.model_version.fully_qualified_model_name,
|
268
|
-
version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
|
269
|
-
function_name=sql_identifier.SqlIdentifier(model_monitor_config.model_function_name),
|
270
|
-
prediction_columns=prediction_columns,
|
271
|
-
label_columns=label_columns,
|
272
|
-
)
|
273
|
-
|
274
|
-
def get_monitor_by_model_version(
|
275
|
-
self, model_version: model_version_impl.ModelVersion
|
276
|
-
) -> model_monitor.ModelMonitor:
|
277
|
-
fq_model_name = model_version.fully_qualified_model_name
|
278
|
-
version_name = sql_identifier.SqlIdentifier(model_version.version_name)
|
279
|
-
if self._model_monitor_client.validate_existence(fq_model_name, version_name, self.statement_params):
|
280
|
-
model_db, model_schema, model_name = sql_identifier.parse_fully_qualified_name(fq_model_name)
|
281
|
-
if model_db is None or model_schema is None:
|
282
|
-
raise ValueError("Failed to parse model name")
|
283
|
-
|
284
|
-
model_monitor_params: monitor_sql_client._ModelMonitorParams = (
|
285
|
-
self._model_monitor_client.get_model_monitor_by_model_version(
|
286
|
-
model_db=model_db,
|
287
|
-
model_schema=model_schema,
|
288
|
-
model_name=model_name,
|
289
|
-
version_name=version_name,
|
290
|
-
statement_params=self.statement_params,
|
291
|
-
)
|
292
|
-
)
|
293
|
-
return model_monitor.ModelMonitor._ref(
|
294
|
-
model_monitor_client=self._model_monitor_client,
|
295
|
-
name=sql_identifier.SqlIdentifier(model_monitor_params["monitor_name"]),
|
296
|
-
fully_qualified_model_name=fq_model_name,
|
297
|
-
version_name=version_name,
|
298
|
-
function_name=sql_identifier.SqlIdentifier(model_monitor_params["function_name"]),
|
299
|
-
prediction_columns=model_monitor_params["prediction_columns"],
|
300
|
-
label_columns=model_monitor_params["label_columns"],
|
301
|
-
)
|
302
|
-
|
303
|
-
else:
|
304
|
-
raise ValueError(
|
305
|
-
f"ModelMonitor not found for model version {model_version.model_name} - {model_version.version_name}"
|
306
|
-
)
|
307
|
-
|
308
|
-
def get_monitor(self, name: str) -> model_monitor.ModelMonitor:
|
309
|
-
"""Get a Model Monitor from the Registry
|
310
|
-
|
311
|
-
Args:
|
312
|
-
name: Name of Model Monitor to retrieve.
|
313
|
-
|
314
|
-
Raises:
|
315
|
-
ValueError: If model monitor is not found.
|
316
|
-
|
317
|
-
Returns:
|
318
|
-
The fetched ModelMonitor.
|
319
|
-
"""
|
320
|
-
name_id = sql_identifier.SqlIdentifier(name)
|
321
|
-
|
322
|
-
if not self._model_monitor_client.validate_existence_by_name(
|
323
|
-
monitor_name=name_id,
|
324
|
-
statement_params=self.statement_params,
|
325
|
-
):
|
326
|
-
raise ValueError(f"Unable to find model monitor '{name}'")
|
327
|
-
model_monitor_params: monitor_sql_client._ModelMonitorParams = (
|
328
|
-
self._model_monitor_client.get_model_monitor_by_name(name_id, statement_params=self.statement_params)
|
329
|
-
)
|
330
|
-
|
331
|
-
return model_monitor.ModelMonitor._ref(
|
332
|
-
model_monitor_client=self._model_monitor_client,
|
333
|
-
name=name_id,
|
334
|
-
fully_qualified_model_name=model_monitor_params["fully_qualified_model_name"],
|
335
|
-
version_name=sql_identifier.SqlIdentifier(model_monitor_params["version_name"]),
|
336
|
-
function_name=sql_identifier.SqlIdentifier(model_monitor_params["function_name"]),
|
337
|
-
prediction_columns=model_monitor_params["prediction_columns"],
|
338
|
-
label_columns=model_monitor_params["label_columns"],
|
339
|
-
)
|
340
|
-
|
341
|
-
def show_model_monitors(self) -> List[snowpark.Row]:
|
342
|
-
"""Show all model monitors in the registry.
|
343
|
-
|
344
|
-
Returns:
|
345
|
-
List of snowpark.Row containing metadata for each model monitor.
|
346
|
-
"""
|
347
|
-
return self._model_monitor_client.get_all_model_monitor_metadata()
|
348
|
-
|
349
|
-
def delete_monitor(self, name: str) -> None:
|
350
|
-
"""Delete a Model Monitor from the Registry
|
351
|
-
|
352
|
-
Args:
|
353
|
-
name: Name of the Model Monitor to delete.
|
354
|
-
"""
|
355
|
-
name_id = sql_identifier.SqlIdentifier(name)
|
356
|
-
monitor_params = self._model_monitor_client.get_model_monitor_by_name(name_id)
|
357
|
-
_, _, model = sql_identifier.parse_fully_qualified_name(monitor_params["fully_qualified_model_name"])
|
358
|
-
version = sql_identifier.SqlIdentifier(monitor_params["version_name"])
|
359
|
-
self._model_monitor_client.delete_monitor_metadata(name_id)
|
360
|
-
self._model_monitor_client.delete_baseline_table(model, version)
|
361
|
-
self._model_monitor_client.delete_dynamic_tables(model, version)
|