snowflake-ml-python 1.6.4__py3-none-any.whl → 1.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_complete.py +107 -64
  3. snowflake/cortex/_finetune.py +273 -0
  4. snowflake/cortex/_sse_client.py +91 -28
  5. snowflake/cortex/_util.py +30 -1
  6. snowflake/ml/_internal/telemetry.py +4 -2
  7. snowflake/ml/_internal/type_utils.py +3 -3
  8. snowflake/ml/_internal/utils/import_utils.py +31 -0
  9. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +13 -0
  10. snowflake/ml/data/__init__.py +5 -0
  11. snowflake/ml/data/_internal/arrow_ingestor.py +8 -0
  12. snowflake/ml/data/data_connector.py +1 -1
  13. snowflake/ml/data/torch_utils.py +33 -14
  14. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +5 -3
  15. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +7 -5
  16. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +4 -2
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +3 -1
  18. snowflake/ml/feature_store/examples/example_helper.py +6 -3
  19. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +4 -2
  20. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +4 -2
  21. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +3 -1
  22. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +3 -1
  23. snowflake/ml/feature_store/feature_store.py +1 -2
  24. snowflake/ml/feature_store/feature_view.py +5 -1
  25. snowflake/ml/model/_client/model/model_version_impl.py +145 -11
  26. snowflake/ml/model/_client/ops/model_ops.py +56 -16
  27. snowflake/ml/model/_client/ops/service_ops.py +46 -30
  28. snowflake/ml/model/_client/service/model_deployment_spec.py +19 -8
  29. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  30. snowflake/ml/model/_client/sql/service.py +25 -1
  31. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  32. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  34. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  35. snowflake/ml/model/_model_composer/model_method/model_method.py +1 -1
  36. snowflake/ml/model/_packager/model_env/model_env.py +12 -0
  37. snowflake/ml/model/_packager/model_handlers/_utils.py +6 -2
  38. snowflake/ml/model/_packager/model_handlers/catboost.py +4 -7
  39. snowflake/ml/model/_packager/model_handlers/custom.py +5 -1
  40. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +10 -1
  41. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -7
  42. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -1
  43. snowflake/ml/model/_packager/model_handlers/sklearn.py +51 -7
  44. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +8 -66
  45. snowflake/ml/model/_packager/model_handlers/tensorflow.py +23 -6
  46. snowflake/ml/model/_packager/model_handlers/torchscript.py +14 -14
  47. snowflake/ml/model/_packager/model_handlers/xgboost.py +10 -40
  48. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +2 -3
  49. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +5 -0
  50. snowflake/ml/model/_packager/model_packager.py +0 -11
  51. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -10
  52. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -9
  53. snowflake/ml/model/_packager/{model_handlers/model_objective_utils.py → model_task/model_task_utils.py} +14 -26
  54. snowflake/ml/model/_signatures/core.py +63 -16
  55. snowflake/ml/model/_signatures/pandas_handler.py +87 -27
  56. snowflake/ml/model/_signatures/pytorch_handler.py +2 -2
  57. snowflake/ml/model/_signatures/snowpark_handler.py +2 -1
  58. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -2
  59. snowflake/ml/model/_signatures/utils.py +4 -0
  60. snowflake/ml/model/custom_model.py +47 -7
  61. snowflake/ml/model/model_signature.py +40 -9
  62. snowflake/ml/model/type_hints.py +9 -1
  63. snowflake/ml/modeling/_internal/estimator_utils.py +13 -0
  64. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +7 -2
  65. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +16 -5
  66. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +8 -2
  67. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -3
  68. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -8
  69. snowflake/ml/modeling/cluster/agglomerative_clustering.py +17 -19
  70. snowflake/ml/modeling/cluster/dbscan.py +5 -2
  71. snowflake/ml/modeling/cluster/feature_agglomeration.py +7 -19
  72. snowflake/ml/modeling/cluster/k_means.py +14 -19
  73. snowflake/ml/modeling/cluster/mini_batch_k_means.py +3 -3
  74. snowflake/ml/modeling/cluster/optics.py +6 -6
  75. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -3
  76. snowflake/ml/modeling/compose/column_transformer.py +15 -5
  77. snowflake/ml/modeling/compose/transformed_target_regressor.py +7 -6
  78. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  79. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  80. snowflake/ml/modeling/covariance/min_cov_det.py +2 -2
  81. snowflake/ml/modeling/covariance/oas.py +1 -1
  82. snowflake/ml/modeling/decomposition/kernel_pca.py +2 -2
  83. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +5 -12
  84. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +5 -12
  85. snowflake/ml/modeling/decomposition/pca.py +28 -15
  86. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -0
  87. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -12
  88. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -11
  89. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -8
  90. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -8
  91. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +21 -2
  92. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +18 -2
  93. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +2 -0
  94. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +2 -0
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +21 -8
  96. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +21 -11
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +21 -2
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +18 -2
  99. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +2 -1
  100. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +5 -3
  101. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +2 -2
  102. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +2 -4
  103. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +2 -4
  104. snowflake/ml/modeling/linear_model/ard_regression.py +5 -10
  105. snowflake/ml/modeling/linear_model/bayesian_ridge.py +5 -11
  106. snowflake/ml/modeling/linear_model/elastic_net.py +3 -0
  107. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  108. snowflake/ml/modeling/linear_model/lars.py +0 -10
  109. snowflake/ml/modeling/linear_model/lars_cv.py +1 -11
  110. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  111. snowflake/ml/modeling/linear_model/lasso_lars.py +0 -10
  112. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -11
  113. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +0 -10
  114. snowflake/ml/modeling/linear_model/logistic_regression.py +28 -22
  115. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +30 -24
  116. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  117. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  118. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +4 -13
  119. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +4 -4
  120. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  121. snowflake/ml/modeling/linear_model/perceptron.py +3 -3
  122. snowflake/ml/modeling/linear_model/ransac_regressor.py +3 -2
  123. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +14 -6
  124. snowflake/ml/modeling/linear_model/ridge_cv.py +17 -11
  125. snowflake/ml/modeling/linear_model/sgd_classifier.py +2 -2
  126. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +5 -1
  127. snowflake/ml/modeling/linear_model/sgd_regressor.py +12 -3
  128. snowflake/ml/modeling/manifold/isomap.py +1 -1
  129. snowflake/ml/modeling/manifold/mds.py +3 -3
  130. snowflake/ml/modeling/manifold/tsne.py +10 -4
  131. snowflake/ml/modeling/metrics/classification.py +12 -16
  132. snowflake/ml/modeling/metrics/ranking.py +3 -3
  133. snowflake/ml/modeling/metrics/regression.py +3 -3
  134. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +3 -3
  135. snowflake/ml/modeling/naive_bayes/categorical_nb.py +3 -3
  136. snowflake/ml/modeling/naive_bayes/complement_nb.py +3 -3
  137. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +3 -3
  138. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +10 -4
  139. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +5 -2
  140. snowflake/ml/modeling/neighbors/local_outlier_factor.py +2 -2
  141. snowflake/ml/modeling/neighbors/nearest_centroid.py +7 -14
  142. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  143. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -1
  144. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  145. snowflake/ml/modeling/neural_network/mlp_classifier.py +7 -1
  146. snowflake/ml/modeling/neural_network/mlp_regressor.py +3 -0
  147. snowflake/ml/modeling/pipeline/pipeline.py +16 -14
  148. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +8 -4
  149. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -7
  150. snowflake/ml/modeling/svm/linear_svc.py +25 -16
  151. snowflake/ml/modeling/svm/linear_svr.py +23 -17
  152. snowflake/ml/modeling/svm/nu_svc.py +5 -3
  153. snowflake/ml/modeling/svm/nu_svr.py +3 -1
  154. snowflake/ml/modeling/svm/svc.py +9 -5
  155. snowflake/ml/modeling/svm/svr.py +3 -1
  156. snowflake/ml/modeling/tree/decision_tree_classifier.py +21 -2
  157. snowflake/ml/modeling/tree/decision_tree_regressor.py +18 -2
  158. snowflake/ml/modeling/tree/extra_tree_classifier.py +28 -9
  159. snowflake/ml/modeling/tree/extra_tree_regressor.py +18 -2
  160. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +448 -0
  161. snowflake/ml/monitoring/_manager/model_monitor_manager.py +238 -0
  162. snowflake/ml/monitoring/entities/model_monitor_config.py +10 -10
  163. snowflake/ml/monitoring/model_monitor.py +37 -0
  164. snowflake/ml/registry/_manager/model_manager.py +15 -1
  165. snowflake/ml/registry/registry.py +32 -37
  166. snowflake/ml/version.py +1 -1
  167. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/METADATA +104 -12
  168. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/RECORD +172 -171
  169. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/WHEEL +1 -1
  170. snowflake/ml/monitoring/_client/model_monitor.py +0 -126
  171. snowflake/ml/monitoring/_client/model_monitor_manager.py +0 -361
  172. snowflake/ml/monitoring/_client/monitor_sql_client.py +0 -1335
  173. snowflake/ml/monitoring/entities/model_monitor_interval.py +0 -46
  174. /snowflake/ml/monitoring/{_client/model_monitor_version.py → model_monitor_version.py} +0 -0
  175. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/LICENSE.txt +0 -0
  176. {snowflake_ml_python-1.6.4.dist-info → snowflake_ml_python-1.7.1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.2.0)
