konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -0,0 +1,597 @@
1
+ import importlib
2
+ import torch
3
+ from abc import ABC, abstractmethod
4
+ import numpy as np
5
+ import SimpleITK as sitk
6
+ import torch.nn.functional as F
7
+ from typing import Union
8
+ import os
9
+ from KonfAI.konfai import DEEP_LEARNING_API_ROOT
10
+ from KonfAI.konfai.utils.config import config
11
+ from KonfAI.konfai.utils.utils import _getModule
12
+ from KonfAI.konfai.utils.dataset import Attribute, data_to_image
13
+
14
+
15
+ def _translate2DMatrix(t: torch.Tensor) -> torch.Tensor:
16
+ return torch.cat((torch.cat((torch.eye(2), torch.tensor([[t[0]], [t[1]]])), dim=1), torch.Tensor([[0,0,1]])), dim=0)
17
+
18
+ def _translate3DMatrix(t: torch.Tensor) -> torch.Tensor:
19
+ return torch.cat((torch.cat((torch.eye(3), torch.tensor([[t[0]], [t[1]], [t[2]]])), dim=1), torch.Tensor([[0,0,0,1]])), dim=0)
20
+
21
+ def _scale2DMatrix(s: torch.Tensor) -> torch.Tensor:
22
+ return torch.cat((torch.cat((torch.eye(2)*s, torch.tensor([[0], [0]])), dim=1), torch.tensor([[0, 0, 1]])), dim=0)
23
+
24
+ def _scale3DMatrix(s: torch.Tensor) -> torch.Tensor:
25
+ return torch.cat((torch.cat((torch.eye(3)*s, torch.tensor([[0], [0], [0]])), dim=1), torch.tensor([[0, 0, 0, 1]])), dim=0)
26
+
27
+ def _rotation3DMatrix(rotation : torch.Tensor, center: Union[torch.Tensor, None] = None) -> torch.Tensor:
28
+ A = torch.tensor([[torch.cos(rotation[2]), -torch.sin(rotation[2]), 0], [torch.sin(rotation[2]), torch.cos(rotation[2]), 0], [0, 0, 1]])
29
+ B = torch.tensor([[torch.cos(rotation[1]), 0, torch.sin(rotation[1])], [0, 1, 0], [-torch.sin(rotation[1]), 0, torch.cos(rotation[1])]])
30
+ C = torch.tensor([[1, 0, 0], [0, torch.cos(rotation[0]), -torch.sin(rotation[0])], [0, torch.sin(rotation[0]), torch.cos(rotation[0])]])
31
+ rotation_matrix = torch.cat((torch.cat((A.mm(B).mm(C), torch.zeros((3, 1))), dim=1), torch.tensor([[0, 0, 0, 1]])), dim=0)
32
+ if center is not None:
33
+ translation_before = torch.eye(4)
34
+ translation_before[:-1, -1] = -center
35
+ rotation_matrix = translation_before.mm(rotation_matrix)
36
+ if center is not None:
37
+ translation_after = torch.eye(4)
38
+ translation_after[:-1, -1] = center
39
+ rotation_matrix = rotation_matrix.mm(translation_after)
40
+ return rotation_matrix
41
+
42
+ def _rotation2DMatrix(rotation : torch.Tensor, center: Union[torch.Tensor, None] = None) -> torch.Tensor:
43
+ return torch.cat((torch.cat((torch.tensor([[torch.cos(rotation[0]), -torch.sin(rotation[0])], [torch.sin(rotation[0]), torch.cos(rotation[0])]]), torch.zeros((2, 1))), dim=1), torch.tensor([[0, 0, 1]])), dim=0)
44
+
45
+ class Prob():
46
+
47
+ @config()
48
+ def __init__(self, prob: float = 1.0) -> None:
49
+ self.prob = prob
50
+
51
+ class DataAugmentationsList():
52
+
53
+ @config()
54
+ def __init__(self, nb : int = 10, dataAugmentations: dict[str, Prob] = {"default:RandomElastixTransform" : Prob(1)}) -> None:
55
+ self.nb = nb
56
+ self.dataAugmentations : list[DataAugmentation] = []
57
+ self.dataAugmentationsLoader = dataAugmentations
58
+
59
+ def load(self, key: str):
60
+ for augmentation, prob in self.dataAugmentationsLoader.items():
61
+ module, name = _getModule(augmentation, "augmentation")
62
+ dataAugmentation: DataAugmentation = getattr(importlib.import_module(module), name)(config = None, DL_args="{}.Dataset.augmentations.{}.dataAugmentations".format(DEEP_LEARNING_API_ROOT(), key))
63
+ dataAugmentation.load(prob.prob)
64
+ self.dataAugmentations.append(dataAugmentation)
65
+
66
+ class DataAugmentation(ABC):
67
+
68
+ def __init__(self) -> None:
69
+ self.who_index: dict[int, list[int]] = {}
70
+ self.shape_index: dict[int, list[list[int]]] = {}
71
+ self._prob: float = 0
72
+
73
+ def load(self, prob: float):
74
+ self._prob = prob
75
+
76
+ def state_init(self, index: Union[None, int], shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
77
+ if index is not None:
78
+ if index not in self.who_index:
79
+ self.who_index[index] = torch.where(torch.rand(len(shapes)) < self._prob)[0].tolist()
80
+ else:
81
+ return self.shape_index[index]
82
+ else:
83
+ index = 0
84
+ self.who_index[index] = torch.where(torch.rand(len(shapes)) < self._prob)[0].tolist()
85
+
86
+ if len(self.who_index[index]) > 0:
87
+ for i, shape in enumerate(self._state_init(index, [shapes[i] for i in self.who_index[index]], [caches_attribute[i] for i in self.who_index[index]])):
88
+ shapes[self.who_index[index][i]] = shape
89
+ self.shape_index[index] = shapes
90
+ return self.shape_index[index]
91
+
92
+ @abstractmethod
93
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
94
+ pass
95
+
96
+ def __call__(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
97
+ if len(self.who_index[index]) > 0:
98
+ for i, result in enumerate(self._compute(index, [inputs[i] for i in self.who_index[index]], device)):
99
+ inputs[self.who_index[index][i]] = result if device is None else result.cpu()
100
+ return inputs
101
+
102
+ @abstractmethod
103
+ def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
104
+ pass
105
+
106
+ def inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
107
+ if a in self.who_index[index]:
108
+ input = self._inverse(index, a, input)
109
+ return input
110
+
111
+ @abstractmethod
112
+ def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
113
+ pass
114
+
115
+ class EulerTransform(DataAugmentation):
116
+
117
+ def __init__(self) -> None:
118
+ super().__init__()
119
+ self.matrix: dict[int, list[torch.Tensor]] = {}
120
+
121
+ def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
122
+ results = []
123
+ for input, matrix in zip(inputs, self.matrix[index]):
124
+ results.append(F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0))
125
+ return results
126
+
127
+ def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
128
+ return F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(self.matrix[index][a].inverse()[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0)
129
+
130
+ class Translate(EulerTransform):
131
+
132
+ @config("Translate")
133
+ def __init__(self, t_min: float = -10, t_max = 10, is_int: bool = False):
134
+ super().__init__()
135
+ self.t_min = t_min
136
+ self.t_max = t_max
137
+ self.is_int = is_int
138
+
139
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
140
+ dim = len(shapes[0])
141
+ func = _translate3DMatrix if dim == 3 else _translate2DMatrix
142
+ translate = torch.rand((len(shapes), dim)) * torch.tensor(self.t_max-self.t_min) + torch.tensor(self.t_min)
143
+ if self.is_int:
144
+ translate = torch.round(translate*100)/100
145
+ self.matrix[index] = [torch.unsqueeze(func(value), dim=0) for value in translate]
146
+ return shapes
147
+
148
+ class Rotate(EulerTransform):
149
+
150
+ @config("Rotate")
151
+ def __init__(self, a_min: float = 0, a_max: float = 360, is_quarter: bool = False):
152
+ super().__init__()
153
+ self.a_min = a_min
154
+ self.a_max = a_max
155
+ self.is_quarter = is_quarter
156
+
157
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
158
+ dim = len(shapes[0])
159
+ func = _rotation3DMatrix if dim == 3 else _rotation2DMatrix
160
+ angles = []
161
+
162
+ if self.is_quarter:
163
+ angles = torch.Tensor.repeat(torch.tensor([90,180,270]), 3)
164
+ else:
165
+ angles = torch.rand((len(shapes), dim))*torch.tensor(self.a_max-self.a_min) + torch.tensor(self.a_min)
166
+
167
+ self.matrix[index] = [torch.unsqueeze(func(value), dim=0) for value in angles]
168
+ return shapes
169
+
170
+ class Scale(EulerTransform):
171
+
172
+ @config("Scale")
173
+ def __init__(self, s_std: float = 0.2):
174
+ super().__init__()
175
+ self.s_std = s_std
176
+
177
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
178
+ func = _scale3DMatrix if len(shapes[0]) == 3 else _scale2DMatrix
179
+ scale = torch.Tensor.repeat(torch.exp2(torch.randn((len(shapes))) * self.s_std).unsqueeze(1), [1, len(shapes[0])])
180
+ self.matrix[index] = [torch.unsqueeze(func(value), dim=0) for value in scale]
181
+ return shapes
182
+
183
+ class Flip(DataAugmentation):
184
+
185
+ @config("Flip")
186
+ def __init__(self, f_prob: Union[list[float], None] = [0.33, 0.33 ,0.33]) -> None:
187
+ super().__init__()
188
+ self.f_prob = f_prob
189
+ self.flip: dict[int, list[int]] = {}
190
+
191
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
192
+ prob = torch.rand((len(shapes), len(self.f_prob))) < torch.tensor(self.f_prob)
193
+ dims = torch.tensor([1, 2, 3][:len(self.f_prob)])
194
+ self.flip[index] = [dims[mask].tolist() for mask in prob]
195
+ return shapes
196
+
197
+ def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
198
+ results = []
199
+ for input, flip in zip(inputs, self.flip[index]):
200
+ results.append(torch.flip(input, dims=flip))
201
+ return results
202
+
203
+ def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
204
+ return torch.flip(input, dims=self.flip[index][a])
205
+
206
+
207
+ class ColorTransform(DataAugmentation):
208
+
209
+ @config("ColorTransform")
210
+ def __init__(self) -> None:
211
+ super().__init__()
212
+ self.matrix: dict[int, list[torch.Tensor]] = {}
213
+
214
+ def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
215
+ results = []
216
+ for input, matrix in zip(inputs, self.matrix[index]):
217
+ result = input.reshape([*input.shape[:1], int(np.prod(input.shape[1:]))])
218
+ if input.shape[0] == 3:
219
+ matrix = matrix.to(input.device)
220
+ result = matrix[:, :3, :3] @ result.float() + matrix[:, :3, 3:]
221
+ elif input.shape[0] == 1:
222
+ matrix = matrix[:, :3, :].mean(dim=1, keepdims=True).to(input.device)
223
+ result = result.float() * matrix[:, :, :3].sum(dim=2, keepdims=True) + matrix[:, :, 3:]
224
+ else:
225
+ raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
226
+ results.append(result.reshape(input.shape))
227
+ return results
228
+
229
+ def _inverse(self, index: int, a: int, inputs : torch.Tensor) -> torch.Tensor:
230
+ pass
231
+
232
+ class Brightness(ColorTransform):
233
+
234
+ @config("Brightness")
235
+ def __init__(self, b_std: float) -> None:
236
+ super().__init__()
237
+ self.b_std = b_std
238
+
239
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
240
+ brightness = torch.Tensor.repeat((torch.randn((len(shapes)))*self.b_std).unsqueeze(1), [1, 3])
241
+ self.matrix[index] = [torch.unsqueeze(_translate3DMatrix(value), dim=0) for value in brightness]
242
+ return shapes
243
+
244
+ class Contrast(ColorTransform):
245
+
246
+ @config("Contrast")
247
+ def __init__(self, c_std: float) -> None:
248
+ super().__init__()
249
+ self.c_std = c_std
250
+
251
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
252
+ contrast = torch.exp2(torch.randn((len(shapes))) * self.c_std)
253
+ self.matrix[index] = [torch.unsqueeze(_scale3DMatrix(value), dim=0) for value in contrast]
254
+ return shapes
255
+
256
+ class LumaFlip(ColorTransform):
257
+
258
+ @config("LumaFlip")
259
+ def __init__(self) -> None:
260
+ super().__init__()
261
+ self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
262
+
263
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
264
+ luma = torch.floor(torch.rand([len(shapes), 1, 1]) * 2)
265
+ self.matrix[index] = [torch.unsqueeze((torch.eye(4) - 2 * self.v.ger(self.v) * value), dim=0) for value in luma]
266
+ return shapes
267
+
268
+ class HUE(ColorTransform):
269
+
270
+ @config("HUE")
271
+ def __init__(self, hue_max: float) -> None:
272
+ super().__init__()
273
+ self.hue_max = hue_max
274
+ self.v = torch.tensor([1, 1, 1])/torch.sqrt(torch.tensor(3))
275
+
276
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
277
+ theta = (torch.rand([len(shapes)]) * 2 - 1) * np.pi * self.hue_max
278
+ self.matrix[index] = [torch.unsqueeze(_rotation3DMatrix(value.repeat(3), self.v), dim=0) for value in theta]
279
+ return shapes
280
+
281
+ class Saturation(ColorTransform):
282
+
283
+ @config("Saturation")
284
+ def __init__(self, s_std: float) -> None:
285
+ super().__init__()
286
+ self.s_std = s_std
287
+ self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
288
+
289
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
290
+ saturation = torch.exp2(torch.randn((len(shapes))) * self.s_std)
291
+ self.matrix[index] = [(self.v.ger(self.v) + (torch.eye(4) - self.v.ger(self.v))).unsqueeze(0) * value for value in saturation]
292
+ return shapes
293
+
294
+ """class Filter(DataAugmentation):
295
+
296
+ def __init__(self) -> None:
297
+ super().__init__()
298
+ wavelets = {
299
+ 'haar': [0.7071067811865476, 0.7071067811865476],
300
+ 'db1': [0.7071067811865476, 0.7071067811865476],
301
+ 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
302
+ 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
303
+ 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
304
+ 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
305
+ 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
306
+ 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
307
+ 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
308
+ 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
309
+ 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
310
+ 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
311
+ 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
312
+ 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
313
+ 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
314
+ 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
315
+ }
316
+ Hz_lo = np.asarray(wavelets['sym2']) # H(z)
317
+ Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
318
+ Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
319
+ Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
320
+ Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
321
+ for i in range(1, Hz_fbank.shape[0]):
322
+ Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
323
+ Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
324
+ Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
325
+ self.Hz_fbank = torch.as_tensor(Hz_fbank, dtype=torch.float32)
326
+
327
+ self.imgfilter_bands = [1,1,1,1]
328
+ self.imgfilter_std = 1
329
+
330
+
331
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
332
+
333
+ return shapes
334
+
335
+ def _compute(self, index: int, inputs : list[torch.Tensor]) -> list[torch.Tensor]:
336
+ num_bands = self.Hz_fbank.shape[0]
337
+ assert len(self.imgfilter_bands) == num_bands
338
+ expected_power =torch.tensor([10, 1, 1, 1]) / 13 # Expected power spectrum (1/f).
339
+ for input in inputs:
340
+ batch_size = input.shape[0]
341
+ num_channels = input.shape[1]
342
+ # Apply amplification for each band with probability (imgfilter * strength * band_strength).
343
+ g = torch.ones([batch_size, num_bands]) # Global gain vector (identity).
344
+ for i, band_strength in enumerate(self.imgfilter_bands):
345
+ t_i = torch.exp2(torch.randn([batch_size]) * self.imgfilter_std)
346
+ t = torch.ones([batch_size, num_bands]) # Temporary gain vector.
347
+ t[:, i] = t_i # Replace i'th element.
348
+ t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
349
+ g = g * t # Accumulate into global gain.
350
+
351
+ # Construct combined amplification filter.
352
+ Hz_prime = g @ self.Hz_fbank # [batch, tap]
353
+ Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
354
+ Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
355
+
356
+ # Apply filter.
357
+ p = self.Hz_fbank.shape[1] // 2
358
+ images = images.reshape([1, batch_size * num_channels, height, width])
359
+ images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
360
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
361
+ images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
362
+ images = images.reshape([batch_size, num_channels, height, width])
363
+
364
+
365
+ def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
366
+ pass """
367
+
368
+ class Noise(DataAugmentation):
369
+
370
+ @config("Noise")
371
+ def __init__(self, n_std: float, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02) -> None:
372
+ super().__init__()
373
+ self.n_std = n_std
374
+ self.noise_step = noise_step
375
+
376
+ self.ts: dict[int, list[torch.Tensor]] = {}
377
+ self.betas = torch.linspace(beta_start, beta_end, noise_step)
378
+ self.betas = Noise.enforce_zero_terminal_snr(self.betas)
379
+ self.alphas = 1 - self.betas
380
+ self.alpha_hat = torch.concat((torch.ones(1), torch.cumprod(self.alphas, dim=0)))
381
+ self.max_T = 0
382
+
383
+ self.C = 1
384
+ self.n = 4
385
+ self.d = 0.25
386
+ self._prob = 1
387
+
388
+ def enforce_zero_terminal_snr(betas: torch.Tensor):
389
+ alphas = 1 - betas
390
+ alphas_bar = alphas.cumprod(0)
391
+ alphas_bar_sqrt = alphas_bar.sqrt()
392
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
393
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
394
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
395
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
396
+ alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
397
+ alphas_bar = alphas_bar_sqrt ** 2
398
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
399
+ alphas = torch.cat([alphas_bar[0:1], alphas])
400
+ betas = 1 - alphas
401
+ return betas
402
+
403
+ def load(self, prob: float):
404
+ self.max_T = prob*self.noise_step
405
+
406
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
407
+ if int(self.max_T) == 0:
408
+ self.ts[index] = [0 for _ in shapes]
409
+ else:
410
+ self.ts[index] = [torch.randint(0, int(self.max_T), (1,)) for _ in shapes]
411
+ return shapes
412
+
413
+ def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
414
+ results = []
415
+ for input, t in zip(inputs, self.ts[index]):
416
+ alpha_hat_t = self.alpha_hat[t].to(input.device).reshape(*[1 for _ in range(len(input.shape))])
417
+ results.append(alpha_hat_t.sqrt() * input + (1 - alpha_hat_t).sqrt() * torch.randn_like(input.float()).to(input.device)*self.n_std)
418
+ return results
419
+
420
+ def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
421
+ pass
422
+
423
+ class CutOUT(DataAugmentation):
424
+
425
+ @config("CutOUT")
426
+ def __init__(self, c_prob: float, cutout_size: int, value: float) -> None:
427
+ super().__init__()
428
+ self.c_prob = c_prob
429
+ self.cutout_size = cutout_size
430
+ self.centers: dict[int, list[torch.Tensor]] = {}
431
+ self.value = value
432
+
433
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
434
+ self.centers[index] = [torch.rand((3) if len(shape) == 3 else (2)) for shape in shapes]
435
+ return shapes
436
+
437
+ def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
438
+ results = []
439
+ for input, center in zip(inputs, self.centers[index]):
440
+ masks = []
441
+ for i, w in enumerate(input.shape[1:]):
442
+ re = [1]*i+[-1]+[1]*(len(input.shape[1:])-i-1)
443
+ masks.append((((torch.arange(w).reshape(re) + 0.5) / w - center[i].reshape([1, 1])).abs() >= torch.tensor(self.cutout_size).reshape([1, 1])/ 2))
444
+ result = masks[0]
445
+ for mask in masks[1:]:
446
+ result = torch.logical_or(result, mask)
447
+ result = result.unsqueeze(0).repeat([input.shape[0], *[1 for _ in range(len(input.shape)-1)]])
448
+ results.append(torch.where(result.to(input.device) == 1, input, torch.tensor(self.value).to(input.device)))
449
+ return results
450
+
451
+ def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
452
+ pass
453
+
454
+ class Elastix(DataAugmentation):
455
+
456
+ @config("Elastix")
457
+ def __init__(self, grid_spacing: int = 16, max_displacement: int = 16) -> None:
458
+ super().__init__()
459
+ self.grid_spacing = grid_spacing
460
+ self.max_displacement = max_displacement
461
+ self.displacement_fields: dict[int, list[torch.Tensor]] = {}
462
+ self.displacement_fields_true: dict[int, list[torch.Tensor]] = {}
463
+
464
+ @staticmethod
465
+ def _formatLoc(new_locs, shape):
466
+ for i in range(len(shape)):
467
+ new_locs[..., i] = 2 * (new_locs[..., i] / (shape[i] - 1) - 0.5)
468
+ new_locs = new_locs[..., [i for i in reversed(range(len(shape)))]]
469
+ return new_locs
470
+
471
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
472
+ print("Compute Displacement Field for index {}".format(index))
473
+ self.displacement_fields[index] = []
474
+ self.displacement_fields_true[index] = []
475
+ for i, (shape, cache_attribute) in enumerate(zip(shapes, caches_attribute)):
476
+ shape = shape
477
+ dim = len(shape)
478
+ if "Spacing" not in cache_attribute:
479
+ spacing = np.array([1.0 for _ in range(dim)])
480
+ else:
481
+ spacing = cache_attribute.get_np_array("Spacing")
482
+
483
+ grid_physical_spacing = [self.grid_spacing]*dim
484
+ image_physical_size = [size*spacing for size, spacing in zip(shape, spacing)]
485
+ mesh_size = [int(image_size/grid_spacing + 0.5) for image_size,grid_spacing in zip(image_physical_size, grid_physical_spacing)]
486
+ if "Spacing" not in cache_attribute:
487
+ cache_attribute["Spacing"] = np.array([1.0 for _ in range(dim)])
488
+ if "Origin" not in cache_attribute:
489
+ cache_attribute["Origin"] = np.array([1.0 for _ in range(dim)])
490
+ if "Direction" not in cache_attribute:
491
+ cache_attribute["Direction"] = np.eye(dim).flatten()
492
+
493
+ ref_image = data_to_image(np.expand_dims(np.zeros(shape), 0), cache_attribute)
494
+
495
+ bspline_transform = sitk.BSplineTransformInitializer(image1 = ref_image, transformDomainMeshSize = mesh_size, order=3)
496
+ displacement_filter = sitk.TransformToDisplacementFieldFilter()
497
+ displacement_filter.SetReferenceImage(ref_image)
498
+
499
+ vectors = [torch.arange(0, s) for s in shape]
500
+ grids = torch.meshgrid(vectors, indexing='ij')
501
+ grid = torch.stack(grids)
502
+ grid = torch.unsqueeze(grid, 0)
503
+ grid = grid.type(torch.float).permute([0]+[i+2 for i in range(len(shape))] + [1])
504
+
505
+ control_points = torch.rand(*[size+3 for size in mesh_size], dim)
506
+ control_points -= 0.5
507
+ control_points *= 2*self.max_displacement
508
+ bspline_transform.SetParameters(control_points.flatten().tolist())
509
+ displacement = sitk.GetArrayFromImage(displacement_filter.Execute(bspline_transform))
510
+ self.displacement_fields_true[index].append(displacement)
511
+ new_locs = grid+torch.unsqueeze(torch.from_numpy(displacement), 0).type(torch.float32)
512
+ self.displacement_fields[index].append(Elastix._formatLoc(new_locs, shape))
513
+ print("Compute in progress : {:.2f} %".format((i+1)/len(shapes)*100))
514
+ return shapes
515
+
516
+ def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
517
+ results = []
518
+ for input, displacement_field in zip(inputs, self.displacement_fields[index]):
519
+ results.append(F.grid_sample(input.type(torch.float32).unsqueeze(0), displacement_field.to(input.device), align_corners=True, mode="bilinear", padding_mode="border").type(input.dtype).squeeze(0))
520
+ return results
521
+
522
+ def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
523
+ pass
524
+
525
+ class Permute(DataAugmentation):
526
+
527
+ @config("Permute")
528
+ def __init__(self, prob_permute: Union[list[float], None] = [0.33 ,0.33]) -> None:
529
+ super().__init__()
530
+ self._permute_dims = torch.tensor([[0, 2, 1, 3], [0, 3, 1, 2]])
531
+ self.prob_permute = prob_permute
532
+ self.permute: dict[int, torch.Tensor] = {}
533
+
534
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
535
+ if len(shapes):
536
+ dim = len(shapes[0])
537
+ assert dim == 3, "The permute augmentation only support 3D images"
538
+ if self.prob_permute:
539
+ assert len(self.prob_permute) == 2, "len of prob_permute must be equal 2"
540
+ self.permute[index] = torch.rand((len(shapes), len(self.prob_permute))) < torch.tensor(self.prob_permute)
541
+ else:
542
+ assert len(shapes) == 2, "The number of augmentation images must be equal to 2"
543
+ self.permute[index] = torch.eye(2, dtype=torch.bool)
544
+ for i, prob in enumerate(self.permute[index]):
545
+ for permute in self._permute_dims[prob]:
546
+ shapes[i] = [shapes[i][dim-1] for dim in permute[1:]]
547
+ return shapes
548
+
549
+ def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
550
+ results = []
551
+ for input, prob in zip(inputs, self.permute[index]):
552
+ res = input
553
+ for permute in self._permute_dims[prob]:
554
+ res = res.permute(tuple(permute))
555
+ results.append(res)
556
+ return results
557
+
558
+ def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
559
+ for permute in reversed(self._permute_dims[self.permute[index][a]]):
560
+ input = input.permute(tuple(np.argsort(permute)))
561
+ return input
562
+
563
+ class Mask(DataAugmentation):
564
+
565
+ @config("Mask")
566
+ def __init__(self, mask: str, value: float) -> None:
567
+ super().__init__()
568
+ if mask is not None:
569
+ if os.path.exists(mask):
570
+ self.mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(mask)))
571
+ else:
572
+ raise NameError('Mask file not found')
573
+ self.positions: dict[int, list[torch.Tensor]] = {}
574
+ self.value = value
575
+
576
+ def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
577
+ self.positions[index] = [torch.rand((3) if len(shape) == 3 else (2))*(torch.tensor([max(s1-s2, 0) for s1, s2 in zip(torch.tensor(shape), torch.tensor(self.mask.shape))])) for shape in shapes]
578
+ return [self.mask.shape for _ in shapes]
579
+
580
+ def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
581
+ results = []
582
+ for input, position in zip(inputs, self.positions[index]):
583
+ slices = [slice(None, None)]+[slice(int(s1), int(s1)+s2) for s1, s2 in zip(position, self.mask.shape)]
584
+ padding = []
585
+ for s1, s2 in zip(reversed(input.shape), reversed(self.mask.shape)):
586
+ if s1 < s2:
587
+ pad = s2-s1
588
+ else:
589
+ pad = 0
590
+ padding.append(0)
591
+ padding.append(pad)
592
+ value = torch.tensor(0, dtype=torch.uint8) if input.dtype == torch.uint8 else torch.tensor(self.value).to(input.device)
593
+ results.append(torch.where(self.mask.to(input.device) == 1, torch.nn.functional.pad(input, tuple(padding), mode="constant", value=value)[tuple(slices)], value))
594
+ return results
595
+
596
+ def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
597
+ pass