validmind 2.4.10__py3-none-any.whl → 2.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. validmind/__version__.py +1 -1
  2. validmind/api_client.py +1 -0
  3. validmind/client.py +0 -2
  4. validmind/input_registry.py +8 -0
  5. validmind/tests/__types__.py +4 -0
  6. validmind/tests/data_validation/DatasetDescription.py +1 -0
  7. validmind/tests/model_validation/sklearn/ClassifierPerformance.py +15 -6
  8. validmind/tests/model_validation/sklearn/ClusterPerformance.py +2 -2
  9. validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +10 -3
  10. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +349 -291
  11. validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +1 -1
  12. validmind/tests/model_validation/sklearn/RobustnessDiagnosis.py +36 -37
  13. validmind/tests/ongoing_monitoring/FeatureDrift.py +182 -0
  14. validmind/tests/ongoing_monitoring/PredictionAcrossEachFeature.py +76 -0
  15. validmind/tests/ongoing_monitoring/PredictionCorrelation.py +91 -0
  16. validmind/tests/ongoing_monitoring/TargetPredictionDistributionPlot.py +57 -0
  17. validmind/tests/run.py +35 -19
  18. validmind/unit_metrics/__init__.py +1 -1
  19. validmind/unit_metrics/classification/sklearn/ROC_AUC.py +22 -1
  20. validmind/utils.py +1 -1
  21. validmind/vm_models/__init__.py +2 -0
  22. validmind/vm_models/dataset/dataset.py +55 -14
  23. validmind/vm_models/input.py +31 -0
  24. validmind/vm_models/model.py +4 -2
  25. validmind/vm_models/test_context.py +9 -2
  26. {validmind-2.4.10.dist-info → validmind-2.5.1.dist-info}/METADATA +1 -1
  27. {validmind-2.4.10.dist-info → validmind-2.5.1.dist-info}/RECORD +30 -25
  28. {validmind-2.4.10.dist-info → validmind-2.5.1.dist-info}/LICENSE +0 -0
  29. {validmind-2.4.10.dist-info → validmind-2.5.1.dist-info}/WHEEL +0 -0
  30. {validmind-2.4.10.dist-info → validmind-2.5.1.dist-info}/entry_points.txt +0 -0
@@ -65,7 +65,7 @@ class PrecisionRecallCurve(Metric):
65
65
  raise SkipTestError("Skipping PrecisionRecallCurve for Foundation models")
66
66
 
67
67
  y_true = self.inputs.dataset.y
68
- y_pred = self.inputs.model.predict_proba(self.inputs.dataset.x)
68
+ y_pred = self.inputs.dataset.y_prob(self.inputs.model)
69
69
 
70
70
  # PR curve is only supported for binary classification
71
71
  if len(np.unique(y_true)) > 2:
@@ -12,6 +12,7 @@ import pandas as pd
12
12
  import seaborn as sns
13
13
  from sklearn import metrics
14
14
 
