fastMONAI 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fastMONAI/vision_data.py CHANGED
@@ -10,168 +10,258 @@ from fastai.vision.data import *
10
10
  from .vision_core import *
11
11
 
12
12
  # %% ../nbs/02_vision_data.ipynb 5
13
- def pred_to_multiclass_mask(pred:torch.Tensor # [C,W,H,D] activation tensor
14
- ) -> torch.Tensor:
15
- '''Apply Softmax function on the predicted tensor to rescale the values in the range [0, 1] and sum to 1.
16
- Then apply argmax to get the indices of the maximum value of all elements in the predicted Tensor.
17
- Returns: Predicted mask.
18
- '''
13
+ def pred_to_multiclass_mask(pred: torch.Tensor) -> torch.Tensor:
14
+ """Apply Softmax on the predicted tensor to rescale the values in the range [0, 1]
15
+ and sum to 1. Then apply argmax to get the indices of the maximum value of all
16
+ elements in the predicted Tensor.
17
+
18
+ Args:
19
+ pred: [C,W,H,D] activation tensor.
20
+
21
+ Returns:
22
+ Predicted mask.
23
+ """
24
+
19
25
  pred = pred.softmax(dim=0)
26
+
20
27
  return pred.argmax(dim=0, keepdims=True)
21
28
 
22
29
  # %% ../nbs/02_vision_data.ipynb 6
23
- def batch_pred_to_multiclass_mask(pred:torch.Tensor # [B, C, W, H, D] batch of activations
24
- ) -> (torch.Tensor, int):
25
- '''Convert a batch of predicted activation tensors to masks.
26
- Returns batch of predicted masks and number of classes.
27
- '''
28
-
30
+ def batch_pred_to_multiclass_mask(pred: torch.Tensor) -> (torch.Tensor, int):
31
+ """Convert a batch of predicted activation tensors to masks.
32
+
33
+ Args:
34
+ pred: [B, C, W, H, D] batch of activations.
35
+
36
+ Returns:
37
+ Tuple of batch of predicted masks and number of classes.
38
+ """
39
+
29
40
  n_classes = pred.shape[1]
30
41
  pred = [pred_to_multiclass_mask(p) for p in pred]
31
42
 
32
43
  return torch.stack(pred), n_classes
33
44
 
34
45
  # %% ../nbs/02_vision_data.ipynb 7
35
- def pred_to_binary_mask(pred # [B, C, W, H, D] or [C, W, H, D] activation tensor
36
- ) -> torch.Tensor:
37
- '''Apply Sigmoid function that squishes activations into a range between 0 and 1.
38
- Then we classify all values greater than or equal to 0.5 to 1, and the values below 0.5 to 0.
39
-
40
- Returns predicted binary mask(s).
41
- '''
42
-
46
+ def pred_to_binary_mask(pred: torch.Tensor) -> torch.Tensor:
47
+ """Apply Sigmoid function that squishes activations into a range between 0 and 1.
48
+ Then we classify all values greater than or equal to 0.5 to 1,
49
+ and the values below 0.5 to 0.
50
+
51
+ Args:
52
+ pred: [B, C, W, H, D] or [C, W, H, D] activation tensor
53
+
54
+ Returns:
55
+ Predicted binary mask(s).
56
+ """
57
+
43
58
  pred = torch.sigmoid(pred)
44
- return torch.where(pred>=0.5, 1, 0)
59
+
60
+ return torch.where(pred >= 0.5, 1, 0)
45
61
 
46
62
  # %% ../nbs/02_vision_data.ipynb 9
47
63
  class MedDataBlock(DataBlock):
48
- '''Container to quickly build dataloaders.'''
64
+ """Container to quickly build dataloaders."""
65
+ #TODO add get_x
66
+ def __init__(self, blocks: list = None, dl_type: TfmdDL = None, getters: list = None,
67
+ n_inp: int = None, item_tfms: list = None, batch_tfms: list = None,
68
+ reorder: bool = False, resample: (int, list) = None, **kwargs):
49
69
 
