validmind 2.3.1__py3-none-any.whl → 2.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__init__.py +2 -1
- validmind/__version__.py +1 -1
- validmind/datasets/regression/fred_timeseries.py +272 -0
- validmind/test_suites/__init__.py +0 -2
- validmind/tests/__init__.py +7 -7
- validmind/tests/__types__.py +180 -0
- validmind/tests/data_validation/SeasonalDecompose.py +68 -40
- validmind/tests/data_validation/TimeSeriesDescription.py +74 -0
- validmind/tests/data_validation/TimeSeriesDescriptiveStatistics.py +76 -0
- validmind/tests/data_validation/TimeSeriesHistogram.py +29 -45
- validmind/tests/data_validation/TimeSeriesOutliers.py +30 -41
- validmind/tests/decorator.py +12 -0
- validmind/tests/model_validation/ModelMetadataComparison.py +59 -0
- validmind/tests/model_validation/ModelPredictionResiduals.py +103 -0
- validmind/tests/model_validation/TimeSeriesPredictionWithCI.py +131 -0
- validmind/tests/model_validation/TimeSeriesPredictionsPlot.py +76 -0
- validmind/tests/model_validation/TimeSeriesR2SquareBySegments.py +103 -0
- validmind/tests/model_validation/sklearn/FeatureImportanceComparison.py +83 -0
- validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +1 -1
- validmind/tests/model_validation/sklearn/RegressionErrorsComparison.py +76 -0
- validmind/tests/model_validation/sklearn/RegressionR2SquareComparison.py +63 -0
- validmind/utils.py +34 -0
- {validmind-2.3.1.dist-info → validmind-2.3.5.dist-info}/METADATA +70 -36
- {validmind-2.3.1.dist-info → validmind-2.3.5.dist-info}/RECORD +28 -16
- /validmind/datasets/regression/datasets/{lending_club_loan_rates.csv → leanding_club_loan_rates.csv} +0 -0
- {validmind-2.3.1.dist-info → validmind-2.3.5.dist-info}/LICENSE +0 -0
- {validmind-2.3.1.dist-info → validmind-2.3.5.dist-info}/WHEEL +0 -0
- {validmind-2.3.1.dist-info → validmind-2.3.5.dist-info}/entry_points.txt +0 -0
@@ -4,10 +4,10 @@
|
|
4
4
|
|
5
5
|
import warnings
|
6
6
|
|
7
|
-
import matplotlib.pyplot as plt
|
8
7
|
import numpy as np
|
9
8
|
import pandas as pd
|
10
|
-
import
|
9
|
+
import plotly.graph_objects as go
|
10
|
+
from plotly.subplots import make_subplots
|
11
11
|
from scipy import stats
|
12
12
|
from statsmodels.tsa.seasonal import seasonal_decompose
|
13
13
|
|
@@ -132,7 +132,6 @@ class SeasonalDecompose(Metric):
|
|
132
132
|
inferred_freq = pd.infer_freq(series.index)
|
133
133
|
|
134
134
|
if inferred_freq is not None:
|
135
|
-
logger.info(f"Frequency of {col}: {inferred_freq}")
|
136
135
|
|
137
136
|
# Only take finite values to seasonal_decompose
|
138
137
|
sd = seasonal_decompose(
|
@@ -142,58 +141,87 @@ class SeasonalDecompose(Metric):
|
|
142
141
|
|
143
142
|
results[col] = self.serialize_seasonal_decompose(sd)
|
144
143
|
|
145
|
-
# Create subplots
|
146
|
-
fig
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
144
|
+
# Create subplots using Plotly
|
145
|
+
fig = make_subplots(
|
146
|
+
rows=3,
|
147
|
+
cols=2,
|
148
|
+
subplot_titles=(
|
149
|
+
"Observed",
|
150
|
+
"Trend",
|
151
|
+
"Seasonal",
|
152
|
+
"Residuals",
|
153
|
+
"Histogram and KDE of Residuals",
|
154
|
+
"Normal Q-Q Plot of Residuals",
|
155
|
+
),
|
156
|
+
vertical_spacing=0.1,
|
155
157
|
)
|
156
158
|
|
157
|
-
# Original seasonal decomposition plots
|
158
159
|
# Observed
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
160
|
+
fig.add_trace(
|
161
|
+
go.Scatter(x=sd.observed.index, y=sd.observed, name="Observed"),
|
162
|
+
row=1,
|
163
|
+
col=1,
|
164
|
+
)
|
163
165
|
|
164
166
|
# Trend
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
167
|
+
fig.add_trace(
|
168
|
+
go.Scatter(x=sd.trend.index, y=sd.trend, name="Trend"),
|
169
|
+
row=1,
|
170
|
+
col=2,
|
171
|
+
)
|
169
172
|
|
170
173
|
# Seasonal
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
174
|
+
fig.add_trace(
|
175
|
+
go.Scatter(x=sd.seasonal.index, y=sd.seasonal, name="Seasonal"),
|
176
|
+
row=2,
|
177
|
+
col=1,
|
178
|
+
)
|
175
179
|
|
176
180
|
# Residuals
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
+
fig.add_trace(
|
182
|
+
go.Scatter(x=sd.resid.index, y=sd.resid, name="Residuals"),
|
183
|
+
row=2,
|
184
|
+
col=2,
|
185
|
+
)
|
181
186
|
|
182
187
|
# Histogram with KDE
|
183
188
|
residuals = sd.resid.dropna()
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
189
|
+
fig.add_trace(
|
190
|
+
go.Histogram(x=residuals, nbinsx=100, name="Residuals"),
|
191
|
+
row=3,
|
192
|
+
col=1,
|
193
|
+
)
|
188
194
|
|
189
195
|
# Normal Q-Q plot
|
190
|
-
stats.probplot(residuals, plot=
|
191
|
-
|
192
|
-
|
193
|
-
|
196
|
+
qq = stats.probplot(residuals, plot=None)
|
197
|
+
qq_line_slope, qq_line_intercept = stats.linregress(
|
198
|
+
qq[0][0], qq[0][1]
|
199
|
+
)[:2]
|
200
|
+
qq_line = qq_line_slope * np.array(qq[0][0]) + qq_line_intercept
|
201
|
+
|
202
|
+
fig.add_trace(
|
203
|
+
go.Scatter(
|
204
|
+
x=qq[0][0], y=qq[0][1], mode="markers", name="QQ plot"
|
205
|
+
),
|
206
|
+
row=3,
|
207
|
+
col=2,
|
208
|
+
)
|
209
|
+
fig.add_trace(
|
210
|
+
go.Scatter(
|
211
|
+
x=qq[0][0],
|
212
|
+
y=qq_line,
|
213
|
+
mode="lines",
|
214
|
+
name="QQ line",
|
215
|
+
),
|
216
|
+
row=3,
|
217
|
+
col=2,
|
218
|
+
)
|
194
219
|
|
195
|
-
|
196
|
-
|
220
|
+
fig.update_layout(
|
221
|
+
height=1000,
|
222
|
+
title_text=f"Seasonal Decomposition for {col}",
|
223
|
+
showlegend=False,
|
224
|
+
)
|
197
225
|
|
198
226
|
figures.append(
|
199
227
|
Figure(
|
@@ -0,0 +1,74 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import pandas as pd
|
6
|
+
|
7
|
+
from validmind import tags, tasks
|
8
|
+
|
9
|
+
|
10
|
+
@tags("time_series_data", "analysis")
|
11
|
+
@tasks("regression")
|
12
|
+
def TimeSeriesDescription(dataset):
|
13
|
+
"""
|
14
|
+
Generates a detailed analysis for the provided time series dataset.
|
15
|
+
|
16
|
+
**Purpose**: The purpose of the TimeSeriesDescription function is to analyze an individual time series
|
17
|
+
by providing a summary of key statistics. This helps in understanding trends, patterns, and data quality issues
|
18
|
+
within the time series.
|
19
|
+
|
20
|
+
**Test Mechanism**: The function extracts the time series data and provides a summary of key statistics.
|
21
|
+
The dataset is expected to have a datetime index. The function checks this and raises an error if the index is
|
22
|
+
not in datetime format. For each variable (column) in the dataset, appropriate statistics including start date,
|
23
|
+
end date, frequency, number of missing values, count, min, and max values are calculated.
|
24
|
+
|
25
|
+
**Signs of High Risk**:
|
26
|
+
- If the index of the dataset is not in datetime format, it could lead to errors in time-series analysis.
|
27
|
+
- Inconsistent or missing data within the dataset might affect the analysis of trends and patterns.
|
28
|
+
|
29
|
+
**Strengths**:
|
30
|
+
- This function provides a comprehensive summary of key statistics for each variable, helping to identify data quality
|
31
|
+
issues such as missing values.
|
32
|
+
- The function helps in understanding the distribution and range of the data by including min and max values.
|
33
|
+
|
34
|
+
**Limitations**:
|
35
|
+
- This function assumes that the dataset is provided as a DataFrameDataset object with a .df attribute to access
|
36
|
+
the pandas DataFrame.
|
37
|
+
- It only analyzes datasets with a datetime index and will raise an error for other types of indices.
|
38
|
+
- The function does not handle large datasets efficiently, and performance may degrade with very large datasets.
|
39
|
+
"""
|
40
|
+
|
41
|
+
summary = []
|
42
|
+
|
43
|
+
df = (
|
44
|
+
dataset.df
|
45
|
+
) # Assuming DataFrameDataset objects have a .df attribute to get the pandas DataFrame
|
46
|
+
|
47
|
+
if not pd.api.types.is_datetime64_any_dtype(df.index):
|
48
|
+
raise ValueError(f"Dataset {dataset.input_id} must have a datetime index")
|
49
|
+
|
50
|
+
for column in df.columns:
|
51
|
+
start_date = df.index.min().strftime("%Y-%m-%d")
|
52
|
+
end_date = df.index.max().strftime("%Y-%m-%d")
|
53
|
+
frequency = pd.infer_freq(df.index)
|
54
|
+
num_missing_values = df[column].isna().sum()
|
55
|
+
count = df[column].count()
|
56
|
+
min_value = df[column].min()
|
57
|
+
max_value = df[column].max()
|
58
|
+
|
59
|
+
summary.append(
|
60
|
+
{
|
61
|
+
"Variable": column,
|
62
|
+
"Start Date": start_date,
|
63
|
+
"End Date": end_date,
|
64
|
+
"Frequency": frequency,
|
65
|
+
"Num of Missing Values": num_missing_values,
|
66
|
+
"Count": count,
|
67
|
+
"Min Value": min_value,
|
68
|
+
"Max Value": max_value,
|
69
|
+
}
|
70
|
+
)
|
71
|
+
|
72
|
+
result_df = pd.DataFrame(summary)
|
73
|
+
|
74
|
+
return result_df
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import pandas as pd
|
6
|
+
from scipy.stats import kurtosis, skew
|
7
|
+
|
8
|
+
from validmind import tags, tasks
|
9
|
+
|
10
|
+
|
11
|
+
@tags("time_series_data", "analysis")
|
12
|
+
@tasks("regression")
|
13
|
+
def TimeSeriesDescriptiveStatistics(dataset):
|
14
|
+
"""
|
15
|
+
Generates a detailed table of descriptive statistics for the provided time series dataset.
|
16
|
+
|
17
|
+
**Purpose**: The purpose of the TimeSeriesDescriptiveStatistics function is to analyze an individual time series
|
18
|
+
by providing a summary of key descriptive statistics. This helps in understanding trends, patterns, and data quality issues
|
19
|
+
within the time series.
|
20
|
+
|
21
|
+
**Test Mechanism**: The function extracts the time series data and provides a summary of key descriptive statistics.
|
22
|
+
The dataset is expected to have a datetime index. The function checks this and raises an error if the index is
|
23
|
+
not in datetime format. For each variable (column) in the dataset, appropriate statistics including start date,
|
24
|
+
end date, min, mean, max, skewness, kurtosis, and count are calculated.
|
25
|
+
|
26
|
+
**Signs of High Risk**:
|
27
|
+
- If the index of the dataset is not in datetime format, it could lead to errors in time-series analysis.
|
28
|
+
- Inconsistent or missing data within the dataset might affect the analysis of trends and patterns.
|
29
|
+
|
30
|
+
**Strengths**:
|
31
|
+
- This function provides a comprehensive summary of key descriptive statistics for each variable, helping to identify data quality
|
32
|
+
issues and understand the distribution of the data.
|
33
|
+
|
34
|
+
**Limitations**:
|
35
|
+
- This function assumes that the dataset is provided as a DataFrameDataset object with a .df attribute to access
|
36
|
+
the pandas DataFrame.
|
37
|
+
- It only analyzes datasets with a datetime index and will raise an error for other types of indices.
|
38
|
+
- The function does not handle large datasets efficiently, and performance may degrade with very large datasets.
|
39
|
+
"""
|
40
|
+
|
41
|
+
summary = []
|
42
|
+
|
43
|
+
df = (
|
44
|
+
dataset.df
|
45
|
+
) # Assuming DataFrameDataset objects have a .df attribute to get the pandas DataFrame
|
46
|
+
|
47
|
+
if not pd.api.types.is_datetime64_any_dtype(df.index):
|
48
|
+
raise ValueError(f"Dataset {dataset.input_id} must have a datetime index")
|
49
|
+
|
50
|
+
for column in df.columns:
|
51
|
+
start_date = df.index.min().strftime("%Y-%m-%d")
|
52
|
+
end_date = df.index.max().strftime("%Y-%m-%d")
|
53
|
+
count = df[column].count()
|
54
|
+
min_value = df[column].min()
|
55
|
+
mean_value = df[column].mean()
|
56
|
+
max_value = df[column].max()
|
57
|
+
skewness_value = skew(df[column].dropna())
|
58
|
+
kurtosis_value = kurtosis(df[column].dropna())
|
59
|
+
|
60
|
+
summary.append(
|
61
|
+
{
|
62
|
+
"Variable": column,
|
63
|
+
"Start Date": start_date,
|
64
|
+
"End Date": end_date,
|
65
|
+
"Min": min_value,
|
66
|
+
"Mean": mean_value,
|
67
|
+
"Max": max_value,
|
68
|
+
"Skewness": skewness_value,
|
69
|
+
"Kurtosis": kurtosis_value,
|
70
|
+
"Count": count,
|
71
|
+
}
|
72
|
+
)
|
73
|
+
|
74
|
+
result_df = pd.DataFrame(summary)
|
75
|
+
|
76
|
+
return result_df
|
@@ -2,14 +2,14 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
-
import
|
6
|
-
import pandas as pd
|
7
|
-
import seaborn as sns
|
5
|
+
import plotly.express as px
|
8
6
|
|
9
|
-
from validmind
|
7
|
+
from validmind import tags, tasks
|
10
8
|
|
11
9
|
|
12
|
-
|
10
|
+
@tags("data_validation", "visualization")
|
11
|
+
@tasks("regression", "time_series_forecasting")
|
12
|
+
def TimeSeriesHistogram(dataset, nbins=30):
|
13
13
|
"""
|
14
14
|
Visualizes distribution of time-series data using histograms and Kernel Density Estimation (KDE) lines.
|
15
15
|
|
@@ -20,7 +20,7 @@ class TimeSeriesHistogram(Metric):
|
|
20
20
|
(kurtosis) underlying the data.
|
21
21
|
|
22
22
|
**Test Mechanism**: This test operates on a specific column within the dataset that is required to have a datetime
|
23
|
-
type index. It goes through each column in the given dataset, creating a histogram with
|
23
|
+
type index. It goes through each column in the given dataset, creating a histogram with Plotly's histplot
|
24
24
|
function. In cases where the dataset includes more than one time-series (i.e., more than one column with a datetime
|
25
25
|
type index), a distinct histogram is plotted for each series. Additionally, a kernel density estimate (KDE) line is
|
26
26
|
drawn for each histogram, providing a visualization of the data's underlying probability distribution. The x and
|
@@ -48,46 +48,30 @@ class TimeSeriesHistogram(Metric):
|
|
48
48
|
- The histogram's shape may be sensitive to the number of bins used.
|
49
49
|
"""
|
50
50
|
|
51
|
-
|
52
|
-
required_inputs = ["dataset"]
|
53
|
-
metadata = {
|
54
|
-
"task_types": ["regression"],
|
55
|
-
"tags": ["time_series_data", "visualization"],
|
56
|
-
}
|
51
|
+
df = dataset.df
|
57
52
|
|
58
|
-
|
59
|
-
# Check if index is datetime
|
60
|
-
if not pd.api.types.is_datetime64_any_dtype(self.inputs.dataset.df.index):
|
61
|
-
raise ValueError("Index must be a datetime type")
|
53
|
+
columns = list(dataset.df.columns)
|
62
54
|
|
63
|
-
|
55
|
+
if not set(columns).issubset(set(df.columns)):
|
56
|
+
raise ValueError("Provided 'columns' must exist in the dataset")
|
64
57
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
figures.append(
|
82
|
-
Figure(
|
83
|
-
for_object=self,
|
84
|
-
key=f"{self.key}:{col}",
|
85
|
-
figure=fig,
|
86
|
-
)
|
87
|
-
)
|
88
|
-
|
89
|
-
plt.close("all")
|
90
|
-
|
91
|
-
return self.cache_results(
|
92
|
-
figures=figures,
|
58
|
+
figures = []
|
59
|
+
for col in columns:
|
60
|
+
fig = px.histogram(
|
61
|
+
df, x=col, marginal="violin", nbins=nbins, title=f"Histogram for {col}"
|
62
|
+
)
|
63
|
+
fig.update_layout(
|
64
|
+
title={
|
65
|
+
"text": f"Histogram for {col}",
|
66
|
+
"y": 0.9,
|
67
|
+
"x": 0.5,
|
68
|
+
"xanchor": "center",
|
69
|
+
"yanchor": "top",
|
70
|
+
},
|
71
|
+
xaxis_title="",
|
72
|
+
yaxis_title="",
|
73
|
+
font=dict(size=18),
|
93
74
|
)
|
75
|
+
figures.append(fig)
|
76
|
+
|
77
|
+
return tuple(figures)
|
@@ -4,11 +4,8 @@
|
|
4
4
|
|
5
5
|
from dataclasses import dataclass
|
6
6
|
|
7
|
-
import matplotlib.pyplot as plt
|
8
7
|
import pandas as pd
|
9
|
-
import
|
10
|
-
from ydata_profiling.config import Settings
|
11
|
-
from ydata_profiling.model.typeset import ProfilingTypeSet
|
8
|
+
import plotly.graph_objects as go
|
12
9
|
|
13
10
|
from validmind.vm_models import (
|
14
11
|
Figure,
|
@@ -93,7 +90,8 @@ class TimeSeriesOutliers(ThresholdTest):
|
|
93
90
|
zScores = first_result.values["z-score"]
|
94
91
|
dates = first_result.values["Date"]
|
95
92
|
passFail = [
|
96
|
-
"Pass" if z < self.params["zscore_threshold"] else "Fail"
|
93
|
+
"Pass" if abs(z) < self.params["zscore_threshold"] else "Fail"
|
94
|
+
for z in zScores
|
97
95
|
]
|
98
96
|
|
99
97
|
return ResultSummary(
|
@@ -116,25 +114,26 @@ class TimeSeriesOutliers(ThresholdTest):
|
|
116
114
|
)
|
117
115
|
|
118
116
|
def run(self):
|
117
|
+
# Initialize the test_results list
|
118
|
+
test_results = []
|
119
|
+
|
119
120
|
# Check if the index of dataframe is datetime
|
120
121
|
is_datetime = pd.api.types.is_datetime64_any_dtype(self.inputs.dataset.df.index)
|
121
122
|
if not is_datetime:
|
122
123
|
raise ValueError("Dataset must be provided with datetime index")
|
123
124
|
|
124
|
-
# Validate threshold
|
125
|
+
# Validate threshold parameter
|
125
126
|
if "zscore_threshold" not in self.params:
|
126
127
|
raise ValueError("zscore_threshold must be provided in params")
|
127
128
|
zscore_threshold = self.params["zscore_threshold"]
|
128
129
|
|
129
130
|
temp_df = self.inputs.dataset.df.copy()
|
130
131
|
# temp_df = temp_df.dropna()
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
k for k, v in dataset_types.items() if str(v) == "Numeric"
|
137
|
-
]
|
132
|
+
|
133
|
+
# Infer numeric columns
|
134
|
+
num_features_columns = temp_df.select_dtypes(
|
135
|
+
include=["number"]
|
136
|
+
).columns.tolist()
|
138
137
|
|
139
138
|
outliers_table = self.identify_outliers(
|
140
139
|
temp_df[num_features_columns], zscore_threshold
|
@@ -196,49 +195,39 @@ class TimeSeriesOutliers(ThresholdTest):
|
|
196
195
|
df (pandas.DataFrame): Input data with time series.
|
197
196
|
outliers_table (pandas.DataFrame): DataFrame with identified outliers.
|
198
197
|
Returns:
|
199
|
-
|
198
|
+
list: A list of Figure objects with subplots for each variable.
|
200
199
|
"""
|
201
|
-
sns.set(style="darkgrid")
|
202
|
-
columns = list(self.inputs.dataset.df.columns)
|
203
200
|
figures = []
|
204
201
|
|
205
|
-
for col in columns:
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
ax = sns.lineplot(data=df.reset_index(), x=column_index_name, y=col)
|
202
|
+
for col in df.columns:
|
203
|
+
fig = go.Figure()
|
204
|
+
|
205
|
+
fig.add_trace(go.Scatter(x=df.index, y=df[col], mode="lines", name=col))
|
210
206
|
|
211
207
|
if not outliers_table.empty:
|
212
208
|
variable_outliers = outliers_table[outliers_table["Variable"] == col]
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
s=100,
|
221
|
-
c="red",
|
222
|
-
label="Outlier" if idx == 0 else "",
|
209
|
+
fig.add_trace(
|
210
|
+
go.Scatter(
|
211
|
+
x=variable_outliers["Date"],
|
212
|
+
y=df.loc[variable_outliers["Date"], col],
|
213
|
+
mode="markers",
|
214
|
+
marker=dict(color="red", size=10),
|
215
|
+
name="Outlier",
|
223
216
|
)
|
217
|
+
)
|
224
218
|
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
ax.set_title(
|
230
|
-
f"Time Series with Outliers for {col}", weight="bold", fontsize=20
|
219
|
+
fig.update_layout(
|
220
|
+
title=f"Time Series with Outliers for {col}",
|
221
|
+
xaxis_title="Date",
|
222
|
+
yaxis_title=col,
|
231
223
|
)
|
232
224
|
|
233
|
-
ax.legend()
|
234
225
|
figures.append(
|
235
226
|
Figure(
|
236
227
|
for_object=self,
|
237
|
-
key=f"{self.name}:{col}",
|
228
|
+
key=f"{self.name}:{col}_{self.inputs.dataset.input_id}",
|
238
229
|
figure=fig,
|
239
230
|
)
|
240
231
|
)
|
241
232
|
|
242
|
-
# Do this if you want to prevent the figure from being displayed
|
243
|
-
plt.close("all")
|
244
233
|
return figures
|
validmind/tests/decorator.py
CHANGED
@@ -226,6 +226,18 @@ def _get_save_func(func, test_id):
|
|
226
226
|
|
227
227
|
|
228
228
|
def metric(func_or_id):
|
229
|
+
"""
|
230
|
+
DEPRECATED, use @vm.test instead
|
231
|
+
"""
|
232
|
+
# print a deprecation notice and call the test() function instead
|
233
|
+
logger.warning(
|
234
|
+
"The @vm.metric decorator is deprecated and will be removed in a future release. "
|
235
|
+
"Please use @vm.test instead."
|
236
|
+
)
|
237
|
+
return test(func_or_id)
|
238
|
+
|
239
|
+
|
240
|
+
def test(func_or_id):
|
229
241
|
"""Decorator for creating and registering metrics with the ValidMind framework.
|
230
242
|
|
231
243
|
Creates a metric object and registers it with ValidMind under the provided ID. If
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import pandas as pd
|
6
|
+
|
7
|
+
from validmind import tags, tasks
|
8
|
+
from validmind.utils import get_model_info
|
9
|
+
|
10
|
+
|
11
|
+
@tags("model_training", "metadata")
|
12
|
+
@tasks("regression", "time_series_forecasting")
|
13
|
+
def ModelMetadataComparison(models):
|
14
|
+
"""
|
15
|
+
Compare metadata of different models and generate a summary table with the results.
|
16
|
+
|
17
|
+
**Purpose**: The purpose of this function is to compare the metadata of different models, including information about their architecture, framework, framework version, and programming language.
|
18
|
+
|
19
|
+
**Test Mechanism**: The function retrieves the metadata for each model using `get_model_info`, renames columns according to a predefined set of labels, and compiles this information into a summary table.
|
20
|
+
|
21
|
+
**Signs of High Risk**:
|
22
|
+
- Inconsistent or missing metadata across models can indicate potential issues in model documentation or management.
|
23
|
+
- Significant differences in framework versions or programming languages might pose challenges in model integration and deployment.
|
24
|
+
|
25
|
+
**Strengths**:
|
26
|
+
- Provides a clear comparison of essential model metadata.
|
27
|
+
- Standardizes metadata labels for easier interpretation and comparison.
|
28
|
+
- Helps identify potential compatibility or consistency issues across models.
|
29
|
+
|
30
|
+
**Limitations**:
|
31
|
+
- Assumes that the `get_model_info` function returns all necessary metadata fields.
|
32
|
+
- Relies on the correctness and completeness of the metadata provided by each model.
|
33
|
+
- Does not include detailed parameter information, focusing instead on high-level metadata.
|
34
|
+
"""
|
35
|
+
column_labels = {
|
36
|
+
"architecture": "Modeling Technique",
|
37
|
+
"framework": "Modeling Framework",
|
38
|
+
"framework_version": "Framework Version",
|
39
|
+
"language": "Programming Language",
|
40
|
+
}
|
41
|
+
|
42
|
+
description = []
|
43
|
+
|
44
|
+
for model in models:
|
45
|
+
model_info = get_model_info(model)
|
46
|
+
|
47
|
+
# Rename columns based on provided labels
|
48
|
+
model_info_renamed = {
|
49
|
+
column_labels.get(k, k): v for k, v in model_info.items() if k != "params"
|
50
|
+
}
|
51
|
+
|
52
|
+
# Add model name or identifier if available
|
53
|
+
model_info_renamed = {"Model Name": model.input_id, **model_info_renamed}
|
54
|
+
|
55
|
+
description.append(model_info_renamed)
|
56
|
+
|
57
|
+
description_df = pd.DataFrame(description)
|
58
|
+
|
59
|
+
return description_df
|