konfai 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +1 -1
- konfai/data/augmentation.py +41 -36
- konfai/data/data_manager.py +57 -34
- konfai/data/patching.py +37 -13
- konfai/data/transform.py +49 -21
- konfai/evaluator.py +24 -7
- konfai/main.py +5 -3
- konfai/models/segmentation/UNet.py +9 -10
- konfai/network/network.py +0 -1
- konfai/predictor.py +41 -21
- konfai/trainer.py +24 -10
- konfai/utils/dataset.py +27 -2
- konfai/utils/utils.py +49 -12
- {konfai-1.1.1.dist-info → konfai-1.1.3.dist-info}/METADATA +1 -1
- {konfai-1.1.1.dist-info → konfai-1.1.3.dist-info}/RECORD +19 -19
- {konfai-1.1.1.dist-info → konfai-1.1.3.dist-info}/WHEEL +0 -0
- {konfai-1.1.1.dist-info → konfai-1.1.3.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.1.dist-info → konfai-1.1.3.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.1.dist-info → konfai-1.1.3.dist-info}/top_level.txt +0 -0
konfai/__init__.py
CHANGED
|
@@ -12,5 +12,5 @@ CONFIG_FILE = lambda : os.environ["KONFAI_CONFIG_FILE"]
|
|
|
12
12
|
KONFAI_STATE = lambda : os.environ["KONFAI_STATE"]
|
|
13
13
|
KONFAI_ROOT = lambda : os.environ["KONFAI_ROOT"]
|
|
14
14
|
CUDA_VISIBLE_DEVICES = lambda : os.environ["CUDA_VISIBLE_DEVICES"]
|
|
15
|
-
|
|
15
|
+
KONFAI_NB_CORES = lambda : os.environ["KONFAI_NB_CORES"]
|
|
16
16
|
DATE = lambda : datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
konfai/data/augmentation.py
CHANGED
|
@@ -9,8 +9,7 @@ import os
|
|
|
9
9
|
from konfai import KONFAI_ROOT
|
|
10
10
|
from konfai.utils.config import config
|
|
11
11
|
from konfai.utils.utils import _getModule, AugmentationError
|
|
12
|
-
from konfai.utils.dataset import Attribute, data_to_image
|
|
13
|
-
|
|
12
|
+
from konfai.utils.dataset import Attribute, data_to_image, Dataset
|
|
14
13
|
|
|
15
14
|
def _translate2DMatrix(t: torch.Tensor) -> torch.Tensor:
|
|
16
15
|
return torch.cat((torch.cat((torch.eye(2), torch.tensor([[t[0]], [t[1]]])), dim=1), torch.Tensor([[0,0,1]])), dim=0)
|
|
@@ -56,23 +55,29 @@ class DataAugmentationsList():
|
|
|
56
55
|
self.dataAugmentations : list[DataAugmentation] = []
|
|
57
56
|
self.dataAugmentationsLoader = dataAugmentations
|
|
58
57
|
|
|
59
|
-
def load(self, key: str):
|
|
58
|
+
def load(self, key: str, datasets: list[Dataset]):
|
|
60
59
|
for augmentation, prob in self.dataAugmentationsLoader.items():
|
|
61
60
|
module, name = _getModule(augmentation, "data.augmentation")
|
|
62
|
-
dataAugmentation: DataAugmentation =
|
|
61
|
+
dataAugmentation: DataAugmentation = config("{}.Dataset.augmentations.{}.dataAugmentations.{}".format(KONFAI_ROOT(), key, augmentation))(getattr(importlib.import_module(module), name))(config = None)
|
|
63
62
|
dataAugmentation.load(prob.prob)
|
|
63
|
+
dataAugmentation.setDatasets(datasets)
|
|
64
64
|
self.dataAugmentations.append(dataAugmentation)
|
|
65
|
-
|
|
65
|
+
|
|
66
66
|
class DataAugmentation(ABC):
|
|
67
67
|
|
|
68
|
-
def __init__(self) -> None:
|
|
68
|
+
def __init__(self, groups: Union[list[str], None] = None) -> None:
|
|
69
69
|
self.who_index: dict[int, list[int]] = {}
|
|
70
70
|
self.shape_index: dict[int, list[list[int]]] = {}
|
|
71
71
|
self._prob: float = 0
|
|
72
|
+
self.groups = groups
|
|
73
|
+
self.datasets : list[Dataset] = []
|
|
72
74
|
|
|
73
75
|
def load(self, prob: float):
|
|
74
76
|
self._prob = prob
|
|
75
77
|
|
|
78
|
+
def setDatasets(self, datasets: list[Dataset]):
|
|
79
|
+
self.datasets = datasets
|
|
80
|
+
|
|
76
81
|
def state_init(self, index: Union[None, int], shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
77
82
|
if index is not None:
|
|
78
83
|
if index not in self.who_index:
|
|
@@ -93,14 +98,14 @@ class DataAugmentation(ABC):
|
|
|
93
98
|
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
94
99
|
pass
|
|
95
100
|
|
|
96
|
-
def __call__(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
101
|
+
def __call__(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
97
102
|
if len(self.who_index[index]) > 0:
|
|
98
|
-
for i, result in enumerate(self._compute(index, [inputs[i] for i in self.who_index[index]], device)):
|
|
103
|
+
for i, result in enumerate(self._compute(name, index, [inputs[i] for i in self.who_index[index]], device)):
|
|
99
104
|
inputs[self.who_index[index][i]] = result if device is None else result.cpu()
|
|
100
105
|
return inputs
|
|
101
106
|
|
|
102
107
|
@abstractmethod
|
|
103
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
108
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
104
109
|
pass
|
|
105
110
|
|
|
106
111
|
def inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
@@ -118,7 +123,7 @@ class EulerTransform(DataAugmentation):
|
|
|
118
123
|
super().__init__()
|
|
119
124
|
self.matrix: dict[int, list[torch.Tensor]] = {}
|
|
120
125
|
|
|
121
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
126
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
122
127
|
results = []
|
|
123
128
|
for input, matrix in zip(inputs, self.matrix[index]):
|
|
124
129
|
results.append(F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0))
|
|
@@ -126,7 +131,7 @@ class EulerTransform(DataAugmentation):
|
|
|
126
131
|
|
|
127
132
|
def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
128
133
|
return F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(self.matrix[index][a].inverse()[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0)
|
|
129
|
-
|
|
134
|
+
|
|
130
135
|
class Translate(EulerTransform):
|
|
131
136
|
|
|
132
137
|
@config("Translate")
|
|
@@ -194,7 +199,7 @@ class Flip(DataAugmentation):
|
|
|
194
199
|
self.flip[index] = [dims[mask].tolist() for mask in prob]
|
|
195
200
|
return shapes
|
|
196
201
|
|
|
197
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
202
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
198
203
|
results = []
|
|
199
204
|
for input, flip in zip(inputs, self.flip[index]):
|
|
200
205
|
results.append(torch.flip(input, dims=flip))
|
|
@@ -207,11 +212,11 @@ class Flip(DataAugmentation):
|
|
|
207
212
|
class ColorTransform(DataAugmentation):
|
|
208
213
|
|
|
209
214
|
@config("ColorTransform")
|
|
210
|
-
def __init__(self) -> None:
|
|
211
|
-
super().__init__()
|
|
215
|
+
def __init__(self, groups: Union[list[str], None] = None) -> None:
|
|
216
|
+
super().__init__(groups)
|
|
212
217
|
self.matrix: dict[int, list[torch.Tensor]] = {}
|
|
213
218
|
|
|
214
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
219
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
215
220
|
results = []
|
|
216
221
|
for input, matrix in zip(inputs, self.matrix[index]):
|
|
217
222
|
result = input.reshape([*input.shape[:1], int(np.prod(input.shape[1:]))])
|
|
@@ -232,8 +237,8 @@ class ColorTransform(DataAugmentation):
|
|
|
232
237
|
class Brightness(ColorTransform):
|
|
233
238
|
|
|
234
239
|
@config("Brightness")
|
|
235
|
-
def __init__(self, b_std: float) -> None:
|
|
236
|
-
super().__init__()
|
|
240
|
+
def __init__(self, b_std: float, groups: Union[list[str], None] = None) -> None:
|
|
241
|
+
super().__init__(groups)
|
|
237
242
|
self.b_std = b_std
|
|
238
243
|
|
|
239
244
|
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
@@ -244,8 +249,8 @@ class Brightness(ColorTransform):
|
|
|
244
249
|
class Contrast(ColorTransform):
|
|
245
250
|
|
|
246
251
|
@config("Contrast")
|
|
247
|
-
def __init__(self, c_std: float) -> None:
|
|
248
|
-
super().__init__()
|
|
252
|
+
def __init__(self, c_std: float, groups: Union[list[str], None] = None) -> None:
|
|
253
|
+
super().__init__(groups)
|
|
249
254
|
self.c_std = c_std
|
|
250
255
|
|
|
251
256
|
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
@@ -256,8 +261,8 @@ class Contrast(ColorTransform):
|
|
|
256
261
|
class LumaFlip(ColorTransform):
|
|
257
262
|
|
|
258
263
|
@config("LumaFlip")
|
|
259
|
-
def __init__(self) -> None:
|
|
260
|
-
super().__init__()
|
|
264
|
+
def __init__(self, groups: Union[list[str], None] = None) -> None:
|
|
265
|
+
super().__init__(groups)
|
|
261
266
|
self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
|
|
262
267
|
|
|
263
268
|
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
@@ -268,8 +273,8 @@ class LumaFlip(ColorTransform):
|
|
|
268
273
|
class HUE(ColorTransform):
|
|
269
274
|
|
|
270
275
|
@config("HUE")
|
|
271
|
-
def __init__(self, hue_max: float) -> None:
|
|
272
|
-
super().__init__()
|
|
276
|
+
def __init__(self, hue_max: float, groups: Union[list[str], None] = None) -> None:
|
|
277
|
+
super().__init__(groups)
|
|
273
278
|
self.hue_max = hue_max
|
|
274
279
|
self.v = torch.tensor([1, 1, 1])/torch.sqrt(torch.tensor(3))
|
|
275
280
|
|
|
@@ -281,8 +286,8 @@ class HUE(ColorTransform):
|
|
|
281
286
|
class Saturation(ColorTransform):
|
|
282
287
|
|
|
283
288
|
@config("Saturation")
|
|
284
|
-
def __init__(self, s_std: float) -> None:
|
|
285
|
-
super().__init__()
|
|
289
|
+
def __init__(self, s_std: float, groups: Union[list[str], None] = None) -> None:
|
|
290
|
+
super().__init__(groups)
|
|
286
291
|
self.s_std = s_std
|
|
287
292
|
self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
|
|
288
293
|
|
|
@@ -368,8 +373,8 @@ class Saturation(ColorTransform):
|
|
|
368
373
|
class Noise(DataAugmentation):
|
|
369
374
|
|
|
370
375
|
@config("Noise")
|
|
371
|
-
def __init__(self, n_std: float, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02) -> None:
|
|
372
|
-
super().__init__()
|
|
376
|
+
def __init__(self, n_std: float, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02, groups: Union[list[str], None] = None) -> None:
|
|
377
|
+
super().__init__(groups)
|
|
373
378
|
self.n_std = n_std
|
|
374
379
|
self.noise_step = noise_step
|
|
375
380
|
|
|
@@ -410,7 +415,7 @@ class Noise(DataAugmentation):
|
|
|
410
415
|
self.ts[index] = [torch.randint(0, int(self.max_T), (1,)) for _ in shapes]
|
|
411
416
|
return shapes
|
|
412
417
|
|
|
413
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
418
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
414
419
|
results = []
|
|
415
420
|
for input, t in zip(inputs, self.ts[index]):
|
|
416
421
|
alpha_hat_t = self.alpha_hat[t].to(input.device).reshape(*[1 for _ in range(len(input.shape))])
|
|
@@ -423,8 +428,8 @@ class Noise(DataAugmentation):
|
|
|
423
428
|
class CutOUT(DataAugmentation):
|
|
424
429
|
|
|
425
430
|
@config("CutOUT")
|
|
426
|
-
def __init__(self, c_prob: float, cutout_size: int, value: float) -> None:
|
|
427
|
-
super().__init__()
|
|
431
|
+
def __init__(self, c_prob: float, cutout_size: int, value: float, groups: Union[list[str], None] = None) -> None:
|
|
432
|
+
super().__init__(groups)
|
|
428
433
|
self.c_prob = c_prob
|
|
429
434
|
self.cutout_size = cutout_size
|
|
430
435
|
self.centers: dict[int, list[torch.Tensor]] = {}
|
|
@@ -434,7 +439,7 @@ class CutOUT(DataAugmentation):
|
|
|
434
439
|
self.centers[index] = [torch.rand((3) if len(shape) == 3 else (2)) for shape in shapes]
|
|
435
440
|
return shapes
|
|
436
441
|
|
|
437
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
442
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
438
443
|
results = []
|
|
439
444
|
for input, center in zip(inputs, self.centers[index]):
|
|
440
445
|
masks = []
|
|
@@ -513,7 +518,7 @@ class Elastix(DataAugmentation):
|
|
|
513
518
|
print("Compute in progress : {:.2f} %".format((i+1)/len(shapes)*100))
|
|
514
519
|
return shapes
|
|
515
520
|
|
|
516
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
521
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
517
522
|
results = []
|
|
518
523
|
for input, displacement_field in zip(inputs, self.displacement_fields[index]):
|
|
519
524
|
results.append(F.grid_sample(input.type(torch.float32).unsqueeze(0), displacement_field.to(input.device), align_corners=True, mode="bilinear", padding_mode="border").type(input.dtype).squeeze(0))
|
|
@@ -546,7 +551,7 @@ class Permute(DataAugmentation):
|
|
|
546
551
|
shapes[i] = [shapes[i][dim-1] for dim in permute[1:]]
|
|
547
552
|
return shapes
|
|
548
553
|
|
|
549
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
554
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
550
555
|
results = []
|
|
551
556
|
for input, prob in zip(inputs, self.permute[index]):
|
|
552
557
|
res = input
|
|
@@ -563,8 +568,8 @@ class Permute(DataAugmentation):
|
|
|
563
568
|
class Mask(DataAugmentation):
|
|
564
569
|
|
|
565
570
|
@config("Mask")
|
|
566
|
-
def __init__(self, mask: str, value: float) -> None:
|
|
567
|
-
super().__init__()
|
|
571
|
+
def __init__(self, mask: str, value: float, groups: Union[list[str], None] = None) -> None:
|
|
572
|
+
super().__init__(groups)
|
|
568
573
|
if mask is not None:
|
|
569
574
|
if os.path.exists(mask):
|
|
570
575
|
self.mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(mask)))
|
|
@@ -577,7 +582,7 @@ class Mask(DataAugmentation):
|
|
|
577
582
|
self.positions[index] = [torch.rand((3) if len(shape) == 3 else (2))*(torch.tensor([max(s1-s2, 0) for s1, s2 in zip(torch.tensor(shape), torch.tensor(self.mask.shape))])) for shape in shapes]
|
|
578
583
|
return [self.mask.shape for _ in shapes]
|
|
579
584
|
|
|
580
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
585
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
581
586
|
results = []
|
|
582
587
|
for input, position in zip(inputs, self.positions[index]):
|
|
583
588
|
slices = [slice(None, None)]+[slice(int(s1), int(s1)+s2) for s1, s2 in zip(position, self.mask.shape)]
|
konfai/data/data_manager.py
CHANGED
|
@@ -62,12 +62,25 @@ class GroupTransform:
|
|
|
62
62
|
for transform in self.post_transforms:
|
|
63
63
|
transform.setDevice(device)
|
|
64
64
|
|
|
65
|
+
class GroupTransformMetric(GroupTransform):
|
|
66
|
+
|
|
67
|
+
@config()
|
|
68
|
+
def __init__(self, pre_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
|
|
69
|
+
post_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()}):
|
|
70
|
+
super().__init__(pre_transforms, post_transforms)
|
|
71
|
+
|
|
65
72
|
class Group(dict[str, GroupTransform]):
|
|
66
73
|
|
|
67
74
|
@config()
|
|
68
75
|
def __init__(self, groups_dest: dict[str, GroupTransform] = {"default:group_dest": GroupTransform()}):
|
|
69
76
|
super().__init__(groups_dest)
|
|
70
77
|
|
|
78
|
+
class GroupMetric(dict[str, GroupTransformMetric]):
|
|
79
|
+
|
|
80
|
+
@config()
|
|
81
|
+
def __init__(self, groups_dest: dict[str, GroupTransformMetric] = {"default:group_dest": GroupTransformMetric()}):
|
|
82
|
+
super().__init__(groups_dest)
|
|
83
|
+
|
|
71
84
|
class CustomSampler(Sampler[int]):
|
|
72
85
|
|
|
73
86
|
def __init__(self, size: int, shuffle: bool = False) -> None:
|
|
@@ -109,32 +122,33 @@ class DatasetIter(data.Dataset):
|
|
|
109
122
|
def getDatasetFromIndex(self, group_dest: str, index: int) -> DatasetManager:
|
|
110
123
|
return self.data[group_dest][index]
|
|
111
124
|
|
|
112
|
-
def resetAugmentation(self):
|
|
113
|
-
if self.inlineAugmentations:
|
|
125
|
+
def resetAugmentation(self, label):
|
|
126
|
+
if self.inlineAugmentations and len(self.dataAugmentationsList) > 0:
|
|
114
127
|
for index in range(self.nb_dataset):
|
|
115
|
-
self._unloadData(index)
|
|
116
128
|
for group_src in self.groups_src:
|
|
117
129
|
for group_dest in self.groups_src[group_src]:
|
|
130
|
+
self.data[group_dest][index].unloadAugmentation()
|
|
118
131
|
self.data[group_dest][index].resetAugmentation()
|
|
132
|
+
self.load(label + " Augmentation")
|
|
119
133
|
|
|
120
|
-
def load(self):
|
|
134
|
+
def load(self, label: str):
|
|
121
135
|
if self.use_cache:
|
|
122
136
|
memory_init = getMemory()
|
|
123
137
|
|
|
124
|
-
indexs = [index for index in range(self.nb_dataset)
|
|
138
|
+
indexs = [index for index in range(self.nb_dataset)]
|
|
125
139
|
if len(indexs) > 0:
|
|
126
140
|
memory_lock = threading.Lock()
|
|
141
|
+
desc = lambda : "Caching "+ label +": {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo())
|
|
127
142
|
pbar = tqdm.tqdm(
|
|
128
143
|
total=len(indexs),
|
|
129
|
-
desc=
|
|
130
|
-
leave=False
|
|
131
|
-
disable=self.rank != 0 and "KONFAI_CLUSTER" not in os.environ
|
|
144
|
+
desc=desc(),
|
|
145
|
+
leave=False
|
|
132
146
|
)
|
|
133
147
|
|
|
134
148
|
def process(index):
|
|
135
149
|
self._loadData(index)
|
|
136
150
|
with memory_lock:
|
|
137
|
-
pbar.set_description(
|
|
151
|
+
pbar.set_description(desc())
|
|
138
152
|
pbar.update(1)
|
|
139
153
|
with ThreadPoolExecutor(max_workers=os.cpu_count()//(device_count() if device_count() > 0 else 1)) as executor:
|
|
140
154
|
futures = [executor.submit(process, index) for index in indexs]
|
|
@@ -170,7 +184,7 @@ class DatasetIter(data.Dataset):
|
|
|
170
184
|
data = {}
|
|
171
185
|
x, a, p = self.map[index]
|
|
172
186
|
if x not in self._index_cache:
|
|
173
|
-
if len(self._index_cache) >= self.buffer_size and not self.use_cache:
|
|
187
|
+
if x not in self._index_cache and len(self._index_cache) >= self.buffer_size and not self.use_cache:
|
|
174
188
|
self._unloadData(self._index_cache[0])
|
|
175
189
|
self._loadData(x)
|
|
176
190
|
|
|
@@ -257,10 +271,10 @@ class Data(ABC):
|
|
|
257
271
|
groups_src : dict[str, Group],
|
|
258
272
|
patch : Union[DatasetPatch, None],
|
|
259
273
|
use_cache : bool,
|
|
260
|
-
subset :
|
|
274
|
+
subset : Subset,
|
|
261
275
|
num_workers : int,
|
|
262
276
|
batch_size : int,
|
|
263
|
-
validation: Union[float, str, list[int], list[str]] =
|
|
277
|
+
validation: Union[float, str, list[int], list[str], None] = None,
|
|
264
278
|
inlineAugmentations: bool = False,
|
|
265
279
|
dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
|
|
266
280
|
self.dataset_filenames = dataset_filenames
|
|
@@ -343,7 +357,7 @@ class Data(ABC):
|
|
|
343
357
|
append = flag == "a"
|
|
344
358
|
|
|
345
359
|
if format not in SUPPORTED_EXTENSIONS:
|
|
346
|
-
|
|
360
|
+
raise DatasetManagerError(f"Unsupported file format '{format}'.",
|
|
347
361
|
f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}")
|
|
348
362
|
|
|
349
363
|
dataset = Dataset(filename, format)
|
|
@@ -362,7 +376,7 @@ class Data(ABC):
|
|
|
362
376
|
raise DatasetManagerError(
|
|
363
377
|
f"Group source '{group_src}' not found in any dataset.",
|
|
364
378
|
f"Dataset filenames provided: {self.dataset_filenames}",
|
|
365
|
-
|
|
379
|
+
"Available groups across all datasets: {}".format(["{} {}".format(f, d.getGroup()) for f, d in self.datasets.items()]),
|
|
366
380
|
f"Please check that an entry in the dataset with the name '{group_src}.{format}' exists."
|
|
367
381
|
)
|
|
368
382
|
|
|
@@ -376,34 +390,44 @@ class Data(ABC):
|
|
|
376
390
|
)
|
|
377
391
|
|
|
378
392
|
for key, dataAugmentations in self.dataAugmentationsList.items():
|
|
379
|
-
dataAugmentations.load(key)
|
|
393
|
+
dataAugmentations.load(key, [self.datasets[filename] for filename, _ in datasets[group_src]])
|
|
380
394
|
|
|
381
395
|
names = set()
|
|
382
396
|
dataset_name : dict[str, dict[str, list[str]]] = {}
|
|
383
397
|
dataset_info : dict[str, dict[str, dict[str, Attribute]]] = {}
|
|
384
398
|
for group in self.groups_src:
|
|
399
|
+
namesByGroup = set()
|
|
385
400
|
if group not in dataset_name:
|
|
386
401
|
dataset_name[group] = {}
|
|
387
402
|
dataset_info[group] = {}
|
|
388
403
|
for filename, _ in datasets[group]:
|
|
389
|
-
|
|
404
|
+
namesByGroup.update(self.datasets[filename].getNames(group))
|
|
390
405
|
dataset_name[group][filename] = self.datasets[filename].getNames(group)
|
|
391
406
|
dataset_info[group][filename] = {name: self.datasets[filename].getInfos(group, name) for name in dataset_name[group][filename]}
|
|
407
|
+
if len(names) == 0:
|
|
408
|
+
names.update(namesByGroup)
|
|
409
|
+
else:
|
|
410
|
+
names = names.intersection(namesByGroup)
|
|
411
|
+
if len(names) == 0:
|
|
412
|
+
raise DatasetManagerError(
|
|
413
|
+
f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data from a dataset, there are no common dataset names shared across all groups, the intersection is empty."
|
|
414
|
+
)
|
|
415
|
+
|
|
392
416
|
subset_names = set()
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
for filename,
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
417
|
+
for group in dataset_name:
|
|
418
|
+
subset_names_bygroup = set()
|
|
419
|
+
for filename, append in datasets[group]:
|
|
420
|
+
if append:
|
|
421
|
+
subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
422
|
+
else:
|
|
423
|
+
if len(subset_names_bygroup) == 0:
|
|
424
|
+
subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
402
425
|
else:
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
426
|
+
subset_names_bygroup = subset_names_bygroup.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
427
|
+
if len(subset_names) == 0:
|
|
428
|
+
subset_names.update(subset_names_bygroup)
|
|
429
|
+
else:
|
|
430
|
+
subset_names = subset_names.intersection(subset_names_bygroup)
|
|
407
431
|
if len(subset_names) == 0:
|
|
408
432
|
raise DatasetManagerError("All data entries were excluded by the subset filter.",
|
|
409
433
|
f"Dataset entries found: {', '.join(names)}",
|
|
@@ -424,7 +448,7 @@ class Data(ABC):
|
|
|
424
448
|
validate_map = []
|
|
425
449
|
if isinstance(self.validation, float) or isinstance(self.validation, int):
|
|
426
450
|
if self.validation <= 0 or self.validation >= 1:
|
|
427
|
-
raise DatasetManagerError("
|
|
451
|
+
raise DatasetManagerError("Validation must be a float between 0 and 1.", f"Received: {self.validation}", "Example: validation = 0.2 # for a 20% validation split")
|
|
428
452
|
|
|
429
453
|
train_map, validate_map = map[:int(math.floor(len(map)*(1-self.validation)))], map[int(math.floor(len(map)*(1-self.validation))):]
|
|
430
454
|
elif isinstance(self.validation, str):
|
|
@@ -527,20 +551,19 @@ class DataPrediction(Data):
|
|
|
527
551
|
groups_src : dict[str, Group] = {"default" : Group()},
|
|
528
552
|
augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
|
|
529
553
|
patch : Union[DatasetPatch, None] = DatasetPatch(),
|
|
530
|
-
use_cache : bool = True,
|
|
531
554
|
subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
|
|
532
555
|
num_workers : int = 4,
|
|
533
556
|
batch_size : int = 1) -> None:
|
|
534
557
|
|
|
535
|
-
super().__init__(dataset_filenames, groups_src, patch,
|
|
558
|
+
super().__init__(dataset_filenames, groups_src, patch, False, subset, num_workers, batch_size, dataAugmentationsList=augmentations if augmentations else {})
|
|
536
559
|
|
|
537
560
|
class DataMetric(Data):
|
|
538
561
|
|
|
539
562
|
@config("Dataset")
|
|
540
563
|
def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
|
|
541
|
-
groups_src : dict[str,
|
|
564
|
+
groups_src : dict[str, GroupMetric] = {"default" : GroupMetric()},
|
|
542
565
|
subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
|
|
543
566
|
validation: Union[str, None] = None,
|
|
544
567
|
num_workers : int = 4) -> None:
|
|
545
568
|
|
|
546
|
-
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=
|
|
569
|
+
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=validation)
|
konfai/data/patching.py
CHANGED
|
@@ -15,7 +15,6 @@ from konfai.utils.dataset import Dataset, Attribute
|
|
|
15
15
|
from konfai.data.transform import Transform, Save
|
|
16
16
|
from konfai.data.augmentation import DataAugmentationsList
|
|
17
17
|
|
|
18
|
-
|
|
19
18
|
class PathCombine(ABC):
|
|
20
19
|
|
|
21
20
|
def __init__(self) -> None:
|
|
@@ -240,6 +239,7 @@ class DatasetManager():
|
|
|
240
239
|
self.index = index
|
|
241
240
|
self.dataset = dataset
|
|
242
241
|
self.loaded = False
|
|
242
|
+
self.augmentationLoaded = False
|
|
243
243
|
self.cache_attributes: list[Attribute] = []
|
|
244
244
|
_shape, cache_attribute = self.dataset.getInfos(self.group_src, name)
|
|
245
245
|
self.cache_attributes.append(cache_attribute)
|
|
@@ -258,6 +258,7 @@ class DatasetManager():
|
|
|
258
258
|
self.cache_attributes_bak = copy.deepcopy(self.cache_attributes)
|
|
259
259
|
|
|
260
260
|
def resetAugmentation(self):
|
|
261
|
+
self.cache_attributes[:] = self.cache_attributes[:1]
|
|
261
262
|
i = 1
|
|
262
263
|
for dataAugmentations in self.dataAugmentationsList:
|
|
263
264
|
shape = []
|
|
@@ -272,15 +273,24 @@ class DatasetManager():
|
|
|
272
273
|
self.cache_attributes.append(caches_attribute[it])
|
|
273
274
|
self.patch.load(s, i)
|
|
274
275
|
i+=1
|
|
275
|
-
|
|
276
|
+
|
|
276
277
|
def load(self, pre_transform : list[Transform], dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
|
|
277
|
-
if self.loaded:
|
|
278
|
-
|
|
278
|
+
if not self.loaded:
|
|
279
|
+
self._load(pre_transform)
|
|
280
|
+
if not self.augmentationLoaded:
|
|
281
|
+
self._loadAugmentation(dataAugmentationsList, device)
|
|
282
|
+
|
|
283
|
+
def _load(self, pre_transform : list[Transform]):
|
|
284
|
+
self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
|
|
279
285
|
i = len(pre_transform)
|
|
280
286
|
data = None
|
|
281
287
|
for transformFunction in reversed(pre_transform):
|
|
282
288
|
if isinstance(transformFunction, Save):
|
|
283
|
-
|
|
289
|
+
if len(transformFunction.dataset.split(":")) > 1:
|
|
290
|
+
filename, format = transformFunction.dataset.split(":")
|
|
291
|
+
else:
|
|
292
|
+
filename = transformFunction.dataset.split(":")
|
|
293
|
+
format = "mha"
|
|
284
294
|
dataset = Dataset(filename, format)
|
|
285
295
|
if dataset.isDatasetExist(self.group_dest, self.name):
|
|
286
296
|
data, attrib = dataset.readData(self.group_dest, self.name)
|
|
@@ -298,27 +308,41 @@ class DatasetManager():
|
|
|
298
308
|
for transformFunction in pre_transform[i:]:
|
|
299
309
|
data = transformFunction(self.name, data, self.cache_attributes[0])
|
|
300
310
|
if isinstance(transformFunction, Save):
|
|
301
|
-
|
|
311
|
+
if len(transformFunction.dataset.split(":")) > 1:
|
|
312
|
+
filename, format = transformFunction.dataset.split(":")
|
|
313
|
+
else:
|
|
314
|
+
filename = transformFunction.dataset.split(":")
|
|
315
|
+
format = "mha"
|
|
302
316
|
dataset = Dataset(filename, format)
|
|
303
317
|
dataset.write(self.group_dest, self.name, data.numpy(), self.cache_attributes[0])
|
|
304
318
|
self.data : list[torch.Tensor] = list()
|
|
305
319
|
self.data.append(data)
|
|
306
320
|
|
|
321
|
+
for i in range(len(self.cache_attributes)-1):
|
|
322
|
+
self.cache_attributes[i+1].update(self.cache_attributes[0])
|
|
323
|
+
self.loaded = True
|
|
324
|
+
|
|
325
|
+
def _loadAugmentation(self, dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
|
|
307
326
|
for dataAugmentations in dataAugmentationsList:
|
|
308
|
-
a_data = [data.clone() for _ in range(dataAugmentations.nb)]
|
|
327
|
+
a_data = [self.data[0].clone() for _ in range(dataAugmentations.nb)]
|
|
309
328
|
for dataAugmentation in dataAugmentations.dataAugmentations:
|
|
310
|
-
|
|
329
|
+
if dataAugmentation.groups is None or self.group_dest in dataAugmentation.groups:
|
|
330
|
+
a_data = dataAugmentation(self.name, self.index, a_data, device)
|
|
311
331
|
|
|
312
332
|
for d in a_data:
|
|
313
333
|
self.data.append(d)
|
|
314
|
-
self.
|
|
334
|
+
self.augmentationLoaded = True
|
|
315
335
|
|
|
316
336
|
def unload(self) -> None:
|
|
317
|
-
|
|
318
|
-
del self.data
|
|
319
|
-
self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
|
|
337
|
+
self.data.clear()
|
|
320
338
|
self.loaded = False
|
|
321
|
-
|
|
339
|
+
self.augmentationLoaded = False
|
|
340
|
+
|
|
341
|
+
def unloadAugmentation(self) -> None:
|
|
342
|
+
self.data[:] = self.data[:1]
|
|
343
|
+
self.augmentationLoaded = False
|
|
344
|
+
|
|
345
|
+
|
|
322
346
|
def getData(self, index : int, a : int, post_transforms : list[Transform], isInput: bool) -> torch.Tensor:
|
|
323
347
|
data = self.patch.getData(self.data[a], index, a, isInput)
|
|
324
348
|
for transformFunction in post_transforms:
|