50
- def __init__(self, blocks:list=None,dl_type:TfmdDL=None, getters:list=None, n_inp:int=None, item_tfms:list=None,
51
- batch_tfms:list=None, reorder:bool=False, resample:(int, list)=None, **kwargs):
70
+ super().__init__(blocks, dl_type, getters, n_inp, item_tfms,
71
+ batch_tfms, **kwargs)
52
72
 
53
- super().__init__(blocks, dl_type, getters, n_inp, item_tfms, batch_tfms, **kwargs)
54
- MedBase.item_preprocessing(resample,reorder)
73
+ MedBase.item_preprocessing(resample, reorder)
55
74
 
56
- # %% ../nbs/02_vision_data.ipynb 12
75
+ # %% ../nbs/02_vision_data.ipynb 11
57
76
  def MedMaskBlock():
77
+ """Create a TransformBlock for medical masks."""
58
78
  return TransformBlock(type_tfms=MedMask.create)
59
79
 
60
- # %% ../nbs/02_vision_data.ipynb 14
80
+ # %% ../nbs/02_vision_data.ipynb 13
61
81
  class MedImageDataLoaders(DataLoaders):
62
- '''Higher-level `MedDataBlock` API.'''
63
-
82
+ """Higher-level `MedDataBlock` API."""
83
+
64
84
  @classmethod
65
85
  @delegates(DataLoaders.from_dblock)
66
- def from_df(cls, df, valid_pct=0.2, seed=None, fn_col=0, folder=None, suff='', label_col=1, label_delim=None,
67
- y_block=None, valid_col=None, item_tfms=None, batch_tfms=None, reorder=False, resample=None, **kwargs):
68
- '''Create from DataFrame.'''
69
-
86
+ def from_df(cls, df, valid_pct=0.2, seed=None, fn_col=0, folder=None, suff='',
87
+ label_col=1, label_delim=None, y_block=None, valid_col=None,
88
+ item_tfms=None, batch_tfms=None, reorder=False, resample=None, **kwargs):
89
+ """Create from DataFrame."""
90
+
70
91
  if y_block is None:
71
92
  is_multi = (is_listy(label_col) and len(label_col) > 1) or label_delim is not None
72
93
  y_block = MultiCategoryBlock if is_multi else CategoryBlock
73
- splitter = RandomSplitter(valid_pct, seed=seed) if valid_col is None else ColSplitter(valid_col)
74
94
 
95
+ splitter = (RandomSplitter(valid_pct, seed=seed)
96
+ if valid_col is None else ColSplitter(valid_col))
75
97
 
76
- dblock = MedDataBlock(blocks=(ImageBlock(cls=MedImage), y_block), get_x=ColReader(fn_col, suff=suff),
77
- get_y=ColReader(label_col, label_delim=label_delim),
78
- splitter=splitter,
79
- item_tfms=item_tfms,
80
- reorder=reorder,
81
- resample=resample)
98
+ dblock = MedDataBlock(
99
+ blocks=(ImageBlock(cls=MedImage), y_block),
100
+ get_x=ColReader(fn_col, suff=suff),
101
+ get_y=ColReader(label_col, label_delim=label_delim),
102
+ splitter=splitter,
103
+ item_tfms=item_tfms,
104
+ reorder=reorder,
105
+ resample=resample
106
+ )
82
107
 
83
108
  return cls.from_dblock(dblock, df, **kwargs)
84
109
 
85
- # %% ../nbs/02_vision_data.ipynb 19
110
+ # %% ../nbs/02_vision_data.ipynb 16
86
111
  @typedispatch
87
- def show_batch(x:MedImage, y, samples, ctxs=None, max_n=6, nrows=None, ncols=None, figsize=None, channel=0, indices=None, anatomical_plane=0, **kwargs):
112
+ def show_batch(x: MedImage, y, samples, ctxs=None, max_n=6, nrows=None,
113
+ ncols=None, figsize=None, channel=0, indices=None,
114
+ anatomical_plane=0, **kwargs):
88
115
  '''Showing a batch of samples for classification and regression tasks.'''
89
116
 
