validmind 2.5.6__py3-none-any.whl → 2.5.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
validmind/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "2.5.6"
1
+ __version__ = "2.5.8"
@@ -7,9 +7,9 @@ from dataclasses import dataclass
7
7
  from operator import add
8
8
  from typing import List, Tuple
9
9
 
10
- import matplotlib.pyplot as plt
11
10
  import numpy as np
12
11
  import pandas as pd
12
+ import plotly.graph_objects as go
13
13
  import seaborn as sns
14
14
  from sklearn import metrics
15
15
 
@@ -132,24 +132,28 @@ def _combine_results(results: List[dict]):
132
132
 
133
133
 
134
134
  def _plot_robustness(
135
- results: pd.DataFrame, metric: str, threshold: float, columns: List[str]
135
+ results: pd.DataFrame, metric: str, threshold: float, columns: List[str], model: str
136
136
  ):
137
- fig, ax = plt.subplots()
138
-
139
- pallete = sns.color_palette("muted", len(results["Dataset"].unique()))
140
- sns.lineplot(
141
- data=results,
142
- x="Perturbation Size",
143
- y=metric.upper(),
144
- hue="Dataset",
145
- style="Dataset",
146
- linewidth=3,
147
- markers=True,
148
- markersize=10,
149
- dashes=False,
150
- palette=pallete,
151
- ax=ax,
152
- )
137
+ fig = go.Figure()
138
+
139
+ datasets = results["Dataset"].unique()
140
+ pallete = [
141
+ f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
142
+ for r, g, b in sns.color_palette("husl", len(datasets))
143
+ ]
144
+
145
+ for i, dataset in enumerate(datasets):
146
+ dataset_results = results[results["Dataset"] == dataset]
147
+ fig.add_trace(
148
+ go.Scatter(
149
+ x=dataset_results["Perturbation Size"],
150
+ y=dataset_results[metric.upper()],
151
+ mode="lines+markers",
152
+ name=dataset,
153
+ line=dict(width=3, color=pallete[i]),
154
+ marker=dict(size=10),
155
+ )
156
+ )
153
157
 
154
158
  if PERFORMANCE_METRICS[metric]["is_lower_better"]:
155
159
  y_label = f"{metric.upper()} (lower is better)"
