validmind 2.3.1__py3-none-any.whl → 2.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. validmind/__init__.py +2 -1
  2. validmind/__version__.py +1 -1
  3. validmind/datasets/regression/fred_timeseries.py +272 -0
  4. validmind/test_suites/__init__.py +0 -2
  5. validmind/tests/__init__.py +7 -7
  6. validmind/tests/__types__.py +180 -0
  7. validmind/tests/data_validation/SeasonalDecompose.py +68 -40
  8. validmind/tests/data_validation/TimeSeriesDescription.py +74 -0
  9. validmind/tests/data_validation/TimeSeriesDescriptiveStatistics.py +76 -0
  10. validmind/tests/data_validation/TimeSeriesHistogram.py +29 -45
  11. validmind/tests/data_validation/TimeSeriesOutliers.py +30 -41
  12. validmind/tests/decorator.py +12 -0
  13. validmind/tests/model_validation/ModelMetadataComparison.py +59 -0
  14. validmind/tests/model_validation/ModelPredictionResiduals.py +103 -0
  15. validmind/tests/model_validation/TimeSeriesPredictionWithCI.py +131 -0
  16. validmind/tests/model_validation/TimeSeriesPredictionsPlot.py +76 -0
  17. validmind/tests/model_validation/TimeSeriesR2SquareBySegments.py +103 -0
  18. validmind/tests/model_validation/sklearn/FeatureImportanceComparison.py +83 -0
  19. validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +1 -1
  20. validmind/tests/model_validation/sklearn/RegressionErrorsComparison.py +76 -0
  21. validmind/tests/model_validation/sklearn/RegressionR2SquareComparison.py +63 -0
  22. validmind/utils.py +34 -0
  23. {validmind-2.3.1.dist-info → validmind-2.3.5.dist-info}/METADATA +70 -36
  24. {validmind-2.3.1.dist-info → validmind-2.3.5.dist-info}/RECORD +28 -16
  25. /validmind/datasets/regression/datasets/{lending_club_loan_rates.csv → leanding_club_loan_rates.csv} +0 -0
  26. {validmind-2.3.1.dist-info → validmind-2.3.5.dist-info}/LICENSE +0 -0
  27. {validmind-2.3.1.dist-info → validmind-2.3.5.dist-info}/WHEEL +0 -0
  28. {validmind-2.3.1.dist-info → validmind-2.3.5.dist-info}/entry_points.txt +0 -0
@@ -4,10 +4,10 @@
4
4
 
5
5
  import warnings
6
6
 
7
- import matplotlib.pyplot as plt
8
7
  import numpy as np
9
8
  import pandas as pd
10
- import seaborn as sns
9
+ import plotly.graph_objects as go
10
+ from plotly.subplots import make_subplots
11
11
  from scipy import stats
12
12
  from statsmodels.tsa.seasonal import seasonal_decompose
13
13
 
@@ -132,7 +132,6 @@ class SeasonalDecompose(Metric):
132
132
  inferred_freq = pd.infer_freq(series.index)
133
133
 
134
134
  if inferred_freq is not None:
135
- logger.info(f"Frequency of {col}: {inferred_freq}")
136
135
 
137
136
  # Only take finite values to seasonal_decompose