90
- if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
117
+ if ctxs is None:
118
+ ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
119
+
91
120
  n = 1 if y is None else 2
121
+
92
122
  for i in range(n):
93
- ctxs = [b.show(ctx=c, channel=channel, indices=indices, anatomical_plane=anatomical_plane, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]
123
+ ctxs = [
124
+ b.show(ctx=c, channel=channel, indices=indices, anatomical_plane=anatomical_plane, **kwargs)
125
+ for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))
126
+ ]
94
127
 
95
128
  plt.tight_layout()
129
+
96
130
  return ctxs
97
131
 
98
- # %% ../nbs/02_vision_data.ipynb 20
132
+ # %% ../nbs/02_vision_data.ipynb 17
99
133
  @typedispatch
100
- def show_batch(x:MedImage, y:MedMask, samples, ctxs=None, max_n=6, nrows=None, ncols=None, figsize=None, channel=0, indices=None, anatomical_plane=0, **kwargs):
101
- '''Showing a batch of decoded segmentation samples.'''
134
+ def show_batch(x: MedImage, y: MedMask, samples, ctxs=None, max_n=6, nrows=None,
135
+ ncols=None, figsize=None, channel=0, indices=None,
136
+ anatomical_plane=0, **kwargs):
137
+ """Showing a batch of decoded segmentation samples."""
102
138
 
103
139
  nrows, ncols = min(len(samples), max_n), x.shape[1] + 1
104
140
  imgs = []
105
141
 
106
- fig,axs = subplots(nrows, ncols, figsize=figsize, **kwargs)
142
+ fig, axs = subplots(nrows, ncols, figsize=figsize, **kwargs)
107
143
  axs = axs.flatten()
108
144
 
109
- for img, mask in list(zip(x,y)):
145
+ for img, mask in zip(x, y):
110
146
  im_channels = [MedImage(c_img[None]) for c_img in img]
111
147
  im_channels.append(MedMask(mask))
112
148
  imgs.extend(im_channels)
113
149
 
114
- ctxs = [im.show(ax=ax, indices=indices, anatomical_plane=anatomical_plane) for im, ax in zip(imgs, axs)]
150
+ ctxs = [im.show(ax=ax, indices=indices, anatomical_plane=anatomical_plane)
151
+ for im, ax in zip(imgs, axs)]
152
+
115
153
  plt.tight_layout()
116
154
 
117
155
  return ctxs
118
156
 
119
- # %% ../nbs/02_vision_data.ipynb 22
157
+ # %% ../nbs/02_vision_data.ipynb 19
120
158
  @typedispatch
121
- def show_results(x:MedImage, y:torch.Tensor, samples, outs, ctxs=None, max_n=6, nrows=None, ncols=None, figsize=None, channel=0, indices=None, anatomical_plane=0, **kwargs):
122
- '''Showing samples and their corresponding predictions for regression tasks.'''
159
+ def show_results(x: MedImage, y: torch.Tensor, samples, outs, ctxs=None, max_n: int = 6,
160
+ nrows: int = None, ncols: int = None, figsize=None, channel: int = 0,
161
+ indices: int = None, anatomical_plane: int = 0, **kwargs):
162
+ """Showing samples and their corresponding predictions for regression tasks."""
123
163
 
124
- if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
164
+ if ctxs is None:
165
+ ctxs = get_grid(min(len(samples), max_n), nrows=nrows,
166
+ ncols=ncols, figsize=figsize)
125
167
 
126
168
  for i in range(len(samples[0])):
