MIDRC-MELODY 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- MIDRC_MELODY/__init__.py +0 -0
- MIDRC_MELODY/__main__.py +4 -0
- MIDRC_MELODY/common/__init__.py +0 -0
- MIDRC_MELODY/common/data_loading.py +199 -0
- MIDRC_MELODY/common/data_preprocessing.py +134 -0
- MIDRC_MELODY/common/edit_config.py +156 -0
- MIDRC_MELODY/common/eod_aaod_metrics.py +292 -0
- MIDRC_MELODY/common/generate_eod_aaod_spiders.py +69 -0
- MIDRC_MELODY/common/generate_qwk_spiders.py +56 -0
- MIDRC_MELODY/common/matplotlib_spider.py +425 -0
- MIDRC_MELODY/common/plot_tools.py +132 -0
- MIDRC_MELODY/common/plotly_spider.py +217 -0
- MIDRC_MELODY/common/qwk_metrics.py +244 -0
- MIDRC_MELODY/common/table_tools.py +230 -0
- MIDRC_MELODY/gui/__init__.py +0 -0
- MIDRC_MELODY/gui/config_editor.py +200 -0
- MIDRC_MELODY/gui/data_loading.py +157 -0
- MIDRC_MELODY/gui/main_controller.py +154 -0
- MIDRC_MELODY/gui/main_window.py +545 -0
- MIDRC_MELODY/gui/matplotlib_spider_widget.py +204 -0
- MIDRC_MELODY/gui/metrics_model.py +62 -0
- MIDRC_MELODY/gui/plotly_spider_widget.py +56 -0
- MIDRC_MELODY/gui/qchart_spider_widget.py +272 -0
- MIDRC_MELODY/gui/shared/__init__.py +0 -0
- MIDRC_MELODY/gui/shared/react/__init__.py +0 -0
- MIDRC_MELODY/gui/shared/react/copyabletableview.py +100 -0
- MIDRC_MELODY/gui/shared/react/grabbablewidget.py +406 -0
- MIDRC_MELODY/gui/tqdm_handler.py +210 -0
- MIDRC_MELODY/melody.py +102 -0
- MIDRC_MELODY/melody_gui.py +111 -0
- MIDRC_MELODY/resources/MIDRC.ico +0 -0
- midrc_melody-0.3.3.dist-info/METADATA +151 -0
- midrc_melody-0.3.3.dist-info/RECORD +37 -0
- midrc_melody-0.3.3.dist-info/WHEEL +5 -0
- midrc_melody-0.3.3.dist-info/entry_points.txt +4 -0
- midrc_melody-0.3.3.dist-info/licenses/LICENSE +201 -0
- midrc_melody-0.3.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
# Copyright (c) 2025 Medical Imaging and Data Resource Center (MIDRC).
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
from typing import List, Tuple, Any, Optional
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
from matplotlib import pyplot as plt
|
|
21
|
+
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
|
22
|
+
from matplotlib.collections import PathCollection
|
|
23
|
+
import mplcursors
|
|
24
|
+
|
|
25
|
+
from MIDRC_MELODY.common.plot_tools import SpiderPlotData, prepare_and_sort, get_full_theta, compute_angles
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def plot_spider_chart(spider_data: SpiderPlotData) -> plt.Figure:
|
|
29
|
+
"""
|
|
30
|
+
Plot a spider chart for the given groups, values, and bounds.
|
|
31
|
+
|
|
32
|
+
:arg spider_data: SpiderPlotData object containing the following fields:
|
|
33
|
+
- `model_name`: Name of the model
|
|
34
|
+
- `groups` (List[str]): List of group names
|
|
35
|
+
- `values`: List of values for each group
|
|
36
|
+
- `lower_bounds`: List of lower bounds for each group
|
|
37
|
+
- `upper_bounds`: List of upper bounds for each group
|
|
38
|
+
- `ylim_min`: Dict of metric and Minimum value for the y-axis
|
|
39
|
+
- `ylim_max`: Dict of metric and Maximum value for the y-axis
|
|
40
|
+
- `metric`: Metric to display on the plot
|
|
41
|
+
- `plot_config`: Optional configuration dictionary for the plot
|
|
42
|
+
|
|
43
|
+
:returns: Matplotlib figure object
|
|
44
|
+
"""
|
|
45
|
+
title = f"{spider_data.model_name} - {spider_data.metric.upper()}"
|
|
46
|
+
|
|
47
|
+
# Prepare and sort the data for plotting, and create figure and axes
|
|
48
|
+
groups, values, lower_bounds, upper_bounds = prepare_and_sort(spider_data)
|
|
49
|
+
angles = compute_angles(len(groups), spider_data.plot_config)
|
|
50
|
+
fig, ax = _init_spider_axes(spider_data.ylim_min[spider_data.metric],
|
|
51
|
+
spider_data.ylim_max[spider_data.metric])
|
|
52
|
+
|
|
53
|
+
# Configure the axes with labels and title
|
|
54
|
+
_configure_axes(ax, angles, groups, title)
|
|
55
|
+
|
|
56
|
+
# Draw the main series of the spider plot (line and scatter points)
|
|
57
|
+
sc = _draw_main_series(ax, angles, values, zorder=9)
|
|
58
|
+
|
|
59
|
+
# Add a hover cursor to the scatter points for interactivity
|
|
60
|
+
_add_cursor_to_spider_plot(sc, fig.canvas, groups, values, lower_bounds, upper_bounds)
|
|
61
|
+
|
|
62
|
+
# Fill the area between lower and upper bounds of the e.g. confidence intervals
|
|
63
|
+
_fill_bounds(ax, angles, lower_bounds, upper_bounds, zorder=5)
|
|
64
|
+
|
|
65
|
+
# If a metric is specified, apply the metric-specific overlay (e.g., thresholds, fill regions)
|
|
66
|
+
if spider_data.metric:
|
|
67
|
+
_apply_metric_overlay(ax, angles, spider_data.metric, values, lower_bounds, upper_bounds,
|
|
68
|
+
zorder_bg=2, zorder_thresholds=10)
|
|
69
|
+
|
|
70
|
+
fig.tight_layout()
|
|
71
|
+
return fig
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _init_spider_axes(ymin: float, ymax: float) -> Tuple[plt.Figure, plt.Axes]:
|
|
75
|
+
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw={'polar': True})
|
|
76
|
+
ax.set_ylim(ymin, ymax)
|
|
77
|
+
return fig, ax
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _add_cursor_to_spider_plot(sc, canvas, groups, values, lower_bounds, upper_bounds) -> mplcursors.Cursor:
|
|
81
|
+
"""
|
|
82
|
+
Attach a hover cursor to the scatter points in the spider plot.
|
|
83
|
+
|
|
84
|
+
:arg sc: Scatter collection object from the spider plot
|
|
85
|
+
:arg groups: List of group names
|
|
86
|
+
:arg values: List of values for each group
|
|
87
|
+
:arg lower_bounds: List of lower bounds for each group
|
|
88
|
+
:arg upper_bounds: List of upper bounds for each group
|
|
89
|
+
"""
|
|
90
|
+
cursor = mplcursors.cursor(sc, hover=True)
|
|
91
|
+
|
|
92
|
+
@cursor.connect("add")
|
|
93
|
+
def on_add(sel):
|
|
94
|
+
i = sel.index
|
|
95
|
+
sel.annotation.set_text(
|
|
96
|
+
f"{groups[i]}\n"
|
|
97
|
+
f"Median: {values[i]:.3f} "
|
|
98
|
+
f"[{lower_bounds[i]:.3f}, {upper_bounds[i]:.3f}]"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def _hide_all_annotations(event):
|
|
102
|
+
for sel in list(cursor.selections):
|
|
103
|
+
cursor.remove_selection(sel)
|
|
104
|
+
|
|
105
|
+
canvas.mpl_connect('button_press_event', _hide_all_annotations)
|
|
106
|
+
return cursor
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _draw_main_series(ax: plt.Axes, angles: List[float], values: List[float], *, zorder: Optional[float] = None) -> PathCollection:
|
|
110
|
+
"""
|
|
111
|
+
Draw the main series of the spider plot.
|
|
112
|
+
|
|
113
|
+
:arg ax: Matplotlib Axes object
|
|
114
|
+
:arg angles: List of angles for each group
|
|
115
|
+
:arg values: List of values for each group
|
|
116
|
+
|
|
117
|
+
:returns: Matplotlib PathCollection object for the scatter points
|
|
118
|
+
"""
|
|
119
|
+
ax.plot(angles, values, color='steelblue', linestyle='-', linewidth=2, zorder=zorder)
|
|
120
|
+
if zorder is not None:
|
|
121
|
+
zorder -= 0.01
|
|
122
|
+
return ax.scatter(angles, values, marker='o', color='b', zorder=zorder)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _apply_metric_overlay(
|
|
126
|
+
ax: plt.Axes,
|
|
127
|
+
angles: List[float],
|
|
128
|
+
metric: str,
|
|
129
|
+
values: List[float],
|
|
130
|
+
lower_bounds: List[float],
|
|
131
|
+
upper_bounds: List[float],
|
|
132
|
+
*,
|
|
133
|
+
zorder_bg: Optional[float] = None,
|
|
134
|
+
zorder_thresholds: Optional[float] = None,
|
|
135
|
+
) -> None:
|
|
136
|
+
"""
|
|
137
|
+
Apply metric-specific overlays to the spider plot.
|
|
138
|
+
|
|
139
|
+
:arg ax: Matplotlib Axes object
|
|
140
|
+
:arg angles: List of angles for each group
|
|
141
|
+
:arg metric: Metric to apply the overlay for (e.g., 'QWK', 'EOD', 'AAOD')
|
|
142
|
+
:arg values: List of values for each group
|
|
143
|
+
:arg lower_bounds: List of lower bounds for each group
|
|
144
|
+
:arg upper_bounds: List of upper bounds for each group
|
|
145
|
+
"""
|
|
146
|
+
metric = metric.upper()
|
|
147
|
+
full_theta = get_full_theta()
|
|
148
|
+
overlay_config = {
|
|
149
|
+
'QWK': {
|
|
150
|
+
'baseline': {'type': 'line', 'y': 0, 'style': '--', 'color': 'seagreen', 'linewidth': 3, 'alpha': 0.5},
|
|
151
|
+
'thresholds': [
|
|
152
|
+
(lower_bounds, lambda v: v > 0, 'maroon'),
|
|
153
|
+
(upper_bounds, lambda v: v < 0, 'red'),
|
|
154
|
+
],
|
|
155
|
+
},
|
|
156
|
+
'EOD': {
|
|
157
|
+
'fill': {'lo': -0.1, 'hi': 0.1, 'color': 'lightgreen', 'alpha': 0.4},
|
|
158
|
+
'thresholds': [
|
|
159
|
+
(values, lambda v: v > 0.1, 'maroon'),
|
|
160
|
+
(values, lambda v: v < -0.1, 'red'),
|
|
161
|
+
],
|
|
162
|
+
},
|
|
163
|
+
'AAOD': {
|
|
164
|
+
'fill': {'lo': 0, 'hi': 0.1, 'color': 'lightgreen', 'alpha': 0.4},
|
|
165
|
+
'baseline': {'type': 'ylim', 'lo': 0},
|
|
166
|
+
'thresholds': [
|
|
167
|
+
(values, lambda v: v > 0.1, 'maroon'),
|
|
168
|
+
],
|
|
169
|
+
},
|
|
170
|
+
}
|
|
171
|
+
cfg = overlay_config.get(metric)
|
|
172
|
+
if not cfg:
|
|
173
|
+
return
|
|
174
|
+
|
|
175
|
+
# Baseline rendering
|
|
176
|
+
if 'baseline' in cfg:
|
|
177
|
+
base = cfg['baseline']
|
|
178
|
+
if base['type'] == 'line':
|
|
179
|
+
ax.plot(full_theta, np.full_like(full_theta, base['y']), base['style'],
|
|
180
|
+
linewidth=base['linewidth'], alpha=base['alpha'], color=base['color'], zorder=zorder_bg)
|
|
181
|
+
elif base['type'] == 'ylim':
|
|
182
|
+
_, ymax = ax.get_ylim()
|
|
183
|
+
ax.set_ylim(base['lo'], ymax)
|
|
184
|
+
|
|
185
|
+
# Fill region if specified
|
|
186
|
+
if 'fill' in cfg:
|
|
187
|
+
f = cfg['fill']
|
|
188
|
+
ax.fill_between(full_theta, f['lo'], f['hi'], color=f['color'], alpha=f['alpha'], zorder=zorder_bg)
|
|
189
|
+
|
|
190
|
+
# Annotate thresholds
|
|
191
|
+
for data, cond, color in cfg['thresholds']:
|
|
192
|
+
_annotate(ax, angles, data, cond, color, zorder=zorder_thresholds)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _annotate(
|
|
196
|
+
ax: plt.Axes,
|
|
197
|
+
angles: List[float],
|
|
198
|
+
data: List[float],
|
|
199
|
+
condition: Any,
|
|
200
|
+
color: str,
|
|
201
|
+
delta: float = 0.05,
|
|
202
|
+
*,
|
|
203
|
+
zorder: Optional[float] = None,
|
|
204
|
+
) -> None:
|
|
205
|
+
"""
|
|
206
|
+
Draw small perpendicular line segments at threshold points, scaling
|
|
207
|
+
their angular span based on the distance from the y-axis.
|
|
208
|
+
"""
|
|
209
|
+
ymin, ymax = ax.get_ylim()
|
|
210
|
+
full_span = ymax - ymin
|
|
211
|
+
max_angle = 2 * np.pi
|
|
212
|
+
labels = ax.get_xticklabels()
|
|
213
|
+
|
|
214
|
+
for i in range(len(data) - 1):
|
|
215
|
+
raw_val = data[i]
|
|
216
|
+
if not condition(raw_val):
|
|
217
|
+
continue
|
|
218
|
+
|
|
219
|
+
# color the i-th tick label
|
|
220
|
+
if i < len(labels):
|
|
221
|
+
labels[i].set_fontweight('bold')
|
|
222
|
+
labels[i].set_color(color)
|
|
223
|
+
else:
|
|
224
|
+
print(f"Warning: Not enough labels ({len(labels)}) for {len(data) - 1} data points. "
|
|
225
|
+
"Please check the input data.")
|
|
226
|
+
|
|
227
|
+
angle = angles[i]
|
|
228
|
+
d_angle = delta * full_span / (raw_val - ymin)
|
|
229
|
+
# compute radial value so the chord crosses through the true data point
|
|
230
|
+
r_val = ymin + (raw_val - ymin) / math.cos(d_angle)
|
|
231
|
+
start, end = angle - d_angle, angle + d_angle
|
|
232
|
+
|
|
233
|
+
# handle wrap-around in [0, 2π)
|
|
234
|
+
segments = []
|
|
235
|
+
if start < 0 or end >= max_angle:
|
|
236
|
+
segments.append((start % max_angle, end % max_angle))
|
|
237
|
+
else:
|
|
238
|
+
segments.append((start, end))
|
|
239
|
+
|
|
240
|
+
for a0, a1 in segments:
|
|
241
|
+
ax.plot([a0, a1], [r_val, r_val], color=color, linewidth=1.3, solid_capstyle='butt', zorder=zorder)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _fill_bounds(
|
|
245
|
+
ax: plt.Axes,
|
|
246
|
+
angles: List[float],
|
|
247
|
+
lower_bounds: List[float],
|
|
248
|
+
upper_bounds: List[float],
|
|
249
|
+
*,
|
|
250
|
+
zorder: Optional[float] = None,
|
|
251
|
+
) -> None:
|
|
252
|
+
"""
|
|
253
|
+
Fill the area between the lower and upper bounds in the spider plot.
|
|
254
|
+
|
|
255
|
+
:arg ax: Matplotlib Axes object
|
|
256
|
+
:arg angles: List of angles for each group
|
|
257
|
+
:arg lower_bounds: List of lower bounds for each group
|
|
258
|
+
:arg upper_bounds: List of upper bounds for each group
|
|
259
|
+
"""
|
|
260
|
+
ax.fill_between(angles, lower_bounds, upper_bounds, color='steelblue', alpha=0.2, zorder=zorder)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _configure_axes(
|
|
264
|
+
ax: plt.Axes,
|
|
265
|
+
angles: List[float],
|
|
266
|
+
groups: List[str],
|
|
267
|
+
title: str,
|
|
268
|
+
) -> None:
|
|
269
|
+
"""
|
|
270
|
+
Configure the axes of the spider plot with labels and title.
|
|
271
|
+
|
|
272
|
+
:arg ax: Matplotlib Axes object
|
|
273
|
+
:arg angles: List of angles for each group
|
|
274
|
+
:arg groups: List of group names
|
|
275
|
+
:arg title: Title for the spider plot
|
|
276
|
+
"""
|
|
277
|
+
ax.set_xticks(angles[:-1])
|
|
278
|
+
ax.set_xticklabels(groups[:-1])
|
|
279
|
+
ax.set_title(title, size=14, weight='bold')
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def figure_to_image(fig: plt.Figure) -> np.ndarray:
|
|
283
|
+
"""
|
|
284
|
+
Convert a Matplotlib figure to a numpy array, robust against:
|
|
285
|
+
- interactive canvases vs Agg canvas APIs
|
|
286
|
+
- HiDPI / device-pixel-ratio scaling (buffer larger than reported w*h)
|
|
287
|
+
- swapped width/height reports
|
|
288
|
+
|
|
289
|
+
:arg fig: Matplotlib figure object
|
|
290
|
+
|
|
291
|
+
:returns: An (H, W, 3) uint8 RGB array.
|
|
292
|
+
"""
|
|
293
|
+
# Prefer the figure's existing canvas so we don't reattach an Agg canvas
|
|
294
|
+
canvas = getattr(fig, "canvas", None)
|
|
295
|
+
|
|
296
|
+
def _try_shape(buf1d: np.ndarray, w: int, h: int) -> Optional[np.ndarray]:
|
|
297
|
+
# buf1d is a 1-D uint8 view; try to infer correct (h, w, 4) layout
|
|
298
|
+
total_bytes = buf1d.size
|
|
299
|
+
if total_bytes % 4 != 0:
|
|
300
|
+
return None
|
|
301
|
+
pixels = total_bytes // 4
|
|
302
|
+
expected = w * h
|
|
303
|
+
|
|
304
|
+
# exact match
|
|
305
|
+
if pixels == expected:
|
|
306
|
+
try:
|
|
307
|
+
arr = buf1d.reshape((h, w, 4))
|
|
308
|
+
return arr
|
|
309
|
+
except Exception:
|
|
310
|
+
pass
|
|
311
|
+
|
|
312
|
+
# try swapped w/h if caller reported them reversed
|
|
313
|
+
if pixels == (h * w):
|
|
314
|
+
try:
|
|
315
|
+
arr = buf1d.reshape((w, h, 4))
|
|
316
|
+
return arr
|
|
317
|
+
except Exception:
|
|
318
|
+
pass
|
|
319
|
+
|
|
320
|
+
# try integer scale factor (HiDPI / retina): pixels == expected * scale^2
|
|
321
|
+
scale2 = pixels / expected if expected else 0
|
|
322
|
+
if scale2 >= 1:
|
|
323
|
+
scale = int(round(math.sqrt(scale2)))
|
|
324
|
+
if scale > 0 and scale * scale == int(scale2):
|
|
325
|
+
try:
|
|
326
|
+
arr = buf1d.reshape((h * scale, w * scale, 4))
|
|
327
|
+
return arr
|
|
328
|
+
except Exception:
|
|
329
|
+
pass
|
|
330
|
+
|
|
331
|
+
# nothing matched
|
|
332
|
+
return None
|
|
333
|
+
|
|
334
|
+
# 1) Prefer native buffer_rgba() if available (gives RGBA-like buffer)
|
|
335
|
+
if canvas is not None:
|
|
336
|
+
try:
|
|
337
|
+
canvas.draw()
|
|
338
|
+
except Exception:
|
|
339
|
+
pass
|
|
340
|
+
|
|
341
|
+
if hasattr(canvas, "buffer_rgba"):
|
|
342
|
+
try:
|
|
343
|
+
raw = canvas.buffer_rgba()
|
|
344
|
+
if isinstance(raw, np.ndarray):
|
|
345
|
+
arr = raw
|
|
346
|
+
else:
|
|
347
|
+
arr = np.frombuffer(raw, dtype=np.uint8)
|
|
348
|
+
w, h = canvas.get_width_height()
|
|
349
|
+
maybe = _try_shape(arr, w, h)
|
|
350
|
+
if maybe is not None:
|
|
351
|
+
arr = maybe
|
|
352
|
+
else:
|
|
353
|
+
# if buffer_rgba returned already shaped (h,w,4) as flat bytes,
|
|
354
|
+
# try reshaping by interpreting raw length / 4
|
|
355
|
+
arr = arr.reshape((-1, 4)) # fallback to something shaped
|
|
356
|
+
# at this point arr may be shaped (h,w,4) or (N,4)
|
|
357
|
+
if arr.ndim == 3 and arr.shape[2] == 4:
|
|
358
|
+
# assume RGBA ordering -> keep RGB channels
|
|
359
|
+
rgb = arr[:, :, :3].copy()
|
|
360
|
+
return rgb
|
|
361
|
+
elif arr.ndim == 2 and arr.shape[1] == 4:
|
|
362
|
+
# try to infer width/height from get_width_height
|
|
363
|
+
w, h = canvas.get_width_height()
|
|
364
|
+
maybe = _try_shape(arr.ravel(), w, h)
|
|
365
|
+
if maybe is not None:
|
|
366
|
+
return maybe[:, :, 1:4][:, :, ::-1]
|
|
367
|
+
# fall through to other strategies if ambiguous
|
|
368
|
+
except Exception:
|
|
369
|
+
pass
|
|
370
|
+
|
|
371
|
+
# 2) Fallback to tostring_argb() path when available and try to infer shape / scale
|
|
372
|
+
if hasattr(canvas, "tostring_argb"):
|
|
373
|
+
try:
|
|
374
|
+
raw = canvas.tostring_argb()
|
|
375
|
+
buf = np.frombuffer(raw, dtype=np.uint8)
|
|
376
|
+
w, h = canvas.get_width_height()
|
|
377
|
+
arr = _try_shape(buf, w, h)
|
|
378
|
+
if arr is not None:
|
|
379
|
+
# keep original behavior for ARGB -> RGB conversion
|
|
380
|
+
return arr[:, :, 1:4][:, :, ::-1]
|
|
381
|
+
except Exception:
|
|
382
|
+
pass
|
|
383
|
+
|
|
384
|
+
# 3) Last resort: paint to a temporary Agg canvas (does not modify existing interactive canvas)
|
|
385
|
+
try:
|
|
386
|
+
agg = FigureCanvas(fig)
|
|
387
|
+
agg.draw()
|
|
388
|
+
w, h = agg.get_width_height()
|
|
389
|
+
buf = np.frombuffer(agg.tostring_argb(), dtype=np.uint8)
|
|
390
|
+
buf = buf.reshape(h, w, 4)
|
|
391
|
+
return buf[:, :, 1:4][:, :, ::-1]
|
|
392
|
+
except Exception as exc:
|
|
393
|
+
raise RuntimeError(f"Failed to convert figure to image: {exc}")
|
|
394
|
+
|
|
395
|
+
def display_figures_grid(
|
|
396
|
+
figures: List[plt.Figure],
|
|
397
|
+
n_cols: int = 3,
|
|
398
|
+
*,
|
|
399
|
+
anchor_fig: Optional[plt.Figure] = None,
|
|
400
|
+
anchor_index: int = 0,
|
|
401
|
+
) -> None:
|
|
402
|
+
"""
|
|
403
|
+
Display a grid of: figures in a single plot.
|
|
404
|
+
|
|
405
|
+
:arg figures: List of Matplotlib figure objects
|
|
406
|
+
:arg n_cols: Number of columns in the grid
|
|
407
|
+
"""
|
|
408
|
+
if not figures:
|
|
409
|
+
return
|
|
410
|
+
|
|
411
|
+
n_figs = len(figures)
|
|
412
|
+
n_rows = int(np.ceil(n_figs / n_cols))
|
|
413
|
+
grid_fig, axes = plt.subplots(n_rows, n_cols, figsize=(8 * n_cols, 8 * n_rows))
|
|
414
|
+
axes = np.array(axes).reshape(-1)
|
|
415
|
+
|
|
416
|
+
for i, f in enumerate(figures):
|
|
417
|
+
img = figure_to_image(f)
|
|
418
|
+
axes[i].imshow(img)
|
|
419
|
+
axes[i].axis('off')
|
|
420
|
+
# plt.close(f)
|
|
421
|
+
|
|
422
|
+
for ax in axes[n_figs:]:
|
|
423
|
+
ax.remove()
|
|
424
|
+
|
|
425
|
+
grid_fig.tight_layout()
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# Copyright (c) 2025 Medical Imaging and Data Resource Center (MIDRC).
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
|
|
16
|
+
""" Plotting tools for visualizing model performance metrics. """
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from typing import Any, Dict, List, Tuple
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class SpiderPlotData:
|
|
25
|
+
""" Data class for spider plot data. """
|
|
26
|
+
model_name: str = ""
|
|
27
|
+
groups: List[str] = field(default_factory=list)
|
|
28
|
+
values: List[float] = field(default_factory=list)
|
|
29
|
+
lower_bounds: List[float] = field(default_factory=list)
|
|
30
|
+
upper_bounds: List[float] = field(default_factory=list)
|
|
31
|
+
ylim_min: Dict[str, float] = field(default_factory=dict)
|
|
32
|
+
ylim_max: Dict[str, float] = field(default_factory=dict)
|
|
33
|
+
metric: str = ""
|
|
34
|
+
plot_config: Dict[str, Any] = field(default_factory=dict)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_angle_rot(start_loc: str) -> float:
|
|
38
|
+
"""
|
|
39
|
+
Get the angle rotation based on the starting location.
|
|
40
|
+
|
|
41
|
+
:arg start_loc: Starting location string
|
|
42
|
+
|
|
43
|
+
:returns: Angle rotation in radians
|
|
44
|
+
"""
|
|
45
|
+
if start_loc.startswith('t'):
|
|
46
|
+
return np.pi / 2
|
|
47
|
+
if start_loc.startswith('l'):
|
|
48
|
+
return np.pi
|
|
49
|
+
if start_loc.startswith('b'):
|
|
50
|
+
return 3 * np.pi / 2
|
|
51
|
+
return 0.0
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_angles(num_axes: int, plot_config: dict) -> List[float]:
|
|
55
|
+
"""
|
|
56
|
+
Get the angles for the spider chart axes.
|
|
57
|
+
|
|
58
|
+
:arg num_axes: Number of axes
|
|
59
|
+
:arg plot_config: Plot configuration dictionary
|
|
60
|
+
|
|
61
|
+
:returns: List of angles in radians
|
|
62
|
+
"""
|
|
63
|
+
angles = np.linspace(0, 2 * np.pi, num_axes, endpoint=False).tolist()
|
|
64
|
+
if plot_config.get('clockwise', False):
|
|
65
|
+
angles.reverse()
|
|
66
|
+
rot = get_angle_rot(plot_config.get('start', 'right'))
|
|
67
|
+
return [(a + rot) % (2 * np.pi) for a in angles]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def prepare_and_sort(plot_data: SpiderPlotData) -> Tuple[List[str], List[float], List[float], List[float]]:
|
|
71
|
+
custom_orders = plot_data.plot_config.get('custom_orders') or {
|
|
72
|
+
'age_binned': ['18-29', '30-39', '40-49', '50-64', '65-74', '75-84', '85+'],
|
|
73
|
+
'sex': ['Male', 'Female'],
|
|
74
|
+
'race': ['White', 'Asian', 'Black or African American', 'Other'],
|
|
75
|
+
'ethnicity': ['Hispanic or Latino', 'Not Hispanic or Latino'],
|
|
76
|
+
'intersectional_race_ethnicity': ['White', 'Not White or Hispanic or Latino'],
|
|
77
|
+
}
|
|
78
|
+
attributes = list(custom_orders.keys())
|
|
79
|
+
|
|
80
|
+
def sort_key(label: str) -> Any:
|
|
81
|
+
attr, grp = label.split(': ', 1)
|
|
82
|
+
if attr in attributes:
|
|
83
|
+
if grp in custom_orders[attr]:
|
|
84
|
+
return (attributes.index(attr), custom_orders[attr].index(grp))
|
|
85
|
+
else:
|
|
86
|
+
return (attributes.index(attr), len(custom_orders[attr]))
|
|
87
|
+
# Other items sort after custom-ordered, by string label
|
|
88
|
+
return (len(attributes), label)
|
|
89
|
+
|
|
90
|
+
zipped = list(zip(
|
|
91
|
+
plot_data.groups,
|
|
92
|
+
plot_data.values,
|
|
93
|
+
plot_data.lower_bounds,
|
|
94
|
+
plot_data.upper_bounds
|
|
95
|
+
))
|
|
96
|
+
sorted_zipped = sorted(zipped, key=lambda x: sort_key(x[0]))
|
|
97
|
+
groups, values, lower_bounds, upper_bounds = map(list, zip(*sorted_zipped))
|
|
98
|
+
|
|
99
|
+
# Close the loop for spider plot
|
|
100
|
+
groups.append(groups[0])
|
|
101
|
+
values.append(values[0])
|
|
102
|
+
lower_bounds.append(lower_bounds[0])
|
|
103
|
+
upper_bounds.append(upper_bounds[0])
|
|
104
|
+
|
|
105
|
+
return groups, values, lower_bounds, upper_bounds
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def get_full_theta() -> np.ndarray:
|
|
109
|
+
"""
|
|
110
|
+
Get a full circle of angles for plotting.
|
|
111
|
+
|
|
112
|
+
:returns: Array of angles from 0 to 2π
|
|
113
|
+
"""
|
|
114
|
+
return np.linspace(0, 2 * np.pi, 100)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def compute_angles(num_axes_with_close: int, plot_config: dict) -> List[float]:
|
|
118
|
+
"""
|
|
119
|
+
Compute angles for spider plot, accounting for loop closure.
|
|
120
|
+
|
|
121
|
+
:arg num_axes_with_close: Number of items in groups list (including duplicated first at end).
|
|
122
|
+
:arg plot_config: Configuration dict for angle ordering.
|
|
123
|
+
|
|
124
|
+
:returns: Angles list matching the length of groups list.
|
|
125
|
+
"""
|
|
126
|
+
# The groups list already closes the loop by duplicating the first entry.
|
|
127
|
+
# Compute based on original number of axes (excluding the closure element).
|
|
128
|
+
original_count = num_axes_with_close - 1
|
|
129
|
+
angles = get_angles(original_count, plot_config)
|
|
130
|
+
# Close the loop by appending the first angle
|
|
131
|
+
angles.append(angles[0])
|
|
132
|
+
return angles
|