konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +509 -290
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +384 -277
  6. konfai/evaluator.py +309 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +341 -222
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +781 -423
  23. konfai/predictor.py +645 -240
  24. konfai/trainer.py +527 -216
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +495 -249
  29. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
  30. konfai-1.1.9.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.7.dist-info/RECORD +0 -39
  33. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,87 @@
1
1
  import importlib
2
- import torch
3
- from abc import ABC, abstractmethod
4
- import numpy as np
5
- import SimpleITK as sitk
6
- import torch.nn.functional as F
7
- from typing import Union
8
2
  import os
9
- from konfai import KONFAI_ROOT
10
- from konfai.utils.config import config
11
- from konfai.utils.utils import _getModule, AugmentationError
12
- from konfai.utils.dataset import Attribute, data_to_image, Dataset
13
-
14
- def _translate2DMatrix(t: torch.Tensor) -> torch.Tensor:
15
- return torch.cat((torch.cat((torch.eye(2), torch.tensor([[t[0]], [t[1]]])), dim=1), torch.Tensor([[0,0,1]])), dim=0)
16
-
17
- def _translate3DMatrix(t: torch.Tensor) -> torch.Tensor:
18
- return torch.cat((torch.cat((torch.eye(3), torch.tensor([[t[0]], [t[1]], [t[2]]])), dim=1), torch.Tensor([[0,0,0,1]])), dim=0)
19
-
20
- def _scale2DMatrix(s: torch.Tensor) -> torch.Tensor:
21
- return torch.cat((torch.cat((torch.eye(2)*s, torch.tensor([[0], [0]])), dim=1), torch.tensor([[0, 0, 1]])), dim=0)
3
+ from abc import ABC, abstractmethod
22
4
 
23
- def _scale3DMatrix(s: torch.Tensor) -> torch.Tensor:
24
- return torch.cat((torch.cat((torch.eye(3)*s, torch.tensor([[0], [0], [0]])), dim=1), torch.tensor([[0, 0, 0, 1]])), dim=0)
5
+ import numpy as np
6
+ import SimpleITK as sitk # noqa: N813
7
+ import torch
8
+ import torch.nn.functional as F # noqa: N812
25
9
 
26
- def _rotation3DMatrix(rotation : torch.Tensor, center: Union[torch.Tensor, None] = None) -> torch.Tensor:
27
- A = torch.tensor([[torch.cos(rotation[2]), -torch.sin(rotation[2]), 0], [torch.sin(rotation[2]), torch.cos(rotation[2]), 0], [0, 0, 1]])
28
- B = torch.tensor([[torch.cos(rotation[1]), 0, torch.sin(rotation[1])], [0, 1, 0], [-torch.sin(rotation[1]), 0, torch.cos(rotation[1])]])
29
- C = torch.tensor([[1, 0, 0], [0, torch.cos(rotation[0]), -torch.sin(rotation[0])], [0, torch.sin(rotation[0]), torch.cos(rotation[0])]])
30
- rotation_matrix = torch.cat((torch.cat((A.mm(B).mm(C), torch.zeros((3, 1))), dim=1), torch.tensor([[0, 0, 0, 1]])), dim=0)
10
+ from konfai import konfai_root
11
+ from konfai.utils.config import config
12
+ from konfai.utils.dataset import Attribute, Dataset, data_to_image
13
+ from konfai.utils.utils import AugmentationError, NeedDevice, get_module
14
+
15
+
16
+ def _translate_2d_matrix(t: torch.Tensor) -> torch.Tensor:
17
+ return torch.cat(
18
+ (
19
+ torch.cat((torch.eye(2), torch.tensor([[t[0]], [t[1]]])), dim=1),
20
+ torch.Tensor([[0, 0, 1]]),
21
+ ),
22
+ dim=0,
23
+ )
24
+
25
+
26
+ def _translate_3d_matrix(t: torch.Tensor) -> torch.Tensor:
27
+ return torch.cat(
28
+ (
29
+ torch.cat((torch.eye(3), torch.tensor([[t[0]], [t[1]], [t[2]]])), dim=1),
30
+ torch.Tensor([[0, 0, 0, 1]]),
31
+ ),
32
+ dim=0,
33
+ )
34
+
35
+
36
+ def _scale_2d_matrix(s: torch.Tensor) -> torch.Tensor:
37
+ return torch.cat(
38
+ (
39
+ torch.cat((torch.eye(2) * s, torch.tensor([[0], [0]])), dim=1),
40
+ torch.tensor([[0, 0, 1]]),
41
+ ),
42
+ dim=0,
43
+ )
44
+
45
+
46
+ def _scale_3d_matrix(s: torch.Tensor) -> torch.Tensor:
47
+ return torch.cat(
48
+ (
49
+ torch.cat((torch.eye(3) * s, torch.tensor([[0], [0], [0]])), dim=1),
50
+ torch.tensor([[0, 0, 0, 1]]),
51
+ ),
52
+ dim=0,
53
+ )
54
+
55
+
56
+ def _rotation_3d_matrix(rotation: torch.Tensor, center: torch.Tensor | None = None) -> torch.Tensor:
57
+ a = torch.tensor(
58
+ [
59
+ [torch.cos(rotation[2]), -torch.sin(rotation[2]), 0],
60
+ [torch.sin(rotation[2]), torch.cos(rotation[2]), 0],
61
+ [0, 0, 1],
62
+ ]
63
+ )
64
+ b = torch.tensor(
65
+ [
66
+ [torch.cos(rotation[1]), 0, torch.sin(rotation[1])],
67
+ [0, 1, 0],
68
+ [-torch.sin(rotation[1]), 0, torch.cos(rotation[1])],
69
+ ]
70
+ )
71
+ c = torch.tensor(
72
+ [
73
+ [1, 0, 0],
74
+ [0, torch.cos(rotation[0]), -torch.sin(rotation[0])],
75
+ [0, torch.sin(rotation[0]), torch.cos(rotation[0])],
76
+ ]
77
+ )
78
+ rotation_matrix = torch.cat(
79
+ (
80
+ torch.cat((a.mm(b).mm(c), torch.zeros((3, 1))), dim=1),
81
+ torch.tensor([[0, 0, 0, 1]]),
82
+ ),
83
+ dim=0,
84
+ )
31
85
  if center is not None:
