konfai 1.1.4__tar.gz → 1.1.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- {konfai-1.1.4 → konfai-1.1.5}/PKG-INFO +1 -1
- {konfai-1.1.4 → konfai-1.1.5}/konfai/data/data_manager.py +1 -1
- {konfai-1.1.4 → konfai-1.1.5}/konfai/data/patching.py +14 -17
- {konfai-1.1.4 → konfai-1.1.5}/konfai/metric/measure.py +37 -13
- konfai-1.1.5/konfai/models/segmentation/NestedUNet.py +115 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/network/blocks.py +32 -13
- {konfai-1.1.4 → konfai-1.1.5}/konfai/network/network.py +4 -3
- {konfai-1.1.4 → konfai-1.1.5}/konfai/predictor.py +28 -4
- {konfai-1.1.4 → konfai-1.1.5}/konfai/trainer.py +2 -1
- {konfai-1.1.4 → konfai-1.1.5}/konfai/utils/utils.py +5 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai.egg-info/PKG-INFO +1 -1
- {konfai-1.1.4 → konfai-1.1.5}/pyproject.toml +1 -1
- konfai-1.1.4/konfai/models/segmentation/NestedUNet.py +0 -52
- {konfai-1.1.4 → konfai-1.1.5}/LICENSE +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/README.md +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/__init__.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/data/__init__.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/data/augmentation.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/data/transform.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/evaluator.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/main.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/metric/__init__.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/metric/schedulers.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/models/classification/convNeXt.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/models/classification/resnet.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/models/generation/cStyleGan.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/models/generation/ddpm.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/models/generation/diffusionGan.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/models/generation/gan.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/models/generation/vae.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/models/registration/registration.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/models/representation/representation.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/models/segmentation/UNet.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/network/__init__.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/utils/ITK.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/utils/__init__.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/utils/config.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/utils/dataset.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai/utils/registration.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai.egg-info/SOURCES.txt +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai.egg-info/dependency_links.txt +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai.egg-info/entry_points.txt +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai.egg-info/requires.txt +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/konfai.egg-info/top_level.txt +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/setup.cfg +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/tests/test_config.py +0 -0
- {konfai-1.1.4 → konfai-1.1.5}/tests/test_dataset.py +0 -0
|
@@ -129,7 +129,7 @@ class DatasetIter(data.Dataset):
|
|
|
129
129
|
for group_dest in self.groups_src[group_src]:
|
|
130
130
|
self.data[group_dest][index].unloadAugmentation()
|
|
131
131
|
self.data[group_dest][index].resetAugmentation()
|
|
132
|
-
|
|
132
|
+
self.load(label + " Augmentation")
|
|
133
133
|
|
|
134
134
|
def load(self, label: str):
|
|
135
135
|
if self.use_cache:
|
|
@@ -125,23 +125,20 @@ class Accumulator():
|
|
|
125
125
|
return len(self.patch_slices) == len([v for v in self._layer_accumulator if v is not None])
|
|
126
126
|
|
|
127
127
|
def assemble(self) -> torch.Tensor:
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
result = result[tuple([slice(None, None)]+[slice(0, s) for s in self.shape])]
|
|
143
|
-
else:
|
|
144
|
-
result = torch.cat(tuple(self._layer_accumulator), dim=0)
|
|
128
|
+
N = 2 if self.batch else 1
|
|
129
|
+
result = torch.zeros((list(self._layer_accumulator[0].shape[:N])+list(max([[v.stop for v in patch] for patch in self.patch_slices]))), dtype=self._layer_accumulator[0].dtype).to(self._layer_accumulator[0].device)
|
|
130
|
+
for patch_slice, data in zip(self.patch_slices, self._layer_accumulator):
|
|
131
|
+
slices_dest = tuple([slice(result.shape[i]) for i in range(N)] + list(patch_slice))
|
|
132
|
+
|
|
133
|
+
for dim, s in enumerate(patch_slice):
|
|
134
|
+
if s.stop-s.start == 1:
|
|
135
|
+
data = data.unsqueeze(dim=dim+N)
|
|
136
|
+
if self.patchCombine is not None:
|
|
137
|
+
result[slices_dest] += self.patchCombine(data)
|
|
138
|
+
else:
|
|
139
|
+
result[slices_dest] = data
|
|
140
|
+
result = result[tuple([slice(None, None)]+[slice(0, s) for s in self.shape])]
|
|
141
|
+
|
|
145
142
|
self._layer_accumulator.clear()
|
|
146
143
|
return result
|
|
147
144
|
|
|
@@ -3,6 +3,9 @@ import importlib
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
5
5
|
|
|
6
|
+
from torchvision.models import inception_v3, Inception_V3_Weights
|
|
7
|
+
from torchvision.transforms import functional as F
|
|
8
|
+
from scipy import linalg
|
|
6
9
|
|
|
7
10
|
import torch.nn.functional as F
|
|
8
11
|
import os
|
|
@@ -401,21 +404,42 @@ class L1LossRepresentation(Criterion):
|
|
|
401
404
|
def forward(self, output: torch.Tensor) -> torch.Tensor:
|
|
402
405
|
return self.loss(output[0], output[1])+ self._variance(output[0]) + self._variance(output[1])
|
|
403
406
|
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
407
|
+
class FocalLoss(Criterion):
|
|
408
|
+
|
|
409
|
+
def __init__(self, gamma: float = 2.0, alpha: list[float] = [0.5, 2.0, 0.5, 0.5, 1], reduction: str = "mean"):
|
|
410
|
+
super().__init__()
|
|
411
|
+
raw_alpha = torch.tensor(alpha, dtype=torch.float32)
|
|
412
|
+
self.alpha = raw_alpha / raw_alpha.sum() * len(raw_alpha)
|
|
413
|
+
self.gamma = gamma
|
|
414
|
+
self.reduction = reduction
|
|
415
|
+
|
|
416
|
+
def forward(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
|
|
417
|
+
target = targets[0].long()
|
|
418
|
+
|
|
419
|
+
logpt = F.log_softmax(output, dim=1)
|
|
420
|
+
pt = torch.exp(logpt)
|
|
421
|
+
|
|
422
|
+
logpt = logpt.gather(1, target.unsqueeze(1))
|
|
423
|
+
pt = pt.gather(1, target.unsqueeze(1))
|
|
424
|
+
|
|
425
|
+
at = self.alpha[target].unsqueeze(1)
|
|
426
|
+
loss = -at * ((1 - pt) ** self.gamma) * logpt
|
|
427
|
+
|
|
428
|
+
if self.reduction == "mean":
|
|
429
|
+
return loss.mean()
|
|
430
|
+
elif self.reduction == "sum":
|
|
431
|
+
return loss.sum()
|
|
432
|
+
return loss
|
|
410
433
|
|
|
434
|
+
|
|
411
435
|
class FID(Criterion):
|
|
412
436
|
|
|
413
|
-
class InceptionV3(nn.Module):
|
|
437
|
+
class InceptionV3(torch.nn.Module):
|
|
414
438
|
|
|
415
439
|
def __init__(self) -> None:
|
|
416
440
|
super().__init__()
|
|
417
441
|
self.model = inception_v3(weights=Inception_V3_Weights.DEFAULT, transform_input=False)
|
|
418
|
-
self.model.fc = nn.Identity()
|
|
442
|
+
self.model.fc = torch.nn.Identity()
|
|
419
443
|
self.model.eval()
|
|
420
444
|
|
|
421
445
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -456,9 +480,9 @@ class FID(Criterion):
|
|
|
456
480
|
generated_features = FID.get_features(generated_images, self.inception_model)
|
|
457
481
|
|
|
458
482
|
return FID.calculate_fid(real_features, generated_features)
|
|
459
|
-
|
|
483
|
+
|
|
460
484
|
|
|
461
|
-
|
|
485
|
+
class MutualInformationLoss(torch.nn.Module):
|
|
462
486
|
def __init__(self, num_bins: int = 23, sigma_ratio: float = 0.5, smooth_nr: float = 1e-7, smooth_dr: float = 1e-7) -> None:
|
|
463
487
|
super().__init__()
|
|
464
488
|
bin_centers = torch.linspace(0.0, 1.0, num_bins)
|
|
@@ -469,13 +493,13 @@ class FID(Criterion):
|
|
|
469
493
|
self.smooth_nr = float(smooth_nr)
|
|
470
494
|
self.smooth_dr = float(smooth_dr)
|
|
471
495
|
|
|
472
|
-
def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) ->
|
|
496
|
+
def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
473
497
|
pred_weight, pred_probability = self.parzen_windowing_gaussian(pred)
|
|
474
498
|
target_weight, target_probability = self.parzen_windowing_gaussian(target)
|
|
475
499
|
return pred_weight, pred_probability, target_weight, target_probability
|
|
476
500
|
|
|
477
501
|
|
|
478
|
-
def parzen_windowing_gaussian(self, img: torch.Tensor) ->
|
|
502
|
+
def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
479
503
|
img = torch.clamp(img, 0, 1)
|
|
480
504
|
img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
|
|
481
505
|
weight = torch.exp(-self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2) # (batch, num_sample, num_bin)
|
|
@@ -488,4 +512,4 @@ class FID(Criterion):
|
|
|
488
512
|
pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
|
|
489
513
|
papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
|
|
490
514
|
mi = torch.sum(pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)) # (batch)
|
|
491
|
-
return torch.mean(mi).neg() # average over the batch and channel ndims
|
|
515
|
+
return torch.mean(mi).neg() # average over the batch and channel ndims
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from konfai.network import network, blocks
|
|
2
|
+
from typing import Union
|
|
3
|
+
import torch
|
|
4
|
+
from konfai.data.patching import ModelPatch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class NestedUNetBlock(network.ModuleArgsDict):
|
|
8
|
+
|
|
9
|
+
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, dim: int, i : int = 0) -> None:
|
|
10
|
+
super().__init__()
|
|
11
|
+
if i > 0:
|
|
12
|
+
self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
|
|
13
|
+
|
|
14
|
+
self.add_module("X_{}_{}".format(i, 0), block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), out_branch=["X_{}_{}".format(i, 0)])
|
|
15
|
+
if len(channels) > 2:
|
|
16
|
+
self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(channels)-2)])
|
|
17
|
+
for j in range(len(channels)-2):
|
|
18
|
+
self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=channels[2], out_channels=channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
|
|
19
|
+
self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
20
|
+
self.add_module("X_{}_{}".format(i, j+1), block(in_channels=(channels[1]*(j+1)+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*(j+2), out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
21
|
+
|
|
22
|
+
class NestedUNet(network.Network):
|
|
23
|
+
|
|
24
|
+
class NestedUNetHead(network.ModuleArgsDict):
|
|
25
|
+
|
|
26
|
+
def __init__(self, in_channels: int, nb_class: int, activation: str, dim: int) -> None:
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
|
29
|
+
if activation == "Softmax":
|
|
30
|
+
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
31
|
+
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
32
|
+
elif activation == "Tanh":
|
|
33
|
+
self.add_module("Tanh", torch.nn.Tanh())
|
|
34
|
+
|
|
35
|
+
def __init__( self,
|
|
36
|
+
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
37
|
+
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
38
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
39
|
+
patch : Union[ModelPatch, None] = None,
|
|
40
|
+
dim : int = 3,
|
|
41
|
+
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
42
|
+
nb_class: int = 2,
|
|
43
|
+
blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
|
|
44
|
+
nb_conv_per_stage: int = 2,
|
|
45
|
+
downSampleMode: str = "MAXPOOL",
|
|
46
|
+
upSampleMode: str = "CONV_TRANSPOSE",
|
|
47
|
+
attention : bool = False,
|
|
48
|
+
blockType: str = "Conv",
|
|
49
|
+
activation: str = "Softmax") -> None:
|
|
50
|
+
super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
|
|
51
|
+
|
|
52
|
+
self.add_module("UNetBlock_0", NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
|
|
53
|
+
for j in range(len(channels)-2):
|
|
54
|
+
self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, activation=activation, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class ResNetEncoderLayer(network.ModuleArgsDict):
|
|
60
|
+
|
|
61
|
+
def __init__(self, in_channel: int, out_channel: int, nb_block: int, dim: int, downSampleMode : blocks.DownSampleMode):
|
|
62
|
+
super().__init__()
|
|
63
|
+
for i in range(nb_block):
|
|
64
|
+
if downSampleMode == blocks.DownSampleMode.MAXPOOL and i == 0:
|
|
65
|
+
self.add_module("DownSample", blocks.getTorchModule("MaxPool", dim)(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
|
|
66
|
+
self.add_module("ResBlock_{}".format(i), blocks.ResBlock(in_channel, out_channel, [blocks.BlockConfig(3, 2 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i == 0 else 1, 1, False, "ReLU;True", blocks.NormMode.BATCH), blocks.BlockConfig(3, 1, 1, False, None, blocks.NormMode.BATCH)], dim=dim))
|
|
67
|
+
in_channel = out_channel
|
|
68
|
+
|
|
69
|
+
def resNetEncoder(channels: list[int], layers: list[int], dim: int) -> list[torch.nn.Module]:
|
|
70
|
+
modules = []
|
|
71
|
+
modules.append(blocks.ConvBlock(channels[0], channels[1], [blocks.BlockConfig(7, 2, 3, False, "ReLU", blocks.NormMode.BATCH)], dim=dim))
|
|
72
|
+
for i, (in_channel, out_channel, layer) in enumerate(zip(channels[1:], channels[2:], layers)):
|
|
73
|
+
modules.append(ResNetEncoderLayer(in_channel, out_channel, layer, dim, blocks.DownSampleMode.MAXPOOL if i==0 else blocks.DownSampleMode.CONV_STRIDE))
|
|
74
|
+
return modules
|
|
75
|
+
|
|
76
|
+
class NestedUNetBlock(network.ModuleArgsDict):
|
|
77
|
+
|
|
78
|
+
def __init__(self, encoder_channels: list[int], decoder_channels: list[int], encoders: list[torch.nn.Module], upSampleMode: blocks.UpSampleMode, dim: int, i : int = 0) -> None:
|
|
79
|
+
super().__init__()
|
|
80
|
+
self.add_module("X_{}_{}".format(i, 0), encoders[0], out_branch=["X_{}_{}".format(i, 0)])
|
|
81
|
+
if len(encoder_channels) > 2:
|
|
82
|
+
self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(encoder_channels[1:], decoder_channels[1:], encoders[1:], upSampleMode, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(encoder_channels)-2)])
|
|
83
|
+
for j in range(len(encoder_channels)-2):
|
|
84
|
+
self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=encoder_channels[2], out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
|
|
85
|
+
self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
86
|
+
self.add_module("X_{}_{}".format(i, j+1), blocks.ConvBlock(in_channels=encoder_channels[1]*j+((encoder_channels[1]+encoder_channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else (decoder_channels[1]+decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1]+encoder_channels[1])), out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
87
|
+
|
|
88
|
+
class NestedUNetHead(network.ModuleArgsDict):
|
|
89
|
+
|
|
90
|
+
def __init__(self, in_channels: int, out_channels: int, nb_class: int, dim: int) -> None:
|
|
91
|
+
super().__init__()
|
|
92
|
+
self.add_module("Upsample", blocks.upSample(in_channels=in_channels, out_channels=out_channels, upSampleMode=blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim))
|
|
93
|
+
self.add_module("ConvBlock", blocks.ConvBlock(in_channels=in_channels, out_channels=out_channels, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim))
|
|
94
|
+
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = out_channels, out_channels = nb_class, kernel_size = 3, stride = 1, padding = 1))
|
|
95
|
+
if nb_class > 1:
|
|
96
|
+
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
97
|
+
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
98
|
+
else:
|
|
99
|
+
self.add_module("Tanh", torch.nn.Tanh())
|
|
100
|
+
|
|
101
|
+
class UNetpp(network.Network):
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def __init__( self,
|
|
105
|
+
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
106
|
+
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
107
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
108
|
+
patch : Union[ModelPatch, None] = None,
|
|
109
|
+
encoder_channels: list[int] = [1,64,64,128,256,512],
|
|
110
|
+
decoder_channels: list[int] = [256,128,64,32,16,1],
|
|
111
|
+
layers: list[int] = [3,4,6,3],
|
|
112
|
+
dim : int = 2) -> None:
|
|
113
|
+
super().__init__(in_channels = encoder_channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
|
|
114
|
+
self.add_module("Block_0", NestedUNetBlock(encoder_channels, decoder_channels[::-1], resNetEncoder(encoder_channels, layers, dim), blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(encoder_channels)-2)])
|
|
115
|
+
self.add_module("Head", NestedUNetHead(in_channels=decoder_channels[-3], out_channels=decoder_channels[-2], nb_class=decoder_channels[-1], dim=dim), in_branch=["X_0_{}".format(len(encoder_channels)-2)], out_branch=[-1])
|
|
@@ -19,7 +19,7 @@ class NormMode(Enum):
|
|
|
19
19
|
SYNCBATCH = 5
|
|
20
20
|
INSTANCE_AFFINE = 6
|
|
21
21
|
|
|
22
|
-
def getNorm(normMode: Enum, channels : int, dim: int) -> torch.nn.Module:
|
|
22
|
+
def getNorm(normMode: Enum, channels : int, dim: int) -> Union[torch.nn.Module, None]:
|
|
23
23
|
if normMode == NormMode.BATCH:
|
|
24
24
|
return getTorchModule("BatchNorm", dim = dim)(channels, affine=True, track_running_stats=True)
|
|
25
25
|
if normMode == NormMode.INSTANCE:
|
|
@@ -32,7 +32,7 @@ def getNorm(normMode: Enum, channels : int, dim: int) -> torch.nn.Module:
|
|
|
32
32
|
return torch.nn.GroupNorm(num_groups=32, num_channels=channels)
|
|
33
33
|
if normMode == NormMode.LAYER:
|
|
34
34
|
return torch.nn.GroupNorm(num_groups=1, num_channels=channels)
|
|
35
|
-
return
|
|
35
|
+
return None
|
|
36
36
|
|
|
37
37
|
class UpSampleMode(Enum):
|
|
38
38
|
CONV_TRANSPOSE = 0,
|
|
@@ -59,9 +59,9 @@ class BlockConfig():
|
|
|
59
59
|
self.stride = stride
|
|
60
60
|
self.padding = padding
|
|
61
61
|
self.activation = activation
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
62
|
+
if normMode is None:
|
|
63
|
+
self.norm = None
|
|
64
|
+
elif isinstance(normMode, str):
|
|
65
65
|
self.norm = NormMode._member_map_[normMode]
|
|
66
66
|
elif isinstance(normMode, NormMode):
|
|
67
67
|
self.norm = normMode
|
|
@@ -70,9 +70,13 @@ class BlockConfig():
|
|
|
70
70
|
return getTorchModule("Conv", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = self.kernel_size, stride = self.stride, padding = self.padding, bias=self.bias)
|
|
71
71
|
|
|
72
72
|
def getNorm(self, channels : int, dim: int) -> torch.nn.Module:
|
|
73
|
+
if self.norm is None:
|
|
74
|
+
return None
|
|
73
75
|
return getNorm(self.norm, channels, dim) if isinstance(self.norm, NormMode) else self.norm(channels)
|
|
74
76
|
|
|
75
77
|
def getActivation(self) -> torch.nn.Module:
|
|
78
|
+
if self.activation is None:
|
|
79
|
+
return None
|
|
76
80
|
if isinstance(self.activation, str):
|
|
77
81
|
return getTorchModule(self.activation.split(";")[0])(*[ast.literal_eval(value) for value in self.activation.split(";")[1:]]) if self.activation != "None" else torch.nn.Identity()
|
|
78
82
|
return self.activation()
|
|
@@ -83,21 +87,35 @@ class ConvBlock(network.ModuleArgsDict):
|
|
|
83
87
|
super().__init__()
|
|
84
88
|
for i, blockConfig in enumerate(blockConfigs):
|
|
85
89
|
self.add_module("Conv_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[0])
|
|
86
|
-
|
|
87
|
-
|
|
90
|
+
norm = blockConfig.getNorm(out_channels, dim)
|
|
91
|
+
if norm is not None:
|
|
92
|
+
self.add_module("Norm_{}".format(i), norm, alias=alias[1])
|
|
93
|
+
activation = blockConfig.getActivation()
|
|
94
|
+
if activation is not None:
|
|
95
|
+
self.add_module("Activation_{}".format(i), activation, alias=alias[2])
|
|
88
96
|
in_channels = out_channels
|
|
89
97
|
|
|
90
98
|
class ResBlock(network.ModuleArgsDict):
|
|
91
99
|
|
|
92
|
-
def __init__(self, in_channels : int, out_channels : int,
|
|
100
|
+
def __init__(self, in_channels : int, out_channels : int, blockConfigs : list[BlockConfig], dim : int, alias : list[list[str]]=[[], [], [], [], []]) -> None:
|
|
93
101
|
super().__init__()
|
|
94
|
-
for i in
|
|
95
|
-
self.add_module("
|
|
102
|
+
for i, blockConfig in enumerate(blockConfigs):
|
|
103
|
+
self.add_module("Conv_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[0])
|
|
104
|
+
norm = blockConfig.getNorm(out_channels, dim)
|
|
105
|
+
if norm is not None:
|
|
106
|
+
self.add_module("Norm_{}".format(i), norm, alias=alias[1])
|
|
107
|
+
activation = blockConfig.getActivation()
|
|
108
|
+
if activation is not None:
|
|
109
|
+
self.add_module("Activation_{}".format(i), activation, alias=alias[2])
|
|
110
|
+
|
|
96
111
|
if in_channels != out_channels:
|
|
97
|
-
self.add_module("
|
|
112
|
+
self.add_module("Conv_skip", getTorchModule("Conv", dim)(in_channels, out_channels, 1, blockConfig.stride, bias=blockConfig.bias), alias=alias[3], in_branch=[1], out_branch=[1])
|
|
113
|
+
self.add_module("Norm_skip", blockConfig.getNorm(out_channels, dim), alias=alias[4], in_branch=[1], out_branch=[1])
|
|
98
114
|
in_channels = out_channels
|
|
99
|
-
|
|
100
|
-
|
|
115
|
+
|
|
116
|
+
self.add_module("Add", Add(), in_branch=[0,1])
|
|
117
|
+
self.add_module("Norm_{}".format(i+1), torch.nn.ReLU(inplace=True))
|
|
118
|
+
|
|
101
119
|
def downSample(in_channels: int, out_channels: int, downSampleMode: DownSampleMode, dim: int) -> torch.nn.Module:
|
|
102
120
|
if downSampleMode == DownSampleMode.MAXPOOL:
|
|
103
121
|
return getTorchModule("MaxPool", dim = dim)(2)
|
|
@@ -176,6 +194,7 @@ class Print(torch.nn.Module):
|
|
|
176
194
|
super().__init__()
|
|
177
195
|
|
|
178
196
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
197
|
+
print(input.shape)
|
|
179
198
|
return input
|
|
180
199
|
|
|
181
200
|
class Write(torch.nn.Module):
|
|
@@ -133,9 +133,10 @@ class Measure():
|
|
|
133
133
|
self._loss.clear()
|
|
134
134
|
|
|
135
135
|
def add(self, weight: float, value: torch.Tensor) -> None:
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
136
|
+
if self.isLoss or value != 0:
|
|
137
|
+
self._loss.append(value if self.isLoss else value.detach())
|
|
138
|
+
self._values.append(value.item())
|
|
139
|
+
self._weight.append(weight)
|
|
139
140
|
|
|
140
141
|
def getLastLoss(self) -> torch.Tensor:
|
|
141
142
|
return self._loss[-1]*self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad = True)
|
|
@@ -8,7 +8,7 @@ import os
|
|
|
8
8
|
|
|
9
9
|
from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, KONFAI_ROOT
|
|
10
10
|
from konfai.utils.config import config
|
|
11
|
-
from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description
|
|
11
|
+
from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description, PredictorError
|
|
12
12
|
from konfai.utils.dataset import Dataset, Attribute
|
|
13
13
|
from konfai.data.data_manager import DataPrediction, DatasetIter
|
|
14
14
|
from konfai.data.patching import Accumulator, PathCombine
|
|
@@ -46,7 +46,7 @@ class OutDataset(Dataset, NeedDevice, ABC):
|
|
|
46
46
|
self.names: dict[int, str] = {}
|
|
47
47
|
self.nb_data_augmentation = 0
|
|
48
48
|
|
|
49
|
-
def load(self, name_layer: str, datasets: list[Dataset]):
|
|
49
|
+
def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
|
|
50
50
|
transforms_type = ["pre_transforms", "post_transforms", "final_transforms"]
|
|
51
51
|
for name, _transform_type, transform_type in [(k, getattr(self, "_{}".format(k)), getattr(self, k)) for k in transforms_type]:
|
|
52
52
|
|
|
@@ -137,9 +137,22 @@ class OutSameAsGroupDataset(OutDataset):
|
|
|
137
137
|
if self.inverse_transform:
|
|
138
138
|
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].post_transforms):
|
|
139
139
|
layer = transform.inverse(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
|
|
140
|
-
|
|
141
140
|
self.output_layer_accumulator[index_dataset][index_augmentation].addLayer(index_patch, layer)
|
|
142
141
|
|
|
142
|
+
|
|
143
|
+
def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
|
|
144
|
+
super().load(name_layer, datasets, groups)
|
|
145
|
+
|
|
146
|
+
if self.group_src not in groups.keys():
|
|
147
|
+
raise PredictorError(
|
|
148
|
+
f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}."
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if self.group_dest not in groups[self.group_src]:
|
|
152
|
+
raise PredictorError(
|
|
153
|
+
f"Destination group '{self.group_dest}' not found. Available groups: {groups[self.group_src]}."
|
|
154
|
+
)
|
|
155
|
+
|
|
143
156
|
def _getOutput(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
|
|
144
157
|
layer = self.output_layer_accumulator[index][index_augmentation].assemble()
|
|
145
158
|
name = self.names[index]
|
|
@@ -403,9 +416,20 @@ class Predictor(DistributedObject):
|
|
|
403
416
|
|
|
404
417
|
shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
|
|
405
418
|
|
|
419
|
+
|
|
406
420
|
self.model.init(self.autocast, State.PREDICTION, self.dataset.getGroupsDest())
|
|
407
421
|
self.model.init_outputsGroup()
|
|
408
422
|
self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
|
|
423
|
+
|
|
424
|
+
modules = []
|
|
425
|
+
for i,_,_ in self.model.named_ModuleArgsDict():
|
|
426
|
+
modules.append(i)
|
|
427
|
+
for output_group in self.outsDataset.keys():
|
|
428
|
+
if output_group not in modules:
|
|
429
|
+
raise PredictorError("The output group '{}' defined in 'outputsCriterions' does not correspond to any module in the model.".format(output_group),
|
|
430
|
+
"Available modules: {}".format(modules),
|
|
431
|
+
"Please check that the name matches exactly a submodule or output of your model architecture."
|
|
432
|
+
)
|
|
409
433
|
self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), self.combine)
|
|
410
434
|
self.modelComposite.load(self._load())
|
|
411
435
|
|
|
@@ -415,7 +439,7 @@ class Predictor(DistributedObject):
|
|
|
415
439
|
self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
|
|
416
440
|
self.dataloader = self.dataset.getData(world_size//self.size)
|
|
417
441
|
for name, outDataset in self.outsDataset.items():
|
|
418
|
-
outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()))
|
|
442
|
+
outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()), {src : dest for src, inner in self.dataset.groups_src.items() for dest in inner})
|
|
419
443
|
|
|
420
444
|
|
|
421
445
|
def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
|
|
@@ -120,7 +120,8 @@ class _Trainer():
|
|
|
120
120
|
|
|
121
121
|
def run(self) -> None:
|
|
122
122
|
self.dataloader_training.dataset.load("Train")
|
|
123
|
-
self.dataloader_validation
|
|
123
|
+
if self.dataloader_validation is not None:
|
|
124
|
+
self.dataloader_validation.dataset.load("Validation")
|
|
124
125
|
with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress") as epoch_tqdm:
|
|
125
126
|
for self.epoch in epoch_tqdm:
|
|
126
127
|
self.train()
|
|
@@ -586,6 +586,11 @@ class EvaluatorError(KonfAIError):
|
|
|
586
586
|
def __init__(self, *message) -> None:
|
|
587
587
|
super().__init__("Evaluator", message)
|
|
588
588
|
|
|
589
|
+
class PredictorError(KonfAIError):
|
|
590
|
+
|
|
591
|
+
def __init__(self, *message) -> None:
|
|
592
|
+
super().__init__("Predictor", message)
|
|
593
|
+
|
|
589
594
|
class TransformError(KonfAIError):
|
|
590
595
|
|
|
591
596
|
def __init__(self, *message) -> None:
|
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
from typing import Union
|
|
2
|
-
import torch
|
|
3
|
-
|
|
4
|
-
from konfai.network import network, blocks
|
|
5
|
-
from konfai.utils.config import config
|
|
6
|
-
from konfai.data.patching import ModelPatch
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class NestedUNetBlock(network.ModuleArgsDict):
|
|
10
|
-
|
|
11
|
-
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, dim: int, i : int = 0) -> None:
|
|
12
|
-
super().__init__()
|
|
13
|
-
if i > 0:
|
|
14
|
-
self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
|
|
15
|
-
|
|
16
|
-
self.add_module("X_{}_{}".format(i, 0), block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), out_branch=["X_{}_{}".format(i, 0)])
|
|
17
|
-
if len(channels) > 2:
|
|
18
|
-
self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(channels)-2)])
|
|
19
|
-
for j in range(len(channels)-2):
|
|
20
|
-
self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=channels[2], out_channels=channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
|
|
21
|
-
self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
22
|
-
self.add_module("X_{}_{}".format(i, j+1), block(in_channels=(channels[1]*(j+1)+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*(j+2), out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
23
|
-
|
|
24
|
-
class NestedUNet(network.Network):
|
|
25
|
-
|
|
26
|
-
class NestedUNetHead(network.ModuleArgsDict):
|
|
27
|
-
|
|
28
|
-
def __init__(self, in_channels: int, nb_class: int, dim: int) -> None:
|
|
29
|
-
super().__init__()
|
|
30
|
-
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
|
31
|
-
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
32
|
-
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
33
|
-
|
|
34
|
-
def __init__( self,
|
|
35
|
-
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
36
|
-
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
37
|
-
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
38
|
-
patch : Union[ModelPatch, None] = None,
|
|
39
|
-
dim : int = 3,
|
|
40
|
-
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
41
|
-
nb_class: int = 2,
|
|
42
|
-
blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
|
|
43
|
-
nb_conv_per_stage: int = 2,
|
|
44
|
-
downSampleMode: str = "MAXPOOL",
|
|
45
|
-
upSampleMode: str = "CONV_TRANSPOSE",
|
|
46
|
-
attention : bool = False,
|
|
47
|
-
blockType: str = "Conv") -> None:
|
|
48
|
-
super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
|
|
49
|
-
|
|
50
|
-
self.add_module("UNetBlock_0", NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
|
|
51
|
-
for j in range(len(channels)-2):
|
|
52
|
-
self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|