konfai 1.1.3__tar.gz → 1.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- {konfai-1.1.3 → konfai-1.1.5}/PKG-INFO +1 -1
- {konfai-1.1.3 → konfai-1.1.5}/konfai/data/augmentation.py +0 -14
- {konfai-1.1.3 → konfai-1.1.5}/konfai/data/data_manager.py +14 -6
- {konfai-1.1.3 → konfai-1.1.5}/konfai/data/patching.py +14 -17
- {konfai-1.1.3 → konfai-1.1.5}/konfai/evaluator.py +6 -3
- {konfai-1.1.3 → konfai-1.1.5}/konfai/metric/measure.py +37 -13
- konfai-1.1.5/konfai/models/segmentation/NestedUNet.py +115 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/models/segmentation/UNet.py +0 -1
- {konfai-1.1.3 → konfai-1.1.5}/konfai/network/blocks.py +32 -13
- {konfai-1.1.3 → konfai-1.1.5}/konfai/network/network.py +29 -15
- {konfai-1.1.3 → konfai-1.1.5}/konfai/predictor.py +29 -6
- {konfai-1.1.3 → konfai-1.1.5}/konfai/trainer.py +3 -2
- {konfai-1.1.3 → konfai-1.1.5}/konfai/utils/config.py +10 -7
- {konfai-1.1.3 → konfai-1.1.5}/konfai/utils/utils.py +9 -4
- {konfai-1.1.3 → konfai-1.1.5}/konfai.egg-info/PKG-INFO +1 -1
- {konfai-1.1.3 → konfai-1.1.5}/pyproject.toml +1 -1
- konfai-1.1.3/konfai/models/segmentation/NestedUNet.py +0 -53
- {konfai-1.1.3 → konfai-1.1.5}/LICENSE +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/README.md +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/__init__.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/data/__init__.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/data/transform.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/main.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/metric/__init__.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/metric/schedulers.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/models/classification/convNeXt.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/models/classification/resnet.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/models/generation/cStyleGan.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/models/generation/ddpm.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/models/generation/diffusionGan.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/models/generation/gan.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/models/generation/vae.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/models/registration/registration.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/models/representation/representation.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/network/__init__.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/utils/ITK.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/utils/__init__.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/utils/dataset.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai/utils/registration.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai.egg-info/SOURCES.txt +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai.egg-info/dependency_links.txt +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai.egg-info/entry_points.txt +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai.egg-info/requires.txt +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/konfai.egg-info/top_level.txt +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/setup.cfg +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/tests/test_config.py +0 -0
- {konfai-1.1.3 → konfai-1.1.5}/tests/test_dataset.py +0 -0
|
@@ -134,7 +134,6 @@ class EulerTransform(DataAugmentation):
|
|
|
134
134
|
|
|
135
135
|
class Translate(EulerTransform):
|
|
136
136
|
|
|
137
|
-
@config("Translate")
|
|
138
137
|
def __init__(self, t_min: float = -10, t_max = 10, is_int: bool = False):
|
|
139
138
|
super().__init__()
|
|
140
139
|
self.t_min = t_min
|
|
@@ -152,7 +151,6 @@ class Translate(EulerTransform):
|
|
|
152
151
|
|
|
153
152
|
class Rotate(EulerTransform):
|
|
154
153
|
|
|
155
|
-
@config("Rotate")
|
|
156
154
|
def __init__(self, a_min: float = 0, a_max: float = 360, is_quarter: bool = False):
|
|
157
155
|
super().__init__()
|
|
158
156
|
self.a_min = a_min
|
|
@@ -174,7 +172,6 @@ class Rotate(EulerTransform):
|
|
|
174
172
|
|
|
175
173
|
class Scale(EulerTransform):
|
|
176
174
|
|
|
177
|
-
@config("Scale")
|
|
178
175
|
def __init__(self, s_std: float = 0.2):
|
|
179
176
|
super().__init__()
|
|
180
177
|
self.s_std = s_std
|
|
@@ -187,7 +184,6 @@ class Scale(EulerTransform):
|
|
|
187
184
|
|
|
188
185
|
class Flip(DataAugmentation):
|
|
189
186
|
|
|
190
|
-
@config("Flip")
|
|
191
187
|
def __init__(self, f_prob: Union[list[float], None] = [0.33, 0.33 ,0.33]) -> None:
|
|
192
188
|
super().__init__()
|
|
193
189
|
self.f_prob = f_prob
|
|
@@ -211,7 +207,6 @@ class Flip(DataAugmentation):
|
|
|
211
207
|
|
|
212
208
|
class ColorTransform(DataAugmentation):
|
|
213
209
|
|
|
214
|
-
@config("ColorTransform")
|
|
215
210
|
def __init__(self, groups: Union[list[str], None] = None) -> None:
|
|
216
211
|
super().__init__(groups)
|
|
217
212
|
self.matrix: dict[int, list[torch.Tensor]] = {}
|
|
@@ -236,7 +231,6 @@ class ColorTransform(DataAugmentation):
|
|
|
236
231
|
|
|
237
232
|
class Brightness(ColorTransform):
|
|
238
233
|
|
|
239
|
-
@config("Brightness")
|
|
240
234
|
def __init__(self, b_std: float, groups: Union[list[str], None] = None) -> None:
|
|
241
235
|
super().__init__(groups)
|
|
242
236
|
self.b_std = b_std
|
|
@@ -248,7 +242,6 @@ class Brightness(ColorTransform):
|
|
|
248
242
|
|
|
249
243
|
class Contrast(ColorTransform):
|
|
250
244
|
|
|
251
|
-
@config("Contrast")
|
|
252
245
|
def __init__(self, c_std: float, groups: Union[list[str], None] = None) -> None:
|
|
253
246
|
super().__init__(groups)
|
|
254
247
|
self.c_std = c_std
|
|
@@ -260,7 +253,6 @@ class Contrast(ColorTransform):
|
|
|
260
253
|
|
|
261
254
|
class LumaFlip(ColorTransform):
|
|
262
255
|
|
|
263
|
-
@config("LumaFlip")
|
|
264
256
|
def __init__(self, groups: Union[list[str], None] = None) -> None:
|
|
265
257
|
super().__init__(groups)
|
|
266
258
|
self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
|
|
@@ -285,7 +277,6 @@ class HUE(ColorTransform):
|
|
|
285
277
|
|
|
286
278
|
class Saturation(ColorTransform):
|
|
287
279
|
|
|
288
|
-
@config("Saturation")
|
|
289
280
|
def __init__(self, s_std: float, groups: Union[list[str], None] = None) -> None:
|
|
290
281
|
super().__init__(groups)
|
|
291
282
|
self.s_std = s_std
|
|
@@ -372,7 +363,6 @@ class Saturation(ColorTransform):
|
|
|
372
363
|
|
|
373
364
|
class Noise(DataAugmentation):
|
|
374
365
|
|
|
375
|
-
@config("Noise")
|
|
376
366
|
def __init__(self, n_std: float, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02, groups: Union[list[str], None] = None) -> None:
|
|
377
367
|
super().__init__(groups)
|
|
378
368
|
self.n_std = n_std
|
|
@@ -427,7 +417,6 @@ class Noise(DataAugmentation):
|
|
|
427
417
|
|
|
428
418
|
class CutOUT(DataAugmentation):
|
|
429
419
|
|
|
430
|
-
@config("CutOUT")
|
|
431
420
|
def __init__(self, c_prob: float, cutout_size: int, value: float, groups: Union[list[str], None] = None) -> None:
|
|
432
421
|
super().__init__(groups)
|
|
433
422
|
self.c_prob = c_prob
|
|
@@ -458,7 +447,6 @@ class CutOUT(DataAugmentation):
|
|
|
458
447
|
|
|
459
448
|
class Elastix(DataAugmentation):
|
|
460
449
|
|
|
461
|
-
@config("Elastix")
|
|
462
450
|
def __init__(self, grid_spacing: int = 16, max_displacement: int = 16) -> None:
|
|
463
451
|
super().__init__()
|
|
464
452
|
self.grid_spacing = grid_spacing
|
|
@@ -529,7 +517,6 @@ class Elastix(DataAugmentation):
|
|
|
529
517
|
|
|
530
518
|
class Permute(DataAugmentation):
|
|
531
519
|
|
|
532
|
-
@config("Permute")
|
|
533
520
|
def __init__(self, prob_permute: Union[list[float], None] = [0.5 ,0.5]) -> None:
|
|
534
521
|
super().__init__()
|
|
535
522
|
self._permute_dims = torch.tensor([[0, 2, 1, 3], [0, 3, 1, 2]])
|
|
@@ -567,7 +554,6 @@ class Permute(DataAugmentation):
|
|
|
567
554
|
|
|
568
555
|
class Mask(DataAugmentation):
|
|
569
556
|
|
|
570
|
-
@config("Mask")
|
|
571
557
|
def __init__(self, mask: str, value: float, groups: Union[list[str], None] = None) -> None:
|
|
572
558
|
super().__init__(groups)
|
|
573
559
|
if mask is not None:
|
|
@@ -129,7 +129,7 @@ class DatasetIter(data.Dataset):
|
|
|
129
129
|
for group_dest in self.groups_src[group_src]:
|
|
130
130
|
self.data[group_dest][index].unloadAugmentation()
|
|
131
131
|
self.data[group_dest][index].resetAugmentation()
|
|
132
|
-
|
|
132
|
+
self.load(label + " Augmentation")
|
|
133
133
|
|
|
134
134
|
def load(self, label: str):
|
|
135
135
|
if self.use_cache:
|
|
@@ -217,7 +217,6 @@ class Subset():
|
|
|
217
217
|
else:
|
|
218
218
|
names_filtred = names
|
|
219
219
|
size = len(names_filtred)
|
|
220
|
-
|
|
221
220
|
index = []
|
|
222
221
|
if self.subset is None:
|
|
223
222
|
index = list(range(0, size))
|
|
@@ -226,15 +225,24 @@ class Subset():
|
|
|
226
225
|
r = np.clip(np.asarray([int(self.subset.split(":")[0]), int(self.subset.split(":")[1])]), 0, size)
|
|
227
226
|
index = list(range(r[0], r[1]))
|
|
228
227
|
elif os.path.exists(self.subset):
|
|
229
|
-
|
|
228
|
+
train_names = []
|
|
230
229
|
with open(self.subset, "r") as f:
|
|
231
230
|
for name in f:
|
|
232
|
-
|
|
231
|
+
train_names.append(name.strip())
|
|
233
232
|
index = []
|
|
234
233
|
for i, name in enumerate(names_filtred):
|
|
235
|
-
if name in
|
|
234
|
+
if name in train_names:
|
|
236
235
|
index.append(i)
|
|
237
|
-
|
|
236
|
+
elif self.subset.startswith("~") and os.path.exists(self.subset[1:]):
|
|
237
|
+
exclude_names = []
|
|
238
|
+
with open(self.subset[1:], "r") as f:
|
|
239
|
+
for name in f:
|
|
240
|
+
exclude_names.append(name.strip())
|
|
241
|
+
index = []
|
|
242
|
+
for i, name in enumerate(names_filtred):
|
|
243
|
+
if name not in exclude_names:
|
|
244
|
+
index.append(i)
|
|
245
|
+
|
|
238
246
|
elif isinstance(self.subset, list):
|
|
239
247
|
if len(self.subset) > 0:
|
|
240
248
|
if isinstance(self.subset[0], int):
|
|
@@ -125,23 +125,20 @@ class Accumulator():
|
|
|
125
125
|
return len(self.patch_slices) == len([v for v in self._layer_accumulator if v is not None])
|
|
126
126
|
|
|
127
127
|
def assemble(self) -> torch.Tensor:
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
result = result[tuple([slice(None, None)]+[slice(0, s) for s in self.shape])]
|
|
143
|
-
else:
|
|
144
|
-
result = torch.cat(tuple(self._layer_accumulator), dim=0)
|
|
128
|
+
N = 2 if self.batch else 1
|
|
129
|
+
result = torch.zeros((list(self._layer_accumulator[0].shape[:N])+list(max([[v.stop for v in patch] for patch in self.patch_slices]))), dtype=self._layer_accumulator[0].dtype).to(self._layer_accumulator[0].device)
|
|
130
|
+
for patch_slice, data in zip(self.patch_slices, self._layer_accumulator):
|
|
131
|
+
slices_dest = tuple([slice(result.shape[i]) for i in range(N)] + list(patch_slice))
|
|
132
|
+
|
|
133
|
+
for dim, s in enumerate(patch_slice):
|
|
134
|
+
if s.stop-s.start == 1:
|
|
135
|
+
data = data.unsqueeze(dim=dim+N)
|
|
136
|
+
if self.patchCombine is not None:
|
|
137
|
+
result[slices_dest] += self.patchCombine(data)
|
|
138
|
+
else:
|
|
139
|
+
result[slices_dest] = data
|
|
140
|
+
result = result[tuple([slice(None, None)]+[slice(0, s) for s in self.shape])]
|
|
141
|
+
|
|
145
142
|
self._layer_accumulator.clear()
|
|
146
143
|
return result
|
|
147
144
|
|
|
@@ -102,7 +102,7 @@ class Evaluator(DistributedObject):
|
|
|
102
102
|
result = {}
|
|
103
103
|
for output_group in self.metrics:
|
|
104
104
|
for target_group in self.metrics[output_group]:
|
|
105
|
-
targets = [data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0] for group in target_group.split("
|
|
105
|
+
targets = [data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0] for group in target_group.split(";") if group in data_dict]
|
|
106
106
|
name = data_dict[output_group][1][0]
|
|
107
107
|
for metric in self.metrics[output_group][target_group]:
|
|
108
108
|
result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0].to(0) if torch.cuda.is_available() else data_dict[output_group][0], *targets).item()
|
|
@@ -135,8 +135,11 @@ class Evaluator(DistributedObject):
|
|
|
135
135
|
f"Available groups: {sorted(groupsDest)}"
|
|
136
136
|
)
|
|
137
137
|
|
|
138
|
-
target_groups =
|
|
139
|
-
|
|
138
|
+
target_groups = []
|
|
139
|
+
for i in {target for targets in self.metrics.values() for target in targets}:
|
|
140
|
+
for u in i.split(";"):
|
|
141
|
+
target_groups.append(u)
|
|
142
|
+
missing_targets = set(target_groups) - set(groupsDest)
|
|
140
143
|
if missing_targets:
|
|
141
144
|
raise EvaluatorError(
|
|
142
145
|
f"The following metric target groups are missing from 'groupsDest': {sorted(missing_targets)}. ",
|
|
@@ -3,6 +3,9 @@ import importlib
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
+
from torchvision.models import inception_v3, Inception_V3_Weights
|
|
7
|
+
from torchvision.transforms import functional as F
|
|
8
|
+
from scipy import linalg
|
|
6
9
|
|
|
7
10
|
import torch.nn.functional as F
|
|
8
11
|
import os
|
|
@@ -401,21 +404,42 @@ class L1LossRepresentation(Criterion):
|
|
|
401
404
|
def forward(self, output: torch.Tensor) -> torch.Tensor:
|
|
402
405
|
return self.loss(output[0], output[1])+ self._variance(output[0]) + self._variance(output[1])
|
|
403
406
|
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
407
|
+
class FocalLoss(Criterion):
|
|
408
|
+
|
|
409
|
+
def __init__(self, gamma: float = 2.0, alpha: list[float] = [0.5, 2.0, 0.5, 0.5, 1], reduction: str = "mean"):
|
|
410
|
+
super().__init__()
|
|
411
|
+
raw_alpha = torch.tensor(alpha, dtype=torch.float32)
|
|
412
|
+
self.alpha = raw_alpha / raw_alpha.sum() * len(raw_alpha)
|
|
413
|
+
self.gamma = gamma
|
|
414
|
+
self.reduction = reduction
|
|
415
|
+
|
|
416
|
+
def forward(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
|
|
417
|
+
target = targets[0].long()
|
|
418
|
+
|
|
419
|
+
logpt = F.log_softmax(output, dim=1)
|
|
420
|
+
pt = torch.exp(logpt)
|
|
421
|
+
|
|
422
|
+
logpt = logpt.gather(1, target.unsqueeze(1))
|
|
423
|
+
pt = pt.gather(1, target.unsqueeze(1))
|
|
424
|
+
|
|
425
|
+
at = self.alpha[target].unsqueeze(1)
|
|
426
|
+
loss = -at * ((1 - pt) ** self.gamma) * logpt
|
|
427
|
+
|
|
428
|
+
if self.reduction == "mean":
|
|
429
|
+
return loss.mean()
|
|
430
|
+
elif self.reduction == "sum":
|
|
431
|
+
return loss.sum()
|
|
432
|
+
return loss
|
|
410
433
|
|
|
434
|
+
|
|
411
435
|
class FID(Criterion):
|
|
412
436
|
|
|
413
|
-
class InceptionV3(nn.Module):
|
|
437
|
+
class InceptionV3(torch.nn.Module):
|
|
414
438
|
|
|
415
439
|
def __init__(self) -> None:
|
|
416
440
|
super().__init__()
|
|
417
441
|
self.model = inception_v3(weights=Inception_V3_Weights.DEFAULT, transform_input=False)
|
|
418
|
-
self.model.fc = nn.Identity()
|
|
442
|
+
self.model.fc = torch.nn.Identity()
|
|
419
443
|
self.model.eval()
|
|
420
444
|
|
|
421
445
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -456,9 +480,9 @@ class FID(Criterion):
|
|
|
456
480
|
generated_features = FID.get_features(generated_images, self.inception_model)
|
|
457
481
|
|
|
458
482
|
return FID.calculate_fid(real_features, generated_features)
|
|
459
|
-
|
|
483
|
+
|
|
460
484
|
|
|
461
|
-
|
|
485
|
+
class MutualInformationLoss(torch.nn.Module):
|
|
462
486
|
def __init__(self, num_bins: int = 23, sigma_ratio: float = 0.5, smooth_nr: float = 1e-7, smooth_dr: float = 1e-7) -> None:
|
|
463
487
|
super().__init__()
|
|
464
488
|
bin_centers = torch.linspace(0.0, 1.0, num_bins)
|
|
@@ -469,13 +493,13 @@ class FID(Criterion):
|
|
|
469
493
|
self.smooth_nr = float(smooth_nr)
|
|
470
494
|
self.smooth_dr = float(smooth_dr)
|
|
471
495
|
|
|
472
|
-
def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) ->
|
|
496
|
+
def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
473
497
|
pred_weight, pred_probability = self.parzen_windowing_gaussian(pred)
|
|
474
498
|
target_weight, target_probability = self.parzen_windowing_gaussian(target)
|
|
475
499
|
return pred_weight, pred_probability, target_weight, target_probability
|
|
476
500
|
|
|
477
501
|
|
|
478
|
-
def parzen_windowing_gaussian(self, img: torch.Tensor) ->
|
|
502
|
+
def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
479
503
|
img = torch.clamp(img, 0, 1)
|
|
480
504
|
img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
|
|
481
505
|
weight = torch.exp(-self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2) # (batch, num_sample, num_bin)
|
|
@@ -488,4 +512,4 @@ class FID(Criterion):
|
|
|
488
512
|
pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
|
|
489
513
|
papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
|
|
490
514
|
mi = torch.sum(pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)) # (batch)
|
|
491
|
-
return torch.mean(mi).neg() # average over the batch and channel ndims
|
|
515
|
+
return torch.mean(mi).neg() # average over the batch and channel ndims
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from konfai.network import network, blocks
|
|
2
|
+
from typing import Union
|
|
3
|
+
import torch
|
|
4
|
+
from konfai.data.patching import ModelPatch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class NestedUNetBlock(network.ModuleArgsDict):
|
|
8
|
+
|
|
9
|
+
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, dim: int, i : int = 0) -> None:
|
|
10
|
+
super().__init__()
|
|
11
|
+
if i > 0:
|
|
12
|
+
self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
|
|
13
|
+
|
|
14
|
+
self.add_module("X_{}_{}".format(i, 0), block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), out_branch=["X_{}_{}".format(i, 0)])
|
|
15
|
+
if len(channels) > 2:
|
|
16
|
+
self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(channels)-2)])
|
|
17
|
+
for j in range(len(channels)-2):
|
|
18
|
+
self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=channels[2], out_channels=channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
|
|
19
|
+
self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
20
|
+
self.add_module("X_{}_{}".format(i, j+1), block(in_channels=(channels[1]*(j+1)+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*(j+2), out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
21
|
+
|
|
22
|
+
class NestedUNet(network.Network):
|
|
23
|
+
|
|
24
|
+
class NestedUNetHead(network.ModuleArgsDict):
|
|
25
|
+
|
|
26
|
+
def __init__(self, in_channels: int, nb_class: int, activation: str, dim: int) -> None:
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
|
29
|
+
if activation == "Softmax":
|
|
30
|
+
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
31
|
+
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
32
|
+
elif activation == "Tanh":
|
|
33
|
+
self.add_module("Tanh", torch.nn.Tanh())
|
|
34
|
+
|
|
35
|
+
def __init__( self,
|
|
36
|
+
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
37
|
+
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
38
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
39
|
+
patch : Union[ModelPatch, None] = None,
|
|
40
|
+
dim : int = 3,
|
|
41
|
+
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
42
|
+
nb_class: int = 2,
|
|
43
|
+
blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
|
|
44
|
+
nb_conv_per_stage: int = 2,
|
|
45
|
+
downSampleMode: str = "MAXPOOL",
|
|
46
|
+
upSampleMode: str = "CONV_TRANSPOSE",
|
|
47
|
+
attention : bool = False,
|
|
48
|
+
blockType: str = "Conv",
|
|
49
|
+
activation: str = "Softmax") -> None:
|
|
50
|
+
super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
|
|
51
|
+
|
|
52
|
+
self.add_module("UNetBlock_0", NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
|
|
53
|
+
for j in range(len(channels)-2):
|
|
54
|
+
self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, activation=activation, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ResNetEncoderLayer(network.ModuleArgsDict):
|
|
60
|
+
|
|
61
|
+
def __init__(self, in_channel: int, out_channel: int, nb_block: int, dim: int, downSampleMode : blocks.DownSampleMode):
|
|
62
|
+
super().__init__()
|
|
63
|
+
for i in range(nb_block):
|
|
64
|
+
if downSampleMode == blocks.DownSampleMode.MAXPOOL and i == 0:
|
|
65
|
+
self.add_module("DownSample", blocks.getTorchModule("MaxPool", dim)(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
|
|
66
|
+
self.add_module("ResBlock_{}".format(i), blocks.ResBlock(in_channel, out_channel, [blocks.BlockConfig(3, 2 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i == 0 else 1, 1, False, "ReLU;True", blocks.NormMode.BATCH), blocks.BlockConfig(3, 1, 1, False, None, blocks.NormMode.BATCH)], dim=dim))
|
|
67
|
+
in_channel = out_channel
|
|
68
|
+
|
|
69
|
+
def resNetEncoder(channels: list[int], layers: list[int], dim: int) -> list[torch.nn.Module]:
|
|
70
|
+
modules = []
|
|
71
|
+
modules.append(blocks.ConvBlock(channels[0], channels[1], [blocks.BlockConfig(7, 2, 3, False, "ReLU", blocks.NormMode.BATCH)], dim=dim))
|
|
72
|
+
for i, (in_channel, out_channel, layer) in enumerate(zip(channels[1:], channels[2:], layers)):
|
|
73
|
+
modules.append(ResNetEncoderLayer(in_channel, out_channel, layer, dim, blocks.DownSampleMode.MAXPOOL if i==0 else blocks.DownSampleMode.CONV_STRIDE))
|
|
74
|
+
return modules
|
|
75
|
+
|
|
76
|
+
class NestedUNetBlock(network.ModuleArgsDict):
|
|
77
|
+
|
|
78
|
+
def __init__(self, encoder_channels: list[int], decoder_channels: list[int], encoders: list[torch.nn.Module], upSampleMode: blocks.UpSampleMode, dim: int, i : int = 0) -> None:
|
|
79
|
+
super().__init__()
|
|
80
|
+
self.add_module("X_{}_{}".format(i, 0), encoders[0], out_branch=["X_{}_{}".format(i, 0)])
|
|
81
|
+
if len(encoder_channels) > 2:
|
|
82
|
+
self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(encoder_channels[1:], decoder_channels[1:], encoders[1:], upSampleMode, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(encoder_channels)-2)])
|
|
83
|
+
for j in range(len(encoder_channels)-2):
|
|
84
|
+
self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=encoder_channels[2], out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
|
|
85
|
+
self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
86
|
+
self.add_module("X_{}_{}".format(i, j+1), blocks.ConvBlock(in_channels=encoder_channels[1]*j+((encoder_channels[1]+encoder_channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else (decoder_channels[1]+decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1]+encoder_channels[1])), out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
87
|
+
|
|
88
|
+
class NestedUNetHead(network.ModuleArgsDict):
|
|
89
|
+
|
|
90
|
+
def __init__(self, in_channels: int, out_channels: int, nb_class: int, dim: int) -> None:
|
|
91
|
+
super().__init__()
|
|
92
|
+
self.add_module("Upsample", blocks.upSample(in_channels=in_channels, out_channels=out_channels, upSampleMode=blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim))
|
|
93
|
+
self.add_module("ConvBlock", blocks.ConvBlock(in_channels=in_channels, out_channels=out_channels, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim))
|
|
94
|
+
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = out_channels, out_channels = nb_class, kernel_size = 3, stride = 1, padding = 1))
|
|
95
|
+
if nb_class > 1:
|
|
96
|
+
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
97
|
+
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
98
|
+
else:
|
|
99
|
+
self.add_module("Tanh", torch.nn.Tanh())
|
|
100
|
+
|
|
101
|
+
class UNetpp(network.Network):
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def __init__( self,
|
|
105
|
+
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
106
|
+
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
107
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
108
|
+
patch : Union[ModelPatch, None] = None,
|
|
109
|
+
encoder_channels: list[int] = [1,64,64,128,256,512],
|
|
110
|
+
decoder_channels: list[int] = [256,128,64,32,16,1],
|
|
111
|
+
layers: list[int] = [3,4,6,3],
|
|
112
|
+
dim : int = 2) -> None:
|
|
113
|
+
super().__init__(in_channels = encoder_channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
|
|
114
|
+
self.add_module("Block_0", NestedUNetBlock(encoder_channels, decoder_channels[::-1], resNetEncoder(encoder_channels, layers, dim), blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(encoder_channels)-2)])
|
|
115
|
+
self.add_module("Head", NestedUNetHead(in_channels=decoder_channels[-3], out_channels=decoder_channels[-2], nb_class=decoder_channels[-1], dim=dim), in_branch=["X_0_{}".format(len(encoder_channels)-2)], out_branch=[-1])
|
|
@@ -38,7 +38,6 @@ class UNetBlock(network.ModuleArgsDict):
|
|
|
38
38
|
|
|
39
39
|
class UNet(network.Network):
|
|
40
40
|
|
|
41
|
-
@config("UNet")
|
|
42
41
|
def __init__( self,
|
|
43
42
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
44
43
|
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
@@ -19,7 +19,7 @@ class NormMode(Enum):
|
|
|
19
19
|
SYNCBATCH = 5
|
|
20
20
|
INSTANCE_AFFINE = 6
|
|
21
21
|
|
|
22
|
-
def getNorm(normMode: Enum, channels : int, dim: int) -> torch.nn.Module:
|
|
22
|
+
def getNorm(normMode: Enum, channels : int, dim: int) -> Union[torch.nn.Module, None]:
|
|
23
23
|
if normMode == NormMode.BATCH:
|
|
24
24
|
return getTorchModule("BatchNorm", dim = dim)(channels, affine=True, track_running_stats=True)
|
|
25
25
|
if normMode == NormMode.INSTANCE:
|
|
@@ -32,7 +32,7 @@ def getNorm(normMode: Enum, channels : int, dim: int) -> torch.nn.Module:
|
|
|
32
32
|
return torch.nn.GroupNorm(num_groups=32, num_channels=channels)
|
|
33
33
|
if normMode == NormMode.LAYER:
|
|
34
34
|
return torch.nn.GroupNorm(num_groups=1, num_channels=channels)
|
|
35
|
-
return
|
|
35
|
+
return None
|
|
36
36
|
|
|
37
37
|
class UpSampleMode(Enum):
|
|
38
38
|
CONV_TRANSPOSE = 0,
|
|
@@ -59,9 +59,9 @@ class BlockConfig():
|
|
|
59
59
|
self.stride = stride
|
|
60
60
|
self.padding = padding
|
|
61
61
|
self.activation = activation
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
62
|
+
if normMode is None:
|
|
63
|
+
self.norm = None
|
|
64
|
+
elif isinstance(normMode, str):
|
|
65
65
|
self.norm = NormMode._member_map_[normMode]
|
|
66
66
|
elif isinstance(normMode, NormMode):
|
|
67
67
|
self.norm = normMode
|
|
@@ -70,9 +70,13 @@ class BlockConfig():
|
|
|
70
70
|
return getTorchModule("Conv", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = self.kernel_size, stride = self.stride, padding = self.padding, bias=self.bias)
|
|
71
71
|
|
|
72
72
|
def getNorm(self, channels : int, dim: int) -> torch.nn.Module:
|
|
73
|
+
if self.norm is None:
|
|
74
|
+
return None
|
|
73
75
|
return getNorm(self.norm, channels, dim) if isinstance(self.norm, NormMode) else self.norm(channels)
|
|
74
76
|
|
|
75
77
|
def getActivation(self) -> torch.nn.Module:
|
|
78
|
+
if self.activation is None:
|
|
79
|
+
return None
|
|
76
80
|
if isinstance(self.activation, str):
|
|
77
81
|
return getTorchModule(self.activation.split(";")[0])(*[ast.literal_eval(value) for value in self.activation.split(";")[1:]]) if self.activation != "None" else torch.nn.Identity()
|
|
78
82
|
return self.activation()
|
|
@@ -83,21 +87,35 @@ class ConvBlock(network.ModuleArgsDict):
|
|
|
83
87
|
super().__init__()
|
|
84
88
|
for i, blockConfig in enumerate(blockConfigs):
|
|
85
89
|
self.add_module("Conv_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[0])
|
|
86
|
-
|
|
87
|
-
|
|
90
|
+
norm = blockConfig.getNorm(out_channels, dim)
|
|
91
|
+
if norm is not None:
|
|
92
|
+
self.add_module("Norm_{}".format(i), norm, alias=alias[1])
|
|
93
|
+
activation = blockConfig.getActivation()
|
|
94
|
+
if activation is not None:
|
|
95
|
+
self.add_module("Activation_{}".format(i), activation, alias=alias[2])
|
|
88
96
|
in_channels = out_channels
|
|
89
97
|
|
|
90
98
|
class ResBlock(network.ModuleArgsDict):
|
|
91
99
|
|
|
92
|
-
def __init__(self, in_channels : int, out_channels : int,
|
|
100
|
+
def __init__(self, in_channels : int, out_channels : int, blockConfigs : list[BlockConfig], dim : int, alias : list[list[str]]=[[], [], [], [], []]) -> None:
|
|
93
101
|
super().__init__()
|
|
94
|
-
for i in
|
|
95
|
-
self.add_module("
|
|
102
|
+
for i, blockConfig in enumerate(blockConfigs):
|
|
103
|
+
self.add_module("Conv_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[0])
|
|
104
|
+
norm = blockConfig.getNorm(out_channels, dim)
|
|
105
|
+
if norm is not None:
|
|
106
|
+
self.add_module("Norm_{}".format(i), norm, alias=alias[1])
|
|
107
|
+
activation = blockConfig.getActivation()
|
|
108
|
+
if activation is not None:
|
|
109
|
+
self.add_module("Activation_{}".format(i), activation, alias=alias[2])
|
|
110
|
+
|
|
96
111
|
if in_channels != out_channels:
|
|
97
|
-
self.add_module("
|
|
112
|
+
self.add_module("Conv_skip", getTorchModule("Conv", dim)(in_channels, out_channels, 1, blockConfig.stride, bias=blockConfig.bias), alias=alias[3], in_branch=[1], out_branch=[1])
|
|
113
|
+
self.add_module("Norm_skip", blockConfig.getNorm(out_channels, dim), alias=alias[4], in_branch=[1], out_branch=[1])
|
|
98
114
|
in_channels = out_channels
|
|
99
|
-
|
|
100
|
-
|
|
115
|
+
|
|
116
|
+
self.add_module("Add", Add(), in_branch=[0,1])
|
|
117
|
+
self.add_module("Norm_{}".format(i+1), torch.nn.ReLU(inplace=True))
|
|
118
|
+
|
|
101
119
|
def downSample(in_channels: int, out_channels: int, downSampleMode: DownSampleMode, dim: int) -> torch.nn.Module:
|
|
102
120
|
if downSampleMode == DownSampleMode.MAXPOOL:
|
|
103
121
|
return getTorchModule("MaxPool", dim = dim)(2)
|
|
@@ -176,6 +194,7 @@ class Print(torch.nn.Module):
|
|
|
176
194
|
super().__init__()
|
|
177
195
|
|
|
178
196
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
197
|
+
print(input.shape)
|
|
179
198
|
return input
|
|
180
199
|
|
|
181
200
|
class Write(torch.nn.Module):
|
|
@@ -133,9 +133,10 @@ class Measure():
|
|
|
133
133
|
self._loss.clear()
|
|
134
134
|
|
|
135
135
|
def add(self, weight: float, value: torch.Tensor) -> None:
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
136
|
+
if self.isLoss or value != 0:
|
|
137
|
+
self._loss.append(value if self.isLoss else value.detach())
|
|
138
|
+
self._values.append(value.item())
|
|
139
|
+
self._weight.append(weight)
|
|
139
140
|
|
|
140
141
|
def getLastLoss(self) -> torch.Tensor:
|
|
141
142
|
return self._loss[-1]*self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad = True)
|
|
@@ -157,7 +158,7 @@ class Measure():
|
|
|
157
158
|
outputs_group_rename = {}
|
|
158
159
|
|
|
159
160
|
modules = []
|
|
160
|
-
for i,_ in model.
|
|
161
|
+
for i,_,_ in model.named_ModuleArgsDict():
|
|
161
162
|
modules.append(i)
|
|
162
163
|
|
|
163
164
|
for output_group in self.outputsCriterions.keys():
|
|
@@ -167,11 +168,12 @@ class Measure():
|
|
|
167
168
|
"Please check that the name matches exactly a submodule or output of your model architecture."
|
|
168
169
|
)
|
|
169
170
|
for target_group in self.outputsCriterions[output_group]:
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
171
|
+
for target_group_tmp in target_group.split(";"):
|
|
172
|
+
if target_group_tmp not in group_dest:
|
|
173
|
+
raise MeasureError(
|
|
174
|
+
f"The target_group '{target_group_tmp}' defined in 'outputsCriterions.{output_group}.targetsCriterions' was not found in the available destination groups.",
|
|
175
|
+
"This target_group is expected for loss or metric computation, but was not loaded in 'group_dest'.",
|
|
176
|
+
f"Please make sure that the group '{target_group_tmp}' is defined in 'Dataset:groups_src:...:groups_dest:'{target_group_tmp}'' and correctly loaded from the dataset.")
|
|
175
177
|
for criterion in self.outputsCriterions[output_group][target_group]:
|
|
176
178
|
if not self.outputsCriterions[output_group][target_group][criterion].isTorchCriterion:
|
|
177
179
|
outputs_group_rename[output_group] = criterion.init(model, output_group, target_group)
|
|
@@ -189,7 +191,7 @@ class Measure():
|
|
|
189
191
|
|
|
190
192
|
def update(self, output_group: str, output : torch.Tensor, data_dict: dict[str, torch.Tensor], it: int, nb_patch: int, training: bool) -> None:
|
|
191
193
|
for target_group in self.outputsCriterions[output_group]:
|
|
192
|
-
target = [data_dict[group].to(output[0].device).detach() for group in target_group.split("
|
|
194
|
+
target = [data_dict[group].to(output[0].device).detach() for group in target_group.split(";") if group in data_dict]
|
|
193
195
|
|
|
194
196
|
for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
|
|
195
197
|
if it >= criterionsAttr.stepStart and (criterionsAttr.stepStop is None or it <= criterionsAttr.stepStop):
|
|
@@ -928,6 +930,16 @@ class Network(ModuleArgsDict, ABC):
|
|
|
928
930
|
if isinstance(module, ModuleArgsDict):
|
|
929
931
|
module._training = state
|
|
930
932
|
|
|
933
|
+
class MinimalModel(Network):
|
|
934
|
+
|
|
935
|
+
def __init__(self, model: Network, optimizer : OptimizerLoader = OptimizerLoader(),
|
|
936
|
+
schedulers : LRSchedulersLoader = LRSchedulersLoader(),
|
|
937
|
+
outputsCriterions: dict[str, TargetCriterionsLoader] = {"default" : TargetCriterionsLoader()},
|
|
938
|
+
patch : Union[ModelPatch, None] = None,
|
|
939
|
+
dim : int = 3, nb_batch_per_step = 1, init_type = "normal", init_gain = 0.02):
|
|
940
|
+
super().__init__(1, optimizer, schedulers, outputsCriterions, patch, nb_batch_per_step, init_type, init_gain, dim)
|
|
941
|
+
self.add_module("Model", model)
|
|
942
|
+
|
|
931
943
|
class ModelLoader():
|
|
932
944
|
|
|
933
945
|
@config("Model")
|
|
@@ -937,11 +949,13 @@ class ModelLoader():
|
|
|
937
949
|
def getModel(self, train : bool = True, DL_args: Union[str, None] = None, DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain"]) -> Network:
|
|
938
950
|
if not DL_args:
|
|
939
951
|
DL_args="{}.Model".format(KONFAI_ROOT())
|
|
940
|
-
|
|
941
|
-
if not train
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
952
|
+
DL_args += "."+self.name
|
|
953
|
+
model = config(DL_args)(getattr(importlib.import_module(self.module), self.name))(config = None, DL_without = DL_without if not train else [])
|
|
954
|
+
if not isinstance(model, Network):
|
|
955
|
+
model = config(DL_args)(partial(MinimalModel, model))(config = None, DL_without = DL_without+["model"] if not train else [])
|
|
956
|
+
model.setName(self.name)
|
|
957
|
+
return model
|
|
958
|
+
|
|
945
959
|
class CPU_Model():
|
|
946
960
|
|
|
947
961
|
def __init__(self, model: Network) -> None:
|
|
@@ -8,7 +8,7 @@ import os
|
|
|
8
8
|
|
|
9
9
|
from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, KONFAI_ROOT
|
|
10
10
|
from konfai.utils.config import config
|
|
11
|
-
from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description
|
|
11
|
+
from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description, PredictorError
|
|
12
12
|
from konfai.utils.dataset import Dataset, Attribute
|
|
13
13
|
from konfai.data.data_manager import DataPrediction, DatasetIter
|
|
14
14
|
from konfai.data.patching import Accumulator, PathCombine
|
|
@@ -46,7 +46,7 @@ class OutDataset(Dataset, NeedDevice, ABC):
|
|
|
46
46
|
self.names: dict[int, str] = {}
|
|
47
47
|
self.nb_data_augmentation = 0
|
|
48
48
|
|
|
49
|
-
def load(self, name_layer: str, datasets: list[Dataset]):
|
|
49
|
+
def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
|
|
50
50
|
transforms_type = ["pre_transforms", "post_transforms", "final_transforms"]
|
|
51
51
|
for name, _transform_type, transform_type in [(k, getattr(self, "_{}".format(k)), getattr(self, k)) for k in transforms_type]:
|
|
52
52
|
|
|
@@ -137,9 +137,22 @@ class OutSameAsGroupDataset(OutDataset):
|
|
|
137
137
|
if self.inverse_transform:
|
|
138
138
|
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].post_transforms):
|
|
139
139
|
layer = transform.inverse(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
|
|
140
|
-
|
|
141
140
|
self.output_layer_accumulator[index_dataset][index_augmentation].addLayer(index_patch, layer)
|
|
142
141
|
|
|
142
|
+
|
|
143
|
+
def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
|
|
144
|
+
super().load(name_layer, datasets, groups)
|
|
145
|
+
|
|
146
|
+
if self.group_src not in groups.keys():
|
|
147
|
+
raise PredictorError(
|
|
148
|
+
f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}."
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if self.group_dest not in groups[self.group_src]:
|
|
152
|
+
raise PredictorError(
|
|
153
|
+
f"Destination group '{self.group_dest}' not found. Available groups: {groups[self.group_src]}."
|
|
154
|
+
)
|
|
155
|
+
|
|
143
156
|
def _getOutput(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
|
|
144
157
|
layer = self.output_layer_accumulator[index][index_augmentation].assemble()
|
|
145
158
|
name = self.names[index]
|
|
@@ -258,7 +271,6 @@ class _Predictor():
|
|
|
258
271
|
desc = lambda : "Prediction : {}".format(description(self.modelComposite))
|
|
259
272
|
self.dataloader_prediction.dataset.load("Prediction")
|
|
260
273
|
with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=True, desc = desc(), total=len(self.dataloader_prediction), ncols=0) as batch_iter:
|
|
261
|
-
dist.barrier()
|
|
262
274
|
for it, data_dict in batch_iter:
|
|
263
275
|
with torch.amp.autocast('cuda', enabled=self.autocast):
|
|
264
276
|
input = self.getInput(data_dict)
|
|
@@ -272,7 +284,7 @@ class _Predictor():
|
|
|
272
284
|
|
|
273
285
|
batch_iter.set_description(desc())
|
|
274
286
|
self.it += 1
|
|
275
|
-
|
|
287
|
+
|
|
276
288
|
def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
|
|
277
289
|
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.modelComposite.module}, 1)
|
|
278
290
|
|
|
@@ -404,9 +416,20 @@ class Predictor(DistributedObject):
|
|
|
404
416
|
|
|
405
417
|
shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
|
|
406
418
|
|
|
419
|
+
|
|
407
420
|
self.model.init(self.autocast, State.PREDICTION, self.dataset.getGroupsDest())
|
|
408
421
|
self.model.init_outputsGroup()
|
|
409
422
|
self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
|
|
423
|
+
|
|
424
|
+
modules = []
|
|
425
|
+
for i,_,_ in self.model.named_ModuleArgsDict():
|
|
426
|
+
modules.append(i)
|
|
427
|
+
for output_group in self.outsDataset.keys():
|
|
428
|
+
if output_group not in modules:
|
|
429
|
+
raise PredictorError("The output group '{}' defined in 'outputsCriterions' does not correspond to any module in the model.".format(output_group),
|
|
430
|
+
"Available modules: {}".format(modules),
|
|
431
|
+
"Please check that the name matches exactly a submodule or output of your model architecture."
|
|
432
|
+
)
|
|
410
433
|
self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), self.combine)
|
|
411
434
|
self.modelComposite.load(self._load())
|
|
412
435
|
|
|
@@ -416,7 +439,7 @@ class Predictor(DistributedObject):
|
|
|
416
439
|
self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
|
|
417
440
|
self.dataloader = self.dataset.getData(world_size//self.size)
|
|
418
441
|
for name, outDataset in self.outsDataset.items():
|
|
419
|
-
outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()))
|
|
442
|
+
outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()), {src : dest for src, inner in self.dataset.groups_src.items() for dest in inner})
|
|
420
443
|
|
|
421
444
|
|
|
422
445
|
def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
|
|
@@ -34,7 +34,7 @@ class EarlyStoppingBase:
|
|
|
34
34
|
class EarlyStopping(EarlyStoppingBase):
|
|
35
35
|
|
|
36
36
|
@config("EarlyStopping")
|
|
37
|
-
def __init__(self, monitor: Union[list[str], None] = [], patience=10, min_delta=0.0, mode="min"):
|
|
37
|
+
def __init__(self, monitor: Union[list[str], None] = [], patience: int=10, min_delta: float=0.0, mode: str="min"):
|
|
38
38
|
super().__init__()
|
|
39
39
|
self.monitor = [] if monitor is None else monitor
|
|
40
40
|
self.patience = patience
|
|
@@ -120,7 +120,8 @@ class _Trainer():
|
|
|
120
120
|
|
|
121
121
|
def run(self) -> None:
|
|
122
122
|
self.dataloader_training.dataset.load("Train")
|
|
123
|
-
self.dataloader_validation
|
|
123
|
+
if self.dataloader_validation is not None:
|
|
124
|
+
self.dataloader_validation.dataset.load("Validation")
|
|
124
125
|
with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress") as epoch_tqdm:
|
|
125
126
|
for self.epoch in epoch_tqdm:
|
|
126
127
|
self.train()
|
|
@@ -3,7 +3,7 @@ import ruamel.yaml
|
|
|
3
3
|
import inspect
|
|
4
4
|
import collections
|
|
5
5
|
from copy import deepcopy
|
|
6
|
-
from typing import Union, Literal, get_origin, get_args
|
|
6
|
+
from typing import Union, Literal, get_origin, get_args, Any
|
|
7
7
|
import torch
|
|
8
8
|
from konfai import CONFIG_FILE
|
|
9
9
|
from konfai.utils.utils import ConfigError
|
|
@@ -175,8 +175,11 @@ def config(key : Union[str, None] = None):
|
|
|
175
175
|
os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
|
|
176
176
|
kwargs = {}
|
|
177
177
|
for param in list(inspect.signature(function).parameters.values())[len(args):]:
|
|
178
|
-
|
|
178
|
+
if param.name in without:
|
|
179
|
+
continue
|
|
180
|
+
|
|
179
181
|
annotation = param.annotation
|
|
182
|
+
|
|
180
183
|
# --- support Literal ---
|
|
181
184
|
if get_origin(annotation) is Literal:
|
|
182
185
|
allowed_values = get_args(annotation)
|
|
@@ -192,11 +195,10 @@ def config(key : Union[str, None] = None):
|
|
|
192
195
|
for i in annotation.__args__:
|
|
193
196
|
annotation = i
|
|
194
197
|
break
|
|
195
|
-
|
|
196
|
-
continue
|
|
198
|
+
|
|
197
199
|
if not annotation == inspect._empty:
|
|
198
200
|
if annotation not in [int, str, bool, float, torch.Tensor]:
|
|
199
|
-
if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple") or str(annotation).startswith("typing.List"):
|
|
201
|
+
if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple") or str(annotation).startswith("typing.List") or str(annotation).startswith("typing.Sequence"):
|
|
200
202
|
elem_type = annotation.__args__[0]
|
|
201
203
|
values = config.getValue(param.name, param.default)
|
|
202
204
|
if getattr(elem_type, '__origin__', None) is Union:
|
|
@@ -221,7 +223,7 @@ def config(key : Union[str, None] = None):
|
|
|
221
223
|
elif str(annotation).startswith("dict"):
|
|
222
224
|
if annotation.__args__[0] == str:
|
|
223
225
|
values = config.getValue(param.name, param.default)
|
|
224
|
-
if values is not None and annotation.__args__[1] not in [int, str, bool, float]:
|
|
226
|
+
if values is not None and annotation.__args__[1] not in [int, str, bool, float, Any]:
|
|
225
227
|
try:
|
|
226
228
|
kwargs[param.name] = {value : annotation.__args__[1](config = filename, DL_args = key_tmp+"."+param.name+"."+value) for value in values}
|
|
227
229
|
except ValueError as e:
|
|
@@ -229,6 +231,7 @@ def config(key : Union[str, None] = None):
|
|
|
229
231
|
except Exception as e:
|
|
230
232
|
raise ConfigError("{} {}".format(values, e))
|
|
231
233
|
else:
|
|
234
|
+
|
|
232
235
|
kwargs[param.name] = values
|
|
233
236
|
else:
|
|
234
237
|
raise ConfigError("Config: The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]")
|
|
@@ -236,7 +239,7 @@ def config(key : Union[str, None] = None):
|
|
|
236
239
|
try:
|
|
237
240
|
kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
|
|
238
241
|
except Exception as e:
|
|
239
|
-
raise ConfigError("Failed to instantiate {} with type {}, error {} ".format(param.name, annotation
|
|
242
|
+
raise ConfigError("Failed to instantiate {} with type {}, error {} ".format(param.name, annotation, e))
|
|
240
243
|
|
|
241
244
|
if os.environ['KONFAI_CONFIG_VARIABLE'] == "True":
|
|
242
245
|
os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
|
|
@@ -21,7 +21,7 @@ import sys
|
|
|
21
21
|
import re
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def description(model, modelEMA = None, showMemory: bool = True) -> str:
|
|
24
|
+
def description(model, modelEMA = None, showMemory: bool = True, train: bool = True) -> str:
|
|
25
25
|
values_desc = lambda weights, values: " ".join(["{}({:.2f}) : {:.6f}".format(name.split(":")[-1], weight, value) for (name, value), weight in zip(values.items(), weights.values())])
|
|
26
26
|
model_desc = lambda model : "("+" ".join(["{}({:.6f}) : {}".format(name, network.optimizer.param_groups[0]['lr'] if network.optimizer is not None else 0, values_desc(network.measure.getLastWeights(), network.measure.getLastValues())) for name, network in model.module.getNetworks().items() if network.measure is not None])+")"
|
|
27
27
|
result = "Loss {}".format(model_desc(model))
|
|
@@ -33,9 +33,9 @@ def description(model, modelEMA = None, showMemory: bool = True) -> str:
|
|
|
33
33
|
return result
|
|
34
34
|
|
|
35
35
|
def _getModule(classpath : str, type : str) -> tuple[str, str]:
|
|
36
|
-
if len(classpath.split("
|
|
37
|
-
module = ".".join(classpath.split("
|
|
38
|
-
name = classpath.split("
|
|
36
|
+
if len(classpath.split(":")) > 1:
|
|
37
|
+
module = ".".join(classpath.split(":")[:-1])
|
|
38
|
+
name = classpath.split(":")[-1]
|
|
39
39
|
else:
|
|
40
40
|
module = "konfai."+type+("." if len(classpath.split(".")) > 2 else "")+".".join(classpath.split(".")[:-1])
|
|
41
41
|
name = classpath.split(".")[-1]
|
|
@@ -586,6 +586,11 @@ class EvaluatorError(KonfAIError):
|
|
|
586
586
|
def __init__(self, *message) -> None:
|
|
587
587
|
super().__init__("Evaluator", message)
|
|
588
588
|
|
|
589
|
+
class PredictorError(KonfAIError):
|
|
590
|
+
|
|
591
|
+
def __init__(self, *message) -> None:
|
|
592
|
+
super().__init__("Predictor", message)
|
|
593
|
+
|
|
589
594
|
class TransformError(KonfAIError):
|
|
590
595
|
|
|
591
596
|
def __init__(self, *message) -> None:
|
|
@@ -1,53 +0,0 @@
|
|
|
1
|
-
from typing import Union
|
|
2
|
-
import torch
|
|
3
|
-
|
|
4
|
-
from konfai.network import network, blocks
|
|
5
|
-
from konfai.utils.config import config
|
|
6
|
-
from konfai.data.patching import ModelPatch
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class NestedUNetBlock(network.ModuleArgsDict):
|
|
10
|
-
|
|
11
|
-
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, dim: int, i : int = 0) -> None:
|
|
12
|
-
super().__init__()
|
|
13
|
-
if i > 0:
|
|
14
|
-
self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
|
|
15
|
-
|
|
16
|
-
self.add_module("X_{}_{}".format(i, 0), block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), out_branch=["X_{}_{}".format(i, 0)])
|
|
17
|
-
if len(channels) > 2:
|
|
18
|
-
self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(channels)-2)])
|
|
19
|
-
for j in range(len(channels)-2):
|
|
20
|
-
self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=channels[2], out_channels=channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
|
|
21
|
-
self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
22
|
-
self.add_module("X_{}_{}".format(i, j+1), block(in_channels=(channels[1]*(j+1)+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*(j+2), out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
23
|
-
|
|
24
|
-
class NestedUNet(network.Network):
|
|
25
|
-
|
|
26
|
-
class NestedUNetHead(network.ModuleArgsDict):
|
|
27
|
-
|
|
28
|
-
def __init__(self, in_channels: int, nb_class: int, dim: int) -> None:
|
|
29
|
-
super().__init__()
|
|
30
|
-
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
|
31
|
-
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
32
|
-
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
33
|
-
|
|
34
|
-
@config("NestedUNet")
|
|
35
|
-
def __init__( self,
|
|
36
|
-
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
37
|
-
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
38
|
-
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
39
|
-
patch : Union[ModelPatch, None] = None,
|
|
40
|
-
dim : int = 3,
|
|
41
|
-
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
42
|
-
nb_class: int = 2,
|
|
43
|
-
blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
|
|
44
|
-
nb_conv_per_stage: int = 2,
|
|
45
|
-
downSampleMode: str = "MAXPOOL",
|
|
46
|
-
upSampleMode: str = "CONV_TRANSPOSE",
|
|
47
|
-
attention : bool = False,
|
|
48
|
-
blockType: str = "Conv") -> None:
|
|
49
|
-
super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
|
|
50
|
-
|
|
51
|
-
self.add_module("UNetBlock_0", NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
|
|
52
|
-
for j in range(len(channels)-2):
|
|
53
|
-
self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|