konfai 1.1.2__tar.gz → 1.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- {konfai-1.1.2 → konfai-1.1.3}/PKG-INFO +1 -1
- {konfai-1.1.2 → konfai-1.1.3}/konfai/data/augmentation.py +41 -36
- {konfai-1.1.2 → konfai-1.1.3}/konfai/data/data_manager.py +2 -2
- {konfai-1.1.2 → konfai-1.1.3}/konfai/data/patching.py +2 -1
- {konfai-1.1.2 → konfai-1.1.3}/konfai/data/transform.py +1 -1
- {konfai-1.1.2 → konfai-1.1.3}/konfai/main.py +2 -1
- {konfai-1.1.2 → konfai-1.1.3}/konfai/models/segmentation/UNet.py +9 -10
- {konfai-1.1.2 → konfai-1.1.3}/konfai/predictor.py +17 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/trainer.py +14 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai.egg-info/PKG-INFO +1 -1
- {konfai-1.1.2 → konfai-1.1.3}/pyproject.toml +1 -1
- {konfai-1.1.2 → konfai-1.1.3}/LICENSE +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/README.md +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/__init__.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/data/__init__.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/evaluator.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/metric/__init__.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/metric/measure.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/metric/schedulers.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/models/classification/convNeXt.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/models/classification/resnet.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/models/generation/cStyleGan.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/models/generation/ddpm.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/models/generation/diffusionGan.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/models/generation/gan.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/models/generation/vae.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/models/registration/registration.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/models/representation/representation.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/models/segmentation/NestedUNet.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/network/__init__.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/network/blocks.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/network/network.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/utils/ITK.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/utils/__init__.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/utils/config.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/utils/dataset.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/utils/registration.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai/utils/utils.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai.egg-info/SOURCES.txt +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai.egg-info/dependency_links.txt +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai.egg-info/entry_points.txt +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai.egg-info/requires.txt +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/konfai.egg-info/top_level.txt +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/setup.cfg +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/tests/test_config.py +0 -0
- {konfai-1.1.2 → konfai-1.1.3}/tests/test_dataset.py +0 -0
|
@@ -9,8 +9,7 @@ import os
|
|
|
9
9
|
from konfai import KONFAI_ROOT
|
|
10
10
|
from konfai.utils.config import config
|
|
11
11
|
from konfai.utils.utils import _getModule, AugmentationError
|
|
12
|
-
from konfai.utils.dataset import Attribute, data_to_image
|
|
13
|
-
|
|
12
|
+
from konfai.utils.dataset import Attribute, data_to_image, Dataset
|
|
14
13
|
|
|
15
14
|
def _translate2DMatrix(t: torch.Tensor) -> torch.Tensor:
|
|
16
15
|
return torch.cat((torch.cat((torch.eye(2), torch.tensor([[t[0]], [t[1]]])), dim=1), torch.Tensor([[0,0,1]])), dim=0)
|
|
@@ -56,23 +55,29 @@ class DataAugmentationsList():
|
|
|
56
55
|
self.dataAugmentations : list[DataAugmentation] = []
|
|
57
56
|
self.dataAugmentationsLoader = dataAugmentations
|
|
58
57
|
|
|
59
|
-
def load(self, key: str):
|
|
58
|
+
def load(self, key: str, datasets: list[Dataset]):
|
|
60
59
|
for augmentation, prob in self.dataAugmentationsLoader.items():
|
|
61
60
|
module, name = _getModule(augmentation, "data.augmentation")
|
|
62
|
-
dataAugmentation: DataAugmentation =
|
|
61
|
+
dataAugmentation: DataAugmentation = config("{}.Dataset.augmentations.{}.dataAugmentations.{}".format(KONFAI_ROOT(), key, augmentation))(getattr(importlib.import_module(module), name))(config = None)
|
|
63
62
|
dataAugmentation.load(prob.prob)
|
|
63
|
+
dataAugmentation.setDatasets(datasets)
|
|
64
64
|
self.dataAugmentations.append(dataAugmentation)
|
|
65
|
-
|
|
65
|
+
|
|
66
66
|
class DataAugmentation(ABC):
|
|
67
67
|
|
|
68
|
-
def __init__(self) -> None:
|
|
68
|
+
def __init__(self, groups: Union[list[str], None] = None) -> None:
|
|
69
69
|
self.who_index: dict[int, list[int]] = {}
|
|
70
70
|
self.shape_index: dict[int, list[list[int]]] = {}
|
|
71
71
|
self._prob: float = 0
|
|
72
|
+
self.groups = groups
|
|
73
|
+
self.datasets : list[Dataset] = []
|
|
72
74
|
|
|
73
75
|
def load(self, prob: float):
|
|
74
76
|
self._prob = prob
|
|
75
77
|
|
|
78
|
+
def setDatasets(self, datasets: list[Dataset]):
|
|
79
|
+
self.datasets = datasets
|
|
80
|
+
|
|
76
81
|
def state_init(self, index: Union[None, int], shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
77
82
|
if index is not None:
|
|
78
83
|
if index not in self.who_index:
|
|
@@ -93,14 +98,14 @@ class DataAugmentation(ABC):
|
|
|
93
98
|
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
94
99
|
pass
|
|
95
100
|
|
|
96
|
-
def __call__(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
101
|
+
def __call__(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
97
102
|
if len(self.who_index[index]) > 0:
|
|
98
|
-
for i, result in enumerate(self._compute(index, [inputs[i] for i in self.who_index[index]], device)):
|
|
103
|
+
for i, result in enumerate(self._compute(name, index, [inputs[i] for i in self.who_index[index]], device)):
|
|
99
104
|
inputs[self.who_index[index][i]] = result if device is None else result.cpu()
|
|
100
105
|
return inputs
|
|
101
106
|
|
|
102
107
|
@abstractmethod
|
|
103
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
108
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
104
109
|
pass
|
|
105
110
|
|
|
106
111
|
def inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
@@ -118,7 +123,7 @@ class EulerTransform(DataAugmentation):
|
|
|
118
123
|
super().__init__()
|
|
119
124
|
self.matrix: dict[int, list[torch.Tensor]] = {}
|
|
120
125
|
|
|
121
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
126
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
122
127
|
results = []
|
|
123
128
|
for input, matrix in zip(inputs, self.matrix[index]):
|
|
124
129
|
results.append(F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0))
|
|
@@ -126,7 +131,7 @@ class EulerTransform(DataAugmentation):
|
|
|
126
131
|
|
|
127
132
|
def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
128
133
|
return F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(self.matrix[index][a].inverse()[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0)
|
|
129
|
-
|
|
134
|
+
|
|
130
135
|
class Translate(EulerTransform):
|
|
131
136
|
|
|
132
137
|
@config("Translate")
|
|
@@ -194,7 +199,7 @@ class Flip(DataAugmentation):
|
|
|
194
199
|
self.flip[index] = [dims[mask].tolist() for mask in prob]
|
|
195
200
|
return shapes
|
|
196
201
|
|
|
197
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
202
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
198
203
|
results = []
|
|
199
204
|
for input, flip in zip(inputs, self.flip[index]):
|
|
200
205
|
results.append(torch.flip(input, dims=flip))
|
|
@@ -207,11 +212,11 @@ class Flip(DataAugmentation):
|
|
|
207
212
|
class ColorTransform(DataAugmentation):
|
|
208
213
|
|
|
209
214
|
@config("ColorTransform")
|
|
210
|
-
def __init__(self) -> None:
|
|
211
|
-
super().__init__()
|
|
215
|
+
def __init__(self, groups: Union[list[str], None] = None) -> None:
|
|
216
|
+
super().__init__(groups)
|
|
212
217
|
self.matrix: dict[int, list[torch.Tensor]] = {}
|
|
213
218
|
|
|
214
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
219
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
215
220
|
results = []
|
|
216
221
|
for input, matrix in zip(inputs, self.matrix[index]):
|
|
217
222
|
result = input.reshape([*input.shape[:1], int(np.prod(input.shape[1:]))])
|
|
@@ -232,8 +237,8 @@ class ColorTransform(DataAugmentation):
|
|
|
232
237
|
class Brightness(ColorTransform):
|
|
233
238
|
|
|
234
239
|
@config("Brightness")
|
|
235
|
-
def __init__(self, b_std: float) -> None:
|
|
236
|
-
super().__init__()
|
|
240
|
+
def __init__(self, b_std: float, groups: Union[list[str], None] = None) -> None:
|
|
241
|
+
super().__init__(groups)
|
|
237
242
|
self.b_std = b_std
|
|
238
243
|
|
|
239
244
|
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
@@ -244,8 +249,8 @@ class Brightness(ColorTransform):
|
|
|
244
249
|
class Contrast(ColorTransform):
|
|
245
250
|
|
|
246
251
|
@config("Contrast")
|
|
247
|
-
def __init__(self, c_std: float) -> None:
|
|
248
|
-
super().__init__()
|
|
252
|
+
def __init__(self, c_std: float, groups: Union[list[str], None] = None) -> None:
|
|
253
|
+
super().__init__(groups)
|
|
249
254
|
self.c_std = c_std
|
|
250
255
|
|
|
251
256
|
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
@@ -256,8 +261,8 @@ class Contrast(ColorTransform):
|
|
|
256
261
|
class LumaFlip(ColorTransform):
|
|
257
262
|
|
|
258
263
|
@config("LumaFlip")
|
|
259
|
-
def __init__(self) -> None:
|
|
260
|
-
super().__init__()
|
|
264
|
+
def __init__(self, groups: Union[list[str], None] = None) -> None:
|
|
265
|
+
super().__init__(groups)
|
|
261
266
|
self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
|
|
262
267
|
|
|
263
268
|
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
@@ -268,8 +273,8 @@ class LumaFlip(ColorTransform):
|
|
|
268
273
|
class HUE(ColorTransform):
|
|
269
274
|
|
|
270
275
|
@config("HUE")
|
|
271
|
-
def __init__(self, hue_max: float) -> None:
|
|
272
|
-
super().__init__()
|
|
276
|
+
def __init__(self, hue_max: float, groups: Union[list[str], None] = None) -> None:
|
|
277
|
+
super().__init__(groups)
|
|
273
278
|
self.hue_max = hue_max
|
|
274
279
|
self.v = torch.tensor([1, 1, 1])/torch.sqrt(torch.tensor(3))
|
|
275
280
|
|
|
@@ -281,8 +286,8 @@ class HUE(ColorTransform):
|
|
|
281
286
|
class Saturation(ColorTransform):
|
|
282
287
|
|
|
283
288
|
@config("Saturation")
|
|
284
|
-
def __init__(self, s_std: float) -> None:
|
|
285
|
-
super().__init__()
|
|
289
|
+
def __init__(self, s_std: float, groups: Union[list[str], None] = None) -> None:
|
|
290
|
+
super().__init__(groups)
|
|
286
291
|
self.s_std = s_std
|
|
287
292
|
self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
|
|
288
293
|
|
|
@@ -368,8 +373,8 @@ class Saturation(ColorTransform):
|
|
|
368
373
|
class Noise(DataAugmentation):
|
|
369
374
|
|
|
370
375
|
@config("Noise")
|
|
371
|
-
def __init__(self, n_std: float, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02) -> None:
|
|
372
|
-
super().__init__()
|
|
376
|
+
def __init__(self, n_std: float, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02, groups: Union[list[str], None] = None) -> None:
|
|
377
|
+
super().__init__(groups)
|
|
373
378
|
self.n_std = n_std
|
|
374
379
|
self.noise_step = noise_step
|
|
375
380
|
|
|
@@ -410,7 +415,7 @@ class Noise(DataAugmentation):
|
|
|
410
415
|
self.ts[index] = [torch.randint(0, int(self.max_T), (1,)) for _ in shapes]
|
|
411
416
|
return shapes
|
|
412
417
|
|
|
413
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
418
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
414
419
|
results = []
|
|
415
420
|
for input, t in zip(inputs, self.ts[index]):
|
|
416
421
|
alpha_hat_t = self.alpha_hat[t].to(input.device).reshape(*[1 for _ in range(len(input.shape))])
|
|
@@ -423,8 +428,8 @@ class Noise(DataAugmentation):
|
|
|
423
428
|
class CutOUT(DataAugmentation):
|
|
424
429
|
|
|
425
430
|
@config("CutOUT")
|
|
426
|
-
def __init__(self, c_prob: float, cutout_size: int, value: float) -> None:
|
|
427
|
-
super().__init__()
|
|
431
|
+
def __init__(self, c_prob: float, cutout_size: int, value: float, groups: Union[list[str], None] = None) -> None:
|
|
432
|
+
super().__init__(groups)
|
|
428
433
|
self.c_prob = c_prob
|
|
429
434
|
self.cutout_size = cutout_size
|
|
430
435
|
self.centers: dict[int, list[torch.Tensor]] = {}
|
|
@@ -434,7 +439,7 @@ class CutOUT(DataAugmentation):
|
|
|
434
439
|
self.centers[index] = [torch.rand((3) if len(shape) == 3 else (2)) for shape in shapes]
|
|
435
440
|
return shapes
|
|
436
441
|
|
|
437
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
442
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
438
443
|
results = []
|
|
439
444
|
for input, center in zip(inputs, self.centers[index]):
|
|
440
445
|
masks = []
|
|
@@ -513,7 +518,7 @@ class Elastix(DataAugmentation):
|
|
|
513
518
|
print("Compute in progress : {:.2f} %".format((i+1)/len(shapes)*100))
|
|
514
519
|
return shapes
|
|
515
520
|
|
|
516
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
521
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
517
522
|
results = []
|
|
518
523
|
for input, displacement_field in zip(inputs, self.displacement_fields[index]):
|
|
519
524
|
results.append(F.grid_sample(input.type(torch.float32).unsqueeze(0), displacement_field.to(input.device), align_corners=True, mode="bilinear", padding_mode="border").type(input.dtype).squeeze(0))
|
|
@@ -546,7 +551,7 @@ class Permute(DataAugmentation):
|
|
|
546
551
|
shapes[i] = [shapes[i][dim-1] for dim in permute[1:]]
|
|
547
552
|
return shapes
|
|
548
553
|
|
|
549
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
554
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
550
555
|
results = []
|
|
551
556
|
for input, prob in zip(inputs, self.permute[index]):
|
|
552
557
|
res = input
|
|
@@ -563,8 +568,8 @@ class Permute(DataAugmentation):
|
|
|
563
568
|
class Mask(DataAugmentation):
|
|
564
569
|
|
|
565
570
|
@config("Mask")
|
|
566
|
-
def __init__(self, mask: str, value: float) -> None:
|
|
567
|
-
super().__init__()
|
|
571
|
+
def __init__(self, mask: str, value: float, groups: Union[list[str], None] = None) -> None:
|
|
572
|
+
super().__init__(groups)
|
|
568
573
|
if mask is not None:
|
|
569
574
|
if os.path.exists(mask):
|
|
570
575
|
self.mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(mask)))
|
|
@@ -577,7 +582,7 @@ class Mask(DataAugmentation):
|
|
|
577
582
|
self.positions[index] = [torch.rand((3) if len(shape) == 3 else (2))*(torch.tensor([max(s1-s2, 0) for s1, s2 in zip(torch.tensor(shape), torch.tensor(self.mask.shape))])) for shape in shapes]
|
|
578
583
|
return [self.mask.shape for _ in shapes]
|
|
579
584
|
|
|
580
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
585
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
581
586
|
results = []
|
|
582
587
|
for input, position in zip(inputs, self.positions[index]):
|
|
583
588
|
slices = [slice(None, None)]+[slice(int(s1), int(s1)+s2) for s1, s2 in zip(position, self.mask.shape)]
|
|
@@ -184,7 +184,7 @@ class DatasetIter(data.Dataset):
|
|
|
184
184
|
data = {}
|
|
185
185
|
x, a, p = self.map[index]
|
|
186
186
|
if x not in self._index_cache:
|
|
187
|
-
if len(self._index_cache) >= self.buffer_size and not self.use_cache:
|
|
187
|
+
if x not in self._index_cache and len(self._index_cache) >= self.buffer_size and not self.use_cache:
|
|
188
188
|
self._unloadData(self._index_cache[0])
|
|
189
189
|
self._loadData(x)
|
|
190
190
|
|
|
@@ -390,7 +390,7 @@ class Data(ABC):
|
|
|
390
390
|
)
|
|
391
391
|
|
|
392
392
|
for key, dataAugmentations in self.dataAugmentationsList.items():
|
|
393
|
-
dataAugmentations.load(key)
|
|
393
|
+
dataAugmentations.load(key, [self.datasets[filename] for filename, _ in datasets[group_src]])
|
|
394
394
|
|
|
395
395
|
names = set()
|
|
396
396
|
dataset_name : dict[str, dict[str, list[str]]] = {}
|
|
@@ -326,7 +326,8 @@ class DatasetManager():
|
|
|
326
326
|
for dataAugmentations in dataAugmentationsList:
|
|
327
327
|
a_data = [self.data[0].clone() for _ in range(dataAugmentations.nb)]
|
|
328
328
|
for dataAugmentation in dataAugmentations.dataAugmentations:
|
|
329
|
-
|
|
329
|
+
if dataAugmentation.groups is None or self.group_dest in dataAugmentation.groups:
|
|
330
|
+
a_data = dataAugmentation(self.name, self.index, a_data, device)
|
|
330
331
|
|
|
331
332
|
for d in a_data:
|
|
332
333
|
self.data.append(d)
|
|
@@ -52,7 +52,7 @@ class Clip(Transform):
|
|
|
52
52
|
input[torch.where(input < self.min_value)] = self.min_value
|
|
53
53
|
input[torch.where(input > self.max_value)] = self.max_value
|
|
54
54
|
if self.saveClip_min:
|
|
55
|
-
cache_attribute["Min"] = self
|
|
55
|
+
cache_attribute["Min"] = self.min_value
|
|
56
56
|
if self.saveClip_max:
|
|
57
57
|
cache_attribute["Max"] = self.max_value
|
|
58
58
|
return input
|
|
@@ -9,6 +9,8 @@ import sys
|
|
|
9
9
|
sys.path.insert(0, os.getcwd())
|
|
10
10
|
|
|
11
11
|
def main():
|
|
12
|
+
import tracemalloc
|
|
13
|
+
|
|
12
14
|
parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
13
15
|
try:
|
|
14
16
|
with setup(parser) as distributedObject:
|
|
@@ -22,7 +24,6 @@ def main():
|
|
|
22
24
|
except KeyboardInterrupt:
|
|
23
25
|
print("\n[KonfAI] Manual interruption (Ctrl+C)")
|
|
24
26
|
|
|
25
|
-
|
|
26
27
|
def cluster():
|
|
27
28
|
parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
28
29
|
|
|
@@ -7,32 +7,33 @@ from konfai.data.patching import ModelPatch
|
|
|
7
7
|
|
|
8
8
|
class UNetHead(network.ModuleArgsDict):
|
|
9
9
|
|
|
10
|
-
def __init__(self, in_channels: int, nb_class: int, dim: int) -> None:
|
|
10
|
+
def __init__(self, in_channels: int, nb_class: int, dim: int, level: int) -> None:
|
|
11
11
|
super().__init__()
|
|
12
12
|
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
|
13
13
|
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
14
14
|
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
15
|
+
|
|
15
16
|
|
|
16
17
|
class UNetBlock(network.ModuleArgsDict):
|
|
17
18
|
|
|
18
|
-
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, nb_class: int, dim: int, i : int = 0
|
|
19
|
+
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, nb_class: int, dim: int, i : int = 0) -> None:
|
|
19
20
|
super().__init__()
|
|
20
21
|
blockConfig_stride = blockConfig
|
|
21
22
|
if i > 0:
|
|
22
23
|
if downSampleMode != blocks.DownSampleMode.CONV_STRIDE:
|
|
23
24
|
self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
|
|
24
25
|
else:
|
|
25
|
-
blockConfig_stride = blocks.BlockConfig(blockConfig.kernel_size,
|
|
26
|
+
blockConfig_stride = blocks.BlockConfig(blockConfig.kernel_size, 2, blockConfig.padding, blockConfig.bias, blockConfig.activation, blockConfig.normMode)
|
|
26
27
|
self.add_module("DownConvBlock", block(in_channels=channels[0], out_channels=channels[1], blockConfigs=[blockConfig_stride]+[blockConfig]*(nb_conv_per_stage-1), dim=dim))
|
|
27
28
|
if len(channels) > 2:
|
|
28
|
-
self.add_module("UNetBlock_{}".format(i+1), UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, nb_class, dim, i+1
|
|
29
|
+
self.add_module("UNetBlock_{}".format(i+1), UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, nb_class, dim, i+1))
|
|
29
30
|
self.add_module("UpConvBlock", block(in_channels=(channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2, out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim))
|
|
30
31
|
if nb_class > 0:
|
|
31
|
-
self.add_module("Head", UNetHead(channels[1], nb_class, dim), out_branch=[-1])
|
|
32
|
+
self.add_module("Head", UNetHead(channels[1], nb_class, dim, i), out_branch=[-1])
|
|
32
33
|
if i > 0:
|
|
33
34
|
if attention:
|
|
34
35
|
self.add_module("Attention", blocks.Attention(F_g=channels[1], F_l=channels[0], F_int=channels[0], dim=dim), in_branch=[1, 0], out_branch=[1])
|
|
35
|
-
self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim, kernel_size=
|
|
36
|
+
self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim, kernel_size=2, stride=2))
|
|
36
37
|
self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
|
|
37
38
|
|
|
38
39
|
class UNet(network.Network):
|
|
@@ -51,8 +52,6 @@ class UNet(network.Network):
|
|
|
51
52
|
downSampleMode: str = "MAXPOOL",
|
|
52
53
|
upSampleMode: str = "CONV_TRANSPOSE",
|
|
53
54
|
attention : bool = False,
|
|
54
|
-
blockType: str = "Conv"
|
|
55
|
-
mri: bool = False) -> None:
|
|
55
|
+
blockType: str = "Conv") -> None:
|
|
56
56
|
super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
|
|
57
|
-
self.add_module("UNetBlock_0", UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=nb_class, dim=dim
|
|
58
|
-
|
|
57
|
+
self.add_module("UNetBlock_0", UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=nb_class, dim=dim))
|
|
@@ -91,6 +91,23 @@ class OutDataset(Dataset, NeedDevice, ABC):
|
|
|
91
91
|
super().write(self.group, name, layer.numpy(), self.attributes[index][0][0])
|
|
92
92
|
self.attributes.pop(index)
|
|
93
93
|
|
|
94
|
+
class Reduction():
|
|
95
|
+
|
|
96
|
+
def __init__(self):
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
class ReductionMean():
|
|
100
|
+
|
|
101
|
+
def __init__(self):
|
|
102
|
+
pass
|
|
103
|
+
|
|
104
|
+
class ReductionMedian():
|
|
105
|
+
|
|
106
|
+
def __init__(self):
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
|
|
94
111
|
class OutSameAsGroupDataset(OutDataset):
|
|
95
112
|
|
|
96
113
|
@config("OutDataset")
|
|
@@ -109,6 +109,8 @@ class _Trainer():
|
|
|
109
109
|
for data in data_log:
|
|
110
110
|
self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
|
|
111
111
|
|
|
112
|
+
|
|
113
|
+
|
|
112
114
|
def __enter__(self):
|
|
113
115
|
return self
|
|
114
116
|
|
|
@@ -314,6 +316,18 @@ class Trainer(DistributedObject):
|
|
|
314
316
|
self.ema_decay = ema_decay
|
|
315
317
|
self.modelEMA : Union[torch.optim.swa_utils.AveragedModel, None] = None
|
|
316
318
|
self.data_log = data_log
|
|
319
|
+
|
|
320
|
+
modules = []
|
|
321
|
+
for i,_ in self.model.named_modules():
|
|
322
|
+
modules.append(i)
|
|
323
|
+
for k in self.data_log:
|
|
324
|
+
tmp = k.split("/")[0].replace(":", ".")
|
|
325
|
+
if tmp not in self.dataset.getGroupsDest() and tmp not in modules:
|
|
326
|
+
raise TrainerError( f"Invalid key '{tmp}' in `data_log`.",
|
|
327
|
+
f"This key is neither a destination group from the dataset ({self.dataset.getGroupsDest()})",
|
|
328
|
+
f"nor a valid module name in the model ({modules}).",
|
|
329
|
+
"Please check your `data_log` configuration — it should reference either a model output or a dataset group.")
|
|
330
|
+
|
|
317
331
|
self.gradient_checkpoints = gradient_checkpoints
|
|
318
332
|
self.gpu_checkpoints = gpu_checkpoints
|
|
319
333
|
self.save_checkpoint_mode = save_checkpoint_mode
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|