validmind 2.0.7__py3-none-any.whl → 2.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__init__.py +3 -3
- validmind/__version__.py +1 -1
- validmind/ai.py +7 -11
- validmind/api_client.py +29 -27
- validmind/client.py +10 -3
- validmind/datasets/credit_risk/__init__.py +11 -0
- validmind/datasets/credit_risk/datasets/lending_club_loan_data_2007_2014_clean.csv.gz +0 -0
- validmind/datasets/credit_risk/lending_club.py +394 -0
- validmind/logging.py +9 -2
- validmind/template.py +2 -2
- validmind/test_suites/__init__.py +4 -2
- validmind/tests/__init__.py +97 -50
- validmind/tests/data_validation/FeatureTargetCorrelationPlot.py +3 -1
- validmind/tests/data_validation/PiTCreditScoresHistogram.py +1 -1
- validmind/tests/data_validation/ScatterPlot.py +8 -2
- validmind/tests/decorator.py +138 -14
- validmind/tests/model_validation/BertScore.py +1 -1
- validmind/tests/model_validation/BertScoreAggregate.py +1 -1
- validmind/tests/model_validation/BleuScore.py +1 -1
- validmind/tests/model_validation/ClusterSizeDistribution.py +1 -1
- validmind/tests/model_validation/ContextualRecall.py +1 -1
- validmind/tests/model_validation/FeaturesAUC.py +110 -0
- validmind/tests/model_validation/MeteorScore.py +1 -1
- validmind/tests/model_validation/RegardHistogram.py +1 -1
- validmind/tests/model_validation/RegardScore.py +1 -1
- validmind/tests/model_validation/RegressionResidualsPlot.py +127 -0
- validmind/tests/model_validation/RougeMetrics.py +1 -1
- validmind/tests/model_validation/RougeMetricsAggregate.py +1 -1
- validmind/tests/model_validation/SelfCheckNLIScore.py +1 -1
- validmind/tests/model_validation/TokenDisparity.py +1 -1
- validmind/tests/model_validation/ToxicityHistogram.py +1 -1
- validmind/tests/model_validation/ToxicityScore.py +1 -1
- validmind/tests/model_validation/embeddings/ClusterDistribution.py +1 -1
- validmind/tests/model_validation/embeddings/CosineSimilarityDistribution.py +1 -3
- validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +1 -1
- validmind/tests/model_validation/embeddings/EmbeddingsVisualization2D.py +1 -1
- validmind/tests/model_validation/sklearn/ClassifierPerformance.py +15 -18
- validmind/tests/model_validation/sklearn/ClusterCosineSimilarity.py +1 -1
- validmind/tests/model_validation/sklearn/ClusterPerformance.py +2 -2
- validmind/tests/model_validation/sklearn/ConfusionMatrix.py +21 -3
- validmind/tests/model_validation/sklearn/MinimumAccuracy.py +1 -1
- validmind/tests/model_validation/sklearn/MinimumF1Score.py +1 -1
- validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +1 -1
- validmind/tests/model_validation/sklearn/ModelsPerformanceComparison.py +5 -4
- validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +2 -2
- validmind/tests/model_validation/sklearn/ROCCurve.py +6 -12
- validmind/tests/model_validation/sklearn/RegressionErrors.py +2 -2
- validmind/tests/model_validation/sklearn/RegressionModelsPerformanceComparison.py +6 -4
- validmind/tests/model_validation/sklearn/RegressionR2Square.py +2 -2
- validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +33 -3
- validmind/tests/model_validation/sklearn/SilhouettePlot.py +1 -1
- validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +2 -2
- validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +2 -2
- validmind/tests/model_validation/statsmodels/CumulativePredictionProbabilities.py +140 -0
- validmind/tests/model_validation/statsmodels/GINITable.py +22 -45
- validmind/tests/model_validation/statsmodels/{LogisticRegPredictionHistogram.py → PredictionProbabilitiesHistogram.py} +67 -92
- validmind/tests/model_validation/statsmodels/RegressionModelForecastPlot.py +2 -2
- validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +2 -2
- validmind/tests/model_validation/statsmodels/RegressionModelInsampleComparison.py +1 -1
- validmind/tests/model_validation/statsmodels/RegressionModelOutsampleComparison.py +1 -1
- validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +1 -1
- validmind/tests/model_validation/statsmodels/RegressionModelsPerformance.py +1 -1
- validmind/tests/model_validation/statsmodels/RegressionPermutationFeatureImportance.py +128 -0
- validmind/tests/model_validation/statsmodels/ScorecardHistogram.py +70 -103
- validmind/tests/test_providers.py +14 -124
- validmind/unit_metrics/__init__.py +76 -69
- validmind/unit_metrics/classification/sklearn/Accuracy.py +14 -0
- validmind/unit_metrics/classification/sklearn/F1.py +13 -0
- validmind/unit_metrics/classification/sklearn/Precision.py +13 -0
- validmind/unit_metrics/classification/sklearn/ROC_AUC.py +13 -0
- validmind/unit_metrics/classification/sklearn/Recall.py +13 -0
- validmind/unit_metrics/composite.py +24 -71
- validmind/unit_metrics/regression/GiniCoefficient.py +20 -26
- validmind/unit_metrics/regression/HuberLoss.py +12 -16
- validmind/unit_metrics/regression/KolmogorovSmirnovStatistic.py +18 -24
- validmind/unit_metrics/regression/MeanAbsolutePercentageError.py +7 -13
- validmind/unit_metrics/regression/MeanBiasDeviation.py +5 -14
- validmind/unit_metrics/regression/QuantileLoss.py +6 -16
- validmind/unit_metrics/regression/sklearn/AdjustedRSquaredScore.py +12 -18
- validmind/unit_metrics/regression/sklearn/MeanAbsoluteError.py +6 -15
- validmind/unit_metrics/regression/sklearn/MeanSquaredError.py +5 -14
- validmind/unit_metrics/regression/sklearn/RSquaredScore.py +6 -15
- validmind/unit_metrics/regression/sklearn/RootMeanSquaredError.py +11 -14
- validmind/utils.py +18 -45
- validmind/vm_models/__init__.py +0 -2
- validmind/vm_models/dataset.py +255 -16
- validmind/vm_models/test/metric.py +1 -2
- validmind/vm_models/test/result_wrapper.py +12 -13
- validmind/vm_models/test/test.py +2 -1
- validmind/vm_models/test/threshold_test.py +1 -2
- validmind/vm_models/test_suite/summary.py +3 -3
- validmind/vm_models/test_suite/test_suite.py +2 -1
- {validmind-2.0.7.dist-info → validmind-2.1.1.dist-info}/METADATA +10 -6
- {validmind-2.0.7.dist-info → validmind-2.1.1.dist-info}/RECORD +97 -96
- validmind/tests/__types__.py +0 -62
- validmind/tests/model_validation/statsmodels/LogRegressionConfusionMatrix.py +0 -128
- validmind/tests/model_validation/statsmodels/LogisticRegCumulativeProb.py +0 -172
- validmind/tests/model_validation/statsmodels/ScorecardBucketHistogram.py +0 -181
- validmind/tests/model_validation/statsmodels/ScorecardProbabilitiesHistogram.py +0 -175
- validmind/unit_metrics/sklearn/classification/Accuracy.py +0 -22
- validmind/unit_metrics/sklearn/classification/F1.py +0 -24
- validmind/unit_metrics/sklearn/classification/Precision.py +0 -24
- validmind/unit_metrics/sklearn/classification/ROC_AUC.py +0 -22
- validmind/unit_metrics/sklearn/classification/Recall.py +0 -22
- validmind/vm_models/test/unit_metric.py +0 -88
- {validmind-2.0.7.dist-info → validmind-2.1.1.dist-info}/LICENSE +0 -0
- {validmind-2.0.7.dist-info → validmind-2.1.1.dist-info}/WHEEL +0 -0
- {validmind-2.0.7.dist-info → validmind-2.1.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,128 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
from dataclasses import dataclass
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
import plotly.graph_objects as go
|
10
|
+
from sklearn.metrics import r2_score
|
11
|
+
from sklearn.utils import check_random_state
|
12
|
+
|
13
|
+
from validmind.errors import SkipTestError
|
14
|
+
from validmind.logging import get_logger
|
15
|
+
from validmind.vm_models import Figure, Metric
|
16
|
+
|
17
|
+
logger = get_logger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class RegressionPermutationFeatureImportance(Metric):
|
22
|
+
"""
|
23
|
+
Assesses the significance of each feature in a model by evaluating the impact on model performance when feature
|
24
|
+
values are randomly rearranged. Specifically designed for use with statsmodels, this metric offers insight into the
|
25
|
+
importance of features based on the decrease in model's predictive accuracy, typically R².
|
26
|
+
|
27
|
+
**Purpose**: The primary purpose of this metric is to determine which features significantly impact the performance
|
28
|
+
of a regression model developed using statsmodels. The metric measures how much the prediction accuracy deteriorates
|
29
|
+
when each feature's values are permuted.
|
30
|
+
|
31
|
+
**Test Mechanism**: This metric shuffles the values of each feature one at a time in the dataset, computes the model's
|
32
|
+
performance after each permutation, and compares it to the baseline performance. A significant decrease in performance
|
33
|
+
indicates the importance of the feature.
|
34
|
+
|
35
|
+
**Signs of High Risk**:
|
36
|
+
- Significant reliance on a feature that when permuted leads to a substantial decrease in performance, suggesting
|
37
|
+
overfitting or high model dependency on that feature.
|
38
|
+
- Features identified as unimportant despite known impacts from domain knowledge, suggesting potential issues in
|
39
|
+
model training or data preprocessing.
|
40
|
+
|
41
|
+
**Strengths**:
|
42
|
+
- Directly assesses the impact of each feature on model performance, providing clear insights into model dependencies.
|
43
|
+
- Model-agnostic within the scope of statsmodels, applicable to any regression model that outputs predictions.
|
44
|
+
|
45
|
+
**Limitations**:
|
46
|
+
- The metric is specific to statsmodels and cannot be used with other types of models without adaptation.
|
47
|
+
- It does not capture interactions between features, which can lead to underestimating the importance of correlated
|
48
|
+
features.
|
49
|
+
- Assumes independence of features when calculating importance, which might not always hold true.
|
50
|
+
"""
|
51
|
+
|
52
|
+
name = "regression_pfi"
|
53
|
+
required_inputs = ["model", "dataset"]
|
54
|
+
default_params = {
|
55
|
+
"fontsize": 12,
|
56
|
+
"figure_height": 500,
|
57
|
+
}
|
58
|
+
metadata = {
|
59
|
+
"task_types": ["regression"],
|
60
|
+
"tags": [
|
61
|
+
"statsmodels",
|
62
|
+
"feature_importance",
|
63
|
+
"visualization",
|
64
|
+
],
|
65
|
+
}
|
66
|
+
|
67
|
+
def run(self):
|
68
|
+
x = self.inputs.dataset.x_df()
|
69
|
+
y = self.inputs.dataset.y_df()
|
70
|
+
|
71
|
+
model = self.inputs.model.model
|
72
|
+
if not hasattr(model, "predict"):
|
73
|
+
raise SkipTestError(
|
74
|
+
"Model does not support 'predict' method required for PFI"
|
75
|
+
)
|
76
|
+
|
77
|
+
# Calculate baseline performance
|
78
|
+
baseline_performance = r2_score(y, model.predict(x))
|
79
|
+
importances = pd.DataFrame(index=x.columns, columns=["Importance", "Std Dev"])
|
80
|
+
|
81
|
+
for column in x.columns:
|
82
|
+
shuffled_scores = []
|
83
|
+
for _ in range(30): # Default number of shuffles
|
84
|
+
x_shuffled = x.copy()
|
85
|
+
x_shuffled[column] = check_random_state(0).permutation(
|
86
|
+
x_shuffled[column]
|
87
|
+
)
|
88
|
+
permuted_performance = r2_score(y, model.predict(x_shuffled))
|
89
|
+
shuffled_scores.append(baseline_performance - permuted_performance)
|
90
|
+
|
91
|
+
importances.loc[column] = {
|
92
|
+
"Importance": np.mean(shuffled_scores),
|
93
|
+
"Std Dev": np.std(shuffled_scores),
|
94
|
+
}
|
95
|
+
|
96
|
+
sorted_idx = importances["Importance"].argsort()
|
97
|
+
|
98
|
+
# Plotting the results
|
99
|
+
fig = go.Figure()
|
100
|
+
fig.add_trace(
|
101
|
+
go.Bar(
|
102
|
+
y=importances.index[sorted_idx],
|
103
|
+
x=importances.loc[importances.index[sorted_idx], "Importance"],
|
104
|
+
orientation="h",
|
105
|
+
error_x=dict(
|
106
|
+
type="data",
|
107
|
+
array=importances.loc[importances.index[sorted_idx], "Std Dev"],
|
108
|
+
),
|
109
|
+
)
|
110
|
+
)
|
111
|
+
fig.update_layout(
|
112
|
+
title_text="Permutation Feature Importances",
|
113
|
+
yaxis=dict(
|
114
|
+
tickmode="linear", dtick=1, tickfont=dict(size=self.params["fontsize"])
|
115
|
+
),
|
116
|
+
height=self.params["figure_height"],
|
117
|
+
)
|
118
|
+
|
119
|
+
return self.cache_results(
|
120
|
+
metric_value=importances.to_dict(),
|
121
|
+
figures=[
|
122
|
+
Figure(
|
123
|
+
for_object=self,
|
124
|
+
key="regression_pfi",
|
125
|
+
figure=fig,
|
126
|
+
),
|
127
|
+
],
|
128
|
+
)
|
@@ -4,10 +4,8 @@
|
|
4
4
|
|
5
5
|
from dataclasses import dataclass
|
6
6
|
|
7
|
-
import numpy as np
|
8
|
-
import pandas as pd
|
9
7
|
import plotly.graph_objects as go
|
10
|
-
from
|
8
|
+
from matplotlib import cm
|
11
9
|
|
12
10
|
from validmind.vm_models import Figure, Metric
|
13
11
|
|
@@ -53,120 +51,89 @@ class ScorecardHistogram(Metric):
|
|
53
51
|
"""
|
54
52
|
|
55
53
|
name = "scorecard_histogram"
|
56
|
-
required_inputs = ["
|
54
|
+
required_inputs = ["datasets"]
|
57
55
|
metadata = {
|
58
56
|
"task_types": ["classification"],
|
59
57
|
"tags": ["tabular_data", "visualization", "credit_risk"],
|
60
58
|
}
|
61
59
|
default_params = {
|
62
60
|
"title": "Histogram of Scores",
|
63
|
-
"
|
64
|
-
"target_odds": 50,
|
65
|
-
"pdo": 20,
|
61
|
+
"score_column": "score",
|
66
62
|
}
|
67
63
|
|
68
64
|
@staticmethod
|
69
|
-
def
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
fig.add_trace(trace_train_0, row=1, col=1)
|
111
|
-
fig.add_trace(trace_train_1, row=1, col=1)
|
112
|
-
fig.add_trace(trace_test_0, row=1, col=2)
|
113
|
-
fig.add_trace(trace_test_1, row=1, col=2)
|
114
|
-
|
115
|
-
fig.update_layout(barmode="overlay", title_text=title)
|
116
|
-
|
117
|
-
return fig
|
65
|
+
def plot_score_histogram(dataframes, dataset_titles, score_col, target_col, title):
|
66
|
+
figures = []
|
67
|
+
# Generate a colormap and convert to Plotly-accepted color format
|
68
|
+
# Adjust 'viridis' to any other matplotlib colormap if desired
|
69
|
+
colormap = cm.get_cmap("viridis")
|
70
|
+
|
71
|
+
for _, (df, dataset_title) in enumerate(zip(dataframes, dataset_titles)):
|
72
|
+
fig = go.Figure()
|
73
|
+
|
74
|
+
# Get unique classes and assign colors
|
75
|
+
classes = sorted(df[target_col].unique())
|
76
|
+
colors = [
|
77
|
+
colormap(i / len(classes))[:3] for i in range(len(classes))
|
78
|
+
] # RGB
|
79
|
+
color_dict = {
|
80
|
+
cls: f"rgb({int(rgb[0]*255)}, {int(rgb[1]*255)}, {int(rgb[2]*255)})"
|
81
|
+
for cls, rgb in zip(classes, colors)
|
82
|
+
}
|
83
|
+
|
84
|
+
for class_value in sorted(df[target_col].unique()):
|
85
|
+
scores_class = df[df[target_col] == class_value][score_col]
|
86
|
+
fig.add_trace(
|
87
|
+
go.Histogram(
|
88
|
+
x=scores_class,
|
89
|
+
opacity=0.75,
|
90
|
+
name=f"{dataset_title} {target_col} = {class_value}",
|
91
|
+
marker=dict(
|
92
|
+
color=color_dict[class_value],
|
93
|
+
),
|
94
|
+
)
|
95
|
+
)
|
96
|
+
fig.update_layout(
|
97
|
+
barmode="overlay",
|
98
|
+
title_text=f"{title} - {dataset_title}",
|
99
|
+
xaxis_title="Score",
|
100
|
+
yaxis_title="Frequency",
|
101
|
+
legend_title=target_col,
|
102
|
+
)
|
103
|
+
figures.append(fig)
|
104
|
+
return figures
|
118
105
|
|
119
106
|
def run(self):
|
120
|
-
model = (
|
121
|
-
self.inputs.model[0]
|
122
|
-
if isinstance(self.inputs.model, list)
|
123
|
-
else self.inputs.model
|
124
|
-
)
|
125
|
-
|
126
|
-
target_column = model.train_ds.target_column
|
127
107
|
title = self.params["title"]
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
108
|
+
score_column = self.params["score_column"]
|
109
|
+
dataset_titles = [dataset.input_id for dataset in self.inputs.datasets]
|
110
|
+
target_column = self.inputs.datasets[0].target_column
|
111
|
+
|
112
|
+
dataframes = []
|
113
|
+
metric_value = {"score_histogram": {}}
|
114
|
+
for dataset in self.inputs.datasets:
|
115
|
+
df = dataset.df.copy()
|
116
|
+
# Check if the score_column exists in the DataFrame
|
117
|
+
if score_column not in df.columns:
|
118
|
+
raise ValueError(
|
119
|
+
f"The required column '{score_column}' is not present in the dataset with input_id {dataset.input_id}"
|
120
|
+
)
|
139
121
|
|
140
|
-
|
141
|
-
|
142
|
-
|
122
|
+
df[score_column] = dataset.get_extra_column(score_column)
|
123
|
+
dataframes.append(df)
|
124
|
+
metric_value["score_histogram"][dataset.input_id] = list(df[score_column])
|
143
125
|
|
144
|
-
|
145
|
-
|
146
|
-
)
|
147
|
-
X_test_scores = self.compute_scores(
|
148
|
-
model, X_test, target_score, target_odds, pdo
|
126
|
+
figures = self.plot_score_histogram(
|
127
|
+
dataframes, dataset_titles, score_column, target_column, title
|
149
128
|
)
|
150
129
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
130
|
+
figures_list = [
|
131
|
+
Figure(
|
132
|
+
for_object=self,
|
133
|
+
key=f"score_histogram_{title.replace(' ', '_')}_{i+1}",
|
134
|
+
figure=fig,
|
135
|
+
)
|
136
|
+
for i, fig in enumerate(figures)
|
137
|
+
]
|
157
138
|
|
158
|
-
return self.cache_results(
|
159
|
-
metric_value={
|
160
|
-
"score_histogram": {
|
161
|
-
"train_scores": list(X_train_scores["score"]),
|
162
|
-
"test_scores": list(X_test_scores["score"]),
|
163
|
-
},
|
164
|
-
},
|
165
|
-
figures=[
|
166
|
-
Figure(
|
167
|
-
for_object=self,
|
168
|
-
key="score_histogram",
|
169
|
-
figure=fig,
|
170
|
-
)
|
171
|
-
],
|
172
|
-
)
|
139
|
+
return self.cache_results(metric_value=metric_value, figures=figures_list)
|
@@ -5,44 +5,30 @@
|
|
5
5
|
import importlib.util
|
6
6
|
import os
|
7
7
|
import sys
|
8
|
-
|
9
|
-
import requests
|
8
|
+
from typing import Protocol
|
10
9
|
|
11
10
|
from validmind.logging import get_logger
|
12
11
|
|
13
12
|
logger = get_logger(__name__)
|
14
13
|
|
15
14
|
|
16
|
-
class
|
17
|
-
"""
|
18
|
-
When the remote file can't be downloaded from the repo.
|
19
|
-
"""
|
20
|
-
|
21
|
-
pass
|
22
|
-
|
15
|
+
class TestProvider(Protocol):
|
16
|
+
"""Protocol for user-defined test providers"""
|
23
17
|
|
24
|
-
|
25
|
-
|
26
|
-
When the remote file can't be downloaded from the repo.
|
27
|
-
"""
|
28
|
-
|
29
|
-
pass
|
30
|
-
|
31
|
-
|
32
|
-
class GithubTestProviderLoadModuleError(Exception):
|
33
|
-
"""
|
34
|
-
When the remote file was downloaded but the module can't be loaded.
|
35
|
-
"""
|
36
|
-
|
37
|
-
pass
|
18
|
+
def load_test(self, test_id: str):
|
19
|
+
"""Load the test by test ID
|
38
20
|
|
21
|
+
Args:
|
22
|
+
test_id (str): The test ID (does not contain the namespace under which
|
23
|
+
the test is registered)
|
39
24
|
|
40
|
-
|
41
|
-
|
42
|
-
When the module was loaded but the test class can't be located.
|
43
|
-
"""
|
25
|
+
Returns:
|
26
|
+
Test: A test class or function
|
44
27
|
|
45
|
-
|
28
|
+
Raises:
|
29
|
+
FileNotFoundError: If the test is not found
|
30
|
+
"""
|
31
|
+
...
|
46
32
|
|
47
33
|
|
48
34
|
class LocalTestProviderLoadModuleError(Exception):
|
@@ -61,102 +47,6 @@ class LocalTestProviderLoadTestError(Exception):
|
|
61
47
|
pass
|
62
48
|
|
63
49
|
|
64
|
-
class GithubTestProvider:
|
65
|
-
"""
|
66
|
-
A class used to download python files from a Github repository and
|
67
|
-
dynamically load and execute the tests from those files.
|
68
|
-
"""
|
69
|
-
|
70
|
-
BASE_URL = "https://api.github.com/repos"
|
71
|
-
|
72
|
-
def __init__(self, org: str, repo: str, token: str):
|
73
|
-
"""
|
74
|
-
Initialize the GithubTestProvider with the given org, repo, and token.
|
75
|
-
|
76
|
-
Args:
|
77
|
-
org (str): The Github organization.
|
78
|
-
repo (str): The Github repository.
|
79
|
-
token (str): The Github access token.
|
80
|
-
"""
|
81
|
-
self.org = org
|
82
|
-
self.repo = repo
|
83
|
-
self.token = token
|
84
|
-
|
85
|
-
def _download_file(self, test_path: str) -> str:
|
86
|
-
"""
|
87
|
-
Download the file at the given test_path from the Github repository.
|
88
|
-
|
89
|
-
Args:
|
90
|
-
test_path (str): The path of the file in the repository.
|
91
|
-
|
92
|
-
Returns:
|
93
|
-
str: The local file path where the file was downloaded.
|
94
|
-
|
95
|
-
Raises:
|
96
|
-
Exception: If the file can't be downloaded or written.
|
97
|
-
"""
|
98
|
-
url = f"{self.BASE_URL}/{self.org}/{self.repo}/contents/{test_path}"
|
99
|
-
|
100
|
-
headers = {
|
101
|
-
"Authorization": f"token {self.token}",
|
102
|
-
"Accept": "application/vnd.github.v3.raw",
|
103
|
-
"X-Github-Api-Version": "2022-11-28",
|
104
|
-
}
|
105
|
-
|
106
|
-
try:
|
107
|
-
response = requests.get(url, headers=headers)
|
108
|
-
response.raise_for_status()
|
109
|
-
except requests.RequestException as e:
|
110
|
-
raise GithubTestProviderDownloadError(
|
111
|
-
f"Failed to download the file at {url}. Error: {str(e)}"
|
112
|
-
)
|
113
|
-
|
114
|
-
file_path = f"/tmp/{os.path.basename(test_path)}"
|
115
|
-
try:
|
116
|
-
with open(file_path, "w") as file:
|
117
|
-
file.write(response.text)
|
118
|
-
except IOError as e:
|
119
|
-
raise GithubTestProviderWriteFileError(
|
120
|
-
f"Failed to write the file to {file_path}. Error: {str(e)}"
|
121
|
-
)
|
122
|
-
|
123
|
-
return file_path
|
124
|
-
|
125
|
-
def load_test(self, test_id):
|
126
|
-
"""
|
127
|
-
Load the test identified by the given test_id.
|
128
|
-
|
129
|
-
Args:
|
130
|
-
test_id (str): The identifier of the test. This corresponds to the
|
131
|
-
relative path of the python file in the repository, with slashes replaced by dots.
|
132
|
-
|
133
|
-
Returns:
|
134
|
-
The test class that matches the last part of the test_id.
|
135
|
-
|
136
|
-
Raises:
|
137
|
-
Exception: If the test can't be imported or loaded.
|
138
|
-
"""
|
139
|
-
test_path = f"{test_id.replace('.', '/')}.py"
|
140
|
-
file_path = self._download_file(test_path)
|
141
|
-
|
142
|
-
try:
|
143
|
-
spec = importlib.util.spec_from_file_location(test_id, file_path)
|
144
|
-
module = importlib.util.module_from_spec(spec)
|
145
|
-
spec.loader.exec_module(module)
|
146
|
-
except Exception as e:
|
147
|
-
raise GithubTestProviderLoadModuleError(
|
148
|
-
f"Failed to load the module from {file_path}. Error: {str(e)}"
|
149
|
-
)
|
150
|
-
|
151
|
-
try:
|
152
|
-
# find the test class that matches the last part of the test_id
|
153
|
-
return getattr(module, test_id.split(".")[-1])
|
154
|
-
except AttributeError as e:
|
155
|
-
raise GithubTestProviderLoadTestError(
|
156
|
-
f"Failed to find the test class in the module. Error: {str(e)}"
|
157
|
-
)
|
158
|
-
|
159
|
-
|
160
50
|
class LocalTestProvider:
|
161
51
|
"""
|
162
52
|
Test providers in ValidMind are responsible for loading tests from different sources,
|
@@ -3,14 +3,13 @@
|
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
5
|
import hashlib
|
6
|
-
import importlib
|
7
6
|
import json
|
7
|
+
from importlib import import_module
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
|
11
|
-
from
|
12
|
-
|
13
|
-
from ..utils import get_model_info
|
11
|
+
from ..tests.decorator import _build_result, _inspect_signature
|
12
|
+
from ..utils import get_model_info, test_id_to_name
|
14
13
|
|
15
14
|
unit_metric_results_cache = {}
|
16
15
|
|
@@ -134,52 +133,6 @@ def _fast_hash(df, sample_size=1000, model_and_prediction_info=None):
|
|
134
133
|
return hash_obj.hexdigest()
|
135
134
|
|
136
135
|
|
137
|
-
def _get_metric_class(metric_id):
|
138
|
-
"""Get the metric class by metric_id
|
139
|
-
|
140
|
-
This function will load the metric class by metric_id.
|
141
|
-
|
142
|
-
Args:
|
143
|
-
metric_id (str): The full metric id (e.g. 'validmind.vm_models.test.v2.model_validation.sklearn.F1')
|
144
|
-
|
145
|
-
Returns:
|
146
|
-
Metric: The metric class
|
147
|
-
"""
|
148
|
-
|
149
|
-
metric_module = importlib.import_module(f"{metric_id}")
|
150
|
-
|
151
|
-
class_name = metric_id.split(".")[-1]
|
152
|
-
|
153
|
-
# Access the class within the F1 module
|
154
|
-
metric_class = getattr(metric_module, class_name)
|
155
|
-
|
156
|
-
return metric_class
|
157
|
-
|
158
|
-
|
159
|
-
def get_input_type(input_obj):
|
160
|
-
"""
|
161
|
-
Determines whether the input object is a 'dataset' or 'model' based on its class module path.
|
162
|
-
|
163
|
-
Args:
|
164
|
-
input_obj: The object to type check.
|
165
|
-
|
166
|
-
Returns:
|
167
|
-
str: 'dataset' or 'model' depending on the object's module, or raises ValueError.
|
168
|
-
"""
|
169
|
-
# Obtain the class object of input_obj (for clarity and debugging)
|
170
|
-
class_obj = input_obj.__class__
|
171
|
-
|
172
|
-
# Obtain the module name as a string from the class object
|
173
|
-
class_module = class_obj.__module__
|
174
|
-
|
175
|
-
if "validmind.vm_models.dataset" in class_module:
|
176
|
-
return "dataset"
|
177
|
-
elif "validmind.models" in class_module:
|
178
|
-
return "model"
|
179
|
-
else:
|
180
|
-
raise ValueError("Input must be of type validmind Dataset or Model")
|
181
|
-
|
182
|
-
|
183
136
|
def get_metric_cache_key(metric_id, params, inputs):
|
184
137
|
cache_elements = [metric_id]
|
185
138
|
|
@@ -209,34 +162,88 @@ def get_metric_cache_key(metric_id, params, inputs):
|
|
209
162
|
return key
|
210
163
|
|
211
164
|
|
212
|
-
def
|
213
|
-
"""
|
214
|
-
|
215
|
-
This function provides a high level interface for running a single metric. A metric
|
216
|
-
is a single test that calculates a value based on the input data.
|
165
|
+
def load_metric(metric_id):
|
166
|
+
"""Load a metric class from a string
|
217
167
|
|
218
168
|
Args:
|
219
|
-
metric_id (str): The metric
|
220
|
-
params (dict): A dictionary of the metric parameters
|
169
|
+
metric_id (str): The metric id (e.g. 'validmind.unit_metrics.classification.sklearn.F1')
|
221
170
|
|
222
171
|
Returns:
|
223
|
-
|
172
|
+
callable: The metric function
|
224
173
|
"""
|
225
|
-
|
174
|
+
return getattr(import_module(metric_id), metric_id.split(".")[-1])
|
226
175
|
|
227
|
-
# Check if the metric value already exists in the global variable
|
228
|
-
if cache_key in unit_metric_results_cache:
|
229
|
-
return unit_metric_results_cache[cache_key]
|
230
176
|
|
231
|
-
|
232
|
-
|
177
|
+
def run_metric(metric_id, inputs=None, params=None, show=True, value_only=False):
|
178
|
+
"""Run a single metric and cache the results
|
179
|
+
|
180
|
+
Args:
|
181
|
+
metric_id (str): The metric id (e.g. 'validmind.unit_metrics.classification.sklearn.F1')
|
182
|
+
inputs (dict): A dictionary of the metric inputs
|
183
|
+
params (dict): A dictionary of the metric parameters
|
184
|
+
show (bool): Whether to display the results
|
185
|
+
value_only (bool): Whether to return only the value
|
186
|
+
"""
|
187
|
+
inputs = inputs or {}
|
188
|
+
params = params or {}
|
189
|
+
|
190
|
+
cache_key = get_metric_cache_key(metric_id, params, inputs)
|
191
|
+
|
192
|
+
if cache_key not in unit_metric_results_cache:
|
193
|
+
metric = load_metric(metric_id)
|
194
|
+
_inputs, _params = _inspect_signature(metric)
|
195
|
+
|
196
|
+
result = metric(
|
197
|
+
**{k: v for k, v in inputs.items() if k in _inputs.keys()},
|
198
|
+
**{k: v for k, v in params.items() if k in _params.keys()},
|
199
|
+
)
|
200
|
+
unit_metric_results_cache[cache_key] = (result, list(_inputs.keys()))
|
201
|
+
|
202
|
+
value = unit_metric_results_cache[cache_key][0]
|
203
|
+
|
204
|
+
if value_only:
|
205
|
+
return value
|
206
|
+
|
207
|
+
output_template = f"""
|
208
|
+
<table>
|
209
|
+
<thead>
|
210
|
+
<tr>
|
211
|
+
<th>Metric</th>
|
212
|
+
<th>Value</th>
|
213
|
+
</tr>
|
214
|
+
</thead>
|
215
|
+
<tbody>
|
216
|
+
<tr>
|
217
|
+
<td><strong>{test_id_to_name(metric_id)}</strong></td>
|
218
|
+
<td>{value:.4f}</td>
|
219
|
+
</tr>
|
220
|
+
</tbody>
|
221
|
+
</table>
|
222
|
+
<style>
|
223
|
+
th, td {{
|
224
|
+
padding: 5px;
|
225
|
+
text-align: left;
|
226
|
+
}}
|
227
|
+
</style>
|
228
|
+
"""
|
229
|
+
result = _build_result(
|
230
|
+
results=value,
|
231
|
+
test_id=metric_id,
|
232
|
+
description="",
|
233
|
+
output_template=output_template,
|
234
|
+
inputs=unit_metric_results_cache[cache_key][1],
|
235
|
+
)
|
233
236
|
|
234
|
-
#
|
235
|
-
|
237
|
+
# in case the user tries to log the result object
|
238
|
+
def log(self):
|
239
|
+
raise Exception(
|
240
|
+
"Cannot log unit metrics directly..."
|
241
|
+
"You can run this unit metric as part of a composite metric and log that"
|
242
|
+
)
|
236
243
|
|
237
|
-
|
238
|
-
result = metric.run()
|
244
|
+
result.log = log
|
239
245
|
|
240
|
-
|
246
|
+
if show:
|
247
|
+
result.show()
|
241
248
|
|
242
249
|
return result
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
from sklearn.metrics import accuracy_score
|
6
|
+
|
7
|
+
from validmind import tags, tasks
|
8
|
+
|
9
|
+
|
10
|
+
@tags("classification", "sklearn", "unit_metric")
|
11
|
+
@tasks("classification")
|
12
|
+
def Accuracy(dataset, model):
|
13
|
+
"""Calculates the accuracy of a model"""
|
14
|
+
return accuracy_score(dataset.y, dataset.y_pred(model))
|