127
- ctxs = [b.show(ctx=c, channel=channel, indices=indices, anatomical_plane=anatomical_plane, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]
169
+ ctxs = [
170
+ b.show(ctx=c, channel=channel, indices=indices,
171
+ anatomical_plane=anatomical_plane, **kwargs)
172
+ for b, c, _ in zip(samples.itemgot(i), ctxs, range(max_n))
173
+ ]
174
+
128
175
  for i in range(len(outs[0])):
129
- ctxs = [b.show(ctx=c, **kwargs) for b,c,_ in zip(outs.itemgot(i),ctxs,range(max_n))]
176
+ ctxs = [
177
+ b.show(ctx=c, **kwargs)
178
+ for b, c, _ in zip(outs.itemgot(i), ctxs, range(max_n))
179
+ ]
180
+
130
181
  return ctxs
131
182
 
132
- # %% ../nbs/02_vision_data.ipynb 23
183
+ # %% ../nbs/02_vision_data.ipynb 20
133
184
  @typedispatch
134
- def show_results(x:MedImage, y:TensorCategory, samples, outs, ctxs=None, max_n=6, nrows=None, ncols=None, figsize=None, channel=0, indices=None, anatomical_plane=0, **kwargs):
135
- '''Showing samples and their corresponding predictions for classification tasks.'''
136
-
137
- if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
185
+ def show_results(x: MedImage, y: TensorCategory, samples, outs, ctxs=None,
186
+ max_n: int = 6, nrows: int = None, ncols: int = None, figsize=None, channel: int = 0,
187
+ indices: int = None, anatomical_plane: int = 0, **kwargs):
188
+ """Showing samples and their corresponding predictions for classification tasks."""
189
+
190
+ if ctxs is None:
191
+ ctxs = get_grid(min(len(samples), max_n), nrows=nrows,
192
+ ncols=ncols, figsize=figsize)
193
+
138
194
  for i in range(2):
139
- ctxs = [b.show(ctx=c, channel=channel, indices=indices, anatomical_plane=anatomical_plane, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))]
140
- ctxs = [r.show(ctx=c, color='green' if b==r else 'red', **kwargs) for b,r,c,_ in zip(samples.itemgot(1),outs.itemgot(0),ctxs,range(max_n))]
195
+ ctxs = [b.show(ctx=c, channel=channel, indices=indices,
196
+ anatomical_plane=anatomical_plane, **kwargs)
197
+ for b, c, _ in zip(samples.itemgot(i), ctxs, range(max_n))]
198
+
199
+ ctxs = [r.show(ctx=c, color='green' if b == r else 'red', **kwargs)
200
+ for b, r, c, _ in zip(samples.itemgot(1), outs.itemgot(0), ctxs, range(max_n))]
201
+
141
202
  return ctxs
142
203
 
143
- # %% ../nbs/02_vision_data.ipynb 24
204
+ # %% ../nbs/02_vision_data.ipynb 21
144
205
  @typedispatch
145
- def show_results(x:MedImage, y:MedMask, samples, outs, ctxs=None, max_n=6, nrows=None, ncols=1, figsize=None, channel=0, indices=None, anatomical_plane=0, **kwargs):
146
- ''' Showing decoded samples and their corresponding predictions for segmentation tasks.'''
206
+ def show_results(x: MedImage, y: MedMask, samples, outs, ctxs=None, max_n: int = 6,
207
+ nrows: int = None, ncols: int = 1, figsize=None, channel: int = 0,
208
+ indices: int = None, anatomical_plane: int = 0, **kwargs):
209
+ """Showing decoded samples and their corresponding predictions for segmentation tasks."""
210
+
211
+ if ctxs is None:
212
+ ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols,
213
+ figsize=figsize, double=True, title='Target/Prediction')
147
214
 
148
- if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize, double=True, title='Target/Prediction')
149
215
  for i in range(2):
150
- ctxs[::2] = [b.show(ctx=c, channel=channel, indices=indices, anatomical_plane=anatomical_plane, **kwargs) for b,c,_ in zip(samples.itemgot(i),ctxs[::2],range(2*max_n))]
151
- for o in [samples,outs]:
152
- ctxs[1::2] = [b.show(ctx=c, channel=channel, indices=indices, anatomical_plane=anatomical_plane, **kwargs) for b,c,_ in zip(o.itemgot(0),ctxs[1::2],range(2*max_n))]
216
+ ctxs[::2] = [b.show(ctx=c, channel=channel, indices=indices,
217
+ anatomical_plane=anatomical_plane, **kwargs)
218
+ for b, c, _ in zip(samples.itemgot(i), ctxs[::2], range(2 * max_n))]
219
+
220
+ for o in [samples, outs]:
221
+ ctxs[1::2] = [b.show(ctx=c, channel=channel, indices=indices,
222
+ anatomical_plane=anatomical_plane, **kwargs)
223
+ for b, c, _ in zip(o.itemgot(0), ctxs[1::2], range(2 * max_n))]
224
+
153
225
  return ctxs
