MIDRC-MELODY 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- MIDRC_MELODY/__init__.py +0 -0
- MIDRC_MELODY/__main__.py +4 -0
- MIDRC_MELODY/common/__init__.py +0 -0
- MIDRC_MELODY/common/data_loading.py +199 -0
- MIDRC_MELODY/common/data_preprocessing.py +134 -0
- MIDRC_MELODY/common/edit_config.py +156 -0
- MIDRC_MELODY/common/eod_aaod_metrics.py +292 -0
- MIDRC_MELODY/common/generate_eod_aaod_spiders.py +69 -0
- MIDRC_MELODY/common/generate_qwk_spiders.py +56 -0
- MIDRC_MELODY/common/matplotlib_spider.py +425 -0
- MIDRC_MELODY/common/plot_tools.py +132 -0
- MIDRC_MELODY/common/plotly_spider.py +217 -0
- MIDRC_MELODY/common/qwk_metrics.py +244 -0
- MIDRC_MELODY/common/table_tools.py +230 -0
- MIDRC_MELODY/gui/__init__.py +0 -0
- MIDRC_MELODY/gui/config_editor.py +200 -0
- MIDRC_MELODY/gui/data_loading.py +157 -0
- MIDRC_MELODY/gui/main_controller.py +154 -0
- MIDRC_MELODY/gui/main_window.py +545 -0
- MIDRC_MELODY/gui/matplotlib_spider_widget.py +204 -0
- MIDRC_MELODY/gui/metrics_model.py +62 -0
- MIDRC_MELODY/gui/plotly_spider_widget.py +56 -0
- MIDRC_MELODY/gui/qchart_spider_widget.py +272 -0
- MIDRC_MELODY/gui/shared/__init__.py +0 -0
- MIDRC_MELODY/gui/shared/react/__init__.py +0 -0
- MIDRC_MELODY/gui/shared/react/copyabletableview.py +100 -0
- MIDRC_MELODY/gui/shared/react/grabbablewidget.py +406 -0
- MIDRC_MELODY/gui/tqdm_handler.py +210 -0
- MIDRC_MELODY/melody.py +102 -0
- MIDRC_MELODY/melody_gui.py +111 -0
- MIDRC_MELODY/resources/MIDRC.ico +0 -0
- midrc_melody-0.3.3.dist-info/METADATA +151 -0
- midrc_melody-0.3.3.dist-info/RECORD +37 -0
- midrc_melody-0.3.3.dist-info/WHEEL +5 -0
- midrc_melody-0.3.3.dist-info/entry_points.txt +4 -0
- midrc_melody-0.3.3.dist-info/licenses/LICENSE +201 -0
- midrc_melody-0.3.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
# Copyright (c) 2025 Medical Imaging and Data Resource Center (MIDRC).
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
|
|
16
|
+
from numpy import pi as np_pi
|
|
17
|
+
import plotly.graph_objects as go
|
|
18
|
+
import plotly.io as pio
|
|
19
|
+
|
|
20
|
+
from MIDRC_MELODY.common.plot_tools import SpiderPlotData, get_full_theta, compute_angles, prepare_and_sort
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def spider_to_html(spider_data: SpiderPlotData) -> str:
|
|
24
|
+
"""
|
|
25
|
+
Given a SpiderPlotData, return an HTML <div> string containing a Plotly radar chart where:
|
|
26
|
+
- Median 'values' are drawn as a line + circle markers at discrete category angles (in degrees).
|
|
27
|
+
- The region between 'lower_bounds' and 'upper_bounds' is shaded (no boundary lines),
|
|
28
|
+
computed using full_theta_deg for a smooth polygon.
|
|
29
|
+
- Baseline(s) and safe-band fills also use full_theta_deg (0→360°).
|
|
30
|
+
- Thresholds are drawn as short line segments at ±Δθ around each category angle,
|
|
31
|
+
where Δθ = delta * (ymax – radius)/(ymax – ymin) so that the visible length
|
|
32
|
+
of each tick is roughly constant in screen pixels.
|
|
33
|
+
- Line thickness is 1 px, and each tick is colored correctly.
|
|
34
|
+
"""
|
|
35
|
+
raw_metric: str = spider_data.metric
|
|
36
|
+
metric_display: str = spider_data.metric.upper()
|
|
37
|
+
groups, values, lower_bounds, upper_bounds = prepare_and_sort(spider_data)
|
|
38
|
+
|
|
39
|
+
# 1) Number of categories (excluding the “closing” duplicate)
|
|
40
|
+
N = len(groups)
|
|
41
|
+
|
|
42
|
+
# 2) Build full_theta_deg: 100 points from 0° to 360° for smooth circular traces
|
|
43
|
+
full_theta = get_full_theta()
|
|
44
|
+
|
|
45
|
+
# 3) Compute discrete category angles in degrees: [0°, 360/N°, 2*360/N°, …]
|
|
46
|
+
cat_angles = compute_angles(len(groups), spider_data.plot_config)
|
|
47
|
+
cat_labels = [g.split(": ", 1)[-1] for g in groups]
|
|
48
|
+
|
|
49
|
+
# 5) Determine radial axis min/max from spider_data
|
|
50
|
+
radial_min = spider_data.ylim_min.get(raw_metric, None)
|
|
51
|
+
radial_max = spider_data.ylim_max.get(raw_metric, None)
|
|
52
|
+
|
|
53
|
+
# 6) Start building the Plotly figure
|
|
54
|
+
fig = go.Figure()
|
|
55
|
+
|
|
56
|
+
# 7) Shade between lower_bounds and upper_bounds (CI band)
|
|
57
|
+
theta_ub = cat_angles
|
|
58
|
+
theta_lb = cat_angles
|
|
59
|
+
theta_ci = theta_ub + theta_lb[::-1]
|
|
60
|
+
r_ci = upper_bounds + lower_bounds[::-1]
|
|
61
|
+
fig.add_trace(
|
|
62
|
+
go.Scatterpolar(
|
|
63
|
+
r=r_ci,
|
|
64
|
+
theta=theta_ci,
|
|
65
|
+
thetaunit="radians", # treat ALL numeric theta as radians
|
|
66
|
+
mode="none",
|
|
67
|
+
fill="toself",
|
|
68
|
+
fillcolor="rgba(70,130,180,0.2)", # semi-transparent steelblue
|
|
69
|
+
line=dict(color="rgba(0,0,0,0)"),
|
|
70
|
+
hoverinfo="skip",
|
|
71
|
+
showlegend=False,
|
|
72
|
+
)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# 8) Median “values” trace (lines + circle markers)
|
|
76
|
+
theta_vals = cat_angles
|
|
77
|
+
fig.add_trace(
|
|
78
|
+
go.Scatterpolar(
|
|
79
|
+
r=values,
|
|
80
|
+
theta=theta_vals,
|
|
81
|
+
thetaunit="radians",
|
|
82
|
+
mode="lines+markers",
|
|
83
|
+
line=dict(color="steelblue", width=2),
|
|
84
|
+
marker=dict(symbol="circle", size=6, color="steelblue"),
|
|
85
|
+
customdata=list(zip(groups, lower_bounds, upper_bounds)),
|
|
86
|
+
hovertemplate="%{customdata[0]}<br>Median: %{r:.3f} [%{customdata[1]:.3f}, %{customdata[2]:.3f}]<extra></extra>",
|
|
87
|
+
showlegend=False,
|
|
88
|
+
)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# 9) Metric-specific overlay rules (matching plot_tools._apply_metric_overlay)
|
|
92
|
+
overlay_config = {
|
|
93
|
+
"QWK": {
|
|
94
|
+
"baseline": {"type": "line", "y": 0, "color": "seagreen", "width": 3, "dash": "dash", "alpha": 0.8},
|
|
95
|
+
"thresholds": [
|
|
96
|
+
(lower_bounds[:N], lambda v: v > 0, "maroon"),
|
|
97
|
+
(upper_bounds[:N], lambda v: v < 0, "red"),
|
|
98
|
+
],
|
|
99
|
+
},
|
|
100
|
+
"EOD": {
|
|
101
|
+
"fill": {"lo": -0.1, "hi": 0.1, "color": "lightgreen", "alpha": 0.5},
|
|
102
|
+
"thresholds": [
|
|
103
|
+
(values[:N], lambda v: v > 0.1, "maroon"),
|
|
104
|
+
(values[:N], lambda v: v < -0.1, "red"),
|
|
105
|
+
],
|
|
106
|
+
},
|
|
107
|
+
"AAOD": {
|
|
108
|
+
"fill": {"lo": 0.0, "hi": 0.1, "color": "lightgreen", "alpha": 0.5},
|
|
109
|
+
"baseline": {"type": "ylim", "lo": 0.0},
|
|
110
|
+
"thresholds": [
|
|
111
|
+
(values[:N], lambda v: v > 0.1, "maroon"),
|
|
112
|
+
],
|
|
113
|
+
},
|
|
114
|
+
}
|
|
115
|
+
cfg = overlay_config.get(metric_display, None)
|
|
116
|
+
if cfg:
|
|
117
|
+
# 9a) Draw baseline if specified
|
|
118
|
+
if "baseline" in cfg:
|
|
119
|
+
base = cfg["baseline"]
|
|
120
|
+
if base["type"] == "line":
|
|
121
|
+
baseline_r = [base["y"]] * len(full_theta)
|
|
122
|
+
fig.add_trace(
|
|
123
|
+
go.Scatterpolar(
|
|
124
|
+
r=baseline_r,
|
|
125
|
+
theta=list(full_theta),
|
|
126
|
+
thetaunit="radians", # treat ALL numeric theta as radians
|
|
127
|
+
mode="lines",
|
|
128
|
+
line=dict(color=base["color"], dash=base["dash"], width=base["width"]),
|
|
129
|
+
opacity=base["alpha"],
|
|
130
|
+
hoverinfo="skip",
|
|
131
|
+
showlegend=False,
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
elif base["type"] == "ylim":
|
|
135
|
+
# Override radial_min below
|
|
136
|
+
radial_min = base["lo"]
|
|
137
|
+
|
|
138
|
+
# 9b) Draw “safe‐band” fill if specified
|
|
139
|
+
if "fill" in cfg:
|
|
140
|
+
f = cfg["fill"]
|
|
141
|
+
hi_vals = [f["hi"]] * len(full_theta)
|
|
142
|
+
lo_vals = [f["lo"]] * len(full_theta)
|
|
143
|
+
theta_fill = list(full_theta) + list(full_theta[::-1])
|
|
144
|
+
r_fill = hi_vals + lo_vals[::-1]
|
|
145
|
+
fig.add_trace(
|
|
146
|
+
go.Scatterpolar(
|
|
147
|
+
r=r_fill,
|
|
148
|
+
theta=theta_fill,
|
|
149
|
+
thetaunit="radians", # treat ALL numeric theta as radians
|
|
150
|
+
mode="none",
|
|
151
|
+
fill="toself",
|
|
152
|
+
fillcolor=f"rgba({_hex_to_rgb(f['color'])}, {f['alpha']})",
|
|
153
|
+
line=dict(color="rgba(0,0,0,0)"),
|
|
154
|
+
hoverinfo="skip",
|
|
155
|
+
showlegend=False,
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# 9c) Draw threshold ticks as short line segments of constant pixel length
|
|
160
|
+
# Compute Δθ per‐point: Δθ = delta * (radial_max – radius)/(radial_max – radial_min)
|
|
161
|
+
# so that a fixed “delta” produces roughly uniform on‐screen length.
|
|
162
|
+
delta = 0.15 # Adjust this value to change the length of the threshold ticks (in radians)
|
|
163
|
+
for data_list, cond, color_name in cfg.get("thresholds", []):
|
|
164
|
+
for i, v in enumerate(data_list):
|
|
165
|
+
if cond(v):
|
|
166
|
+
angle = cat_angles[i]
|
|
167
|
+
radius = v
|
|
168
|
+
# Avoid division by zero
|
|
169
|
+
if radial_max == radial_min:
|
|
170
|
+
d_theta = 0
|
|
171
|
+
else:
|
|
172
|
+
d_theta = delta * (radial_max - radius) / (radial_max - radial_min)
|
|
173
|
+
|
|
174
|
+
theta_line = [angle - d_theta, angle + d_theta]
|
|
175
|
+
r_line = [radius, radius]
|
|
176
|
+
fig.add_trace(
|
|
177
|
+
go.Scatterpolar(
|
|
178
|
+
r=r_line,
|
|
179
|
+
theta=theta_line,
|
|
180
|
+
thetaunit="radians", # treat ALL numeric theta as radians
|
|
181
|
+
mode="lines",
|
|
182
|
+
line=dict(color=color_name, width=1.5),
|
|
183
|
+
hoverinfo="skip",
|
|
184
|
+
showlegend=False,
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# 10) Final polar layout adjustments - Tick angles must be degrees for Plotly
|
|
189
|
+
fig.update_layout(
|
|
190
|
+
title=f"{spider_data.model_name} – {metric_display}",
|
|
191
|
+
polar=dict(
|
|
192
|
+
radialaxis=dict(range=[radial_min, radial_max], visible=True),
|
|
193
|
+
angularaxis=dict(
|
|
194
|
+
tickmode="array",
|
|
195
|
+
tickvals=[ang * 180.0/np_pi for ang in cat_angles],
|
|
196
|
+
ticktext=cat_labels,
|
|
197
|
+
),
|
|
198
|
+
),
|
|
199
|
+
showlegend=False,
|
|
200
|
+
margin=dict(l=50, r=50, t=50, b=50),
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# 11) Export only the <div> (omit full HTML <head>), using CDN for Plotly.js
|
|
204
|
+
html_str = pio.to_html(fig, full_html=False, include_plotlyjs="cdn")
|
|
205
|
+
return html_str
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _hex_to_rgb(css_color: str) -> str:
|
|
209
|
+
"""
|
|
210
|
+
Convert a CSS color name or hex string (e.g. "lightgreen") into an "R,G,B" integer string
|
|
211
|
+
so that Plotly’s fillcolor accepts "rgba(R,G,B,alpha)".
|
|
212
|
+
"""
|
|
213
|
+
import matplotlib.colors as mcolors
|
|
214
|
+
|
|
215
|
+
rgba = mcolors.to_rgba(css_color)
|
|
216
|
+
r, g, b, _ = [int(255 * c) for c in rgba]
|
|
217
|
+
return f"{r},{g},{b}"
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
# Copyright (c) 2025 Medical Imaging and Data Resource Center (MIDRC).
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
|
|
16
|
+
""" Module for calculating quadratic weighted kappa and delta kappa values with confidence intervals. """
|
|
17
|
+
|
|
18
|
+
from dataclasses import replace
|
|
19
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
20
|
+
|
|
21
|
+
from joblib import delayed, Parallel
|
|
22
|
+
import matplotlib.pyplot as plt
|
|
23
|
+
import numpy as np
|
|
24
|
+
import pandas as pd
|
|
25
|
+
from sklearn.metrics import cohen_kappa_score
|
|
26
|
+
from sklearn.utils import resample
|
|
27
|
+
from tqdm import tqdm
|
|
28
|
+
from tqdm_joblib import tqdm_joblib
|
|
29
|
+
|
|
30
|
+
from MIDRC_MELODY.common.data_loading import TestAndDemographicData
|
|
31
|
+
from MIDRC_MELODY.common.plot_tools import SpiderPlotData
|
|
32
|
+
from MIDRC_MELODY.common.matplotlib_spider import plot_spider_chart, display_figures_grid
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def calculate_kappas_and_intervals(
|
|
36
|
+
test_data: TestAndDemographicData
|
|
37
|
+
) -> Tuple[Dict[str, float], Dict[str, Tuple[float, float]]]:
|
|
38
|
+
"""
|
|
39
|
+
Calculate Cohen's quadratic weighted kappa and bootstrap confidence intervals.
|
|
40
|
+
|
|
41
|
+
:arg test_data: TestAndDemographicData object containing the test and demographic data.
|
|
42
|
+
|
|
43
|
+
:returns: Tuple of dictionaries containing kappa scores and 95% confidence intervals.
|
|
44
|
+
"""
|
|
45
|
+
ai_cols = test_data.test_cols
|
|
46
|
+
if not isinstance(ai_cols, list):
|
|
47
|
+
ai_cols = [ai_cols]
|
|
48
|
+
kappas: Dict[str, float] = {}
|
|
49
|
+
intervals: Dict[str, Tuple[float, float]] = {}
|
|
50
|
+
y_true = test_data.matched_df[test_data.truth_col].to_numpy(dtype=int)
|
|
51
|
+
|
|
52
|
+
rng = np.random.default_rng(test_data.base_seed)
|
|
53
|
+
print('-'*50)
|
|
54
|
+
print("Overall Quadratic Weighted Kappa (κ) Scores:")
|
|
55
|
+
for col in ai_cols:
|
|
56
|
+
y_pred = test_data.matched_df[col].to_numpy(dtype=int)
|
|
57
|
+
kappa = cohen_kappa_score(y_true, y_pred, weights='quadratic')
|
|
58
|
+
kappas[col] = kappa
|
|
59
|
+
|
|
60
|
+
kappa_scores = np.empty(test_data.n_iter)
|
|
61
|
+
for i in range(test_data.n_iter):
|
|
62
|
+
indices = rng.integers(0, len(y_true), size=len(y_true))
|
|
63
|
+
kappa_scores[i] = cohen_kappa_score(y_true[indices], y_pred[indices], weights='quadratic')
|
|
64
|
+
lower_bnd, upper_bnd = np.percentile(kappa_scores, [2.5, 97.5])
|
|
65
|
+
intervals[col] = (lower_bnd, upper_bnd)
|
|
66
|
+
print(f"Model: {col} | Kappa (κ): {kappa:.4f} | 95% CI: ({lower_bnd:.4f}, {upper_bnd:.4f}) N: {len(y_true)}")
|
|
67
|
+
print('-'*50)
|
|
68
|
+
|
|
69
|
+
return kappas, intervals
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def bootstrap_kappa(test_data: TestAndDemographicData, n_jobs: int = -1) -> Dict[str, List[float]]:
|
|
73
|
+
"""
|
|
74
|
+
Perform bootstrap estimation of quadratic weighted kappa scores for each model in parallel.
|
|
75
|
+
|
|
76
|
+
:arg test_data: TestAndDemographicData object containing the test and demographic data.
|
|
77
|
+
:arg n_jobs: Number of parallel jobs.
|
|
78
|
+
|
|
79
|
+
:returns: Dictionary of model names and their corresponding kappa scores.
|
|
80
|
+
"""
|
|
81
|
+
models = test_data.test_cols
|
|
82
|
+
if not isinstance(models, list):
|
|
83
|
+
models = [models]
|
|
84
|
+
rng = np.random.default_rng(test_data.base_seed)
|
|
85
|
+
seeds = rng.integers(0, 1_000_000, size=test_data.n_iter)
|
|
86
|
+
|
|
87
|
+
def resample_and_compute_kappa(df: pd.DataFrame, truth_col: str, _models: List[str], seed: int) -> List[float]:
|
|
88
|
+
sampled_df = resample(df, replace=True, random_state=seed)
|
|
89
|
+
return [
|
|
90
|
+
cohen_kappa_score(sampled_df[truth_col].to_numpy(dtype=int),
|
|
91
|
+
sampled_df[model].to_numpy(dtype=int),
|
|
92
|
+
weights='quadratic')
|
|
93
|
+
for model in _models
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
with tqdm_joblib(total=test_data.n_iter, desc="Bootstrapping", leave=False):
|
|
97
|
+
kappas_2d = Parallel(n_jobs=n_jobs)(
|
|
98
|
+
delayed(resample_and_compute_kappa)(test_data.matched_df, test_data.truth_col, models, seed)
|
|
99
|
+
for seed in seeds
|
|
100
|
+
)
|
|
101
|
+
kappa_dict = dict(zip(models, zip(*kappas_2d)))
|
|
102
|
+
return kappa_dict
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def calculate_delta_kappa(
|
|
106
|
+
# df: pd.DataFrame, categories: List[str], reference_groups: Dict[str, Any], valid_groups: Dict[str, List[Any]],
|
|
107
|
+
# truth_col: str, ai_columns: List[str], n_iter: int = 1000, base_seed: Optional[int] = None
|
|
108
|
+
test_data: TestAndDemographicData
|
|
109
|
+
) -> Dict[str, Dict[str, Dict[Any, Tuple[float, Tuple[float, float]]]]]:
|
|
110
|
+
"""
|
|
111
|
+
Calculate delta kappa (difference between group and reference) with bootstrap confidence intervals.
|
|
112
|
+
|
|
113
|
+
:arg test_data: TestAndDemographicData object containing the test and demographic data.
|
|
114
|
+
|
|
115
|
+
:returns: Dictionary of delta quadratic weighted kappa values with 95% confidence intervals.
|
|
116
|
+
"""
|
|
117
|
+
delta_kappas: Dict[str, Dict[str, Dict[Any, Tuple[float, Tuple[float, float]]]]] = {}
|
|
118
|
+
df = test_data.matched_df
|
|
119
|
+
|
|
120
|
+
for category in tqdm(test_data.categories, desc="Categories", position=0):
|
|
121
|
+
if category not in test_data.valid_groups:
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
delta_kappas[category] = {model: {} for model in test_data.test_cols}
|
|
125
|
+
unique_values = df[category].unique().tolist()
|
|
126
|
+
|
|
127
|
+
kappa_dicts: Dict[str, Dict[str, List[float]]] = {}
|
|
128
|
+
for value in tqdm(unique_values, desc=f"Category \033[1m{category}\033[0m Groups", leave=False, position=1):
|
|
129
|
+
if value not in test_data.valid_groups[category]:
|
|
130
|
+
continue
|
|
131
|
+
|
|
132
|
+
filtered_df = df[df[category] == value]
|
|
133
|
+
|
|
134
|
+
# Create a shallow copy of test_data and update matched_df with filtered_df
|
|
135
|
+
filtered_test_data = replace(test_data, matched_df=filtered_df)
|
|
136
|
+
|
|
137
|
+
kappa_dicts[value] = bootstrap_kappa(filtered_test_data, n_jobs=-1)
|
|
138
|
+
|
|
139
|
+
# Remove and store reference bootstraps.
|
|
140
|
+
ref_bootstraps = kappa_dicts.pop(test_data.reference_groups[category])
|
|
141
|
+
|
|
142
|
+
# Now calculate the differences.
|
|
143
|
+
for value, kappa_dict in kappa_dicts.items():
|
|
144
|
+
for model in test_data.test_cols:
|
|
145
|
+
model_boot = np.array(kappa_dict[model])
|
|
146
|
+
ref_boot = np.array(ref_bootstraps[model])
|
|
147
|
+
deltas = model_boot - ref_boot
|
|
148
|
+
delta_median = float(np.median(deltas))
|
|
149
|
+
lower_value, upper_value = np.percentile(deltas, [2.5, 97.5])
|
|
150
|
+
delta_kappas[category][model][value] = (
|
|
151
|
+
delta_median,
|
|
152
|
+
(float(lower_value), float(upper_value))
|
|
153
|
+
)
|
|
154
|
+
return delta_kappas
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def extract_plot_data(delta_kappas: Dict[str, Dict[str, Dict[Any, Tuple[float, Tuple[float, float]]]]],
|
|
158
|
+
model_name: str) -> Tuple[List[str], List[float], List[float], List[float]]:
|
|
159
|
+
"""
|
|
160
|
+
Extract group names, delta values and confidence intervals for plotting.
|
|
161
|
+
|
|
162
|
+
:arg delta_kappas: Dictionary of delta kappa values with 95% confidence intervals.
|
|
163
|
+
:arg model_name: Name of the AI model.
|
|
164
|
+
|
|
165
|
+
:returns: Tuple of group names, delta values, lower bounds and upper bounds.
|
|
166
|
+
"""
|
|
167
|
+
groups: List[str] = []
|
|
168
|
+
values: List[float] = []
|
|
169
|
+
lower_bounds: List[float] = []
|
|
170
|
+
upper_bounds: List[float] = []
|
|
171
|
+
|
|
172
|
+
for category, model_data in delta_kappas.items():
|
|
173
|
+
if model_name in model_data:
|
|
174
|
+
for group, (value, (lower_ci, upper_ci)) in model_data[model_name].items():
|
|
175
|
+
groups.append(f"{category}: {group}")
|
|
176
|
+
values.append(value)
|
|
177
|
+
lower_bounds.append(lower_ci)
|
|
178
|
+
upper_bounds.append(upper_ci)
|
|
179
|
+
return groups, values, lower_bounds, upper_bounds
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def create_spider_plot_data_qwk(
|
|
183
|
+
delta_kappas: Dict[str, Dict[str, Dict[Any, Tuple[float, Tuple[float, float]]]]],
|
|
184
|
+
ai_models: List[str],
|
|
185
|
+
plot_config: Optional[Dict[str, Any]] = None
|
|
186
|
+
) -> List[SpiderPlotData]:
|
|
187
|
+
"""
|
|
188
|
+
Create a list of SpiderPlotData instances for each AI model based on delta kappas.
|
|
189
|
+
|
|
190
|
+
:arg delta_kappas: Dictionary of delta kappa values with 95% confidence intervals.
|
|
191
|
+
:arg ai_models: List of test columns (AI model names).
|
|
192
|
+
:arg plot_config: Optional configuration dictionary for plotting.
|
|
193
|
+
|
|
194
|
+
:returns: Dictionary of SpiderPlotData instances keyed by model names.
|
|
195
|
+
"""
|
|
196
|
+
plot_data_list: List[SpiderPlotData] = []
|
|
197
|
+
all_values, all_lower, all_upper = [], [], []
|
|
198
|
+
|
|
199
|
+
for model in ai_models:
|
|
200
|
+
_, values, lower_bounds, upper_bounds = extract_plot_data(delta_kappas, model)
|
|
201
|
+
all_values.extend(values)
|
|
202
|
+
all_lower.extend(lower_bounds)
|
|
203
|
+
all_upper.extend(upper_bounds)
|
|
204
|
+
|
|
205
|
+
global_min = min(all_lower) - 0.05
|
|
206
|
+
global_max = max(all_upper) + 0.05
|
|
207
|
+
metric = "QWK"
|
|
208
|
+
base_plot_data = SpiderPlotData(
|
|
209
|
+
ylim_min={metric: global_min},
|
|
210
|
+
ylim_max={metric: global_max},
|
|
211
|
+
plot_config=plot_config,
|
|
212
|
+
metric=metric,
|
|
213
|
+
)
|
|
214
|
+
for model in ai_models:
|
|
215
|
+
# Create a new copy based on the base instance
|
|
216
|
+
plot_data = SpiderPlotData(**base_plot_data.__dict__)
|
|
217
|
+
plot_data.model_name = model
|
|
218
|
+
plot_data.groups, plot_data.values, plot_data.lower_bounds, plot_data.upper_bounds = \
|
|
219
|
+
extract_plot_data(delta_kappas, model)
|
|
220
|
+
plot_data_list.append(plot_data)
|
|
221
|
+
|
|
222
|
+
return plot_data_list
|
|
223
|
+
|
|
224
|
+
def generate_plots_from_delta_kappas(
|
|
225
|
+
delta_kappas: Dict[str, Dict[str, Dict[Any, Tuple[float, Tuple[float, float]]]]],
|
|
226
|
+
ai_models: List[str],
|
|
227
|
+
plot_config: Optional[Dict[str, Any]] = None
|
|
228
|
+
) -> None:
|
|
229
|
+
"""
|
|
230
|
+
Generate spider plots for delta kappas using consistent scale across models.
|
|
231
|
+
|
|
232
|
+
:arg delta_kappas: Dictionary of delta kappa values with 95% confidence intervals.
|
|
233
|
+
:arg ai_models: List of test columns (AI model names).
|
|
234
|
+
:arg plot_config: Optional configuration dictionary for plotting
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
figures = []
|
|
238
|
+
|
|
239
|
+
plot_data_list = create_spider_plot_data_qwk(delta_kappas, ai_models, plot_config)
|
|
240
|
+
for plot_data in plot_data_list:
|
|
241
|
+
fig = plot_spider_chart(plot_data)
|
|
242
|
+
figures.append(fig)
|
|
243
|
+
|
|
244
|
+
display_figures_grid(figures)
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# Copyright (c) 2025 Medical Imaging and Data Resource Center (MIDRC).
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Dict, Final, List, Tuple
|
|
16
|
+
|
|
17
|
+
from tabulate import tabulate
|
|
18
|
+
|
|
19
|
+
# ANSI color codes
|
|
20
|
+
GLOBAL_COLORS: Final = {
|
|
21
|
+
'eod_negative': (128, 0, 0), # Maroon
|
|
22
|
+
'eod_positive': (0, 128, 0), # Green
|
|
23
|
+
'aaod': (255, 165, 0), # Orange
|
|
24
|
+
'kappa_negative': (128, 0, 0), # Maroon
|
|
25
|
+
'kappa_positive': (0, 128, 0), # Green
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
_CONSOLE_RESET: Final = "\033[0m"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _console_color(color: str | Tuple[int, int, int]) -> str:
|
|
32
|
+
"""
|
|
33
|
+
Convert RGB color tuple or color name to ANSI escape code string.
|
|
34
|
+
"""
|
|
35
|
+
if isinstance(color, tuple):
|
|
36
|
+
return f"\033[38;2;{color[0]};{color[1]};{color[2]}m"
|
|
37
|
+
if isinstance(color, str):
|
|
38
|
+
if color in GLOBAL_COLORS:
|
|
39
|
+
rgb = GLOBAL_COLORS[color]
|
|
40
|
+
return _console_color(rgb)
|
|
41
|
+
return ""
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def _format_console_string(value: str, color: str) -> str:
|
|
45
|
+
"""
|
|
46
|
+
Format a string value with ANSI color if it qualifies.
|
|
47
|
+
"""
|
|
48
|
+
if color is not None:
|
|
49
|
+
return f"{color}{value}{_CONSOLE_RESET}"
|
|
50
|
+
return value
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _format_console_value(value: float, color: str) -> str:
|
|
54
|
+
"""
|
|
55
|
+
Format a numeric value with ANSI color if it qualifies.
|
|
56
|
+
"""
|
|
57
|
+
formatted = f"{value:.4f}"
|
|
58
|
+
return _format_console_string(formatted, color)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _sort_rows(rows: List[List[str]]) -> List[List[str]]:
|
|
62
|
+
"""Sort rows by Model, Category, Group, then Metric if present."""
|
|
63
|
+
return sorted(rows, key=lambda r: tuple(r[:4]))
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _build_eod_aaod_tables_generic(
|
|
67
|
+
eod_aaod: Dict[str, Dict[str, Dict[Any, Dict[str, Tuple[float, Tuple[float, float]]]]]], *,
|
|
68
|
+
console: bool
|
|
69
|
+
) -> List[Tuple[List[str], str]]:
|
|
70
|
+
"""
|
|
71
|
+
Generate tables for EOD/AAOD metrics.
|
|
72
|
+
If console is True, returns tables of list[str] rows with ANSI coloring.
|
|
73
|
+
If console is False, returns tables as tuples of (row: list[str], color: QColor | None).
|
|
74
|
+
"""
|
|
75
|
+
# Set up color formatting based on mode.
|
|
76
|
+
if console:
|
|
77
|
+
color_fn = _console_color
|
|
78
|
+
else:
|
|
79
|
+
# For GUI, use plain formatting
|
|
80
|
+
try:
|
|
81
|
+
from PySide6.QtGui import QColor
|
|
82
|
+
except ImportError:
|
|
83
|
+
raise ImportError("PySide6 is required for GUI table generation.")
|
|
84
|
+
color_fn = lambda x: QColor(*GLOBAL_COLORS[x]) if x in GLOBAL_COLORS else None
|
|
85
|
+
|
|
86
|
+
# Initialize lists; console returns lists of rows, GUI returns tuples (row, color).
|
|
87
|
+
all_eod = []
|
|
88
|
+
all_aaod = []
|
|
89
|
+
filtered = []
|
|
90
|
+
|
|
91
|
+
for category, model_data in eod_aaod.items():
|
|
92
|
+
for model, groups in model_data.items():
|
|
93
|
+
for group, metrics in groups.items():
|
|
94
|
+
for metric in ('eod', 'aaod'):
|
|
95
|
+
if metric not in metrics:
|
|
96
|
+
continue
|
|
97
|
+
median, (ci_lo, ci_hi) = metrics[metric]
|
|
98
|
+
if metric == 'eod':
|
|
99
|
+
qualifies = abs(median) > 0.1
|
|
100
|
+
color = None if not qualifies else\
|
|
101
|
+
color_fn('eod_negative') if median < 0 else color_fn('eod_positive')
|
|
102
|
+
target_list = all_eod
|
|
103
|
+
else:
|
|
104
|
+
qualifies = median > 0.1
|
|
105
|
+
color = color_fn('aaod') if qualifies else None
|
|
106
|
+
target_list = all_aaod
|
|
107
|
+
|
|
108
|
+
# Format each cell
|
|
109
|
+
format_fn = lambda v: f"{v:.4f}"
|
|
110
|
+
val_str = format_fn(median)
|
|
111
|
+
lo_str = format_fn(ci_lo)
|
|
112
|
+
hi_str = format_fn(ci_hi)
|
|
113
|
+
|
|
114
|
+
row = [model, category, group, val_str, lo_str, hi_str]
|
|
115
|
+
target_list.append((row, color))
|
|
116
|
+
|
|
117
|
+
if qualifies:
|
|
118
|
+
# For filtered rows, insert the metric name.
|
|
119
|
+
row_f = row.copy()
|
|
120
|
+
row_f.insert(3, metric.upper())
|
|
121
|
+
filtered.append((row_f, color))
|
|
122
|
+
|
|
123
|
+
# Define a common sort key based on the first 4 cells of each row.
|
|
124
|
+
sort_key = lambda x: tuple(x[0][:4])
|
|
125
|
+
|
|
126
|
+
# Sort all lists using map for conciseness.
|
|
127
|
+
sorted_all_eod, sorted_all_aaod, sorted_filtered = map(
|
|
128
|
+
lambda rows: sorted(rows, key=sort_key),
|
|
129
|
+
[all_eod, all_aaod, filtered]
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return sorted_all_eod, sorted_all_aaod, sorted_filtered
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _print_section(title: str, rows: List[List[str]], headers: List[str], tablefmt: str) -> None:
|
|
136
|
+
print(title)
|
|
137
|
+
print(tabulate(rows, headers=headers, tablefmt=tablefmt))
|
|
138
|
+
print()
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _build_eod_aaod_tables_console(
|
|
142
|
+
eod_aaod: Dict[str, Dict[str, Dict[Any, Dict[str, Tuple[float, Tuple[float, float]]]]]]
|
|
143
|
+
) -> Tuple[List[List[str]], List[List[str]], List[List[str]]]:
|
|
144
|
+
# Delegate table-building to the generic function with console=True.
|
|
145
|
+
sorted_all_eod, sorted_all_aaod, sorted_filtered = _build_eod_aaod_tables_generic(eod_aaod, console=True)
|
|
146
|
+
|
|
147
|
+
def convert_fn(rows: List[Tuple[List[str], str]]) -> List[List[str]]:
|
|
148
|
+
return [
|
|
149
|
+
row if color is None else row[:-3] + [f"{color}{cell}{_CONSOLE_RESET}" for cell in row[-3:]]
|
|
150
|
+
for row, color in rows
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
sorted_all_eod, sorted_all_aaod, sorted_filtered = map(convert_fn,
|
|
154
|
+
[sorted_all_eod, sorted_all_aaod, sorted_filtered])
|
|
155
|
+
return sorted_all_eod, sorted_all_aaod, sorted_filtered
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def print_table_of_nonzero_eod_aaod(
|
|
159
|
+
eod_aaod: Dict[str, Dict[str, Dict[Any, Dict[str, Tuple[float, Tuple[float, float]]]]]],
|
|
160
|
+
tablefmt: str = 'grid'
|
|
161
|
+
) -> None:
|
|
162
|
+
"""
|
|
163
|
+
Print tables for EOD and AAOD medians, highlighting values meeting criteria.
|
|
164
|
+
"""
|
|
165
|
+
all_eod, all_aaod, filtered = _build_eod_aaod_tables_console(eod_aaod)
|
|
166
|
+
|
|
167
|
+
headers_all = ['Model', 'Category', 'Group', 'Median', 'Lower CI', 'Upper CI']
|
|
168
|
+
headers_filtered = ['Model', 'Category', 'Group', 'Metric', 'Median', 'Lower CI', 'Upper CI']
|
|
169
|
+
|
|
170
|
+
_print_section('All EOD median values:', all_eod, headers_all, tablefmt)
|
|
171
|
+
_print_section('All AAOD median values:', all_aaod, headers_all, tablefmt)
|
|
172
|
+
|
|
173
|
+
if filtered:
|
|
174
|
+
_print_section('EOD/AAOD median values meeting criteria:', filtered, headers_filtered, tablefmt)
|
|
175
|
+
else:
|
|
176
|
+
print('No model/group combinations meeting the specified criteria for EOD/AAOD.')
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _build_delta_tables(
|
|
180
|
+
delta_kappas: Dict[str, Dict[str, Dict[Any, Tuple[float, Tuple[float, float]]]]]
|
|
181
|
+
) -> Tuple[List[List[str]], List[List[str]]]:
|
|
182
|
+
all_deltas, filtered = [], []
|
|
183
|
+
|
|
184
|
+
for category, model_data in delta_kappas.items():
|
|
185
|
+
for model, groups in model_data.items():
|
|
186
|
+
for group, (delta, (ci_lo, ci_hi)) in groups.items():
|
|
187
|
+
qualifies = ci_lo > 0 or ci_hi < 0
|
|
188
|
+
color = None if not qualifies else\
|
|
189
|
+
_console_color('kappa_negative' if delta < 0 else 'kappa_positive')
|
|
190
|
+
|
|
191
|
+
delta_str = _format_console_value(delta, color)
|
|
192
|
+
lo_str = _format_console_value(ci_lo, color)
|
|
193
|
+
hi_str = _format_console_value(ci_hi, color)
|
|
194
|
+
row = [model, category, group, delta_str, lo_str, hi_str]
|
|
195
|
+
all_deltas.append(row)
|
|
196
|
+
|
|
197
|
+
if qualifies:
|
|
198
|
+
filtered.append(row)
|
|
199
|
+
|
|
200
|
+
return _sort_rows(all_deltas), _sort_rows(filtered)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def print_table_of_nonzero_deltas(
|
|
204
|
+
delta_kappas: Dict[str, Dict[str, Dict[Any, Tuple[float, Tuple[float, float]]]]],
|
|
205
|
+
tablefmt: str = 'grid'
|
|
206
|
+
) -> None:
|
|
207
|
+
"""
|
|
208
|
+
Print tables for Delta Kappa values, highlighting those with 95% CI excluding zero.
|
|
209
|
+
"""
|
|
210
|
+
all_deltas, filtered = _build_delta_tables(delta_kappas)
|
|
211
|
+
headers = ['Model', 'Category', 'Group', 'Δκ', 'Lower CI', 'Upper CI']
|
|
212
|
+
|
|
213
|
+
_print_section('All Δκ Values:', all_deltas, headers, tablefmt)
|
|
214
|
+
|
|
215
|
+
if filtered:
|
|
216
|
+
_print_section('Δκ values with 95% CI excluding zero:', filtered, headers, tablefmt)
|
|
217
|
+
else:
|
|
218
|
+
print('No model/group combinations meeting the specified criteria for Δκ.')
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
try:
|
|
222
|
+
from PySide6.QtGui import QColor
|
|
223
|
+
except ImportError:
|
|
224
|
+
build_eod_aaod_tables_gui = None
|
|
225
|
+
else:
|
|
226
|
+
def build_eod_aaod_tables_gui(
|
|
227
|
+
eod_aaod: Dict[str, Dict[str, Dict[Any, Dict[str, Tuple[float, Tuple[float, float]]]]]]
|
|
228
|
+
) -> tuple[list[tuple[list[str], "QColor | None"]], list[tuple[list[str], "QColor | None"]], list[tuple[list[str], "QColor | None"]]]:
|
|
229
|
+
# Delegate table-building to the generic function with console=False.
|
|
230
|
+
return _build_eod_aaod_tables_generic(eod_aaod, console=False)
|
|
File without changes
|