32
86
  translation_before = torch.eye(4)
33
87
  translation_before[:-1, -1] = -center
@@ -38,117 +92,197 @@ def _rotation3DMatrix(rotation : torch.Tensor, center: Union[torch.Tensor, None]
38
92
  rotation_matrix = rotation_matrix.mm(translation_after)
39
93
  return rotation_matrix
40
94
 
41
- def _rotation2DMatrix(rotation : torch.Tensor, center: Union[torch.Tensor, None] = None) -> torch.Tensor:
42
- return torch.cat((torch.cat((torch.tensor([[torch.cos(rotation[0]), -torch.sin(rotation[0])], [torch.sin(rotation[0]), torch.cos(rotation[0])]]), torch.zeros((2, 1))), dim=1), torch.tensor([[0, 0, 1]])), dim=0)
43
95
 
44
- class Prob():
96
+ def _rotation_2d_matrix(rotation: torch.Tensor, center: torch.Tensor | None = None) -> torch.Tensor:
97
+ return torch.cat(
98
+ (
99
+ torch.cat(
100
+ (
101
+ torch.tensor(
102
+ [
103
+ [torch.cos(rotation[0]), -torch.sin(rotation[0])],
104
+ [torch.sin(rotation[0]), torch.cos(rotation[0])],
105
+ ]
106
+ ),
107
+ torch.zeros((2, 1)),
108
+ ),
109
+ dim=1,
110
+ ),
111
+ torch.tensor([[0, 0, 1]]),
112
+ ),
113
+ dim=0,
114
+ )
115
+
116
+
117
+ class Prob:
45
118
 
46
119
  @config()
47
120
  def __init__(self, prob: float = 1.0) -> None:
48
121
  self.prob = prob
49
122
 
50
- class DataAugmentationsList():
123
+
124
+ class DataAugmentationsList:
51
125
 
52
126
  @config()
53
- def __init__(self, nb : int = 10, dataAugmentations: dict[str, Prob] = {"default:Flip" : Prob(1)}) -> None:
127
+ def __init__(
128
+ self,
129
+ nb: int = 10,
130
+ data_augmentations: dict[str, Prob] = {"default:Flip": Prob(1)},
131
+ ) -> None:
54
132
  self.nb = nb
55
- self.dataAugmentations : list[DataAugmentation] = []
56
- self.dataAugmentationsLoader = dataAugmentations
133
+ self.data_augmentations: list[DataAugmentation] = []
134
+ self.data_augmentationsLoader = data_augmentations
57
135
 
58
136
  def load(self, key: str, datasets: list[Dataset]):
59
- for augmentation, prob in self.dataAugmentationsLoader.items():
60
- module, name = _getModule(augmentation, "konfai.data.augmentation")
61
- dataAugmentation: DataAugmentation = config("{}.Dataset.augmentations.{}.dataAugmentations.{}".format(KONFAI_ROOT(), key, augmentation))(getattr(importlib.import_module(module), name))(config = None)
62
- dataAugmentation.load(prob.prob)
63
- dataAugmentation.setDatasets(datasets)
64
- self.dataAugmentations.append(dataAugmentation)
137
+ for augmentation, prob in self.data_augmentationsLoader.items():
138
+ module, name = get_module(augmentation, "konfai.data.augmentation")
139
+ data_augmentation: DataAugmentation = config(
140
+ f"{konfai_root()}.Dataset.augmentations.{key}.data_augmentations.{augmentation}"
141
+ )(getattr(importlib.import_module(module), name))(config=None)
142
+ data_augmentation.load(prob.prob)
143
+ data_augmentation.set_datasets(datasets)
144
+ self.data_augmentations.append(data_augmentation)
145
+
65
146
 
66
- class DataAugmentation(ABC):
147
+ class DataAugmentation(NeedDevice, ABC):
67
148
 
68
- def __init__(self, groups: Union[list[str], None] = None) -> None:
149
+ def __init__(self, groups: list[str] | None = None) -> None:
69
150
  self.who_index: dict[int, list[int]] = {}
70
151
  self.shape_index: dict[int, list[list[int]]] = {}
71
152
  self._prob: float = 0
72
153
  self.groups = groups
73
- self.datasets : list[Dataset] = []
154
+ self.datasets: list[Dataset] = []
74
155
 
75
156
  def load(self, prob: float):
76
157
  self._prob = prob
77
158
 
78
- def setDatasets(self, datasets: list[Dataset]):
159
+ def set_datasets(self, datasets: list[Dataset]):
79
160
  self.datasets = datasets
80
-
81
- def state_init(self, index: Union[None, int], shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
161
+
162
+ def state_init(
163
+ self,
164
+ index: None | int,
165
+ shapes: list[list[int]],
166
+ caches_attribute: list[Attribute],
167
+ ) -> list[list[int]]:
82
168
  if index is not None:
83
169
  if index not in self.who_index:
84
170
  self.who_index[index] = torch.where(torch.rand(len(shapes)) < self._prob)[0].tolist()
85
171
  else:
86
172
  return self.shape_index[index]
87
- else:
173
+ else:
88
174
  index = 0
89
175
  self.who_index[index] = torch.where(torch.rand(len(shapes)) < self._prob)[0].tolist()
90
-
176
+
91
177
  if len(self.who_index[index]) > 0:
92
- for i, shape in enumerate(self._state_init(index, [shapes[i] for i in self.who_index[index]], [caches_attribute[i] for i in self.who_index[index]])):
178
+ for i, shape in enumerate(
179
+ self._state_init(
180
+ index,
181
+ [shapes[i] for i in self.who_index[index]],
182
+ [caches_attribute[i] for i in self.who_index[index]],
183
+ )
184
+ ):
93
185
  shapes[self.who_index[index][i]] = shape
94
186
  self.shape_index[index] = shapes
95
187
  return self.shape_index[index]
96
-
188
+
97
189
  @abstractmethod
98
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
190
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
99
191
  pass
100
192
 
101
- def __call__(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
193
+ def __call__(
194
+ self,
195
+ name: str,
196
+ index: int,
197
+ tensors: list[torch.Tensor],
198
+ ) -> list[torch.Tensor]:
102
199
  if len(self.who_index[index]) > 0:
103
- for i, result in enumerate(self._compute(name, index, [inputs[i] for i in self.who_index[index]], device)):
104
- inputs[self.who_index[index][i]] = result if device is None else result.cpu()
105
- return inputs
106
-
200
+ for i, result in enumerate(self._compute(name, index, [tensors[i] for i in self.who_index[index]])):
201
+ tensors[self.who_index[index][i]] = result
202
+ return tensors
203
+
107
204
  @abstractmethod
108
- def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
205
+ def _compute(
206
+ self,
207
+ name: str,
208
+ index: int,
209
+ tensors: list[torch.Tensor],
210
+ ) -> list[torch.Tensor]:
109
211
  pass
110
212
 
111
- def inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
213
+ def inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
112
214
  if a in self.who_index[index]:
113
- input = self._inverse(index, a, input)
114
- return input
115
-
215
+ tensor = self._inverse(index, a, tensor)
216
+ return tensor
217
+
116
218
  @abstractmethod
117
- def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
219
+ def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
118
220
  pass
119
221
 
222
+
120
223
  class EulerTransform(DataAugmentation):
121
224
 
122
225
  def __init__(self) -> None:
123
226
  super().__init__()
124
227
  self.matrix: dict[int, list[torch.Tensor]] = {}
125
228
 
126
- def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
229
+ def _compute(
230
+ self,
231
+ name: str,
232
+ index: int,
233
+ tensors: list[torch.Tensor],
234
+ ) -> list[torch.Tensor]:
127
235
  results = []
128
- for input, matrix in zip(inputs, self.matrix[index]):
129
- results.append(F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0))
236
+ for tensor, matrix in zip(tensors, self.matrix[index]):
237
+ results.append(
238
+ F.grid_sample(
239
+ tensor.unsqueeze(0).type(torch.float32),
240
+ F.affine_grid(matrix[:, :-1, ...], [1] + list(tensor.shape), align_corners=True).to(tensor.device),
241
+ align_corners=True,
242
+ mode="bilinear",
243
+ padding_mode="reflection",
244
+ )
245
+ .type(tensor.dtype)
246
+ .squeeze(0)
247
+ )
130
248
  return results
131
-
132
- def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
133
- return F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(self.matrix[index][a].inverse()[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0)
134
-
249
+
250
+ def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
251
+ return (
252
+ F.grid_sample(
253
+ tensor.unsqueeze(0).type(torch.float32),
254
+ F.affine_grid(
255
+ self.matrix[index][a].inverse()[:, :-1, ...],
256
+ [1] + list(tensor.shape),
257
+ align_corners=True,
258
+ ).to(tensor.device),
259
+ align_corners=True,
260
+ mode="bilinear",
261
+ padding_mode="reflection",
262
+ )
263
+ .type(tensor.dtype)
264
+ .squeeze(0)
265
+ )
266
+
267
+
135
268
  class Translate(EulerTransform):
136
-
137
- def __init__(self, t_min: float = -10, t_max = 10, is_int: bool = False):
269
+
270
+ def __init__(self, t_min: float = -10, t_max=10, is_int: bool = False):
138
271
  super().__init__()
139
272
  self.t_min = t_min
140
273
  self.t_max = t_max
141
274
  self.is_int = is_int
142
275
 
143
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
276
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
144
277
  dim = len(shapes[0])
145
- func = _translate3DMatrix if dim == 3 else _translate2DMatrix
146
- translate = torch.rand((len(shapes), dim)) * torch.tensor(self.t_max-self.t_min) + torch.tensor(self.t_min)
278
+ func = _translate_3d_matrix if dim == 3 else _translate_2d_matrix
279
+ translate = torch.rand((len(shapes), dim)) * torch.tensor(self.t_max - self.t_min) + torch.tensor(self.t_min)
147
280
  if self.is_int:
148
- translate = torch.round(translate*100)/100
281
+ translate = torch.round(translate * 100) / 100
149
282
  self.matrix[index] = [torch.unsqueeze(func(value), dim=0) for value in translate]
150
283
  return shapes
151
284
 
285
+
152
286
  class Rotate(EulerTransform):
153
287
 
154
288
  def __init__(self, a_min: float = 0, a_max: float = 360, is_quarter: bool = False):
@@ -157,18 +291,19 @@ class Rotate(EulerTransform):
157
291
  self.a_max = a_max
158
292
  self.is_quarter = is_quarter
159
293
 
160
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
294
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
161
295
  dim = len(shapes[0])
162
- func = _rotation3DMatrix if dim == 3 else _rotation2DMatrix
296
+ func = _rotation_3d_matrix if dim == 3 else _rotation_2d_matrix
163
297
  angles = []
164
-
298
+
165
299
  if self.is_quarter:
166
- angles = torch.Tensor.repeat(torch.tensor([90,180,270]), 3)
300
+ angles = torch.Tensor.repeat(torch.tensor([90, 180, 270]), 3)
167
301
  else:
168
- angles = torch.rand((len(shapes), dim))*torch.tensor(self.a_max-self.a_min) + torch.tensor(self.a_min)
302
+ angles = torch.rand((len(shapes), dim)) * torch.tensor(self.a_max - self.a_min) + torch.tensor(self.a_min)
169
303
 
170
304
  self.matrix[index] = [torch.unsqueeze(func(value), dim=0) for value in angles]
171
- return shapes
305
+ return shapes
306
+
172
307
 
173
308
  class Scale(EulerTransform):
174
309
 
@@ -176,193 +311,148 @@ class Scale(EulerTransform):
176
311
  super().__init__()
177
312
  self.s_std = s_std
178
313
 
179
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
180
- func = _scale3DMatrix if len(shapes[0]) == 3 else _scale2DMatrix
181
- scale = torch.Tensor.repeat(torch.exp2(torch.randn((len(shapes))) * self.s_std).unsqueeze(1), [1, len(shapes[0])])
314
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
315
+ func = _scale_3d_matrix if len(shapes[0]) == 3 else _scale_2d_matrix
316
+ scale = torch.Tensor.repeat(
317
+ torch.exp2(torch.randn(len(shapes)) * self.s_std).unsqueeze(1),
318
+ [1, len(shapes[0])],
319
+ )
182
320
  self.matrix[index] = [torch.unsqueeze(func(value), dim=0) for value in scale]
183
321
  return shapes
184
322
 
323
+
185
324
  class Flip(DataAugmentation):
186
325
 
187
- def __init__(self, f_prob: Union[list[float], None] = [0.33, 0.33 ,0.33]) -> None:
326
+ def __init__(self, f_prob: list[float] = [0.33, 0.33, 0.33]) -> None:
188
327
  super().__init__()
189
328
  self.f_prob = f_prob
190
329
  self.flip: dict[int, list[int]] = {}
191
330
 
192
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
331
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
193
332
  prob = torch.rand((len(shapes), len(self.f_prob))) < torch.tensor(self.f_prob)
194
- dims = torch.tensor([1, 2, 3][:len(self.f_prob)])
333
+ dims = torch.tensor([1, 2, 3][: len(self.f_prob)])
195
334
  self.flip[index] = [dims[mask].tolist() for mask in prob]
196
335
  return shapes
197
336
 
198
- def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
337
+ def _compute(
338
+ self,
339
+ name: str,
340
+ index: int,
341
+ tensors: list[torch.Tensor],
342
+ ) -> list[torch.Tensor]:
199
343
  results = []
200
- for input, flip in zip(inputs, self.flip[index]):
201
- results.append(torch.flip(input, dims=flip))
344
+ for tensor, flip in zip(tensors, self.flip[index]):
345
+ results.append(torch.flip(tensor, dims=flip))
202
346
  return results
203
-
204
- def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
205
- return torch.flip(input, dims=self.flip[index][a])
206
-
347
+
348
+ def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
349
+ return torch.flip(tensor, dims=self.flip[index][a])
350
+
207
351
 
208
352
  class ColorTransform(DataAugmentation):
209
353
 
210
- def __init__(self, groups: Union[list[str], None] = None) -> None:
354
+ def __init__(self, groups: list[str] | None = None) -> None:
211
355
  super().__init__(groups)
212
356
  self.matrix: dict[int, list[torch.Tensor]] = {}
213
-
214
- def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
357
+
358
+ def _compute(
359
+ self,
360
+ name: str,
361
+ index: int,
362
+ tensors: list[torch.Tensor],
363
+ ) -> list[torch.Tensor]:
215
364
  results = []
216
- for input, matrix in zip(inputs, self.matrix[index]):
217
- result = input.reshape([*input.shape[:1], int(np.prod(input.shape[1:]))])
218
- if input.shape[0] == 3:
219
- matrix = matrix.to(input.device)
365
+ for tensor, matrix in zip(tensors, self.matrix[index]):
366
+ result = tensor.reshape([*tensor.shape[:1], int(np.prod(tensor.shape[1:]))])
367
+ if tensor.shape[0] == 3:
368
+ matrix = matrix.to(tensor.device)
220
369
  result = matrix[:, :3, :3] @ result.float() + matrix[:, :3, 3:]
221
- elif input.shape[0] == 1:
222
- matrix = matrix[:, :3, :].mean(dim=1, keepdims=True).to(input.device)
370
+ elif tensor.shape[0] == 1:
371
+ matrix = matrix[:, :3, :].mean(dim=1, keepdims=True).to(tensor.device)
223
372
  result = result.float() * matrix[:, :, :3].sum(dim=2, keepdims=True) + matrix[:, :, 3:]
224
373
  else:
225
- raise AugmentationError('Image must be RGB (3 channels) or L (1 channel)')
226
- results.append(result.reshape(input.shape))
374
+ raise AugmentationError("Image must be RGB (3 channels) or L (1 channel)")
375
+ results.append(result.reshape(tensor.shape))
227
376
  return results
228
-
229
- def _inverse(self, index: int, a: int, inputs : torch.Tensor) -> torch.Tensor:
377
+
378
+ def _inverse(self, index: int, a: int, tensors: torch.Tensor) -> torch.Tensor:
230
379
  pass
231
380
 
381
+
232
382
  class Brightness(ColorTransform):
233
383
 
234
- def __init__(self, b_std: float, groups: Union[list[str], None] = None) -> None:
384
+ def __init__(self, b_std: float, groups: list[str] | None = None) -> None:
235
385
  super().__init__(groups)
236
386
  self.b_std = b_std
237
387
 
238
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
239
- brightness = torch.Tensor.repeat((torch.randn((len(shapes)))*self.b_std).unsqueeze(1), [1, 3])
240
- self.matrix[index] = [torch.unsqueeze(_translate3DMatrix(value), dim=0) for value in brightness]
388
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
389
+ brightness = torch.Tensor.repeat((torch.randn(len(shapes)) * self.b_std).unsqueeze(1), [1, 3])
390
+ self.matrix[index] = [torch.unsqueeze(_translate_3d_matrix(value), dim=0) for value in brightness]
241
391
  return shapes
242
392
 
393
+
243
394
  class Contrast(ColorTransform):
244
395
 
245
- def __init__(self, c_std: float, groups: Union[list[str], None] = None) -> None:
396
+ def __init__(self, c_std: float, groups: list[str] | None = None) -> None:
246
397
  super().__init__(groups)
247
398
  self.c_std = c_std
248
399
 
249
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
250
- contrast = torch.exp2(torch.randn((len(shapes))) * self.c_std)
251
- self.matrix[index] = [torch.unsqueeze(_scale3DMatrix(value), dim=0) for value in contrast]
400
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
401
+ contrast = torch.exp2(torch.randn(len(shapes)) * self.c_std)
402
+ self.matrix[index] = [torch.unsqueeze(_scale_3d_matrix(value), dim=0) for value in contrast]
252
403
  return shapes
253
404
 
405
+
254
406
  class LumaFlip(ColorTransform):
255
407
 
256
- def __init__(self, groups: Union[list[str], None] = None) -> None:
408
+ def __init__(self, groups: list[str] | None = None) -> None:
257
409
  super().__init__(groups)
258
- self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
259
-
260
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
410
+ self.v = torch.tensor([1, 1, 1, 0]) / torch.sqrt(torch.tensor(3))
411
+
412
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
261
413
  luma = torch.floor(torch.rand([len(shapes), 1, 1]) * 2)
262
414
  self.matrix[index] = [torch.unsqueeze((torch.eye(4) - 2 * self.v.ger(self.v) * value), dim=0) for value in luma]
263
415
  return shapes
264
416
 
417
+
265
418
  class HUE(ColorTransform):
266
419
 
267
- def __init__(self, hue_max: float, groups: Union[list[str], None] = None) -> None:
420
+ def __init__(self, hue_max: float, groups: list[str] | None = None) -> None:
268
421
  super().__init__(groups)
269
422
  self.hue_max = hue_max
270
- self.v = torch.tensor([1, 1, 1])/torch.sqrt(torch.tensor(3))
271
-
272
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
423
+ self.v = torch.tensor([1, 1, 1]) / torch.sqrt(torch.tensor(3))
424
+
425
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
273
426
  theta = (torch.rand([len(shapes)]) * 2 - 1) * np.pi * self.hue_max
274
- self.matrix[index] = [torch.unsqueeze(_rotation3DMatrix(value.repeat(3), self.v), dim=0) for value in theta]
427
+ self.matrix[index] = [torch.unsqueeze(_rotation_3d_matrix(value.repeat(3), self.v), dim=0) for value in theta]
275
428
  return shapes
276
-
429
+
430
+
277
431
  class Saturation(ColorTransform):
278
432
 
279
- def __init__(self, s_std: float, groups: Union[list[str], None] = None) -> None:
433
+ def __init__(self, s_std: float, groups: list[str] | None = None) -> None:
280
434
  super().__init__(groups)
281
435
  self.s_std = s_std
282
- self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
436
+ self.v = torch.tensor([1, 1, 1, 0]) / torch.sqrt(torch.tensor(3))
283
437
 
284
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
285
- saturation = torch.exp2(torch.randn((len(shapes))) * self.s_std)
286
- self.matrix[index] = [(self.v.ger(self.v) + (torch.eye(4) - self.v.ger(self.v))).unsqueeze(0) * value for value in saturation]
438
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
439
+ saturation = torch.exp2(torch.randn(len(shapes)) * self.s_std)
440
+ self.matrix[index] = [
441
+ (self.v.ger(self.v) + (torch.eye(4) - self.v.ger(self.v))).unsqueeze(0) * value for value in saturation
442
+ ]
287
443
  return shapes
288
-
289
- """class Filter(DataAugmentation):
290
444
 
291
- def __init__(self) -> None:
292
- super().__init__()
293
- wavelets = {
294
- 'haar': [0.7071067811865476, 0.7071067811865476],
295
- 'db1': [0.7071067811865476, 0.7071067811865476],
296
- 'db2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
297
- 'db3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
298
- 'db4': [-0.010597401784997278, 0.032883011666982945, 0.030841381835986965, -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, 0.7148465705525415, 0.23037781330885523],
299
- 'db5': [0.003335725285001549, -0.012580751999015526, -0.006241490213011705, 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, 0.160102397974125],
300
- 'db6': [-0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, 0.7511339080215775, 0.4946238903983854, 0.11154074335008017],
301
- 'db7': [0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, 0.39653931948230575, 0.07785205408506236],
302
- 'db8': [-0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, -0.04408825393106472, -0.01736930100202211, 0.128747426620186, 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, 0.05441584224308161],
303
- 'sym2': [-0.12940952255092145, 0.22414386804185735, 0.836516303737469, 0.48296291314469025],
304
- 'sym3': [0.035226291882100656, -0.08544127388224149, -0.13501102001039084, 0.4598775021193313, 0.8068915093133388, 0.3326705529509569],
305
- 'sym4': [-0.07576571478927333, -0.02963552764599851, 0.49761866763201545, 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, -0.012603967262037833, 0.0322231006040427],
306
- 'sym5': [0.027333068345077982, 0.029519490925774643, -0.039134249302383094, 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, 0.019538882735286728],
307
- 'sym6': [0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, -0.048311742585633, 0.4910559419267466, 0.787641141030194, 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148],
308
- 'sym7': [0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, 0.017441255086855827, 0.5361019170917628, 0.767764317003164, 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, 0.004010244871533663, 0.010268176708511255],
309
- 'sym8': [-0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, 0.0018899503327594609],
310
- }
311
- Hz_lo = np.asarray(wavelets['sym2']) # H(z)
312
- Hz_hi = Hz_lo * ((-1) ** np.arange(Hz_lo.size)) # H(-z)
313
- Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2
314
- Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2
315
- Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i)
316
- for i in range(1, Hz_fbank.shape[0]):
317
- Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank)]).reshape(Hz_fbank.shape[0], -1)[:, :-1]
318
- Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2])
319
- Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // 2 : (Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2
320
- self.Hz_fbank = torch.as_tensor(Hz_fbank, dtype=torch.float32)
321
-
322
- self.imgfilter_bands = [1,1,1,1]
323
- self.imgfilter_std = 1
324
-
325
-
326
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
327
-
328
- return shapes
329
-
330
- def _compute(self, index: int, inputs : list[torch.Tensor]) -> list[torch.Tensor]:
331
- num_bands = self.Hz_fbank.shape[0]
332
- assert len(self.imgfilter_bands) == num_bands
333
- expected_power =torch.tensor([10, 1, 1, 1]) / 13 # Expected power spectrum (1/f).
334
- for input in inputs:
335
- batch_size = input.shape[0]
336
- num_channels = input.shape[1]
337
- # Apply amplification for each band with probability (imgfilter * strength * band_strength).
338
- g = torch.ones([batch_size, num_bands]) # Global gain vector (identity).
339
- for i, band_strength in enumerate(self.imgfilter_bands):
340
- t_i = torch.exp2(torch.randn([batch_size]) * self.imgfilter_std)
341
- t = torch.ones([batch_size, num_bands]) # Temporary gain vector.
342
- t[:, i] = t_i # Replace i'th element.
343
- t = t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt() # Normalize power.
344
- g = g * t # Accumulate into global gain.
345
-
346
- # Construct combined amplification filter.
347
- Hz_prime = g @ self.Hz_fbank # [batch, tap]
348
- Hz_prime = Hz_prime.unsqueeze(1).repeat([1, num_channels, 1]) # [batch, channels, tap]
349
- Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, -1]) # [batch * channels, 1, tap]
350
-
351
- # Apply filter.
352
- p = self.Hz_fbank.shape[1] // 2
353
- images = images.reshape([1, batch_size * num_channels, height, width])
354
- images = torch.nn.functional.pad(input=images, pad=[p,p,p,p], mode='reflect')
355
- images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(2), groups=batch_size*num_channels)
356
- images = conv2d_gradfix.conv2d(input=images, weight=Hz_prime.unsqueeze(3), groups=batch_size*num_channels)
357
- images = images.reshape([batch_size, num_channels, height, width])
358
-
359
-
360
- def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
361
- pass """
362
445
 
363
446
  class Noise(DataAugmentation):
364
447
 
365
- def __init__(self, n_std: float, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02, groups: Union[list[str], None] = None) -> None:
448
+ def __init__(
449
+ self,
450
+ n_std: float,
451
+ noise_step: int = 1000,
452
+ beta_start: float = 1e-4,
453
+ beta_end: float = 0.02,
454
+ groups: list[str] | None = None,
455
+ ) -> None:
366
456
  super().__init__(groups)
367
457
  self.n_std = n_std
368
458
  self.noise_step = noise_step
@@ -372,78 +462,108 @@ class Noise(DataAugmentation):
372
462
  self.betas = Noise.enforce_zero_terminal_snr(self.betas)
373
463
  self.alphas = 1 - self.betas
374
464
  self.alpha_hat = torch.concat((torch.ones(1), torch.cumprod(self.alphas, dim=0)))
375
- self.max_T = 0
465
+ self.max_T = 0.0
376
466
 
377
467
  self.C = 1
378
468
  self.n = 4
379
469
  self.d = 0.25
380
470
  self._prob = 1
381
471
 
472
+ @staticmethod
382
473
  def enforce_zero_terminal_snr(betas: torch.Tensor):
383
474
  alphas = 1 - betas
384
475
  alphas_bar = alphas.cumprod(0)
385
476
  alphas_bar_sqrt = alphas_bar.sqrt()
386
477
  alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
387
- alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
388
- alphas_bar_sqrt -= alphas_bar_sqrt_T
389
- alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
390
- alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
391
- alphas_bar = alphas_bar_sqrt ** 2
478
+ alphas_bar_sqrt_t = alphas_bar_sqrt[-1].clone()
479
+ alphas_bar_sqrt -= alphas_bar_sqrt_t
480
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_t)
481
+ alphas_bar = alphas_bar_sqrt**2
392
482
  alphas = alphas_bar[1:] / alphas_bar[:-1]
393
483
  alphas = torch.cat([alphas_bar[0:1], alphas])
394
484
  betas = 1 - alphas
395
485
  return betas
396
-
486
+
397
487
  def load(self, prob: float):
398
- self.max_T = prob*self.noise_step
488
+ self.max_T = prob * self.noise_step
399
489
 
400
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
490
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
401
491
  if int(self.max_T) == 0:
402
492
  self.ts[index] = [0 for _ in shapes]
403
- else:
493
+ else:
404
494
  self.ts[index] = [torch.randint(0, int(self.max_T), (1,)) for _ in shapes]
405
495
  return shapes
406
-
407
- def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
496
+
497
+ def _compute(
498
+ self,
499
+ name: str,
500
+ index: int,
501
+ tensors: list[torch.Tensor],
502
+ ) -> list[torch.Tensor]:
408
503
  results = []
409
- for input, t in zip(inputs, self.ts[index]):
410
- alpha_hat_t = self.alpha_hat[t].to(input.device).reshape(*[1 for _ in range(len(input.shape))])
411
- results.append(alpha_hat_t.sqrt() * input + (1 - alpha_hat_t).sqrt() * torch.randn_like(input.float()).to(input.device)*self.n_std)
504
+ for tensor, t in zip(tensors, self.ts[index]):
505
+ alpha_hat_t = self.alpha_hat[t].to(tensor.device).reshape(*[1 for _ in range(len(tensor.shape))])
506
+ results.append(
507
+ alpha_hat_t.sqrt() * tensor
508
+ + (1 - alpha_hat_t).sqrt() * torch.randn_like(tensor.float()).to(tensor.device) * self.n_std
509
+ )
412
510
  return results
413
511
 
414
- def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
512
+ def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
415
513
  pass
416
514
 
515
+
417
516
  class CutOUT(DataAugmentation):
418
517
 
419
- def __init__(self, c_prob: float, cutout_size: int, value: float, groups: Union[list[str], None] = None) -> None:
518
+ def __init__(
519
+ self,
520
+ c_prob: float,
521
+ cutout_size: int,
522
+ value: float,
523
+ groups: list[str] | None = None,
524
+ ) -> None:
420
525
  super().__init__(groups)
421
526
  self.c_prob = c_prob
422
527
  self.cutout_size = cutout_size
423
528
  self.centers: dict[int, list[torch.Tensor]] = {}
424
529
  self.value = value
425
530
 
426
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
531
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
427
532
  self.centers[index] = [torch.rand((3) if len(shape) == 3 else (2)) for shape in shapes]
428
533
  return shapes
429
-
430
- def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
534
+
535
+ def _compute(
536
+ self,
537
+ name: str,
538
+ index: int,
539
+ tensors: list[torch.Tensor],
540
+ ) -> list[torch.Tensor]:
431
541
  results = []
432
- for input, center in zip(inputs, self.centers[index]):
542
+ for tensor, center in zip(tensors, self.centers[index]):
433
543
  masks = []
434
- for i, w in enumerate(input.shape[1:]):
435
- re = [1]*i+[-1]+[1]*(len(input.shape[1:])-i-1)
436
- masks.append((((torch.arange(w).reshape(re) + 0.5) / w - center[i].reshape([1, 1])).abs() >= torch.tensor(self.cutout_size).reshape([1, 1])/ 2))
544
+ for i, w in enumerate(tensor.shape[1:]):
545
+ re = [1] * i + [-1] + [1] * (len(tensor.shape[1:]) - i - 1)
546
+ masks.append(
547
+ ((torch.arange(w).reshape(re) + 0.5) / w - center[i].reshape([1, 1])).abs()
548
+ >= torch.tensor(self.cutout_size).reshape([1, 1]) / 2
549
+ )
437
550
  result = masks[0]
438
551
  for mask in masks[1:]:
439
552
  result = torch.logical_or(result, mask)
440
- result = result.unsqueeze(0).repeat([input.shape[0], *[1 for _ in range(len(input.shape)-1)]])
441
- results.append(torch.where(result.to(input.device) == 1, input, torch.tensor(self.value).to(input.device)))
553
+ result = result.unsqueeze(0).repeat([tensor.shape[0], *[1 for _ in range(len(tensor.shape) - 1)]])
554
+ results.append(
555
+ torch.where(
556
+ result.to(tensor.device) == 1,
557
+ tensor,
558
+ torch.tensor(self.value).to(tensor.device),
559
+ )
560
+ )
442
561
  return results
443
-
444
- def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
562
+
563
+ def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
445
564
  pass
446
565
 
566
+
447
567
  class Elastix(DataAugmentation):
448
568
 
449
569
  def __init__(self, grid_spacing: int = 16, max_displacement: int = 16) -> None:
@@ -452,16 +572,16 @@ class Elastix(DataAugmentation):
452
572
  self.max_displacement = max_displacement
453
573
  self.displacement_fields: dict[int, list[torch.Tensor]] = {}
454
574
  self.displacement_fields_true: dict[int, list[torch.Tensor]] = {}
455
-
575
+
456
576
  @staticmethod
457
- def _formatLoc(new_locs, shape):
577
+ def _format_loc(new_locs, shape):
458
578
  for i in range(len(shape)):
459
579
  new_locs[..., i] = 2 * (new_locs[..., i] / (shape[i] - 1) - 0.5)
460
- new_locs = new_locs[..., [i for i in reversed(range(len(shape)))]]
580
+ new_locs = new_locs[..., list(reversed(range(len(shape))))]
461
581
  return new_locs
462
582
 
463
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
464
- print("Compute Displacement Field for index {}".format(index))
583
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
584
+ print(f"Compute Displacement Field for index {index}")
465
585
  self.displacement_fields[index] = []
466
586
  self.displacement_fields_true[index] = []
467
587
  for i, (shape, cache_attribute) in enumerate(zip(shapes, caches_attribute)):
@@ -471,117 +591,168 @@ class Elastix(DataAugmentation):
471
591
  spacing = np.array([1.0 for _ in range(dim)])
472
592
  else:
473
593
  spacing = cache_attribute.get_np_array("Spacing")
474
-
475
- grid_physical_spacing = [self.grid_spacing]*dim
476
- image_physical_size = [size*spacing for size, spacing in zip(shape, spacing)]
477
- mesh_size = [int(image_size/grid_spacing + 0.5) for image_size,grid_spacing in zip(image_physical_size, grid_physical_spacing)]
594
+
595
+ grid_physical_spacing = [self.grid_spacing] * dim
596
+ image_physical_size = [size * spacing for size, spacing in zip(shape, spacing)]
597
+ mesh_size = [
598
+ int(image_size / grid_spacing + 0.5)
599
+ for image_size, grid_spacing in zip(image_physical_size, grid_physical_spacing)
600
+ ]
478
601
  if "Spacing" not in cache_attribute:
479
602
  cache_attribute["Spacing"] = np.array([1.0 for _ in range(dim)])
480
603
  if "Origin" not in cache_attribute:
481
604
  cache_attribute["Origin"] = np.array([1.0 for _ in range(dim)])
482
605
  if "Direction" not in cache_attribute:
483
606
  cache_attribute["Direction"] = np.eye(dim).flatten()
484
-
607
+
485
608
  ref_image = data_to_image(np.expand_dims(np.zeros(shape), 0), cache_attribute)
486
609
 
487
- bspline_transform = sitk.BSplineTransformInitializer(image1 = ref_image, transformDomainMeshSize = mesh_size, order=3)
610
+ bspline_transform = sitk.BSplineTransformInitializer(
611
+ image1=ref_image, transformDomainMeshSize=mesh_size, order=3
612
+ )
488
613
  displacement_filter = sitk.TransformToDisplacementFieldFilter()
489
614
  displacement_filter.SetReferenceImage(ref_image)
490
-
615
+
491
616
  vectors = [torch.arange(0, s) for s in shape]
492
- grids = torch.meshgrid(vectors, indexing='ij')
617
+ grids = torch.meshgrid(vectors, indexing="ij")
493
618
  grid = torch.stack(grids)
494
619
  grid = torch.unsqueeze(grid, 0)
495
- grid = grid.type(torch.float).permute([0]+[i+2 for i in range(len(shape))] + [1])
496
-
497
- control_points = torch.rand(*[size+3 for size in mesh_size], dim)
620
+ grid = grid.type(torch.float).permute([0] + [i + 2 for i in range(len(shape))] + [1])
621
+
622
+ control_points = torch.rand(*[size + 3 for size in mesh_size], dim)
498
623
  control_points -= 0.5
499
- control_points *= 2*self.max_displacement
624
+ control_points *= 2 * self.max_displacement
500
625
  bspline_transform.SetParameters(control_points.flatten().tolist())
501
626
  displacement = sitk.GetArrayFromImage(displacement_filter.Execute(bspline_transform))
502
627
  self.displacement_fields_true[index].append(displacement)
503
- new_locs = grid+torch.unsqueeze(torch.from_numpy(displacement), 0).type(torch.float32)
504
- self.displacement_fields[index].append(Elastix._formatLoc(new_locs, shape))
505
- print("Compute in progress : {:.2f} %".format((i+1)/len(shapes)*100))
628
+ new_locs = grid + torch.unsqueeze(torch.from_numpy(displacement), 0).type(torch.float32)
629
+ self.displacement_fields[index].append(Elastix._format_loc(new_locs, shape))
630
+ print(f"Compute in progress : {(i + 1) / len(shapes) * 100:.2f} %")
506
631
  return shapes
507
-
508
- def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
632
+
633
+ def _compute(
634
+ self,
635
+ name: str,
636
+ index: int,
637
+ tensors: list[torch.Tensor],
638
+ ) -> list[torch.Tensor]:
509
639
  results = []
510
- for input, displacement_field in zip(inputs, self.displacement_fields[index]):
511
- results.append(F.grid_sample(input.type(torch.float32).unsqueeze(0), displacement_field.to(input.device), align_corners=True, mode="bilinear", padding_mode="border").type(input.dtype).squeeze(0))
640
+ for tensor, displacement_field in zip(tensors, self.displacement_fields[index]):
641
+ results.append(
642
+ F.grid_sample(
643
+ tensor.type(torch.float32).unsqueeze(0),
644
+ displacement_field.to(tensor.device),
645
+ align_corners=True,
646
+ mode="bilinear",
647
+ padding_mode="border",
648
+ )
649
+ .type(tensor.dtype)
650
+ .squeeze(0)
651
+ )
512
652
  return results
513
-
514
- def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
653
+
654
+ def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
515
655
  pass
516
656
 
657
+
517
658
  class Permute(DataAugmentation):
518
659
 
519
- def __init__(self, prob_permute: Union[list[float], None] = [0.5 ,0.5]) -> None:
660
+ def __init__(self, prob_permute: list[float] | None = [0.5, 0.5]) -> None:
520
661
  super().__init__()
521
662
  self._permute_dims = torch.tensor([[0, 2, 1, 3], [0, 3, 1, 2]])
522
663
  self.prob_permute = prob_permute
523
664
  self.permute: dict[int, torch.Tensor] = {}
524
665
 
525
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
666
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
526
667
  if len(shapes):
527
668
  dim = len(shapes[0])
528
- assert dim == 3, "The permute augmentation only support 3D images"
669
+ if dim != 3:
670
+ raise ValueError("The permute augmentation only support 3D images")
529
671
  if self.prob_permute:
530
- assert len(self.prob_permute) == 2, "len of prob_permute must be equal 2"
531
- self.permute[index] = torch.rand((len(shapes), len(self.prob_permute))) < torch.tensor(self.prob_permute)
672
+ if len(self.prob_permute) != 2:
673
+ raise ValueError("Size of prob_permute must be equal 2")
674
+ self.permute[index] = torch.rand((len(shapes), len(self.prob_permute))) < torch.tensor(
675
+ self.prob_permute
676
+ )
532
677
  else:
533
- assert len(shapes) == 2, "The number of augmentation images must be equal to 2"
678
+ if len(shapes) != 2:
679
+ raise ValueError("The number of augmentation images must be equal to 2")
534
680
  self.permute[index] = torch.eye(2, dtype=torch.bool)
535
681
  for i, prob in enumerate(self.permute[index]):
536
682
  for permute in self._permute_dims[prob]:
537
- shapes[i] = [shapes[i][dim-1] for dim in permute[1:]]
683
+ shapes[i] = [shapes[i][dim - 1] for dim in permute[1:]]
538
684
  return shapes
539
-
540
- def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
685
+
686
+ def _compute(
687
+ self,
688
+ name: str,
689
+ index: int,
690
+ tensors: list[torch.Tensor],
691
+ ) -> list[torch.Tensor]:
541
692
  results = []
542
- for input, prob in zip(inputs, self.permute[index]):
543
- res = input
693
+ for tensor, prob in zip(tensors, self.permute[index]):
694
+ res = tensor
544
695
  for permute in self._permute_dims[prob]:
545
696
  res = res.permute(tuple(permute))
546
697
  results.append(res)
547
698
  return results
548
-
549
- def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
699
+
700
+ def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
550
701
  for permute in reversed(self._permute_dims[self.permute[index][a]]):
551
- input = input.permute(tuple(np.argsort(permute)))
552
- return input
702
+ tensor = tensor.permute(tuple(np.argsort(permute)))
703
+ return tensor
704
+
553
705
 
554
706
  class Mask(DataAugmentation):
555
707
 
556
- def __init__(self, mask: str, value: float, groups: Union[list[str], None] = None) -> None:
708
+ def __init__(self, mask: str, value: float, groups: list[str] | None = None) -> None:
557
709
  super().__init__(groups)
558
710
  if mask is not None:
559
711
  if os.path.exists(mask):
560
712
  self.mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(mask)))