154
226
 
155
- # %% ../nbs/02_vision_data.ipynb 26
227
+ # %% ../nbs/02_vision_data.ipynb 23
156
228
  @typedispatch
157
- def plot_top_losses(x: MedImage, y, samples, outs, raws, losses, nrows=None, ncols=None, figsize=None, channel=0, indices=None, anatomical_plane=0, **kwargs):
158
- '''Show images in top_losses along with their prediction, actual, loss, and probability of actual class.'''
229
+ def plot_top_losses(x: MedImage, y: TensorCategory, samples, outs, raws, losses, nrows: int = None,
230
+ ncols: int = None, figsize=None, channel: int = 0, indices: int = None,
231
+ anatomical_plane: int = 0, **kwargs):
232
+ """Show images in top_losses along with their prediction, actual, loss, and probability of actual class."""
159
233
 
160
- title = 'Prediction/Actual/Loss' if type(y) == torch.Tensor else 'Prediction/Actual/Loss/Probability'
234
+ title = 'Prediction/Actual/Loss' if isinstance(y, torch.Tensor) else 'Prediction/Actual/Loss/Probability'
161
235
  axs = get_grid(len(samples), nrows=nrows, ncols=ncols, figsize=figsize, title=title)
162
- for ax,s,o,r,l in zip(axs, samples, outs, raws, losses):
236
+
237
+ for ax, s, o, r, l in zip(axs, samples, outs, raws, losses):
163
238
  s[0].show(ctx=ax, channel=channel, indices=indices, anatomical_plane=anatomical_plane, **kwargs)
164
- if type(y) == torch.Tensor: ax.set_title(f'{r.max().item():.2f}/{s[1]} / {l.item():.2f}')
165
- else: ax.set_title(f'{o[0]}/{s[1]} / {l.item():.2f} / {r.max().item():.2f}')
166
239
 
167
- # %% ../nbs/02_vision_data.ipynb 27
240
+ if isinstance(y, torch.Tensor):
241
+ ax.set_title(f'{r.max().item():.2f}/{s[1]} / {l.item():.2f}')
242
+ else:
243
+ ax.set_title(f'{o[0]}/{s[1]} / {l.item():.2f} / {r.max().item():.2f}')
244
+
245
+ # %% ../nbs/02_vision_data.ipynb 24
168
246
  @typedispatch
169
- def plot_top_losses(x: MedImage, y:TensorMultiCategory, samples, outs, raws, losses, nrows=None, ncols=None, figsize=None, channel=0, indices=None, anatomical_plane=0, **kwargs):
170
- #TODO: not tested yet
247
+ def plot_top_losses(x: MedImage, y: TensorMultiCategory, samples, outs, raws,
248
+ losses, nrows: int = None, ncols: int = None, figsize=None,
249
+ channel: int = 0, indices: int = None,
250
+ anatomical_plane: int = 0, **kwargs):
251
+ # TODO: not tested yet
171
252
  axs = get_grid(len(samples), nrows=nrows, ncols=ncols, figsize=figsize)
172
- for i,(ax,s) in enumerate(zip(axs, samples)): s[0].show(ctx=ax, title=f'Image {i}', channel=channel, indices=indices, anatomical_plane=anatomical_plane, **kwargs)
253
+
254
+ for i, (ax, s) in enumerate(zip(axs, samples)):
255
+ s[0].show(ctx=ax, title=f'Image {i}', channel=channel,
256
+ indices=indices, anatomical_plane=anatomical_plane, **kwargs)
257
+
173
258
  rows = get_empty_df(len(samples))