138
137
  sd = seasonal_decompose(
@@ -142,58 +141,87 @@ class SeasonalDecompose(Metric):
142
141
 
143
142
  results[col] = self.serialize_seasonal_decompose(sd)
144
143
 
145
- # Create subplots
146
- fig, axes = plt.subplots(3, 2)
147
- width, _ = fig.get_size_inches()
148
- fig.set_size_inches(width, 15)
149
- fig.subplots_adjust(hspace=0.3)
150
- fig.suptitle(
151
- f"Seasonal Decomposition for {col}",
152
- fontsize=20,
153
- weight="bold",
154
- y=0.95,
144
+ # Create subplots using Plotly
145
+ fig = make_subplots(
146
+ rows=3,
147
+ cols=2,
148
+ subplot_titles=(
149
+ "Observed",
150
+ "Trend",
151
+ "Seasonal",
152
+ "Residuals",
153
+ "Histogram and KDE of Residuals",
154
+ "Normal Q-Q Plot of Residuals",
155
+ ),
156
+ vertical_spacing=0.1,
155
157
  )
156
158
 
157
- # Original seasonal decomposition plots
158
159
  # Observed
159
- sd.observed.plot(ax=axes[0, 0])
160
- axes[0, 0].set_title("Observed", fontsize=18)
161
- axes[0, 0].set_xlabel("")
162
- axes[0, 0].tick_params(axis="both", labelsize=18)
160
+ fig.add_trace(
161
+ go.Scatter(x=sd.observed.index, y=sd.observed, name="Observed"),
162
+ row=1,
163
+ col=1,
164
+ )
163
165
 
164
166
  # Trend
165
- sd.trend.plot(ax=axes[0, 1])
166
- axes[0, 1].set_title("Trend", fontsize=18)
167
- axes[0, 1].set_xlabel("")
168
- axes[0, 1].tick_params(axis="both", labelsize=18)
167
+ fig.add_trace(
168
+ go.Scatter(x=sd.trend.index, y=sd.trend, name="Trend"),
169
+ row=1,
170
+ col=2,
171
+ )
169
172
 
170
173
  # Seasonal
171
- sd.seasonal.plot(ax=axes[1, 0])
172
- axes[1, 0].set_title("Seasonal", fontsize=18)
173
- axes[1, 0].set_xlabel("")
174
- axes[1, 0].tick_params(axis="both", labelsize=18)
174
+ fig.add_trace(
175
+ go.Scatter(x=sd.seasonal.index, y=sd.seasonal, name="Seasonal"),
176
+ row=2,
177
+ col=1,
178
+ )
175
179
 
176
180
  # Residuals
177
- sd.resid.plot(ax=axes[1, 1])
178
- axes[1, 1].set_title("Residuals", fontsize=18)
179
- axes[1, 1].set_xlabel("")
180
- axes[1, 1].tick_params(axis="both", labelsize=18)
181
+ fig.add_trace(
182
+ go.Scatter(x=sd.resid.index, y=sd.resid, name="Residuals"),
183
+ row=2,
184
+ col=2,
185
+ )
181
186
 
182
187
  # Histogram with KDE
183
188
  residuals = sd.resid.dropna()
184
- sns.histplot(residuals, kde=True, ax=axes[2, 0])
185
- axes[2, 0].set_title("Histogram and KDE of Residuals", fontsize=18)
186
- axes[2, 0].set_xlabel("")
187
- axes[2, 0].tick_params(axis="both", labelsize=18)
189
+ fig.add_trace(
190
+ go.Histogram(x=residuals, nbinsx=100, name="Residuals"),
191
+ row=3,
192
+ col=1,
193
+ )
188
194
 
189
195
  # Normal Q-Q plot
190
- stats.probplot(residuals, plot=axes[2, 1])
191
- axes[2, 1].set_title("Normal Q-Q Plot of Residuals", fontsize=18)
192
- axes[2, 1].set_xlabel("")
193
- axes[2, 1].tick_params(axis="both", labelsize=18)
196
+ qq = stats.probplot(residuals, plot=None)
197
+ qq_line_slope, qq_line_intercept = stats.linregress(
198
+ qq[0][0], qq[0][1]
199
+ )[:2]
200
+ qq_line = qq_line_slope * np.array(qq[0][0]) + qq_line_intercept
201
+
202
+ fig.add_trace(
203
+ go.Scatter(
204
+ x=qq[0][0], y=qq[0][1], mode="markers", name="QQ plot"
205
+ ),
206
+ row=3,
207
+ col=2,
208
+ )
209
+ fig.add_trace(
210
+ go.Scatter(
211
+ x=qq[0][0],
212
+ y=qq_line,
213
+ mode="lines",
214
+ name="QQ line",
215
+ ),
216
+ row=3,
217
+ col=2,
218
+ )
194
219
 
195
- # Do this if you want to prevent the figure from being displayed
196
- plt.close("all")
220
+ fig.update_layout(
221
+ height=1000,
222
+ title_text=f"Seasonal Decomposition for {col}",
223
+ showlegend=False,
224
+ )
197
225
 
198
226
  figures.append(
199
227
  Figure(
@@ -0,0 +1,74 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import pandas as pd
6
+
7
+ from validmind import tags, tasks
8
+
9
+
10
+ @tags("time_series_data", "analysis")
11
+ @tasks("regression")
12
+ def TimeSeriesDescription(dataset):
13
+ """
14
+ Generates a detailed analysis for the provided time series dataset.
15
+
16
+ **Purpose**: The purpose of the TimeSeriesDescription function is to analyze an individual time series
17
+ by providing a summary of key statistics. This helps in understanding trends, patterns, and data quality issues
18
+ within the time series.
19
+
20
+ **Test Mechanism**: The function extracts the time series data and provides a summary of key statistics.
21
+ The dataset is expected to have a datetime index. The function checks this and raises an error if the index is
22
+ not in datetime format. For each variable (column) in the dataset, appropriate statistics including start date,
23
+ end date, frequency, number of missing values, count, min, and max values are calculated.
24
+
25
+ **Signs of High Risk**:
26
+ - If the index of the dataset is not in datetime format, it could lead to errors in time-series analysis.
27
+ - Inconsistent or missing data within the dataset might affect the analysis of trends and patterns.
28
+
29
+ **Strengths**:
30
+ - This function provides a comprehensive summary of key statistics for each variable, helping to identify data quality
31
+ issues such as missing values.
32
+ - The function helps in understanding the distribution and range of the data by including min and max values.
33
+
34
+ **Limitations**:
35
+ - This function assumes that the dataset is provided as a DataFrameDataset object with a .df attribute to access
36
+ the pandas DataFrame.
37
+ - It only analyzes datasets with a datetime index and will raise an error for other types of indices.
38
+ - The function does not handle large datasets efficiently, and performance may degrade with very large datasets.
39
+ """
40
+
41
+ summary = []
42
+
43
+ df = (
44
+ dataset.df
45
+ ) # Assuming DataFrameDataset objects have a .df attribute to get the pandas DataFrame
46
+
47
+ if not pd.api.types.is_datetime64_any_dtype(df.index):
48
+ raise ValueError(f"Dataset {dataset.input_id} must have a datetime index")
49
+
50
+ for column in df.columns:
51
+ start_date = df.index.min().strftime("%Y-%m-%d")
52
+ end_date = df.index.max().strftime("%Y-%m-%d")
53
+ frequency = pd.infer_freq(df.index)
54
+ num_missing_values = df[column].isna().sum()
55
+ count = df[column].count()
56
+ min_value = df[column].min()
57
+ max_value = df[column].max()
58
+
59
+ summary.append(
60
+ {
61
+ "Variable": column,
62
+ "Start Date": start_date,
63
+ "End Date": end_date,
64
+ "Frequency": frequency,
65
+ "Num of Missing Values": num_missing_values,
66
+ "Count": count,
67
+ "Min Value": min_value,
68
+ "Max Value": max_value,
69
+ }
70
+ )
71
+
72
+ result_df = pd.DataFrame(summary)
73
+
74
+ return result_df
@@ -0,0 +1,76 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import pandas as pd
6
+ from scipy.stats import kurtosis, skew
7
+
8
+ from validmind import tags, tasks
9
+
10
+
11
+ @tags("time_series_data", "analysis")
12
+ @tasks("regression")
13
+ def TimeSeriesDescriptiveStatistics(dataset):
14
+ """
15
+ Generates a detailed table of descriptive statistics for the provided time series dataset.
16
+
17
+ **Purpose**: The purpose of the TimeSeriesDescriptiveStatistics function is to analyze an individual time series
18
+ by providing a summary of key descriptive statistics. This helps in understanding trends, patterns, and data quality issues
19
+ within the time series.
20
+
21
+ **Test Mechanism**: The function extracts the time series data and provides a summary of key descriptive statistics.
22
+ The dataset is expected to have a datetime index. The function checks this and raises an error if the index is
23
+ not in datetime format. For each variable (column) in the dataset, appropriate statistics including start date,
24
+ end date, min, mean, max, skewness, kurtosis, and count are calculated.
25
+
26
+ **Signs of High Risk**:
27
+ - If the index of the dataset is not in datetime format, it could lead to errors in time-series analysis.
28
+ - Inconsistent or missing data within the dataset might affect the analysis of trends and patterns.
29
+
30
+ **Strengths**:
31
+ - This function provides a comprehensive summary of key descriptive statistics for each variable, helping to identify data quality
32
+ issues and understand the distribution of the data.
33
+
34
+ **Limitations**:
35
+ - This function assumes that the dataset is provided as a DataFrameDataset object with a .df attribute to access
36
+ the pandas DataFrame.
37
+ - It only analyzes datasets with a datetime index and will raise an error for other types of indices.
38
+ - The function does not handle large datasets efficiently, and performance may degrade with very large datasets.
39
+ """
40
+
41
+ summary = []
42
+
43
+ df = (
44
+ dataset.df
45
+ ) # Assuming DataFrameDataset objects have a .df attribute to get the pandas DataFrame
46
+
47
+ if not pd.api.types.is_datetime64_any_dtype(df.index):
48
+ raise ValueError(f"Dataset {dataset.input_id} must have a datetime index")
49
+
50
+ for column in df.columns:
51
+ start_date = df.index.min().strftime("%Y-%m-%d")
52
+ end_date = df.index.max().strftime("%Y-%m-%d")
53
+ count = df[column].count()
54
+ min_value = df[column].min()
55
+ mean_value = df[column].mean()
56
+ max_value = df[column].max()
57
+ skewness_value = skew(df[column].dropna())
58
+ kurtosis_value = kurtosis(df[column].dropna())
59
+
60
+ summary.append(
61
+ {
62
+ "Variable": column,
63
+ "Start Date": start_date,
64
+ "End Date": end_date,
65
+ "Min": min_value,
66
+ "Mean": mean_value,
67
+ "Max": max_value,
68
+ "Skewness": skewness_value,
69
+ "Kurtosis": kurtosis_value,
70
+ "Count": count,
71
+ }
72
+ )
73
+
74
+ result_df = pd.DataFrame(summary)
75
+
76
+ return result_df
@@ -2,14 +2,14 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
- import matplotlib.pyplot as plt
6
- import pandas as pd
7
- import seaborn as sns
5
+ import plotly.express as px
8
6
 
9
- from validmind.vm_models import Figure, Metric
7
+ from validmind import tags, tasks
10
8
 
11
9
 
12
- class TimeSeriesHistogram(Metric):
10
+ @tags("data_validation", "visualization")
11
+ @tasks("regression", "time_series_forecasting")
12
+ def TimeSeriesHistogram(dataset, nbins=30):
13
13
  """
14
14
  Visualizes distribution of time-series data using histograms and Kernel Density Estimation (KDE) lines.
15
15
 
@@ -20,7 +20,7 @@ class TimeSeriesHistogram(Metric):
20
20
  (kurtosis) underlying the data.
21
21
 
22
22
  **Test Mechanism**: This test operates on a specific column within the dataset that is required to have a datetime
23
- type index. It goes through each column in the given dataset, creating a histogram with Seaborn's histplot
23
+ type index. It goes through each column in the given dataset, creating a histogram with Plotly's histplot
24
24
  function. In cases where the dataset includes more than one time-series (i.e., more than one column with a datetime
25
25
  type index), a distinct histogram is plotted for each series. Additionally, a kernel density estimate (KDE) line is
26
26
  drawn for each histogram, providing a visualization of the data's underlying probability distribution. The x and
@@ -48,46 +48,30 @@ class TimeSeriesHistogram(Metric):
48
48
  - The histogram's shape may be sensitive to the number of bins used.
49
49
  """
50
50
 
51
- name = "time_series_histogram"
52
- required_inputs = ["dataset"]
53
- metadata = {
54
- "task_types": ["regression"],
55
- "tags": ["time_series_data", "visualization"],
56
- }
51
+ df = dataset.df
57
52
 
58
- def run(self):
59
- # Check if index is datetime
60
- if not pd.api.types.is_datetime64_any_dtype(self.inputs.dataset.df.index):
61
- raise ValueError("Index must be a datetime type")
53
+ columns = list(dataset.df.columns)
62
54
 
63
- columns = list(self.inputs.dataset.df.columns)
55
+ if not set(columns).issubset(set(df.columns)):
56
+ raise ValueError("Provided 'columns' must exist in the dataset")
64
57
 
65
- df = self.inputs.dataset.df
66
-
67
- if not set(columns).issubset(set(df.columns)):
68
- raise ValueError("Provided 'columns' must exist in the dataset")
69
-
70
- figures = []
71
- for col in columns:
72
- plt.figure()
73
- fig, _ = plt.subplots()
74
- ax = sns.histplot(data=df, x=col, kde=True)
75
- plt.title(f"Histogram for {col}", weight="bold", fontsize=20)
76
-
77
- plt.xticks(fontsize=18)
78
- plt.yticks(fontsize=18)
79
- ax.set_xlabel("")
80
- ax.set_ylabel("")
81
- figures.append(
82
- Figure(
83
- for_object=self,
84
- key=f"{self.key}:{col}",
85
- figure=fig,
86
- )
87
- )
88
-
89
- plt.close("all")
90
-
91
- return self.cache_results(
92
- figures=figures,
58
+ figures = []
59
+ for col in columns:
60
+ fig = px.histogram(
61
+ df, x=col, marginal="violin", nbins=nbins, title=f"Histogram for {col}"
62
+ )
63
+ fig.update_layout(
64
+ title={
65
+ "text": f"Histogram for {col}",
66
+ "y": 0.9,
67
+ "x": 0.5,
68
+ "xanchor": "center",
69
+ "yanchor": "top",
70
+ },
71
+ xaxis_title="",
72
+ yaxis_title="",
73
+ font=dict(size=18),
93
74
  )
75
+ figures.append(fig)
76
+
77
+ return tuple(figures)
@@ -4,11 +4,8 @@
4
4
 
5
5
  from dataclasses import dataclass
6
6
 
7
- import matplotlib.pyplot as plt
8
7
  import pandas as pd
9
- import seaborn as sns
10
- from ydata_profiling.config import Settings
11
- from ydata_profiling.model.typeset import ProfilingTypeSet
8
+ import plotly.graph_objects as go
12
9
 
13
10
  from validmind.vm_models import (
14
11
  Figure,
@@ -93,7 +90,8 @@ class TimeSeriesOutliers(ThresholdTest):
93
90
  zScores = first_result.values["z-score"]
94
91
  dates = first_result.values["Date"]
95
92
  passFail = [
96
- "Pass" if z < self.params["zscore_threshold"] else "Fail" for z in zScores
93
+ "Pass" if abs(z) < self.params["zscore_threshold"] else "Fail"
94
+ for z in zScores
97
95
  ]
98
96
 
99
97
  return ResultSummary(
@@ -116,25 +114,26 @@ class TimeSeriesOutliers(ThresholdTest):
116
114
  )
117
115
 
118
116
  def run(self):
117
+ # Initialize the test_results list
118
+ test_results = []
119
+
119
120
  # Check if the index of dataframe is datetime
120
121
  is_datetime = pd.api.types.is_datetime64_any_dtype(self.inputs.dataset.df.index)
121
122
  if not is_datetime:
122
123
  raise ValueError("Dataset must be provided with datetime index")
123
124
 
124
- # Validate threshold paremeter
125
+ # Validate threshold parameter
125
126
  if "zscore_threshold" not in self.params:
126
127
  raise ValueError("zscore_threshold must be provided in params")
127
128
  zscore_threshold = self.params["zscore_threshold"]
128
129
 
129
130
  temp_df = self.inputs.dataset.df.copy()
130
131
  # temp_df = temp_df.dropna()
131
- typeset = ProfilingTypeSet(Settings())
132
- dataset_types = typeset.infer_type(temp_df)
133
- test_results = []
134
- test_figures = []
135
- num_features_columns = [
136
- k for k, v in dataset_types.items() if str(v) == "Numeric"
137
- ]
132
+
133
+ # Infer numeric columns
134
+ num_features_columns = temp_df.select_dtypes(
135
+ include=["number"]
136
+ ).columns.tolist()
138
137
 
139
138
  outliers_table = self.identify_outliers(
140
139
  temp_df[num_features_columns], zscore_threshold
@@ -196,49 +195,39 @@ class TimeSeriesOutliers(ThresholdTest):
196
195
  df (pandas.DataFrame): Input data with time series.
197
196
  outliers_table (pandas.DataFrame): DataFrame with identified outliers.
198
197
  Returns:
199
- matplotlib.figure.Figure: A matplotlib figure object with subplots for each variable.
198
+ list: A list of Figure objects with subplots for each variable.
200
199
  """
201
- sns.set(style="darkgrid")
202
- columns = list(self.inputs.dataset.df.columns)
203
200
  figures = []
204
201
 
205
- for col in columns:
206
- plt.figure()
207
- fig, _ = plt.subplots()
208
- column_index_name = df.index.name
209
- ax = sns.lineplot(data=df.reset_index(), x=column_index_name, y=col)
202
+ for col in df.columns:
203
+ fig = go.Figure()
204
+
205
+ fig.add_trace(go.Scatter(x=df.index, y=df[col], mode="lines", name=col))
210
206
 
211
207
  if not outliers_table.empty:
212
208
  variable_outliers = outliers_table[outliers_table["Variable"] == col]
213
- for idx, row in variable_outliers.iterrows():
214
- date = row["Date"]
215
- outlier_value = df.loc[date, col]
216
- ax.scatter(
217
- date,
218
- outlier_value,
219
- marker="o",
220
- s=100,
221
- c="red",
222
- label="Outlier" if idx == 0 else "",
209
+ fig.add_trace(
210
+ go.Scatter(
211
+ x=variable_outliers["Date"],
212
+ y=df.loc[variable_outliers["Date"], col],
213
+ mode="markers",
214
+ marker=dict(color="red", size=10),
215
+ name="Outlier",
223
216
  )
217
+ )
224
218
 
225
- plt.xticks(fontsize=18)
226
- plt.yticks(fontsize=18)
227
- ax.set_xlabel("")
228
- ax.set_ylabel("")
229
- ax.set_title(
230
- f"Time Series with Outliers for {col}", weight="bold", fontsize=20
219
+ fig.update_layout(
220
+ title=f"Time Series with Outliers for {col}",
221
+ xaxis_title="Date",
222
+ yaxis_title=col,
231
223
  )
232
224
 
233
- ax.legend()
234
225
  figures.append(
235
226
  Figure(
236
227
  for_object=self,
237
- key=f"{self.name}:{col}",
228
+ key=f"{self.name}:{col}_{self.inputs.dataset.input_id}",
238
229
  figure=fig,
239
230
  )
240
231
  )
241
232
 
242
- # Do this if you want to prevent the figure from being displayed
243
- plt.close("all")
244
233
  return figures
@@ -226,6 +226,18 @@ def _get_save_func(func, test_id):
226
226
 
227
227
 
228
228
  def metric(func_or_id):
229
+ """
230
+ DEPRECATED, use @vm.test instead
231
+ """
232
+ # print a deprecation notice and call the test() function instead
233
+ logger.warning(
234
+ "The @vm.metric decorator is deprecated and will be removed in a future release. "
235
+ "Please use @vm.test instead."
236
+ )
237
+ return test(func_or_id)
238
+
239
+
240
+ def test(func_or_id):
229
241
  """Decorator for creating and registering metrics with the ValidMind framework.
230
242
 
231
243
  Creates a metric object and registers it with ValidMind under the provided ID. If
@@ -0,0 +1,59 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import pandas as pd
6
+
7
+ from validmind import tags, tasks
8
+ from validmind.utils import get_model_info
9
+
10
+
11
+ @tags("model_training", "metadata")
12
+ @tasks("regression", "time_series_forecasting")
13
+ def ModelMetadataComparison(models):
14
+ """
15
+ Compare metadata of different models and generate a summary table with the results.
16
+
17
+ **Purpose**: The purpose of this function is to compare the metadata of different models, including information about their architecture, framework, framework version, and programming language.
18
+
19
+ **Test Mechanism**: The function retrieves the metadata for each model using `get_model_info`, renames columns according to a predefined set of labels, and compiles this information into a summary table.
20
+
21
+ **Signs of High Risk**:
22
+ - Inconsistent or missing metadata across models can indicate potential issues in model documentation or management.
23
+ - Significant differences in framework versions or programming languages might pose challenges in model integration and deployment.
24
+
25
+ **Strengths**:
26
+ - Provides a clear comparison of essential model metadata.
27
+ - Standardizes metadata labels for easier interpretation and comparison.
28
+ - Helps identify potential compatibility or consistency issues across models.
29
+
30
+ **Limitations**:
31
+ - Assumes that the `get_model_info` function returns all necessary metadata fields.
32
+ - Relies on the correctness and completeness of the metadata provided by each model.
33
+ - Does not include detailed parameter information, focusing instead on high-level metadata.
34
+ """
35
+ column_labels = {
36
+ "architecture": "Modeling Technique",
37
+ "framework": "Modeling Framework",
38
+ "framework_version": "Framework Version",
39
+ "language": "Programming Language",
40
+ }
41
+
42
+ description = []
43
+
44
+ for model in models:
45
+ model_info = get_model_info(model)
46
+
47
+ # Rename columns based on provided labels
48
+ model_info_renamed = {
49
+ column_labels.get(k, k): v for k, v in model_info.items() if k != "params"
50
+ }
51
+
52
+ # Add model name or identifier if available
53
+ model_info_renamed = {"Model Name": model.input_id, **model_info_renamed}
54
+
55
+ description.append(model_info_renamed)
56
+
57
+ description_df = pd.DataFrame(description)
58
+
59
+ return description_df