15
+ from validmind.errors import MissingOrInvalidModelPredictFnError
15
16
  from validmind.vm_models import (
16
17
  Figure,
17
18
  ResultSummary,
@@ -22,6 +23,7 @@ from validmind.vm_models import (
22
23
  )
23
24
 
24
25
 
26
+ # TODO: make this support regression and classification as well as more performance metrics
25
27
  @dataclass
26
28
  class RobustnessDiagnosis(ThresholdTest):
27
29
  """
@@ -39,13 +41,13 @@ class RobustnessDiagnosis(ThresholdTest):
39
41
 
40
42
  This test is conducted by adding Gaussian noise, proportional to a particular standard deviation scale, to numeric
41
43
  input features of both the training and testing datasets. The model performance in the face of these perturbed
42
- features is then evaluated using metrics (default: 'accuracy'). This process is iterated over a range of scale
43
- factors. The resulting accuracy trend against the amount of noise introduced is illustrated with a line chart. A
44
- predetermined threshold determines what level of accuracy decay due to perturbation is considered acceptable.
44
+ features is then evaluated using the ROC_AUC score. This process is iterated over a range of scale
45
+ factors. The resulting auc trend against the amount of noise introduced is illustrated with a line chart. A
46
+ predetermined threshold determines what level of auc decay due to perturbation is considered acceptable.
45
47
 
46
48
  **Signs of High Risk**:
47
- - Substantial decreases in accuracy when noise is introduced to feature inputs.
48
- - The decay in accuracy surpasses the configured threshold, indicating that the model is not robust against input
49
+ - Substantial decreases in auc when noise is introduced to feature inputs.
50
+ - The decay in auc surpasses the configured threshold, indicating that the model is not robust against input
49
51
  noise.
50
52
  - Instances where one or more elements provided in the features list don't match with the training dataset's
51
53
  numerical feature columns.
@@ -57,15 +59,12 @@ class RobustnessDiagnosis(ThresholdTest):
57
59
  - Detailed results visualization helps in interpreting the outcome of robustness testing.
58
60
 
59
61
  **Limitations**:
62
+ - The default threshold for auc decay is set to 0.05, which is unlikely to be optimal for most use cases and
63
+ should be adjusted based on domain expertise to suit the needs of the specific model.
60
64
  - Only numerical features are perturbed, leaving out non-numerical features, which can lead to an incomplete
61
65
  analysis of robustness.
62
- - The default metric used is accuracy, which might not always give the best measure of a model's success,
63
- particularly for imbalanced datasets.
64
66
  - The test is contingent on the assumption that the added Gaussian noise sufficiently represents potential data
65
67
  corruption or incompleteness in real-world scenarios.
66
- - There might be a requirement to fine-tune the set decay threshold for accuracy with the help of domain knowledge
67
- or specific project requisites.
68
- - The robustness test might not deliver the expected results for datasets with a text column.
69
68
  """
70
69
 
71
70
  name = "robustness"
@@ -73,9 +72,9 @@ class RobustnessDiagnosis(ThresholdTest):
73
72
  default_params = {
74
73
  "features_columns": None,
75
74
  "scaling_factor_std_dev_list": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5],
76
- "accuracy_decay_threshold": 4,
75
+ "auc_decay_threshold": 0.05,
77
76
  }
78
- tasks = ["classification", "text_classification"]
77
+ tasks = ["classification"]
79
78
  tags = [
80
79
  "sklearn",
81
80
  "binary_classification",
@@ -84,17 +83,15 @@ class RobustnessDiagnosis(ThresholdTest):
84
83
  "visualization",
85
84
  ]
86
85
 
87
- default_metrics = {"accuracy": metrics.accuracy_score}
88
-
89
86
  def run(self):
90
87
  # Validate X std deviation parameter
91
88
  if "scaling_factor_std_dev_list" not in self.params:
92
89
  raise ValueError("scaling_factor_std_dev_list must be provided in params")
93
90
  x_std_dev_list = self.params["scaling_factor_std_dev_list"]
94
91
 
95
- if self.params["accuracy_decay_threshold"] is None:
96
- raise ValueError("accuracy_decay_threshold must be provided in params")
97
- accuracy_threshold = self.params["accuracy_decay_threshold"]
92
+ if self.params["auc_decay_threshold"] is None:
93
+ raise ValueError("auc_decay_threshold must be provided in params")
94
+ auc_threshold = self.params["auc_decay_threshold"]
98
95
 
99
96
  if self.inputs.model is None:
100
97
  raise ValueError("model must of provided to run this test")
@@ -131,9 +128,7 @@ class RobustnessDiagnosis(ThresholdTest):
131
128
  test_results = []
132
129
  test_figures = []
133
130
 
134
- results_headers = ["Perturbation Size", "Dataset Type", "Records"] + list(
135
- self.default_metrics.keys()
136
- )
131
+ results_headers = ["Perturbation Size", "Dataset Type", "Records", "AUC"]
137
132
  results = {k: [] for k in results_headers}
138
133
  # Iterate scaling factor for the standard deviation list
139
134
  for x_std_dev in x_std_dev_list:
@@ -159,32 +154,32 @@ class RobustnessDiagnosis(ThresholdTest):
159
154
  test_figures.append(
160
155
  Figure(
161
156
  for_object=self,
162
- key=f"{self.name}:accuracy",
157
+ key=f"{self.name}:auc",
163
158
  figure=fig,
164
159
  metadata={
165
- "metric": "accuracy",
160
+ "metric": "AUC",
166
161
  "features_list": features_list,
167
162
  },
168
163
  )
169
164
  )