174
- outs = L(s[1:] + o + (TitledStr(r), TitledFloat(l.item())) for s,o,r,l in zip(samples, outs, raws, losses))
175
- for i,l in enumerate(["target", "predicted", "probabilities", "loss"]):
176
- rows = [b.show(ctx=r, label=l, channel=channel, indices=indices, anatomical_plane=anatomical_plane, **kwargs) for b,r in zip(outs.itemgot(i),rows)]
259
+ outs = L(s[1:] + o + (TitledStr(r), TitledFloat(l.item()))
260
+ for s, o, r, l in zip(samples, outs, raws, losses))
261
+
262
+ for i, l in enumerate(["target", "predicted", "probabilities", "loss"]):
263
+ rows = [b.show(ctx=r, label=l, channel=channel, indices=indices,
264
+ anatomical_plane=anatomical_plane, **kwargs)
265
+ for b, r in zip(outs.itemgot(i), rows)]
266
+
177
267
  display_df(pd.DataFrame(rows))
@@ -4,15 +4,16 @@
4
4
  __all__ = ['inference', 'refine_binary_pred_mask']
5
5
 
6
6
  # %% ../nbs/06_vision_inference.ipynb 1
7
- import numpy as np
7
+ from copy import copy
8
8
  from pathlib import Path
9
- from torchio import Resize
9
+ import torch
10
+ import numpy as np
10
11
  from scipy.ndimage import label
11
- from .vision_core import *
12
- from .vision_augmentation import do_pad_or_crop
13
12
  from skimage.morphology import remove_small_objects
14
13
  from SimpleITK import DICOMOrient, GetArrayFromImage
15
- from copy import copy
14
+ from torchio import Resize
15
+ from .vision_core import *
16
+ from .vision_augmentation import do_pad_or_crop
16
17
 
17
18
  # %% ../nbs/06_vision_inference.ipynb 3
18
19
  def _to_original_orientation(input_img, org_orientation):
@@ -24,27 +25,42 @@ def _to_original_orientation(input_img, org_orientation):
24
25
  return reoriented_array[None]
25
26
 
26
27
  # %% ../nbs/06_vision_inference.ipynb 4
27
- def _do_resize(o, target_shape, image_interpolation='linear', label_interpolation='nearest'):
28
- '''Resample images so the output shape matches the given target shape.'''
28
+ def _do_resize(o, target_shape, image_interpolation='linear',
29
+ label_interpolation='nearest'):
30
+ """
31
+ Resample images so the output shape matches the given target shape.
32
+ """
29
33
 
30
- resize = Resize(target_shape, image_interpolation=image_interpolation, label_interpolation=label_interpolation)
34
+ resize = Resize(
35
+ target_shape,
36
+ image_interpolation=image_interpolation,
37
+ label_interpolation=label_interpolation
38
+ )
39
+
31
40
  return resize(o)
32
41
 
33
42
  # %% ../nbs/06_vision_inference.ipynb 5
34
- def inference(learn_inf, reorder, resample, fn:(Path,str)='', save_path:(str,Path)=None, org_img=None, input_img=None, org_size=None):
35
- '''Predict on new data using exported model'''
43
+ def inference(learn_inf, reorder, resample, fn: (str, Path) = '',
44
+ save_path: (str, Path) = None, org_img=None, input_img=None,
45
+ org_size=None):
46
+ """Predict on new data using exported model."""
47
+
36
48
  if None in [org_img, input_img, org_size]:
37
- org_img, input_img, org_size = med_img_reader(fn, reorder, resample, only_tensor=False)
38
- else: org_img, input_img = copy(org_img), copy(input_img)
49
+ org_img, input_img, org_size = med_img_reader(fn, reorder, resample,
50
+ only_tensor=False)
51
+ else:
52
+ org_img, input_img = copy(org_img), copy(input_img)
39
53
 
40
- pred, *_ = learn_inf.predict(input_img.data);
54
+ pred, *_ = learn_inf.predict(input_img.data)
41
55
 
42
- pred_mask = do_pad_or_crop(pred.float(), input_img.shape[1:], padding_mode=0, mask_name=None)
56
+ pred_mask = do_pad_or_crop(pred.float(), input_img.shape[1:], padding_mode=0,
57
+ mask_name=None)
43
58
  input_img.set_data(pred_mask)
