validmind 2.3.3__py3-none-any.whl → 2.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__version__.py +1 -1
- validmind/datasets/regression/fred_timeseries.py +272 -0
- validmind/tests/__types__.py +10 -0
- validmind/tests/data_validation/SeasonalDecompose.py +68 -40
- validmind/tests/data_validation/TimeSeriesDescription.py +74 -0
- validmind/tests/data_validation/TimeSeriesDescriptiveStatistics.py +76 -0
- validmind/tests/data_validation/TimeSeriesHistogram.py +29 -45
- validmind/tests/data_validation/TimeSeriesOutliers.py +30 -41
- validmind/tests/model_validation/ModelMetadataComparison.py +59 -0
- validmind/tests/model_validation/ModelPredictionResiduals.py +103 -0
- validmind/tests/model_validation/TimeSeriesPredictionWithCI.py +131 -0
- validmind/tests/model_validation/TimeSeriesPredictionsPlot.py +76 -0
- validmind/tests/model_validation/TimeSeriesR2SquareBySegments.py +103 -0
- validmind/tests/model_validation/sklearn/FeatureImportanceComparison.py +83 -0
- validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +1 -1
- validmind/tests/model_validation/sklearn/RegressionErrorsComparison.py +76 -0
- validmind/tests/model_validation/sklearn/RegressionR2SquareComparison.py +63 -0
- {validmind-2.3.3.dist-info → validmind-2.3.5.dist-info}/METADATA +70 -36
- {validmind-2.3.3.dist-info → validmind-2.3.5.dist-info}/RECORD +23 -12
- /validmind/datasets/regression/datasets/{lending_club_loan_rates.csv → leanding_club_loan_rates.csv} +0 -0
- {validmind-2.3.3.dist-info → validmind-2.3.5.dist-info}/LICENSE +0 -0
- {validmind-2.3.3.dist-info → validmind-2.3.5.dist-info}/WHEEL +0 -0
- {validmind-2.3.3.dist-info → validmind-2.3.5.dist-info}/entry_points.txt +0 -0
@@ -2,14 +2,14 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
-
import
|
6
|
-
import pandas as pd
|
7
|
-
import seaborn as sns
|
5
|
+
import plotly.express as px
|
8
6
|
|
9
|
-
from validmind
|
7
|
+
from validmind import tags, tasks
|
10
8
|
|
11
9
|
|
12
|
-
|
10
|
+
@tags("data_validation", "visualization")
|
11
|
+
@tasks("regression", "time_series_forecasting")
|
12
|
+
def TimeSeriesHistogram(dataset, nbins=30):
|
13
13
|
"""
|
14
14
|
Visualizes distribution of time-series data using histograms and Kernel Density Estimation (KDE) lines.
|
15
15
|
|
@@ -20,7 +20,7 @@ class TimeSeriesHistogram(Metric):
|
|
20
20
|
(kurtosis) underlying the data.
|
21
21
|
|
22
22
|
**Test Mechanism**: This test operates on a specific column within the dataset that is required to have a datetime
|
23
|
-
type index. It goes through each column in the given dataset, creating a histogram with
|
23
|
+
type index. It goes through each column in the given dataset, creating a histogram with Plotly's histplot
|
24
24
|
function. In cases where the dataset includes more than one time-series (i.e., more than one column with a datetime
|
25
25
|
type index), a distinct histogram is plotted for each series. Additionally, a kernel density estimate (KDE) line is
|
26
26
|
drawn for each histogram, providing a visualization of the data's underlying probability distribution. The x and
|
@@ -48,46 +48,30 @@ class TimeSeriesHistogram(Metric):
|
|
48
48
|
- The histogram's shape may be sensitive to the number of bins used.
|
49
49
|
"""
|
50
50
|
|
51
|
-
|
52
|
-
required_inputs = ["dataset"]
|
53
|
-
metadata = {
|
54
|
-
"task_types": ["regression"],
|
55
|
-
"tags": ["time_series_data", "visualization"],
|
56
|
-
}
|
51
|
+
df = dataset.df
|
57
52
|
|
58
|
-
|
59
|
-
# Check if index is datetime
|
60
|
-
if not pd.api.types.is_datetime64_any_dtype(self.inputs.dataset.df.index):
|
61
|
-
raise ValueError("Index must be a datetime type")
|
53
|
+
columns = list(dataset.df.columns)
|
62
54
|
|
63
|
-
|
55
|
+
if not set(columns).issubset(set(df.columns)):
|
56
|
+
raise ValueError("Provided 'columns' must exist in the dataset")
|
64
57
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
figures.append(
|
82
|
-
Figure(
|
83
|
-
for_object=self,
|
84
|
-
key=f"{self.key}:{col}",
|
85
|
-
figure=fig,
|
86
|
-
)
|
87
|
-
)
|
88
|
-
|
89
|
-
plt.close("all")
|
90
|
-
|
91
|
-
return self.cache_results(
|
92
|
-
figures=figures,
|
58
|
+
figures = []
|
59
|
+
for col in columns:
|
60
|
+
fig = px.histogram(
|
61
|
+
df, x=col, marginal="violin", nbins=nbins, title=f"Histogram for {col}"
|
62
|
+
)
|
63
|
+
fig.update_layout(
|
64
|
+
title={
|
65
|
+
"text": f"Histogram for {col}",
|
66
|
+
"y": 0.9,
|
67
|
+
"x": 0.5,
|
68
|
+
"xanchor": "center",
|
69
|
+
"yanchor": "top",
|
70
|
+
},
|
71
|
+
xaxis_title="",
|
72
|
+
yaxis_title="",
|
73
|
+
font=dict(size=18),
|
93
74
|
)
|
75
|
+
figures.append(fig)
|
76
|
+
|
77
|
+
return tuple(figures)
|
@@ -4,11 +4,8 @@
|
|
4
4
|
|
5
5
|
from dataclasses import dataclass
|
6
6
|
|
7
|
-
import matplotlib.pyplot as plt
|
8
7
|
import pandas as pd
|
9
|
-
import
|
10
|
-
from ydata_profiling.config import Settings
|
11
|
-
from ydata_profiling.model.typeset import ProfilingTypeSet
|
8
|
+
import plotly.graph_objects as go
|
12
9
|
|
13
10
|
from validmind.vm_models import (
|
14
11
|
Figure,
|
@@ -93,7 +90,8 @@ class TimeSeriesOutliers(ThresholdTest):
|
|
93
90
|
zScores = first_result.values["z-score"]
|
94
91
|
dates = first_result.values["Date"]
|
95
92
|
passFail = [
|
96
|
-
"Pass" if z < self.params["zscore_threshold"] else "Fail"
|
93
|
+
"Pass" if abs(z) < self.params["zscore_threshold"] else "Fail"
|
94
|
+
for z in zScores
|
97
95
|
]
|
98
96
|
|
99
97
|
return ResultSummary(
|
@@ -116,25 +114,26 @@ class TimeSeriesOutliers(ThresholdTest):
|
|
116
114
|
)
|
117
115
|
|
118
116
|
def run(self):
|
117
|
+
# Initialize the test_results list
|
118
|
+
test_results = []
|
119
|
+
|
119
120
|
# Check if the index of dataframe is datetime
|
120
121
|
is_datetime = pd.api.types.is_datetime64_any_dtype(self.inputs.dataset.df.index)
|
121
122
|
if not is_datetime:
|
122
123
|
raise ValueError("Dataset must be provided with datetime index")
|
123
124
|
|
124
|
-
# Validate threshold
|
125
|
+
# Validate threshold parameter
|
125
126
|
if "zscore_threshold" not in self.params:
|
126
127
|
raise ValueError("zscore_threshold must be provided in params")
|
127
128
|
zscore_threshold = self.params["zscore_threshold"]
|
128
129
|
|
129
130
|
temp_df = self.inputs.dataset.df.copy()
|
130
131
|
# temp_df = temp_df.dropna()
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
k for k, v in dataset_types.items() if str(v) == "Numeric"
|
137
|
-
]
|
132
|
+
|
133
|
+
# Infer numeric columns
|
134
|
+
num_features_columns = temp_df.select_dtypes(
|
135
|
+
include=["number"]
|
136
|
+
).columns.tolist()
|
138
137
|
|
139
138
|
outliers_table = self.identify_outliers(
|
140
139
|
temp_df[num_features_columns], zscore_threshold
|
@@ -196,49 +195,39 @@ class TimeSeriesOutliers(ThresholdTest):
|
|
196
195
|
df (pandas.DataFrame): Input data with time series.
|
197
196
|
outliers_table (pandas.DataFrame): DataFrame with identified outliers.
|
198
197
|
Returns:
|
199
|
-
|
198
|
+
list: A list of Figure objects with subplots for each variable.
|
200
199
|
"""
|
201
|
-
sns.set(style="darkgrid")
|
202
|
-
columns = list(self.inputs.dataset.df.columns)
|
203
200
|
figures = []
|
204
201
|
|
205
|
-
for col in columns:
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
ax = sns.lineplot(data=df.reset_index(), x=column_index_name, y=col)
|
202
|
+
for col in df.columns:
|
203
|
+
fig = go.Figure()
|
204
|
+
|
205
|
+
fig.add_trace(go.Scatter(x=df.index, y=df[col], mode="lines", name=col))
|
210
206
|
|
211
207
|
if not outliers_table.empty:
|
212
208
|
variable_outliers = outliers_table[outliers_table["Variable"] == col]
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
s=100,
|
221
|
-
c="red",
|
222
|
-
label="Outlier" if idx == 0 else "",
|
209
|
+
fig.add_trace(
|
210
|
+
go.Scatter(
|
211
|
+
x=variable_outliers["Date"],
|
212
|
+
y=df.loc[variable_outliers["Date"], col],
|
213
|
+
mode="markers",
|
214
|
+
marker=dict(color="red", size=10),
|
215
|
+
name="Outlier",
|
223
216
|
)
|
217
|
+
)
|
224
218
|
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
ax.set_title(
|
230
|
-
f"Time Series with Outliers for {col}", weight="bold", fontsize=20
|
219
|
+
fig.update_layout(
|
220
|
+
title=f"Time Series with Outliers for {col}",
|
221
|
+
xaxis_title="Date",
|
222
|
+
yaxis_title=col,
|
231
223
|
)
|
232
224
|
|
233
|
-
ax.legend()
|
234
225
|
figures.append(
|
235
226
|
Figure(
|
236
227
|
for_object=self,
|
237
|
-
key=f"{self.name}:{col}",
|
228
|
+
key=f"{self.name}:{col}_{self.inputs.dataset.input_id}",
|
238
229
|
figure=fig,
|
239
230
|
)
|
240
231
|
)
|
241
232
|
|
242
|
-
# Do this if you want to prevent the figure from being displayed
|
243
|
-
plt.close("all")
|
244
233
|
return figures
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import pandas as pd
|
6
|
+
|
7
|
+
from validmind import tags, tasks
|
8
|
+
from validmind.utils import get_model_info
|
9
|
+
|
10
|
+
|
11
|
+
@tags("model_training", "metadata")
|
12
|
+
@tasks("regression", "time_series_forecasting")
|
13
|
+
def ModelMetadataComparison(models):
|
14
|
+
"""
|
15
|
+
Compare metadata of different models and generate a summary table with the results.
|
16
|
+
|
17
|
+
**Purpose**: The purpose of this function is to compare the metadata of different models, including information about their architecture, framework, framework version, and programming language.
|
18
|
+
|
19
|
+
**Test Mechanism**: The function retrieves the metadata for each model using `get_model_info`, renames columns according to a predefined set of labels, and compiles this information into a summary table.
|
20
|
+
|
21
|
+
**Signs of High Risk**:
|
22
|
+
- Inconsistent or missing metadata across models can indicate potential issues in model documentation or management.
|
23
|
+
- Significant differences in framework versions or programming languages might pose challenges in model integration and deployment.
|
24
|
+
|
25
|
+
**Strengths**:
|
26
|
+
- Provides a clear comparison of essential model metadata.
|
27
|
+
- Standardizes metadata labels for easier interpretation and comparison.
|
28
|
+
- Helps identify potential compatibility or consistency issues across models.
|
29
|
+
|
30
|
+
**Limitations**:
|
31
|
+
- Assumes that the `get_model_info` function returns all necessary metadata fields.
|
32
|
+
- Relies on the correctness and completeness of the metadata provided by each model.
|
33
|
+
- Does not include detailed parameter information, focusing instead on high-level metadata.
|
34
|
+
"""
|
35
|
+
column_labels = {
|
36
|
+
"architecture": "Modeling Technique",
|
37
|
+
"framework": "Modeling Framework",
|
38
|
+
"framework_version": "Framework Version",
|
39
|
+
"language": "Programming Language",
|
40
|
+
}
|
41
|
+
|
42
|
+
description = []
|
43
|
+
|
44
|
+
for model in models:
|
45
|
+
model_info = get_model_info(model)
|
46
|
+
|
47
|
+
# Rename columns based on provided labels
|
48
|
+
model_info_renamed = {
|
49
|
+
column_labels.get(k, k): v for k, v in model_info.items() if k != "params"
|
50
|
+
}
|
51
|
+
|
52
|
+
# Add model name or identifier if available
|
53
|
+
model_info_renamed = {"Model Name": model.input_id, **model_info_renamed}
|
54
|
+
|
55
|
+
description.append(model_info_renamed)
|
56
|
+
|
57
|
+
description_df = pd.DataFrame(description)
|
58
|
+
|
59
|
+
return description_df
|
@@ -0,0 +1,103 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import pandas as pd
|
6
|
+
import plotly.graph_objects as go
|
7
|
+
from scipy.stats import kstest
|
8
|
+
|
9
|
+
from validmind import tags, tasks
|
10
|
+
|
11
|
+
|
12
|
+
@tags("regression")
|
13
|
+
@tasks("residual_analysis", "visualization")
|
14
|
+
def ModelPredictionResiduals(
|
15
|
+
datasets, models, nbins=100, p_value_threshold=0.05, start_date=None, end_date=None
|
16
|
+
):
|
17
|
+
"""
|
18
|
+
Plot the residuals and histograms for each model, and generate a summary table
|
19
|
+
with the Kolmogorov-Smirnov normality test results.
|
20
|
+
|
21
|
+
**Purpose**: The purpose of this function is to visualize the residuals of model predictions and
|
22
|
+
assess the normality of residuals using the Kolmogorov-Smirnov test.
|
23
|
+
|
24
|
+
**Test Mechanism**: The function iterates through each dataset-model pair, calculates residuals, and generates
|
25
|
+
two figures for each model: one for the time series of residuals and one for the histogram of residuals.
|
26
|
+
It also calculates the KS test for normality and summarizes the results in a table.
|
27
|
+
|
28
|
+
**Signs of High Risk**:
|
29
|
+
- If the residuals are not normally distributed, it could indicate issues with model assumptions.
|
30
|
+
- High skewness or kurtosis in the residuals may indicate model misspecification.
|
31
|
+
|
32
|
+
**Strengths**:
|
33
|
+
- Provides a clear visualization of residuals over time and their distribution.
|
34
|
+
- Includes statistical tests to assess the normality of residuals.
|
35
|
+
|
36
|
+
**Limitations**:
|
37
|
+
- Assumes that the dataset is provided as a DataFrameDataset object with a .df attribute to access
|
38
|
+
the pandas DataFrame.
|
39
|
+
- Only generates plots for datasets with a datetime index, and will raise an error for other types of indices.
|
40
|
+
"""
|
41
|
+
|
42
|
+
figures = []
|
43
|
+
summary = []
|
44
|
+
|
45
|
+
for dataset, model in zip(datasets, models):
|
46
|
+
df = dataset.df.copy()
|
47
|
+
|
48
|
+
# Filter DataFrame by date range if specified
|
49
|
+
if start_date:
|
50
|
+
df = df[df.index >= pd.to_datetime(start_date)]
|
51
|
+
if end_date:
|
52
|
+
df = df[df.index <= pd.to_datetime(end_date)]
|
53
|
+
|
54
|
+
y_true = dataset.y
|
55
|
+
y_pred = dataset.y_pred(model)
|
56
|
+
residuals = y_true - y_pred
|
57
|
+
|
58
|
+
# Plot residuals
|
59
|
+
residuals_fig = go.Figure()
|
60
|
+
residuals_fig.add_trace(
|
61
|
+
go.Scatter(x=df.index, y=residuals, mode="lines", name="Residuals")
|
62
|
+
)
|
63
|
+
residuals_fig.update_layout(
|
64
|
+
title=f"Residuals for {model.input_id}",
|
65
|
+
xaxis_title="Date",
|
66
|
+
yaxis_title="Residuals",
|
67
|
+
font=dict(size=16),
|
68
|
+
showlegend=False,
|
69
|
+
)
|
70
|
+
figures.append(residuals_fig)
|
71
|
+
|
72
|
+
# Plot histogram of residuals
|
73
|
+
hist_fig = go.Figure()
|
74
|
+
hist_fig.add_trace(go.Histogram(x=residuals, nbinsx=nbins, name="Residuals"))
|
75
|
+
hist_fig.update_layout(
|
76
|
+
title=f"Histogram of Residuals for {model.input_id}",
|
77
|
+
xaxis_title="Residuals",
|
78
|
+
yaxis_title="Frequency",
|
79
|
+
font=dict(size=16),
|
80
|
+
showlegend=False,
|
81
|
+
)
|
82
|
+
figures.append(hist_fig)
|
83
|
+
|
84
|
+
# Perform KS normality test
|
85
|
+
ks_stat, p_value = kstest(
|
86
|
+
residuals, "norm", args=(residuals.mean(), residuals.std())
|
87
|
+
)
|
88
|
+
ks_normality = "Normal" if p_value > p_value_threshold else "Not Normal"
|
89
|
+
|
90
|
+
summary.append(
|
91
|
+
{
|
92
|
+
"Model": model.input_id,
|
93
|
+
"KS Statistic": ks_stat,
|
94
|
+
"p-value": p_value,
|
95
|
+
"KS Normality": ks_normality,
|
96
|
+
"p-value Threshold": p_value_threshold,
|
97
|
+
}
|
98
|
+
)
|
99
|
+
|
100
|
+
# Create a summary DataFrame for the KS normality test results
|
101
|
+
summary_df = pd.DataFrame(summary)
|
102
|
+
|
103
|
+
return (summary_df, *figures)
|
@@ -0,0 +1,131 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import pandas as pd
|
7
|
+
import plotly.graph_objects as go
|
8
|
+
from scipy.stats import norm
|
9
|
+
|
10
|
+
from validmind import tags, tasks
|
11
|
+
|
12
|
+
|
13
|
+
@tags("model_predictions", "visualization")
|
14
|
+
@tasks("regression", "time_series_forecasting")
|
15
|
+
def TimeSeriesPredictionWithCI(dataset, model, confidence=0.95):
|
16
|
+
"""
|
17
|
+
Plot actual vs predicted values for a time series with confidence intervals and compute breaches.
|
18
|
+
|
19
|
+
**Purpose**: The purpose of this function is to visualize the actual versus predicted values for time series data, including confidence intervals, and to compute and report the number of breaches beyond these intervals.
|
20
|
+
|
21
|
+
**Test Mechanism**: The function calculates the standard deviation of prediction errors, determines the confidence intervals, and counts the number of actual values that fall outside these intervals (breaches). It then generates a plot with the actual values, predicted values, and confidence intervals, and returns a DataFrame summarizing the breach information.
|
22
|
+
|
23
|
+
**Signs of High Risk**:
|
24
|
+
- A high number of breaches indicates that the model's predictions are not reliable within the specified confidence level.
|
25
|
+
- Significant deviations between actual and predicted values may highlight model inadequacies or issues with data quality.
|
26
|
+
|
27
|
+
**Strengths**:
|
28
|
+
- Provides a visual representation of prediction accuracy and the uncertainty around predictions.
|
29
|
+
- Includes a statistical measure of prediction reliability through confidence intervals.
|
30
|
+
- Computes and reports breaches, offering a quantitative assessment of prediction performance.
|
31
|
+
|
32
|
+
**Limitations**:
|
33
|
+
- Assumes that the dataset is provided as a DataFrameDataset object with a datetime index.
|
34
|
+
- Requires that `dataset.y_pred(model)` returns the predicted values for the model.
|
35
|
+
- The calculation of confidence intervals assumes normally distributed errors, which may not hold for all datasets.
|
36
|
+
"""
|
37
|
+
dataset_name = dataset.input_id
|
38
|
+
model_name = model.input_id
|
39
|
+
time_index = dataset.df.index # Assuming the index of the dataset is datetime
|
40
|
+
|
41
|
+
# Get actual and predicted values
|
42
|
+
y_true = dataset.y
|
43
|
+
y_pred = dataset.y_pred(model)
|
44
|
+
|
45
|
+
# Compute the standard deviation of the errors
|
46
|
+
errors = y_true - y_pred
|
47
|
+
std_error = np.std(errors)
|
48
|
+
|
49
|
+
# Compute z-score for the given confidence level
|
50
|
+
z_score = norm.ppf(1 - (1 - confidence) / 2)
|
51
|
+
|
52
|
+
# Compute confidence intervals
|
53
|
+
lower_conf = y_pred - z_score * std_error
|
54
|
+
upper_conf = y_pred + z_score * std_error
|
55
|
+
|
56
|
+
# Calculate breaches
|
57
|
+
upper_breaches = (y_true > upper_conf).sum()
|
58
|
+
lower_breaches = (y_true < lower_conf).sum()
|
59
|
+
total_breaches = upper_breaches + lower_breaches
|
60
|
+
|
61
|
+
# Create DataFrame
|
62
|
+
breaches_df = pd.DataFrame(
|
63
|
+
{
|
64
|
+
"Confidence Level": [confidence],
|
65
|
+
"Total Breaches": [total_breaches],
|
66
|
+
"Upper Breaches": [upper_breaches],
|
67
|
+
"Lower Breaches": [lower_breaches],
|
68
|
+
}
|
69
|
+
)
|
70
|
+
|
71
|
+
# Plotting
|
72
|
+
fig = go.Figure()
|
73
|
+
|
74
|
+
# Plot actual values
|
75
|
+
fig.add_trace(
|
76
|
+
go.Scatter(
|
77
|
+
x=time_index,
|
78
|
+
y=y_true,
|
79
|
+
mode="lines",
|
80
|
+
name="Actual Values",
|
81
|
+
line=dict(color="blue"),
|
82
|
+
)
|
83
|
+
)
|
84
|
+
|
85
|
+
# Plot predicted values
|
86
|
+
fig.add_trace(
|
87
|
+
go.Scatter(
|
88
|
+
x=time_index,
|
89
|
+
y=y_pred,
|
90
|
+
mode="lines",
|
91
|
+
name=f"Predicted by {model_name}",
|
92
|
+
line=dict(color="red"),
|
93
|
+
)
|
94
|
+
)
|
95
|
+
|
96
|
+
# Add confidence interval lower bound as an invisible line
|
97
|
+
fig.add_trace(
|
98
|
+
go.Scatter(
|
99
|
+
x=time_index,
|
100
|
+
y=lower_conf,
|
101
|
+
mode="lines",
|
102
|
+
line=dict(width=0),
|
103
|
+
showlegend=False,
|
104
|
+
name="CI Lower",
|
105
|
+
)
|
106
|
+
)
|
107
|
+
|
108
|
+
# Add confidence interval upper bound and fill area
|
109
|
+
fig.add_trace(
|
110
|
+
go.Scatter(
|
111
|
+
x=time_index,
|
112
|
+
y=upper_conf,
|
113
|
+
mode="lines",
|
114
|
+
fill="tonexty",
|
115
|
+
fillcolor="rgba(200, 200, 200, 0.5)",
|
116
|
+
line=dict(width=0),
|
117
|
+
showlegend=True,
|
118
|
+
name="Confidence Interval",
|
119
|
+
)
|
120
|
+
)
|
121
|
+
|
122
|
+
# Update layout
|
123
|
+
fig.update_layout(
|
124
|
+
title=f"Time Series Actual vs Predicted Values for {dataset_name} and {model_name}",
|
125
|
+
xaxis_title="Time",
|
126
|
+
yaxis_title="Values",
|
127
|
+
legend_title="Legend",
|
128
|
+
template="plotly_white",
|
129
|
+
)
|
130
|
+
|
131
|
+
return fig, breaches_df
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import plotly.express as px
|
6
|
+
import plotly.graph_objects as go
|
7
|
+
|
8
|
+
from validmind import tags, tasks
|
9
|
+
|
10
|
+
|
11
|
+
@tags("model_predictions", "visualization")
|
12
|
+
@tasks("regression", "time_series_forecasting")
|
13
|
+
def TimeSeriesPredictionsPlot(datasets, models):
|
14
|
+
"""
|
15
|
+
Plot actual vs predicted values for time series data and generate a visual comparison for each model.
|
16
|
+
|
17
|
+
**Purpose**: The purpose of this function is to visualize the actual versus predicted values for time series data across different models.
|
18
|
+
|
19
|
+
**Test Mechanism**: The function iterates through each dataset-model pair, plots the actual values from the dataset, and overlays the predicted values from each model using Plotly for interactive visualization.
|
20
|
+
|
21
|
+
**Signs of High Risk**:
|
22
|
+
- Large discrepancies between actual and predicted values indicate poor model performance.
|
23
|
+
- Systematic deviations in predicted values can highlight model bias or issues with data patterns.
|
24
|
+
|
25
|
+
**Strengths**:
|
26
|
+
- Provides a clear visual comparison of model predictions against actual values.
|
27
|
+
- Uses Plotly for interactive and visually appealing plots.
|
28
|
+
- Can handle multiple models and datasets, displaying them with distinct colors.
|
29
|
+
|
30
|
+
**Limitations**:
|
31
|
+
- Assumes that the dataset is provided as a DataFrameDataset object with a datetime index.
|
32
|
+
- Requires that `dataset.y_pred(model)` returns the predicted values for the model.
|
33
|
+
- Visualization might become cluttered with a large number of models or datasets.
|
34
|
+
"""
|
35
|
+
fig = go.Figure()
|
36
|
+
|
37
|
+
# Use Plotly's color sequence for different model predictions
|
38
|
+
colors = px.colors.qualitative.Plotly
|
39
|
+
|
40
|
+
# Plot actual values from the first dataset
|
41
|
+
dataset = datasets[0]
|
42
|
+
time_index = dataset.df.index # Assuming the index of the dataset is datetime
|
43
|
+
fig.add_trace(
|
44
|
+
go.Scatter(
|
45
|
+
x=time_index,
|
46
|
+
y=dataset.y,
|
47
|
+
mode="lines",
|
48
|
+
name="Actual Values",
|
49
|
+
line=dict(color="blue"),
|
50
|
+
)
|
51
|
+
)
|
52
|
+
|
53
|
+
# Plot predicted values for each dataset-model pair
|
54
|
+
for idx, (dataset, model) in enumerate(zip(datasets, models)):
|
55
|
+
model_name = model.input_id
|
56
|
+
y_pred = dataset.y_pred(model)
|
57
|
+
fig.add_trace(
|
58
|
+
go.Scatter(
|
59
|
+
x=time_index,
|
60
|
+
y=y_pred,
|
61
|
+
mode="lines",
|
62
|
+
name=f"Predicted by {model_name}",
|
63
|
+
line=dict(color=colors[idx % len(colors)]),
|
64
|
+
)
|
65
|
+
)
|
66
|
+
|
67
|
+
# Update layout
|
68
|
+
fig.update_layout(
|
69
|
+
title="Time Series Actual vs Predicted Values",
|
70
|
+
xaxis_title="Time",
|
71
|
+
yaxis_title="Values",
|
72
|
+
legend_title="Legend",
|
73
|
+
template="plotly_white",
|
74
|
+
)
|
75
|
+
|
76
|
+
return fig
|