konfai 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +16 -0
- konfai/data/HDF5.py +326 -0
- konfai/data/__init__.py +0 -0
- konfai/data/augmentation.py +597 -0
- konfai/data/dataset.py +470 -0
- konfai/data/transform.py +536 -0
- konfai/evaluator.py +146 -0
- konfai/main.py +43 -0
- konfai/metric/__init__.py +0 -0
- konfai/metric/measure.py +488 -0
- konfai/metric/schedulers.py +49 -0
- konfai/models/classification/convNeXt.py +175 -0
- konfai/models/classification/resnet.py +116 -0
- konfai/models/generation/cStyleGan.py +137 -0
- konfai/models/generation/ddpm.py +218 -0
- konfai/models/generation/diffusionGan.py +557 -0
- konfai/models/generation/gan.py +134 -0
- konfai/models/generation/vae.py +72 -0
- konfai/models/registration/registration.py +136 -0
- konfai/models/representation/representation.py +57 -0
- konfai/models/segmentation/NestedUNet.py +53 -0
- konfai/models/segmentation/UNet.py +58 -0
- konfai/network/__init__.py +0 -0
- konfai/network/blocks.py +348 -0
- konfai/network/network.py +950 -0
- konfai/predictor.py +366 -0
- konfai/trainer.py +330 -0
- konfai/utils/ITK.py +269 -0
- konfai/utils/Registration.py +199 -0
- konfai/utils/__init__.py +0 -0
- konfai/utils/config.py +218 -0
- konfai/utils/dataset.py +764 -0
- konfai/utils/utils.py +493 -0
- konfai-1.0.0.dist-info/METADATA +68 -0
- konfai-1.0.0.dist-info/RECORD +39 -0
- konfai-1.0.0.dist-info/WHEEL +5 -0
- konfai-1.0.0.dist-info/entry_points.txt +3 -0
- konfai-1.0.0.dist-info/licenses/LICENSE +201 -0
- konfai-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from KonfAI.konfai.network import network, blocks
|
|
5
|
+
from KonfAI.konfai.utils.config import config
|
|
6
|
+
from KonfAI.konfai.data.HDF5 import ModelPatch
|
|
7
|
+
|
|
8
|
+
class Discriminator(network.Network):
|
|
9
|
+
|
|
10
|
+
class DiscriminatorNLayers(network.ModuleArgsDict):
|
|
11
|
+
|
|
12
|
+
def __init__(self, channels: list[int], strides: list[int], dim: int) -> None:
|
|
13
|
+
super().__init__()
|
|
14
|
+
blockConfig = partial(blocks.BlockConfig, kernel_size=4, padding=1, bias=False, activation=partial(torch.nn.LeakyReLU, negative_slope = 0.2, inplace=True), normMode=blocks.NormMode.SYNCBATCH)
|
|
15
|
+
for i, (in_channels, out_channels, stride) in enumerate(zip(channels, channels[1:], strides)):
|
|
16
|
+
self.add_module("Layer_{}".format(i), blocks.ConvBlock(in_channels, out_channels, 1, blockConfig(stride=stride), dim))
|
|
17
|
+
|
|
18
|
+
class DiscriminatorHead(network.ModuleArgsDict):
|
|
19
|
+
|
|
20
|
+
def __init__(self, channels: int, dim: int) -> None:
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels=channels, out_channels=1, kernel_size=4, stride=1, padding=1))
|
|
23
|
+
self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(tuple([1]*dim)))
|
|
24
|
+
self.add_module("Flatten", torch.nn.Flatten(1))
|
|
25
|
+
|
|
26
|
+
@config("Discriminator")
|
|
27
|
+
def __init__(self,
|
|
28
|
+
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
29
|
+
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
30
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
31
|
+
nb_batch_per_step: int = 64,
|
|
32
|
+
dim : int = 3) -> None:
|
|
33
|
+
super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=dim, nb_batch_per_step=nb_batch_per_step)
|
|
34
|
+
channels = [1, 16, 32, 64, 64]
|
|
35
|
+
strides = [2,2,2,1]
|
|
36
|
+
self.add_module("Layers", Discriminator.DiscriminatorNLayers(channels, strides, dim))
|
|
37
|
+
self.add_module("Head", Discriminator.DiscriminatorHead(channels[-1], dim))
|
|
38
|
+
|
|
39
|
+
class Generator(network.Network):
|
|
40
|
+
|
|
41
|
+
class GeneratorStem(network.ModuleArgsDict):
|
|
42
|
+
|
|
43
|
+
def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, nb_conv=1, blockConfig=blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH"), dim=dim))
|
|
46
|
+
|
|
47
|
+
class GeneratorHead(network.ModuleArgsDict):
|
|
48
|
+
|
|
49
|
+
def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.add_module("ConvBlock", blocks.ConvBlock(in_channels, in_channels, nb_conv=1, blockConfig=blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH"), dim=dim))
|
|
52
|
+
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels, out_channels, kernel_size=1, bias=False))
|
|
53
|
+
self.add_module("Tanh", torch.nn.Tanh())
|
|
54
|
+
|
|
55
|
+
class GeneratorDownSample(network.ModuleArgsDict):
|
|
56
|
+
|
|
57
|
+
def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, nb_conv=1, blockConfig=blocks.BlockConfig(stride=2, bias=False, activation="ReLU", normMode="SYNCBATCH"), dim=dim))
|
|
60
|
+
|
|
61
|
+
class GeneratorUpSample(network.ModuleArgsDict):
|
|
62
|
+
|
|
63
|
+
def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, nb_conv=1, blockConfig=blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH"), dim=dim))
|
|
66
|
+
self.add_module("Upsample", torch.nn.Upsample(scale_factor=2, mode="bilinear" if dim < 3 else "trilinear"))
|
|
67
|
+
|
|
68
|
+
class GeneratorEncoder(network.ModuleArgsDict):
|
|
69
|
+
def __init__(self, channels: list[int], dim: int) -> None:
|
|
70
|
+
super().__init__()
|
|
71
|
+
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
|
|
72
|
+
self.add_module("DownSample_{}".format(i), Generator.GeneratorDownSample(in_channels=in_channels, out_channels=out_channels, dim=dim))
|
|
73
|
+
|
|
74
|
+
class GeneratorResnetBlock(network.ModuleArgsDict):
|
|
75
|
+
|
|
76
|
+
def __init__(self, channels : int, dim : int):
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.add_module("Conv_0", blocks.getTorchModule("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False))
|
|
79
|
+
self.add_module("Norm", torch.nn.LeakyReLU(0.2, inplace=True))
|
|
80
|
+
self.add_module("Conv_1", blocks.getTorchModule("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False))
|
|
81
|
+
self.add_module("Residual", blocks.Add(), in_branch=[0,1])
|
|
82
|
+
|
|
83
|
+
class GeneratorNResnetBlock(network.ModuleArgsDict):
|
|
84
|
+
|
|
85
|
+
def __init__(self, channels: int, nb_conv: int, dim: int) -> None:
|
|
86
|
+
super().__init__()
|
|
87
|
+
for i in range(nb_conv):
|
|
88
|
+
self.add_module("ResnetBlock_{}".format(i), Generator.GeneratorResnetBlock(channels=channels, dim=dim))
|
|
89
|
+
|
|
90
|
+
class GeneratorDecoder(network.ModuleArgsDict):
|
|
91
|
+
def __init__(self, channels: list[int], dim: int) -> None:
|
|
92
|
+
super().__init__()
|
|
93
|
+
for i, (in_channels, out_channels) in enumerate(zip(reversed(channels), reversed(channels[:-1]))):
|
|
94
|
+
self.add_module("UpSample_{}".format(i), Generator.GeneratorUpSample(in_channels=in_channels, out_channels=out_channels, dim=dim))
|
|
95
|
+
|
|
96
|
+
class GeneratorAutoEncoder(network.ModuleArgsDict):
|
|
97
|
+
|
|
98
|
+
def __init__(self, ngf: int, dim: int) -> None:
|
|
99
|
+
super().__init__()
|
|
100
|
+
channels = [ngf, ngf*2]
|
|
101
|
+
self.add_module("Encoder", Generator.GeneratorEncoder(channels, dim))
|
|
102
|
+
self.add_module("NResBlock", Generator.GeneratorNResnetBlock(channels=channels[-1], nb_conv=6, dim=dim))
|
|
103
|
+
self.add_module("Decoder", Generator.GeneratorDecoder(channels, dim))
|
|
104
|
+
|
|
105
|
+
@config("Generator")
|
|
106
|
+
def __init__(self,
|
|
107
|
+
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
108
|
+
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
109
|
+
patch : ModelPatch = ModelPatch(),
|
|
110
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
111
|
+
nb_batch_per_step: int = 64,
|
|
112
|
+
dim : int = 3) -> None:
|
|
113
|
+
super().__init__(optimizer=optimizer, in_channels=1, schedulers=schedulers, patch=patch, outputsCriterions=outputsCriterions, dim=dim, nb_batch_per_step=nb_batch_per_step)
|
|
114
|
+
ngf=32
|
|
115
|
+
self.add_module("Stem", Generator.GeneratorStem(1, ngf, dim))
|
|
116
|
+
self.add_module("AutoEncoder", Generator.GeneratorAutoEncoder(ngf, dim))
|
|
117
|
+
self.add_module("Head", Generator.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
|
|
118
|
+
|
|
119
|
+
def getName(self):
|
|
120
|
+
return "Generator"
|
|
121
|
+
|
|
122
|
+
class Gan(network.Network):
|
|
123
|
+
|
|
124
|
+
@config("Gan")
|
|
125
|
+
def __init__(self, generator : Generator = Generator(), discriminator : Discriminator = Discriminator()) -> None:
|
|
126
|
+
super().__init__()
|
|
127
|
+
self.add_module("Discriminator_B", discriminator, in_branch=[1], out_branch=[-1], requires_grad=True)
|
|
128
|
+
self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
|
|
129
|
+
|
|
130
|
+
self.add_module("detach", blocks.Detach(), in_branch=["pB"], out_branch=["pB_detach"])
|
|
131
|
+
self.add_module("Discriminator_pB_detach", discriminator, in_branch=["pB_detach"], out_branch=[-1])
|
|
132
|
+
|
|
133
|
+
self.add_module("Discriminator_pB", discriminator, in_branch=["pB"], out_branch=[-1], requires_grad=False)
|
|
134
|
+
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from KonfAI.konfai.network import network, blocks
|
|
4
|
+
from KonfAI.konfai.utils.config import config
|
|
5
|
+
|
|
6
|
+
class VAE(network.Network):
|
|
7
|
+
|
|
8
|
+
class AutoEncoderBlock(network.ModuleArgsDict):
|
|
9
|
+
|
|
10
|
+
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, dim: int, block: type, i : int = 0) -> None:
|
|
11
|
+
super().__init__()
|
|
12
|
+
if i > 0:
|
|
13
|
+
self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
|
|
14
|
+
self.add_module("DownBlock", block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim))
|
|
15
|
+
if len(channels) > 2:
|
|
16
|
+
self.add_module("AutoEncoder_{}".format(i+1), VAE.AutoEncoderBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, dim, block, i+1))
|
|
17
|
+
self.add_module("UpBlock", block(in_channels=channels[2] if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim))
|
|
18
|
+
if i > 0:
|
|
19
|
+
self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim))
|
|
20
|
+
|
|
21
|
+
class VAE_Head(network.ModuleArgsDict):
|
|
22
|
+
|
|
23
|
+
def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
|
|
26
|
+
self.add_module("Tanh", torch.nn.Tanh())
|
|
27
|
+
|
|
28
|
+
@config("VAE")
|
|
29
|
+
def __init__(self,
|
|
30
|
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
31
|
+
schedulers: network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
32
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
33
|
+
dim : int = 3,
|
|
34
|
+
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
35
|
+
blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
|
|
36
|
+
nb_conv_per_stage: int = 2,
|
|
37
|
+
downSampleMode: str = "MAXPOOL",
|
|
38
|
+
upSampleMode: str = "CONV_TRANSPOSE",
|
|
39
|
+
blockType: str = "Conv") -> None:
|
|
40
|
+
|
|
41
|
+
super().__init__(in_channels = channels[0], init_type="normal", optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=dim, nb_batch_per_step=1)
|
|
42
|
+
self.add_module("AutoEncoder_0", VAE.AutoEncoderBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], dim=dim, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock))
|
|
43
|
+
self.add_module("Head", VAE.VAE_Head(channels[1], channels[0], dim))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class LinearVAE(network.Network):
|
|
47
|
+
|
|
48
|
+
class LinearVAE_DenseLayer(network.ModuleArgsDict):
|
|
49
|
+
|
|
50
|
+
def __init__(self, in_features: int, out_features: int) -> None:
|
|
51
|
+
super().__init__()
|
|
52
|
+
self.add_module("Linear", torch.nn.Linear(in_features, out_features))
|
|
53
|
+
#self.add_module("Norm", torch.nn.BatchNorm1d(out_features))
|
|
54
|
+
self.add_module("Activation", torch.nn.LeakyReLU())
|
|
55
|
+
|
|
56
|
+
class LinearVAE_Head(network.ModuleArgsDict):
|
|
57
|
+
|
|
58
|
+
def __init__(self, in_features: int, out_features: int) -> None:
|
|
59
|
+
super().__init__()
|
|
60
|
+
self.add_module("Linear", torch.nn.Linear(in_features, out_features))
|
|
61
|
+
self.add_module("Tanh", torch.nn.Tanh())
|
|
62
|
+
|
|
63
|
+
@config("LinearVAE")
|
|
64
|
+
def __init__(self,
|
|
65
|
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
66
|
+
schedulers: network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
67
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},) -> None:
|
|
68
|
+
super().__init__(in_channels = 1, init_type="normal", optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=1, nb_batch_per_step=1)
|
|
69
|
+
self.add_module("DenseLayer_0", LinearVAE.LinearVAE_DenseLayer(23343, 5))
|
|
70
|
+
#self.add_module("Head", LinearVAE.DenseLayer(100, 28590))
|
|
71
|
+
self.add_module("Head", LinearVAE.LinearVAE_Head(5, 23343))
|
|
72
|
+
#self.add_module("DenseLayer_5", LinearVAE.DenseLayer(5000, 28590))
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.nn.parameter import Parameter
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
from KonfAI.konfai.network import network, blocks
|
|
6
|
+
from KonfAI.konfai.utils.config import config
|
|
7
|
+
from KonfAI.konfai.models.segmentation import UNet
|
|
8
|
+
|
|
9
|
+
class VoxelMorph(network.Network):
|
|
10
|
+
|
|
11
|
+
@config("VoxelMorph")
|
|
12
|
+
def __init__( self,
|
|
13
|
+
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
14
|
+
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
15
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
16
|
+
dim : int = 3,
|
|
17
|
+
channels : list[int] = [4, 16,32,32,32],
|
|
18
|
+
blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
|
|
19
|
+
nb_conv_per_stage: int = 2,
|
|
20
|
+
downSampleMode: str = "MAXPOOL",
|
|
21
|
+
upSampleMode: str = "CONV_TRANSPOSE",
|
|
22
|
+
attention : bool = False,
|
|
23
|
+
shape : list[int] = [192, 192, 192],
|
|
24
|
+
int_steps : int = 7,
|
|
25
|
+
int_downsize : int = 2,
|
|
26
|
+
nb_batch_per_step : int = 1,
|
|
27
|
+
rigid: bool = False):
|
|
28
|
+
super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim = dim, nb_batch_per_step=nb_batch_per_step)
|
|
29
|
+
self.add_module("Concat", blocks.Concat(), in_branch=[0,1], out_branch=["input_concat"])
|
|
30
|
+
self.add_module("UNetBlock_0", UNet.UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock, nb_class=0, dim=dim), in_branch=["input_concat"], out_branch=["unet"])
|
|
31
|
+
|
|
32
|
+
if rigid:
|
|
33
|
+
self.add_module("Flow", Rigid(channels[1], dim), in_branch=["unet"], out_branch=["pos_flow"])
|
|
34
|
+
else:
|
|
35
|
+
self.add_module("Flow", Flow(channels[1], int_steps, int_downsize, shape, dim), in_branch=["unet"], out_branch=["pos_flow"])
|
|
36
|
+
self.add_module("MovingImageResample", SpatialTransformer(shape, rigid=rigid), in_branch=[1, "pos_flow"], out_branch=["moving_image_resample"])
|
|
37
|
+
|
|
38
|
+
class Flow(network.ModuleArgsDict):
|
|
39
|
+
|
|
40
|
+
def __init__(self, in_channels: int, int_steps: int, int_downsize: int, shape: list[int], dim: int) -> None:
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.add_module("Head", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = dim, kernel_size = 3, stride = 1, padding = 1))
|
|
43
|
+
self["Head"].weight = Parameter(torch.distributions.Normal(0, 1e-5).sample(self["Head"].weight.shape))
|
|
44
|
+
self["Head"].bias = Parameter(torch.zeros(self["Head"].bias.shape))
|
|
45
|
+
|
|
46
|
+
if int_steps > 0 and int_downsize > 1:
|
|
47
|
+
self.add_module("DownSample", ResizeTransform(int_downsize))
|
|
48
|
+
|
|
49
|
+
if int_steps > 0:
|
|
50
|
+
self.add_module("Integrate_pos_flow", VecInt([int(dim / int_downsize) for dim in shape], int_steps))
|
|
51
|
+
|
|
52
|
+
if int_steps > 0 and int_downsize > 1:
|
|
53
|
+
self.add_module("Upsample_pos_flow", ResizeTransform(1 / int_downsize))
|
|
54
|
+
|
|
55
|
+
class Rigid(network.ModuleArgsDict):
|
|
56
|
+
|
|
57
|
+
def __init__(self, in_channels: int, dim: int) -> None:
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.add_module("ToFeatures", torch.nn.Flatten(1))
|
|
60
|
+
self.add_module("Head", torch.nn.Linear(in_channels*512*512, 2))
|
|
61
|
+
|
|
62
|
+
def init(self, init_type: str, init_gain: float):
|
|
63
|
+
self["Head"].weight.data.fill_(0)
|
|
64
|
+
self["Head"].bias.data.copy_(torch.tensor([0, 0], dtype=torch.float))
|
|
65
|
+
|
|
66
|
+
class MaskFlow(torch.nn.Module):
|
|
67
|
+
|
|
68
|
+
def __init__(self):
|
|
69
|
+
super().__init__()
|
|
70
|
+
|
|
71
|
+
def forward(self, mask: torch.Tensor, *flows: torch.Tensor):
|
|
72
|
+
result = torch.zeros_like(flows[0])
|
|
73
|
+
for i, flow in enumerate(flows):
|
|
74
|
+
result = result+torch.where(mask == i+1, flow, torch.tensor(0))
|
|
75
|
+
return result
|
|
76
|
+
|
|
77
|
+
class SpatialTransformer(torch.nn.Module):
|
|
78
|
+
|
|
79
|
+
def __init__(self, size : list[int], rigid: bool = False):
|
|
80
|
+
super().__init__()
|
|
81
|
+
self.rigid = rigid
|
|
82
|
+
if not rigid:
|
|
83
|
+
vectors = [torch.arange(0, s) for s in size]
|
|
84
|
+
grids = torch.meshgrid(vectors, indexing='ij')
|
|
85
|
+
grid = torch.stack(grids)
|
|
86
|
+
grid = torch.unsqueeze(grid, 0)
|
|
87
|
+
grid = grid.type(torch.float)
|
|
88
|
+
self.register_buffer('grid', grid)
|
|
89
|
+
|
|
90
|
+
def forward(self, src: torch.Tensor, flow: torch.Tensor):
|
|
91
|
+
if self.rigid:
|
|
92
|
+
new_locs = torch.zeros((flow.shape[0], 2, 3)).to(flow.device)
|
|
93
|
+
new_locs[:, 0,0] = 1
|
|
94
|
+
new_locs[:, 1,1] = 1
|
|
95
|
+
new_locs[:, 0,2] = flow[:, 0]
|
|
96
|
+
new_locs[:, 1,2] = flow[:, 1]
|
|
97
|
+
print(new_locs)
|
|
98
|
+
return F.grid_sample(src, F.affine_grid(new_locs, src.size()), align_corners=True, mode="bilinear")
|
|
99
|
+
else:
|
|
100
|
+
new_locs = self.grid + flow
|
|
101
|
+
shape = flow.shape[2:]
|
|
102
|
+
for i in range(len(shape)):
|
|
103
|
+
new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
|
|
104
|
+
new_locs = new_locs.permute(0, 2, 3, 1)
|
|
105
|
+
return F.grid_sample(src, new_locs[..., [1, 0]], align_corners=True, mode="bilinear")
|
|
106
|
+
|
|
107
|
+
class VecInt(torch.nn.Module):
|
|
108
|
+
|
|
109
|
+
def __init__(self, inshape, nsteps):
|
|
110
|
+
super().__init__()
|
|
111
|
+
assert nsteps >= 0, 'nsteps should be >= 0, found: %d' % nsteps
|
|
112
|
+
self.nsteps = nsteps
|
|
113
|
+
self.scale = 1.0 / (2 ** self.nsteps)
|
|
114
|
+
self.transformer = SpatialTransformer(inshape)
|
|
115
|
+
|
|
116
|
+
def forward(self, vec: torch.Tensor):
|
|
117
|
+
vec = vec * self.scale
|
|
118
|
+
for _ in range(self.nsteps):
|
|
119
|
+
vec = vec + self.transformer(vec, vec)
|
|
120
|
+
return vec
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class ResizeTransform(torch.nn.Module):
|
|
124
|
+
|
|
125
|
+
def __init__(self, size):
|
|
126
|
+
super().__init__()
|
|
127
|
+
self.factor = 1.0 / size
|
|
128
|
+
|
|
129
|
+
def forward(self, x: torch.Tensor):
|
|
130
|
+
if self.factor < 1:
|
|
131
|
+
x = F.interpolate(x, align_corners=True, scale_factor=self.factor, mode="bilinear", recompute_scale_factor = True)
|
|
132
|
+
x = self.factor * x
|
|
133
|
+
elif self.factor > 1:
|
|
134
|
+
x = self.factor * x
|
|
135
|
+
x = F.interpolate(x, align_corners=True, scale_factor=self.factor, mode="bilinear", recompute_scale_factor = True)
|
|
136
|
+
return x
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from KonfAI.konfai.network import network, blocks
|
|
4
|
+
from KonfAI.konfai.utils.config import config
|
|
5
|
+
|
|
6
|
+
class ConvBlock(torch.nn.Module):
|
|
7
|
+
|
|
8
|
+
def __init__(self, in_channels : int, out_channels : int) -> None:
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.Conv_0 = torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True)
|
|
11
|
+
self.Norm_0 = torch.nn.InstanceNorm3d(num_features=out_channels, affine=True)
|
|
12
|
+
self.Activation_0 = torch.nn.LeakyReLU(negative_slope=0.01)
|
|
13
|
+
self.Conv_1 = torch.nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True)
|
|
14
|
+
self.Norm_1 = torch.nn.InstanceNorm3d(num_features=out_channels, affine=True)
|
|
15
|
+
self.Activation_1 = torch.nn.LeakyReLU(negative_slope=0.01)
|
|
16
|
+
|
|
17
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
18
|
+
output = self.Conv_0(input)
|
|
19
|
+
output = self.Norm_0(output)
|
|
20
|
+
output = self.Activation_0(output)
|
|
21
|
+
output = self.Conv_1(output)
|
|
22
|
+
output = self.Norm_1(output)
|
|
23
|
+
output = self.Activation_1(output)
|
|
24
|
+
return output
|
|
25
|
+
|
|
26
|
+
class UnetCPP_1_Layers(torch.nn.Module):
|
|
27
|
+
|
|
28
|
+
def __init__(self) -> None:
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.DownConvBlock_0 = ConvBlock(in_channels=1, out_channels=32)
|
|
31
|
+
|
|
32
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
33
|
+
return self.DownConvBlock_0(input)
|
|
34
|
+
|
|
35
|
+
class Adaptation(torch.nn.Module):
|
|
36
|
+
|
|
37
|
+
def __init__(self) -> None:
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.Encoder_1 = UnetCPP_1_Layers()
|
|
40
|
+
self.ToFeatures = blocks.ToFeatures(3)
|
|
41
|
+
self.FCT_1 = torch.nn.Linear(32, 32, bias=True)
|
|
42
|
+
|
|
43
|
+
def forward(self, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor) -> list[torch.Tensor]:
|
|
44
|
+
self.Encoder_1.requires_grad_(False)
|
|
45
|
+
self.FCT_1.requires_grad_(True)
|
|
46
|
+
return self.FCT_1(self.ToFeatures(self.Encoder_1(A))), self.FCT_1(self.ToFeatures(self.Encoder_1(B))), self.FCT_1(self.ToFeatures(self.Encoder_1(C)))
|
|
47
|
+
|
|
48
|
+
class Representation(network.Network):
|
|
49
|
+
|
|
50
|
+
@config("Representation")
|
|
51
|
+
def __init__( self,
|
|
52
|
+
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
53
|
+
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
54
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
55
|
+
dim : int = 3):
|
|
56
|
+
super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=dim, init_type="kaiming")
|
|
57
|
+
self.add_module("Model", Adaptation(), in_branch=[0,1,2])
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from KonfAI.konfai.network import network, blocks
|
|
5
|
+
from KonfAI.konfai.utils.config import config
|
|
6
|
+
from KonfAI.konfai.data.HDF5 import ModelPatch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class NestedUNetBlock(network.ModuleArgsDict):
|
|
10
|
+
|
|
11
|
+
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, dim: int, i : int = 0) -> None:
|
|
12
|
+
super().__init__()
|
|
13
|
+
if i > 0:
|
|
14
|
+
self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
|
|
15
|
+
|
|
16
|
+
self.add_module("X_{}_{}".format(i, 0), block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), out_branch=["X_{}_{}".format(i, 0)])
|
|
17
|
+
if len(channels) > 2:
|
|
18
|
+
self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(channels)-2)])
|
|
19
|
+
for j in range(len(channels)-2):
|
|
20
|
+
self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=channels[2], out_channels=channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
|
|
21
|
+
self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
22
|
+
self.add_module("X_{}_{}".format(i, j+1), block(in_channels=(channels[1]*(j+1)+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*(j+2), out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
23
|
+
|
|
24
|
+
class NestedUNet(network.Network):
|
|
25
|
+
|
|
26
|
+
class NestedUNetHead(network.ModuleArgsDict):
|
|
27
|
+
|
|
28
|
+
def __init__(self, in_channels: int, nb_class: int, dim: int) -> None:
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
|
31
|
+
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
32
|
+
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
33
|
+
|
|
34
|
+
@config("NestedUNet")
|
|
35
|
+
def __init__( self,
|
|
36
|
+
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
37
|
+
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
38
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
39
|
+
patch : Union[ModelPatch, None] = None,
|
|
40
|
+
dim : int = 3,
|
|
41
|
+
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
42
|
+
nb_class: int = 2,
|
|
43
|
+
blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
|
|
44
|
+
nb_conv_per_stage: int = 2,
|
|
45
|
+
downSampleMode: str = "MAXPOOL",
|
|
46
|
+
upSampleMode: str = "CONV_TRANSPOSE",
|
|
47
|
+
attention : bool = False,
|
|
48
|
+
blockType: str = "Conv") -> None:
|
|
49
|
+
super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
|
|
50
|
+
|
|
51
|
+
self.add_module("UNetBlock_0", NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
|
|
52
|
+
for j in range(len(channels)-2):
|
|
53
|
+
self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
from KonfAI.konfai.network import network, blocks
|
|
5
|
+
from KonfAI.konfai.utils.config import config
|
|
6
|
+
from KonfAI.konfai.data.HDF5 import ModelPatch
|
|
7
|
+
|
|
8
|
+
class UNetHead(network.ModuleArgsDict):
|
|
9
|
+
|
|
10
|
+
def __init__(self, in_channels: int, nb_class: int, dim: int) -> None:
|
|
11
|
+
super().__init__()
|
|
12
|
+
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
|
13
|
+
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
14
|
+
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
15
|
+
|
|
16
|
+
class UNetBlock(network.ModuleArgsDict):
|
|
17
|
+
|
|
18
|
+
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, nb_class: int, dim: int, i : int = 0, mri: bool = False) -> None:
|
|
19
|
+
super().__init__()
|
|
20
|
+
blockConfig_stride = blockConfig
|
|
21
|
+
if i > 0:
|
|
22
|
+
if downSampleMode != blocks.DownSampleMode.CONV_STRIDE:
|
|
23
|
+
self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
|
|
24
|
+
else:
|
|
25
|
+
blockConfig_stride = blocks.BlockConfig(blockConfig.kernel_size, (1,2,2) if mri and i > 4 else 2, blockConfig.padding, blockConfig.bias, blockConfig.activation, blockConfig.normMode)
|
|
26
|
+
self.add_module("DownConvBlock", block(in_channels=channels[0], out_channels=channels[1], blockConfigs=[blockConfig_stride]+[blockConfig]*(nb_conv_per_stage-1), dim=dim))
|
|
27
|
+
if len(channels) > 2:
|
|
28
|
+
self.add_module("UNetBlock_{}".format(i+1), UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, nb_class, dim, i+1, mri=mri))
|
|
29
|
+
self.add_module("UpConvBlock", block(in_channels=(channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2, out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim))
|
|
30
|
+
if nb_class > 0:
|
|
31
|
+
self.add_module("Head", UNetHead(channels[1], nb_class, dim), out_branch=[-1])
|
|
32
|
+
if i > 0:
|
|
33
|
+
if attention:
|
|
34
|
+
self.add_module("Attention", blocks.Attention(F_g=channels[1], F_l=channels[0], F_int=channels[0], dim=dim), in_branch=[1, 0], out_branch=[1])
|
|
35
|
+
self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim, kernel_size=(1,2,2) if mri and i > 4 else 2, stride=(1,2,2) if mri and i > 4 else 2))
|
|
36
|
+
self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
|
|
37
|
+
|
|
38
|
+
class UNet(network.Network):
|
|
39
|
+
|
|
40
|
+
@config("UNet")
|
|
41
|
+
def __init__( self,
|
|
42
|
+
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
43
|
+
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
44
|
+
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
45
|
+
patch : Union[ModelPatch, None] = None,
|
|
46
|
+
dim : int = 3,
|
|
47
|
+
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
48
|
+
nb_class: int = 2,
|
|
49
|
+
blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
|
|
50
|
+
nb_conv_per_stage: int = 2,
|
|
51
|
+
downSampleMode: str = "MAXPOOL",
|
|
52
|
+
upSampleMode: str = "CONV_TRANSPOSE",
|
|
53
|
+
attention : bool = False,
|
|
54
|
+
blockType: str = "Conv",
|
|
55
|
+
mri: bool = False) -> None:
|
|
56
|
+
super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
|
|
57
|
+
self.add_module("UNetBlock_0", UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=nb_class, dim=dim, mri = mri))
|
|
58
|
+
|
|
File without changes
|