44
59
 
45
60
  input_img = _do_resize(input_img, org_size, image_interpolation='nearest')
46
61
 
47
- reoriented_array = _to_original_orientation(input_img.as_sitk(), ('').join(org_img.orientation))
62
+ reoriented_array = _to_original_orientation(input_img.as_sitk(),
63
+ ('').join(org_img.orientation))
48
64
 
49
65
  org_img.set_data(reoriented_array)
50
66
 
@@ -56,12 +72,10 @@ def inference(learn_inf, reorder, resample, fn:(Path,str)='', save_path:(str,Pat
56
72
  return org_img
57
73
 
58
74
  # %% ../nbs/06_vision_inference.ipynb 7
59
- def refine_binary_pred_mask(
60
- pred_mask,
61
- remove_size: (int, float) = None,
62
- percentage: float = 0.2,
63
- verbose: bool = False
64
- ):
75
+ def refine_binary_pred_mask(pred_mask,
76
+ remove_size: (int, float) = None,
77
+ percentage: float = 0.2,
78
+ verbose: bool = False) -> torch.Tensor:
65
79
  """Removes small objects from the predicted binary mask.
66
80
 
67
81
  Args:
@@ -74,6 +88,7 @@ def refine_binary_pred_mask(
74
88
  Returns:
75
89
  The processed mask with small objects removed.
76
90
  """
91
+
77
92
  labeled_mask, n_components = label(pred_mask)
78
93
 
79
94
  if verbose:
@@ -88,4 +103,4 @@ def refine_binary_pred_mask(
88
103
  processed_mask = remove_small_objects(
89
104
  labeled_mask, min_size=small_objects_threshold)
90
105
 
91
- return np.where(processed_mask > 0, 1., 0.)
106
+ return torch.Tensor(processed_mask > 0).float()
fastMONAI/vision_loss.py CHANGED
@@ -12,40 +12,49 @@ from torch.nn.modules.loss import _Loss
12
12
 
13
13
  # %% ../nbs/04_vision_loss_functions.ipynb 3
14
14
  class CustomLoss:
15
- '''Wrapper to get show_results to work.'''
15
+ """A custom loss wrapper class for loss functions to allow them to work with
16
+ the 'show_results' method in fastai.
17
+ """
16
18
 
17
19
  def __init__(self, loss_func):
20
+ """Constructs CustomLoss object."""
21
+
18
22
  self.loss_func = loss_func
19
23
 
20
24
  def __call__(self, pred, targ):
21
- if isinstance(pred, MedBase): pred, targ = torch.Tensor(pred.cpu()), torch.Tensor(targ.cpu().float())
25
+ """Computes the loss for given predictions and targets."""
26
+
27
+ if isinstance(pred, MedBase):
28
+ pred, targ = torch.Tensor(pred.cpu()), torch.Tensor(targ.cpu().float())
29
+
22
30
  return self.loss_func(pred, targ)
23
31
 
24
32
  def activation(self, x):
25
33
  return x
26
34
 
27
- def decodes(self, x):
28
- '''Converts model output to target format.
29
-
35
+ def decodes(self, x) -> torch.Tensor:
36
+ """Converts model output to target format.
37
+
30
38
  Args:
31
- x: Activations for each class [B, C, W, H, D]
39
+ x: Activations for each class with dimensions [B, C, W, H, D].
32
40
 
33
41
  Returns:
34
- torch.Tensor: Predicted mask.
35
- '''
36
-
42
+ The predicted mask.
43
+ """
44
+
37
45
  n_classes = x.shape[1]
38
- if n_classes == 1: x = pred_to_binary_mask(x)
39
- else: x,_ = batch_pred_to_multiclass_mask(x)
46
+ if n_classes == 1:
47
+ x = pred_to_binary_mask(x)
48
+ else:
49
+ x,_ = batch_pred_to_multiclass_mask(x)
40
50
 
41
51
  return x
42
52
 
43
53
  # %% ../nbs/04_vision_loss_functions.ipynb 4
44
54
  class TverskyFocalLoss(_Loss):
45
55
  """
46
- Compute both Dice loss and Focal Loss, and return the weighted sum of these two losses.
47
- The details of Dice loss is shown in ``monai.losses.DiceLoss``.
48
- The details of Focal Loss is shown in ``monai.losses.FocalLoss``.
56
+ Compute Tversky loss with a focus parameter, gamma, applied.
57
+ The details of Tversky loss is shown in ``monai.losses.TverskyLoss``.
49
58
  """
50
59
 
51
60
  def __init__(
@@ -54,45 +63,45 @@ class TverskyFocalLoss(_Loss):
54
63
  to_onehot_y: bool = False,
55
64
  sigmoid: bool = False,
56
65
  softmax: bool = False,
57
- reduction: str = "mean",
58
66
  gamma: float = 2,
59
- #focal_weight: (float, int, torch.Tensor) = None,
60
- #lambda_dice: float = 1.0,
61
- #lambda_focal: float = 1.0,
62
- alpha = 0.5,
63
- beta = 0.99
64
- ) -> None:
65
-
67
+ alpha: float = 0.5,
68
+ beta: float = 0.99):
69
+ """
70
+ Args:
71
+ include_background: if to calculate loss for the background class.
72
+ to_onehot_y: whether to convert `y` into one-hot format.
73
+ sigmoid: if True, apply a sigmoid function to the prediction.
74
+ softmax: if True, apply a softmax function to the prediction.
75
+ gamma: the focal parameter, it modulates the loss with regards to
76
+ how far the prediction is from target.
77
+ alpha: the weight of false positive in Tversky loss calculation.
78
+ beta: the weight of false negative in Tversky loss calculation.
79
+ """
80
+
66
81
  super().__init__()
67
- self.tversky = TverskyLoss(to_onehot_y=to_onehot_y, include_background=include_background, sigmoid=sigmoid, softmax=softmax, alpha=alpha, beta=beta)
68
- #self.focal = FocalLoss(to_onehot_y=to_onehot_y, include_background=include_background, gamma=gamma, weight=focal_weight, reduction=reduction)
69
-
70
- #if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.")
71
- #if lambda_focal < 0.0: raise ValueError("lambda_focal should be no less than 0.0.")
72
- #self.lambda_dice = lambda_dice
73
- #self.lambda_focal = lambda_focal
74
- self.to_onehot_y = to_onehot_y
82
+ self.tversky = TverskyLoss(
83
+ to_onehot_y=to_onehot_y,
84
+ include_background=include_background,
85
+ sigmoid=sigmoid,
86
+ softmax=softmax,
87
+ alpha=alpha,
88
+ beta=beta
89
+ )
75
90
  self.gamma = gamma
76
- self.include_background = include_background
77
91
 
78
92
  def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
79
93
  """
80
94
  Args:
81
- input: the shape should be BNH[WD]. The input should be the original logits
82
- due to the restriction of ``monai.losses.FocalLoss``.
83
- target: the shape should be BNH[WD] or B1H[WD].
95
+ input: the shape should be [B, C, W, H, D]. The input should be the original logits.
96
+ target: the shape should be[B, C, W, H, D].
97
+
84
98
  Raises:
85
99
  ValueError: When number of dimensions for input and target are different.
86
- ValueError: When number of channels for target is neither 1 nor the same as input.
87
100
  """
88
101
  if len(input.shape) != len(target.shape):
89
- raise ValueError("the number of dimensions for input and target should be the same.")
90
-
91
- n_pred_ch = input.shape[1]
102
+ raise ValueError("The number of dimensions for input and target should be the same.")
92
103
 
93
104
  tversky_loss = self.tversky(input, target)
94
- #focal_loss = self.focal(input, target)
95
- total_loss: torch.Tensor = 1 - ((1 - tversky_loss)**self.gamma) #tversky_loss
96
- #print(total_loss,total_loss.shape)
97
- #tversky_loss + focal_loss
105
+ total_loss: torch.Tensor = 1 - ((1 - tversky_loss)**self.gamma)
106
+
98
107
  return total_loss