fastMONAI 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastMONAI/__init__.py +1 -1
- fastMONAI/_modidx.py +6 -20
- fastMONAI/dataset_info.py +58 -50
- fastMONAI/external_data.py +181 -91
- fastMONAI/utils.py +10 -12
- fastMONAI/vision_augmentation.py +160 -139
- fastMONAI/vision_core.py +43 -27
- fastMONAI/vision_data.py +175 -85
- fastMONAI/vision_inference.py +37 -22
- fastMONAI/vision_loss.py +51 -42
- fastMONAI/vision_metrics.py +46 -23
- fastMONAI/vision_plot.py +15 -13
- {fastMONAI-0.3.1.dist-info → fastMONAI-0.3.3.dist-info}/METADATA +1 -1
- fastMONAI-0.3.3.dist-info/RECORD +20 -0
- fastMONAI-0.3.1.dist-info/RECORD +0 -20
- {fastMONAI-0.3.1.dist-info → fastMONAI-0.3.3.dist-info}/LICENSE +0 -0
- {fastMONAI-0.3.1.dist-info → fastMONAI-0.3.3.dist-info}/WHEEL +0 -0
- {fastMONAI-0.3.1.dist-info → fastMONAI-0.3.3.dist-info}/entry_points.txt +0 -0
- {fastMONAI-0.3.1.dist-info → fastMONAI-0.3.3.dist-info}/top_level.txt +0 -0
fastMONAI/vision_data.py
CHANGED
|
@@ -10,168 +10,258 @@ from fastai.vision.data import *
|
|
|
10
10
|
from .vision_core import *
|
|
11
11
|
|
|
12
12
|
# %% ../nbs/02_vision_data.ipynb 5
|
|
13
|
-
def pred_to_multiclass_mask(pred:torch.Tensor
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
13
|
+
def pred_to_multiclass_mask(pred: torch.Tensor) -> torch.Tensor:
|
|
14
|
+
"""Apply Softmax on the predicted tensor to rescale the values in the range [0, 1]
|
|
15
|
+
and sum to 1. Then apply argmax to get the indices of the maximum value of all
|
|
16
|
+
elements in the predicted Tensor.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
pred: [C,W,H,D] activation tensor.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Predicted mask.
|
|
23
|
+
"""
|
|
24
|
+
|
|
19
25
|
pred = pred.softmax(dim=0)
|
|
26
|
+
|
|
20
27
|
return pred.argmax(dim=0, keepdims=True)
|
|
21
28
|
|
|
22
29
|
# %% ../nbs/02_vision_data.ipynb 6
|
|
23
|
-
def batch_pred_to_multiclass_mask(pred:torch.Tensor
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
30
|
+
def batch_pred_to_multiclass_mask(pred: torch.Tensor) -> (torch.Tensor, int):
|
|
31
|
+
"""Convert a batch of predicted activation tensors to masks.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
pred: [B, C, W, H, D] batch of activations.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Tuple of batch of predicted masks and number of classes.
|
|
38
|
+
"""
|
|
39
|
+
|
|
29
40
|
n_classes = pred.shape[1]
|
|
30
41
|
pred = [pred_to_multiclass_mask(p) for p in pred]
|
|
31
42
|
|
|
32
43
|
return torch.stack(pred), n_classes
|
|
33
44
|
|
|
34
45
|
# %% ../nbs/02_vision_data.ipynb 7
|
|
35
|
-
def pred_to_binary_mask(pred
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
46
|
+
def pred_to_binary_mask(pred: torch.Tensor) -> torch.Tensor:
|
|
47
|
+
"""Apply Sigmoid function that squishes activations into a range between 0 and 1.
|
|
48
|
+
Then we classify all values greater than or equal to 0.5 to 1,
|
|
49
|
+
and the values below 0.5 to 0.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
pred: [B, C, W, H, D] or [C, W, H, D] activation tensor
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Predicted binary mask(s).
|
|
56
|
+
"""
|
|
57
|
+
|
|
43
58
|
pred = torch.sigmoid(pred)
|
|
44
|
-
|
|
59
|
+
|
|
60
|
+
return torch.where(pred >= 0.5, 1, 0)
|
|
45
61
|
|
|
46
62
|
# %% ../nbs/02_vision_data.ipynb 9
|
|
47
63
|
class MedDataBlock(DataBlock):
|
|
48
|
-
|
|
64
|
+
"""Container to quickly build dataloaders."""
|
|
65
|
+
#TODO add get_x
|
|
66
|
+
def __init__(self, blocks: list = None, dl_type: TfmdDL = None, getters: list = None,
|
|
67
|
+
n_inp: int = None, item_tfms: list = None, batch_tfms: list = None,
|
|
68
|
+
reorder: bool = False, resample: (int, list) = None, **kwargs):
|
|
49
69
|
|
|
50
|
-
|
|
51
|
-
|
|
70
|
+
super().__init__(blocks, dl_type, getters, n_inp, item_tfms,
|
|
71
|
+
batch_tfms, **kwargs)
|
|
52
72
|
|
|
53
|
-
|
|
54
|
-
MedBase.item_preprocessing(resample,reorder)
|
|
73
|
+
MedBase.item_preprocessing(resample, reorder)
|
|
55
74
|
|
|
56
|
-
# %% ../nbs/02_vision_data.ipynb
|
|
75
|
+
# %% ../nbs/02_vision_data.ipynb 11
|
|
57
76
|
def MedMaskBlock():
|
|
77
|
+
"""Create a TransformBlock for medical masks."""
|
|
58
78
|
return TransformBlock(type_tfms=MedMask.create)
|
|
59
79
|
|
|
60
|
-
# %% ../nbs/02_vision_data.ipynb
|
|
80
|
+
# %% ../nbs/02_vision_data.ipynb 13
|
|
61
81
|
class MedImageDataLoaders(DataLoaders):
|
|
62
|
-
|
|
63
|
-
|
|
82
|
+
"""Higher-level `MedDataBlock` API."""
|
|
83
|
+
|
|
64
84
|
@classmethod
|
|
65
85
|
@delegates(DataLoaders.from_dblock)
|
|
66
|
-
def from_df(cls, df, valid_pct=0.2, seed=None, fn_col=0, folder=None, suff='',
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
86
|
+
def from_df(cls, df, valid_pct=0.2, seed=None, fn_col=0, folder=None, suff='',
|
|
87
|
+
label_col=1, label_delim=None, y_block=None, valid_col=None,
|
|
88
|
+
item_tfms=None, batch_tfms=None, reorder=False, resample=None, **kwargs):
|
|
89
|
+
"""Create from DataFrame."""
|
|
90
|
+
|
|
70
91
|
if y_block is None:
|
|
71
92
|
is_multi = (is_listy(label_col) and len(label_col) > 1) or label_delim is not None
|
|
72
93
|
y_block = MultiCategoryBlock if is_multi else CategoryBlock
|
|
73
|
-
splitter = RandomSplitter(valid_pct, seed=seed) if valid_col is None else ColSplitter(valid_col)
|
|
74
94
|
|
|
95
|
+
splitter = (RandomSplitter(valid_pct, seed=seed)
|
|
96
|
+
if valid_col is None else ColSplitter(valid_col))
|
|
75
97
|
|
|
76
|
-
dblock = MedDataBlock(
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
98
|
+
dblock = MedDataBlock(
|
|
99
|
+
blocks=(ImageBlock(cls=MedImage), y_block),
|
|
100
|
+
get_x=ColReader(fn_col, suff=suff),
|
|
101
|
+
get_y=ColReader(label_col, label_delim=label_delim),
|
|
102
|
+
splitter=splitter,
|
|
103
|
+
item_tfms=item_tfms,
|
|
104
|
+
reorder=reorder,
|
|
105
|
+
resample=resample
|
|
106
|
+
)
|
|
82
107
|
|
|
83
108
|
return cls.from_dblock(dblock, df, **kwargs)
|
|
84
109
|
|
|
85
|
-
# %% ../nbs/02_vision_data.ipynb
|
|
110
|
+
# %% ../nbs/02_vision_data.ipynb 16
|
|
86
111
|
@typedispatch
|
|
87
|
-
def show_batch(x:MedImage, y, samples, ctxs=None, max_n=6, nrows=None,
|
|
112
|
+
def show_batch(x: MedImage, y, samples, ctxs=None, max_n=6, nrows=None,
|
|
113
|
+
ncols=None, figsize=None, channel=0, indices=None,
|
|
114
|
+
anatomical_plane=0, **kwargs):
|
|
88
115
|
'''Showing a batch of samples for classification and regression tasks.'''
|
|
89
116
|
|
|
90
|
-
if ctxs is None:
|
|
117
|
+
if ctxs is None:
|
|
118
|
+
ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
|
|
119
|
+
|
|
91
120
|
n = 1 if y is None else 2
|
|
121
|
+
|
|
92
122
|
for i in range(n):
|
|
93
|
-
ctxs = [
|
|
123
|
+
ctxs = [
|
|
124
|
+
b.show(ctx=c, channel=channel, indices=indices, anatomical_plane=anatomical_plane, **kwargs)
|
|
125
|
+
for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))
|
|
126
|
+
]
|
|
94
127
|
|
|
95
128
|
plt.tight_layout()
|
|
129
|
+
|
|
96
130
|
return ctxs
|
|
97
131
|
|
|
98
|
-
# %% ../nbs/02_vision_data.ipynb
|
|
132
|
+
# %% ../nbs/02_vision_data.ipynb 17
|
|
99
133
|
@typedispatch
|
|
100
|
-
def show_batch(x:MedImage, y:MedMask, samples, ctxs=None, max_n=6, nrows=None,
|
|
101
|
-
|
|
134
|
+
def show_batch(x: MedImage, y: MedMask, samples, ctxs=None, max_n=6, nrows=None,
|
|
135
|
+
ncols=None, figsize=None, channel=0, indices=None,
|
|
136
|
+
anatomical_plane=0, **kwargs):
|
|
137
|
+
"""Showing a batch of decoded segmentation samples."""
|
|
102
138
|
|
|
103
139
|
nrows, ncols = min(len(samples), max_n), x.shape[1] + 1
|
|
104
140
|
imgs = []
|
|
105
141
|
|
|
106
|
-
fig,axs = subplots(nrows, ncols, figsize=figsize, **kwargs)
|
|
142
|
+
fig, axs = subplots(nrows, ncols, figsize=figsize, **kwargs)
|
|
107
143
|
axs = axs.flatten()
|
|
108
144
|
|
|
109
|
-
for img, mask in
|
|
145
|
+
for img, mask in zip(x, y):
|
|
110
146
|
im_channels = [MedImage(c_img[None]) for c_img in img]
|
|
111
147
|
im_channels.append(MedMask(mask))
|
|
112
148
|
imgs.extend(im_channels)
|
|
113
149
|
|
|
114
|
-
ctxs = [im.show(ax=ax, indices=indices, anatomical_plane=anatomical_plane)
|
|
150
|
+
ctxs = [im.show(ax=ax, indices=indices, anatomical_plane=anatomical_plane)
|
|
151
|
+
for im, ax in zip(imgs, axs)]
|
|
152
|
+
|
|
115
153
|
plt.tight_layout()
|
|
116
154
|
|
|
117
155
|
return ctxs
|
|
118
156
|
|
|
119
|
-
# %% ../nbs/02_vision_data.ipynb
|
|
157
|
+
# %% ../nbs/02_vision_data.ipynb 19
|
|
120
158
|
@typedispatch
|
|
121
|
-
def show_results(x:MedImage, y:torch.Tensor, samples, outs, ctxs=None, max_n
|
|
122
|
-
|
|
159
|
+
def show_results(x: MedImage, y: torch.Tensor, samples, outs, ctxs=None, max_n: int = 6,
|
|
160
|
+
nrows: int = None, ncols: int = None, figsize=None, channel: int = 0,
|
|
161
|
+
indices: int = None, anatomical_plane: int = 0, **kwargs):
|
|
162
|
+
"""Showing samples and their corresponding predictions for regression tasks."""
|
|
123
163
|
|
|
124
|
-
if ctxs is None:
|
|
164
|
+
if ctxs is None:
|
|
165
|
+
ctxs = get_grid(min(len(samples), max_n), nrows=nrows,
|
|
166
|
+
ncols=ncols, figsize=figsize)
|
|
125
167
|
|
|
126
168
|
for i in range(len(samples[0])):
|
|
127
|
-
ctxs = [
|
|
169
|
+
ctxs = [
|
|
170
|
+
b.show(ctx=c, channel=channel, indices=indices,
|
|
171
|
+
anatomical_plane=anatomical_plane, **kwargs)
|
|
172
|
+
for b, c, _ in zip(samples.itemgot(i), ctxs, range(max_n))
|
|
173
|
+
]
|
|
174
|
+
|
|
128
175
|
for i in range(len(outs[0])):
|
|
129
|
-
ctxs = [
|
|
176
|
+
ctxs = [
|
|
177
|
+
b.show(ctx=c, **kwargs)
|
|
178
|
+
for b, c, _ in zip(outs.itemgot(i), ctxs, range(max_n))
|
|
179
|
+
]
|
|
180
|
+
|
|
130
181
|
return ctxs
|
|
131
182
|
|
|
132
|
-
# %% ../nbs/02_vision_data.ipynb
|
|
183
|
+
# %% ../nbs/02_vision_data.ipynb 20
|
|
133
184
|
@typedispatch
|
|
134
|
-
def show_results(x:MedImage, y:TensorCategory, samples, outs, ctxs=None,
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
185
|
+
def show_results(x: MedImage, y: TensorCategory, samples, outs, ctxs=None,
|
|
186
|
+
max_n: int = 6, nrows: int = None, ncols: int = None, figsize=None, channel: int = 0,
|
|
187
|
+
indices: int = None, anatomical_plane: int = 0, **kwargs):
|
|
188
|
+
"""Showing samples and their corresponding predictions for classification tasks."""
|
|
189
|
+
|
|
190
|
+
if ctxs is None:
|
|
191
|
+
ctxs = get_grid(min(len(samples), max_n), nrows=nrows,
|
|
192
|
+
ncols=ncols, figsize=figsize)
|
|
193
|
+
|
|
138
194
|
for i in range(2):
|
|
139
|
-
ctxs = [b.show(ctx=c, channel=channel, indices=indices,
|
|
140
|
-
|
|
195
|
+
ctxs = [b.show(ctx=c, channel=channel, indices=indices,
|
|
196
|
+
anatomical_plane=anatomical_plane, **kwargs)
|
|
197
|
+
for b, c, _ in zip(samples.itemgot(i), ctxs, range(max_n))]
|
|
198
|
+
|
|
199
|
+
ctxs = [r.show(ctx=c, color='green' if b == r else 'red', **kwargs)
|
|
200
|
+
for b, r, c, _ in zip(samples.itemgot(1), outs.itemgot(0), ctxs, range(max_n))]
|
|
201
|
+
|
|
141
202
|
return ctxs
|
|
142
203
|
|
|
143
|
-
# %% ../nbs/02_vision_data.ipynb
|
|
204
|
+
# %% ../nbs/02_vision_data.ipynb 21
|
|
144
205
|
@typedispatch
|
|
145
|
-
def show_results(x:MedImage, y:MedMask, samples, outs, ctxs=None, max_n
|
|
146
|
-
|
|
206
|
+
def show_results(x: MedImage, y: MedMask, samples, outs, ctxs=None, max_n: int = 6,
|
|
207
|
+
nrows: int = None, ncols: int = 1, figsize=None, channel: int = 0,
|
|
208
|
+
indices: int = None, anatomical_plane: int = 0, **kwargs):
|
|
209
|
+
"""Showing decoded samples and their corresponding predictions for segmentation tasks."""
|
|
210
|
+
|
|
211
|
+
if ctxs is None:
|
|
212
|
+
ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols,
|
|
213
|
+
figsize=figsize, double=True, title='Target/Prediction')
|
|
147
214
|
|
|
148
|
-
if ctxs is None: ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize, double=True, title='Target/Prediction')
|
|
149
215
|
for i in range(2):
|
|
150
|
-
ctxs[::2] = [b.show(ctx=c, channel=channel, indices=indices,
|
|
151
|
-
|
|
152
|
-
|
|
216
|
+
ctxs[::2] = [b.show(ctx=c, channel=channel, indices=indices,
|
|
217
|
+
anatomical_plane=anatomical_plane, **kwargs)
|
|
218
|
+
for b, c, _ in zip(samples.itemgot(i), ctxs[::2], range(2 * max_n))]
|
|
219
|
+
|
|
220
|
+
for o in [samples, outs]:
|
|
221
|
+
ctxs[1::2] = [b.show(ctx=c, channel=channel, indices=indices,
|
|
222
|
+
anatomical_plane=anatomical_plane, **kwargs)
|
|
223
|
+
for b, c, _ in zip(o.itemgot(0), ctxs[1::2], range(2 * max_n))]
|
|
224
|
+
|
|
153
225
|
return ctxs
|
|
154
226
|
|
|
155
|
-
# %% ../nbs/02_vision_data.ipynb
|
|
227
|
+
# %% ../nbs/02_vision_data.ipynb 23
|
|
156
228
|
@typedispatch
|
|
157
|
-
def plot_top_losses(x: MedImage, y, samples, outs, raws, losses, nrows
|
|
158
|
-
|
|
229
|
+
def plot_top_losses(x: MedImage, y: TensorCategory, samples, outs, raws, losses, nrows: int = None,
|
|
230
|
+
ncols: int = None, figsize=None, channel: int = 0, indices: int = None,
|
|
231
|
+
anatomical_plane: int = 0, **kwargs):
|
|
232
|
+
"""Show images in top_losses along with their prediction, actual, loss, and probability of actual class."""
|
|
159
233
|
|
|
160
|
-
title = 'Prediction/Actual/Loss' if
|
|
234
|
+
title = 'Prediction/Actual/Loss' if isinstance(y, torch.Tensor) else 'Prediction/Actual/Loss/Probability'
|
|
161
235
|
axs = get_grid(len(samples), nrows=nrows, ncols=ncols, figsize=figsize, title=title)
|
|
162
|
-
|
|
236
|
+
|
|
237
|
+
for ax, s, o, r, l in zip(axs, samples, outs, raws, losses):
|
|
163
238
|
s[0].show(ctx=ax, channel=channel, indices=indices, anatomical_plane=anatomical_plane, **kwargs)
|
|
164
|
-
if type(y) == torch.Tensor: ax.set_title(f'{r.max().item():.2f}/{s[1]} / {l.item():.2f}')
|
|
165
|
-
else: ax.set_title(f'{o[0]}/{s[1]} / {l.item():.2f} / {r.max().item():.2f}')
|
|
166
239
|
|
|
167
|
-
|
|
240
|
+
if isinstance(y, torch.Tensor):
|
|
241
|
+
ax.set_title(f'{r.max().item():.2f}/{s[1]} / {l.item():.2f}')
|
|
242
|
+
else:
|
|
243
|
+
ax.set_title(f'{o[0]}/{s[1]} / {l.item():.2f} / {r.max().item():.2f}')
|
|
244
|
+
|
|
245
|
+
# %% ../nbs/02_vision_data.ipynb 24
|
|
168
246
|
@typedispatch
|
|
169
|
-
def plot_top_losses(x: MedImage, y:TensorMultiCategory, samples, outs, raws,
|
|
170
|
-
|
|
247
|
+
def plot_top_losses(x: MedImage, y: TensorMultiCategory, samples, outs, raws,
|
|
248
|
+
losses, nrows: int = None, ncols: int = None, figsize=None,
|
|
249
|
+
channel: int = 0, indices: int = None,
|
|
250
|
+
anatomical_plane: int = 0, **kwargs):
|
|
251
|
+
# TODO: not tested yet
|
|
171
252
|
axs = get_grid(len(samples), nrows=nrows, ncols=ncols, figsize=figsize)
|
|
172
|
-
|
|
253
|
+
|
|
254
|
+
for i, (ax, s) in enumerate(zip(axs, samples)):
|
|
255
|
+
s[0].show(ctx=ax, title=f'Image {i}', channel=channel,
|
|
256
|
+
indices=indices, anatomical_plane=anatomical_plane, **kwargs)
|
|
257
|
+
|
|
173
258
|
rows = get_empty_df(len(samples))
|
|
174
|
-
outs = L(s[1:] + o + (TitledStr(r), TitledFloat(l.item()))
|
|
175
|
-
|
|
176
|
-
|
|
259
|
+
outs = L(s[1:] + o + (TitledStr(r), TitledFloat(l.item()))
|
|
260
|
+
for s, o, r, l in zip(samples, outs, raws, losses))
|
|
261
|
+
|
|
262
|
+
for i, l in enumerate(["target", "predicted", "probabilities", "loss"]):
|
|
263
|
+
rows = [b.show(ctx=r, label=l, channel=channel, indices=indices,
|
|
264
|
+
anatomical_plane=anatomical_plane, **kwargs)
|
|
265
|
+
for b, r in zip(outs.itemgot(i), rows)]
|
|
266
|
+
|
|
177
267
|
display_df(pd.DataFrame(rows))
|
fastMONAI/vision_inference.py
CHANGED
|
@@ -4,15 +4,16 @@
|
|
|
4
4
|
__all__ = ['inference', 'refine_binary_pred_mask']
|
|
5
5
|
|
|
6
6
|
# %% ../nbs/06_vision_inference.ipynb 1
|
|
7
|
-
|
|
7
|
+
from copy import copy
|
|
8
8
|
from pathlib import Path
|
|
9
|
-
|
|
9
|
+
import torch
|
|
10
|
+
import numpy as np
|
|
10
11
|
from scipy.ndimage import label
|
|
11
|
-
from .vision_core import *
|
|
12
|
-
from .vision_augmentation import do_pad_or_crop
|
|
13
12
|
from skimage.morphology import remove_small_objects
|
|
14
13
|
from SimpleITK import DICOMOrient, GetArrayFromImage
|
|
15
|
-
from
|
|
14
|
+
from torchio import Resize
|
|
15
|
+
from .vision_core import *
|
|
16
|
+
from .vision_augmentation import do_pad_or_crop
|
|
16
17
|
|
|
17
18
|
# %% ../nbs/06_vision_inference.ipynb 3
|
|
18
19
|
def _to_original_orientation(input_img, org_orientation):
|
|
@@ -24,27 +25,42 @@ def _to_original_orientation(input_img, org_orientation):
|
|
|
24
25
|
return reoriented_array[None]
|
|
25
26
|
|
|
26
27
|
# %% ../nbs/06_vision_inference.ipynb 4
|
|
27
|
-
def _do_resize(o, target_shape, image_interpolation='linear',
|
|
28
|
-
|
|
28
|
+
def _do_resize(o, target_shape, image_interpolation='linear',
|
|
29
|
+
label_interpolation='nearest'):
|
|
30
|
+
"""
|
|
31
|
+
Resample images so the output shape matches the given target shape.
|
|
32
|
+
"""
|
|
29
33
|
|
|
30
|
-
resize = Resize(
|
|
34
|
+
resize = Resize(
|
|
35
|
+
target_shape,
|
|
36
|
+
image_interpolation=image_interpolation,
|
|
37
|
+
label_interpolation=label_interpolation
|
|
38
|
+
)
|
|
39
|
+
|
|
31
40
|
return resize(o)
|
|
32
41
|
|
|
33
42
|
# %% ../nbs/06_vision_inference.ipynb 5
|
|
34
|
-
def inference(learn_inf, reorder, resample, fn:(
|
|
35
|
-
|
|
43
|
+
def inference(learn_inf, reorder, resample, fn: (str, Path) = '',
|
|
44
|
+
save_path: (str, Path) = None, org_img=None, input_img=None,
|
|
45
|
+
org_size=None):
|
|
46
|
+
"""Predict on new data using exported model."""
|
|
47
|
+
|
|
36
48
|
if None in [org_img, input_img, org_size]:
|
|
37
|
-
org_img, input_img, org_size = med_img_reader(fn, reorder, resample,
|
|
38
|
-
|
|
49
|
+
org_img, input_img, org_size = med_img_reader(fn, reorder, resample,
|
|
50
|
+
only_tensor=False)
|
|
51
|
+
else:
|
|
52
|
+
org_img, input_img = copy(org_img), copy(input_img)
|
|
39
53
|
|
|
40
|
-
pred, *_ = learn_inf.predict(input_img.data)
|
|
54
|
+
pred, *_ = learn_inf.predict(input_img.data)
|
|
41
55
|
|
|
42
|
-
pred_mask = do_pad_or_crop(pred.float(), input_img.shape[1:], padding_mode=0,
|
|
56
|
+
pred_mask = do_pad_or_crop(pred.float(), input_img.shape[1:], padding_mode=0,
|
|
57
|
+
mask_name=None)
|
|
43
58
|
input_img.set_data(pred_mask)
|
|
44
59
|
|
|
45
60
|
input_img = _do_resize(input_img, org_size, image_interpolation='nearest')
|
|
46
61
|
|
|
47
|
-
reoriented_array = _to_original_orientation(input_img.as_sitk(),
|
|
62
|
+
reoriented_array = _to_original_orientation(input_img.as_sitk(),
|
|
63
|
+
('').join(org_img.orientation))
|
|
48
64
|
|
|
49
65
|
org_img.set_data(reoriented_array)
|
|
50
66
|
|
|
@@ -56,12 +72,10 @@ def inference(learn_inf, reorder, resample, fn:(Path,str)='', save_path:(str,Pat
|
|
|
56
72
|
return org_img
|
|
57
73
|
|
|
58
74
|
# %% ../nbs/06_vision_inference.ipynb 7
|
|
59
|
-
def refine_binary_pred_mask(
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
verbose: bool = False
|
|
64
|
-
):
|
|
75
|
+
def refine_binary_pred_mask(pred_mask,
|
|
76
|
+
remove_size: (int, float) = None,
|
|
77
|
+
percentage: float = 0.2,
|
|
78
|
+
verbose: bool = False) -> torch.Tensor:
|
|
65
79
|
"""Removes small objects from the predicted binary mask.
|
|
66
80
|
|
|
67
81
|
Args:
|
|
@@ -74,6 +88,7 @@ def refine_binary_pred_mask(
|
|
|
74
88
|
Returns:
|
|
75
89
|
The processed mask with small objects removed.
|
|
76
90
|
"""
|
|
91
|
+
|
|
77
92
|
labeled_mask, n_components = label(pred_mask)
|
|
78
93
|
|
|
79
94
|
if verbose:
|
|
@@ -88,4 +103,4 @@ def refine_binary_pred_mask(
|
|
|
88
103
|
processed_mask = remove_small_objects(
|
|
89
104
|
labeled_mask, min_size=small_objects_threshold)
|
|
90
105
|
|
|
91
|
-
return
|
|
106
|
+
return torch.Tensor(processed_mask > 0).float()
|
fastMONAI/vision_loss.py
CHANGED
|
@@ -12,40 +12,49 @@ from torch.nn.modules.loss import _Loss
|
|
|
12
12
|
|
|
13
13
|
# %% ../nbs/04_vision_loss_functions.ipynb 3
|
|
14
14
|
class CustomLoss:
|
|
15
|
-
|
|
15
|
+
"""A custom loss wrapper class for loss functions to allow them to work with
|
|
16
|
+
the 'show_results' method in fastai.
|
|
17
|
+
"""
|
|
16
18
|
|
|
17
19
|
def __init__(self, loss_func):
|
|
20
|
+
"""Constructs CustomLoss object."""
|
|
21
|
+
|
|
18
22
|
self.loss_func = loss_func
|
|
19
23
|
|
|
20
24
|
def __call__(self, pred, targ):
|
|
21
|
-
|
|
25
|
+
"""Computes the loss for given predictions and targets."""
|
|
26
|
+
|
|
27
|
+
if isinstance(pred, MedBase):
|
|
28
|
+
pred, targ = torch.Tensor(pred.cpu()), torch.Tensor(targ.cpu().float())
|
|
29
|
+
|
|
22
30
|
return self.loss_func(pred, targ)
|
|
23
31
|
|
|
24
32
|
def activation(self, x):
|
|
25
33
|
return x
|
|
26
34
|
|
|
27
|
-
def decodes(self, x):
|
|
28
|
-
|
|
29
|
-
|
|
35
|
+
def decodes(self, x) -> torch.Tensor:
|
|
36
|
+
"""Converts model output to target format.
|
|
37
|
+
|
|
30
38
|
Args:
|
|
31
|
-
x: Activations for each class [B, C, W, H, D]
|
|
39
|
+
x: Activations for each class with dimensions [B, C, W, H, D].
|
|
32
40
|
|
|
33
41
|
Returns:
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
42
|
+
The predicted mask.
|
|
43
|
+
"""
|
|
44
|
+
|
|
37
45
|
n_classes = x.shape[1]
|
|
38
|
-
if n_classes == 1:
|
|
39
|
-
|
|
46
|
+
if n_classes == 1:
|
|
47
|
+
x = pred_to_binary_mask(x)
|
|
48
|
+
else:
|
|
49
|
+
x,_ = batch_pred_to_multiclass_mask(x)
|
|
40
50
|
|
|
41
51
|
return x
|
|
42
52
|
|
|
43
53
|
# %% ../nbs/04_vision_loss_functions.ipynb 4
|
|
44
54
|
class TverskyFocalLoss(_Loss):
|
|
45
55
|
"""
|
|
46
|
-
Compute
|
|
47
|
-
The details of
|
|
48
|
-
The details of Focal Loss is shown in ``monai.losses.FocalLoss``.
|
|
56
|
+
Compute Tversky loss with a focus parameter, gamma, applied.
|
|
57
|
+
The details of Tversky loss is shown in ``monai.losses.TverskyLoss``.
|
|
49
58
|
"""
|
|
50
59
|
|
|
51
60
|
def __init__(
|
|
@@ -54,45 +63,45 @@ class TverskyFocalLoss(_Loss):
|
|
|
54
63
|
to_onehot_y: bool = False,
|
|
55
64
|
sigmoid: bool = False,
|
|
56
65
|
softmax: bool = False,
|
|
57
|
-
reduction: str = "mean",
|
|
58
66
|
gamma: float = 2,
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
67
|
+
alpha: float = 0.5,
|
|
68
|
+
beta: float = 0.99):
|
|
69
|
+
"""
|
|
70
|
+
Args:
|
|
71
|
+
include_background: if to calculate loss for the background class.
|
|
72
|
+
to_onehot_y: whether to convert `y` into one-hot format.
|
|
73
|
+
sigmoid: if True, apply a sigmoid function to the prediction.
|
|
74
|
+
softmax: if True, apply a softmax function to the prediction.
|
|
75
|
+
gamma: the focal parameter, it modulates the loss with regards to
|
|
76
|
+
how far the prediction is from target.
|
|
77
|
+
alpha: the weight of false positive in Tversky loss calculation.
|
|
78
|
+
beta: the weight of false negative in Tversky loss calculation.
|
|
79
|
+
"""
|
|
80
|
+
|
|
66
81
|
super().__init__()
|
|
67
|
-
self.tversky = TverskyLoss(
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
82
|
+
self.tversky = TverskyLoss(
|
|
83
|
+
to_onehot_y=to_onehot_y,
|
|
84
|
+
include_background=include_background,
|
|
85
|
+
sigmoid=sigmoid,
|
|
86
|
+
softmax=softmax,
|
|
87
|
+
alpha=alpha,
|
|
88
|
+
beta=beta
|
|
89
|
+
)
|
|
75
90
|
self.gamma = gamma
|
|
76
|
-
self.include_background = include_background
|
|
77
91
|
|
|
78
92
|
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
79
93
|
"""
|
|
80
94
|
Args:
|
|
81
|
-
input: the shape should be
|
|
82
|
-
|
|
83
|
-
|
|
95
|
+
input: the shape should be [B, C, W, H, D]. The input should be the original logits.
|
|
96
|
+
target: the shape should be[B, C, W, H, D].
|
|
97
|
+
|
|
84
98
|
Raises:
|
|
85
99
|
ValueError: When number of dimensions for input and target are different.
|
|
86
|
-
ValueError: When number of channels for target is neither 1 nor the same as input.
|
|
87
100
|
"""
|
|
88
101
|
if len(input.shape) != len(target.shape):
|
|
89
|
-
raise ValueError("
|
|
90
|
-
|
|
91
|
-
n_pred_ch = input.shape[1]
|
|
102
|
+
raise ValueError("The number of dimensions for input and target should be the same.")
|
|
92
103
|
|
|
93
104
|
tversky_loss = self.tversky(input, target)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
#print(total_loss,total_loss.shape)
|
|
97
|
-
#tversky_loss + focal_loss
|
|
105
|
+
total_loss: torch.Tensor = 1 - ((1 - tversky_loss)**self.gamma)
|
|
106
|
+
|
|
98
107
|
return total_loss
|