validmind 2.5.15__py3-none-any.whl → 2.5.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. validmind/__version__.py +1 -1
  2. validmind/ai/test_descriptions.py +54 -112
  3. validmind/ai/test_result_description/config.yaml +29 -0
  4. validmind/ai/test_result_description/context.py +73 -0
  5. validmind/ai/test_result_description/image_processing.py +124 -0
  6. validmind/ai/test_result_description/system.jinja +39 -0
  7. validmind/ai/test_result_description/user.jinja +25 -0
  8. validmind/datasets/credit_risk/__init__.py +1 -0
  9. validmind/datasets/credit_risk/datasets/lending_club_biased.csv.gz +0 -0
  10. validmind/datasets/credit_risk/lending_club_bias.py +142 -0
  11. validmind/tests/__types__.py +19 -10
  12. validmind/tests/{model_validation/statsmodels → data_validation}/BoxPierce.py +20 -24
  13. validmind/tests/data_validation/ChiSquaredFeaturesTable.py +4 -1
  14. validmind/tests/{model_validation/statsmodels → data_validation}/JarqueBera.py +22 -30
  15. validmind/tests/{model_validation/statsmodels → data_validation}/LJungBox.py +23 -27
  16. validmind/tests/data_validation/ProtectedClassesCombination.py +197 -0
  17. validmind/tests/data_validation/ProtectedClassesDescription.py +130 -0
  18. validmind/tests/data_validation/ProtectedClassesDisparity.py +133 -0
  19. validmind/tests/data_validation/ProtectedClassesThresholdOptimizer.py +172 -0
  20. validmind/tests/{model_validation/statsmodels → data_validation}/RunsTest.py +17 -20
  21. validmind/tests/{model_validation/statsmodels → data_validation}/ShapiroWilk.py +20 -22
  22. validmind/tests/data_validation/nlp/Hashtags.py +15 -20
  23. validmind/tests/data_validation/nlp/TextDescription.py +3 -1
  24. validmind/tests/model_validation/ContextualRecall.py +3 -0
  25. validmind/tests/model_validation/ragas/AspectCritique.py +5 -6
  26. validmind/tests/model_validation/ragas/ContextUtilization.py +155 -0
  27. validmind/tests/model_validation/ragas/NoiseSensitivity.py +152 -0
  28. validmind/tests/model_validation/sklearn/FeatureImportance.py +3 -3
  29. validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +1 -1
  30. validmind/tests/model_validation/sklearn/RegressionR2Square.py +1 -2
  31. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +59 -0
  32. validmind/tests/model_validation/statsmodels/DurbinWatsonTest.py +40 -20
  33. validmind/tests/model_validation/statsmodels/PredictionProbabilitiesHistogram.py +0 -1
  34. validmind/tests/model_validation/statsmodels/RegressionCoeffs.py +1 -1
  35. validmind/utils.py +4 -0
  36. validmind/vm_models/test/metric.py +1 -0
  37. validmind/vm_models/test/result_wrapper.py +50 -26
  38. validmind/vm_models/test/threshold_test.py +1 -0
  39. {validmind-2.5.15.dist-info → validmind-2.5.18.dist-info}/METADATA +4 -3
  40. {validmind-2.5.15.dist-info → validmind-2.5.18.dist-info}/RECORD +43 -30
  41. {validmind-2.5.15.dist-info → validmind-2.5.18.dist-info}/LICENSE +0 -0
  42. {validmind-2.5.15.dist-info → validmind-2.5.18.dist-info}/WHEEL +0 -0
  43. {validmind-2.5.15.dist-info → validmind-2.5.18.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,155 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import warnings
6
+
7
+ import plotly.express as px
8
+ from datasets import Dataset
9
+
10
+ from validmind import tags, tasks
11
+
12
+ from .utils import get_ragas_config, get_renamed_columns
13
+
14
+
15
+ @tags("ragas", "llm", "retrieval_performance")
16
+ @tasks("text_qa", "text_generation", "text_summarization", "text_classification")
17
+ def ContextUtilization(
18
+ dataset,
19
+ question_column: str = "question",
20
+ contexts_column: str = "contexts",
21
+ answer_column: str = "answer",
22
+ ): # noqa: B950
23
+ """
24
+ Assesses how effectively relevant context chunks are utilized in generating answers by evaluating their ranking
25
+ within the provided contexts.
26
+
27
+ ### Purpose
28
+
29
+ The Context Utilization test evaluates whether all of the answer-relevant items present in the contexts are ranked
30
+ higher within the provided retrieval results. This metric is essential for assessing the performance of models,
31
+ especially those involved in tasks such as text QA, text generation, text summarization, and text classification.
32
+
33
+ ### Test Mechanism
34
+
35
+ The test calculates Context Utilization using the formula:
36
+
37
+ $$
38
+ \\text{Context Utilization@K} = \\frac{\\sum_{k=1}^{K} \\left( \\text{Precision@k} \\times v_k \\right)}{\\text{Total number of relevant items in the top } K \\text{ results}}
39
+ $$
40
+ $$
41
+ \\text{Precision@k} = {\\text{true positives@k} \\over (\\text{true positives@k} + \\text{false positives@k})}
42
+ $$
43
+
44
+ Where $K$ is the total number of chunks in `contexts` and $v_k \\in \\{0, 1\\}$ is the relevance indicator at rank $k$.
45
+
46
+
47
+ This test uses columns for questions, contexts, and answers from the dataset and computes context utilization
48
+ scores, generating a histogram and box plot for visualization.
49
+
50
+ #### Configuring Columns
51
+
52
+ This metric requires the following columns in your dataset:
53
+
54
+ - `question` (str): The text query that was input into the model.
55
+ - `contexts` (List[str]): A list of text contexts which are retrieved and which will be evaluated to
56
+ make sure they contain relevant info in the correct order.
57
+ - `answer` (str): The llm-generated response for the input `question`.
58
+
59
+ If the above data is not in the appropriate column, you can specify different column
60
+ names for these fields using the parameters `question_column`, `contexts_column`
61
+ and `ground_truth_column`.
62
+
63
+ For example, if your dataset has this data stored in different columns, you can
64
+ pass the following parameters:
65
+ ```python
66
+ {
67
+ "question_column": "question",
68
+ "contexts_column": "context_info"
69
+ "ground_truth_column": "my_ground_truth_col",
70
+ }
71
+ ```
72
+
73
+ If the data is stored as a dictionary in another column, specify the column and key
74
+ like this:
75
+ ```python
76
+ pred_col = dataset.prediction_column(model)
77
+ params = {
78
+ "contexts_column": f"{pred_col}.contexts",
79
+ "ground_truth_column": "my_ground_truth_col",
80
+ }
81
+ ```
82
+
83
+ For more complex situations, you can use a function to extract the data:
84
+ ```python
85
+ pred_col = dataset.prediction_column(model)
86
+ params = {
87
+ "contexts_column": lambda x: [x[pred_col]["context_message"]],
88
+ "ground_truth_column": "my_ground_truth_col",
89
+ }
90
+ ```
91
+
92
+ ### Signs of High Risk
93
+
94
+ - Very low mean or median context utilization scores, indicating poor usage of retrieved contexts.
95
+ - High standard deviation, suggesting inconsistent model performance.
96
+ - Low or minimal max scores, pointing to the model's failure to rank relevant contexts at top positions.
97
+
98
+ ### Strengths
99
+
100
+ - Quantifies the rank of relevant context chunks in generating responses.
101
+ - Provides clear visualizations through histograms and box plots for ease of interpretation.
102
+ - Adapts to different dataset schema by allowing configurable column names.
103
+
104
+ ### Limitations
105
+
106
+ - Assumes the relevance of context chunks is binary and may not capture nuances of partial relevance.
107
+ - Requires proper context retrieval to be effective; irrelevant context chunks can skew the results.
108
+ - Dependent on large sample sizes to provide stable and reliable estimates of utilization performance.
109
+ """
110
+ try:
111
+ from ragas import evaluate
112
+ from ragas.metrics import context_utilization
113
+ except ImportError:
114
+ raise ImportError("Please run `pip install validmind[llm]` to use LLM tests")
115
+
116
+ warnings.filterwarnings(
117
+ "ignore",
118
+ category=FutureWarning,
119
+ message="promote has been superseded by promote_options='default'.",
120
+ )
121
+
122
+ required_columns = {
123
+ "question": question_column,
124
+ "contexts": contexts_column,
125
+ "answer": answer_column,
126
+ }
127
+
128
+ df = get_renamed_columns(dataset._df, required_columns)
129
+
130
+ result_df = evaluate(
131
+ Dataset.from_pandas(df), metrics=[context_utilization], **get_ragas_config()
132
+ ).to_pandas()
133
+
134
+ fig_histogram = px.histogram(x=result_df["context_utilization"].to_list(), nbins=10)
135
+ fig_box = px.box(x=result_df["context_utilization"].to_list())
136
+
137
+ return (
138
+ {
139
+ # "Scores (will not be uploaded to UI)": result_df[
140
+ # ["question", "contexts", "answer", "context_utilization"]
141
+ # ],
142
+ "Aggregate Scores": [
143
+ {
144
+ "Mean Score": result_df["context_utilization"].mean(),
145
+ "Median Score": result_df["context_utilization"].median(),
146
+ "Max Score": result_df["context_utilization"].max(),
147
+ "Min Score": result_df["context_utilization"].min(),
148
+ "Standard Deviation": result_df["context_utilization"].std(),
149
+ "Count": result_df.shape[0],
150
+ }
151
+ ],
152
+ },
153
+ fig_histogram,
154
+ fig_box,
155
+ )
@@ -0,0 +1,152 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import warnings
6
+
7
+ import plotly.express as px
8
+ from datasets import Dataset
9
+
10
+ from validmind import tags, tasks
11
+
12
+ from .utils import get_ragas_config, get_renamed_columns
13
+
14
+
15
+ @tags("ragas", "llm", "rag_performance")
16
+ @tasks("text_qa", "text_generation", "text_summarization")
17
+ def NoiseSensitivity(
18
+ dataset,
19
+ answer_column="answer",
20
+ contexts_column="contexts",
21
+ ground_truth_column="ground_truth",
22
+ ):
23
+ """
24
+ Assesses the sensitivity of a Large Language Model (LLM) to noise in retrieved context by measuring how often it
25
+ generates incorrect responses.
26
+
27
+ ### Purpose
28
+
29
+ The Noise Sensitivity test aims to measure how sensitive an LLM is to irrelevant or noisy information within the
30
+ contextual data used to generate its responses. A lower noise sensitivity score suggests better model robustness in
31
+ generating accurate answers from given contexts.
32
+
33
+ ### Test Mechanism
34
+
35
+ This test evaluates the model's answers by comparing the claims made in the generated response against the ground
36
+ truth and the retrieved context. The noise sensitivity score is calculated as:
37
+
38
+ $$
39
+ \\text{noise sensitivity} = {|\\text{Number of incorrect claims in answer}| \\over |\\text{Number of total claims in answer}|}
40
+ $$
41
+
42
+ The formula computes the fraction of incorrect claims to the total claims in the answer, using a dataset where
43
+ 'answer', 'context', and 'ground_truth' columns are specified.
44
+
45
+ #### Configuring Columns
46
+
47
+ This metric requires the following columns in your dataset:
48
+
49
+ - `contexts` (List[str]): A list of text contexts which are retrieved to generate
50
+ the answer.
51
+ - `answer` (str): The response generated by the model
52
+ - `ground_truth` (str): The "correct" answer to the question
53
+
54
+ If the above data is not in the appropriate column, you can specify different column
55
+ names for these fields using the parameters `contexts_column` and `answer_column`.
56
+
57
+ For example, if your dataset has this data stored in different columns, you can
58
+ pass the following parameters:
59
+ ```python
60
+ {
61
+ "contexts_column": "context_info"
62
+ "answer_column": "my_answer_col",
63
+ }
64
+ ```
65
+
66
+ If the data is stored as a dictionary in another column, specify the column and key
67
+ like this:
68
+ ```python
69
+ pred_col = dataset.prediction_column(model)
70
+ params = {
71
+ "contexts_column": f"{pred_col}.contexts",
72
+ "answer_column": f"{pred_col}.answer",
73
+ }
74
+ ```
75
+
76
+ For more complex situations, you can use a function to extract the data:
77
+ ```python
78
+ pred_col = dataset.prediction_column(model)
79
+ params = {
80
+ "contexts_column": lambda row: [row[pred_col]["context_message"]],
81
+ "answer_column": lambda row: "\\n\\n".join(row[pred_col]["messages"]),
82
+ }
83
+
84
+ ### Signs of High Risk
85
+
86
+ - High noise sensitivity scores across multiple samples.
87
+ - Significant deviation between mean and median noise sensitivity scores.
88
+ - High standard deviation indicating inconsistency in the model's performance.
89
+
90
+ ### Strengths
91
+
92
+ - Provides a quantitative measure of how well the LLM handles noisy or irrelevant context.
93
+ - Easy integration and configuration using column parameters.
94
+ - Utilizes both histogram and box plot visualizations to analyze score distribution.
95
+
96
+ ### Limitations
97
+
98
+ - Requires accurate ground truth that aligns with the generated answers.
99
+ - Assumes the context provided is sufficiently granular to assess noise sensitivity.
100
+ - Primarily applicable to tasks like text QA, text generation, and text summarization where contextual relevance is
101
+ critical.
102
+ """
103
+ try:
104
+ from ragas import evaluate
105
+ from ragas.metrics import noise_sensitivity_relevant
106
+ except ImportError:
107
+ raise ImportError("Please run `pip install validmind[llm]` to use LLM tests")
108
+
109
+ warnings.filterwarnings(
110
+ "ignore",
111
+ category=FutureWarning,
112
+ message="promote has been superseded by promote_options='default'.",
113
+ )
114
+
115
+ required_columns = {
116
+ "answer": answer_column,
117
+ "contexts": contexts_column,
118
+ "ground_truth": ground_truth_column,
119
+ }
120
+
121
+ df = get_renamed_columns(dataset._df, required_columns)
122
+
123
+ result_df = evaluate(
124
+ Dataset.from_pandas(df),
125
+ metrics=[noise_sensitivity_relevant],
126
+ **get_ragas_config(),
127
+ ).to_pandas()
128
+
129
+ fig_histogram = px.histogram(
130
+ x=result_df["noise_sensitivity_relevant"].to_list(), nbins=10
131
+ )
132
+ fig_box = px.box(x=result_df["noise_sensitivity_relevant"].to_list())
133
+
134
+ return (
135
+ {
136
+ # "Scores (will not be uploaded to UI)": result_df[
137
+ # ["contexts", "answer", "ground_truth", "noise_sensitivity_relevant"]
138
+ # ],
139
+ "Aggregate Scores": [
140
+ {
141
+ "Mean Score": result_df["noise_sensitivity_relevant"].mean(),
142
+ "Median Score": result_df["noise_sensitivity_relevant"].median(),
143
+ "Max Score": result_df["noise_sensitivity_relevant"].max(),
144
+ "Min Score": result_df["noise_sensitivity_relevant"].min(),
145
+ "Standard Deviation": result_df["noise_sensitivity_relevant"].std(),
146
+ "Count": result_df.shape[0],
147
+ }
148
+ ],
149
+ },
150
+ fig_histogram,
151
+ fig_box,
152
+ )
@@ -81,9 +81,9 @@ def FeatureImportance(dataset, model, num_features=3):
81
81
  # Dynamically add feature columns to the result
82
82
  for i in range(num_features):
83
83
  if i < len(top_features):
84
- result[f"Feature {i + 1}"] = (
85
- f"[{top_features[i][0]}; {top_features[i][1]:.4f}]"
86
- )
84
+ result[
85
+ f"Feature {i + 1}"
86
+ ] = f"[{top_features[i][0]}; {top_features[i][1]:.4f}]"
87
87
  else:
88
88
  result[f"Feature {i + 1}"] = None
89
89
 
@@ -109,7 +109,7 @@ class PermutationFeatureImportance(Metric):
109
109
  )
