konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +509 -290
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +384 -277
- konfai/evaluator.py +309 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +341 -222
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +781 -423
- konfai/predictor.py +645 -240
- konfai/trainer.py +527 -216
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +495 -249
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
- konfai-1.1.9.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.7.dist-info/RECORD +0 -39
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/data/augmentation.py
CHANGED
|
@@ -1,33 +1,87 @@
|
|
|
1
1
|
import importlib
|
|
2
|
-
import torch
|
|
3
|
-
from abc import ABC, abstractmethod
|
|
4
|
-
import numpy as np
|
|
5
|
-
import SimpleITK as sitk
|
|
6
|
-
import torch.nn.functional as F
|
|
7
|
-
from typing import Union
|
|
8
2
|
import os
|
|
9
|
-
from
|
|
10
|
-
from konfai.utils.config import config
|
|
11
|
-
from konfai.utils.utils import _getModule, AugmentationError
|
|
12
|
-
from konfai.utils.dataset import Attribute, data_to_image, Dataset
|
|
13
|
-
|
|
14
|
-
def _translate2DMatrix(t: torch.Tensor) -> torch.Tensor:
|
|
15
|
-
return torch.cat((torch.cat((torch.eye(2), torch.tensor([[t[0]], [t[1]]])), dim=1), torch.Tensor([[0,0,1]])), dim=0)
|
|
16
|
-
|
|
17
|
-
def _translate3DMatrix(t: torch.Tensor) -> torch.Tensor:
|
|
18
|
-
return torch.cat((torch.cat((torch.eye(3), torch.tensor([[t[0]], [t[1]], [t[2]]])), dim=1), torch.Tensor([[0,0,0,1]])), dim=0)
|
|
19
|
-
|
|
20
|
-
def _scale2DMatrix(s: torch.Tensor) -> torch.Tensor:
|
|
21
|
-
return torch.cat((torch.cat((torch.eye(2)*s, torch.tensor([[0], [0]])), dim=1), torch.tensor([[0, 0, 1]])), dim=0)
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
22
4
|
|
|
23
|
-
|
|
24
|
-
|
|
5
|
+
import numpy as np
|
|
6
|
+
import SimpleITK as sitk # noqa: N813
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F # noqa: N812
|
|
25
9
|
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
10
|
+
from konfai import konfai_root
|
|
11
|
+
from konfai.utils.config import config
|
|
12
|
+
from konfai.utils.dataset import Attribute, Dataset, data_to_image
|
|
13
|
+
from konfai.utils.utils import AugmentationError, NeedDevice, get_module
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _translate_2d_matrix(t: torch.Tensor) -> torch.Tensor:
|
|
17
|
+
return torch.cat(
|
|
18
|
+
(
|
|
19
|
+
torch.cat((torch.eye(2), torch.tensor([[t[0]], [t[1]]])), dim=1),
|
|
20
|
+
torch.Tensor([[0, 0, 1]]),
|
|
21
|
+
),
|
|
22
|
+
dim=0,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _translate_3d_matrix(t: torch.Tensor) -> torch.Tensor:
|
|
27
|
+
return torch.cat(
|
|
28
|
+
(
|
|
29
|
+
torch.cat((torch.eye(3), torch.tensor([[t[0]], [t[1]], [t[2]]])), dim=1),
|
|
30
|
+
torch.Tensor([[0, 0, 0, 1]]),
|
|
31
|
+
),
|
|
32
|
+
dim=0,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _scale_2d_matrix(s: torch.Tensor) -> torch.Tensor:
|
|
37
|
+
return torch.cat(
|
|
38
|
+
(
|
|
39
|
+
torch.cat((torch.eye(2) * s, torch.tensor([[0], [0]])), dim=1),
|
|
40
|
+
torch.tensor([[0, 0, 1]]),
|
|
41
|
+
),
|
|
42
|
+
dim=0,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _scale_3d_matrix(s: torch.Tensor) -> torch.Tensor:
|
|
47
|
+
return torch.cat(
|
|
48
|
+
(
|
|
49
|
+
torch.cat((torch.eye(3) * s, torch.tensor([[0], [0], [0]])), dim=1),
|
|
50
|
+
torch.tensor([[0, 0, 0, 1]]),
|
|
51
|
+
),
|
|
52
|
+
dim=0,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _rotation_3d_matrix(rotation: torch.Tensor, center: torch.Tensor | None = None) -> torch.Tensor:
|
|
57
|
+
a = torch.tensor(
|
|
58
|
+
[
|
|
59
|
+
[torch.cos(rotation[2]), -torch.sin(rotation[2]), 0],
|
|
60
|
+
[torch.sin(rotation[2]), torch.cos(rotation[2]), 0],
|
|
61
|
+
[0, 0, 1],
|
|
62
|
+
]
|
|
63
|
+
)
|
|
64
|
+
b = torch.tensor(
|
|
65
|
+
[
|
|
66
|
+
[torch.cos(rotation[1]), 0, torch.sin(rotation[1])],
|
|
67
|
+
[0, 1, 0],
|
|
68
|
+
[-torch.sin(rotation[1]), 0, torch.cos(rotation[1])],
|
|
69
|
+
]
|
|
70
|
+
)
|
|
71
|
+
c = torch.tensor(
|
|
72
|
+
[
|
|
73
|
+
[1, 0, 0],
|
|
74
|
+
[0, torch.cos(rotation[0]), -torch.sin(rotation[0])],
|
|
75
|
+
[0, torch.sin(rotation[0]), torch.cos(rotation[0])],
|
|
76
|
+
]
|
|
77
|
+
)
|
|
78
|
+
rotation_matrix = torch.cat(
|
|
79
|
+
(
|
|
80
|
+
torch.cat((a.mm(b).mm(c), torch.zeros((3, 1))), dim=1),
|
|
81
|
+
torch.tensor([[0, 0, 0, 1]]),
|
|
82
|
+
),
|
|
83
|
+
dim=0,
|
|
84
|
+
)
|
|
31
85
|
if center is not None:
|
|
32
86
|
translation_before = torch.eye(4)
|
|
33
87
|
translation_before[:-1, -1] = -center
|
|
@@ -38,117 +92,197 @@ def _rotation3DMatrix(rotation : torch.Tensor, center: Union[torch.Tensor, None]
|
|
|
38
92
|
rotation_matrix = rotation_matrix.mm(translation_after)
|
|
39
93
|
return rotation_matrix
|
|
40
94
|
|
|
41
|
-
def _rotation2DMatrix(rotation : torch.Tensor, center: Union[torch.Tensor, None] = None) -> torch.Tensor:
|
|
42
|
-
return torch.cat((torch.cat((torch.tensor([[torch.cos(rotation[0]), -torch.sin(rotation[0])], [torch.sin(rotation[0]), torch.cos(rotation[0])]]), torch.zeros((2, 1))), dim=1), torch.tensor([[0, 0, 1]])), dim=0)
|
|
43
95
|
|
|
44
|
-
|
|
96
|
+
def _rotation_2d_matrix(rotation: torch.Tensor, center: torch.Tensor | None = None) -> torch.Tensor:
|
|
97
|
+
return torch.cat(
|
|
98
|
+
(
|
|
99
|
+
torch.cat(
|
|
100
|
+
(
|
|
101
|
+
torch.tensor(
|
|
102
|
+
[
|
|
103
|
+
[torch.cos(rotation[0]), -torch.sin(rotation[0])],
|
|
104
|
+
[torch.sin(rotation[0]), torch.cos(rotation[0])],
|
|
105
|
+
]
|
|
106
|
+
),
|
|
107
|
+
torch.zeros((2, 1)),
|
|
108
|
+
),
|
|
109
|
+
dim=1,
|
|
110
|
+
),
|
|
111
|
+
torch.tensor([[0, 0, 1]]),
|
|
112
|
+
),
|
|
113
|
+
dim=0,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class Prob:
|
|
45
118
|
|
|
46
119
|
@config()
|
|
47
120
|
def __init__(self, prob: float = 1.0) -> None:
|
|
48
121
|
self.prob = prob
|
|
49
122
|
|
|
50
|
-
|
|
123
|
+
|
|
124
|
+
class DataAugmentationsList:
|
|
51
125
|
|
|
52
126
|
@config()
|
|
53
|
-
def __init__(
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
nb: int = 10,
|
|
130
|
+
data_augmentations: dict[str, Prob] = {"default:Flip": Prob(1)},
|
|
131
|
+
) -> None:
|
|
54
132
|
self.nb = nb
|
|
55
|
-
self.
|
|
56
|
-
self.
|
|
133
|
+
self.data_augmentations: list[DataAugmentation] = []
|
|
134
|
+
self.data_augmentationsLoader = data_augmentations
|
|
57
135
|
|
|
58
136
|
def load(self, key: str, datasets: list[Dataset]):
|
|
59
|
-
for augmentation, prob in self.
|
|
60
|
-
module, name =
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
137
|
+
for augmentation, prob in self.data_augmentationsLoader.items():
|
|
138
|
+
module, name = get_module(augmentation, "konfai.data.augmentation")
|
|
139
|
+
data_augmentation: DataAugmentation = config(
|
|
140
|
+
f"{konfai_root()}.Dataset.augmentations.{key}.data_augmentations.{augmentation}"
|
|
141
|
+
)(getattr(importlib.import_module(module), name))(config=None)
|
|
142
|
+
data_augmentation.load(prob.prob)
|
|
143
|
+
data_augmentation.set_datasets(datasets)
|
|
144
|
+
self.data_augmentations.append(data_augmentation)
|
|
145
|
+
|
|
65
146
|
|
|
66
|
-
class DataAugmentation(ABC):
|
|
147
|
+
class DataAugmentation(NeedDevice, ABC):
|
|
67
148
|
|
|
68
|
-
def __init__(self, groups:
|
|
149
|
+
def __init__(self, groups: list[str] | None = None) -> None:
|
|
69
150
|
self.who_index: dict[int, list[int]] = {}
|
|
70
151
|
self.shape_index: dict[int, list[list[int]]] = {}
|
|
71
152
|
self._prob: float = 0
|
|
72
153
|
self.groups = groups
|
|
73
|
-
self.datasets
|
|
154
|
+
self.datasets: list[Dataset] = []
|
|
74
155
|
|
|
75
156
|
def load(self, prob: float):
|
|
76
157
|
self._prob = prob
|
|
77
158
|
|
|
78
|
-
def
|
|
159
|
+
def set_datasets(self, datasets: list[Dataset]):
|
|
79
160
|
self.datasets = datasets
|
|
80
|
-
|
|
81
|
-
def state_init(
|
|
161
|
+
|
|
162
|
+
def state_init(
|
|
163
|
+
self,
|
|
164
|
+
index: None | int,
|
|
165
|
+
shapes: list[list[int]],
|
|
166
|
+
caches_attribute: list[Attribute],
|
|
167
|
+
) -> list[list[int]]:
|
|
82
168
|
if index is not None:
|
|
83
169
|
if index not in self.who_index:
|
|
84
170
|
self.who_index[index] = torch.where(torch.rand(len(shapes)) < self._prob)[0].tolist()
|
|
85
171
|
else:
|
|
86
172
|
return self.shape_index[index]
|
|
87
|
-
else:
|
|
173
|
+
else:
|
|
88
174
|
index = 0
|
|
89
175
|
self.who_index[index] = torch.where(torch.rand(len(shapes)) < self._prob)[0].tolist()
|
|
90
|
-
|
|
176
|
+
|
|
91
177
|
if len(self.who_index[index]) > 0:
|
|
92
|
-
for i, shape in enumerate(
|
|
178
|
+
for i, shape in enumerate(
|
|
179
|
+
self._state_init(
|
|
180
|
+
index,
|
|
181
|
+
[shapes[i] for i in self.who_index[index]],
|
|
182
|
+
[caches_attribute[i] for i in self.who_index[index]],
|
|
183
|
+
)
|
|
184
|
+
):
|
|
93
185
|
shapes[self.who_index[index][i]] = shape
|
|
94
186
|
self.shape_index[index] = shapes
|
|
95
187
|
return self.shape_index[index]
|
|
96
|
-
|
|
188
|
+
|
|
97
189
|
@abstractmethod
|
|
98
|
-
def _state_init(self, index
|
|
190
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
99
191
|
pass
|
|
100
192
|
|
|
101
|
-
def __call__(
|
|
193
|
+
def __call__(
|
|
194
|
+
self,
|
|
195
|
+
name: str,
|
|
196
|
+
index: int,
|
|
197
|
+
tensors: list[torch.Tensor],
|
|
198
|
+
) -> list[torch.Tensor]:
|
|
102
199
|
if len(self.who_index[index]) > 0:
|
|
103
|
-
for i, result in enumerate(self._compute(name, index, [
|
|
104
|
-
|
|
105
|
-
return
|
|
106
|
-
|
|
200
|
+
for i, result in enumerate(self._compute(name, index, [tensors[i] for i in self.who_index[index]])):
|
|
201
|
+
tensors[self.who_index[index][i]] = result
|
|
202
|
+
return tensors
|
|
203
|
+
|
|
107
204
|
@abstractmethod
|
|
108
|
-
def _compute(
|
|
205
|
+
def _compute(
|
|
206
|
+
self,
|
|
207
|
+
name: str,
|
|
208
|
+
index: int,
|
|
209
|
+
tensors: list[torch.Tensor],
|
|
210
|
+
) -> list[torch.Tensor]:
|
|
109
211
|
pass
|
|
110
212
|
|
|
111
|
-
def inverse(self, index: int, a: int,
|
|
213
|
+
def inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
|
|
112
214
|
if a in self.who_index[index]:
|
|
113
|
-
|
|
114
|
-
return
|
|
115
|
-
|
|
215
|
+
tensor = self._inverse(index, a, tensor)
|
|
216
|
+
return tensor
|
|
217
|
+
|
|
116
218
|
@abstractmethod
|
|
117
|
-
def _inverse(self, index: int, a: int,
|
|
219
|
+
def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
|
|
118
220
|
pass
|
|
119
221
|
|
|
222
|
+
|
|
120
223
|
class EulerTransform(DataAugmentation):
|
|
121
224
|
|
|
122
225
|
def __init__(self) -> None:
|
|
123
226
|
super().__init__()
|
|
124
227
|
self.matrix: dict[int, list[torch.Tensor]] = {}
|
|
125
228
|
|
|
126
|
-
def _compute(
|
|
229
|
+
def _compute(
|
|
230
|
+
self,
|
|
231
|
+
name: str,
|
|
232
|
+
index: int,
|
|
233
|
+
tensors: list[torch.Tensor],
|
|
234
|
+
) -> list[torch.Tensor]:
|
|
127
235
|
results = []
|
|
128
|
-
for
|
|
129
|
-
results.append(
|
|
236
|
+
for tensor, matrix in zip(tensors, self.matrix[index]):
|
|
237
|
+
results.append(
|
|
238
|
+
F.grid_sample(
|
|
239
|
+
tensor.unsqueeze(0).type(torch.float32),
|
|
240
|
+
F.affine_grid(matrix[:, :-1, ...], [1] + list(tensor.shape), align_corners=True).to(tensor.device),
|
|
241
|
+
align_corners=True,
|
|
242
|
+
mode="bilinear",
|
|
243
|
+
padding_mode="reflection",
|
|
244
|
+
)
|
|
245
|
+
.type(tensor.dtype)
|
|
246
|
+
.squeeze(0)
|
|
247
|
+
)
|
|
130
248
|
return results
|
|
131
|
-
|
|
132
|
-
def _inverse(self, index: int, a: int,
|
|
133
|
-
return
|
|
134
|
-
|
|
249
|
+
|
|
250
|
+
def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
|
|
251
|
+
return (
|
|
252
|
+
F.grid_sample(
|
|
253
|
+
tensor.unsqueeze(0).type(torch.float32),
|
|
254
|
+
F.affine_grid(
|
|
255
|
+
self.matrix[index][a].inverse()[:, :-1, ...],
|
|
256
|
+
[1] + list(tensor.shape),
|
|
257
|
+
align_corners=True,
|
|
258
|
+
).to(tensor.device),
|
|
259
|
+
align_corners=True,
|
|
260
|
+
mode="bilinear",
|
|
261
|
+
padding_mode="reflection",
|
|
262
|
+
)
|
|
263
|
+
.type(tensor.dtype)
|
|
264
|
+
.squeeze(0)
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
135
268
|
class Translate(EulerTransform):
|
|
136
|
-
|
|
137
|
-
def __init__(self, t_min: float = -10, t_max
|
|
269
|
+
|
|
270
|
+
def __init__(self, t_min: float = -10, t_max=10, is_int: bool = False):
|
|
138
271
|
super().__init__()
|
|
139
272
|
self.t_min = t_min
|
|
140
273
|
self.t_max = t_max
|
|
141
274
|
self.is_int = is_int
|
|
142
275
|
|
|
143
|
-
def _state_init(self, index
|
|
276
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
144
277
|
dim = len(shapes[0])
|
|
145
|
-
func =
|
|
146
|
-
translate = torch.rand((len(shapes), dim)) * torch.tensor(self.t_max-self.t_min) + torch.tensor(self.t_min)
|
|
278
|
+
func = _translate_3d_matrix if dim == 3 else _translate_2d_matrix
|
|
279
|
+
translate = torch.rand((len(shapes), dim)) * torch.tensor(self.t_max - self.t_min) + torch.tensor(self.t_min)
|
|
147
280
|
if self.is_int:
|
|
148
|
-
translate = torch.round(translate*100)/100
|
|
281
|
+
translate = torch.round(translate * 100) / 100
|
|
149
282
|
self.matrix[index] = [torch.unsqueeze(func(value), dim=0) for value in translate]
|
|
150
283
|
return shapes
|
|
151
284
|
|
|
285
|
+
|
|
152
286
|
class Rotate(EulerTransform):
|
|
153
287
|
|
|
154
288
|
def __init__(self, a_min: float = 0, a_max: float = 360, is_quarter: bool = False):
|
|
@@ -157,18 +291,19 @@ class Rotate(EulerTransform):
|
|
|
157
291
|
self.a_max = a_max
|
|
158
292
|
self.is_quarter = is_quarter
|
|
159
293
|
|
|
160
|
-
def _state_init(self, index
|
|
294
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
161
295
|
dim = len(shapes[0])
|
|
162
|
-
func =
|
|
296
|
+
func = _rotation_3d_matrix if dim == 3 else _rotation_2d_matrix
|
|
163
297
|
angles = []
|
|
164
|
-
|
|
298
|
+
|
|
165
299
|
if self.is_quarter:
|
|
166
|
-
angles = torch.Tensor.repeat(torch.tensor([90,180,270]), 3)
|
|
300
|
+
angles = torch.Tensor.repeat(torch.tensor([90, 180, 270]), 3)
|
|
167
301
|
else:
|
|
168
|
-
angles = torch.rand((len(shapes), dim))*torch.tensor(self.a_max-self.a_min) + torch.tensor(self.a_min)
|
|
302
|
+
angles = torch.rand((len(shapes), dim)) * torch.tensor(self.a_max - self.a_min) + torch.tensor(self.a_min)
|
|
169
303
|
|
|
170
304
|
self.matrix[index] = [torch.unsqueeze(func(value), dim=0) for value in angles]
|
|
171
|
-
return shapes
|
|
305
|
+
return shapes
|
|
306
|
+
|
|
172
307
|
|
|
173
308
|
class Scale(EulerTransform):
|
|
174
309
|
|
|
@@ -176,193 +311,148 @@ class Scale(EulerTransform):
|
|
|
176
311
|
super().__init__()
|
|
177
312
|
self.s_std = s_std
|
|
178
313
|
|
|
179
|
-
def _state_init(self, index
|
|
180
|
-
func =
|
|
181
|
-
scale = torch.Tensor.repeat(
|
|
314
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
315
|
+
func = _scale_3d_matrix if len(shapes[0]) == 3 else _scale_2d_matrix
|
|
316
|
+
scale = torch.Tensor.repeat(
|
|
317
|
+
torch.exp2(torch.randn(len(shapes)) * self.s_std).unsqueeze(1),
|
|
318
|
+
[1, len(shapes[0])],
|
|
319
|
+
)
|
|
182
320
|
self.matrix[index] = [torch.unsqueeze(func(value), dim=0) for value in scale]
|
|
183
321
|
return shapes
|
|
184
322
|
|
|
323
|
+
|
|
185
324
|
class Flip(DataAugmentation):
|
|
186
325
|
|
|
187
|
-
def __init__(self, f_prob:
|
|
326
|
+
def __init__(self, f_prob: list[float] = [0.33, 0.33, 0.33]) -> None:
|
|
188
327
|
super().__init__()
|
|
189
328
|
self.f_prob = f_prob
|
|
190
329
|
self.flip: dict[int, list[int]] = {}
|
|
191
330
|
|
|
192
|
-
def _state_init(self, index
|
|
331
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
193
332
|
prob = torch.rand((len(shapes), len(self.f_prob))) < torch.tensor(self.f_prob)
|
|
194
|
-
dims = torch.tensor([1, 2, 3][:len(self.f_prob)])
|
|
333
|
+
dims = torch.tensor([1, 2, 3][: len(self.f_prob)])
|
|
195
334
|
self.flip[index] = [dims[mask].tolist() for mask in prob]
|
|
196
335
|
return shapes
|
|
197
336
|
|
|
198
|
-
def _compute(
|
|
337
|
+
def _compute(
|
|
338
|
+
self,
|
|
339
|
+
name: str,
|
|
340
|
+
index: int,
|
|
341
|
+
tensors: list[torch.Tensor],
|
|
342
|
+
) -> list[torch.Tensor]:
|
|
199
343
|
results = []
|
|
200
|
-
for
|
|
201
|
-
results.append(torch.flip(
|
|
344
|
+
for tensor, flip in zip(tensors, self.flip[index]):
|
|
345
|
+
results.append(torch.flip(tensor, dims=flip))
|
|
202
346
|
return results
|
|
203
|
-
|
|
204
|
-
def _inverse(self, index: int, a: int,
|
|
205
|
-
return torch.flip(
|
|
206
|
-
|
|
347
|
+
|
|
348
|
+
def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
|
|
349
|
+
return torch.flip(tensor, dims=self.flip[index][a])
|
|
350
|
+
|
|
207
351
|
|
|
208
352
|
class ColorTransform(DataAugmentation):
|
|
209
353
|
|
|
210
|
-
def __init__(self, groups:
|
|
354
|
+
def __init__(self, groups: list[str] | None = None) -> None:
|
|
211
355
|
super().__init__(groups)
|
|
212
356
|
self.matrix: dict[int, list[torch.Tensor]] = {}
|
|
213
|
-
|
|
214
|
-
def _compute(
|
|
357
|
+
|
|
358
|
+
def _compute(
|
|
359
|
+
self,
|
|
360
|
+
name: str,
|
|
361
|
+
index: int,
|
|
362
|
+
tensors: list[torch.Tensor],
|
|
363
|
+
) -> list[torch.Tensor]:
|
|
215
364
|
results = []
|
|
216
|
-
for
|
|
217
|
-
result =
|
|
218
|
-
if
|
|
219
|
-
matrix = matrix.to(
|
|
365
|
+
for tensor, matrix in zip(tensors, self.matrix[index]):
|
|
366
|
+
result = tensor.reshape([*tensor.shape[:1], int(np.prod(tensor.shape[1:]))])
|
|
367
|
+
if tensor.shape[0] == 3:
|
|
368
|
+
matrix = matrix.to(tensor.device)
|
|
220
369
|
result = matrix[:, :3, :3] @ result.float() + matrix[:, :3, 3:]
|
|
221
|
-
elif
|
|
222
|
-
matrix = matrix[:, :3, :].mean(dim=1, keepdims=True).to(
|
|
370
|
+
elif tensor.shape[0] == 1:
|
|
371
|
+
matrix = matrix[:, :3, :].mean(dim=1, keepdims=True).to(tensor.device)
|
|
223
372
|
result = result.float() * matrix[:, :, :3].sum(dim=2, keepdims=True) + matrix[:, :, 3:]
|
|
224
373
|
else:
|
|
225
|
-
raise AugmentationError(
|
|
226
|
-
results.append(result.reshape(
|
|
374
|
+
raise AugmentationError("Image must be RGB (3 channels) or L (1 channel)")
|
|
375
|
+
results.append(result.reshape(tensor.shape))
|
|
227
376
|
return results
|
|
228
|
-
|
|
229
|
-
def _inverse(self, index: int, a: int,
|
|
377
|
+
|
|
378
|
+
def _inverse(self, index: int, a: int, tensors: torch.Tensor) -> torch.Tensor:
|
|
230
379
|
pass
|
|
231
380
|
|
|
381
|
+
|
|
232
382
|
class Brightness(ColorTransform):
|
|
233
383
|
|
|
234
|
-
def __init__(self, b_std: float, groups:
|
|
384
|
+
def __init__(self, b_std: float, groups: list[str] | None = None) -> None:
|
|
235
385
|
super().__init__(groups)
|
|
236
386
|
self.b_std = b_std
|
|
237
387
|
|
|
238
|
-
def _state_init(self, index
|
|
239
|
-
brightness = torch.Tensor.repeat((torch.randn(
|
|
240
|
-
self.matrix[index] = [torch.unsqueeze(
|
|
388
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
389
|
+
brightness = torch.Tensor.repeat((torch.randn(len(shapes)) * self.b_std).unsqueeze(1), [1, 3])
|
|
390
|
+
self.matrix[index] = [torch.unsqueeze(_translate_3d_matrix(value), dim=0) for value in brightness]
|
|
241
391
|
return shapes
|
|
242
392
|
|
|
393
|
+
|
|
243
394
|
class Contrast(ColorTransform):
|
|
244
395
|
|
|
245
|
-
def __init__(self, c_std: float, groups:
|
|
396
|
+
def __init__(self, c_std: float, groups: list[str] | None = None) -> None:
|
|
246
397
|
super().__init__(groups)
|
|
247
398
|
self.c_std = c_std
|
|
248
399
|
|
|
249
|
-
def _state_init(self, index
|
|
250
|
-
contrast = torch.exp2(torch.randn(
|
|
251
|
-
self.matrix[index] = [torch.unsqueeze(
|
|
400
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
401
|
+
contrast = torch.exp2(torch.randn(len(shapes)) * self.c_std)
|
|
402
|
+
self.matrix[index] = [torch.unsqueeze(_scale_3d_matrix(value), dim=0) for value in contrast]
|
|
252
403
|
return shapes
|
|
253
404
|
|
|
405
|
+
|
|
254
406
|
class LumaFlip(ColorTransform):
|
|
255
407
|
|
|
256
|
-
def __init__(self, groups:
|
|
408
|
+
def __init__(self, groups: list[str] | None = None) -> None:
|
|
257
409
|
super().__init__(groups)
|
|
258
|
-
self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
|
|
259
|
-
|
|
260
|
-
def _state_init(self, index
|
|
410
|
+
self.v = torch.tensor([1, 1, 1, 0]) / torch.sqrt(torch.tensor(3))
|
|
411
|
+
|
|
412
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
261
413
|
luma = torch.floor(torch.rand([len(shapes), 1, 1]) * 2)
|
|
262
414
|
self.matrix[index] = [torch.unsqueeze((torch.eye(4) - 2 * self.v.ger(self.v) * value), dim=0) for value in luma]
|
|
263
415
|
return shapes
|
|
264
416
|
|
|
417
|
+
|
|
265
418
|
class HUE(ColorTransform):
|
|
266
419
|
|
|
267
|
-
def __init__(self, hue_max: float, groups:
|
|
420
|
+
def __init__(self, hue_max: float, groups: list[str] | None = None) -> None:
|
|
268
421
|
super().__init__(groups)
|
|
269
422
|
self.hue_max = hue_max
|
|
270
|
-
self.v = torch.tensor([1, 1, 1])/torch.sqrt(torch.tensor(3))
|
|
271
|
-
|
|
272
|
-
def _state_init(self, index
|
|
423
|
+
self.v = torch.tensor([1, 1, 1]) / torch.sqrt(torch.tensor(3))
|
|
424
|
+
|
|
425
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
273
426
|
theta = (torch.rand([len(shapes)]) * 2 - 1) * np.pi * self.hue_max
|
|
274
|
-
self.matrix[index] = [torch.unsqueeze(
|
|
427
|
+
self.matrix[index] = [torch.unsqueeze(_rotation_3d_matrix(value.repeat(3), self.v), dim=0) for value in theta]
|
|
275
428
|
return shapes
|
|
276
|
-
|
|
429
|
+
|
|
430
|
+
|
|
277
431
|
class Saturation(ColorTransform):
|
|
278
432
|
|
|
279
|
-
def __init__(self, s_std: float, groups:
|
|
433
|
+
def __init__(self, s_std: float, groups: list[str] | None = None) -> None:
|
|
280
434
|
super().__init__(groups)
|
|
281
435
|
self.s_std = s_std
|
|
282
|
-
self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
|
|
436
|
+
self.v = torch.tensor([1, 1, 1, 0]) / torch.sqrt(torch.tensor(3))
|
|
283
437
|
|
|
284
|
-
def _state_init(self, index
|
|
285
|
-
saturation = torch.exp2(torch.randn(
|
|
286
|
-
self.matrix[index] = [
|
|
438
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
439
|
+
saturation = torch.exp2(torch.randn(len(shapes)) * self.s_std)
|
|
440
|
+
self.matrix[index] = [
|
|
441
|
+
(self.v.ger(self.v) + (torch.eye(4) - self.v.ger(self.v))).unsqueeze(0) * value for value in saturation
|
|
442
|
+
]
|
|
287
443
|
return shapes
|
|
288
|
-
|
|
289
|
-
"""class Filter(DataAugmentation):
|
|
290
444
|
|
|
291
|
-
def __init__(self) -> None:
|
|
292
|
-
super().__init__()
|
|
293
|
-
wavelets = {
|
|
294
|
-
'haar': [0.7071067811865476, 0.7071067811865476],
|
|
295
|
-
'db1': [0.7071067811865476, 0.7071067811865476],
|
|
296
|
-
'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
|
297
|
-
'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
|
298
|
-
'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
|
|
299
|
-
'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
|
|
300
|
-
'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
|
|
301
|
-
'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
|
|
302
|
-
'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
|
|
303
|
-
'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
|
|
304
|
-
'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
|
|
305
|
-
'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
|
|
306
|
-
'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
|
|
307
|
-
'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
|
|
308
|
-
'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
|
|
309
|
-
'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
|
|
310
|
-
}
|
|
311
|
-
Hz_lo = np.asarray(wavelets['sym2']) # H(z)
|
|
312
|
-
Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
|
|
313
|
-
Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
|
|
314
|
-
Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
|
|
315
|
-
Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
|
|
316
|
-
for i in range(1, Hz_fbank.shape[0]):
|
|
317
|
-
Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
|
|
318
|
-
Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
|
|
319
|
-
Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
|
|
320
|
-
self.Hz_fbank = torch.as_tensor(Hz_fbank, dtype=torch.float32)
|
|
321
|
-
|
|
322
|
-
self.imgfilter_bands = [1,1,1,1]
|
|
323
|
-
self.imgfilter_std = 1
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
327
|
-
|
|
328
|
-
return shapes
|
|
329
|
-
|
|
330
|
-
def _compute(self, index: int, inputs : list[torch.Tensor]) -> list[torch.Tensor]:
|
|
331
|
-
num_bands = self.Hz_fbank.shape[0]
|
|
332
|
-
assert len(self.imgfilter_bands) == num_bands
|
|
333
|
-
expected_power =torch.tensor([10, 1, 1, 1]) / 13 # Expected power spectrum (1/f).
|
|
334
|
-
for input in inputs:
|
|
335
|
-
batch_size = input.shape[0]
|
|
336
|
-
num_channels = input.shape[1]
|
|
337
|
-
# Apply amplification for each band with probability (imgfilter * strength * band_strength).
|
|
338
|
-
g = torch.ones([batch_size, num_bands]) # Global gain vector (identity).
|
|
339
|
-
for i, band_strength in enumerate(self.imgfilter_bands):
|
|
340
|
-
t_i = torch.exp2(torch.randn([batch_size]) * self.imgfilter_std)
|
|
341
|
-
t = torch.ones([batch_size, num_bands]) # Temporary gain vector.
|
|
342
|
-
t[:, i] = t_i # Replace i'th element.
|
|
343
|
-
t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
|
|
344
|
-
g = g * t # Accumulate into global gain.
|
|
345
|
-
|
|
346
|
-
# Construct combined amplification filter.
|
|
347
|
-
Hz_prime = g @ self.Hz_fbank # [batch, tap]
|
|
348
|
-
Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
|
|
349
|
-
Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
|
|
350
|
-
|
|
351
|
-
# Apply filter.
|
|
352
|
-
p = self.Hz_fbank.shape[1] // 2
|
|
353
|
-
images = images.reshape([1, batch_size * num_channels, height, width])
|
|
354
|
-
images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
|
|
355
|
-
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
|
|
356
|
-
images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
|
|
357
|
-
images = images.reshape([batch_size, num_channels, height, width])
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
361
|
-
pass """
|
|
362
445
|
|
|
363
446
|
class Noise(DataAugmentation):
|
|
364
447
|
|
|
365
|
-
def __init__(
|
|
448
|
+
def __init__(
|
|
449
|
+
self,
|
|
450
|
+
n_std: float,
|
|
451
|
+
noise_step: int = 1000,
|
|
452
|
+
beta_start: float = 1e-4,
|
|
453
|
+
beta_end: float = 0.02,
|
|
454
|
+
groups: list[str] | None = None,
|
|
455
|
+
) -> None:
|
|
366
456
|
super().__init__(groups)
|
|
367
457
|
self.n_std = n_std
|
|
368
458
|
self.noise_step = noise_step
|
|
@@ -372,78 +462,108 @@ class Noise(DataAugmentation):
|
|
|
372
462
|
self.betas = Noise.enforce_zero_terminal_snr(self.betas)
|
|
373
463
|
self.alphas = 1 - self.betas
|
|
374
464
|
self.alpha_hat = torch.concat((torch.ones(1), torch.cumprod(self.alphas, dim=0)))
|
|
375
|
-
self.max_T = 0
|
|
465
|
+
self.max_T = 0.0
|
|
376
466
|
|
|
377
467
|
self.C = 1
|
|
378
468
|
self.n = 4
|
|
379
469
|
self.d = 0.25
|
|
380
470
|
self._prob = 1
|
|
381
471
|
|
|
472
|
+
@staticmethod
|
|
382
473
|
def enforce_zero_terminal_snr(betas: torch.Tensor):
|
|
383
474
|
alphas = 1 - betas
|
|
384
475
|
alphas_bar = alphas.cumprod(0)
|
|
385
476
|
alphas_bar_sqrt = alphas_bar.sqrt()
|
|
386
477
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
|
387
|
-
|
|
388
|
-
alphas_bar_sqrt -=
|
|
389
|
-
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
|
|
390
|
-
|
|
391
|
-
alphas_bar = alphas_bar_sqrt ** 2
|
|
478
|
+
alphas_bar_sqrt_t = alphas_bar_sqrt[-1].clone()
|
|
479
|
+
alphas_bar_sqrt -= alphas_bar_sqrt_t
|
|
480
|
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_t)
|
|
481
|
+
alphas_bar = alphas_bar_sqrt**2
|
|
392
482
|
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
|
393
483
|
alphas = torch.cat([alphas_bar[0:1], alphas])
|
|
394
484
|
betas = 1 - alphas
|
|
395
485
|
return betas
|
|
396
|
-
|
|
486
|
+
|
|
397
487
|
def load(self, prob: float):
|
|
398
|
-
self.max_T = prob*self.noise_step
|
|
488
|
+
self.max_T = prob * self.noise_step
|
|
399
489
|
|
|
400
|
-
def _state_init(self, index
|
|
490
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
401
491
|
if int(self.max_T) == 0:
|
|
402
492
|
self.ts[index] = [0 for _ in shapes]
|
|
403
|
-
else:
|
|
493
|
+
else:
|
|
404
494
|
self.ts[index] = [torch.randint(0, int(self.max_T), (1,)) for _ in shapes]
|
|
405
495
|
return shapes
|
|
406
|
-
|
|
407
|
-
def _compute(
|
|
496
|
+
|
|
497
|
+
def _compute(
|
|
498
|
+
self,
|
|
499
|
+
name: str,
|
|
500
|
+
index: int,
|
|
501
|
+
tensors: list[torch.Tensor],
|
|
502
|
+
) -> list[torch.Tensor]:
|
|
408
503
|
results = []
|
|
409
|
-
for
|
|
410
|
-
alpha_hat_t = self.alpha_hat[t].to(
|
|
411
|
-
results.append(
|
|
504
|
+
for tensor, t in zip(tensors, self.ts[index]):
|
|
505
|
+
alpha_hat_t = self.alpha_hat[t].to(tensor.device).reshape(*[1 for _ in range(len(tensor.shape))])
|
|
506
|
+
results.append(
|
|
507
|
+
alpha_hat_t.sqrt() * tensor
|
|
508
|
+
+ (1 - alpha_hat_t).sqrt() * torch.randn_like(tensor.float()).to(tensor.device) * self.n_std
|
|
509
|
+
)
|
|
412
510
|
return results
|
|
413
511
|
|
|
414
|
-
def _inverse(self, index: int, a: int,
|
|
512
|
+
def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
|
|
415
513
|
pass
|
|
416
514
|
|
|
515
|
+
|
|
417
516
|
class CutOUT(DataAugmentation):
|
|
418
517
|
|
|
419
|
-
def __init__(
|
|
518
|
+
def __init__(
|
|
519
|
+
self,
|
|
520
|
+
c_prob: float,
|
|
521
|
+
cutout_size: int,
|
|
522
|
+
value: float,
|
|
523
|
+
groups: list[str] | None = None,
|
|
524
|
+
) -> None:
|
|
420
525
|
super().__init__(groups)
|
|
421
526
|
self.c_prob = c_prob
|
|
422
527
|
self.cutout_size = cutout_size
|
|
423
528
|
self.centers: dict[int, list[torch.Tensor]] = {}
|
|
424
529
|
self.value = value
|
|
425
530
|
|
|
426
|
-
def _state_init(self, index
|
|
531
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
427
532
|
self.centers[index] = [torch.rand((3) if len(shape) == 3 else (2)) for shape in shapes]
|
|
428
533
|
return shapes
|
|
429
|
-
|
|
430
|
-
def _compute(
|
|
534
|
+
|
|
535
|
+
def _compute(
|
|
536
|
+
self,
|
|
537
|
+
name: str,
|
|
538
|
+
index: int,
|
|
539
|
+
tensors: list[torch.Tensor],
|
|
540
|
+
) -> list[torch.Tensor]:
|
|
431
541
|
results = []
|
|
432
|
-
for
|
|
542
|
+
for tensor, center in zip(tensors, self.centers[index]):
|
|
433
543
|
masks = []
|
|
434
|
-
for i, w in enumerate(
|
|
435
|
-
re = [1]*i+[-1]+[1]*(len(
|
|
436
|
-
masks.append(
|
|
544
|
+
for i, w in enumerate(tensor.shape[1:]):
|
|
545
|
+
re = [1] * i + [-1] + [1] * (len(tensor.shape[1:]) - i - 1)
|
|
546
|
+
masks.append(
|
|
547
|
+
((torch.arange(w).reshape(re) + 0.5) / w - center[i].reshape([1, 1])).abs()
|
|
548
|
+
>= torch.tensor(self.cutout_size).reshape([1, 1]) / 2
|
|
549
|
+
)
|
|
437
550
|
result = masks[0]
|
|
438
551
|
for mask in masks[1:]:
|
|
439
552
|
result = torch.logical_or(result, mask)
|
|
440
|
-
result = result.unsqueeze(0).repeat([
|
|
441
|
-
results.append(
|
|
553
|
+
result = result.unsqueeze(0).repeat([tensor.shape[0], *[1 for _ in range(len(tensor.shape) - 1)]])
|
|
554
|
+
results.append(
|
|
555
|
+
torch.where(
|
|
556
|
+
result.to(tensor.device) == 1,
|
|
557
|
+
tensor,
|
|
558
|
+
torch.tensor(self.value).to(tensor.device),
|
|
559
|
+
)
|
|
560
|
+
)
|
|
442
561
|
return results
|
|
443
|
-
|
|
444
|
-
def _inverse(self, index: int, a: int,
|
|
562
|
+
|
|
563
|
+
def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
|
|
445
564
|
pass
|
|
446
565
|
|
|
566
|
+
|
|
447
567
|
class Elastix(DataAugmentation):
|
|
448
568
|
|
|
449
569
|
def __init__(self, grid_spacing: int = 16, max_displacement: int = 16) -> None:
|
|
@@ -452,16 +572,16 @@ class Elastix(DataAugmentation):
|
|
|
452
572
|
self.max_displacement = max_displacement
|
|
453
573
|
self.displacement_fields: dict[int, list[torch.Tensor]] = {}
|
|
454
574
|
self.displacement_fields_true: dict[int, list[torch.Tensor]] = {}
|
|
455
|
-
|
|
575
|
+
|
|
456
576
|
@staticmethod
|
|
457
|
-
def
|
|
577
|
+
def _format_loc(new_locs, shape):
|
|
458
578
|
for i in range(len(shape)):
|
|
459
579
|
new_locs[..., i] = 2 * (new_locs[..., i] / (shape[i] - 1) - 0.5)
|
|
460
|
-
new_locs = new_locs[...,
|
|
580
|
+
new_locs = new_locs[..., list(reversed(range(len(shape))))]
|
|
461
581
|
return new_locs
|
|
462
582
|
|
|
463
|
-
def _state_init(self, index
|
|
464
|
-
print("Compute Displacement Field for index {}"
|
|
583
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
584
|
+
print(f"Compute Displacement Field for index {index}")
|
|
465
585
|
self.displacement_fields[index] = []
|
|
466
586
|
self.displacement_fields_true[index] = []
|
|
467
587
|
for i, (shape, cache_attribute) in enumerate(zip(shapes, caches_attribute)):
|
|
@@ -471,117 +591,168 @@ class Elastix(DataAugmentation):
|
|
|
471
591
|
spacing = np.array([1.0 for _ in range(dim)])
|
|
472
592
|
else:
|
|
473
593
|
spacing = cache_attribute.get_np_array("Spacing")
|
|
474
|
-
|
|
475
|
-
grid_physical_spacing = [self.grid_spacing]*dim
|
|
476
|
-
image_physical_size = [size*spacing for size, spacing in zip(shape, spacing)]
|
|
477
|
-
mesh_size = [
|
|
594
|
+
|
|
595
|
+
grid_physical_spacing = [self.grid_spacing] * dim
|
|
596
|
+
image_physical_size = [size * spacing for size, spacing in zip(shape, spacing)]
|
|
597
|
+
mesh_size = [
|
|
598
|
+
int(image_size / grid_spacing + 0.5)
|
|
599
|
+
for image_size, grid_spacing in zip(image_physical_size, grid_physical_spacing)
|
|
600
|
+
]
|
|
478
601
|
if "Spacing" not in cache_attribute:
|
|
479
602
|
cache_attribute["Spacing"] = np.array([1.0 for _ in range(dim)])
|
|
480
603
|
if "Origin" not in cache_attribute:
|
|
481
604
|
cache_attribute["Origin"] = np.array([1.0 for _ in range(dim)])
|
|
482
605
|
if "Direction" not in cache_attribute:
|
|
483
606
|
cache_attribute["Direction"] = np.eye(dim).flatten()
|
|
484
|
-
|
|
607
|
+
|
|
485
608
|
ref_image = data_to_image(np.expand_dims(np.zeros(shape), 0), cache_attribute)
|
|
486
609
|
|
|
487
|
-
bspline_transform = sitk.BSplineTransformInitializer(
|
|
610
|
+
bspline_transform = sitk.BSplineTransformInitializer(
|
|
611
|
+
image1=ref_image, transformDomainMeshSize=mesh_size, order=3
|
|
612
|
+
)
|
|
488
613
|
displacement_filter = sitk.TransformToDisplacementFieldFilter()
|
|
489
614
|
displacement_filter.SetReferenceImage(ref_image)
|
|
490
|
-
|
|
615
|
+
|
|
491
616
|
vectors = [torch.arange(0, s) for s in shape]
|
|
492
|
-
grids = torch.meshgrid(vectors, indexing=
|
|
617
|
+
grids = torch.meshgrid(vectors, indexing="ij")
|
|
493
618
|
grid = torch.stack(grids)
|
|
494
619
|
grid = torch.unsqueeze(grid, 0)
|
|
495
|
-
grid = grid.type(torch.float).permute([0]+[i+2 for i in range(len(shape))] + [1])
|
|
496
|
-
|
|
497
|
-
control_points = torch.rand(*[size+3 for size in mesh_size], dim)
|
|
620
|
+
grid = grid.type(torch.float).permute([0] + [i + 2 for i in range(len(shape))] + [1])
|
|
621
|
+
|
|
622
|
+
control_points = torch.rand(*[size + 3 for size in mesh_size], dim)
|
|
498
623
|
control_points -= 0.5
|
|
499
|
-
control_points *= 2*self.max_displacement
|
|
624
|
+
control_points *= 2 * self.max_displacement
|
|
500
625
|
bspline_transform.SetParameters(control_points.flatten().tolist())
|
|
501
626
|
displacement = sitk.GetArrayFromImage(displacement_filter.Execute(bspline_transform))
|
|
502
627
|
self.displacement_fields_true[index].append(displacement)
|
|
503
|
-
new_locs = grid+torch.unsqueeze(torch.from_numpy(displacement), 0).type(torch.float32)
|
|
504
|
-
self.displacement_fields[index].append(Elastix.
|
|
505
|
-
print("Compute in progress : {
|
|
628
|
+
new_locs = grid + torch.unsqueeze(torch.from_numpy(displacement), 0).type(torch.float32)
|
|
629
|
+
self.displacement_fields[index].append(Elastix._format_loc(new_locs, shape))
|
|
630
|
+
print(f"Compute in progress : {(i + 1) / len(shapes) * 100:.2f} %")
|
|
506
631
|
return shapes
|
|
507
|
-
|
|
508
|
-
def _compute(
|
|
632
|
+
|
|
633
|
+
def _compute(
|
|
634
|
+
self,
|
|
635
|
+
name: str,
|
|
636
|
+
index: int,
|
|
637
|
+
tensors: list[torch.Tensor],
|
|
638
|
+
) -> list[torch.Tensor]:
|
|
509
639
|
results = []
|
|
510
|
-
for
|
|
511
|
-
results.append(
|
|
640
|
+
for tensor, displacement_field in zip(tensors, self.displacement_fields[index]):
|
|
641
|
+
results.append(
|
|
642
|
+
F.grid_sample(
|
|
643
|
+
tensor.type(torch.float32).unsqueeze(0),
|
|
644
|
+
displacement_field.to(tensor.device),
|
|
645
|
+
align_corners=True,
|
|
646
|
+
mode="bilinear",
|
|
647
|
+
padding_mode="border",
|
|
648
|
+
)
|
|
649
|
+
.type(tensor.dtype)
|
|
650
|
+
.squeeze(0)
|
|
651
|
+
)
|
|
512
652
|
return results
|
|
513
|
-
|
|
514
|
-
def _inverse(self, index: int, a: int,
|
|
653
|
+
|
|
654
|
+
def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
|
|
515
655
|
pass
|
|
516
656
|
|
|
657
|
+
|
|
517
658
|
class Permute(DataAugmentation):
|
|
518
659
|
|
|
519
|
-
def __init__(self, prob_permute:
|
|
660
|
+
def __init__(self, prob_permute: list[float] | None = [0.5, 0.5]) -> None:
|
|
520
661
|
super().__init__()
|
|
521
662
|
self._permute_dims = torch.tensor([[0, 2, 1, 3], [0, 3, 1, 2]])
|
|
522
663
|
self.prob_permute = prob_permute
|
|
523
664
|
self.permute: dict[int, torch.Tensor] = {}
|
|
524
665
|
|
|
525
|
-
def _state_init(self, index
|
|
666
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
526
667
|
if len(shapes):
|
|
527
668
|
dim = len(shapes[0])
|
|
528
|
-
|
|
669
|
+
if dim != 3:
|
|
670
|
+
raise ValueError("The permute augmentation only support 3D images")
|
|
529
671
|
if self.prob_permute:
|
|
530
|
-
|
|
531
|
-
|
|
672
|
+
if len(self.prob_permute) != 2:
|
|
673
|
+
raise ValueError("Size of prob_permute must be equal 2")
|
|
674
|
+
self.permute[index] = torch.rand((len(shapes), len(self.prob_permute))) < torch.tensor(
|
|
675
|
+
self.prob_permute
|
|
676
|
+
)
|
|
532
677
|
else:
|
|
533
|
-
|
|
678
|
+
if len(shapes) != 2:
|
|
679
|
+
raise ValueError("The number of augmentation images must be equal to 2")
|
|
534
680
|
self.permute[index] = torch.eye(2, dtype=torch.bool)
|
|
535
681
|
for i, prob in enumerate(self.permute[index]):
|
|
536
682
|
for permute in self._permute_dims[prob]:
|
|
537
|
-
shapes[i] = [shapes[i][dim-1] for dim in permute[1:]]
|
|
683
|
+
shapes[i] = [shapes[i][dim - 1] for dim in permute[1:]]
|
|
538
684
|
return shapes
|
|
539
|
-
|
|
540
|
-
def _compute(
|
|
685
|
+
|
|
686
|
+
def _compute(
|
|
687
|
+
self,
|
|
688
|
+
name: str,
|
|
689
|
+
index: int,
|
|
690
|
+
tensors: list[torch.Tensor],
|
|
691
|
+
) -> list[torch.Tensor]:
|
|
541
692
|
results = []
|
|
542
|
-
for
|
|
543
|
-
res =
|
|
693
|
+
for tensor, prob in zip(tensors, self.permute[index]):
|
|
694
|
+
res = tensor
|
|
544
695
|
for permute in self._permute_dims[prob]:
|
|
545
696
|
res = res.permute(tuple(permute))
|
|
546
697
|
results.append(res)
|
|
547
698
|
return results
|
|
548
|
-
|
|
549
|
-
def _inverse(self, index: int, a: int,
|
|
699
|
+
|
|
700
|
+
def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
|
|
550
701
|
for permute in reversed(self._permute_dims[self.permute[index][a]]):
|
|
551
|
-
|
|
552
|
-
return
|
|
702
|
+
tensor = tensor.permute(tuple(np.argsort(permute)))
|
|
703
|
+
return tensor
|
|
704
|
+
|
|
553
705
|
|
|
554
706
|
class Mask(DataAugmentation):
|
|
555
707
|
|
|
556
|
-
def __init__(self, mask: str, value: float, groups:
|
|
708
|
+
def __init__(self, mask: str, value: float, groups: list[str] | None = None) -> None:
|
|
557
709
|
super().__init__(groups)
|
|
558
710
|
if mask is not None:
|
|
559
711
|
if os.path.exists(mask):
|
|
560
712
|
self.mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(mask)))
|
|
561
713
|
else:
|
|
562
|
-
raise NameError(
|
|
714
|
+
raise NameError("Mask file not found")
|
|
563
715
|
self.positions: dict[int, list[torch.Tensor]] = {}
|
|
564
716
|
self.value = value
|
|
565
717
|
|
|
566
|
-
def _state_init(self, index
|
|
567
|
-
self.positions[index] = [
|
|
718
|
+
def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
719
|
+
self.positions[index] = [
|
|
720
|
+
torch.rand((3) if len(shape) == 3 else (2))
|
|
721
|
+
* (torch.tensor([max(s1 - s2, 0) for s1, s2 in zip(torch.tensor(shape), torch.tensor(self.mask.shape))]))
|
|
722
|
+
for shape in shapes
|
|
723
|
+
]
|
|
568
724
|
return [self.mask.shape for _ in shapes]
|
|
569
|
-
|
|
570
|
-
def _compute(
|
|
725
|
+
|
|
726
|
+
def _compute(
|
|
727
|
+
self,
|
|
728
|
+
name: str,
|
|
729
|
+
index: int,
|
|
730
|
+
tensors: list[torch.Tensor],
|
|
731
|
+
) -> list[torch.Tensor]:
|
|
571
732
|
results = []
|
|
572
|
-
for
|
|
573
|
-
slices = [slice(None, None)]+[slice(int(s1), int(s1)+s2) for s1, s2 in zip(position, self.mask.shape)]
|
|
733
|
+
for tensor, position in zip(tensors, self.positions[index]):
|
|
734
|
+
slices = [slice(None, None)] + [slice(int(s1), int(s1) + s2) for s1, s2 in zip(position, self.mask.shape)]
|
|
574
735
|
padding = []
|
|
575
|
-
for s1, s2 in zip(reversed(
|
|
736
|
+
for s1, s2 in zip(reversed(tensor.shape), reversed(self.mask.shape)):
|
|
576
737
|
if s1 < s2:
|
|
577
|
-
pad = s2-s1
|
|
738
|
+
pad = s2 - s1
|
|
578
739
|
else:
|
|
579
740
|
pad = 0
|
|
580
741
|
padding.append(0)
|
|
581
742
|
padding.append(pad)
|
|
582
|
-
value =
|
|
583
|
-
|
|
743
|
+
value = (
|
|
744
|
+
torch.tensor(0, dtype=torch.uint8)
|
|
745
|
+
if tensor.dtype == torch.uint8
|
|
746
|
+
else torch.tensor(self.value).to(tensor.device)
|
|
747
|
+
)
|
|
748
|
+
results.append(
|
|
749
|
+
torch.where(
|
|
750
|
+
self.mask.to(tensor.device) == 1,
|
|
751
|
+
torch.nn.functional.pad(tensor, tuple(padding), mode="constant", value=value)[tuple(slices)],
|
|
752
|
+
value,
|
|
753
|
+
)
|
|
754
|
+
)
|
|
584
755
|
return results
|
|
585
|
-
|
|
586
|
-
def _inverse(self, index: int, a: int,
|
|
756
|
+
|
|
757
|
+
def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
|
|
587
758
|
pass
|