2
+ Generator: setuptools (75.3.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,126 +0,0 @@
1
- from typing import List, Union
2
-
3
- import pandas as pd
4
-
5
- from snowflake import snowpark
6
- from snowflake.ml._internal import telemetry
7
- from snowflake.ml._internal.utils import sql_identifier
8
- from snowflake.ml.monitoring._client import monitor_sql_client
9
-
10
-
11
- class ModelMonitor:
12
- """Class to manage instrumentation of Model Monitoring and Observability"""
13
-
14
- name: sql_identifier.SqlIdentifier
15
- _model_monitor_client: monitor_sql_client._ModelMonitorSQLClient
16
- _fully_qualified_model_name: str
17
- _version_name: sql_identifier.SqlIdentifier
18
- _function_name: sql_identifier.SqlIdentifier
19
- _prediction_columns: List[sql_identifier.SqlIdentifier]
20
- _label_columns: List[sql_identifier.SqlIdentifier]
21
-
22
- def __init__(self) -> None:
23
- raise RuntimeError("ModelMonitor's initializer is not meant to be used.")
24
-
25
- @classmethod
26
- def _ref(
27
- cls,
28
- model_monitor_client: monitor_sql_client._ModelMonitorSQLClient,
29
- name: sql_identifier.SqlIdentifier,
30
- *,
31
- fully_qualified_model_name: str,
32
- version_name: sql_identifier.SqlIdentifier,
33
- function_name: sql_identifier.SqlIdentifier,
34
- prediction_columns: List[sql_identifier.SqlIdentifier],
35
- label_columns: List[sql_identifier.SqlIdentifier],
36
- ) -> "ModelMonitor":
37
- self: "ModelMonitor" = object.__new__(cls)
38
- self.name = name
39
- self._model_monitor_client = model_monitor_client
40
- self._fully_qualified_model_name = fully_qualified_model_name
41
- self._version_name = version_name
42
- self._function_name = function_name
43
- self._prediction_columns = prediction_columns
44
- self._label_columns = label_columns
45
- return self
46
-
47
- @telemetry.send_api_usage_telemetry(
48
- project=telemetry.TelemetryProject.MLOPS.value,
49
- subproject=telemetry.TelemetrySubProject.MONITORING.value,
50
- )
51
- def set_baseline(self, baseline_df: Union[pd.DataFrame, snowpark.DataFrame]) -> None:
52
- """
53
- The baseline dataframe is compared with the monitored data once monitoring is enabled.
54
- The columns of the dataframe should match the columns of the source table that the
55
- ModelMonitor was configured with. Calling this method overwrites any existing baseline split data.
56
-
57
- Args:
58
- baseline_df: Snowpark dataframe containing baseline data.
59
-
60
- Raises:
61
- ValueError: baseline_df does not contain prediction or label columns
62
- """
63
- statement_params = telemetry.get_statement_params(
64
- project=telemetry.TelemetryProject.MLOPS.value,
65
- subproject=telemetry.TelemetrySubProject.MONITORING.value,
66
- )
67
-
68
- if isinstance(baseline_df, pd.DataFrame):
69
- baseline_df = self._model_monitor_client._sql_client._session.create_dataframe(baseline_df)
70
-
71
- column_names_identifiers: List[sql_identifier.SqlIdentifier] = [
72
- sql_identifier.SqlIdentifier(column_name) for column_name in baseline_df.columns
73
- ]
74
- prediction_cols_not_found = any(
75
- [prediction_col not in column_names_identifiers for prediction_col in self._prediction_columns]
76
- )
77
- label_cols_not_found = any(
78
- [label_col.identifier() not in column_names_identifiers for label_col in self._label_columns]
79
- )
80
-
81
- if prediction_cols_not_found:
82
- raise ValueError(
83
- "Specified prediction columns were not found in the baseline dataframe. "
84
- f"Columns provided were: {column_names_identifiers}. "
85
- f"Configured prediction columns were: {self._prediction_columns}."
86
- )
87
- if label_cols_not_found:
88
- raise ValueError(
89
- "Specified label columns were not found in the baseline dataframe."
90
- f"Columns provided in the baseline dataframe were: {column_names_identifiers}."
91
- f"Configured label columns were: {self._label_columns}."
92
- )
93
-
94
- # Create the table by materializing the df
95
- self._model_monitor_client.materialize_baseline_dataframe(
96
- baseline_df,
97
- self._fully_qualified_model_name,
98
- self._version_name,
99
- statement_params=statement_params,
100
- )
101
-
102
- def suspend(self) -> None:
103
- """Suspend pipeline for ModelMonitor"""
104
- statement_params = telemetry.get_statement_params(
105
- telemetry.TelemetryProject.MLOPS.value,
106
- telemetry.TelemetrySubProject.MONITORING.value,
107
- )
108
- _, _, model_name = sql_identifier.parse_fully_qualified_name(self._fully_qualified_model_name)
109
- self._model_monitor_client.suspend_monitor_dynamic_tables(
110
- model_name=model_name,
111
- version_name=self._version_name,
112
- statement_params=statement_params,
113
- )
114
-
115
- def resume(self) -> None:
116
- """Resume pipeline for ModelMonitor"""
117
- statement_params = telemetry.get_statement_params(
118
- telemetry.TelemetryProject.MLOPS.value,
119
- telemetry.TelemetrySubProject.MONITORING.value,
120
- )
121
- _, _, model_name = sql_identifier.parse_fully_qualified_name(self._fully_qualified_model_name)
122
- self._model_monitor_client.resume_monitor_dynamic_tables(
123
- model_name=model_name,
124
- version_name=self._version_name,
125
- statement_params=statement_params,
126
- )
@@ -1,361 +0,0 @@
1
- from typing import Any, Dict, List, Optional
2
-
3
- from snowflake import snowpark
4
- from snowflake.ml._internal import telemetry
5
- from snowflake.ml._internal.utils import db_utils, sql_identifier
6
- from snowflake.ml.model import type_hints
7
- from snowflake.ml.model._client.model import model_version_impl
8
- from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
9
- from snowflake.ml.monitoring._client import model_monitor, monitor_sql_client
10
- from snowflake.ml.monitoring.entities import (
11
- model_monitor_config,
12
- model_monitor_interval,
13
- )
14
- from snowflake.snowpark import session
15
-
16
-
17
- def _validate_name_constraints(model_version: model_version_impl.ModelVersion) -> None:
18
- system_table_prefixes = [
19
- monitor_sql_client._SNOWML_MONITORING_TABLE_NAME_PREFIX,
20
- monitor_sql_client._SNOWML_MONITORING_ACCURACY_TABLE_NAME_PREFIX,
21
- ]
22
-
23
- max_allowed_model_name_and_version_length = (
24
- db_utils.MAX_IDENTIFIER_LENGTH - max(len(prefix) for prefix in system_table_prefixes) - 1
25
- ) # -1 includes '_' between model_name + model_version
26
- if len(model_version.model_name) + len(model_version.version_name) > max_allowed_model_name_and_version_length:
27
- error_msg = f"Model name and version name exceeds maximum length of {max_allowed_model_name_and_version_length}"
28
- raise ValueError(error_msg)
29
-
30
-
31
- class ModelMonitorManager:
32
- """Class to manage internal operations for Model Monitor workflows.""" # TODO: Move to Registry.
33
-
34
- @staticmethod
35
- def setup(session: session.Session, database_name: str, schema_name: str) -> None:
36
- """Static method to set up schema for Model Monitoring resources.
37
-
38
- Args:
39
- session: The Snowpark Session to connect with Snowflake.
40
- database_name: The name of the database. If None, the current database of the session
41
- will be used. Defaults to None.
42
- schema_name: The name of the schema. If None, the current schema of the session
43
- will be used. If there is no active schema, the PUBLIC schema will be used. Defaults to None.
44
- """
45
- statement_params = telemetry.get_statement_params(
46
- project=telemetry.TelemetryProject.MLOPS.value,
47
- subproject=telemetry.TelemetrySubProject.MONITORING.value,
48
- )
49
- database_name_id = sql_identifier.SqlIdentifier(database_name)
50
- schema_name_id = sql_identifier.SqlIdentifier(schema_name)
51
- monitor_sql_client._ModelMonitorSQLClient.initialize_monitoring_schema(
52
- session, database_name_id, schema_name_id, statement_params=statement_params
53
- )
54
-
55
- def _fetch_task_from_model_version(
56
- self,
57
- model_version: model_version_impl.ModelVersion,
58
- ) -> type_hints.Task:
59
- task = model_version.get_model_task()
60
- if task == type_hints.Task.UNKNOWN:
61
- raise ValueError("Registry model must be logged with task in order to be monitored.")
62
- return task
63
-
64
- def __init__(
65
- self,
66
- session: session.Session,
67
- database_name: sql_identifier.SqlIdentifier,
68
- schema_name: sql_identifier.SqlIdentifier,
69
- *,
70
- create_if_not_exists: bool = False,
71
- statement_params: Optional[Dict[str, Any]] = None,
72
- ) -> None:
73
- """
74
- Opens a ModelMonitorManager for a given database and schema.
75
- Optionally sets up the schema for Model Monitoring.
76
-
77
- Args:
78
- session: The Snowpark Session to connect with Snowflake.
79
- database_name: The name of the database.
80
- schema_name: The name of the schema.
81
- create_if_not_exists: Flag whether to initialize resources in the schema needed for Model Monitoring.
82
- statement_params: Optional set of statement params.
83
-
84
- Raises:
85
- ValueError: When there is no specified or active database in the session.
86
- """
87
- self._database_name = database_name
88
- self._schema_name = schema_name
89
- self.statement_params = statement_params
90
- self._model_monitor_client = monitor_sql_client._ModelMonitorSQLClient(
91
- session,
92
- database_name=self._database_name,
93
- schema_name=self._schema_name,
94
- )
95
- if create_if_not_exists:
96
- monitor_sql_client._ModelMonitorSQLClient.initialize_monitoring_schema(
97
- session, self._database_name, self._schema_name, self.statement_params
98
- )
99
- elif not self._model_monitor_client._validate_is_initialized():
100
- raise ValueError(
101
- "Monitoring has not been setup. Set create_if_not_exists or call ModelMonitorManager.setup"
102
- )
103
-
104
- def _get_and_validate_model_function_from_model_version(
105
- self, function: str, model_version: model_version_impl.ModelVersion
106
- ) -> model_manifest_schema.ModelFunctionInfo:
107
- functions = model_version.show_functions()
108
- for f in functions:
109
- if f["target_method"] == function:
110
- return f
111
- existing_target_methods = {f["target_method"] for f in functions}
112
- raise ValueError(
113
- f"Function with name {function} does not exist in the given model version. "
114
- f"Found: {existing_target_methods}."
115
- )
116
-
117
- def _validate_monitor_config_or_raise(
118
- self,
119
- table_config: model_monitor_config.ModelMonitorTableConfig,
120
- model_monitor_config: model_monitor_config.ModelMonitorConfig,
121
- ) -> None:
122
- """Validate provided config for model monitor.
123
-
124
- Args:
125
- table_config: Config for model monitor tables.
126
- model_monitor_config: Config for ModelMonitor.
127
-
128
- Raises:
129
- ValueError: If warehouse provided does not exist.
130
- """
131
-
132
- # Validate naming will not exceed 255 chars
133
- _validate_name_constraints(model_monitor_config.model_version)
134
-
135
- if len(table_config.prediction_columns) != len(table_config.label_columns):
136
- raise ValueError("Prediction and Label column names must be of the same length.")
137
- # output and ground cols are list to keep interface extensible.
138
- # for prpr only one label and one output col will be supported
139
- if len(table_config.prediction_columns) != 1 or len(table_config.label_columns) != 1:
140
- raise ValueError("Multiple Output columns are not supported in monitoring")
141
-
142
- # Validate warehouse exists.
143
- warehouse_name_id = sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name)
144
- self._model_monitor_client.validate_monitor_warehouse(warehouse_name_id, statement_params=self.statement_params)
145
-
146
- # Validate refresh interval.
147
- try:
148
- num_units, time_units = model_monitor_config.refresh_interval.strip().split(" ")
149
- int(num_units) # try to cast
150
- if time_units.lower() not in {"seconds", "minutes", "hours", "days"}:
151
- raise ValueError(
152
- """Invalid time unit in refresh interval. Provide '<num> <seconds | minutes | hours | days>'.
153
- See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info."""
154
- )
155
- except Exception as e: # TODO: Link to DT page.
156
- raise ValueError(
157
- f"""Failed to parse refresh interval with exception {e}.
158
- Provide '<num> <seconds | minutes | hours | days>'.
159
- See https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table#required-parameters for more info."""
160
- )
161
-
162
- def add_monitor(
163
- self,
164
- name: str,
165
- table_config: model_monitor_config.ModelMonitorTableConfig,
166
- model_monitor_config: model_monitor_config.ModelMonitorConfig,
167
- *,
168
- add_dashboard_udtfs: bool = False,
169
- ) -> model_monitor.ModelMonitor:
170
- """Add a new Model Monitor.
171
-
172
- Args:
173
- name: Name of Model Monitor to create.
174
- table_config: Configuration options for the source table used in ModelMonitor.
175
- model_monitor_config: Configuration options of ModelMonitor.
176
- add_dashboard_udtfs: Add UDTFs useful for creating a dashboard.
177
-
178
- Returns:
179
- The newly added ModelMonitor object.
180
- """
181
- # Validates configuration or raise.
182
- self._validate_monitor_config_or_raise(table_config, model_monitor_config)
183
- model_function = self._get_and_validate_model_function_from_model_version(
184
- model_monitor_config.model_function_name, model_monitor_config.model_version
185
- )
186
- monitor_refresh_interval = model_monitor_interval.ModelMonitorRefreshInterval(
187
- model_monitor_config.refresh_interval
188
- )
189
- name_id = sql_identifier.SqlIdentifier(name)
190
- source_table_name_id = sql_identifier.SqlIdentifier(table_config.source_table)
191
- prediction_columns = [
192
- sql_identifier.SqlIdentifier(column_name) for column_name in table_config.prediction_columns
193
- ]
194
- label_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in table_config.label_columns]
195
- id_columns = [sql_identifier.SqlIdentifier(column_name) for column_name in table_config.id_columns]
196
- ts_column = sql_identifier.SqlIdentifier(table_config.timestamp_column)
197
-
198
- # Validate source table
199
- self._model_monitor_client.validate_source_table(
200
- source_table_name=source_table_name_id,
201
- timestamp_column=ts_column,
202
- prediction_columns=prediction_columns,
203
- label_columns=label_columns,
204
- id_columns=id_columns,
205
- model_function=model_function,
206
- )
207
-
208
- task = self._fetch_task_from_model_version(model_version=model_monitor_config.model_version)
209
- score_type = self._model_monitor_client.get_score_type(task, source_table_name_id, prediction_columns)
210
-
211
- # Insert monitoring metadata for new model version.
212
- self._model_monitor_client.create_monitor_on_model_version(
213
- monitor_name=name_id,
214
- source_table_name=source_table_name_id,
215
- fully_qualified_model_name=model_monitor_config.model_version.fully_qualified_model_name,
216
- version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
217
- function_name=model_monitor_config.model_function_name,
218
- timestamp_column=ts_column,
219
- prediction_columns=prediction_columns,
220
- label_columns=label_columns,
221
- id_columns=id_columns,
222
- task=task,
223
- statement_params=self.statement_params,
224
- )
225
-
226
- # Create Dynamic tables for model monitor.
227
- self._model_monitor_client.create_dynamic_tables_for_monitor(
228
- model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
229
- model_version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
230
- task=task,
231
- source_table_name=source_table_name_id,
232
- refresh_interval=monitor_refresh_interval,
233
- aggregation_window=model_monitor_config.aggregation_window,
234
- warehouse_name=sql_identifier.SqlIdentifier(model_monitor_config.background_compute_warehouse_name),
235
- timestamp_column=sql_identifier.SqlIdentifier(table_config.timestamp_column),
236
- id_columns=id_columns,
237
- prediction_columns=prediction_columns,
238
- label_columns=label_columns,
239
- score_type=score_type,
240
- )
241
-
242
- # Initialize baseline table.
243
- self._model_monitor_client.initialize_baseline_table(
244
- model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
245
- version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
246
- source_table_name=table_config.source_table,
247
- columns_to_drop=[ts_column, *id_columns],
248
- statement_params=self.statement_params,
249
- )
250
-
251
- # Add udtfs helpful for dashboard queries.
252
- # TODO(apgupta) Make this true by default.
253
- if add_dashboard_udtfs:
254
- self._model_monitor_client.add_dashboard_udtfs(
255
- name_id,
256
- model_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.model_name),
257
- model_version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
258
- task=task,
259
- score_type=score_type,
260
- output_columns=prediction_columns,
261
- ground_truth_columns=label_columns,
262
- )
263
-
264
- return model_monitor.ModelMonitor._ref(
265
- model_monitor_client=self._model_monitor_client,
266
- name=name_id,
267
- fully_qualified_model_name=model_monitor_config.model_version.fully_qualified_model_name,
268
- version_name=sql_identifier.SqlIdentifier(model_monitor_config.model_version.version_name),
269
- function_name=sql_identifier.SqlIdentifier(model_monitor_config.model_function_name),
270
- prediction_columns=prediction_columns,
271
- label_columns=label_columns,
272
- )
273
-
274
- def get_monitor_by_model_version(
275
- self, model_version: model_version_impl.ModelVersion
276
- ) -> model_monitor.ModelMonitor:
277
- fq_model_name = model_version.fully_qualified_model_name
278
- version_name = sql_identifier.SqlIdentifier(model_version.version_name)
279
- if self._model_monitor_client.validate_existence(fq_model_name, version_name, self.statement_params):
280
- model_db, model_schema, model_name = sql_identifier.parse_fully_qualified_name(fq_model_name)
281
- if model_db is None or model_schema is None:
282
- raise ValueError("Failed to parse model name")
283
-
284
- model_monitor_params: monitor_sql_client._ModelMonitorParams = (
285
- self._model_monitor_client.get_model_monitor_by_model_version(
286
- model_db=model_db,
287
- model_schema=model_schema,
288
- model_name=model_name,
289
- version_name=version_name,
290
- statement_params=self.statement_params,
291
- )
292
- )
293
- return model_monitor.ModelMonitor._ref(
294
- model_monitor_client=self._model_monitor_client,
295
- name=sql_identifier.SqlIdentifier(model_monitor_params["monitor_name"]),
296
- fully_qualified_model_name=fq_model_name,
297
- version_name=version_name,
298
- function_name=sql_identifier.SqlIdentifier(model_monitor_params["function_name"]),
299
- prediction_columns=model_monitor_params["prediction_columns"],
300
- label_columns=model_monitor_params["label_columns"],
301
- )
302
-
303
- else:
304
- raise ValueError(
305
- f"ModelMonitor not found for model version {model_version.model_name} - {model_version.version_name}"
306
- )
307
-
308
- def get_monitor(self, name: str) -> model_monitor.ModelMonitor:
309
- """Get a Model Monitor from the Registry
310
-
311
- Args:
312
- name: Name of Model Monitor to retrieve.
313
-
314
- Raises:
315
- ValueError: If model monitor is not found.
316
-
317
- Returns:
318
- The fetched ModelMonitor.
319
- """
320
- name_id = sql_identifier.SqlIdentifier(name)
321
-
322
- if not self._model_monitor_client.validate_existence_by_name(
323
- monitor_name=name_id,
324
- statement_params=self.statement_params,
325
- ):
326
- raise ValueError(f"Unable to find model monitor '{name}'")
327
- model_monitor_params: monitor_sql_client._ModelMonitorParams = (
328
- self._model_monitor_client.get_model_monitor_by_name(name_id, statement_params=self.statement_params)
329
- )
330
-
331
- return model_monitor.ModelMonitor._ref(
332
- model_monitor_client=self._model_monitor_client,
333
- name=name_id,
334
- fully_qualified_model_name=model_monitor_params["fully_qualified_model_name"],
335
- version_name=sql_identifier.SqlIdentifier(model_monitor_params["version_name"]),
336
- function_name=sql_identifier.SqlIdentifier(model_monitor_params["function_name"]),
337
- prediction_columns=model_monitor_params["prediction_columns"],
338
- label_columns=model_monitor_params["label_columns"],
339
- )
340
-
341
- def show_model_monitors(self) -> List[snowpark.Row]:
342
- """Show all model monitors in the registry.
343
-
344
- Returns:
345
- List of snowpark.Row containing metadata for each model monitor.
346
- """
347
- return self._model_monitor_client.get_all_model_monitor_metadata()
348
-
349
- def delete_monitor(self, name: str) -> None:
350
- """Delete a Model Monitor from the Registry
351
-
352
- Args:
353
- name: Name of the Model Monitor to delete.
354
- """
355
- name_id = sql_identifier.SqlIdentifier(name)
356
- monitor_params = self._model_monitor_client.get_model_monitor_by_name(name_id)
357
- _, _, model = sql_identifier.parse_fully_qualified_name(monitor_params["fully_qualified_model_name"])
358
- version = sql_identifier.SqlIdentifier(monitor_params["version_name"])
359
- self._model_monitor_client.delete_monitor_metadata(name_id)
360
- self._model_monitor_client.delete_baseline_table(model, version)
361
- self._model_monitor_client.delete_dynamic_tables(model, version)