170
165
 
171
- train_acc = df.loc[(df["Dataset Type"] == "Training"), "accuracy"].values[0]
172
- test_acc = df.loc[(df["Dataset Type"] == "Test"), "accuracy"].values[0]
166
+ train_auc = df.loc[(df["Dataset Type"] == "Training"), "AUC"].values[0]
167
+ test_auc = df.loc[(df["Dataset Type"] == "Test"), "AUC"].values[0]
173
168
 
174
169
  df["Passed"] = np.where(
175
170
  (df["Dataset Type"] == "Training")
176
- & (df["accuracy"] >= (train_acc - accuracy_threshold)),
171
+ & (df["AUC"] >= (train_auc - auc_threshold)),
177
172
  True,
178
173
  np.where(
179
174
  (df["Dataset Type"] == "Test")
180
- & (df["accuracy"] >= (test_acc - accuracy_threshold)),
175
+ & (df["AUC"] >= (test_auc - auc_threshold)),
181
176
  True,
182
177
  False,
183
178
  ),
184
179
  )
185
180
  test_results.append(
186
181
  ThresholdTestResult(
187
- test_name="accuracy",
182
+ test_name="AUC",
188
183
  column=features_list,
189
184
  passed=True,
190
185
  values={"records": df.to_dict("records")},
@@ -194,7 +189,7 @@ class RobustnessDiagnosis(ThresholdTest):
194
189
  test_results, passed=df["Passed"].all(), figures=test_figures
195
190
  )
196
191
 
197
- def summary(self, results: List[ThresholdTestResult], all_passed: bool):
192
+ def summary(self, results: List[ThresholdTestResult], _):
198
193
  results_table = [
199
194
  record for result in results for record in result.values["records"]
200
195
  ]
@@ -229,9 +224,13 @@ class RobustnessDiagnosis(ThresholdTest):
229
224
  results["Dataset Type"].append(dataset_type)
230
225
  results["Perturbation Size"].append(x_std_dev)
231
226
  results["Records"].append(df.shape[0])
232
- y_prediction = self.inputs.model.predict(df)
233
- for metric, metric_fn in self.default_metrics.items():
234
- results[metric].append(metric_fn(y_true, y_prediction) * 100)
227
+
228
+ try:
229
+ y_proba = self.inputs.model.predict_proba(df)
230
+ except MissingOrInvalidModelPredictFnError:
231
+ y_proba = self.inputs.model.predict(df)
232
+
233
+ results["AUC"].append(metrics.roc_auc_score(y_true, y_proba))
235
234
 
236
235
  def _add_noise_std_dev(
237
236
  self, values: List[float], x_std_dev: float
@@ -256,14 +255,14 @@ class RobustnessDiagnosis(ThresholdTest):
256
255
 
257
256
  def _plot_robustness(self, results: dict, features_columns: List[str]):
258
257
  """
259
- Plots the model's accuracy under feature perturbations.
258
+ Plots the model's auc under feature perturbations.
260
259
  Args:
261
260
  results (dict): A dictionary containing the results of the evaluation.
262
261
  It has the following keys:
263
262
  - 'Dataset Type': the type of dataset evaluated, e.g. 'Training' or 'Test'.
264
263
  - 'Perturbation Size': the size of the perturbation applied to the features.
265
264
  - 'Records': the number of records evaluated.
266
- - Any other metric used for evaluation as keys, e.g. 'accuracy', 'precision', 'recall'.
265
+ - 'auc': the ROC AUC score obtained for the evaluation.
267
266
  The values of each key are lists containing the results for each evaluation.
268
267
  features_columns (list[str]): A list containing the names of the features perturbed.
269
268
  Returns:
@@ -277,7 +276,7 @@ class RobustnessDiagnosis(ThresholdTest):
277
276
  sns.lineplot(
278
277
  data=df,
279
278
  x="Perturbation Size",
280
- y="accuracy",
279
+ y="AUC",
281
280
  hue="Dataset Type",
282
281
  style="Dataset Type",
283
282
  linewidth=3,
@@ -288,7 +287,7 @@ class RobustnessDiagnosis(ThresholdTest):
288
287
  ax=ax,
289
288
  )
290
289
  ax.tick_params(axis="x")
291
- ax.set_ylabel("Accuracy", weight="bold", fontsize=18)
290
+ ax.set_ylabel("AUC", weight="bold", fontsize=18)
292
291
  ax.legend(fontsize=18)
293
292
  ax.set_xlabel(
294
293
  "Perturbation Size (X * Standard Deviation)", weight="bold", fontsize=18
@@ -321,9 +320,9 @@ class RobustnessDiagnosis(ThresholdTest):
321
320
  assert isinstance(test_result.values, dict)
322
321
  assert "records" in test_result.values
323
322
 
324
- # For unperturbed training dataset, accuracy should be present
323
+ # For unperturbed training dataset, auc should be present
325
324
  if (
326
325
  test_result.column == self.params["features_columns"]
327
326
  and 0.0 in test_result.values["records"][0]["Perturbation Size"]
328
327
  ):
329
- assert "accuracy" in test_result.values["records"][0]
328
+ assert "AUC" in test_result.values["records"][0]
@@ -0,0 +1,182 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import pandas as pd
9
+
10
+ from validmind import tags, tasks
11
+
12
+
13
+ @tags("visualization")
14
+ @tasks("monitoring")
15
+ def FeatureDrift(
16
+ datasets, bins=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], feature_columns=None
17
+ ):
18
+ """
19
+ **Purpose**:
20
+
21
+ The Feature Drift test aims to evaluate how much the distribution of features has shifted over time between two
22
+ datasets, typically training and monitoring datasets. It uses the Population Stability Index (PSI) to quantify this
23
+ change, providing insights into the model's robustness and the necessity for retraining or feature engineering.
24
+
25
+ **Test Mechanism**:
26
+
27
+ This test calculates the PSI by:
28
+ - Bucketing the distributions of each feature in both datasets.
29
+ - Comparing the percentage of observations in each bucket between the two datasets.
30
+ - Aggregating the differences across all buckets for each feature to produce the PSI score for that feature.
31
+
32
+ The PSI score is interpreted as:
33
+ - PSI < 0.1: No significant population change.
34
+ - PSI < 0.2: Moderate population change.
35
+ - PSI >= 0.2: Significant population change.
36
+
37
+ **Signs of High Risk**:
38
+
39
+ - PSI >= 0.2 for any feature, indicating a significant distribution shift.
40
+ - Consistently high PSI scores across multiple features.
41
+ - Sudden spikes in PSI in recent monitoring data compared to historical data.
42
+
43
+ **Strengths**:
44
+
45
+ - Provides a quantitative measure of feature distribution changes.
46
+ - Easily interpretable thresholds for decision-making.
47
+ - Helps in early detection of data drift, prompting timely interventions.
48
+
49
+ **Limitations**:
50
+
51
+ - May not capture more intricate changes in data distribution nuances.
52
+ - Assumes that bucket thresholds (quantiles) adequately represent distribution shifts.
53
+ - PSI score interpretation can be overly simplistic for complex datasets.
54
+ """
55
+
56
+ # Feature columns for both datasets should be the same if not given
57
+ default_feature_columns = datasets[0].feature_columns
58
+ feature_columns = feature_columns or default_feature_columns
59
+
60
+ x_train_df = datasets[0].x_df()
61
+ x_test_df = datasets[1].x_df()
62
+
63
+ quantiles_train = x_train_df[feature_columns].quantile(
64
+ bins, method="single", interpolation="nearest"
65
+ )
66
+ PSI_QUANTILES = quantiles_train.to_dict()
67
+
68
+ PSI_BUCKET_FRAC, col, n = get_psi_buckets(
69
+ x_test_df, x_train_df, feature_columns, bins, PSI_QUANTILES
70
+ )
71
+
72
+ def nest(d: dict) -> dict:
73
+ result = {}
74
+ for key, value in d.items():
75
+ target = result
76
+ for k in key[:-1]: # traverse all keys but the last
77
+ target = target.setdefault(k, {})
78
+ target[key[-1]] = value
79
+ return result
80
+
81
+ PSI_BUCKET_FRAC = nest(PSI_BUCKET_FRAC)
82
+
83
+ PSI_SCORES = {}
84
+ for col in feature_columns:
85
+ psi = 0
86
+ for n in bins:
87
+ actual = PSI_BUCKET_FRAC["test"][col][n]
88
+ expected = PSI_BUCKET_FRAC["train"][col][n]
89
+ psi_of_bucket = (actual - expected) * np.log(
90
+ (actual + 1e-6) / (expected + 1e-6)
91
+ )
92
+ psi += psi_of_bucket
93
+ PSI_SCORES[col] = psi
94
+
95
+ psi_df = pd.DataFrame(list(PSI_SCORES.items()), columns=["Features", "PSI Score"])
96
+
97
+ psi_df.sort_values(by=["PSI Score"], inplace=True, ascending=False)
98
+
99
+ psi_table = [
100
+ {"Features": values["Features"], "PSI Score": values["PSI Score"]}
101
+ for i, values in enumerate(psi_df.to_dict(orient="records"))
102
+ ]
103
+
104
+ save_fig = plot_hist(PSI_BUCKET_FRAC, bins)
105
+
106
+ final_psi = pd.DataFrame(psi_table)
107
+
108
+ return (final_psi, *save_fig)
109
+
110
+
111
+ def get_psi_buckets(x_test_df, x_train_df, feature_columns, bins, PSI_QUANTILES):
112
+ DATA = {"test": x_test_df, "train": x_train_df}
113
+ PSI_BUCKET_FRAC = {}
114
+ for table in DATA.keys():
115
+ total_count = DATA[table].shape[0]
116
+ for col in feature_columns:
117
+ count_sum = 0
118
+ for n in bins:
119
+ if n == 0:
120
+ bucket_count = (DATA[table][col] < PSI_QUANTILES[col][n]).sum()
121
+ elif n < 9:
122
+ bucket_count = (
123
+ total_count
124
+ - count_sum
125
+ - ((DATA[table][col] >= PSI_QUANTILES[col][n]).sum())
126
+ )
127
+ elif n == 9:
128
+ bucket_count = total_count - count_sum
129
+ count_sum += bucket_count
130
+ PSI_BUCKET_FRAC[table, col, n] = bucket_count / total_count
131
+ return PSI_BUCKET_FRAC, col, n
132
+
133
+
134
+ def plot_hist(PSI_BUCKET_FRAC, bins):
135
+ bin_table_psi = pd.DataFrame(PSI_BUCKET_FRAC)
136
+ save_fig = []
137
+ for i in range(len(bin_table_psi)):
138
+
139
+ x = pd.DataFrame(
140
+ bin_table_psi.iloc[i]["test"].items(),
141
+ columns=["Bin", "Population % Reference"],
142
+ )
143
+ y = pd.DataFrame(
144
+ bin_table_psi.iloc[i]["train"].items(),
145
+ columns=["Bin", "Population % Monitoring"],
146
+ )
147
+ xy = x.merge(y, on="Bin")
148
+ xy.index = xy["Bin"]
149
+ xy = xy.drop(columns="Bin", axis=1)
150
+ feature_name = bin_table_psi.index[i]
151
+
152
+ n = len(bins)
153
+ r = np.arange(n)
154
+ width = 0.25
155
+
156
+ fig = plt.figure()
157
+
158
+ plt.bar(
159
+ r,
160
+ xy["Population % Reference"],
161
+ color="b",
162
+ width=width,
163
+ edgecolor="black",
164
+ label="Reference {0}".format(feature_name),
165
+ )
166
+ plt.bar(
167
+ r + width,
168
+ xy["Population % Monitoring"],
169
+ color="g",
170
+ width=width,
171
+ edgecolor="black",
172
+ label="Monitoring {0}".format(feature_name),
173
+ )
174
+
175
+ plt.xlabel("Bin")
176
+ plt.ylabel("Population %")
177
+ plt.title("Histogram of Population Differences {0}".format(feature_name))
178
+ plt.legend()
179
+ plt.tight_layout()
180
+ plt.close()
181
+ save_fig.append(fig)
182
+ return save_fig
@@ -0,0 +1,76 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+
6
+ import matplotlib.pyplot as plt
7
+
8
+ from validmind import tags, tasks
9
+
10
+
11
+ @tags("visualization")
12
+ @tasks("monitoring")
13
+ def PredictionAcrossEachFeature(datasets, model):
14
+ """
15
+ **Purpose:**
16
+ This test shows visually the prediction using reference data and monitoring data across each individual feature. If
17
+ there are significant differences in predictions across feature values from reference to monitoring dataset, then
18
+ further investigation is needed as the model is producing predictions that are different than what was observed
19
+ during the training of the model.
20
+
21
+ **Test Mechanism:**
22
+ The test creates scatter plots for each feature, comparing the reference dataset (used for training) with the
23
+ monitoring dataset (used in production). Each plot has two subplots: one for the reference data and one for the
24
+ monitoring data, visualizing the prediction probabilities. This allows for a visual comparison of the model's
25
+ behavior across different datasets.
26
+
27
+ **Signs of High Risk:**
28
+ - Significant discrepancies between the reference and monitoring subplots for the same feature
29
+ - Unexpected patterns or trends in monitoring data that weren't present in reference data
30
+
31
+ **Strengths:**
32
+ - Provides a clear visual representation of model performance across different features
33
+ - Allows for easy identification of features where the model's predictions have changed
34
+ - Facilitates quick detection of potential issues with the model when deployed in production
35
+
36
+ **Limitations:**
37
+ - Interpretation of scatter plots can be subjective and may require expertise
38
+ - Visualizations do not provide quantitative metrics for objective evaluation
39
+ - May not capture all types of distribution changes or issues with the model's predictions
40
+ """
41
+
42
+ """
43
+ This test shows visually the prediction using reference data and monitoring data
44
+ across each individual feature. If there are significant differences in predictions
45
+ across feature values from reference to monitoring dataset then futher investigation
46
+ is needed as the model is producing predictions that are different then what was
47
+ observed during the training of the model.
48
+ """
49
+
50
+ df_reference = datasets[0]._df
51
+ df_monitoring = datasets[1]._df
52
+
53
+ figures_to_save = []
54
+ for column in df_reference:
55
+ prediction_prob_column = f"{model.input_id}_probabilities"
56
+ prediction_column = f"{model.input_id}_prediction"
57
+ if column == prediction_prob_column or column == prediction_column:
58
+ pass
59
+ else:
60
+ fig, axs = plt.subplots(1, 2, figsize=(20, 10), sharey="row")
61
+
62
+ ax1, ax2 = axs
63
+
64
+ ax1.scatter(df_reference[column], df_reference[prediction_prob_column])
65
+ ax2.scatter(df_monitoring[column], df_monitoring[prediction_prob_column])
66
+
67
+ ax1.set_title("Reference")
68
+ ax1.set_xlabel(column)
69
+ ax1.set_ylabel("Prediction Value")
70
+
71
+ ax2.set_title("Monitoring")
72
+ ax2.set_xlabel(column)
73
+ figures_to_save.append(fig)
74
+ plt.close()
75
+
76
+ return tuple(figures_to_save)
@@ -0,0 +1,91 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+
9
+ from validmind import tags, tasks
10
+
11
+
12
+ @tags("visualization")
13
+ @tasks("monitoring")
14
+ def PredictionCorrelation(datasets, model):
15
+ """
16
+ **Purpose:**
17
+ The test is used to assess the correlation pairs for each feature between model predictions from reference and
18
+ monitoring datasets. The primary goal is to detect significant changes in these pairs, which may signal target
19
+ drift, leading to lower model performance.
20
+
21
+ **Test Mechanism:**
22
+ The test calculates the correlation of each feature with model predictions for both reference and monitoring
23
+ datasets. The test then compares these correlations side-by-side via a bar plot and a correlation table. Features
24
+ with significant changes in correlation pairs highlight potential risks of model drift.
25
+
26
+ **Signs of High Risk:**
27
+ - Significant changes in correlation pairs between the reference and monitoring predictions.
28
+ - Notable correlation differences indicating a potential shift in the relationship between features and the target
29
+ variable.
30
+
31
+ **Strengths:**
32
+ - Allows for visual identification of drift in feature relationships with model predictions.
33
+ - Comparison via a clear bar plot assists in understanding model stability over time.
34
+ - Helps in early detection of target drift, enabling timely interventions.
35
+
36
+ **Limitations:**
37
+ - May require substantial reference and monitoring data for accurate comparison.
38
+ - Correlation does not imply causation, and other factors might influence changes.
39
+ - The method solely focuses on linear relationships, potentially missing non-linear interactions.
40
+ """
41
+
42
+ prediction_prob_column = f"{model.input_id}_probabilities"
43
+ prediction_column = f"{model.input_id}_prediction"
44
+
45
+ df_corr = datasets[0]._df.corr()
46
+ df_corr = df_corr[[prediction_prob_column]]
47
+
48
+ df_corr2 = datasets[1]._df.corr()
49
+ df_corr2 = df_corr2[[prediction_prob_column]]
50
+
51
+ corr_final = df_corr.merge(df_corr2, left_index=True, right_index=True)
52
+ corr_final.columns = ["Reference Predictions", "Monitoring Predictions"]
53
+ corr_final = corr_final.drop(index=[prediction_column, prediction_prob_column])
54
+
55
+ n = len(corr_final)
56
+ r = np.arange(n)
57
+ width = 0.25
58
+
59
+ fig = plt.figure()
60
+
61
+ plt.bar(
62
+ r,
63
+ corr_final["Reference Predictions"],
64
+ color="b",
65
+ width=width,
66
+ edgecolor="black",
67
+ label="Reference Prediction Correlation",
68
+ )
69
+ plt.bar(
70
+ r + width,
71
+ corr_final["Monitoring Predictions"],
72
+ color="g",
73
+ width=width,
74
+ edgecolor="black",
75
+ label="Monitoring Prediction Correlation",
76
+ )
77
+
78
+ plt.xlabel("Features")
79
+ plt.ylabel("Correlation")
80
+ plt.title("Correlation between Predictions and Features")
81
+
82
+ features = corr_final.index.to_list()
83
+ plt.xticks(r + width / 2, features, rotation=45)
84
+ plt.legend()
85
+ plt.tight_layout()
86
+
87
+ corr_final["Features"] = corr_final.index
88
+ corr_final = corr_final[
89
+ ["Features", "Reference Predictions", "Monitoring Predictions"]
90
+ ]
91
+ return ({"Correlation Pair Table": corr_final}, fig)
@@ -0,0 +1,57 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+
8
+ from validmind import tags, tasks
9
+
10
+
11
+ @tags("visualization")
12
+ @tasks("monitoring")
13
+ def TargetPredictionDistributionPlot(datasets, model):
14
+ """
15
+ **Purpose:**
16
+ This test provides the prediction distributions from the reference dataset and the new monitoring dataset. If there
17
+ are significant differences in the distributions, it might indicate different underlying data characteristics that
18
+ warrant further investigation into the root causes.
19
+
20
+ **Test Mechanism:**
21
+ The methodology involves generating Kernel Density Estimation (KDE) plots for the prediction probabilities from
22
+ both the reference and monitoring datasets. By comparing these KDE plots, one can visually assess any significant
23
+ differences in the prediction distributions between the two datasets.
24
+
25
+ **Signs of High Risk:**
26
+ - Significant divergence between the distribution curves of the reference and monitoring predictions
27
+ - Unusual shifts or bimodal distribution in the monitoring predictions compared to the reference predictions
28
+
29
+ **Strengths:**
30
+ - Visual representation makes it easy to spot differences in prediction distributions
31
+ - Useful for identifying potential data drift or changes in underlying data characteristics
32
+ - Simple and efficient to implement using standard plotting libraries
33
+
34
+ **Limitations:**
35
+ - Subjective interpretation of the visual plots
36
+ - Might not pinpoint the exact cause of distribution changes
37
+ - Less effective if the differences in distributions are subtle and not easily visible
38
+ """
39
+
40
+ pred_ref = datasets[0].y_prob_df(model)
41
+ pred_ref.columns = ["Reference Prediction"]
42
+ pred_monitor = datasets[1].y_prob_df(model)
43
+ pred_monitor.columns = ["Monitoring Prediction"]
44
+
45
+ fig = plt.figure()
46
+ plot = sns.kdeplot(
47
+ pred_ref["Reference Prediction"], shade=True, label="Reference Prediction"
48
+ )
49
+ plot = sns.kdeplot(
50
+ pred_monitor["Monitoring Prediction"], shade=True, label="Monitor Prediction"
51
+ )
52
+ plot.set(
53
+ xlabel="Prediction", title="Distribution of Reference & Monitor Predictions"
54
+ )
55
+ plot.legend()
56
+
57
+ return fig