fastMONAI 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,7 @@
2
2
 
3
3
  # %% auto 0
4
4
  __all__ = ['save_series_pred', 'load_system_resources', 'inference', 'compute_binary_tumor_volume', 'refine_binary_pred_mask',
5
- 'gradio_image_classifier']
5
+ 'keep_largest', 'gradio_image_classifier']
6
6
 
7
7
  # %% ../nbs/06_vision_inference.ipynb 1
8
8
  from copy import copy
@@ -67,18 +67,18 @@ def load_system_resources(models_path, learner_fn, variables_fn):
67
67
 
68
68
  learn = load_learner(models_path / learner_fn, cpu=True)
69
69
  vars_fn = models_path / variables_fn
70
- _, reorder, resample = load_variables(pkl_fn=vars_fn)
70
+ _, apply_reorder, target_spacing = load_variables(pkl_fn=vars_fn)
71
71
 
72
- return learn, reorder, resample
72
+ return learn, apply_reorder, target_spacing
73
73
 
74
74
  # %% ../nbs/06_vision_inference.ipynb 8
75
- def inference(learn_inf, reorder, resample, fn: (str, Path) = '',
75
+ def inference(learn_inf, apply_reorder, target_spacing, fn: (str, Path) = '',
76
76
  save_path: (str, Path) = None, org_img=None, input_img=None,
77
77
  org_size=None):
78
78
  """Predict on new data using exported model."""
79
79
 
80
80
  if None in [org_img, input_img, org_size]:
81
- org_img, input_img, org_size = med_img_reader(fn, reorder, resample,
81
+ org_img, input_img, org_size = med_img_reader(fn, apply_reorder, target_spacing,
82
82
  only_tensor=False)
83
83
  else:
84
84
  org_img, input_img = copy(org_img), copy(input_img)
@@ -148,6 +148,10 @@ def refine_binary_pred_mask(pred_mask,
148
148
  if verbose:
149
149
  print(n_components)
150
150
 
151
+ # Handle empty mask case (no foreground components)
152
+ if n_components == 0:
153
+ return torch.zeros_like(torch.Tensor(pred_mask)).float()
154
+
151
155
  if remove_size is None:
152
156
  sizes = np.bincount(labeled_mask.ravel())
153
157
  max_label = sizes[1:].argmax() + 1
@@ -157,14 +161,36 @@ def refine_binary_pred_mask(pred_mask,
157
161
  processed_mask = remove_small_objects(
158
162
  labeled_mask, min_size=small_objects_threshold)
159
163
 
160
- return torch.Tensor(processed_mask > 0).float()
164
+ return torch.Tensor(processed_mask > 0).float()
165
+
166
+ # %% ../nbs/06_vision_inference.ipynb 12
167
+ def keep_largest(pred_mask: torch.Tensor) -> torch.Tensor:
168
+ """Keep only the largest connected component in a binary mask.
169
+
170
+ Args:
171
+ pred_mask: Binary prediction mask tensor.
172
+
173
+ Returns:
174
+ Binary mask with only the largest connected component.
175
+ """
176
+ mask_np = pred_mask.numpy() if isinstance(pred_mask, torch.Tensor) else pred_mask
177
+ labeled_mask, n_components = label(mask_np)
178
+
179
+ if n_components == 0:
180
+ return torch.zeros_like(pred_mask) if isinstance(pred_mask, torch.Tensor) else mask_np
181
+
182
+ sizes = np.bincount(labeled_mask.ravel())
183
+ largest_label = sizes[1:].argmax() + 1 # Skip background (label 0)
184
+
185
+ result = (labeled_mask == largest_label).astype(np.float32)
186
+ return torch.from_numpy(result) if isinstance(pred_mask, torch.Tensor) else result
161
187
 
162
- # %% ../nbs/06_vision_inference.ipynb 13
163
- def gradio_image_classifier(file_obj, learn, reorder, resample):
188
+ # %% ../nbs/06_vision_inference.ipynb 14
189
+ def gradio_image_classifier(file_obj, learn, apply_reorder, target_spacing):
164
190
  """Predict on images using exported learner and return the result as a dictionary."""
165
191
 
166
192
  img_path = Path(file_obj.name)
167
- img = med_img_reader(img_path, reorder=reorder, resample=resample)
193
+ img = med_img_reader(img_path, apply_reorder=apply_reorder, target_spacing=target_spacing)
168
194
 
169
195
  _, _, predictions = learn.predict(img)
170
196
  prediction_dict = {index: value.item() for index, value in enumerate(predictions)}
@@ -2,28 +2,76 @@
2
2
 
3
3
  # %% auto 0
4
4
  __all__ = ['calculate_dsc', 'calculate_haus', 'binary_dice_score', 'multi_dice_score', 'binary_hausdorff_distance',
5
- 'multi_hausdorff_distance']
5
+ 'multi_hausdorff_distance', 'calculate_confusion_metrics', 'binary_sensitivity', 'multi_sensitivity',
6
+ 'binary_precision', 'multi_precision', 'calculate_lesion_detection_rate', 'binary_lesion_detection_rate',
7
+ 'multi_lesion_detection_rate', 'calculate_signed_rve', 'binary_signed_rve', 'multi_signed_rve',
8
+ 'AccumulatedDice', 'AccumulatedMultiDice']
6
9
 
7
10
  # %% ../nbs/05_vision_metrics.ipynb 1
8
11
  import torch
9
12
  import numpy as np
10
- from monai.metrics import compute_hausdorff_distance, compute_dice
13
+ from monai.metrics import compute_hausdorff_distance, compute_dice, get_confusion_matrix, compute_confusion_matrix_metric
14
+ from scipy.ndimage import label as scipy_label
11
15
  from .vision_data import pred_to_binary_mask, batch_pred_to_multiclass_mask
16
+ from fastai.learner import Metric
12
17
 
13
18
  # %% ../nbs/05_vision_metrics.ipynb 3
14
19
  def calculate_dsc(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
15
- """MONAI `compute_meandice`"""
20
+ """Calculate Dice score using MONAI's compute_dice.
16
21
 
17
- return torch.Tensor([compute_dice(p[None], t[None]) for p, t in list(zip(pred,targ))])
22
+ Accepts tensors of various shapes and automatically reshapes to 5D format.
23
+
24
+ Args:
25
+ pred: Binary prediction tensor. Accepts:
26
+ - [D, H, W] single 3D volume
27
+ - [C, D, H, W] single volume with channel
28
+ - [B, C, D, H, W] batched volumes
29
+ targ: Binary target tensor (same shape options as pred).
30
+
31
+ Returns:
32
+ Dice score(s). Single value for 3D/4D input, tensor of values for 5D batch.
33
+ """
34
+ # Normalize to 5D: [B, C, D, H, W]
35
+ if pred.ndim == 3: # [D, H, W] -> [1, 1, D, H, W]
36
+ pred = pred.unsqueeze(0).unsqueeze(0)
37
+ targ = targ.unsqueeze(0).unsqueeze(0)
38
+ elif pred.ndim == 4: # [C, D, H, W] -> [1, C, D, H, W]
39
+ pred = pred.unsqueeze(0)
40
+ targ = targ.unsqueeze(0)
41
+
42
+ return torch.Tensor([compute_dice(p[None], t[None]) for p, t in zip(pred, targ)])
18
43
 
19
44
  # %% ../nbs/05_vision_metrics.ipynb 4
20
45
  def calculate_haus(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
21
- """MONAI `compute_hausdorff_distance`"""
46
+ """Compute 95th percentile Hausdorff distance (HD95) using MONAI.
47
+
48
+ HD95 is more robust than standard Hausdorff distance as it ignores
49
+ the top 5% of outlier distances.
50
+
51
+ Accepts tensors of various shapes and automatically reshapes to 5D format.
52
+
53
+ Args:
54
+ pred: Binary prediction tensor. Accepts:
55
+ - [D, H, W] single 3D volume
56
+ - [C, D, H, W] single volume with channel
57
+ - [B, C, D, H, W] batched volumes
58
+ targ: Binary target tensor (same shape options as pred).
22
59
 
23
- return torch.Tensor([compute_hausdorff_distance(p[None], t[None]) for p, t in list(zip(pred,targ))])
60
+ Returns:
61
+ HD95 value(s). Single value for 3D/4D input, tensor of values for 5D batch.
62
+ """
63
+ # Normalize to 5D: [B, C, D, H, W]
64
+ if pred.ndim == 3: # [D, H, W] -> [1, 1, D, H, W]
65
+ pred = pred.unsqueeze(0).unsqueeze(0)
66
+ targ = targ.unsqueeze(0).unsqueeze(0)
67
+ elif pred.ndim == 4: # [C, D, H, W] -> [1, C, D, H, W]
68
+ pred = pred.unsqueeze(0)
69
+ targ = targ.unsqueeze(0)
70
+
71
+ return torch.Tensor([compute_hausdorff_distance(p[None], t[None], percentile=95) for p, t in zip(pred, targ)])
24
72
 
25
73
  # %% ../nbs/05_vision_metrics.ipynb 5
26
- def binary_dice_score(act: torch.tensor, targ: torch.Tensor) -> torch.Tensor:
74
+ def binary_dice_score(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
27
75
  """Calculates the mean Dice score for binary semantic segmentation tasks.
28
76
 
29
77
  Args:
@@ -56,45 +104,398 @@ def multi_dice_score(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
56
104
  for c in range(1, n_classes):
57
105
  c_pred, c_targ = torch.where(pred == c, 1, 0), torch.where(targ == c, 1, 0)
58
106
  dsc = calculate_dsc(c_pred, c_targ)
59
- binary_dice_scores.append(np.nanmean(dsc)) # #TODO update torch to get torch.nanmean() to work
107
+ binary_dice_scores.append(float(torch.nanmean(dsc)))
60
108
 
61
109
  return torch.Tensor(binary_dice_scores)
62
110
 
63
111
  # %% ../nbs/05_vision_metrics.ipynb 7
64
112
  def binary_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
65
- """Calculate the mean Hausdorff distance for binary semantic segmentation tasks.
113
+ """Calculate the mean HD95 for binary semantic segmentation tasks.
66
114
 
67
115
  Args:
68
116
  act: Activation tensor with dimensions [B, C, W, H, D].
69
117
  targ: Target masks with dimensions [B, C, W, H, D].
70
118
 
71
119
  Returns:
72
- Mean Hausdorff distance.
120
+ Mean HD95.
73
121
  """
74
-
75
-
76
122
  pred = pred_to_binary_mask(act)
77
-
78
123
  haus = calculate_haus(pred.cpu(), targ.cpu())
79
124
  return torch.mean(haus)
80
125
 
81
126
  # %% ../nbs/05_vision_metrics.ipynb 8
82
- def multi_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor :
83
- """Calculate the mean Hausdorff distance for each class in multi-class semantic segmentation tasks.
127
+ def multi_hausdorff_distance(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
128
+ """Calculate the mean HD95 for each class in multi-class semantic segmentation tasks.
84
129
 
85
130
  Args:
86
131
  act: Activation tensor with dimensions [B, C, W, H, D].
87
132
  targ: Target masks with dimensions [B, C, W, H, D].
88
133
 
89
134
  Returns:
90
- Mean Hausdorff distance for each class.
135
+ Mean HD95 for each class.
91
136
  """
92
-
93
137
  pred, n_classes = batch_pred_to_multiclass_mask(act)
94
138
  binary_haus = []
95
139
 
96
140
  for c in range(1, n_classes):
97
141
  c_pred, c_targ = torch.where(pred==c, 1, 0), torch.where(targ==c, 1, 0)
98
- haus = calculate_haus(pred, targ)
99
- binary_haus.append(np.nanmean(haus))
142
+ haus = calculate_haus(c_pred, c_targ)
143
+ binary_haus.append(float(torch.nanmean(haus)))
100
144
  return torch.Tensor(binary_haus)
145
+
146
+ # %% ../nbs/05_vision_metrics.ipynb 10
147
+ def calculate_confusion_metrics(pred: torch.Tensor, targ: torch.Tensor, metric_name: str) -> torch.Tensor:
148
+ """Calculate confusion matrix-based metric using MONAI.
149
+
150
+ Args:
151
+ pred: Binary prediction tensor [B, C, W, H, D].
152
+ targ: Binary target tensor [B, C, W, H, D].
153
+ metric_name: One of "sensitivity", "precision", "specificity", "f1 score".
154
+
155
+ Returns:
156
+ Metric values for each sample in batch.
157
+ """
158
+ # get_confusion_matrix expects one-hot format and returns [B, n_class, 4] where 4 = [TP, FP, TN, FN]
159
+ confusion_matrix = get_confusion_matrix(pred, targ, include_background=False)
160
+ metric = compute_confusion_matrix_metric(metric_name, confusion_matrix)
161
+ return metric
162
+
163
+ # %% ../nbs/05_vision_metrics.ipynb 11
164
+ def binary_sensitivity(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
165
+ """Calculate mean sensitivity (recall) for binary segmentation.
166
+
167
+ Sensitivity = TP / (TP + FN) - measures the proportion of actual positives
168
+ that are correctly identified.
169
+
170
+ Args:
171
+ act: Activation tensor [B, C, W, H, D].
172
+ targ: Target masks [B, C, W, H, D].
173
+
174
+ Returns:
175
+ Mean sensitivity score.
176
+ """
177
+ pred = pred_to_binary_mask(act)
178
+ sens = calculate_confusion_metrics(pred.cpu(), targ.cpu(), "sensitivity")
179
+ return torch.nanmean(sens)
180
+
181
+ # %% ../nbs/05_vision_metrics.ipynb 12
182
+ def multi_sensitivity(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
183
+ """Calculate mean sensitivity for each class in multi-class segmentation.
184
+
185
+ Args:
186
+ act: Activation tensor [B, C, W, H, D].
187
+ targ: Target masks [B, C, W, H, D].
188
+
189
+ Returns:
190
+ Mean sensitivity for each class.
191
+ """
192
+ pred, n_classes = batch_pred_to_multiclass_mask(act)
193
+ class_sens = []
194
+
195
+ for c in range(1, n_classes):
196
+ c_pred = torch.where(pred == c, 1, 0)
197
+ c_targ = torch.where(targ == c, 1, 0)
198
+ sens = calculate_confusion_metrics(c_pred, c_targ, "sensitivity")
199
+ class_sens.append(float(torch.nanmean(sens)))
200
+
201
+ return torch.Tensor(class_sens)
202
+
203
+ # %% ../nbs/05_vision_metrics.ipynb 13
204
+ def binary_precision(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
205
+ """Calculate mean precision for binary segmentation.
206
+
207
+ Precision = TP / (TP + FP) - measures the proportion of positive predictions
208
+ that are actually correct.
209
+
210
+ Args:
211
+ act: Activation tensor [B, C, W, H, D].
212
+ targ: Target masks [B, C, W, H, D].
213
+
214
+ Returns:
215
+ Mean precision score.
216
+ """
217
+ pred = pred_to_binary_mask(act)
218
+ prec = calculate_confusion_metrics(pred.cpu(), targ.cpu(), "precision")
219
+ return torch.nanmean(prec)
220
+
221
+ # %% ../nbs/05_vision_metrics.ipynb 14
222
+ def multi_precision(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
223
+ """Calculate mean precision for each class in multi-class segmentation.
224
+
225
+ Args:
226
+ act: Activation tensor [B, C, W, H, D].
227
+ targ: Target masks [B, C, W, H, D].
228
+
229
+ Returns:
230
+ Mean precision for each class.
231
+ """
232
+ pred, n_classes = batch_pred_to_multiclass_mask(act)
233
+ class_prec = []
234
+
235
+ for c in range(1, n_classes):
236
+ c_pred = torch.where(pred == c, 1, 0)
237
+ c_targ = torch.where(targ == c, 1, 0)
238
+ prec = calculate_confusion_metrics(c_pred, c_targ, "precision")
239
+ class_prec.append(float(torch.nanmean(prec)))
240
+
241
+ return torch.Tensor(class_prec)
242
+
243
+ # %% ../nbs/05_vision_metrics.ipynb 16
244
+ def calculate_lesion_detection_rate(pred: torch.Tensor, targ: torch.Tensor, threshold: float = 0.0) -> torch.Tensor:
245
+ """Calculate lesion-wise detection rate.
246
+
247
+ For each connected component (lesion) in the target, check if it is
248
+ detected by the prediction. Detection criteria depends on threshold:
249
+ - threshold=0: any overlap counts as detected
250
+ - threshold>0: per-lesion Dice score must exceed threshold
251
+
252
+ Args:
253
+ pred: Binary prediction tensor [B, C, W, H, D].
254
+ targ: Binary target tensor [B, C, W, H, D].
255
+ threshold: Minimum Dice score for a lesion to be considered detected.
256
+ Default 0.0 means any overlap counts as detected.
257
+
258
+ Returns:
259
+ Detection rate (detected lesions / total lesions) for each sample.
260
+ """
261
+ detection_rates = []
262
+
263
+ for p, t in zip(pred, targ):
264
+ p_np = p.squeeze().cpu().numpy()
265
+ t_np = t.squeeze().cpu().numpy()
266
+
267
+ # Label connected components in target
268
+ labeled_targ, n_lesions = scipy_label(t_np)
269
+
270
+ if n_lesions == 0:
271
+ detection_rates.append(float('nan'))
272
+ continue
273
+
274
+ detected = 0
275
+ for lesion_id in range(1, n_lesions + 1):
276
+ lesion_mask = (labeled_targ == lesion_id)
277
+
278
+ if threshold == 0.0:
279
+ # Original behavior: any overlap counts as detected
280
+ overlap = (p_np * lesion_mask).sum()
281
+ if overlap > 0:
282
+ detected += 1
283
+ else:
284
+ # Compute per-lesion Dice score
285
+ pred_in_lesion = p_np * lesion_mask
286
+ intersection = (pred_in_lesion * lesion_mask).sum()
287
+ lesion_vol = lesion_mask.sum()
288
+ pred_vol = pred_in_lesion.sum()
289
+
290
+ if (lesion_vol + pred_vol) > 0:
291
+ dice = 2 * intersection / (lesion_vol + pred_vol)
292
+ if dice > threshold:
293
+ detected += 1
294
+
295
+ detection_rates.append(detected / n_lesions)
296
+
297
+ return torch.Tensor(detection_rates)
298
+
299
+ # %% ../nbs/05_vision_metrics.ipynb 17
300
+ def binary_lesion_detection_rate(act: torch.Tensor, targ: torch.Tensor, threshold: float = 0.0) -> torch.Tensor:
301
+ """Calculate mean lesion detection rate for binary segmentation.
302
+
303
+ Args:
304
+ act: Activation tensor [B, C, W, H, D].
305
+ targ: Target masks [B, C, W, H, D].
306
+ threshold: Minimum Dice score for a lesion to be considered detected.
307
+ Default 0.0 means any overlap counts as detected.
308
+
309
+ Returns:
310
+ Mean lesion detection rate.
311
+ """
312
+ pred = pred_to_binary_mask(act)
313
+ ldr = calculate_lesion_detection_rate(pred.cpu(), targ.cpu(), threshold)
314
+ return torch.nanmean(ldr)
315
+
316
+ # %% ../nbs/05_vision_metrics.ipynb 18
317
+ def multi_lesion_detection_rate(act: torch.Tensor, targ: torch.Tensor, threshold: float = 0.0) -> torch.Tensor:
318
+ """Calculate mean lesion detection rate for each class in multi-class segmentation.
319
+
320
+ Args:
321
+ act: Activation tensor [B, C, W, H, D].
322
+ targ: Target masks [B, C, W, H, D].
323
+ threshold: Minimum Dice score for a lesion to be considered detected.
324
+ Default 0.0 means any overlap counts as detected.
325
+
326
+ Returns:
327
+ Mean lesion detection rate for each class.
328
+ """
329
+ pred, n_classes = batch_pred_to_multiclass_mask(act)
330
+ class_ldr = []
331
+
332
+ for c in range(1, n_classes):
333
+ c_pred = torch.where(pred == c, 1, 0)
334
+ c_targ = torch.where(targ == c, 1, 0)
335
+ ldr = calculate_lesion_detection_rate(c_pred, c_targ, threshold)
336
+ class_ldr.append(float(torch.nanmean(ldr)))
337
+
338
+ return torch.Tensor(class_ldr)
339
+
340
+ # %% ../nbs/05_vision_metrics.ipynb 20
341
+ def calculate_signed_rve(pred: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
342
+ """Calculate signed Relative Volume Error.
343
+
344
+ RVE = (pred_volume - targ_volume) / targ_volume
345
+
346
+ Positive values indicate over-segmentation (model predicts too large),
347
+ negative values indicate under-segmentation (model predicts too small).
348
+
349
+ Args:
350
+ pred: Binary prediction tensor [B, C, W, H, D].
351
+ targ: Binary target tensor [B, C, W, H, D].
352
+
353
+ Returns:
354
+ Signed RVE for each sample in batch.
355
+ """
356
+ rve_values = []
357
+
358
+ for p, t in zip(pred, targ):
359
+ pred_vol = p.sum().float()
360
+ targ_vol = t.sum().float()
361
+
362
+ if targ_vol == 0:
363
+ rve_values.append(float('nan'))
364
+ else:
365
+ rve = (pred_vol - targ_vol) / targ_vol
366
+ rve_values.append(rve.item())
367
+
368
+ return torch.Tensor(rve_values)
369
+
370
+ # %% ../nbs/05_vision_metrics.ipynb 21
371
+ def binary_signed_rve(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
372
+ """Calculate mean signed RVE for binary segmentation.
373
+
374
+ Args:
375
+ act: Activation tensor [B, C, W, H, D].
376
+ targ: Target masks [B, C, W, H, D].
377
+
378
+ Returns:
379
+ Mean signed RVE.
380
+ """
381
+ pred = pred_to_binary_mask(act)
382
+ rve = calculate_signed_rve(pred.cpu(), targ.cpu())
383
+ return torch.nanmean(rve)
384
+
385
+ # %% ../nbs/05_vision_metrics.ipynb 22
386
+ def multi_signed_rve(act: torch.Tensor, targ: torch.Tensor) -> torch.Tensor:
387
+ """Calculate mean signed RVE for each class in multi-class segmentation.
388
+
389
+ Args:
390
+ act: Activation tensor [B, C, W, H, D].
391
+ targ: Target masks [B, C, W, H, D].
392
+
393
+ Returns:
394
+ Mean signed RVE for each class.
395
+ """
396
+ pred, n_classes = batch_pred_to_multiclass_mask(act)
397
+ class_rve = []
398
+
399
+ for c in range(1, n_classes):
400
+ c_pred = torch.where(pred == c, 1, 0)
401
+ c_targ = torch.where(targ == c, 1, 0)
402
+ rve = calculate_signed_rve(c_pred, c_targ)
403
+ class_rve.append(float(torch.nanmean(rve)))
404
+
405
+ return torch.Tensor(class_rve)
406
+
407
+ # %% ../nbs/05_vision_metrics.ipynb 24
408
+ class AccumulatedDice(Metric):
409
+ """nnU-Net-style accumulated Dice metric for reliable pseudo dice during training.
410
+
411
+ Instead of averaging per-batch Dice scores, this metric accumulates
412
+ true positives, false positives, and false negatives across ALL validation
413
+ batches, then computes Dice from the totals. This gives more weight to
414
+ batches with more foreground voxels and is more statistically robust.
415
+
416
+ Args:
417
+ n_classes: Number of classes including background (default: 2 for binary).
418
+ include_background: Whether to include background in metric (default: False).
419
+
420
+ Example:
421
+ ```python
422
+ learn = Learner(dls, model, loss_func=loss_func,
423
+ metrics=[AccumulatedDice(n_classes=2)])
424
+
425
+ # For checkpoint selection based on accumulated dice:
426
+ save_best = SaveModelCallback(
427
+ monitor='accumulated_dice',
428
+ comp=np.greater, # Higher dice is better
429
+ fname='best_model'
430
+ )
431
+ ```
432
+ """
433
+ def __init__(self, n_classes: int = 2, include_background: bool = False):
434
+ self.n_classes = n_classes
435
+ self.include_background = include_background
436
+ self.start_class = 0 if include_background else 1
437
+
438
+ def reset(self):
439
+ """Called at start of validation epoch."""
440
+ n_fg_classes = self.n_classes if self.include_background else self.n_classes - 1
441
+ self.tp = torch.zeros(n_fg_classes, dtype=torch.float64)
442
+ self.fp = torch.zeros(n_fg_classes, dtype=torch.float64)
443
+ self.fn = torch.zeros(n_fg_classes, dtype=torch.float64)
444
+
445
+ def accumulate(self, learn):
446
+ """Called after each validation batch to accumulate TP/FP/FN."""
447
+ pred = learn.pred # Model output [B, C, D, H, W]
448
+ targ = learn.y # Target [B, 1, D, H, W] or [B, D, H, W]
449
+
450
+ # Get predicted segmentation
451
+ pred_seg = pred.argmax(dim=1) # [B, D, H, W]
452
+
453
+ # Ensure target has same shape
454
+ if targ.ndim == pred_seg.ndim + 1:
455
+ targ = targ.squeeze(1)
456
+
457
+ # Accumulate TP/FP/FN for each foreground class
458
+ idx = 0
459
+ for c in range(self.start_class, self.n_classes):
460
+ c_pred = (pred_seg == c)
461
+ c_targ = (targ == c)
462
+
463
+ self.tp[idx] += (c_pred & c_targ).sum().float().cpu()
464
+ self.fp[idx] += (c_pred & ~c_targ).sum().float().cpu()
465
+ self.fn[idx] += (~c_pred & c_targ).sum().float().cpu()
466
+ idx += 1
467
+
468
+ @property
469
+ def value(self):
470
+ """Compute Dice from accumulated counts at end of epoch."""
471
+ dice = 2 * self.tp / (2 * self.tp + self.fp + self.fn + 1e-8)
472
+ return dice.mean().item()
473
+
474
+ @property
475
+ def name(self):
476
+ return 'accumulated_dice'
477
+
478
+ # %% ../nbs/05_vision_metrics.ipynb 25
479
+ class AccumulatedMultiDice(AccumulatedDice):
480
+ """Multi-class version of AccumulatedDice that returns per-class Dice scores.
481
+
482
+ Instead of returning a single mean Dice, this returns a tensor with the
483
+ Dice score for each foreground class. Useful for monitoring per-class
484
+ performance during training.
485
+
486
+ Example:
487
+ ```python
488
+ # For 3-class segmentation (background + 2 foreground classes)
489
+ learn = Learner(dls, model, loss_func=loss_func,
490
+ metrics=[AccumulatedMultiDice(n_classes=3)])
491
+ ```
492
+ """
493
+ @property
494
+ def value(self):
495
+ """Return per-class Dice scores."""
496
+ dice = 2 * self.tp / (2 * self.tp + self.fp + self.fn + 1e-8)
497
+ return dice # Returns tensor, fastai will display all values
498
+
499
+ @property
500
+ def name(self):
501
+ return 'accumulated_multi_dice'