validmind 2.5.15__py3-none-any.whl → 2.5.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__version__.py +1 -1
- validmind/ai/test_descriptions.py +54 -112
- validmind/ai/test_result_description/config.yaml +29 -0
- validmind/ai/test_result_description/context.py +73 -0
- validmind/ai/test_result_description/image_processing.py +124 -0
- validmind/ai/test_result_description/system.jinja +39 -0
- validmind/ai/test_result_description/user.jinja +25 -0
- validmind/datasets/credit_risk/__init__.py +1 -0
- validmind/datasets/credit_risk/datasets/lending_club_biased.csv.gz +0 -0
- validmind/datasets/credit_risk/lending_club_bias.py +142 -0
- validmind/errors.py +17 -0
- validmind/tests/__types__.py +19 -10
- validmind/tests/{model_validation/statsmodels → data_validation}/BoxPierce.py +20 -24
- validmind/tests/data_validation/ChiSquaredFeaturesTable.py +4 -1
- validmind/tests/{model_validation/statsmodels → data_validation}/JarqueBera.py +22 -30
- validmind/tests/{model_validation/statsmodels → data_validation}/LJungBox.py +23 -27
- validmind/tests/data_validation/ProtectedClassesCombination.py +205 -0
- validmind/tests/data_validation/ProtectedClassesDescription.py +130 -0
- validmind/tests/data_validation/ProtectedClassesDisparity.py +141 -0
- validmind/tests/data_validation/ProtectedClassesThresholdOptimizer.py +180 -0
- validmind/tests/{model_validation/statsmodels → data_validation}/RunsTest.py +17 -20
- validmind/tests/{model_validation/statsmodels → data_validation}/ShapiroWilk.py +20 -22
- validmind/tests/data_validation/nlp/Hashtags.py +15 -20
- validmind/tests/data_validation/nlp/TextDescription.py +3 -1
- validmind/tests/load.py +21 -5
- validmind/tests/model_validation/ContextualRecall.py +3 -0
- validmind/tests/model_validation/ragas/AnswerCorrectness.py +12 -5
- validmind/tests/model_validation/ragas/AnswerRelevance.py +12 -6
- validmind/tests/model_validation/ragas/AnswerSimilarity.py +12 -6
- validmind/tests/model_validation/ragas/AspectCritique.py +22 -17
- validmind/tests/model_validation/ragas/ContextEntityRecall.py +12 -6
- validmind/tests/model_validation/ragas/ContextPrecision.py +12 -6
- validmind/tests/model_validation/ragas/ContextRecall.py +12 -6
- validmind/tests/model_validation/ragas/ContextUtilization.py +161 -0
- validmind/tests/model_validation/ragas/Faithfulness.py +12 -6
- validmind/tests/model_validation/ragas/NoiseSensitivity.py +158 -0
- validmind/tests/model_validation/sklearn/FeatureImportance.py +3 -3
- validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +1 -1
- validmind/tests/model_validation/sklearn/RegressionR2Square.py +1 -2
- validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +59 -0
- validmind/tests/model_validation/statsmodels/DurbinWatsonTest.py +40 -20
- validmind/tests/model_validation/statsmodels/PredictionProbabilitiesHistogram.py +0 -1
- validmind/tests/model_validation/statsmodels/RegressionCoeffs.py +1 -1
- validmind/utils.py +4 -0
- validmind/vm_models/test/metric.py +1 -0
- validmind/vm_models/test/result_wrapper.py +50 -26
- validmind/vm_models/test/threshold_test.py +1 -0
- {validmind-2.5.15.dist-info → validmind-2.5.19.dist-info}/METADATA +4 -3
- {validmind-2.5.15.dist-info → validmind-2.5.19.dist-info}/RECORD +52 -39
- {validmind-2.5.15.dist-info → validmind-2.5.19.dist-info}/WHEEL +1 -1
- {validmind-2.5.15.dist-info → validmind-2.5.19.dist-info}/LICENSE +0 -0
- {validmind-2.5.15.dist-info → validmind-2.5.19.dist-info}/entry_points.txt +0 -0
validmind/errors.py
CHANGED
@@ -207,6 +207,23 @@ class MissingRequiredTestInputError(BaseError):
|
|
207
207
|
pass
|
208
208
|
|
209
209
|
|
210
|
+
class MissingDependencyError(BaseError):
|
211
|
+
"""
|
212
|
+
When a required dependency is missing.
|
213
|
+
"""
|
214
|
+
|
215
|
+
def __init__(self, message="", required_dependencies=None, extra=None):
|
216
|
+
"""
|
217
|
+
Args:
|
218
|
+
message (str): The error message.
|
219
|
+
required_dependencies (list): A list of required dependencies.
|
220
|
+
extra (str): The particular validmind `extra` that will install the missing dependencies.
|
221
|
+
"""
|
222
|
+
super().__init__(message)
|
223
|
+
self.required_dependencies = required_dependencies or []
|
224
|
+
self.extra = extra
|
225
|
+
|
226
|
+
|
210
227
|
class MissingRExtrasError(BaseError):
|
211
228
|
"""
|
212
229
|
When the R extras have not been installed.
|
validmind/tests/__types__.py
CHANGED
@@ -33,7 +33,6 @@ TestID = Literal[
|
|
33
33
|
"validmind.model_validation.ClusterSizeDistribution",
|
34
34
|
"validmind.model_validation.TokenDisparity",
|
35
35
|
"validmind.model_validation.ToxicityScore",
|
36
|
-
"validmind.model_validation.ModelMetadata",
|
37
36
|
"validmind.model_validation.TimeSeriesR2SquareBySegments",
|
38
37
|
"validmind.model_validation.embeddings.CosineSimilarityComparison",
|
39
38
|
"validmind.model_validation.embeddings.EmbeddingsVisualization2D",
|
@@ -53,12 +52,13 @@ TestID = Literal[
|
|
53
52
|
"validmind.model_validation.ragas.ContextEntityRecall",
|
54
53
|
"validmind.model_validation.ragas.Faithfulness",
|
55
54
|
"validmind.model_validation.ragas.AspectCritique",
|
55
|
+
"validmind.model_validation.ragas.NoiseSensitivity",
|
56
56
|
"validmind.model_validation.ragas.AnswerSimilarity",
|
57
57
|
"validmind.model_validation.ragas.AnswerCorrectness",
|
58
58
|
"validmind.model_validation.ragas.ContextRecall",
|
59
59
|
"validmind.model_validation.ragas.ContextPrecision",
|
60
60
|
"validmind.model_validation.ragas.AnswerRelevance",
|
61
|
-
"validmind.model_validation.
|
61
|
+
"validmind.model_validation.ragas.ContextUtilization",
|
62
62
|
"validmind.model_validation.sklearn.AdjustedMutualInformation",
|
63
63
|
"validmind.model_validation.sklearn.SilhouettePlot",
|
64
64
|
"validmind.model_validation.sklearn.RobustnessDiagnosis",
|
@@ -77,35 +77,35 @@ TestID = Literal[
|
|
77
77
|
"validmind.model_validation.sklearn.ClassifierPerformance",
|
78
78
|
"validmind.model_validation.sklearn.VMeasure",
|
79
79
|
"validmind.model_validation.sklearn.MinimumF1Score",
|
80
|
+
"validmind.model_validation.sklearn.RegressionPerformance",
|
80
81
|
"validmind.model_validation.sklearn.ROCCurve",
|
81
82
|
"validmind.model_validation.sklearn.RegressionR2Square",
|
82
83
|
"validmind.model_validation.sklearn.RegressionErrors",
|
83
84
|
"validmind.model_validation.sklearn.ClusterPerformance",
|
84
|
-
"validmind.model_validation.sklearn.FeatureImportance",
|
85
85
|
"validmind.model_validation.sklearn.TrainingTestDegradation",
|
86
|
+
"validmind.model_validation.sklearn.RegressionErrorsComparison",
|
87
|
+
"validmind.model_validation.sklearn.FeatureImportance",
|
86
88
|
"validmind.model_validation.sklearn.HyperParametersTuning",
|
87
89
|
"validmind.model_validation.sklearn.KMeansClustersOptimization",
|
88
90
|
"validmind.model_validation.sklearn.ModelsPerformanceComparison",
|
89
91
|
"validmind.model_validation.sklearn.WeakspotsDiagnosis",
|
92
|
+
"validmind.model_validation.sklearn.RegressionR2SquareComparison",
|
90
93
|
"validmind.model_validation.sklearn.PopulationStabilityIndex",
|
91
94
|
"validmind.model_validation.sklearn.MinimumAccuracy",
|
92
|
-
"validmind.model_validation.statsmodels.
|
93
|
-
"validmind.model_validation.statsmodels.
|
94
|
-
"validmind.model_validation.statsmodels.RegressionCoeffsPlot",
|
95
|
+
"validmind.model_validation.statsmodels.RegressionModelSensitivityPlot",
|
96
|
+
"validmind.model_validation.statsmodels.RegressionModelForecastPlotLevels",
|
95
97
|
"validmind.model_validation.statsmodels.ScorecardHistogram",
|
96
|
-
"validmind.model_validation.statsmodels.LJungBox",
|
97
|
-
"validmind.model_validation.statsmodels.JarqueBera",
|
98
98
|
"validmind.model_validation.statsmodels.KolmogorovSmirnov",
|
99
|
-
"validmind.model_validation.statsmodels.ShapiroWilk",
|
100
99
|
"validmind.model_validation.statsmodels.CumulativePredictionProbabilities",
|
101
100
|
"validmind.model_validation.statsmodels.RegressionFeatureSignificance",
|
102
101
|
"validmind.model_validation.statsmodels.RegressionModelSummary",
|
102
|
+
"validmind.model_validation.statsmodels.RegressionCoeffs",
|
103
103
|
"validmind.model_validation.statsmodels.Lilliefors",
|
104
|
-
"validmind.model_validation.statsmodels.RunsTest",
|
105
104
|
"validmind.model_validation.statsmodels.RegressionPermutationFeatureImportance",
|
106
105
|
"validmind.model_validation.statsmodels.PredictionProbabilitiesHistogram",
|
107
106
|
"validmind.model_validation.statsmodels.AutoARIMA",
|
108
107
|
"validmind.model_validation.statsmodels.GINITable",
|
108
|
+
"validmind.model_validation.statsmodels.RegressionModelForecastPlot",
|
109
109
|
"validmind.model_validation.statsmodels.DurbinWatsonTest",
|
110
110
|
"validmind.ongoing_monitoring.PredictionCorrelation",
|
111
111
|
"validmind.ongoing_monitoring.PredictionAcrossEachFeature",
|
@@ -113,9 +113,11 @@ TestID = Literal[
|
|
113
113
|
"validmind.ongoing_monitoring.TargetPredictionDistributionPlot",
|
114
114
|
"validmind.data_validation.IQROutliersTable",
|
115
115
|
"validmind.data_validation.Skewness",
|
116
|
+
"validmind.data_validation.BoxPierce",
|
116
117
|
"validmind.data_validation.Duplicates",
|
117
118
|
"validmind.data_validation.MissingValuesBarPlot",
|
118
119
|
"validmind.data_validation.DatasetDescription",
|
120
|
+
"validmind.data_validation.ProtectedClassesCombination",
|
119
121
|
"validmind.data_validation.ZivotAndrewsArch",
|
120
122
|
"validmind.data_validation.ScatterPlot",
|
121
123
|
"validmind.data_validation.TimeSeriesOutliers",
|
@@ -123,7 +125,9 @@ TestID = Literal[
|
|
123
125
|
"validmind.data_validation.AutoStationarity",
|
124
126
|
"validmind.data_validation.DescriptiveStatistics",
|
125
127
|
"validmind.data_validation.TimeSeriesDescription",
|
128
|
+
"validmind.data_validation.LJungBox",
|
126
129
|
"validmind.data_validation.TargetRateBarPlots",
|
130
|
+
"validmind.data_validation.JarqueBera",
|
127
131
|
"validmind.data_validation.PearsonCorrelationMatrix",
|
128
132
|
"validmind.data_validation.FeatureTargetCorrelationPlot",
|
129
133
|
"validmind.data_validation.TabularNumericalHistograms",
|
@@ -133,9 +137,11 @@ TestID = Literal[
|
|
133
137
|
"validmind.data_validation.MissingValues",
|
134
138
|
"validmind.data_validation.PhillipsPerronArch",
|
135
139
|
"validmind.data_validation.RollingStatsPlot",
|
140
|
+
"validmind.data_validation.ProtectedClassesDisparity",
|
136
141
|
"validmind.data_validation.TabularDescriptionTables",
|
137
142
|
"validmind.data_validation.AutoMA",
|
138
143
|
"validmind.data_validation.UniqueRows",
|
144
|
+
"validmind.data_validation.ShapiroWilk",
|
139
145
|
"validmind.data_validation.TooManyZeroValues",
|
140
146
|
"validmind.data_validation.HighPearsonCorrelation",
|
141
147
|
"validmind.data_validation.ACFandPACFPlot",
|
@@ -146,10 +152,12 @@ TestID = Literal[
|
|
146
152
|
"validmind.data_validation.TimeSeriesLinePlot",
|
147
153
|
"validmind.data_validation.KPSS",
|
148
154
|
"validmind.data_validation.AutoSeasonality",
|
155
|
+
"validmind.data_validation.ProtectedClassesDescription",
|
149
156
|
"validmind.data_validation.BivariateScatterPlots",
|
150
157
|
"validmind.data_validation.EngleGrangerCoint",
|
151
158
|
"validmind.data_validation.TimeSeriesMissingValues",
|
152
159
|
"validmind.data_validation.TimeSeriesHistogram",
|
160
|
+
"validmind.data_validation.RunsTest",
|
153
161
|
"validmind.data_validation.LaggedCorrelationHeatmap",
|
154
162
|
"validmind.data_validation.SeasonalDecompose",
|
155
163
|
"validmind.data_validation.WOEBinPlots",
|
@@ -159,6 +167,7 @@ TestID = Literal[
|
|
159
167
|
"validmind.data_validation.TimeSeriesDescriptiveStatistics",
|
160
168
|
"validmind.data_validation.AutoAR",
|
161
169
|
"validmind.data_validation.TabularDateTimeHistograms",
|
170
|
+
"validmind.data_validation.ProtectedClassesThresholdOptimizer",
|
162
171
|
"validmind.data_validation.ADF",
|
163
172
|
"validmind.data_validation.nlp.Toxicity",
|
164
173
|
"validmind.data_validation.nlp.PolarityAndSubjectivity",
|
@@ -2,12 +2,15 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
+
import pandas as pd
|
5
6
|
from statsmodels.stats.diagnostic import acorr_ljungbox
|
6
7
|
|
7
|
-
from validmind
|
8
|
+
from validmind import tags, tasks
|
8
9
|
|
9
10
|
|
10
|
-
|
11
|
+
@tasks("regression")
|
12
|
+
@tags("time_series_data", "forecasting", "statistical_test", "statsmodels")
|
13
|
+
def BoxPierce(dataset):
|
11
14
|
"""
|
12
15
|
Detects autocorrelation in time-series data through the Box-Pierce test to validate model performance.
|
13
16
|
|
@@ -51,25 +54,18 @@ class BoxPierce(Metric):
|
|
51
54
|
- Applicability is limited to time-series data, which limits its overall utility.
|
52
55
|
"""
|
53
56
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
)
|
70
|
-
box_pierce_values[col] = {
|
71
|
-
"stat": bp_results.iloc[0]["lb_stat"],
|
72
|
-
"pvalue": bp_results.iloc[0]["lb_pvalue"],
|
73
|
-
}
|
74
|
-
|
75
|
-
return self.cache_results(box_pierce_values)
|
57
|
+
df = dataset.df
|
58
|
+
|
59
|
+
box_pierce_values = {}
|
60
|
+
for col in df.columns:
|
61
|
+
bp_results = acorr_ljungbox(df[col].values, boxpierce=True, return_df=True)
|
62
|
+
box_pierce_values[col] = {
|
63
|
+
"stat": bp_results.iloc[0]["lb_stat"],
|
64
|
+
"pvalue": bp_results.iloc[0]["lb_pvalue"],
|
65
|
+
}
|
66
|
+
|
67
|
+
box_pierce_df = pd.DataFrame.from_dict(box_pierce_values, orient="index")
|
68
|
+
box_pierce_df.reset_index(inplace=True)
|
69
|
+
box_pierce_df.columns = ["column", "stat", "pvalue"]
|
70
|
+
|
71
|
+
return box_pierce_df
|
@@ -7,6 +7,7 @@ import pandas as pd
|
|
7
7
|
from scipy.stats import chi2_contingency
|
8
8
|
|
9
9
|
from validmind import tags, tasks
|
10
|
+
from validmind.errors import SkipTestError
|
10
11
|
|
11
12
|
|
12
13
|
@tags("tabular_data", "categorical_data", "statistical_test")
|
@@ -55,9 +56,11 @@ def ChiSquaredFeaturesTable(dataset, p_threshold=0.05):
|
|
55
56
|
"""
|
56
57
|
|
57
58
|
target_column = dataset.target_column
|
58
|
-
|
59
59
|
features = dataset.feature_columns_categorical
|
60
60
|
|
61
|
+
if not features:
|
62
|
+
raise SkipTestError("No categorical features found in dataset")
|
63
|
+
|
61
64
|
results_df = _chi_squared_categorical_feature_selection(
|
62
65
|
dataset.df, features, target_column, p_threshold
|
63
66
|
)
|
@@ -2,12 +2,15 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
+
import pandas as pd
|
5
6
|
from statsmodels.stats.stattools import jarque_bera
|
6
7
|
|
7
|
-
from validmind
|
8
|
+
from validmind import tags, tasks
|
8
9
|
|
9
10
|
|
10
|
-
|
11
|
+
@tasks("classification", "regression")
|
12
|
+
@tags("tabular_data", "data_distribution", "statistical_test", "statsmodels")
|
13
|
+
def JarqueBera(dataset):
|
11
14
|
"""
|
12
15
|
Assesses normality of dataset features in an ML model using the Jarque-Bera test.
|
13
16
|
|
@@ -48,31 +51,20 @@ class JarqueBera(Metric):
|
|
48
51
|
even for minor deviations in larger datasets.
|
49
52
|
"""
|
50
53
|
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
for col in x_train.columns:
|
69
|
-
jb_stat, jb_pvalue, jb_skew, jb_kurtosis = jarque_bera(x_train[col].values)
|
70
|
-
|
71
|
-
jb_values[col] = {
|
72
|
-
"stat": jb_stat,
|
73
|
-
"pvalue": jb_pvalue,
|
74
|
-
"skew": jb_skew,
|
75
|
-
"kurtosis": jb_kurtosis,
|
76
|
-
}
|
77
|
-
|
78
|
-
return self.cache_results(jb_values)
|
54
|
+
df = dataset.df[dataset.feature_columns_numeric]
|
55
|
+
|
56
|
+
jb_values = {}
|
57
|
+
for col in df.columns:
|
58
|
+
jb_stat, jb_pvalue, jb_skew, jb_kurtosis = jarque_bera(df[col].values)
|
59
|
+
jb_values[col] = {
|
60
|
+
"stat": jb_stat,
|
61
|
+
"pvalue": jb_pvalue,
|
62
|
+
"skew": jb_skew,
|
63
|
+
"kurtosis": jb_kurtosis,
|
64
|
+
}
|
65
|
+
|
66
|
+
jb_df = pd.DataFrame.from_dict(jb_values, orient="index")
|
67
|
+
jb_df.reset_index(inplace=True)
|
68
|
+
jb_df.columns = ["column", "stat", "pvalue", "skew", "kurtosis"]
|
69
|
+
|
70
|
+
return jb_df
|
@@ -2,12 +2,15 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
+
import pandas as pd
|
5
6
|
from statsmodels.stats.diagnostic import acorr_ljungbox
|
6
7
|
|
7
|
-
from validmind
|
8
|
+
from validmind import tags, tasks
|
8
9
|
|
9
10
|
|
10
|
-
|
11
|
+
@tasks("regression")
|
12
|
+
@tags("time_series_data", "forecasting", "statistical_test", "statsmodels")
|
13
|
+
def LJungBox(dataset):
|
11
14
|
"""
|
12
15
|
Assesses autocorrelations in dataset features by performing a Ljung-Box test on each feature.
|
13
16
|
|
@@ -20,11 +23,11 @@ class LJungBox(Metric):
|
|
20
23
|
|
21
24
|
### Test Mechanism
|
22
25
|
|
23
|
-
The test operates by iterating over each feature within the
|
26
|
+
The test operates by iterating over each feature within the dataset and applying the `acorr_ljungbox`
|
24
27
|
function from the `statsmodels.stats.diagnostic` library. This function calculates the Ljung-Box statistic and
|
25
|
-
p-value for each feature. These results are then stored in a
|
26
|
-
|
27
|
-
|
28
|
+
p-value for each feature. These results are then stored in a pandas DataFrame where the columns are the feature names,
|
29
|
+
statistic, and p-value respectively. Generally, a lower p-value indicates a higher likelihood of significant
|
30
|
+
autocorrelations within the feature.
|
28
31
|
|
29
32
|
### Signs of High Risk
|
30
33
|
|
@@ -41,30 +44,23 @@ class LJungBox(Metric):
|
|
41
44
|
### Limitations
|
42
45
|
|
43
46
|
- Cannot detect all types of non-linearity or complex interrelationships among variables.
|
44
|
-
- Testing individual features may not fully encapsulate the dynamics of the data if features interact with each
|
45
|
-
other.
|
47
|
+
- Testing individual features may not fully encapsulate the dynamics of the data if features interact with each other.
|
46
48
|
- Designed more for traditional statistical models and may not be fully compatible with certain types of complex
|
47
|
-
|
49
|
+
machine learning models.
|
48
50
|
"""
|
49
51
|
|
50
|
-
|
51
|
-
required_inputs = ["dataset"]
|
52
|
-
tasks = ["regression"]
|
53
|
-
tags = ["time_series_data", "forecasting", "statistical_test", "statsmodels"]
|
52
|
+
df = dataset.df
|
54
53
|
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
54
|
+
ljung_box_values = {}
|
55
|
+
for col in df.columns:
|
56
|
+
lb_results = acorr_ljungbox(df[col].values, return_df=True)
|
57
|
+
ljung_box_values[col] = {
|
58
|
+
"stat": lb_results.iloc[0]["lb_stat"],
|
59
|
+
"pvalue": lb_results.iloc[0]["lb_pvalue"],
|
60
|
+
}
|
60
61
|
|
61
|
-
|
62
|
-
|
63
|
-
|
62
|
+
ljung_box_df = pd.DataFrame.from_dict(ljung_box_values, orient="index")
|
63
|
+
ljung_box_df.reset_index(inplace=True)
|
64
|
+
ljung_box_df.columns = ["column", "stat", "pvalue"]
|
64
65
|
|
65
|
-
|
66
|
-
"stat": lb_results["lb_stat"].values[0],
|
67
|
-
"pvalue": lb_results["lb_pvalue"].values[0],
|
68
|
-
}
|
69
|
-
|
70
|
-
return self.cache_results(ljung_box_values)
|
66
|
+
return ljung_box_df
|
@@ -0,0 +1,205 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import sys
|
6
|
+
|
7
|
+
import pandas as pd
|
8
|
+
import plotly.graph_objects as go
|
9
|
+
import plotly.subplots as sp
|
10
|
+
|
11
|
+
from validmind import tags, tasks
|
12
|
+
from validmind.errors import MissingDependencyError
|
13
|
+
from validmind.logging import get_logger
|
14
|
+
|
15
|
+
try:
|
16
|
+
from fairlearn.metrics import (
|
17
|
+
MetricFrame,
|
18
|
+
count,
|
19
|
+
demographic_parity_ratio,
|
20
|
+
equalized_odds_ratio,
|
21
|
+
false_positive_rate,
|
22
|
+
selection_rate,
|
23
|
+
true_positive_rate,
|
24
|
+
)
|
25
|
+
except ImportError as e:
|
26
|
+
raise MissingDependencyError(
|
27
|
+
"Missing required package `fairlearn` for ProtectedClassesCombination.",
|
28
|
+
required_dependencies=["fairlearn"],
|
29
|
+
) from e
|
30
|
+
|
31
|
+
logger = get_logger(__name__)
|
32
|
+
|
33
|
+
|
34
|
+
@tags("bias_and_fairness")
|
35
|
+
@tasks("classification", "regression")
|
36
|
+
def ProtectedClassesCombination(dataset, model, protected_classes=None):
|
37
|
+
"""
|
38
|
+
Visualizes combinations of protected classes and their corresponding error metric differences.
|
39
|
+
|
40
|
+
### Purpose
|
41
|
+
|
42
|
+
This test aims to provide insights into how different combinations of protected classes affect various error metrics,
|
43
|
+
particularly the false negative rate (FNR) and false positive rate (FPR). By visualizing these combinations,
|
44
|
+
it helps identify potential biases or disparities in model performance across different intersectional groups.
|
45
|
+
|
46
|
+
### Test Mechanism
|
47
|
+
|
48
|
+
The test performs the following steps:
|
49
|
+
1. Combines the specified protected class columns to create a single multi-class category.
|
50
|
+
2. Calculates error metrics (FNR, FPR, etc.) for each combination of protected classes.
|
51
|
+
3. Generates visualizations showing the distribution of these metrics across all class combinations.
|
52
|
+
|
53
|
+
### Signs of High Risk
|
54
|
+
|
55
|
+
- Large disparities in FNR or FPR across different protected class combinations.
|
56
|
+
- Consistent patterns of higher error rates for specific combinations of protected attributes.
|
57
|
+
- Unexpected or unexplainable variations in error metrics between similar group combinations.
|
58
|
+
|
59
|
+
### Strengths
|
60
|
+
|
61
|
+
- Provides a comprehensive view of intersectional fairness across multiple protected attributes.
|
62
|
+
- Allows for easy identification of potentially problematic combinations of protected classes.
|
63
|
+
- Visualizations make it easier to spot patterns or outliers in model performance across groups.
|
64
|
+
|
65
|
+
### Limitations
|
66
|
+
|
67
|
+
- May become complex and difficult to interpret with a large number of protected classes or combinations.
|
68
|
+
- Does not provide statistical significance of observed differences.
|
69
|
+
- Visualization alone may not capture all nuances of intersectional fairness.
|
70
|
+
"""
|
71
|
+
|
72
|
+
if sys.version_info < (3, 9):
|
73
|
+
raise RuntimeError("This test requires Python 3.9 or higher.")
|
74
|
+
|
75
|
+
if protected_classes is None:
|
76
|
+
logger.warning(
|
77
|
+
"No protected classes provided. Please pass the 'protected_classes' parameter to run this test."
|
78
|
+
)
|
79
|
+
return pd.DataFrame()
|
80
|
+
|
81
|
+
# Construct a function dictionary for figures
|
82
|
+
my_metrics = {
|
83
|
+
"fpr": false_positive_rate,
|
84
|
+
"tpr": true_positive_rate,
|
85
|
+
"selection rate": selection_rate,
|
86
|
+
"count": count,
|
87
|
+
}
|
88
|
+
|
89
|
+
# Construct a MetricFrame for figures
|
90
|
+
mf = MetricFrame(
|
91
|
+
metrics=my_metrics,
|
92
|
+
y_true=dataset.y,
|
93
|
+
y_pred=dataset.y_pred(model),
|
94
|
+
sensitive_features=dataset._df[protected_classes],
|
95
|
+
)
|
96
|
+
|
97
|
+
# Combine protected class columns to create a single multi-class category for the x-axis
|
98
|
+
metrics_by_group = mf.by_group.reset_index()
|
99
|
+
metrics_by_group["class_combination"] = metrics_by_group[protected_classes].apply(
|
100
|
+
lambda row: ", ".join(row.values.astype(str)), axis=1
|
101
|
+
)
|
102
|
+
|
103
|
+
# Create the subplots for the bar plots
|
104
|
+
fig = sp.make_subplots(
|
105
|
+
rows=2,
|
106
|
+
cols=2,
|
107
|
+
subplot_titles=[
|
108
|
+
"False Positive Rate",
|
109
|
+
"True Positive Rate",
|
110
|
+
"Selection Rate",
|
111
|
+
"Count",
|
112
|
+
],
|
113
|
+
)
|
114
|
+
|
115
|
+
# Add bar plots for each metric
|
116
|
+
fig.add_trace(
|
117
|
+
go.Bar(
|
118
|
+
x=metrics_by_group["class_combination"],
|
119
|
+
y=metrics_by_group["fpr"],
|
120
|
+
name="FPR",
|
121
|
+
),
|
122
|
+
row=1,
|
123
|
+
col=1,
|
124
|
+
)
|
125
|
+
fig.add_trace(
|
126
|
+
go.Bar(
|
127
|
+
x=metrics_by_group["class_combination"],
|
128
|
+
y=metrics_by_group["tpr"],
|
129
|
+
name="TPR",
|
130
|
+
),
|
131
|
+
row=1,
|
132
|
+
col=2,
|
133
|
+
)
|
134
|
+
fig.add_trace(
|
135
|
+
go.Bar(
|
136
|
+
x=metrics_by_group["class_combination"],
|
137
|
+
y=metrics_by_group["selection rate"],
|
138
|
+
name="Selection Rate",
|
139
|
+
),
|
140
|
+
row=2,
|
141
|
+
col=1,
|
142
|
+
)
|
143
|
+
fig.add_trace(
|
144
|
+
go.Bar(
|
145
|
+
x=metrics_by_group["class_combination"],
|
146
|
+
y=metrics_by_group["count"],
|
147
|
+
name="Count",
|
148
|
+
),
|
149
|
+
row=2,
|
150
|
+
col=2,
|
151
|
+
)
|
152
|
+
|
153
|
+
# Update layout of the figure to match the original style
|
154
|
+
fig.update_layout(
|
155
|
+
title="Show all metrics",
|
156
|
+
height=800,
|
157
|
+
width=900,
|
158
|
+
barmode="group",
|
159
|
+
legend=dict(orientation="h", yanchor="bottom", y=-0.3, xanchor="center", x=0.5),
|
160
|
+
margin=dict(t=50),
|
161
|
+
font=dict(size=12),
|
162
|
+
)
|
163
|
+
|
164
|
+
# Rotate x-axis labels for better readability
|
165
|
+
fig.update_xaxes(tickangle=45, row=1, col=1)
|
166
|
+
fig.update_xaxes(tickangle=45, row=1, col=2)
|
167
|
+
fig.update_xaxes(tickangle=45, row=2, col=1)
|
168
|
+
fig.update_xaxes(tickangle=45, row=2, col=2)
|
169
|
+
|
170
|
+
# Extract demographic parity ratio and equalized odds ratio
|
171
|
+
m_dpr = []
|
172
|
+
m_eqo = []
|
173
|
+
for protected_class in protected_classes:
|
174
|
+
m_dpr.append(
|
175
|
+
demographic_parity_ratio(
|
176
|
+
y_true=dataset.y,
|
177
|
+
y_pred=dataset.y_pred(model),
|
178
|
+
sensitive_features=dataset._df[[protected_class]],
|
179
|
+
)
|
180
|
+
)
|
181
|
+
m_eqo.append(
|
182
|
+
equalized_odds_ratio(
|
183
|
+
y_true=dataset.y,
|
184
|
+
y_pred=dataset.y_pred(model),
|
185
|
+
sensitive_features=dataset._df[[protected_class]],
|
186
|
+
)
|
187
|
+
)
|
188
|
+
|
189
|
+
# Create a DataFrame for the demographic parity and equalized odds ratio
|
190
|
+
dpr_eor_df = pd.DataFrame(
|
191
|
+
columns=protected_classes,
|
192
|
+
index=["demographic parity ratio", "equal odds ratio"],
|
193
|
+
)
|
194
|
+
|
195
|
+
for i in range(len(m_dpr)):
|
196
|
+
dpr_eor_df[protected_classes[i]]["demographic parity ratio"] = round(
|
197
|
+
m_dpr[i], 2
|
198
|
+
)
|
199
|
+
dpr_eor_df[protected_classes[i]]["equal odds ratio"] = round(m_eqo[i], 2)
|
200
|
+
|
201
|
+
return (
|
202
|
+
{"Class Combination Table": metrics_by_group},
|
203
|
+
{"DPR and EOR table": dpr_eor_df},
|
204
|
+
fig,
|
205
|
+
)
|