@@ -157,33 +161,64 @@ def _plot_robustness(
157
161
  threshold = -threshold
158
162
  y_label = f"{metric.upper()} (higher is better)"
159
163
 
160
- # add dotted threshold line
161
- for i in range(len(results["Dataset"].unique())):
162
- baseline = results[results["Dataset"] == results["Dataset"].unique()[i]][
163
- metric.upper()
164
- ].iloc[0]
165
- ax.axhline(
166
- y=baseline + threshold,
167
- color=pallete[i],
168
- linestyle="dotted",
164
+ # add threshold lines
165
+ for i, dataset in enumerate(datasets):
166
+ baseline = results[results["Dataset"] == dataset][metric.upper()].iloc[0]
167
+ fig.add_trace(
168
+ go.Scatter(
169
+ x=results["Perturbation Size"].unique(),
170
+ y=[baseline + threshold] * len(results["Perturbation Size"].unique()),
171
+ mode="lines",
172
+ name=f"threshold_{dataset}",
173
+ line=dict(dash="dash", width=2, color=pallete[i]),
174
+ showlegend=True,
175
+ )
169
176
  )
170
177
 
171
- ax.tick_params(axis="x")
172
- ax.set_ylabel(y_label, weight="bold", fontsize=18)
173
- ax.legend(fontsize=18)
174
- ax.set_xlabel(
175
- "Perturbation Size (X * Standard Deviation)", weight="bold", fontsize=18
176
- )
177
- ax.set_title(
178
- f"Perturbed Features: {', '.join(columns)}",
179
- weight="bold",
180
- fontsize=20,
181
- wrap=True,
178
+ columns_lines = [""]
179
+ for column in columns:
180
+ # keep adding to the last line in list until character limit (40)
181
+ if len(columns_lines[-1]) + len(column) < 40:
182
+ columns_lines[-1] += f"{column}, "
183
+ else:
184
+ columns_lines.append(f"{column}, ")
185
+
186
+ fig.update_layout(
187
+ title=dict(
188
+ text=(
189
+ f"Model Robustness for '{model}'<br><sup>As determined by calculating "
190
+ f"{metric.upper()} decay in the presence of random gaussian noise</sup>"
191
+ ),
192
+ font=dict(size=20),
193
+ x=0.5,
194
+ xanchor="center",
195
+ ),
196
+ xaxis_title=dict(
197
+ text="Perturbation Size (X * Standard Deviation)",
198
+ ),
199
+ yaxis_title=dict(text=y_label),
200
+ plot_bgcolor="white",
201
+ margin=dict(t=60, b=80, r=20, l=60),
202
+ xaxis=dict(showgrid=True, gridcolor="lightgrey"),
203
+ yaxis=dict(showgrid=True, gridcolor="lightgrey"),
204
+ annotations=[
205
+ go.layout.Annotation(
206
+ text=f"Perturbed Features:<br><sup>{'<br>'.join(columns_lines)}</sup>",
207
+ align="left",
208
+ font=dict(size=14),
209
+ bordercolor="lightgrey",
210
+ borderwidth=1,
211
+ borderpad=4,
212
+ showarrow=False,
213
+ x=1.025,
214
+ xref="paper",
215
+ xanchor="left",
216
+ y=-0.15,
217
+ yref="paper",
218
+ )
219
+ ],
182
220
  )
183
221
 
184
- # prevent the figure from being displayed
185
- plt.close("all")
186
-
187
222
  return fig
188
223
 
189
224
 
@@ -267,6 +302,7 @@ def robustness_diagnosis(
267
302
  metric=metric,
268
303
  threshold=performance_decay_threshold,
269
304
  columns=datasets[0].feature_columns_numeric,
305
+ model=model.input_id,
270
306
  )
271
307
 
272
308
  # rename perturbation size for baseline
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: validmind
3
- Version: 2.5.6
3
+ Version: 2.5.8
4
4
  Summary: ValidMind Developer Framework
5
5
  License: Commercial License
6
6
  Author: Andres Rodriguez
@@ -1,5 +1,5 @@
1
1
  validmind/__init__.py,sha256=UfmzPwUCdUWbWq3zPqqmq4jw0_kfl3hX4U72p_seE4I,3700
2
- validmind/__version__.py,sha256=1X5kk-wx4BPu-LJgtjkfia2ZtDoOHOAsSyEsLKKQCY0,22
2
+ validmind/__version__.py,sha256=mNA8KAyMUolRKqUZCQp6s1ZGetufDZcybBUJHOyKaZA,22
3
3
  validmind/ai/test_descriptions.py,sha256=Q1Ftus4x5eiVLKWJu7hqPLukBQZzhy-dARqq_6_JWtk,9464
4
4
  validmind/ai/utils.py,sha256=TEXII_S5CpkpczzSyHwTlqLcPMLnPBJWEBR6QFMKh1U,3421
5
5
  validmind/api_client.py,sha256=JZIJWuYtvl-VEVi_AK4c839Fn7cGa40J2d4_4FUZcno,17483
@@ -226,7 +226,7 @@ validmind/tests/model_validation/sklearn/RegressionErrorsComparison.py,sha256=CH
226
226
  validmind/tests/model_validation/sklearn/RegressionModelsPerformanceComparison.py,sha256=ELYhY_My1YqS4_i2fnHgL5Dg7vKUIa0wska0bkAFkuU,5737
227
227
  validmind/tests/model_validation/sklearn/RegressionR2Square.py,sha256=Ojm5sz3re4rk17u7xiezn1P_rp7wcA3etKgzdhGYH-s,4906
228
228
  validmind/tests/model_validation/sklearn/RegressionR2SquareComparison.py,sha256=tGJKpfeTvU2xBxsYbQSC5GPDcCS2_j0FcT3uceXZduI,2761
229
- validmind/tests/model_validation/sklearn/RobustnessDiagnosis.py,sha256=KXBQ5-3ZDLil3WBZR-zWySelN_zb5Ob4Uvuoi1qfFaI,12821
229
+ validmind/tests/model_validation/sklearn/RobustnessDiagnosis.py,sha256=-DyGzQ0PItOISGqtgn2b0WVGG3hycg3lRdgjFM_jPdk,14400
230
230
  validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py,sha256=ECYjHHIz5kfnLi2XlzWOKquRf23_77kdcPK8Xw2qwQk,8887
231
231
  validmind/tests/model_validation/sklearn/SilhouettePlot.py,sha256=6PZ_sqiPBpL4_fyRE_sg0bSWWrDkryh_v-88KK4i3RQ,6185
232
232
  validmind/tests/model_validation/sklearn/TrainingTestDegradation.py,sha256=K3F8Ev7nIaIjwLHC9ljnMp07YwZeqo4RLui5C6IDuR8,7209
@@ -311,8 +311,8 @@ validmind/vm_models/test_suite/runner.py,sha256=aewxadRfoOPH48jes2Gtb3Ju_FWFfVM_
311
311
  validmind/vm_models/test_suite/summary.py,sha256=GQRNe2ZvvqjQN0yKmaN7ohAUjRFQIN4YYUYxfOuWN6M,4682
312
312
  validmind/vm_models/test_suite/test.py,sha256=_GfbK36l98SjzgVcucmp0OKBJKqMW3neO7SqJ3EWeps,5049
313
313
  validmind/vm_models/test_suite/test_suite.py,sha256=Cns2wL54v0T5Mv5_HJb3kMeaa4rtycdqT8KxK9_rWEU,6279
314
- validmind-2.5.6.dist-info/LICENSE,sha256=XonPUfwjvrC5Ombl3y-ko0Wubb1xdG_7nzvIbkZRKHw,35772
315
- validmind-2.5.6.dist-info/METADATA,sha256=cPAO_Hlc8esuZm1W96GFv09amk1VNxn7Oh33iUPUfbI,4242
316
- validmind-2.5.6.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
317
- validmind-2.5.6.dist-info/entry_points.txt,sha256=HuW7YyOv9u_OEWpViQXtv0nfoI67uieJHawKWA4Hv9A,76
318
- validmind-2.5.6.dist-info/RECORD,,
314
+ validmind-2.5.8.dist-info/LICENSE,sha256=XonPUfwjvrC5Ombl3y-ko0Wubb1xdG_7nzvIbkZRKHw,35772
315
+ validmind-2.5.8.dist-info/METADATA,sha256=YrAvv1MV1wQ1q4FaqUSvJNVP3ZSC_P9AeY4GY0pFiEI,4242
316
+ validmind-2.5.8.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
317
+ validmind-2.5.8.dist-info/entry_points.txt,sha256=HuW7YyOv9u_OEWpViQXtv0nfoI67uieJHawKWA4Hv9A,76
318
+ validmind-2.5.8.dist-info/RECORD,,