konfai 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +16 -0
- konfai/data/HDF5.py +326 -0
- konfai/data/__init__.py +0 -0
- konfai/data/augmentation.py +597 -0
- konfai/data/dataset.py +470 -0
- konfai/data/transform.py +536 -0
- konfai/evaluator.py +146 -0
- konfai/main.py +43 -0
- konfai/metric/__init__.py +0 -0
- konfai/metric/measure.py +488 -0
- konfai/metric/schedulers.py +49 -0
- konfai/models/classification/convNeXt.py +175 -0
- konfai/models/classification/resnet.py +116 -0
- konfai/models/generation/cStyleGan.py +137 -0
- konfai/models/generation/ddpm.py +218 -0
- konfai/models/generation/diffusionGan.py +557 -0
- konfai/models/generation/gan.py +134 -0
- konfai/models/generation/vae.py +72 -0
- konfai/models/registration/registration.py +136 -0
- konfai/models/representation/representation.py +57 -0
- konfai/models/segmentation/NestedUNet.py +53 -0
- konfai/models/segmentation/UNet.py +58 -0
- konfai/network/__init__.py +0 -0
- konfai/network/blocks.py +348 -0
- konfai/network/network.py +950 -0
- konfai/predictor.py +366 -0
- konfai/trainer.py +330 -0
- konfai/utils/ITK.py +269 -0
- konfai/utils/Registration.py +199 -0
- konfai/utils/__init__.py +0 -0
- konfai/utils/config.py +218 -0
- konfai/utils/dataset.py +764 -0
- konfai/utils/utils.py +493 -0
- konfai-1.0.0.dist-info/METADATA +68 -0
- konfai-1.0.0.dist-info/RECORD +39 -0
- konfai-1.0.0.dist-info/WHEEL +5 -0
- konfai-1.0.0.dist-info/entry_points.txt +3 -0
- konfai-1.0.0.dist-info/licenses/LICENSE +201 -0
- konfai-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,597 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import torch
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
import numpy as np
|
|
5
|
+
import SimpleITK as sitk
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from typing import Union
|
|
8
|
+
import os
|
|
9
|
+
from KonfAI.konfai import DEEP_LEARNING_API_ROOT
|
|
10
|
+
from KonfAI.konfai.utils.config import config
|
|
11
|
+
from KonfAI.konfai.utils.utils import _getModule
|
|
12
|
+
from KonfAI.konfai.utils.dataset import Attribute, data_to_image
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _translate2DMatrix(t: torch.Tensor) -> torch.Tensor:
|
|
16
|
+
return torch.cat((torch.cat((torch.eye(2), torch.tensor([[t[0]], [t[1]]])), dim=1), torch.Tensor([[0,0,1]])), dim=0)
|
|
17
|
+
|
|
18
|
+
def _translate3DMatrix(t: torch.Tensor) -> torch.Tensor:
|
|
19
|
+
return torch.cat((torch.cat((torch.eye(3), torch.tensor([[t[0]], [t[1]], [t[2]]])), dim=1), torch.Tensor([[0,0,0,1]])), dim=0)
|
|
20
|
+
|
|
21
|
+
def _scale2DMatrix(s: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
return torch.cat((torch.cat((torch.eye(2)*s, torch.tensor([[0], [0]])), dim=1), torch.tensor([[0, 0, 1]])), dim=0)
|
|
23
|
+
|
|
24
|
+
def _scale3DMatrix(s: torch.Tensor) -> torch.Tensor:
|
|
25
|
+
return torch.cat((torch.cat((torch.eye(3)*s, torch.tensor([[0], [0], [0]])), dim=1), torch.tensor([[0, 0, 0, 1]])), dim=0)
|
|
26
|
+
|
|
27
|
+
def _rotation3DMatrix(rotation : torch.Tensor, center: Union[torch.Tensor, None] = None) -> torch.Tensor:
|
|
28
|
+
A = torch.tensor([[torch.cos(rotation[2]), -torch.sin(rotation[2]), 0], [torch.sin(rotation[2]), torch.cos(rotation[2]), 0], [0, 0, 1]])
|
|
29
|
+
B = torch.tensor([[torch.cos(rotation[1]), 0, torch.sin(rotation[1])], [0, 1, 0], [-torch.sin(rotation[1]), 0, torch.cos(rotation[1])]])
|
|
30
|
+
C = torch.tensor([[1, 0, 0], [0, torch.cos(rotation[0]), -torch.sin(rotation[0])], [0, torch.sin(rotation[0]), torch.cos(rotation[0])]])
|
|
31
|
+
rotation_matrix = torch.cat((torch.cat((A.mm(B).mm(C), torch.zeros((3, 1))), dim=1), torch.tensor([[0, 0, 0, 1]])), dim=0)
|
|
32
|
+
if center is not None:
|
|
33
|
+
translation_before = torch.eye(4)
|
|
34
|
+
translation_before[:-1, -1] = -center
|
|
35
|
+
rotation_matrix = translation_before.mm(rotation_matrix)
|
|
36
|
+
if center is not None:
|
|
37
|
+
translation_after = torch.eye(4)
|
|
38
|
+
translation_after[:-1, -1] = center
|
|
39
|
+
rotation_matrix = rotation_matrix.mm(translation_after)
|
|
40
|
+
return rotation_matrix
|
|
41
|
+
|
|
42
|
+
def _rotation2DMatrix(rotation : torch.Tensor, center: Union[torch.Tensor, None] = None) -> torch.Tensor:
|
|
43
|
+
return torch.cat((torch.cat((torch.tensor([[torch.cos(rotation[0]), -torch.sin(rotation[0])], [torch.sin(rotation[0]), torch.cos(rotation[0])]]), torch.zeros((2, 1))), dim=1), torch.tensor([[0, 0, 1]])), dim=0)
|
|
44
|
+
|
|
45
|
+
class Prob():
|
|
46
|
+
|
|
47
|
+
@config()
|
|
48
|
+
def __init__(self, prob: float = 1.0) -> None:
|
|
49
|
+
self.prob = prob
|
|
50
|
+
|
|
51
|
+
class DataAugmentationsList():
|
|
52
|
+
|
|
53
|
+
@config()
|
|
54
|
+
def __init__(self, nb : int = 10, dataAugmentations: dict[str, Prob] = {"default:RandomElastixTransform" : Prob(1)}) -> None:
|
|
55
|
+
self.nb = nb
|
|
56
|
+
self.dataAugmentations : list[DataAugmentation] = []
|
|
57
|
+
self.dataAugmentationsLoader = dataAugmentations
|
|
58
|
+
|
|
59
|
+
def load(self, key: str):
|
|
60
|
+
for augmentation, prob in self.dataAugmentationsLoader.items():
|
|
61
|
+
module, name = _getModule(augmentation, "augmentation")
|
|
62
|
+
dataAugmentation: DataAugmentation = getattr(importlib.import_module(module), name)(config = None, DL_args="{}.Dataset.augmentations.{}.dataAugmentations".format(DEEP_LEARNING_API_ROOT(), key))
|
|
63
|
+
dataAugmentation.load(prob.prob)
|
|
64
|
+
self.dataAugmentations.append(dataAugmentation)
|
|
65
|
+
|
|
66
|
+
class DataAugmentation(ABC):
|
|
67
|
+
|
|
68
|
+
def __init__(self) -> None:
|
|
69
|
+
self.who_index: dict[int, list[int]] = {}
|
|
70
|
+
self.shape_index: dict[int, list[list[int]]] = {}
|
|
71
|
+
self._prob: float = 0
|
|
72
|
+
|
|
73
|
+
def load(self, prob: float):
|
|
74
|
+
self._prob = prob
|
|
75
|
+
|
|
76
|
+
def state_init(self, index: Union[None, int], shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
77
|
+
if index is not None:
|
|
78
|
+
if index not in self.who_index:
|
|
79
|
+
self.who_index[index] = torch.where(torch.rand(len(shapes)) < self._prob)[0].tolist()
|
|
80
|
+
else:
|
|
81
|
+
return self.shape_index[index]
|
|
82
|
+
else:
|
|
83
|
+
index = 0
|
|
84
|
+
self.who_index[index] = torch.where(torch.rand(len(shapes)) < self._prob)[0].tolist()
|
|
85
|
+
|
|
86
|
+
if len(self.who_index[index]) > 0:
|
|
87
|
+
for i, shape in enumerate(self._state_init(index, [shapes[i] for i in self.who_index[index]], [caches_attribute[i] for i in self.who_index[index]])):
|
|
88
|
+
shapes[self.who_index[index][i]] = shape
|
|
89
|
+
self.shape_index[index] = shapes
|
|
90
|
+
return self.shape_index[index]
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
def __call__(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
97
|
+
if len(self.who_index[index]) > 0:
|
|
98
|
+
for i, result in enumerate(self._compute(index, [inputs[i] for i in self.who_index[index]], device)):
|
|
99
|
+
inputs[self.who_index[index][i]] = result if device is None else result.cpu()
|
|
100
|
+
return inputs
|
|
101
|
+
|
|
102
|
+
@abstractmethod
|
|
103
|
+
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
def inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
107
|
+
if a in self.who_index[index]:
|
|
108
|
+
input = self._inverse(index, a, input)
|
|
109
|
+
return input
|
|
110
|
+
|
|
111
|
+
@abstractmethod
|
|
112
|
+
def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
class EulerTransform(DataAugmentation):
|
|
116
|
+
|
|
117
|
+
def __init__(self) -> None:
|
|
118
|
+
super().__init__()
|
|
119
|
+
self.matrix: dict[int, list[torch.Tensor]] = {}
|
|
120
|
+
|
|
121
|
+
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
122
|
+
results = []
|
|
123
|
+
for input, matrix in zip(inputs, self.matrix[index]):
|
|
124
|
+
results.append(F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0))
|
|
125
|
+
return results
|
|
126
|
+
|
|
127
|
+
def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
128
|
+
return F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(self.matrix[index][a].inverse()[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0)
|
|
129
|
+
|
|
130
|
+
class Translate(EulerTransform):
|
|
131
|
+
|
|
132
|
+
@config("Translate")
|
|
133
|
+
def __init__(self, t_min: float = -10, t_max = 10, is_int: bool = False):
|
|
134
|
+
super().__init__()
|
|
135
|
+
self.t_min = t_min
|
|
136
|
+
self.t_max = t_max
|
|
137
|
+
self.is_int = is_int
|
|
138
|
+
|
|
139
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
140
|
+
dim = len(shapes[0])
|
|
141
|
+
func = _translate3DMatrix if dim == 3 else _translate2DMatrix
|
|
142
|
+
translate = torch.rand((len(shapes), dim)) * torch.tensor(self.t_max-self.t_min) + torch.tensor(self.t_min)
|
|
143
|
+
if self.is_int:
|
|
144
|
+
translate = torch.round(translate*100)/100
|
|
145
|
+
self.matrix[index] = [torch.unsqueeze(func(value), dim=0) for value in translate]
|
|
146
|
+
return shapes
|
|
147
|
+
|
|
148
|
+
class Rotate(EulerTransform):
|
|
149
|
+
|
|
150
|
+
@config("Rotate")
|
|
151
|
+
def __init__(self, a_min: float = 0, a_max: float = 360, is_quarter: bool = False):
|
|
152
|
+
super().__init__()
|
|
153
|
+
self.a_min = a_min
|
|
154
|
+
self.a_max = a_max
|
|
155
|
+
self.is_quarter = is_quarter
|
|
156
|
+
|
|
157
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
158
|
+
dim = len(shapes[0])
|
|
159
|
+
func = _rotation3DMatrix if dim == 3 else _rotation2DMatrix
|
|
160
|
+
angles = []
|
|
161
|
+
|
|
162
|
+
if self.is_quarter:
|
|
163
|
+
angles = torch.Tensor.repeat(torch.tensor([90,180,270]), 3)
|
|
164
|
+
else:
|
|
165
|
+
angles = torch.rand((len(shapes), dim))*torch.tensor(self.a_max-self.a_min) + torch.tensor(self.a_min)
|
|
166
|
+
|
|
167
|
+
self.matrix[index] = [torch.unsqueeze(func(value), dim=0) for value in angles]
|
|
168
|
+
return shapes
|
|
169
|
+
|
|
170
|
+
class Scale(EulerTransform):
|
|
171
|
+
|
|
172
|
+
@config("Scale")
|
|
173
|
+
def __init__(self, s_std: float = 0.2):
|
|
174
|
+
super().__init__()
|
|
175
|
+
self.s_std = s_std
|
|
176
|
+
|
|
177
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
178
|
+
func = _scale3DMatrix if len(shapes[0]) == 3 else _scale2DMatrix
|
|
179
|
+
scale = torch.Tensor.repeat(torch.exp2(torch.randn((len(shapes))) * self.s_std).unsqueeze(1), [1, len(shapes[0])])
|
|
180
|
+
self.matrix[index] = [torch.unsqueeze(func(value), dim=0) for value in scale]
|
|
181
|
+
return shapes
|
|
182
|
+
|
|
183
|
+
class Flip(DataAugmentation):
|
|
184
|
+
|
|
185
|
+
@config("Flip")
|
|
186
|
+
def __init__(self, f_prob: Union[list[float], None] = [0.33, 0.33 ,0.33]) -> None:
|
|
187
|
+
super().__init__()
|
|
188
|
+
self.f_prob = f_prob
|
|
189
|
+
self.flip: dict[int, list[int]] = {}
|
|
190
|
+
|
|
191
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
192
|
+
prob = torch.rand((len(shapes), len(self.f_prob))) < torch.tensor(self.f_prob)
|
|
193
|
+
dims = torch.tensor([1, 2, 3][:len(self.f_prob)])
|
|
194
|
+
self.flip[index] = [dims[mask].tolist() for mask in prob]
|
|
195
|
+
return shapes
|
|
196
|
+
|
|
197
|
+
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
198
|
+
results = []
|
|
199
|
+
for input, flip in zip(inputs, self.flip[index]):
|
|
200
|
+
results.append(torch.flip(input, dims=flip))
|
|
201
|
+
return results
|
|
202
|
+
|
|
203
|
+
def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
204
|
+
return torch.flip(input, dims=self.flip[index][a])
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class ColorTransform(DataAugmentation):
|
|
208
|
+
|
|
209
|
+
@config("ColorTransform")
|
|
210
|
+
def __init__(self) -> None:
|
|
211
|
+
super().__init__()
|
|
212
|
+
self.matrix: dict[int, list[torch.Tensor]] = {}
|
|
213
|
+
|
|
214
|
+
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
215
|
+
results = []
|
|
216
|
+
for input, matrix in zip(inputs, self.matrix[index]):
|
|
217
|
+
result = input.reshape([*input.shape[:1], int(np.prod(input.shape[1:]))])
|
|
218
|
+
if input.shape[0] == 3:
|
|
219
|
+
matrix = matrix.to(input.device)
|
|
220
|
+
result = matrix[:, :3, :3] @ result.float() + matrix[:, :3, 3:]
|
|
221
|
+
elif input.shape[0] == 1:
|
|
222
|
+
matrix = matrix[:, :3, :].mean(dim=1, keepdims=True).to(input.device)
|
|
223
|
+
result = result.float() * matrix[:, :, :3].sum(dim=2, keepdims=True) + matrix[:, :, 3:]
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
|
|
226
|
+
results.append(result.reshape(input.shape))
|
|
227
|
+
return results
|
|
228
|
+
|
|
229
|
+
def _inverse(self, index: int, a: int, inputs : torch.Tensor) -> torch.Tensor:
|
|
230
|
+
pass
|
|
231
|
+
|
|
232
|
+
class Brightness(ColorTransform):
|
|
233
|
+
|
|
234
|
+
@config("Brightness")
|
|
235
|
+
def __init__(self, b_std: float) -> None:
|
|
236
|
+
super().__init__()
|
|
237
|
+
self.b_std = b_std
|
|
238
|
+
|
|
239
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
240
|
+
brightness = torch.Tensor.repeat((torch.randn((len(shapes)))*self.b_std).unsqueeze(1), [1, 3])
|
|
241
|
+
self.matrix[index] = [torch.unsqueeze(_translate3DMatrix(value), dim=0) for value in brightness]
|
|
242
|
+
return shapes
|
|
243
|
+
|
|
244
|
+
class Contrast(ColorTransform):
|
|
245
|
+
|
|
246
|
+
@config("Contrast")
|
|
247
|
+
def __init__(self, c_std: float) -> None:
|
|
248
|
+
super().__init__()
|
|
249
|
+
self.c_std = c_std
|
|
250
|
+
|
|
251
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
252
|
+
contrast = torch.exp2(torch.randn((len(shapes))) * self.c_std)
|
|
253
|
+
self.matrix[index] = [torch.unsqueeze(_scale3DMatrix(value), dim=0) for value in contrast]
|
|
254
|
+
return shapes
|
|
255
|
+
|
|
256
|
+
class LumaFlip(ColorTransform):
|
|
257
|
+
|
|
258
|
+
@config("LumaFlip")
|
|
259
|
+
def __init__(self) -> None:
|
|
260
|
+
super().__init__()
|
|
261
|
+
self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
|
|
262
|
+
|
|
263
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
264
|
+
luma = torch.floor(torch.rand([len(shapes), 1, 1]) * 2)
|
|
265
|
+
self.matrix[index] = [torch.unsqueeze((torch.eye(4) - 2 * self.v.ger(self.v) * value), dim=0) for value in luma]
|
|
266
|
+
return shapes
|
|
267
|
+
|
|
268
|
+
class HUE(ColorTransform):
|
|
269
|
+
|
|
270
|
+
@config("HUE")
|
|
271
|
+
def __init__(self, hue_max: float) -> None:
|
|
272
|
+
super().__init__()
|
|
273
|
+
self.hue_max = hue_max
|
|
274
|
+
self.v = torch.tensor([1, 1, 1])/torch.sqrt(torch.tensor(3))
|
|
275
|
+
|
|
276
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
277
|
+
theta = (torch.rand([len(shapes)]) * 2 - 1) * np.pi * self.hue_max
|
|
278
|
+
self.matrix[index] = [torch.unsqueeze(_rotation3DMatrix(value.repeat(3), self.v), dim=0) for value in theta]
|
|
279
|
+
return shapes
|
|
280
|
+
|
|
281
|
+
class Saturation(ColorTransform):
|
|
282
|
+
|
|
283
|
+
@config("Saturation")
|
|
284
|
+
def __init__(self, s_std: float) -> None:
|
|
285
|
+
super().__init__()
|
|
286
|
+
self.s_std = s_std
|
|
287
|
+
self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
|
|
288
|
+
|
|
289
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
290
|
+
saturation = torch.exp2(torch.randn((len(shapes))) * self.s_std)
|
|
291
|
+
self.matrix[index] = [(self.v.ger(self.v) + (torch.eye(4) - self.v.ger(self.v))).unsqueeze(0) * value for value in saturation]
|
|
292
|
+
return shapes
|
|
293
|
+
|
|
294
|
+
"""class Filter(DataAugmentation):
|
|
295
|
+
|
|
296
|
+
def __init__(self) -> None:
|
|
297
|
+
super().__init__()
|
|
298
|
+
wavelets = {
|
|
299
|
+
'haar': [0.7071067811865476, 0.7071067811865476],
|
|
300
|
+
'db1': [0.7071067811865476, 0.7071067811865476],
|
|
301
|
+
'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
|
302
|
+
'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
|
303
|
+
'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
|
|
304
|
+
'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
|
|
305
|
+
'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
|
|
306
|
+
'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
|
|
307
|
+
'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
|
|
308
|
+
'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
|
309
|
+
'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
|
310
|
+
'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
|
|
311
|
+
'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
|
|
312
|
+
'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
|
|
313
|
+
'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
|
|
314
|
+
'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
|
|
315
|
+
}
|
|
316
|
+
Hz_lo = np.asarray(wavelets['sym2']) # H(z)
|
|
317
|
+
Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
|
|
318
|
+
Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
|
|
319
|
+
Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
|
|
320
|
+
Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
|
|
321
|
+
for i in range(1, Hz_fbank.shape[0]):
|
|
322
|
+
Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
|
|
323
|
+
Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
|
|
324
|
+
Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
|
|
325
|
+
self.Hz_fbank = torch.as_tensor(Hz_fbank, dtype=torch.float32)
|
|
326
|
+
|
|
327
|
+
self.imgfilter_bands = [1,1,1,1]
|
|
328
|
+
self.imgfilter_std = 1
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
332
|
+
|
|
333
|
+
return shapes
|
|
334
|
+
|
|
335
|
+
def _compute(self, index: int, inputs : list[torch.Tensor]) -> list[torch.Tensor]:
|
|
336
|
+
num_bands = self.Hz_fbank.shape[0]
|
|
337
|
+
assert len(self.imgfilter_bands) == num_bands
|
|
338
|
+
expected_power =torch.tensor([10, 1, 1, 1]) / 13 # Expected power spectrum (1/f).
|
|
339
|
+
for input in inputs:
|
|
340
|
+
batch_size = input.shape[0]
|
|
341
|
+
num_channels = input.shape[1]
|
|
342
|
+
# Apply amplification for each band with probability (imgfilter * strength * band_strength).
|
|
343
|
+
g = torch.ones([batch_size, num_bands]) # Global gain vector (identity).
|
|
344
|
+
for i, band_strength in enumerate(self.imgfilter_bands):
|
|
345
|
+
t_i = torch.exp2(torch.randn([batch_size]) * self.imgfilter_std)
|
|
346
|
+
t = torch.ones([batch_size, num_bands]) # Temporary gain vector.
|
|
347
|
+
t[:, i] = t_i # Replace i'th element.
|
|
348
|
+
t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
|
|
349
|
+
g = g * t # Accumulate into global gain.
|
|
350
|
+
|
|
351
|
+
# Construct combined amplification filter.
|
|
352
|
+
Hz_prime = g @ self.Hz_fbank # [batch, tap]
|
|
353
|
+
Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
|
|
354
|
+
Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
|
|
355
|
+
|
|
356
|
+
# Apply filter.
|
|
357
|
+
p = self.Hz_fbank.shape[1] // 2
|
|
358
|
+
images = images.reshape([1, batch_size * num_channels, height, width])
|
|
359
|
+
images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
|
|
360
|
+
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
|
|
361
|
+
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
|
|
362
|
+
images = images.reshape([batch_size, num_channels, height, width])
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
366
|
+
pass """
|
|
367
|
+
|
|
368
|
+
class Noise(DataAugmentation):
|
|
369
|
+
|
|
370
|
+
@config("Noise")
|
|
371
|
+
def __init__(self, n_std: float, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02) -> None:
|
|
372
|
+
super().__init__()
|
|
373
|
+
self.n_std = n_std
|
|
374
|
+
self.noise_step = noise_step
|
|
375
|
+
|
|
376
|
+
self.ts: dict[int, list[torch.Tensor]] = {}
|
|
377
|
+
self.betas = torch.linspace(beta_start, beta_end, noise_step)
|
|
378
|
+
self.betas = Noise.enforce_zero_terminal_snr(self.betas)
|
|
379
|
+
self.alphas = 1 - self.betas
|
|
380
|
+
self.alpha_hat = torch.concat((torch.ones(1), torch.cumprod(self.alphas, dim=0)))
|
|
381
|
+
self.max_T = 0
|
|
382
|
+
|
|
383
|
+
self.C = 1
|
|
384
|
+
self.n = 4
|
|
385
|
+
self.d = 0.25
|
|
386
|
+
self._prob = 1
|
|
387
|
+
|
|
388
|
+
def enforce_zero_terminal_snr(betas: torch.Tensor):
|
|
389
|
+
alphas = 1 - betas
|
|
390
|
+
alphas_bar = alphas.cumprod(0)
|
|
391
|
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
|
392
|
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
|
393
|
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
|
394
|
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
|
395
|
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
|
|
396
|
+
alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
|
397
|
+
alphas_bar = alphas_bar_sqrt ** 2
|
|
398
|
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
|
399
|
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
|
400
|
+
betas = 1 - alphas
|
|
401
|
+
return betas
|
|
402
|
+
|
|
403
|
+
def load(self, prob: float):
|
|
404
|
+
self.max_T = prob*self.noise_step
|
|
405
|
+
|
|
406
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
407
|
+
if int(self.max_T) == 0:
|
|
408
|
+
self.ts[index] = [0 for _ in shapes]
|
|
409
|
+
else:
|
|
410
|
+
self.ts[index] = [torch.randint(0, int(self.max_T), (1,)) for _ in shapes]
|
|
411
|
+
return shapes
|
|
412
|
+
|
|
413
|
+
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
414
|
+
results = []
|
|
415
|
+
for input, t in zip(inputs, self.ts[index]):
|
|
416
|
+
alpha_hat_t = self.alpha_hat[t].to(input.device).reshape(*[1 for _ in range(len(input.shape))])
|
|
417
|
+
results.append(alpha_hat_t.sqrt() * input + (1 - alpha_hat_t).sqrt() * torch.randn_like(input.float()).to(input.device)*self.n_std)
|
|
418
|
+
return results
|
|
419
|
+
|
|
420
|
+
def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
421
|
+
pass
|
|
422
|
+
|
|
423
|
+
class CutOUT(DataAugmentation):
|
|
424
|
+
|
|
425
|
+
@config("CutOUT")
|
|
426
|
+
def __init__(self, c_prob: float, cutout_size: int, value: float) -> None:
|
|
427
|
+
super().__init__()
|
|
428
|
+
self.c_prob = c_prob
|
|
429
|
+
self.cutout_size = cutout_size
|
|
430
|
+
self.centers: dict[int, list[torch.Tensor]] = {}
|
|
431
|
+
self.value = value
|
|
432
|
+
|
|
433
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
434
|
+
self.centers[index] = [torch.rand((3) if len(shape) == 3 else (2)) for shape in shapes]
|
|
435
|
+
return shapes
|
|
436
|
+
|
|
437
|
+
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
438
|
+
results = []
|
|
439
|
+
for input, center in zip(inputs, self.centers[index]):
|
|
440
|
+
masks = []
|
|
441
|
+
for i, w in enumerate(input.shape[1:]):
|
|
442
|
+
re = [1]*i+[-1]+[1]*(len(input.shape[1:])-i-1)
|
|
443
|
+
masks.append((((torch.arange(w).reshape(re) + 0.5) / w - center[i].reshape([1, 1])).abs() >= torch.tensor(self.cutout_size).reshape([1, 1])/ 2))
|
|
444
|
+
result = masks[0]
|
|
445
|
+
for mask in masks[1:]:
|
|
446
|
+
result = torch.logical_or(result, mask)
|
|
447
|
+
result = result.unsqueeze(0).repeat([input.shape[0], *[1 for _ in range(len(input.shape)-1)]])
|
|
448
|
+
results.append(torch.where(result.to(input.device) == 1, input, torch.tensor(self.value).to(input.device)))
|
|
449
|
+
return results
|
|
450
|
+
|
|
451
|
+
def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
452
|
+
pass
|
|
453
|
+
|
|
454
|
+
class Elastix(DataAugmentation):
|
|
455
|
+
|
|
456
|
+
@config("Elastix")
|
|
457
|
+
def __init__(self, grid_spacing: int = 16, max_displacement: int = 16) -> None:
|
|
458
|
+
super().__init__()
|
|
459
|
+
self.grid_spacing = grid_spacing
|
|
460
|
+
self.max_displacement = max_displacement
|
|
461
|
+
self.displacement_fields: dict[int, list[torch.Tensor]] = {}
|
|
462
|
+
self.displacement_fields_true: dict[int, list[torch.Tensor]] = {}
|
|
463
|
+
|
|
464
|
+
@staticmethod
|
|
465
|
+
def _formatLoc(new_locs, shape):
|
|
466
|
+
for i in range(len(shape)):
|
|
467
|
+
new_locs[..., i] = 2 * (new_locs[..., i] / (shape[i] - 1) - 0.5)
|
|
468
|
+
new_locs = new_locs[..., [i for i in reversed(range(len(shape)))]]
|
|
469
|
+
return new_locs
|
|
470
|
+
|
|
471
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
472
|
+
print("Compute Displacement Field for index {}".format(index))
|
|
473
|
+
self.displacement_fields[index] = []
|
|
474
|
+
self.displacement_fields_true[index] = []
|
|
475
|
+
for i, (shape, cache_attribute) in enumerate(zip(shapes, caches_attribute)):
|
|
476
|
+
shape = shape
|
|
477
|
+
dim = len(shape)
|
|
478
|
+
if "Spacing" not in cache_attribute:
|
|
479
|
+
spacing = np.array([1.0 for _ in range(dim)])
|
|
480
|
+
else:
|
|
481
|
+
spacing = cache_attribute.get_np_array("Spacing")
|
|
482
|
+
|
|
483
|
+
grid_physical_spacing = [self.grid_spacing]*dim
|
|
484
|
+
image_physical_size = [size*spacing for size, spacing in zip(shape, spacing)]
|
|
485
|
+
mesh_size = [int(image_size/grid_spacing + 0.5) for image_size,grid_spacing in zip(image_physical_size, grid_physical_spacing)]
|
|
486
|
+
if "Spacing" not in cache_attribute:
|
|
487
|
+
cache_attribute["Spacing"] = np.array([1.0 for _ in range(dim)])
|
|
488
|
+
if "Origin" not in cache_attribute:
|
|
489
|
+
cache_attribute["Origin"] = np.array([1.0 for _ in range(dim)])
|
|
490
|
+
if "Direction" not in cache_attribute:
|
|
491
|
+
cache_attribute["Direction"] = np.eye(dim).flatten()
|
|
492
|
+
|
|
493
|
+
ref_image = data_to_image(np.expand_dims(np.zeros(shape), 0), cache_attribute)
|
|
494
|
+
|
|
495
|
+
bspline_transform = sitk.BSplineTransformInitializer(image1 = ref_image, transformDomainMeshSize = mesh_size, order=3)
|
|
496
|
+
displacement_filter = sitk.TransformToDisplacementFieldFilter()
|
|
497
|
+
displacement_filter.SetReferenceImage(ref_image)
|
|
498
|
+
|
|
499
|
+
vectors = [torch.arange(0, s) for s in shape]
|
|
500
|
+
grids = torch.meshgrid(vectors, indexing='ij')
|
|
501
|
+
grid = torch.stack(grids)
|
|
502
|
+
grid = torch.unsqueeze(grid, 0)
|
|
503
|
+
grid = grid.type(torch.float).permute([0]+[i+2 for i in range(len(shape))] + [1])
|
|
504
|
+
|
|
505
|
+
control_points = torch.rand(*[size+3 for size in mesh_size], dim)
|
|
506
|
+
control_points -= 0.5
|
|
507
|
+
control_points *= 2*self.max_displacement
|
|
508
|
+
bspline_transform.SetParameters(control_points.flatten().tolist())
|
|
509
|
+
displacement = sitk.GetArrayFromImage(displacement_filter.Execute(bspline_transform))
|
|
510
|
+
self.displacement_fields_true[index].append(displacement)
|
|
511
|
+
new_locs = grid+torch.unsqueeze(torch.from_numpy(displacement), 0).type(torch.float32)
|
|
512
|
+
self.displacement_fields[index].append(Elastix._formatLoc(new_locs, shape))
|
|
513
|
+
print("Compute in progress : {:.2f} %".format((i+1)/len(shapes)*100))
|
|
514
|
+
return shapes
|
|
515
|
+
|
|
516
|
+
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
517
|
+
results = []
|
|
518
|
+
for input, displacement_field in zip(inputs, self.displacement_fields[index]):
|
|
519
|
+
results.append(F.grid_sample(input.type(torch.float32).unsqueeze(0), displacement_field.to(input.device), align_corners=True, mode="bilinear", padding_mode="border").type(input.dtype).squeeze(0))
|
|
520
|
+
return results
|
|
521
|
+
|
|
522
|
+
def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
523
|
+
pass
|
|
524
|
+
|
|
525
|
+
class Permute(DataAugmentation):
|
|
526
|
+
|
|
527
|
+
@config("Permute")
|
|
528
|
+
def __init__(self, prob_permute: Union[list[float], None] = [0.33 ,0.33]) -> None:
|
|
529
|
+
super().__init__()
|
|
530
|
+
self._permute_dims = torch.tensor([[0, 2, 1, 3], [0, 3, 1, 2]])
|
|
531
|
+
self.prob_permute = prob_permute
|
|
532
|
+
self.permute: dict[int, torch.Tensor] = {}
|
|
533
|
+
|
|
534
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
535
|
+
if len(shapes):
|
|
536
|
+
dim = len(shapes[0])
|
|
537
|
+
assert dim == 3, "The permute augmentation only support 3D images"
|
|
538
|
+
if self.prob_permute:
|
|
539
|
+
assert len(self.prob_permute) == 2, "len of prob_permute must be equal 2"
|
|
540
|
+
self.permute[index] = torch.rand((len(shapes), len(self.prob_permute))) < torch.tensor(self.prob_permute)
|
|
541
|
+
else:
|
|
542
|
+
assert len(shapes) == 2, "The number of augmentation images must be equal to 2"
|
|
543
|
+
self.permute[index] = torch.eye(2, dtype=torch.bool)
|
|
544
|
+
for i, prob in enumerate(self.permute[index]):
|
|
545
|
+
for permute in self._permute_dims[prob]:
|
|
546
|
+
shapes[i] = [shapes[i][dim-1] for dim in permute[1:]]
|
|
547
|
+
return shapes
|
|
548
|
+
|
|
549
|
+
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
550
|
+
results = []
|
|
551
|
+
for input, prob in zip(inputs, self.permute[index]):
|
|
552
|
+
res = input
|
|
553
|
+
for permute in self._permute_dims[prob]:
|
|
554
|
+
res = res.permute(tuple(permute))
|
|
555
|
+
results.append(res)
|
|
556
|
+
return results
|
|
557
|
+
|
|
558
|
+
def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
559
|
+
for permute in reversed(self._permute_dims[self.permute[index][a]]):
|
|
560
|
+
input = input.permute(tuple(np.argsort(permute)))
|
|
561
|
+
return input
|
|
562
|
+
|
|
563
|
+
class Mask(DataAugmentation):
|
|
564
|
+
|
|
565
|
+
@config("Mask")
|
|
566
|
+
def __init__(self, mask: str, value: float) -> None:
|
|
567
|
+
super().__init__()
|
|
568
|
+
if mask is not None:
|
|
569
|
+
if os.path.exists(mask):
|
|
570
|
+
self.mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(mask)))
|
|
571
|
+
else:
|
|
572
|
+
raise NameError('Mask file not found')
|
|
573
|
+
self.positions: dict[int, list[torch.Tensor]] = {}
|
|
574
|
+
self.value = value
|
|
575
|
+
|
|
576
|
+
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
577
|
+
self.positions[index] = [torch.rand((3) if len(shape) == 3 else (2))*(torch.tensor([max(s1-s2, 0) for s1, s2 in zip(torch.tensor(shape), torch.tensor(self.mask.shape))])) for shape in shapes]
|
|
578
|
+
return [self.mask.shape for _ in shapes]
|
|
579
|
+
|
|
580
|
+
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
581
|
+
results = []
|
|
582
|
+
for input, position in zip(inputs, self.positions[index]):
|
|
583
|
+
slices = [slice(None, None)]+[slice(int(s1), int(s1)+s2) for s1, s2 in zip(position, self.mask.shape)]
|
|
584
|
+
padding = []
|
|
585
|
+
for s1, s2 in zip(reversed(input.shape), reversed(self.mask.shape)):
|
|
586
|
+
if s1 < s2:
|
|
587
|
+
pad = s2-s1
|
|
588
|
+
else:
|
|
589
|
+
pad = 0
|
|
590
|
+
padding.append(0)
|
|
591
|
+
padding.append(pad)
|
|
592
|
+
value = torch.tensor(0, dtype=torch.uint8) if input.dtype == torch.uint8 else torch.tensor(self.value).to(input.device)
|
|
593
|
+
results.append(torch.where(self.mask.to(input.device) == 1, torch.nn.functional.pad(input, tuple(padding), mode="constant", value=value)[tuple(slices)], value))
|
|
594
|
+
return results
|
|
595
|
+
|
|
596
|
+
def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
597
|
+
pass
|