fastMONAI 0.5.3__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastMONAI/__init__.py +1 -1
- fastMONAI/_modidx.py +171 -27
- fastMONAI/dataset_info.py +190 -45
- fastMONAI/external_data.py +1 -1
- fastMONAI/utils.py +101 -18
- fastMONAI/vision_all.py +3 -2
- fastMONAI/vision_augmentation.py +133 -29
- fastMONAI/vision_core.py +29 -132
- fastMONAI/vision_data.py +6 -6
- fastMONAI/vision_inference.py +35 -9
- fastMONAI/vision_metrics.py +420 -19
- fastMONAI/vision_patch.py +1125 -0
- fastMONAI/vision_plot.py +1 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/METADATA +5 -5
- fastmonai-0.5.4.dist-info/RECORD +21 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/WHEEL +1 -1
- fastmonai-0.5.3.dist-info/RECORD +0 -20
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/entry_points.txt +0 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {fastmonai-0.5.3.dist-info → fastmonai-0.5.4.dist-info}/top_level.txt +0 -0
fastMONAI/vision_inference.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
# %% auto 0
|
|
4
4
|
__all__ = ['save_series_pred', 'load_system_resources', 'inference', 'compute_binary_tumor_volume', 'refine_binary_pred_mask',
|
|
5
|
-
'gradio_image_classifier']
|
|
5
|
+
'keep_largest', 'gradio_image_classifier']
|
|
6
6
|
|
|
7
7
|
# %% ../nbs/06_vision_inference.ipynb 1
|
|
8
8
|
from copy import copy
|
|
@@ -67,18 +67,18 @@ def load_system_resources(models_path, learner_fn, variables_fn):
|
|
|
67
67
|
|
|
68
68
|
learn = load_learner(models_path / learner_fn, cpu=True)
|
|
69
69
|
vars_fn = models_path / variables_fn
|
|
70
|
-
_,
|
|
70
|
+
_, apply_reorder, target_spacing = load_variables(pkl_fn=vars_fn)
|
|
71
71
|
|
|
72
|
-
return learn,
|
|
72
|
+
return learn, apply_reorder, target_spacing
|
|
73
73
|
|
|
74
74
|
# %% ../nbs/06_vision_inference.ipynb 8
|
|
75
|
-
def inference(learn_inf,
|
|
75
|
+
def inference(learn_inf, apply_reorder, target_spacing, fn: (str, Path) = '',
|
|
76
76
|
save_path: (str, Path) = None, org_img=None, input_img=None,
|
|
77
77
|
org_size=None):
|
|
78
78
|
"""Predict on new data using exported model."""
|
|
79
79
|
|
|
80
80
|
if None in [org_img, input_img, org_size]:
|
|
81
|
-
org_img, input_img, org_size = med_img_reader(fn,
|
|
81
|
+
org_img, input_img, org_size = med_img_reader(fn, apply_reorder, target_spacing,
|
|
82
82
|
only_tensor=False)
|
|
83
83
|
else:
|
|
84
84
|
org_img, input_img = copy(org_img), copy(input_img)
|
|
@@ -148,6 +148,10 @@ def refine_binary_pred_mask(pred_mask,
|
|
|
148
148
|
if verbose:
|
|
149
149
|
print(n_components)
|
|
150
150
|
|
|
151
|
+
# Handle empty mask case (no foreground components)
|
|
152
|
+
if n_components == 0:
|
|
153
|
+
return torch.zeros_like(torch.Tensor(pred_mask)).float()
|
|
154
|
+
|
|
151
155
|
if remove_size is None:
|
|
152
156
|
sizes = np.bincount(labeled_mask.ravel())
|
|
153
157
|
max_label = sizes[1:].argmax() + 1
|
|
@@ -157,14 +161,36 @@ def refine_binary_pred_mask(pred_mask,
|
|
|
157
161
|
processed_mask = remove_small_objects(
|
|
158
162
|
labeled_mask, min_size=small_objects_threshold)
|
|
159
163
|
|
|
160
|
-
return torch.Tensor(processed_mask > 0).float()
|
|
164
|
+
return torch.Tensor(processed_mask > 0).float()
|
|
165
|
+
|
|
166
|
+
# %% ../nbs/06_vision_inference.ipynb 12
|
|
167
|
+
def keep_largest(pred_mask: torch.Tensor) -> torch.Tensor:
|
|
168
|
+
"""Keep only the largest connected component in a binary mask.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
pred_mask: Binary prediction mask tensor.
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
Binary mask with only the largest connected component.
|
|
175
|
+
"""
|
|
176
|
+
mask_np = pred_mask.numpy() if isinstance(pred_mask, torch.Tensor) else pred_mask
|
|
177
|
+
labeled_mask, n_components = label(mask_np)
|
|
178
|
+
|
|
179
|
+
if n_components == 0:
|
|
180
|
+
return torch.zeros_like(pred_mask) if isinstance(pred_mask, torch.Tensor) else mask_np
|
|
181
|
+
|
|
182
|
+
sizes = np.bincount(labeled_mask.ravel())
|
|
183
|
+
largest_label = sizes[1:].argmax() + 1 # Skip background (label 0)
|
|
184
|
+
|
|
185
|
+
result = (labeled_mask == largest_label).astype(np.float32)
|
|
186
|
+
return torch.from_numpy(result) if isinstance(pred_mask, torch.Tensor) else result
|
|
161
187
|
|
|
162
|
-
# %% ../nbs/06_vision_inference.ipynb
|
|
163
|
-
def gradio_image_classifier(file_obj, learn,
|
|
188
|
+
# %% ../nbs/06_vision_inference.ipynb 14
|
|
189
|
+
def gradio_image_classifier(file_obj, learn, apply_reorder, target_spacing):
|
|
164
190
|
"""Predict on images using exported learner and return the result as a dictionary."""
|
|
165
191
|
|
|
166
192
|
img_path = Path(file_obj.name)
|
|
167
|
-
img = med_img_reader(img_path,
|
|
193
|
+
img = med_img_reader(img_path, apply_reorder=apply_reorder, target_spacing=target_spacing)
|
|
168
194
|
|
|
169
195
|
_, _, predictions = learn.predict(img)
|
|
170
196
|
prediction_dict = {index: value.item() for index, value in enumerate(predictions)}
|
fastMONAI/vision_metrics.py
CHANGED
|
@@ -2,28 +2,76 @@
|
|
|
2
2
|
|
|
3
3
|
# %% auto 0
|
|
4
4
|
__all__ = ['calculate_dsc', 'calculate_haus', 'binary_dice_score', 'multi_dice_score', 'binary_hausdorff_distance',
|
|
5
|
-
'multi_hausdorff_distance'
|
|
5
|
+
'multi_hausdorff_distance', 'calculate_confusion_metrics', 'binary_sensitivity', 'multi_sensitivity',
|
|
6
|
+
'binary_precision', 'multi_precision', 'calculate_lesion_detection_rate', 'binary_lesion_detection_rate',
|
|
7
|
+
'multi_lesion_detection_rate', 'calculate_signed_rve', 'binary_signed_rve', 'multi_signed_rve',
|
|
8
|
+
'AccumulatedDice', 'AccumulatedMultiDice']
|
|
6
9
|
|
|
7
10
|
# %% ../nbs/05_vision_metrics.ipynb 1
|
|
8
11
|
import torch
|
|
9
12
|
import numpy as np
|
|
10
|
-
from monai.metrics import compute_hausdorff_distance, compute_dice
|
|
13
|
+
from monai.metrics import compute_hausdorff_distance, compute_dice, get_confusion_matrix, compute_confusion_matrix_metric
|
|
14
|
+
from scipy.ndimage import label as scipy_label
|
|
11
15
|
from .vision_data import pred_to_binary_mask, batch_pred_to_multiclass_mask
|
|
16
|
+
from fastai.learner import Metric
|
|
12
17
|
|
|
13
18
|
# %% ../nbs/05_vision_metrics.ipynb 3
|
|
14
19
|
def calculate_dsc(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
|
15
|
-
"""MONAI
|
|
20
|
+
"""Calculate Dice score using MONAI's compute_dice.
|
|
16
21
|
|
|
17
|
-
|
|
22
|
+
Accepts tensors of various shapes and automatically reshapes to 5D format.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
pred: Binary prediction tensor. Accepts:
|
|
26
|
+
- [D, H, W] single 3D volume
|
|
27
|
+
- [C, D, H, W] single volume with channel
|
|
28
|
+
- [B, C, D, H, W] batched volumes
|
|
29
|
+
targ: Binary target tensor (same shape options as pred).
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Dice score(s). Single value for 3D/4D input, tensor of values for 5D batch.
|
|
33
|
+
"""
|
|
34
|
+
# Normalize to 5D: [B, C, D, H, W]
|
|
35
|
+
if pred.ndim == 3: # [D, H, W] -> [1, 1, D, H, W]
|
|
36
|
+
pred = pred.unsqueeze(0).unsqueeze(0)
|
|
37
|
+
targ = targ.unsqueeze(0).unsqueeze(0)
|
|
38
|
+
elif pred.ndim == 4: # [C, D, H, W] -> [1, C, D, H, W]
|
|
39
|
+
pred = pred.unsqueeze(0)
|
|
40
|
+
targ = targ.unsqueeze(0)
|
|
41
|
+
|
|
42
|
+
return torch.Tensor([compute_dice(p[None], t[None]) for p, t in zip(pred, targ)])
|
|
18
43
|
|
|
19
44
|
# %% ../nbs/05_vision_metrics.ipynb 4
|
|
20
45
|
def calculate_haus(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
|
21
|
-
"""MONAI
|
|
46
|
+
"""Compute 95th percentile Hausdorff distance (HD95) using MONAI.
|
|
47
|
+
|
|
48
|
+
HD95 is more robust than standard Hausdorff distance as it ignores
|
|
49
|
+
the top 5% of outlier distances.
|
|
50
|
+
|
|
51
|
+
Accepts tensors of various shapes and automatically reshapes to 5D format.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
pred: Binary prediction tensor. Accepts:
|
|
55
|
+
- [D, H, W] single 3D volume
|
|
56
|
+
- [C, D, H, W] single volume with channel
|
|
57
|
+
- [B, C, D, H, W] batched volumes
|
|
58
|
+
targ: Binary target tensor (same shape options as pred).
|
|
22
59
|
|
|
23
|
-
|
|
60
|
+
Returns:
|
|
61
|
+
HD95 value(s). Single value for 3D/4D input, tensor of values for 5D batch.
|
|
62
|
+
"""
|
|
63
|
+
# Normalize to 5D: [B, C, D, H, W]
|
|
64
|
+
if pred.ndim == 3: # [D, H, W] -> [1, 1, D, H, W]
|
|
65
|
+
pred = pred.unsqueeze(0).unsqueeze(0)
|
|
66
|
+
targ = targ.unsqueeze(0).unsqueeze(0)
|
|
67
|
+
elif pred.ndim == 4: # [C, D, H, W] -> [1, C, D, H, W]
|
|
68
|
+
pred = pred.unsqueeze(0)
|
|
69
|
+
targ = targ.unsqueeze(0)
|
|
70
|
+
|
|
71
|
+
return torch.Tensor([compute_hausdorff_distance(p[None], t[None], percentile=95) for p, t in zip(pred, targ)])
|
|
24
72
|
|
|
25
73
|
# %% ../nbs/05_vision_metrics.ipynb 5
|
|
26
|
-
def binary_dice_score(act: torch.
|
|
74
|
+
def binary_dice_score(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
|
27
75
|
"""Calculates the mean Dice score for binary semantic segmentation tasks.
|
|
28
76
|
|
|
29
77
|
Args:
|
|
@@ -56,45 +104,398 @@ def multi_dice_score(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
|
|
56
104
|
for c in range(1, n_classes):
|
|
57
105
|
c_pred, c_targ = torch.where(pred == c, 1, 0), torch.where(targ == c, 1, 0)
|
|
58
106
|
dsc = calculate_dsc(c_pred, c_targ)
|
|
59
|
-
binary_dice_scores.append(
|
|
107
|
+
binary_dice_scores.append(float(torch.nanmean(dsc)))
|
|
60
108
|
|
|
61
109
|
return torch.Tensor(binary_dice_scores)
|
|
62
110
|
|
|
63
111
|
# %% ../nbs/05_vision_metrics.ipynb 7
|
|
64
112
|
def binary_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
|
65
|
-
"""Calculate the mean
|
|
113
|
+
"""Calculate the mean HD95 for binary semantic segmentation tasks.
|
|
66
114
|
|
|
67
115
|
Args:
|
|
68
116
|
act: Activation tensor with dimensions [B, C, W, H, D].
|
|
69
117
|
targ: Target masks with dimensions [B, C, W, H, D].
|
|
70
118
|
|
|
71
119
|
Returns:
|
|
72
|
-
Mean
|
|
120
|
+
Mean HD95.
|
|
73
121
|
"""
|
|
74
|
-
|
|
75
|
-
|
|
76
122
|
pred = pred_to_binary_mask(act)
|
|
77
|
-
|
|
78
123
|
haus = calculate_haus(pred.cpu(), targ.cpu())
|
|
79
124
|
return torch.mean(haus)
|
|
80
125
|
|
|
81
126
|
# %% ../nbs/05_vision_metrics.ipynb 8
|
|
82
|
-
def multi_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor
|
|
83
|
-
"""Calculate the mean
|
|
127
|
+
def multi_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
|
128
|
+
"""Calculate the mean HD95 for each class in multi-class semantic segmentation tasks.
|
|
84
129
|
|
|
85
130
|
Args:
|
|
86
131
|
act: Activation tensor with dimensions [B, C, W, H, D].
|
|
87
132
|
targ: Target masks with dimensions [B, C, W, H, D].
|
|
88
133
|
|
|
89
134
|
Returns:
|
|
90
|
-
Mean
|
|
135
|
+
Mean HD95 for each class.
|
|
91
136
|
"""
|
|
92
|
-
|
|
93
137
|
pred, n_classes = batch_pred_to_multiclass_mask(act)
|
|
94
138
|
binary_haus = []
|
|
95
139
|
|
|
96
140
|
for c in range(1, n_classes):
|
|
97
141
|
c_pred, c_targ = torch.where(pred==c, 1, 0), torch.where(targ==c, 1, 0)
|
|
98
|
-
haus = calculate_haus(
|
|
99
|
-
binary_haus.append(
|
|
142
|
+
haus = calculate_haus(c_pred, c_targ)
|
|
143
|
+
binary_haus.append(float(torch.nanmean(haus)))
|
|
100
144
|
return torch.Tensor(binary_haus)
|
|
145
|
+
|
|
146
|
+
# %% ../nbs/05_vision_metrics.ipynb 10
|
|
147
|
+
def calculate_confusion_metrics(pred: torch.Tensor, targ: torch.Tensor, metric_name: str) -> torch.Tensor:
|
|
148
|
+
"""Calculate confusion matrix-based metric using MONAI.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
pred: Binary prediction tensor [B, C, W, H, D].
|
|
152
|
+
targ: Binary target tensor [B, C, W, H, D].
|
|
153
|
+
metric_name: One of "sensitivity", "precision", "specificity", "f1 score".
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Metric values for each sample in batch.
|
|
157
|
+
"""
|
|
158
|
+
# get_confusion_matrix expects one-hot format and returns [B, n_class, 4] where 4 = [TP, FP, TN, FN]
|
|
159
|
+
confusion_matrix = get_confusion_matrix(pred, targ, include_background=False)
|
|
160
|
+
metric = compute_confusion_matrix_metric(metric_name, confusion_matrix)
|
|
161
|
+
return metric
|
|
162
|
+
|
|
163
|
+
# %% ../nbs/05_vision_metrics.ipynb 11
|
|
164
|
+
def binary_sensitivity(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
|
165
|
+
"""Calculate mean sensitivity (recall) for binary segmentation.
|
|
166
|
+
|
|
167
|
+
Sensitivity = TP / (TP + FN) - measures the proportion of actual positives
|
|
168
|
+
that are correctly identified.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
act: Activation tensor [B, C, W, H, D].
|
|
172
|
+
targ: Target masks [B, C, W, H, D].
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
Mean sensitivity score.
|
|
176
|
+
"""
|
|
177
|
+
pred = pred_to_binary_mask(act)
|
|
178
|
+
sens = calculate_confusion_metrics(pred.cpu(), targ.cpu(), "sensitivity")
|
|
179
|
+
return torch.nanmean(sens)
|
|
180
|
+
|
|
181
|
+
# %% ../nbs/05_vision_metrics.ipynb 12
|
|
182
|
+
def multi_sensitivity(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
|
183
|
+
"""Calculate mean sensitivity for each class in multi-class segmentation.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
act: Activation tensor [B, C, W, H, D].
|
|
187
|
+
targ: Target masks [B, C, W, H, D].
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Mean sensitivity for each class.
|
|
191
|
+
"""
|
|
192
|
+
pred, n_classes = batch_pred_to_multiclass_mask(act)
|
|
193
|
+
class_sens = []
|
|
194
|
+
|
|
195
|
+
for c in range(1, n_classes):
|
|
196
|
+
c_pred = torch.where(pred == c, 1, 0)
|
|
197
|
+
c_targ = torch.where(targ == c, 1, 0)
|
|
198
|
+
sens = calculate_confusion_metrics(c_pred, c_targ, "sensitivity")
|
|
199
|
+
class_sens.append(float(torch.nanmean(sens)))
|
|
200
|
+
|
|
201
|
+
return torch.Tensor(class_sens)
|
|
202
|
+
|
|
203
|
+
# %% ../nbs/05_vision_metrics.ipynb 13
|
|
204
|
+
def binary_precision(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
|
205
|
+
"""Calculate mean precision for binary segmentation.
|
|
206
|
+
|
|
207
|
+
Precision = TP / (TP + FP) - measures the proportion of positive predictions
|
|
208
|
+
that are actually correct.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
act: Activation tensor [B, C, W, H, D].
|
|
212
|
+
targ: Target masks [B, C, W, H, D].
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Mean precision score.
|
|
216
|
+
"""
|
|
217
|
+
pred = pred_to_binary_mask(act)
|
|
218
|
+
prec = calculate_confusion_metrics(pred.cpu(), targ.cpu(), "precision")
|
|
219
|
+
return torch.nanmean(prec)
|
|
220
|
+
|
|
221
|
+
# %% ../nbs/05_vision_metrics.ipynb 14
|
|
222
|
+
def multi_precision(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
|
223
|
+
"""Calculate mean precision for each class in multi-class segmentation.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
act: Activation tensor [B, C, W, H, D].
|
|
227
|
+
targ: Target masks [B, C, W, H, D].
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
Mean precision for each class.
|
|
231
|
+
"""
|
|
232
|
+
pred, n_classes = batch_pred_to_multiclass_mask(act)
|
|
233
|
+
class_prec = []
|
|
234
|
+
|
|
235
|
+
for c in range(1, n_classes):
|
|
236
|
+
c_pred = torch.where(pred == c, 1, 0)
|
|
237
|
+
c_targ = torch.where(targ == c, 1, 0)
|
|
238
|
+
prec = calculate_confusion_metrics(c_pred, c_targ, "precision")
|
|
239
|
+
class_prec.append(float(torch.nanmean(prec)))
|
|
240
|
+
|
|
241
|
+
return torch.Tensor(class_prec)
|
|
242
|
+
|
|
243
|
+
# %% ../nbs/05_vision_metrics.ipynb 16
|
|
244
|
+
def calculate_lesion_detection_rate(pred: torch.Tensor, targ: torch.Tensor, threshold: float = 0.0) -> torch.Tensor:
|
|
245
|
+
"""Calculate lesion-wise detection rate.
|
|
246
|
+
|
|
247
|
+
For each connected component (lesion) in the target, check if it is
|
|
248
|
+
detected by the prediction. Detection criteria depends on threshold:
|
|
249
|
+
- threshold=0: any overlap counts as detected
|
|
250
|
+
- threshold>0: per-lesion Dice score must exceed threshold
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
pred: Binary prediction tensor [B, C, W, H, D].
|
|
254
|
+
targ: Binary target tensor [B, C, W, H, D].
|
|
255
|
+
threshold: Minimum Dice score for a lesion to be considered detected.
|
|
256
|
+
Default 0.0 means any overlap counts as detected.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
Detection rate (detected lesions / total lesions) for each sample.
|
|
260
|
+
"""
|
|
261
|
+
detection_rates = []
|
|
262
|
+
|
|
263
|
+
for p, t in zip(pred, targ):
|
|
264
|
+
p_np = p.squeeze().cpu().numpy()
|
|
265
|
+
t_np = t.squeeze().cpu().numpy()
|
|
266
|
+
|
|
267
|
+
# Label connected components in target
|
|
268
|
+
labeled_targ, n_lesions = scipy_label(t_np)
|
|
269
|
+
|
|
270
|
+
if n_lesions == 0:
|
|
271
|
+
detection_rates.append(float('nan'))
|
|
272
|
+
continue
|
|
273
|
+
|
|
274
|
+
detected = 0
|
|
275
|
+
for lesion_id in range(1, n_lesions + 1):
|
|
276
|
+
lesion_mask = (labeled_targ == lesion_id)
|
|
277
|
+
|
|
278
|
+
if threshold == 0.0:
|
|
279
|
+
# Original behavior: any overlap counts as detected
|
|
280
|
+
overlap = (p_np * lesion_mask).sum()
|
|
281
|
+
if overlap > 0:
|
|
282
|
+
detected += 1
|
|
283
|
+
else:
|
|
284
|
+
# Compute per-lesion Dice score
|
|
285
|
+
pred_in_lesion = p_np * lesion_mask
|
|
286
|
+
intersection = (pred_in_lesion * lesion_mask).sum()
|
|
287
|
+
lesion_vol = lesion_mask.sum()
|
|
288
|
+
pred_vol = pred_in_lesion.sum()
|
|
289
|
+
|
|
290
|
+
if (lesion_vol + pred_vol) > 0:
|
|
291
|
+
dice = 2 * intersection / (lesion_vol + pred_vol)
|
|
292
|
+
if dice > threshold:
|
|
293
|
+
detected += 1
|
|
294
|
+
|
|
295
|
+
detection_rates.append(detected / n_lesions)
|
|
296
|
+
|
|
297
|
+
return torch.Tensor(detection_rates)
|
|
298
|
+
|
|
299
|
+
# %% ../nbs/05_vision_metrics.ipynb 17
|
|
300
|
+
def binary_lesion_detection_rate(act: torch.Tensor, targ: torch.Tensor, threshold: float = 0.0) -> torch.Tensor:
|
|
301
|
+
"""Calculate mean lesion detection rate for binary segmentation.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
act: Activation tensor [B, C, W, H, D].
|
|
305
|
+
targ: Target masks [B, C, W, H, D].
|
|
306
|
+
threshold: Minimum Dice score for a lesion to be considered detected.
|
|
307
|
+
Default 0.0 means any overlap counts as detected.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
Mean lesion detection rate.
|
|
311
|
+
"""
|
|
312
|
+
pred = pred_to_binary_mask(act)
|
|
313
|
+
ldr = calculate_lesion_detection_rate(pred.cpu(), targ.cpu(), threshold)
|
|
314
|
+
return torch.nanmean(ldr)
|
|
315
|
+
|
|
316
|
+
# %% ../nbs/05_vision_metrics.ipynb 18
|
|
317
|
+
def multi_lesion_detection_rate(act: torch.Tensor, targ: torch.Tensor, threshold: float = 0.0) -> torch.Tensor:
|
|
318
|
+
"""Calculate mean lesion detection rate for each class in multi-class segmentation.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
act: Activation tensor [B, C, W, H, D].
|
|
322
|
+
targ: Target masks [B, C, W, H, D].
|
|
323
|
+
threshold: Minimum Dice score for a lesion to be considered detected.
|
|
324
|
+
Default 0.0 means any overlap counts as detected.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
Mean lesion detection rate for each class.
|
|
328
|
+
"""
|
|
329
|
+
pred, n_classes = batch_pred_to_multiclass_mask(act)
|
|
330
|
+
class_ldr = []
|
|
331
|
+
|
|
332
|
+
for c in range(1, n_classes):
|
|
333
|
+
c_pred = torch.where(pred == c, 1, 0)
|
|
334
|
+
c_targ = torch.where(targ == c, 1, 0)
|
|
335
|
+
ldr = calculate_lesion_detection_rate(c_pred, c_targ, threshold)
|
|
336
|
+
class_ldr.append(float(torch.nanmean(ldr)))
|
|
337
|
+
|
|
338
|
+
return torch.Tensor(class_ldr)
|
|
339
|
+
|
|
340
|
+
# %% ../nbs/05_vision_metrics.ipynb 20
|
|
341
|
+
def calculate_signed_rve(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
|
342
|
+
"""Calculate signed Relative Volume Error.
|
|
343
|
+
|
|
344
|
+
RVE = (pred_volume - targ_volume) / targ_volume
|
|
345
|
+
|
|
346
|
+
Positive values indicate over-segmentation (model predicts too large),
|
|
347
|
+
negative values indicate under-segmentation (model predicts too small).
|
|
348
|
+
|
|
349
|
+
Args:
|
|
350
|
+
pred: Binary prediction tensor [B, C, W, H, D].
|
|
351
|
+
targ: Binary target tensor [B, C, W, H, D].
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
Signed RVE for each sample in batch.
|
|
355
|
+
"""
|
|
356
|
+
rve_values = []
|
|
357
|
+
|
|
358
|
+
for p, t in zip(pred, targ):
|
|
359
|
+
pred_vol = p.sum().float()
|
|
360
|
+
targ_vol = t.sum().float()
|
|
361
|
+
|
|
362
|
+
if targ_vol == 0:
|
|
363
|
+
rve_values.append(float('nan'))
|
|
364
|
+
else:
|
|
365
|
+
rve = (pred_vol - targ_vol) / targ_vol
|
|
366
|
+
rve_values.append(rve.item())
|
|
367
|
+
|
|
368
|
+
return torch.Tensor(rve_values)
|
|
369
|
+
|
|
370
|
+
# %% ../nbs/05_vision_metrics.ipynb 21
|
|
371
|
+
def binary_signed_rve(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
|
372
|
+
"""Calculate mean signed RVE for binary segmentation.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
act: Activation tensor [B, C, W, H, D].
|
|
376
|
+
targ: Target masks [B, C, W, H, D].
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
Mean signed RVE.
|
|
380
|
+
"""
|
|
381
|
+
pred = pred_to_binary_mask(act)
|
|
382
|
+
rve = calculate_signed_rve(pred.cpu(), targ.cpu())
|
|
383
|
+
return torch.nanmean(rve)
|
|
384
|
+
|
|
385
|
+
# %% ../nbs/05_vision_metrics.ipynb 22
|
|
386
|
+
def multi_signed_rve(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
|
|
387
|
+
"""Calculate mean signed RVE for each class in multi-class segmentation.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
act: Activation tensor [B, C, W, H, D].
|
|
391
|
+
targ: Target masks [B, C, W, H, D].
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
Mean signed RVE for each class.
|
|
395
|
+
"""
|
|
396
|
+
pred, n_classes = batch_pred_to_multiclass_mask(act)
|
|
397
|
+
class_rve = []
|
|
398
|
+
|
|
399
|
+
for c in range(1, n_classes):
|
|
400
|
+
c_pred = torch.where(pred == c, 1, 0)
|
|
401
|
+
c_targ = torch.where(targ == c, 1, 0)
|
|
402
|
+
rve = calculate_signed_rve(c_pred, c_targ)
|
|
403
|
+
class_rve.append(float(torch.nanmean(rve)))
|
|
404
|
+
|
|
405
|
+
return torch.Tensor(class_rve)
|
|
406
|
+
|
|
407
|
+
# %% ../nbs/05_vision_metrics.ipynb 24
|
|
408
|
+
class AccumulatedDice(Metric):
|
|
409
|
+
"""nnU-Net-style accumulated Dice metric for reliable pseudo dice during training.
|
|
410
|
+
|
|
411
|
+
Instead of averaging per-batch Dice scores, this metric accumulates
|
|
412
|
+
true positives, false positives, and false negatives across ALL validation
|
|
413
|
+
batches, then computes Dice from the totals. This gives more weight to
|
|
414
|
+
batches with more foreground voxels and is more statistically robust.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
n_classes: Number of classes including background (default: 2 for binary).
|
|
418
|
+
include_background: Whether to include background in metric (default: False).
|
|
419
|
+
|
|
420
|
+
Example:
|
|
421
|
+
```python
|
|
422
|
+
learn = Learner(dls, model, loss_func=loss_func,
|
|
423
|
+
metrics=[AccumulatedDice(n_classes=2)])
|
|
424
|
+
|
|
425
|
+
# For checkpoint selection based on accumulated dice:
|
|
426
|
+
save_best = SaveModelCallback(
|
|
427
|
+
monitor='accumulated_dice',
|
|
428
|
+
comp=np.greater, # Higher dice is better
|
|
429
|
+
fname='best_model'
|
|
430
|
+
)
|
|
431
|
+
```
|
|
432
|
+
"""
|
|
433
|
+
def __init__(self, n_classes: int = 2, include_background: bool = False):
|
|
434
|
+
self.n_classes = n_classes
|
|
435
|
+
self.include_background = include_background
|
|
436
|
+
self.start_class = 0 if include_background else 1
|
|
437
|
+
|
|
438
|
+
def reset(self):
|
|
439
|
+
"""Called at start of validation epoch."""
|
|
440
|
+
n_fg_classes = self.n_classes if self.include_background else self.n_classes - 1
|
|
441
|
+
self.tp = torch.zeros(n_fg_classes, dtype=torch.float64)
|
|
442
|
+
self.fp = torch.zeros(n_fg_classes, dtype=torch.float64)
|
|
443
|
+
self.fn = torch.zeros(n_fg_classes, dtype=torch.float64)
|
|
444
|
+
|
|
445
|
+
def accumulate(self, learn):
|
|
446
|
+
"""Called after each validation batch to accumulate TP/FP/FN."""
|
|
447
|
+
pred = learn.pred # Model output [B, C, D, H, W]
|
|
448
|
+
targ = learn.y # Target [B, 1, D, H, W] or [B, D, H, W]
|
|
449
|
+
|
|
450
|
+
# Get predicted segmentation
|
|
451
|
+
pred_seg = pred.argmax(dim=1) # [B, D, H, W]
|
|
452
|
+
|
|
453
|
+
# Ensure target has same shape
|
|
454
|
+
if targ.ndim == pred_seg.ndim + 1:
|
|
455
|
+
targ = targ.squeeze(1)
|
|
456
|
+
|
|
457
|
+
# Accumulate TP/FP/FN for each foreground class
|
|
458
|
+
idx = 0
|
|
459
|
+
for c in range(self.start_class, self.n_classes):
|
|
460
|
+
c_pred = (pred_seg == c)
|
|
461
|
+
c_targ = (targ == c)
|
|
462
|
+
|
|
463
|
+
self.tp[idx] += (c_pred & c_targ).sum().float().cpu()
|
|
464
|
+
self.fp[idx] += (c_pred & ~c_targ).sum().float().cpu()
|
|
465
|
+
self.fn[idx] += (~c_pred & c_targ).sum().float().cpu()
|
|
466
|
+
idx += 1
|
|
467
|
+
|
|
468
|
+
@property
|
|
469
|
+
def value(self):
|
|
470
|
+
"""Compute Dice from accumulated counts at end of epoch."""
|
|
471
|
+
dice = 2 * self.tp / (2 * self.tp + self.fp + self.fn + 1e-8)
|
|
472
|
+
return dice.mean().item()
|
|
473
|
+
|
|
474
|
+
@property
|
|
475
|
+
def name(self):
|
|
476
|
+
return 'accumulated_dice'
|
|
477
|
+
|
|
478
|
+
# %% ../nbs/05_vision_metrics.ipynb 25
|
|
479
|
+
class AccumulatedMultiDice(AccumulatedDice):
|
|
480
|
+
"""Multi-class version of AccumulatedDice that returns per-class Dice scores.
|
|
481
|
+
|
|
482
|
+
Instead of returning a single mean Dice, this returns a tensor with the
|
|
483
|
+
Dice score for each foreground class. Useful for monitoring per-class
|
|
484
|
+
performance during training.
|
|
485
|
+
|
|
486
|
+
Example:
|
|
487
|
+
```python
|
|
488
|
+
# For 3-class segmentation (background + 2 foreground classes)
|
|
489
|
+
learn = Learner(dls, model, loss_func=loss_func,
|
|
490
|
+
metrics=[AccumulatedMultiDice(n_classes=3)])
|
|
491
|
+
```
|
|
492
|
+
"""
|
|
493
|
+
@property
|
|
494
|
+
def value(self):
|
|
495
|
+
"""Return per-class Dice scores."""
|
|
496
|
+
dice = 2 * self.tp / (2 * self.tp + self.fp + self.fn + 1e-8)
|
|
497
|
+
return dice # Returns tensor, fastai will display all values
|
|
498
|
+
|
|
499
|
+
@property
|
|
500
|
+
def name(self):
|
|
501
|
+
return 'accumulated_multi_dice'
|