konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -0,0 +1,134 @@
1
+ from functools import partial
2
+ import torch
3
+
4
+ from KonfAI.konfai.network import network, blocks
5
+ from KonfAI.konfai.utils.config import config
6
+ from KonfAI.konfai.data.HDF5 import ModelPatch
7
+
8
+ class Discriminator(network.Network):
9
+
10
+ class DiscriminatorNLayers(network.ModuleArgsDict):
11
+
12
+ def __init__(self, channels: list[int], strides: list[int], dim: int) -> None:
13
+ super().__init__()
14
+ blockConfig = partial(blocks.BlockConfig, kernel_size=4, padding=1, bias=False, activation=partial(torch.nn.LeakyReLU, negative_slope = 0.2, inplace=True), normMode=blocks.NormMode.SYNCBATCH)
15
+ for i, (in_channels, out_channels, stride) in enumerate(zip(channels, channels[1:], strides)):
16
+ self.add_module("Layer_{}".format(i), blocks.ConvBlock(in_channels, out_channels, 1, blockConfig(stride=stride), dim))
17
+
18
+ class DiscriminatorHead(network.ModuleArgsDict):
19
+
20
+ def __init__(self, channels: int, dim: int) -> None:
21
+ super().__init__()
22
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels=channels, out_channels=1, kernel_size=4, stride=1, padding=1))
23
+ self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(tuple([1]*dim)))
24
+ self.add_module("Flatten", torch.nn.Flatten(1))
25
+
26
+ @config("Discriminator")
27
+ def __init__(self,
28
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
29
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
30
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
31
+ nb_batch_per_step: int = 64,
32
+ dim : int = 3) -> None:
33
+ super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=dim, nb_batch_per_step=nb_batch_per_step)
34
+ channels = [1, 16, 32, 64, 64]
35
+ strides = [2,2,2,1]
36
+ self.add_module("Layers", Discriminator.DiscriminatorNLayers(channels, strides, dim))
37
+ self.add_module("Head", Discriminator.DiscriminatorHead(channels[-1], dim))
38
+
39
+ class Generator(network.Network):
40
+
41
+ class GeneratorStem(network.ModuleArgsDict):
42
+
43
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
44
+ super().__init__()
45
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, nb_conv=1, blockConfig=blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH"), dim=dim))
46
+
47
+ class GeneratorHead(network.ModuleArgsDict):
48
+
49
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
50
+ super().__init__()
51
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels, in_channels, nb_conv=1, blockConfig=blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH"), dim=dim))
52
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels, out_channels, kernel_size=1, bias=False))
53
+ self.add_module("Tanh", torch.nn.Tanh())
54
+
55
+ class GeneratorDownSample(network.ModuleArgsDict):
56
+
57
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
58
+ super().__init__()
59
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, nb_conv=1, blockConfig=blocks.BlockConfig(stride=2, bias=False, activation="ReLU", normMode="SYNCBATCH"), dim=dim))
60
+
61
+ class GeneratorUpSample(network.ModuleArgsDict):
62
+
63
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
64
+ super().__init__()
65
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, nb_conv=1, blockConfig=blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH"), dim=dim))
66
+ self.add_module("Upsample", torch.nn.Upsample(scale_factor=2, mode="bilinear" if dim < 3 else "trilinear"))
67
+
68
+ class GeneratorEncoder(network.ModuleArgsDict):
69
+ def __init__(self, channels: list[int], dim: int) -> None:
70
+ super().__init__()
71
+ for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
72
+ self.add_module("DownSample_{}".format(i), Generator.GeneratorDownSample(in_channels=in_channels, out_channels=out_channels, dim=dim))
73
+
74
+ class GeneratorResnetBlock(network.ModuleArgsDict):
75
+
76
+ def __init__(self, channels : int, dim : int):
77
+ super().__init__()
78
+ self.add_module("Conv_0", blocks.getTorchModule("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False))
79
+ self.add_module("Norm", torch.nn.LeakyReLU(0.2, inplace=True))
80
+ self.add_module("Conv_1", blocks.getTorchModule("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False))
81
+ self.add_module("Residual", blocks.Add(), in_branch=[0,1])
82
+
83
+ class GeneratorNResnetBlock(network.ModuleArgsDict):
84
+
85
+ def __init__(self, channels: int, nb_conv: int, dim: int) -> None:
86
+ super().__init__()
87
+ for i in range(nb_conv):
88
+ self.add_module("ResnetBlock_{}".format(i), Generator.GeneratorResnetBlock(channels=channels, dim=dim))
89
+
90
+ class GeneratorDecoder(network.ModuleArgsDict):
91
+ def __init__(self, channels: list[int], dim: int) -> None:
92
+ super().__init__()
93
+ for i, (in_channels, out_channels) in enumerate(zip(reversed(channels), reversed(channels[:-1]))):
94
+ self.add_module("UpSample_{}".format(i), Generator.GeneratorUpSample(in_channels=in_channels, out_channels=out_channels, dim=dim))
95
+
96
+ class GeneratorAutoEncoder(network.ModuleArgsDict):
97
+
98
+ def __init__(self, ngf: int, dim: int) -> None:
99
+ super().__init__()
100
+ channels = [ngf, ngf*2]
101
+ self.add_module("Encoder", Generator.GeneratorEncoder(channels, dim))
102
+ self.add_module("NResBlock", Generator.GeneratorNResnetBlock(channels=channels[-1], nb_conv=6, dim=dim))
103
+ self.add_module("Decoder", Generator.GeneratorDecoder(channels, dim))
104
+
105
+ @config("Generator")
106
+ def __init__(self,
107
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
108
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
109
+ patch : ModelPatch = ModelPatch(),
110
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
111
+ nb_batch_per_step: int = 64,
112
+ dim : int = 3) -> None:
113
+ super().__init__(optimizer=optimizer, in_channels=1, schedulers=schedulers, patch=patch, outputsCriterions=outputsCriterions, dim=dim, nb_batch_per_step=nb_batch_per_step)
114
+ ngf=32
115
+ self.add_module("Stem", Generator.GeneratorStem(1, ngf, dim))
116
+ self.add_module("AutoEncoder", Generator.GeneratorAutoEncoder(ngf, dim))
117
+ self.add_module("Head", Generator.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
118
+
119
+ def getName(self):
120
+ return "Generator"
121
+
122
+ class Gan(network.Network):
123
+
124
+ @config("Gan")
125
+ def __init__(self, generator : Generator = Generator(), discriminator : Discriminator = Discriminator()) -> None:
126
+ super().__init__()
127
+ self.add_module("Discriminator_B", discriminator, in_branch=[1], out_branch=[-1], requires_grad=True)
128
+ self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
129
+
130
+ self.add_module("detach", blocks.Detach(), in_branch=["pB"], out_branch=["pB_detach"])
131
+ self.add_module("Discriminator_pB_detach", discriminator, in_branch=["pB_detach"], out_branch=[-1])
132
+
133
+ self.add_module("Discriminator_pB", discriminator, in_branch=["pB"], out_branch=[-1], requires_grad=False)
134
+
@@ -0,0 +1,72 @@
1
+ import torch
2
+
3
+ from KonfAI.konfai.network import network, blocks
4
+ from KonfAI.konfai.utils.config import config
5
+
6
+ class VAE(network.Network):
7
+
8
+ class AutoEncoderBlock(network.ModuleArgsDict):
9
+
10
+ def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, dim: int, block: type, i : int = 0) -> None:
11
+ super().__init__()
12
+ if i > 0:
13
+ self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
14
+ self.add_module("DownBlock", block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim))
15
+ if len(channels) > 2:
16
+ self.add_module("AutoEncoder_{}".format(i+1), VAE.AutoEncoderBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, dim, block, i+1))
17
+ self.add_module("UpBlock", block(in_channels=channels[2] if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim))
18
+ if i > 0:
19
+ self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim))
20
+
21
+ class VAE_Head(network.ModuleArgsDict):
22
+
23
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
24
+ super().__init__()
25
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
26
+ self.add_module("Tanh", torch.nn.Tanh())
27
+
28
+ @config("VAE")
29
+ def __init__(self,
30
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
31
+ schedulers: network.LRSchedulersLoader = network.LRSchedulersLoader(),
32
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
33
+ dim : int = 3,
34
+ channels: list[int]=[1, 64, 128, 256, 512, 1024],
35
+ blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
36
+ nb_conv_per_stage: int = 2,
37
+ downSampleMode: str = "MAXPOOL",
38
+ upSampleMode: str = "CONV_TRANSPOSE",
39
+ blockType: str = "Conv") -> None:
40
+
41
+ super().__init__(in_channels = channels[0], init_type="normal", optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=dim, nb_batch_per_step=1)
42
+ self.add_module("AutoEncoder_0", VAE.AutoEncoderBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], dim=dim, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock))
43
+ self.add_module("Head", VAE.VAE_Head(channels[1], channels[0], dim))
44
+
45
+
46
+ class LinearVAE(network.Network):
47
+
48
+ class LinearVAE_DenseLayer(network.ModuleArgsDict):
49
+
50
+ def __init__(self, in_features: int, out_features: int) -> None:
51
+ super().__init__()
52
+ self.add_module("Linear", torch.nn.Linear(in_features, out_features))
53
+ #self.add_module("Norm", torch.nn.BatchNorm1d(out_features))
54
+ self.add_module("Activation", torch.nn.LeakyReLU())
55
+
56
+ class LinearVAE_Head(network.ModuleArgsDict):
57
+
58
+ def __init__(self, in_features: int, out_features: int) -> None:
59
+ super().__init__()
60
+ self.add_module("Linear", torch.nn.Linear(in_features, out_features))
61
+ self.add_module("Tanh", torch.nn.Tanh())
62
+
63
+ @config("LinearVAE")
64
+ def __init__(self,
65
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
66
+ schedulers: network.LRSchedulersLoader = network.LRSchedulersLoader(),
67
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},) -> None:
68
+ super().__init__(in_channels = 1, init_type="normal", optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=1, nb_batch_per_step=1)
69
+ self.add_module("DenseLayer_0", LinearVAE.LinearVAE_DenseLayer(23343, 5))
70
+ #self.add_module("Head", LinearVAE.DenseLayer(100, 28590))
71
+ self.add_module("Head", LinearVAE.LinearVAE_Head(5, 23343))
72
+ #self.add_module("DenseLayer_5", LinearVAE.DenseLayer(5000, 28590))
@@ -0,0 +1,136 @@
1
+ import torch
2
+ from torch.nn.parameter import Parameter
3
+ import torch.nn.functional as F
4
+
5
+ from KonfAI.konfai.network import network, blocks
6
+ from KonfAI.konfai.utils.config import config
7
+ from KonfAI.konfai.models.segmentation import UNet
8
+
9
+ class VoxelMorph(network.Network):
10
+
11
+ @config("VoxelMorph")
12
+ def __init__( self,
13
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
14
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
15
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
16
+ dim : int = 3,
17
+ channels : list[int] = [4, 16,32,32,32],
18
+ blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
19
+ nb_conv_per_stage: int = 2,
20
+ downSampleMode: str = "MAXPOOL",
21
+ upSampleMode: str = "CONV_TRANSPOSE",
22
+ attention : bool = False,
23
+ shape : list[int] = [192, 192, 192],
24
+ int_steps : int = 7,
25
+ int_downsize : int = 2,
26
+ nb_batch_per_step : int = 1,
27
+ rigid: bool = False):
28
+ super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim = dim, nb_batch_per_step=nb_batch_per_step)
29
+ self.add_module("Concat", blocks.Concat(), in_branch=[0,1], out_branch=["input_concat"])
30
+ self.add_module("UNetBlock_0", UNet.UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock, nb_class=0, dim=dim), in_branch=["input_concat"], out_branch=["unet"])
31
+
32
+ if rigid:
33
+ self.add_module("Flow", Rigid(channels[1], dim), in_branch=["unet"], out_branch=["pos_flow"])
34
+ else:
35
+ self.add_module("Flow", Flow(channels[1], int_steps, int_downsize, shape, dim), in_branch=["unet"], out_branch=["pos_flow"])
36
+ self.add_module("MovingImageResample", SpatialTransformer(shape, rigid=rigid), in_branch=[1, "pos_flow"], out_branch=["moving_image_resample"])
37
+
38
+ class Flow(network.ModuleArgsDict):
39
+
40
+ def __init__(self, in_channels: int, int_steps: int, int_downsize: int, shape: list[int], dim: int) -> None:
41
+ super().__init__()
42
+ self.add_module("Head", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = dim, kernel_size = 3, stride = 1, padding = 1))
43
+ self["Head"].weight = Parameter(torch.distributions.Normal(0, 1e-5).sample(self["Head"].weight.shape))
44
+ self["Head"].bias = Parameter(torch.zeros(self["Head"].bias.shape))
45
+
46
+ if int_steps > 0 and int_downsize > 1:
47
+ self.add_module("DownSample", ResizeTransform(int_downsize))
48
+
49
+ if int_steps > 0:
50
+ self.add_module("Integrate_pos_flow", VecInt([int(dim / int_downsize) for dim in shape], int_steps))
51
+
52
+ if int_steps > 0 and int_downsize > 1:
53
+ self.add_module("Upsample_pos_flow", ResizeTransform(1 / int_downsize))
54
+
55
+ class Rigid(network.ModuleArgsDict):
56
+
57
+ def __init__(self, in_channels: int, dim: int) -> None:
58
+ super().__init__()
59
+ self.add_module("ToFeatures", torch.nn.Flatten(1))
60
+ self.add_module("Head", torch.nn.Linear(in_channels*512*512, 2))
61
+
62
+ def init(self, init_type: str, init_gain: float):
63
+ self["Head"].weight.data.fill_(0)
64
+ self["Head"].bias.data.copy_(torch.tensor([0, 0], dtype=torch.float))
65
+
66
+ class MaskFlow(torch.nn.Module):
67
+
68
+ def __init__(self):
69
+ super().__init__()
70
+
71
+ def forward(self, mask: torch.Tensor, *flows: torch.Tensor):
72
+ result = torch.zeros_like(flows[0])
73
+ for i, flow in enumerate(flows):
74
+ result = result+torch.where(mask == i+1, flow, torch.tensor(0))
75
+ return result
76
+
77
+ class SpatialTransformer(torch.nn.Module):
78
+
79
+ def __init__(self, size : list[int], rigid: bool = False):
80
+ super().__init__()
81
+ self.rigid = rigid
82
+ if not rigid:
83
+ vectors = [torch.arange(0, s) for s in size]
84
+ grids = torch.meshgrid(vectors, indexing='ij')
85
+ grid = torch.stack(grids)
86
+ grid = torch.unsqueeze(grid, 0)
87
+ grid = grid.type(torch.float)
88
+ self.register_buffer('grid', grid)
89
+
90
+ def forward(self, src: torch.Tensor, flow: torch.Tensor):
91
+ if self.rigid:
92
+ new_locs = torch.zeros((flow.shape[0], 2, 3)).to(flow.device)
93
+ new_locs[:, 0,0] = 1
94
+ new_locs[:, 1,1] = 1
95
+ new_locs[:, 0,2] = flow[:, 0]
96
+ new_locs[:, 1,2] = flow[:, 1]
97
+ print(new_locs)
98
+ return F.grid_sample(src, F.affine_grid(new_locs, src.size()), align_corners=True, mode="bilinear")
99
+ else:
100
+ new_locs = self.grid + flow
101
+ shape = flow.shape[2:]
102
+ for i in range(len(shape)):
103
+ new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
104
+ new_locs = new_locs.permute(0, 2, 3, 1)
105
+ return F.grid_sample(src, new_locs[..., [1, 0]], align_corners=True, mode="bilinear")
106
+
107
+ class VecInt(torch.nn.Module):
108
+
109
+ def __init__(self, inshape, nsteps):
110
+ super().__init__()
111
+ assert nsteps >= 0, 'nsteps should be >= 0, found: %d' % nsteps
112
+ self.nsteps = nsteps
113
+ self.scale = 1.0 / (2 ** self.nsteps)
114
+ self.transformer = SpatialTransformer(inshape)
115
+
116
+ def forward(self, vec: torch.Tensor):
117
+ vec = vec * self.scale
118
+ for _ in range(self.nsteps):
119
+ vec = vec + self.transformer(vec, vec)
120
+ return vec
121
+
122
+
123
+ class ResizeTransform(torch.nn.Module):
124
+
125
+ def __init__(self, size):
126
+ super().__init__()
127
+ self.factor = 1.0 / size
128
+
129
+ def forward(self, x: torch.Tensor):
130
+ if self.factor < 1:
131
+ x = F.interpolate(x, align_corners=True, scale_factor=self.factor, mode="bilinear", recompute_scale_factor = True)
132
+ x = self.factor * x
133
+ elif self.factor > 1:
134
+ x = self.factor * x
135
+ x = F.interpolate(x, align_corners=True, scale_factor=self.factor, mode="bilinear", recompute_scale_factor = True)
136
+ return x
@@ -0,0 +1,57 @@
1
+ import torch
2
+
3
+ from KonfAI.konfai.network import network, blocks
4
+ from KonfAI.konfai.utils.config import config
5
+
6
+ class ConvBlock(torch.nn.Module):
7
+
8
+ def __init__(self, in_channels : int, out_channels : int) -> None:
9
+ super().__init__()
10
+ self.Conv_0 = torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True)
11
+ self.Norm_0 = torch.nn.InstanceNorm3d(num_features=out_channels, affine=True)
12
+ self.Activation_0 = torch.nn.LeakyReLU(negative_slope=0.01)
13
+ self.Conv_1 = torch.nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True)
14
+ self.Norm_1 = torch.nn.InstanceNorm3d(num_features=out_channels, affine=True)
15
+ self.Activation_1 = torch.nn.LeakyReLU(negative_slope=0.01)
16
+
17
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
18
+ output = self.Conv_0(input)
19
+ output = self.Norm_0(output)
20
+ output = self.Activation_0(output)
21
+ output = self.Conv_1(output)
22
+ output = self.Norm_1(output)
23
+ output = self.Activation_1(output)
24
+ return output
25
+
26
+ class UnetCPP_1_Layers(torch.nn.Module):
27
+
28
+ def __init__(self) -> None:
29
+ super().__init__()
30
+ self.DownConvBlock_0 = ConvBlock(in_channels=1, out_channels=32)
31
+
32
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
33
+ return self.DownConvBlock_0(input)
34
+
35
+ class Adaptation(torch.nn.Module):
36
+
37
+ def __init__(self) -> None:
38
+ super().__init__()
39
+ self.Encoder_1 = UnetCPP_1_Layers()
40
+ self.ToFeatures = blocks.ToFeatures(3)
41
+ self.FCT_1 = torch.nn.Linear(32, 32, bias=True)
42
+
43
+ def forward(self, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor) -> list[torch.Tensor]:
44
+ self.Encoder_1.requires_grad_(False)
45
+ self.FCT_1.requires_grad_(True)
46
+ return self.FCT_1(self.ToFeatures(self.Encoder_1(A))), self.FCT_1(self.ToFeatures(self.Encoder_1(B))), self.FCT_1(self.ToFeatures(self.Encoder_1(C)))
47
+
48
+ class Representation(network.Network):
49
+
50
+ @config("Representation")
51
+ def __init__( self,
52
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
53
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
54
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
55
+ dim : int = 3):
56
+ super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=dim, init_type="kaiming")
57
+ self.add_module("Model", Adaptation(), in_branch=[0,1,2])
@@ -0,0 +1,53 @@
1
+ from typing import Union
2
+ import torch
3
+
4
+ from KonfAI.konfai.network import network, blocks
5
+ from KonfAI.konfai.utils.config import config
6
+ from KonfAI.konfai.data.HDF5 import ModelPatch
7
+
8
+
9
+ class NestedUNetBlock(network.ModuleArgsDict):
10
+
11
+ def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, dim: int, i : int = 0) -> None:
12
+ super().__init__()
13
+ if i > 0:
14
+ self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
15
+
16
+ self.add_module("X_{}_{}".format(i, 0), block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), out_branch=["X_{}_{}".format(i, 0)])
17
+ if len(channels) > 2:
18
+ self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(channels)-2)])
19
+ for j in range(len(channels)-2):
20
+ self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=channels[2], out_channels=channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
21
+ self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
22
+ self.add_module("X_{}_{}".format(i, j+1), block(in_channels=(channels[1]*(j+1)+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*(j+2), out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
23
+
24
+ class NestedUNet(network.Network):
25
+
26
+ class NestedUNetHead(network.ModuleArgsDict):
27
+
28
+ def __init__(self, in_channels: int, nb_class: int, dim: int) -> None:
29
+ super().__init__()
30
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
31
+ self.add_module("Softmax", torch.nn.Softmax(dim=1))
32
+ self.add_module("Argmax", blocks.ArgMax(dim=1))
33
+
34
+ @config("NestedUNet")
35
+ def __init__( self,
36
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
37
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
38
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
39
+ patch : Union[ModelPatch, None] = None,
40
+ dim : int = 3,
41
+ channels: list[int]=[1, 64, 128, 256, 512, 1024],
42
+ nb_class: int = 2,
43
+ blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
44
+ nb_conv_per_stage: int = 2,
45
+ downSampleMode: str = "MAXPOOL",
46
+ upSampleMode: str = "CONV_TRANSPOSE",
47
+ attention : bool = False,
48
+ blockType: str = "Conv") -> None:
49
+ super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
50
+
51
+ self.add_module("UNetBlock_0", NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
52
+ for j in range(len(channels)-2):
53
+ self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
@@ -0,0 +1,58 @@
1
+ import torch
2
+ from typing import Union
3
+
4
+ from KonfAI.konfai.network import network, blocks
5
+ from KonfAI.konfai.utils.config import config
6
+ from KonfAI.konfai.data.HDF5 import ModelPatch
7
+
8
+ class UNetHead(network.ModuleArgsDict):
9
+
10
+ def __init__(self, in_channels: int, nb_class: int, dim: int) -> None:
11
+ super().__init__()
12
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
13
+ self.add_module("Softmax", torch.nn.Softmax(dim=1))
14
+ self.add_module("Argmax", blocks.ArgMax(dim=1))
15
+
16
+ class UNetBlock(network.ModuleArgsDict):
17
+
18
+ def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, nb_class: int, dim: int, i : int = 0, mri: bool = False) -> None:
19
+ super().__init__()
20
+ blockConfig_stride = blockConfig
21
+ if i > 0:
22
+ if downSampleMode != blocks.DownSampleMode.CONV_STRIDE:
23
+ self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
24
+ else:
25
+ blockConfig_stride = blocks.BlockConfig(blockConfig.kernel_size, (1,2,2) if mri and i > 4 else 2, blockConfig.padding, blockConfig.bias, blockConfig.activation, blockConfig.normMode)
26
+ self.add_module("DownConvBlock", block(in_channels=channels[0], out_channels=channels[1], blockConfigs=[blockConfig_stride]+[blockConfig]*(nb_conv_per_stage-1), dim=dim))
27
+ if len(channels) > 2:
28
+ self.add_module("UNetBlock_{}".format(i+1), UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, nb_class, dim, i+1, mri=mri))
29
+ self.add_module("UpConvBlock", block(in_channels=(channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2, out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim))
30
+ if nb_class > 0:
31
+ self.add_module("Head", UNetHead(channels[1], nb_class, dim), out_branch=[-1])
32
+ if i > 0:
33
+ if attention:
34
+ self.add_module("Attention", blocks.Attention(F_g=channels[1], F_l=channels[0], F_int=channels[0], dim=dim), in_branch=[1, 0], out_branch=[1])
35
+ self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim, kernel_size=(1,2,2) if mri and i > 4 else 2, stride=(1,2,2) if mri and i > 4 else 2))
36
+ self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
37
+
38
+ class UNet(network.Network):
39
+
40
+ @config("UNet")
41
+ def __init__( self,
42
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
43
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
44
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
45
+ patch : Union[ModelPatch, None] = None,
46
+ dim : int = 3,
47
+ channels: list[int]=[1, 64, 128, 256, 512, 1024],
48
+ nb_class: int = 2,
49
+ blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
50
+ nb_conv_per_stage: int = 2,
51
+ downSampleMode: str = "MAXPOOL",
52
+ upSampleMode: str = "CONV_TRANSPOSE",
53
+ attention : bool = False,
54
+ blockType: str = "Conv",
55
+ mri: bool = False) -> None:
56
+ super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
57
+ self.add_module("UNetBlock_0", UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=nb_class, dim=dim, mri = mri))
58
+
File without changes