110
110
  )
111
111
  fig.update_layout(
112
- title_text="Permutation Importances (train set)",
112
+ title_text="Permutation Importances",
113
113
  yaxis=dict(
114
114
  tickmode="linear", # set tick mode to linear
115
115
  dtick=1, # set interval between ticks
@@ -3,11 +3,10 @@
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
5
  import pandas as pd
6
-
7
6
  from sklearn import metrics
8
7
 
9
- from validmind.tests.model_validation.statsmodels.statsutils import adj_r2_score
10
8
  from validmind import tags, tasks
9
+ from validmind.tests.model_validation.statsmodels.statsutils import adj_r2_score
11
10
 
12
11
 
13
12
  @tags("sklearn", "model_performance")
@@ -78,6 +78,7 @@ class SHAPGlobalImportance(Metric):
78
78
  default_params = {
79
79
  "kernel_explainer_samples": 10,
80
80
  "tree_or_linear_explainer_samples": 200,
81
+ "class_of_interest": None,
81
82
  }
82
83
 
83
84
  def _generate_shap_plot(self, type_, shap_values, x_test):
@@ -107,6 +108,7 @@ class SHAPGlobalImportance(Metric):
107
108
  shap_values / max_shap_value * 100
108
109
  ) # scaling factor to make the top feature 100%
109
110
  summary_plot_extra_args = {"plot_type": "bar"}
111
+
110
112
  shap.summary_plot(
111
113
  shap_values, x_test, show=False, **summary_plot_extra_args
112
114
  )
@@ -192,6 +194,10 @@ class SHAPGlobalImportance(Metric):
192
194
 
193
195
  shap_values = explainer.shap_values(shap_sample)
194
196
 
197
+ # Select the SHAP values for the specified class (classification) or for the regression output.
198
+ class_of_interest = self.params["class_of_interest"]
199
+ shap_values = _select_shap_values(shap_values, class_of_interest)
200
+
195
201
  figures = [
196
202
  self._generate_shap_plot("mean", shap_values, shap_sample),
197
203
  self._generate_shap_plot("summary", shap_values, shap_sample),
@@ -214,3 +220,56 @@ class SHAPGlobalImportance(Metric):
214
220
  for fig_num, type_ in enumerate(["mean", "summary"], start=1):
215
221
  assert isinstance(self.result.figures[fig_num - 1], Figure)
216
222
  assert self.result.figures[fig_num - 1].metadata["type"] == type_
223
+
224
+
225
+ def _select_shap_values(shap_values, class_of_interest=None):
226
+ """
227
+ Selects SHAP values for binary or multiclass classification. For regression models,
228
+ returns the SHAP values directly as there are no classes.
229
+
230
+ Parameters:
231
+ -----------
232
+ shap_values : list or numpy.ndarray
233
+ The SHAP values returned by the SHAP explainer. For multiclass classification,
234
+ this will be a list where each element corresponds to a class. For regression,
235
+ this will be a single array of SHAP values.
236
+
237
+ class_of_interest : int, optional
238
+ The class index for which to retrieve SHAP values. If None (default), the function
239
+ will assume binary classification and use class 1 by default.
240
+
241
+ Returns:
242
+ --------
243
+ numpy.ndarray
244
+ The SHAP values for the specified class (classification) or for the regression output.
245
+
246
+ Raises:
247
+ -------
248
+ ValueError
249
+ If class_of_interest is specified and is out of bounds for the number of classes.
250
+ """
251
+ # Check if we are dealing with a multiclass classification
252
+ if isinstance(shap_values, list):
253
+ num_classes = len(shap_values)
254
+
255
+ # Default to class 1 for binary classification
256
+ if num_classes == 2 and class_of_interest is None:
257
+ logger.info(
258
+ "Binary classification detected: using SHAP values for class 1 (positive class)."
259
+ )
260
+ return shap_values[1]
261
+ else:
262
+ # Multiclass classification: use the specified class_of_interest
263
+ if class_of_interest is not None and 0 <= class_of_interest < num_classes:
264
+ logger.info(
265
+ f"Multiclass classification: using SHAP values for class {class_of_interest}."
266
+ )
267
+ return shap_values[class_of_interest]
268
+ else:
269
+ raise ValueError(
270
+ f"Invalid class_of_interest: {class_of_interest}. Must be between 0 and {num_classes - 1}."
271
+ )
272
+ else:
273
+ # For regression, return the SHAP values as they are
274
+ logger.info("Regression model detected: returning SHAP values as-is.")
275
+ return shap_values
@@ -2,15 +2,15 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
- from dataclasses import dataclass
6
-
5
+ import pandas as pd
7
6
  from statsmodels.stats.stattools import durbin_watson
8
7
 
9
- from validmind.vm_models import Metric
8
+ from validmind import tags, tasks
10
9
 
11
10
 
12
- @dataclass
13
- class DurbinWatsonTest(Metric):
11
+ @tasks("regression")
12
+ @tags("time_series_data", "forecasting", "statistical_test", "statsmodels")
13
+ def DurbinWatsonTest(dataset, model, threshold=[1.5, 2.5]):
14
14
  """
15
15
  Assesses autocorrelation in time series data features using the Durbin-Watson statistic.
16
16
 
@@ -49,18 +49,38 @@ class DurbinWatsonTest(Metric):
49
49
  to detect higher-order autocorrelation.
50
50
  """
51
51
 
52
- name = "durbin_watson"
53
- required_inputs = ["dataset"]
54
- tasks = ["regression"]
55
- tags = ["time_series_data", "forecasting", "statistical_test", "statsmodels"]
56
-
57
- def run(self):
58
- """
59
- Calculates DB for each of the dataset features
60
- """
61
- x_train = self.inputs.dataset.df
62
- dw_values = {}
63
- for col in x_train.columns:
64
- dw_values[col] = durbin_watson(x_train[col].values)
65
-
66
- return self.cache_results(dw_values)
52
+ # Validate threshold values
53
+ if not (0 < threshold[0] < threshold[1] < 4):
54
+ raise ValueError(
55
+ "Invalid threshold. It should be in the form [a, b] where 0 < a < b < 4."
56
+ )
57
+
58
+ # Check if threshold values are around 2
59
+ if abs(2 - threshold[0]) > 1 or abs(2 - threshold[1]) > 1:
60
+ raise ValueError(
61
+ "Threshold values should be around 2 for meaningful Durbin-Watson test results."
62
+ )
63
+
64
+ y_true = dataset.y
65
+ y_pred = dataset.y_pred(model)
66
+ residuals = y_true - y_pred
67
+
68
+ dw_statistic = durbin_watson(residuals)
69
+
70
+ def get_autocorrelation(dw_value, threshold):
71
+ if dw_value < threshold[0]:
72
+ return "Positive autocorrelation"
73
+ elif dw_value > threshold[1]:
74
+ return "Negative autocorrelation"
75
+ else:
76
+ return "No autocorrelation"
77
+
78
+ results = pd.DataFrame(
79
+ {
80
+ "dw_statistic": [dw_statistic],
81
+ "threshold": [str(threshold)],
82
+ "autocorrelation": [get_autocorrelation(dw_statistic, threshold)],
83
+ }
84
+ )
85
+
86
+ return results
@@ -6,7 +6,6 @@
6
6
  import plotly.graph_objects as go
7
7
  from matplotlib import cm
8
8
 
9
-
10
9
  from validmind import tags, tasks
11
10
 
12
11
 
@@ -7,8 +7,8 @@ import pandas as pd
7
7
  import plotly.graph_objects as go
8
8
  from scipy import stats
9
9
 
10
- from validmind.errors import SkipTestError
11
10
  from validmind import tags, tasks
11
+ from validmind.errors import SkipTestError
12
12
 
13
13
 
14
14
  @tags("tabular_data", "visualization", "model_training")
validmind/utils.py CHANGED
@@ -175,6 +175,10 @@ def format_records(df):
175
175
  continue
176
176
  not_zero = df[col][df[col] != 0]
177
177
  min_number = not_zero.min()
178
+ if math.isnan(min_number) or math.isinf(min_number):
179
+ df[col] = df[col].round(DEFAULT_SMALL_NUMBER_DECIMALS)
180
+ continue
181
+
178
182
  _, min_scale = precision_and_scale(min_number)
179
183
 
180
184
  if min_number >= 10:
@@ -77,6 +77,7 @@ class Metric(Test):
77
77
 
78
78
  self.result = MetricResultWrapper(
79
79
  result_id=self.test_id,
80
+ result_description=self.description(),
80
81
  result_metadata=[
81
82
  (
82
83
  get_description_metadata(
@@ -128,6 +128,8 @@ class ResultWrapper(ABC):
128
128
  # id of the result, can be set by the subclass. This helps
129
129
  # looking up results later on
130
130
  result_id: str = None
131
+ # Text description from test or metric (docstring usually)
132
+ result_description: str = None
131
133
  # Text metadata about the result, can include description, etc.
132
134
  result_metadata: List[dict] = None
133
135
  # Output template to use for rendering the result
@@ -300,38 +302,60 @@ class MetricResultWrapper(ResultWrapper):
300
302
  return VBox(vbox_children)
301
303
 
302
304
  def _get_filtered_summary(self):
303
- """Check if the metric summary has columns from input datasets"""
304
- dataset_columns = set()
305
-
306
- for input in self.inputs:
307
- input_id = input if isinstance(input, str) else input.input_id
308
- input_obj = input_registry.get(input_id)
309
- if isinstance(input_obj, VMDataset):
310
- dataset_columns.update(input_obj.columns)
311
-
312
- for table in [*self.metric.summary.results]:
313
- columns = set()
305
+ """Check if the metric summary has columns from input datasets with matching row counts."""
306
+ dataset_columns = self._get_dataset_columns()
307
+ filtered_results = []
308
+
309
+ for table in self.metric.summary.results:
310
+ table_columns = self._get_table_columns(table)
311
+ sensitive_columns = self._find_sensitive_columns(
312
+ dataset_columns, table_columns
313
+ )
314
314
 
315
- if isinstance(table.data, pd.DataFrame):
316
- columns.update(table.data.columns)
317
- elif isinstance(table.data, list):
318
- columns.update(table.data[0].keys())
315
+ if sensitive_columns:
316
+ self._log_sensitive_data_warning(sensitive_columns)
319
317
  else:
320
- raise ValueError("Invalid data type in summary table")
318
+ filtered_results.append(table)
321
319
 
322
- if bool(columns.intersection(dataset_columns)):
323
- logger.warning(
324
- "Sensitive data in metric summary table. Not logging to API automatically."
325
- " Pass `unsafe=True` to result.log() method to override manually."
326
- )
327
- logger.warning(
328
- f"The following columns are present in the table: {columns}"
329
- f" and also present in the dataset: {dataset_columns}"
320
+ self.metric.summary.results = filtered_results
321
+ return self.metric.summary
322
+
323
+ def _get_dataset_columns(self):
324
+ dataset_columns = {}
325
+ for input_item in self.inputs:
326
+ input_id = (
327
+ input_item if isinstance(input_item, str) else input_item.input_id
328
+ )
329
+ input_obj = input_registry.get(input_id)
330
+ if isinstance(input_obj, VMDataset):
331
+ dataset_columns.update(
332
+ {col: len(input_obj.df) for col in input_obj.columns}
330
333
  )
334
+ return dataset_columns
331
335
 
332
- self.metric.summary.results.remove(table)
336
+ def _get_table_columns(self, table):
337
+ if isinstance(table.data, pd.DataFrame):
338
+ return {col: len(table.data) for col in table.data.columns}
339
+ elif isinstance(table.data, list) and table.data:
340
+ return {col: len(table.data) for col in table.data[0].keys()}
341
+ else:
342
+ raise ValueError("Invalid data type in summary table")
333
343
 
334
- return self.metric.summary
344
+ def _find_sensitive_columns(self, dataset_columns, table_columns):
345
+ return [
346
+ col
347
+ for col, row_count in table_columns.items()
348
+ if col in dataset_columns and row_count == dataset_columns[col]
349
+ ]
350
+
351
+ def _log_sensitive_data_warning(self, sensitive_columns):
352
+ logger.warning(
353
+ "Sensitive data in metric summary table. Not logging to API automatically. "
354
+ "Pass `unsafe=True` to result.log() method to override manually."
355
+ )
356
+ logger.warning(
357
+ f"The following columns are present in the table with matching row counts: {sensitive_columns}"
358
+ )
335
359
 
336
360
  async def log_async(
337
361
  self, section_id: str = None, position: int = None, unsafe=False
@@ -80,6 +80,7 @@ class ThresholdTest(Test):
80
80
 
81
81
  self.result = ThresholdTestResultWrapper(
82
82
  result_id=self.test_id,
83
+ result_description=self.description(),
83
84
  result_metadata=[
84
85
  get_description_metadata(
85
86
  test_id=self.test_id,