fastMONAI 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastMONAI/__init__.py +1 -1
- fastMONAI/_modidx.py +6 -20
- fastMONAI/dataset_info.py +58 -50
- fastMONAI/external_data.py +181 -91
- fastMONAI/utils.py +10 -12
- fastMONAI/vision_augmentation.py +160 -139
- fastMONAI/vision_core.py +43 -27
- fastMONAI/vision_data.py +175 -85
- fastMONAI/vision_inference.py +37 -22
- fastMONAI/vision_loss.py +51 -42
- fastMONAI/vision_metrics.py +46 -23
- fastMONAI/vision_plot.py +15 -13
- {fastMONAI-0.3.1.dist-info → fastMONAI-0.3.3.dist-info}/METADATA +1 -1
- fastMONAI-0.3.3.dist-info/RECORD +20 -0
- fastMONAI-0.3.1.dist-info/RECORD +0 -20
- {fastMONAI-0.3.1.dist-info → fastMONAI-0.3.3.dist-info}/LICENSE +0 -0
- {fastMONAI-0.3.1.dist-info → fastMONAI-0.3.3.dist-info}/WHEEL +0 -0
- {fastMONAI-0.3.1.dist-info → fastMONAI-0.3.3.dist-info}/entry_points.txt +0 -0
- {fastMONAI-0.3.1.dist-info → fastMONAI-0.3.3.dist-info}/top_level.txt +0 -0
fastMONAI/vision_augmentation.py
CHANGED
|
@@ -12,75 +12,88 @@ import torchio as tio
|
|
|
12
12
|
|
|
13
13
|
# %% ../nbs/03_vision_augment.ipynb 5
|
|
14
14
|
class CustomDictTransform(ItemTransform):
|
|
15
|
-
|
|
15
|
+
"""A class that serves as a wrapper to perform an identical transformation on both
|
|
16
|
+
the image and the target (if it's a mask).
|
|
17
|
+
"""
|
|
16
18
|
|
|
17
|
-
split_idx = 0
|
|
18
|
-
|
|
19
|
+
split_idx = 0 # Only perform transformations on training data. Use TTA() for transformations on validation data.
|
|
20
|
+
|
|
21
|
+
def __init__(self, aug):
|
|
22
|
+
"""Constructs CustomDictTransform object.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
aug (Callable): Function to apply augmentation on the image.
|
|
26
|
+
"""
|
|
27
|
+
self.aug = aug
|
|
19
28
|
|
|
20
29
|
def encodes(self, x):
|
|
21
|
-
|
|
30
|
+
"""
|
|
31
|
+
Applies the stored transformation to an image, and the same random transformation
|
|
32
|
+
to the target if it is a mask. If the target is not a mask, it returns the target as is.
|
|
22
33
|
|
|
23
34
|
Args:
|
|
24
|
-
x:
|
|
35
|
+
x (Tuple[MedImage, Union[MedMask, TensorCategory]]): A tuple containing the
|
|
36
|
+
image and the target.
|
|
25
37
|
|
|
26
38
|
Returns:
|
|
27
|
-
MedImage:
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
39
|
+
Tuple[MedImage, Union[MedMask, TensorCategory]]: The transformed image and target.
|
|
40
|
+
If the target is a mask, it's transformed identically to the image. If the target
|
|
41
|
+
is not a mask, the original target is returned.
|
|
42
|
+
"""
|
|
31
43
|
img, y_true = x
|
|
32
44
|
|
|
33
45
|
if isinstance(y_true, (MedMask)):
|
|
34
|
-
aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img, affine=MedImage.affine_matrix),
|
|
46
|
+
aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img, affine=MedImage.affine_matrix),
|
|
47
|
+
mask=tio.LabelMap(tensor=y_true, affine=MedImage.affine_matrix)))
|
|
35
48
|
return MedImage.create(aug['img'].data), MedMask.create(aug['mask'].data)
|
|
36
|
-
else:
|
|
37
|
-
aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img)))
|
|
38
|
-
return MedImage.create(aug['img'].data), y_true
|
|
39
49
|
|
|
40
|
-
|
|
41
|
-
|
|
50
|
+
aug = self.aug(tio.Subject(img=tio.ScalarImage(tensor=img)))
|
|
51
|
+
return MedImage.create(aug['img'].data), y_true
|
|
42
52
|
|
|
53
|
+
|
|
54
|
+
# %% ../nbs/03_vision_augment.ipynb 7
|
|
55
|
+
def do_pad_or_crop(o, target_shape, padding_mode, mask_name, dtype=torch.Tensor):
|
|
56
|
+
#TODO:refactorize
|
|
43
57
|
pad_or_crop = tio.CropOrPad(target_shape=target_shape, padding_mode=padding_mode, mask_name=mask_name)
|
|
44
58
|
return dtype(pad_or_crop(o))
|
|
45
59
|
|
|
46
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
60
|
+
# %% ../nbs/03_vision_augment.ipynb 8
|
|
47
61
|
class PadOrCrop(DisplayedTransform):
|
|
48
|
-
|
|
62
|
+
"""Resize image using TorchIO `CropOrPad`."""
|
|
63
|
+
|
|
64
|
+
order = 0
|
|
49
65
|
|
|
50
|
-
order=0
|
|
51
66
|
def __init__(self, size, padding_mode=0, mask_name=None):
|
|
52
|
-
if not is_listy(size):
|
|
53
|
-
|
|
67
|
+
if not is_listy(size):
|
|
68
|
+
size = [size, size, size]
|
|
69
|
+
self.pad_or_crop = tio.CropOrPad(target_shape=size,
|
|
70
|
+
padding_mode=padding_mode,
|
|
71
|
+
mask_name=mask_name)
|
|
54
72
|
|
|
55
|
-
def encodes(self, o:(MedImage, MedMask)):
|
|
56
|
-
return
|
|
73
|
+
def encodes(self, o: (MedImage, MedMask)):
|
|
74
|
+
return type(o)(self.pad_or_crop(o))
|
|
57
75
|
|
|
58
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
z_normalization = tio.ZNormalization(masking_method=masking_method)
|
|
62
|
-
normalized_tensor = torch.zeros(o.shape)
|
|
76
|
+
# %% ../nbs/03_vision_augment.ipynb 9
|
|
77
|
+
class ZNormalization(DisplayedTransform):
|
|
78
|
+
"""Apply TorchIO `ZNormalization`."""
|
|
63
79
|
|
|
64
|
-
|
|
65
|
-
for idx, c in enumerate(o):
|
|
66
|
-
normalized_tensor[idx] = z_normalization(c[None])[0]
|
|
67
|
-
|
|
68
|
-
else: normalized_tensor = z_normalization(o)
|
|
80
|
+
order = 0
|
|
69
81
|
|
|
70
|
-
|
|
82
|
+
def __init__(self, masking_method=None, channel_wise=True):
|
|
83
|
+
self.z_normalization = tio.ZNormalization(masking_method=masking_method)
|
|
84
|
+
self.channel_wise = channel_wise
|
|
71
85
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
86
|
+
def encodes(self, o: MedImage):
|
|
87
|
+
if self.channel_wise:
|
|
88
|
+
o = torch.stack([self.z_normalization(c[None])[0] for c in o])
|
|
89
|
+
else: o = self.z_normalization(o)
|
|
75
90
|
|
|
76
|
-
|
|
77
|
-
def __init__(self, masking_method=None, channel_wise=True):
|
|
78
|
-
self.masking_method, self.channel_wise = masking_method, channel_wise
|
|
91
|
+
return MedImage.create(o)
|
|
79
92
|
|
|
80
|
-
def encodes(self, o:
|
|
81
|
-
|
|
93
|
+
def encodes(self, o: MedMask):
|
|
94
|
+
return o
|
|
82
95
|
|
|
83
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
96
|
+
# %% ../nbs/03_vision_augment.ipynb 10
|
|
84
97
|
class BraTSMaskConverter(DisplayedTransform):
|
|
85
98
|
'''Convert BraTS masks.'''
|
|
86
99
|
|
|
@@ -92,115 +105,95 @@ class BraTSMaskConverter(DisplayedTransform):
|
|
|
92
105
|
o = torch.where(o==4, 3., o)
|
|
93
106
|
return MedMask.create(o)
|
|
94
107
|
|
|
95
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
108
|
+
# %% ../nbs/03_vision_augment.ipynb 11
|
|
96
109
|
class BinaryConverter(DisplayedTransform):
|
|
97
110
|
'''Convert to binary mask.'''
|
|
98
111
|
|
|
99
112
|
order=1
|
|
100
113
|
|
|
101
|
-
def encodes(self, o:
|
|
114
|
+
def encodes(self, o: MedImage):
|
|
115
|
+
return o
|
|
102
116
|
|
|
103
|
-
def encodes(self, o:
|
|
117
|
+
def encodes(self, o: MedMask):
|
|
104
118
|
o = torch.where(o>0, 1., 0)
|
|
105
119
|
return MedMask.create(o)
|
|
106
120
|
|
|
107
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
108
|
-
def _do_rand_ghosting(o, intensity, p):
|
|
109
|
-
|
|
110
|
-
add_ghosts = tio.RandomGhosting(intensity=intensity, p=p)
|
|
111
|
-
return add_ghosts(o)
|
|
112
|
-
|
|
113
|
-
# %% ../nbs/03_vision_augment.ipynb 19
|
|
121
|
+
# %% ../nbs/03_vision_augment.ipynb 12
|
|
114
122
|
class RandomGhosting(DisplayedTransform):
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
split_idx,order=0,1
|
|
118
|
-
|
|
119
|
-
def __init__(self, intensity =(0.5, 1), p=0.5):
|
|
120
|
-
self.intensity, self.p = intensity, p
|
|
123
|
+
"""Apply TorchIO `RandomGhosting`."""
|
|
124
|
+
|
|
125
|
+
split_idx, order = 0, 1
|
|
121
126
|
|
|
122
|
-
def
|
|
123
|
-
|
|
127
|
+
def __init__(self, intensity=(0.5, 1), p=0.5):
|
|
128
|
+
self.add_ghosts = tio.RandomGhosting(intensity=intensity, p=p)
|
|
124
129
|
|
|
125
|
-
|
|
126
|
-
|
|
130
|
+
def encodes(self, o: MedImage):
|
|
131
|
+
return MedImage.create(self.add_ghosts(o))
|
|
127
132
|
|
|
128
|
-
|
|
129
|
-
|
|
133
|
+
def encodes(self, o: MedMask):
|
|
134
|
+
return o
|
|
130
135
|
|
|
131
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
136
|
+
# %% ../nbs/03_vision_augment.ipynb 13
|
|
132
137
|
class RandomSpike(DisplayedTransform):
|
|
133
138
|
'''Apply TorchIO `RandomSpike`.'''
|
|
134
139
|
|
|
135
140
|
split_idx,order=0,1
|
|
136
141
|
|
|
137
142
|
def __init__(self, num_spikes=1, intensity=(1, 3), p=0.5):
|
|
138
|
-
self.
|
|
143
|
+
self.add_spikes = tio.RandomSpike(num_spikes=num_spikes, intensity=intensity, p=p)
|
|
139
144
|
|
|
140
|
-
def encodes(self, o:
|
|
141
|
-
|
|
145
|
+
def encodes(self, o:MedImage):
|
|
146
|
+
return MedImage.create(self.add_spikes(o))
|
|
147
|
+
|
|
148
|
+
def encodes(self, o:MedMask):
|
|
149
|
+
return o
|
|
142
150
|
|
|
143
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
144
|
-
def _do_rand_noise(o, mean, std, p):
|
|
145
|
-
|
|
146
|
-
add_noise = tio.RandomNoise(mean=mean, std=std, p=p)
|
|
147
|
-
return add_noise(o) #return torch tensor
|
|
148
|
-
|
|
149
|
-
# %% ../nbs/03_vision_augment.ipynb 25
|
|
151
|
+
# %% ../nbs/03_vision_augment.ipynb 14
|
|
150
152
|
class RandomNoise(DisplayedTransform):
|
|
151
153
|
'''Apply TorchIO `RandomNoise`.'''
|
|
152
154
|
|
|
153
155
|
split_idx,order=0,1
|
|
154
156
|
|
|
155
157
|
def __init__(self, mean=0, std=(0, 0.25), p=0.5):
|
|
156
|
-
self.
|
|
157
|
-
|
|
158
|
-
def encodes(self, o:(MedImage)): return MedImage.create(_do_rand_noise(o, mean=self.mean, std=self.std, p=self.p))
|
|
159
|
-
def encodes(self, o:(MedMask)):return o
|
|
158
|
+
self.add_noise = tio.RandomNoise(mean=mean, std=std, p=p)
|
|
160
159
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
160
|
+
def encodes(self, o: MedImage):
|
|
161
|
+
return MedImage.create(self.add_noise(o))
|
|
162
|
+
|
|
163
|
+
def encodes(self, o: MedMask):
|
|
164
|
+
return o
|
|
166
165
|
|
|
167
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
166
|
+
# %% ../nbs/03_vision_augment.ipynb 15
|
|
168
167
|
class RandomBiasField(DisplayedTransform):
|
|
169
168
|
'''Apply TorchIO `RandomBiasField`.'''
|
|
170
169
|
|
|
171
170
|
split_idx,order=0,1
|
|
172
171
|
|
|
173
172
|
def __init__(self, coefficients=0.5, order=3, p=0.5):
|
|
174
|
-
self.
|
|
175
|
-
|
|
176
|
-
def encodes(self, o:(MedImage)): return MedImage.create(_do_rand_biasfield(o, coefficients=self.coefficients, order=self.order, p=self.p))
|
|
177
|
-
def encodes(self, o:(MedMask)):return o
|
|
178
|
-
|
|
179
|
-
# %% ../nbs/03_vision_augment.ipynb 30
|
|
180
|
-
def _do_rand_blur(o, std, p):
|
|
173
|
+
self.add_biasfield = tio.RandomBiasField(coefficients=coefficients, order=order, p=p)
|
|
181
174
|
|
|
182
|
-
|
|
183
|
-
|
|
175
|
+
def encodes(self, o: MedImage):
|
|
176
|
+
return MedImage.create(self.add_biasfield(o))
|
|
177
|
+
|
|
178
|
+
def encodes(self, o: MedMask):
|
|
179
|
+
return o
|
|
184
180
|
|
|
185
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
181
|
+
# %% ../nbs/03_vision_augment.ipynb 16
|
|
186
182
|
class RandomBlur(DisplayedTransform):
|
|
187
183
|
'''Apply TorchIO `RandomBiasField`.'''
|
|
188
184
|
|
|
189
185
|
split_idx,order=0,1
|
|
190
186
|
|
|
191
187
|
def __init__(self, std=(0, 2), p=0.5):
|
|
192
|
-
self.
|
|
193
|
-
|
|
194
|
-
def encodes(self, o:
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
add_gamma = tio.RandomGamma(log_gamma=log_gamma, p=p)
|
|
201
|
-
return add_gamma(o)
|
|
188
|
+
self.add_blur = tio.RandomBlur(std=std, p=p)
|
|
189
|
+
|
|
190
|
+
def encodes(self, o: MedImage):
|
|
191
|
+
return MedImage.create(self.add_blur(o))
|
|
192
|
+
|
|
193
|
+
def encodes(self, o: MedMask):
|
|
194
|
+
return o
|
|
202
195
|
|
|
203
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
196
|
+
# %% ../nbs/03_vision_augment.ipynb 17
|
|
204
197
|
class RandomGamma(DisplayedTransform):
|
|
205
198
|
'''Apply TorchIO `RandomGamma`.'''
|
|
206
199
|
|
|
@@ -208,53 +201,81 @@ class RandomGamma(DisplayedTransform):
|
|
|
208
201
|
split_idx,order=0,1
|
|
209
202
|
|
|
210
203
|
def __init__(self, log_gamma=(-0.3, 0.3), p=0.5):
|
|
211
|
-
self.
|
|
212
|
-
|
|
213
|
-
def encodes(self, o:(MedImage)): return MedImage.create(_do_rand_gamma(o, log_gamma=self.log_gamma, p=self.p))
|
|
214
|
-
def encodes(self, o:(MedMask)):return o
|
|
215
|
-
|
|
216
|
-
# %% ../nbs/03_vision_augment.ipynb 36
|
|
217
|
-
def _do_rand_motion(o, degrees, translation, num_transforms, image_interpolation, p):
|
|
204
|
+
self.add_gamma = tio.RandomGamma(log_gamma=log_gamma, p=p)
|
|
218
205
|
|
|
219
|
-
|
|
220
|
-
|
|
206
|
+
def encodes(self, o: MedImage):
|
|
207
|
+
return MedImage.create(self.add_gamma(o))
|
|
208
|
+
|
|
209
|
+
def encodes(self, o: MedMask):
|
|
210
|
+
return o
|
|
221
211
|
|
|
222
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
212
|
+
# %% ../nbs/03_vision_augment.ipynb 18
|
|
223
213
|
class RandomMotion(DisplayedTransform):
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
split_idx,order=0,1
|
|
227
|
-
|
|
228
|
-
def __init__(
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
214
|
+
"""Apply TorchIO `RandomMotion`."""
|
|
215
|
+
|
|
216
|
+
split_idx, order = 0, 1
|
|
217
|
+
|
|
218
|
+
def __init__(
|
|
219
|
+
self,
|
|
220
|
+
degrees=10,
|
|
221
|
+
translation=10,
|
|
222
|
+
num_transforms=2,
|
|
223
|
+
image_interpolation='linear',
|
|
224
|
+
p=0.5
|
|
225
|
+
):
|
|
226
|
+
self.add_motion = tio.RandomMotion(
|
|
227
|
+
degrees=degrees,
|
|
228
|
+
translation=translation,
|
|
229
|
+
num_transforms=num_transforms,
|
|
230
|
+
image_interpolation=image_interpolation,
|
|
231
|
+
p=p
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
def encodes(self, o: MedImage):
|
|
235
|
+
return MedImage.create(self.add_motion(o))
|
|
236
|
+
|
|
237
|
+
def encodes(self, o: MedMask):
|
|
238
|
+
return o
|
|
239
|
+
|
|
240
|
+
# %% ../nbs/03_vision_augment.ipynb 20
|
|
235
241
|
class RandomElasticDeformation(CustomDictTransform):
|
|
236
|
-
|
|
242
|
+
"""Apply TorchIO `RandomElasticDeformation`."""
|
|
237
243
|
|
|
238
|
-
def __init__(self,num_control_points=7, max_displacement=7.5,
|
|
239
|
-
|
|
244
|
+
def __init__(self, num_control_points=7, max_displacement=7.5,
|
|
245
|
+
image_interpolation='linear', p=0.5):
|
|
246
|
+
|
|
247
|
+
super().__init__(tio.RandomElasticDeformation(
|
|
248
|
+
num_control_points=num_control_points,
|
|
249
|
+
max_displacement=max_displacement,
|
|
250
|
+
image_interpolation=image_interpolation,
|
|
251
|
+
p=p))
|
|
240
252
|
|
|
241
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
253
|
+
# %% ../nbs/03_vision_augment.ipynb 21
|
|
242
254
|
class RandomAffine(CustomDictTransform):
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
def __init__(self, scales=0, degrees=10, translation=0, isotropic=False,
|
|
246
|
-
|
|
255
|
+
"""Apply TorchIO `RandomAffine`."""
|
|
256
|
+
|
|
257
|
+
def __init__(self, scales=0, degrees=10, translation=0, isotropic=False,
|
|
258
|
+
image_interpolation='linear', default_pad_value=0., p=0.5):
|
|
259
|
+
|
|
260
|
+
super().__init__(tio.RandomAffine(
|
|
261
|
+
scales=scales,
|
|
262
|
+
degrees=degrees,
|
|
263
|
+
translation=translation,
|
|
264
|
+
isotropic=isotropic,
|
|
265
|
+
image_interpolation=image_interpolation,
|
|
266
|
+
default_pad_value=default_pad_value,
|
|
267
|
+
p=p))
|
|
247
268
|
|
|
248
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
269
|
+
# %% ../nbs/03_vision_augment.ipynb 22
|
|
249
270
|
class RandomFlip(CustomDictTransform):
|
|
250
|
-
|
|
271
|
+
"""Apply TorchIO `RandomFlip`."""
|
|
251
272
|
|
|
252
273
|
def __init__(self, axes='LR', p=0.5):
|
|
253
274
|
super().__init__(tio.RandomFlip(axes=axes, flip_probability=p))
|
|
254
275
|
|
|
255
|
-
# %% ../nbs/03_vision_augment.ipynb
|
|
276
|
+
# %% ../nbs/03_vision_augment.ipynb 23
|
|
256
277
|
class OneOf(CustomDictTransform):
|
|
257
|
-
|
|
278
|
+
"""Apply only one of the given transforms using TorchIO `OneOf`."""
|
|
258
279
|
|
|
259
280
|
def __init__(self, transform_dict, p=1):
|
|
260
281
|
super().__init__(tio.OneOf(transform_dict, p=p))
|
fastMONAI/vision_core.py
CHANGED
|
@@ -10,7 +10,8 @@ from torchio import ScalarImage, LabelMap, ToCanonical, Resample
|
|
|
10
10
|
|
|
11
11
|
# %% ../nbs/01_vision_core.ipynb 5
|
|
12
12
|
def _preprocess(obj, reorder, resample):
|
|
13
|
-
"""
|
|
13
|
+
"""
|
|
14
|
+
Preprocesses the given object.
|
|
14
15
|
|
|
15
16
|
Args:
|
|
16
17
|
obj: The object to preprocess.
|
|
@@ -83,12 +84,8 @@ def _multi_channel(image_paths: list, reorder: bool, resample: list, dtype, only
|
|
|
83
84
|
|
|
84
85
|
|
|
85
86
|
# %% ../nbs/01_vision_core.ipynb 8
|
|
86
|
-
def med_img_reader(
|
|
87
|
-
|
|
88
|
-
dtype=torch.Tensor,
|
|
89
|
-
reorder: bool = False,
|
|
90
|
-
resample: list = None,
|
|
91
|
-
only_tensor: bool = True
|
|
87
|
+
def med_img_reader(file_path: (str, Path), dtype=torch.Tensor, reorder: bool = False,
|
|
88
|
+
resample: list = None, only_tensor: bool = True
|
|
92
89
|
):
|
|
93
90
|
"""Loads and preprocesses a medical image.
|
|
94
91
|
|
|
@@ -120,32 +117,36 @@ def med_img_reader(
|
|
|
120
117
|
|
|
121
118
|
# %% ../nbs/01_vision_core.ipynb 10
|
|
122
119
|
class MetaResolver(type(torch.Tensor), metaclass=BypassNewMeta):
|
|
123
|
-
|
|
120
|
+
"""
|
|
121
|
+
A class to bypass metaclass conflict:
|
|
124
122
|
https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/data/batch.html
|
|
125
|
-
|
|
123
|
+
"""
|
|
126
124
|
pass
|
|
127
125
|
|
|
128
126
|
# %% ../nbs/01_vision_core.ipynb 11
|
|
129
|
-
class MedBase(torch.Tensor, metaclass=MetaResolver):
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
127
|
+
class MedBase(torch.Tensor, metaclass=MetaResolver):
|
|
128
|
+
"""A class that represents an image object.
|
|
129
|
+
Metaclass casts `x` to this class if it is of type `cls._bypass_type`."""
|
|
130
|
+
|
|
131
|
+
_bypass_type = torch.Tensor
|
|
133
132
|
_show_args = {'cmap':'gray'}
|
|
134
133
|
resample, reorder = None, False
|
|
135
134
|
affine_matrix = None
|
|
136
135
|
|
|
137
|
-
|
|
138
136
|
@classmethod
|
|
139
|
-
def create(cls, fn: (Path, str, torch.Tensor), **kwargs):
|
|
137
|
+
def create(cls, fn: (Path, str, torch.Tensor), **kwargs) -> torch.Tensor:
|
|
140
138
|
"""
|
|
141
|
-
|
|
139
|
+
Opens a medical image and casts it to MedBase object.
|
|
140
|
+
If `fn` is a torch.Tensor, it's cast to MedBase object.
|
|
142
141
|
|
|
143
142
|
Args:
|
|
144
|
-
fn:
|
|
145
|
-
|
|
143
|
+
fn : (Path, str, torch.Tensor)
|
|
144
|
+
Image path or a 4D torch.Tensor.
|
|
145
|
+
kwargs : dict
|
|
146
|
+
Additional parameters for the medical image reader.
|
|
146
147
|
|
|
147
148
|
Returns:
|
|
148
|
-
A 4D tensor as MedBase object.
|
|
149
|
+
torch.Tensor : A 4D tensor as a MedBase object.
|
|
149
150
|
"""
|
|
150
151
|
if isinstance(fn, torch.Tensor):
|
|
151
152
|
return cls(fn)
|
|
@@ -155,18 +156,32 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
|
|
|
155
156
|
@classmethod
|
|
156
157
|
def item_preprocessing(cls, resample: (list, int, tuple), reorder: bool):
|
|
157
158
|
"""
|
|
158
|
-
|
|
159
|
+
Changes the values for the class variables `resample` and `reorder`.
|
|
159
160
|
|
|
160
161
|
Args:
|
|
161
|
-
resample:
|
|
162
|
-
|
|
162
|
+
resample : (list, int, tuple)
|
|
163
|
+
A list with voxel spacing.
|
|
164
|
+
reorder : bool
|
|
165
|
+
Whether to reorder the data to be closest to canonical (RAS+) orientation.
|
|
163
166
|
"""
|
|
164
167
|
cls.resample = resample
|
|
165
168
|
cls.reorder = reorder
|
|
166
169
|
|
|
167
|
-
def show(self, ctx=None, channel=0, indices=None, anatomical_plane=0, **kwargs):
|
|
170
|
+
def show(self, ctx=None, channel: int = 0, indices: int = None, anatomical_plane: int = 0, **kwargs):
|
|
168
171
|
"""
|
|
169
|
-
|
|
172
|
+
Displays the Medimage using `merge(self._show_args, kwargs)`.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
ctx : Any, optional
|
|
176
|
+
Context to use for the display. Defaults to None.
|
|
177
|
+
channel : int, optional
|
|
178
|
+
The channel of the image to be displayed. Defaults to 0.
|
|
179
|
+
indices : list or None, optional
|
|
180
|
+
Indices of the images to be displayed. Defaults to None.
|
|
181
|
+
anatomical_plane : int, optional
|
|
182
|
+
Anatomical plane of the image to be displayed. Defaults to 0.
|
|
183
|
+
kwargs : dict, optional
|
|
184
|
+
Additional parameters for the show function.
|
|
170
185
|
|
|
171
186
|
Returns:
|
|
172
187
|
Shown image.
|
|
@@ -177,15 +192,16 @@ class MedBase(torch.Tensor, metaclass=MetaResolver):
|
|
|
177
192
|
**merge(self._show_args, kwargs)
|
|
178
193
|
)
|
|
179
194
|
|
|
180
|
-
def __repr__(self):
|
|
195
|
+
def __repr__(self) -> str:
|
|
196
|
+
"""Returns the string representation of the MedBase instance."""
|
|
181
197
|
return f'{self.__class__.__name__} mode={self.mode} size={"x".join([str(d) for d in self.size])}'
|
|
182
198
|
|
|
183
199
|
# %% ../nbs/01_vision_core.ipynb 12
|
|
184
200
|
class MedImage(MedBase):
|
|
185
|
-
|
|
201
|
+
"""Subclass of MedBase that represents an image object."""
|
|
186
202
|
pass
|
|
187
203
|
|
|
188
204
|
# %% ../nbs/01_vision_core.ipynb 13
|
|
189
205
|
class MedMask(MedBase):
|
|
190
|
-
|
|
206
|
+
"""Subclass of MedBase that represents an mask object."""
|
|
191
207
|
_show_args = {'alpha':0.5, 'cmap':'tab20'}
|