561
713
  else:
562
- raise NameError('Mask file not found')
714
+ raise NameError("Mask file not found")
563
715
  self.positions: dict[int, list[torch.Tensor]] = {}
564
716
  self.value = value
565
717
 
566
- def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
567
- self.positions[index] = [torch.rand((3) if len(shape) == 3 else (2))*(torch.tensor([max(s1-s2, 0) for s1, s2 in zip(torch.tensor(shape), torch.tensor(self.mask.shape))])) for shape in shapes]
718
+ def _state_init(self, index: int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
719
+ self.positions[index] = [
720
+ torch.rand((3) if len(shape) == 3 else (2))
721
+ * (torch.tensor([max(s1 - s2, 0) for s1, s2 in zip(torch.tensor(shape), torch.tensor(self.mask.shape))]))
722
+ for shape in shapes
723
+ ]
568
724
  return [self.mask.shape for _ in shapes]
569
-
570
- def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
725
+
726
+ def _compute(
727
+ self,
728
+ name: str,
729
+ index: int,
730
+ tensors: list[torch.Tensor],
731
+ ) -> list[torch.Tensor]:
571
732
  results = []
572
- for input, position in zip(inputs, self.positions[index]):
573
- slices = [slice(None, None)]+[slice(int(s1), int(s1)+s2) for s1, s2 in zip(position, self.mask.shape)]
733
+ for tensor, position in zip(tensors, self.positions[index]):
734
+ slices = [slice(None, None)] + [slice(int(s1), int(s1) + s2) for s1, s2 in zip(position, self.mask.shape)]
574
735
  padding = []
575
- for s1, s2 in zip(reversed(input.shape), reversed(self.mask.shape)):
736
+ for s1, s2 in zip(reversed(tensor.shape), reversed(self.mask.shape)):
576
737
  if s1 < s2:
577
- pad = s2-s1
738
+ pad = s2 - s1
578
739
  else:
579
740
  pad = 0
580
741
  padding.append(0)
581
742
  padding.append(pad)
582
- value = torch.tensor(0, dtype=torch.uint8) if input.dtype == torch.uint8 else torch.tensor(self.value).to(input.device)
583
- results.append(torch.where(self.mask.to(input.device) == 1, torch.nn.functional.pad(input, tuple(padding), mode="constant", value=value)[tuple(slices)], value))
743
+ value = (
744
+ torch.tensor(0, dtype=torch.uint8)
745
+ if tensor.dtype == torch.uint8
746
+ else torch.tensor(self.value).to(tensor.device)
747
+ )
748
+ results.append(
749
+ torch.where(
750
+ self.mask.to(tensor.device) == 1,
751
+ torch.nn.functional.pad(tensor, tuple(padding), mode="constant", value=value)[tuple(slices)],
752
+ value,
753
+ )
754
+ )
584
755
  return results
585
-
586
- def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
756
+
757
+ def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
587
758
  pass