konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
@@ -1,99 +1,168 @@
1
1
  import torch
2
+ import torch.nn.functional as F # noqa: N812
2
3
  from torch.nn.parameter import Parameter
3
- import torch.nn.functional as F
4
4
 
5
- from konfai.network import network, blocks
6
- from konfai.utils.config import config
7
5
  from konfai.models.segmentation import UNet
6
+ from konfai.network import blocks, network
7
+
8
8
 
9
9
  class VoxelMorph(network.Network):
10
-
11
- def __init__( self,
12
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
13
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
14
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
15
- dim : int = 3,
16
- channels : list[int] = [4, 16,32,32,32],
17
- blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
18
- nb_conv_per_stage: int = 2,
19
- downSampleMode: str = "MAXPOOL",
20
- upSampleMode: str = "CONV_TRANSPOSE",
21
- attention : bool = False,
22
- shape : list[int] = [192, 192, 192],
23
- int_steps : int = 7,
24
- int_downsize : int = 2,
25
- nb_batch_per_step : int = 1,
26
- rigid: bool = False):
27
- super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim = dim, nb_batch_per_step=nb_batch_per_step)
28
- self.add_module("Concat", blocks.Concat(), in_branch=[0,1], out_branch=["input_concat"])
29
- self.add_module("UNetBlock_0", UNet.UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock, nb_class=0, dim=dim), in_branch=["input_concat"], out_branch=["unet"])
10
+
11
+ def __init__(
12
+ self,
13
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
14
+ schedulers: dict[str, network.LRSchedulersLoader] = {
15
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
16
+ },
17
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
18
+ dim: int = 3,
19
+ channels: list[int] = [4, 16, 32, 32, 32],
20
+ block_config: blocks.BlockConfig = blocks.BlockConfig(),
21
+ nb_conv_per_stage: int = 2,
22
+ downsample_mode: str = "MAXPOOL",
23
+ upsample_mode: str = "CONV_TRANSPOSE",
24
+ attention: bool = False,
25
+ shape: list[int] = [192, 192, 192],
26
+ int_steps: int = 7,
27
+ int_downsize: int = 2,
28
+ nb_batch_per_step: int = 1,
29
+ rigid: bool = False,
30
+ ):
31
+ super().__init__(
32
+ in_channels=channels[0],
33
+ optimizer=optimizer,
34
+ schedulers=schedulers,
35
+ outputs_criterions=outputs_criterions,
36
+ dim=dim,
37
+ nb_batch_per_step=nb_batch_per_step,
38
+ )
39
+ self.add_module("Concat", blocks.Concat(), in_branch=[0, 1], out_branch=["input_concat"])
40
+ self.add_module(
41
+ "UNetBlock_0",
42
+ UNet.UNetBlock(
43
+ channels,
44
+ nb_conv_per_stage,
45
+ block_config,
46
+ downsample_mode=blocks.DownsampleMode[downsample_mode],
47
+ upsample_mode=blocks.UpsampleMode[upsample_mode],
48
+ attention=attention,
49
+ block=blocks.ConvBlock,
50
+ nb_class=0,
51
+ dim=dim,
52
+ ),
53
+ in_branch=["input_concat"],
54
+ out_branch=["unet"],
55
+ )
30
56
 
31
57
  if rigid:
32
- self.add_module("Flow", Rigid(channels[1], dim), in_branch=["unet"], out_branch=["pos_flow"])
58
+ self.add_module(
59
+ "Flow",
60
+ Rigid(channels[1], dim),
61
+ in_branch=["unet"],
62
+ out_branch=["pos_flow"],
63
+ )
33
64
  else:
34
- self.add_module("Flow", Flow(channels[1], int_steps, int_downsize, shape, dim), in_branch=["unet"], out_branch=["pos_flow"])
35
- self.add_module("MovingImageResample", SpatialTransformer(shape, rigid=rigid), in_branch=[1, "pos_flow"], out_branch=["moving_image_resample"])
65
+ self.add_module(
66
+ "Flow",
67
+ Flow(channels[1], int_steps, int_downsize, shape, dim),
68
+ in_branch=["unet"],
69
+ out_branch=["pos_flow"],
70
+ )
71
+ self.add_module(
72
+ "MovingImageResample",
73
+ SpatialTransformer(shape, rigid=rigid),
74
+ in_branch=[1, "pos_flow"],
75
+ out_branch=["moving_image_resample"],
76
+ )
77
+
36
78
 
37
79
  class Flow(network.ModuleArgsDict):
38
80
 
39
- def __init__(self, in_channels: int, int_steps: int, int_downsize: int, shape: list[int], dim: int) -> None:
81
+ def __init__(
82
+ self,
83
+ in_channels: int,
84
+ int_steps: int,
85
+ int_downsize: int,
86
+ shape: list[int],
87
+ dim: int,
88
+ ) -> None:
40
89
  super().__init__()
41
- self.add_module("Head", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = dim, kernel_size = 3, stride = 1, padding = 1))
90
+ self.add_module(
91
+ "Head",
92
+ blocks.get_torch_module("Conv", dim)(
93
+ in_channels=in_channels,
94
+ out_channels=dim,
95
+ kernel_size=3,
96
+ stride=1,
97
+ padding=1,
98
+ ),
99
+ )
42
100
  self["Head"].weight = Parameter(torch.distributions.Normal(0, 1e-5).sample(self["Head"].weight.shape))
43
101
  self["Head"].bias = Parameter(torch.zeros(self["Head"].bias.shape))
44
-
102
+
45
103
  if int_steps > 0 and int_downsize > 1:
46
104
  self.add_module("DownSample", ResizeTransform(int_downsize))
47
105
 
48
106
  if int_steps > 0:
49
- self.add_module("Integrate_pos_flow", VecInt([int(dim / int_downsize) for dim in shape], int_steps))
50
-
107
+ self.add_module(
108
+ "Integrate_pos_flow",
109
+ VecInt([int(dim / int_downsize) for dim in shape], int_steps),
110
+ )
111
+
51
112
  if int_steps > 0 and int_downsize > 1:
52
113
  self.add_module("Upsample_pos_flow", ResizeTransform(1 / int_downsize))
53
114
 
115
+
54
116
  class Rigid(network.ModuleArgsDict):
55
117
 
56
118
  def __init__(self, in_channels: int, dim: int) -> None:
57
119
  super().__init__()
58
120
  self.add_module("ToFeatures", torch.nn.Flatten(1))
59
- self.add_module("Head", torch.nn.Linear(in_channels*512*512, 2))
121
+ self.add_module("Head", torch.nn.Linear(in_channels * 512 * 512, 2))
60
122
 
61
123
  def init(self, init_type: str, init_gain: float):
62
124
  self["Head"].weight.data.fill_(0)
63
125
  self["Head"].bias.data.copy_(torch.tensor([0, 0], dtype=torch.float))
64
-
126
+
127
+
65
128
  class MaskFlow(torch.nn.Module):
66
129
 
67
130
  def __init__(self):
68
131
  super().__init__()
69
-
132
+
70
133
  def forward(self, mask: torch.Tensor, *flows: torch.Tensor):
71
134
  result = torch.zeros_like(flows[0])
72
135
  for i, flow in enumerate(flows):
73
- result = result+torch.where(mask == i+1, flow, torch.tensor(0))
136
+ result = result + torch.where(mask == i + 1, flow, torch.tensor(0))
74
137
  return result
75
138
 
139
+
76
140
  class SpatialTransformer(torch.nn.Module):
77
-
78
- def __init__(self, size : list[int], rigid: bool = False):
141
+
142
+ def __init__(self, size: list[int], rigid: bool = False):
79
143
  super().__init__()
80
144
  self.rigid = rigid
81
145
  if not rigid:
82
146
  vectors = [torch.arange(0, s) for s in size]
83
- grids = torch.meshgrid(vectors, indexing='ij')
147
+ grids = torch.meshgrid(vectors, indexing="ij")
84
148
  grid = torch.stack(grids)
85
149
  grid = torch.unsqueeze(grid, 0)
86
150
  grid = grid.type(torch.float)
87
- self.register_buffer('grid', grid)
151
+ self.register_buffer("grid", grid)
88
152
 
89
153
  def forward(self, src: torch.Tensor, flow: torch.Tensor):
90
154
  if self.rigid:
91
155
  new_locs = torch.zeros((flow.shape[0], 2, 3)).to(flow.device)
92
- new_locs[:, 0,0] = 1
93
- new_locs[:, 1,1] = 1
94
- new_locs[:, 0,2] = flow[:, 0]
95
- new_locs[:, 1,2] = flow[:, 1]
96
- return F.grid_sample(src, F.affine_grid(new_locs, src.size()), align_corners=True, mode="bilinear")
156
+ new_locs[:, 0, 0] = 1
157
+ new_locs[:, 1, 1] = 1
158
+ new_locs[:, 0, 2] = flow[:, 0]
159
+ new_locs[:, 1, 2] = flow[:, 1]
160
+ return F.grid_sample(
161
+ src,
162
+ F.affine_grid(new_locs, src.size()),
163
+ align_corners=True,
164
+ mode="bilinear",
165
+ )
97
166
  else:
98
167
  new_locs = self.grid + flow
99
168
  shape = flow.shape[2:]
@@ -102,13 +171,15 @@ class SpatialTransformer(torch.nn.Module):
102
171
  new_locs = new_locs.permute(0, 2, 3, 1)
103
172
  return F.grid_sample(src, new_locs[..., [1, 0]], align_corners=True, mode="bilinear")
104
173
 
174
+
105
175
  class VecInt(torch.nn.Module):
106
176
 
107
- def __init__(self, inshape, nsteps):
177
+ def __init__(self, inshape: list[int], nsteps: int):
108
178
  super().__init__()
109
- assert nsteps >= 0, 'nsteps should be >= 0, found: %d' % nsteps
179
+ if nsteps < 0:
180
+ raise NameError(f"nsteps should be >= 0, found: {nsteps}")
110
181
  self.nsteps = nsteps
111
- self.scale = 1.0 / (2 ** self.nsteps)
182
+ self.scale = 1.0 / (2**self.nsteps)
112
183
  self.transformer = SpatialTransformer(inshape)
113
184
 
114
185
  def forward(self, vec: torch.Tensor):
@@ -119,16 +190,28 @@ class VecInt(torch.nn.Module):
119
190
 
120
191
 
121
192
  class ResizeTransform(torch.nn.Module):
122
-
123
- def __init__(self, size):
193
+
194
+ def __init__(self, size: float):
124
195
  super().__init__()
125
196
  self.factor = 1.0 / size
126
197
 
127
198
  def forward(self, x: torch.Tensor):
128
199
  if self.factor < 1:
129
- x = F.interpolate(x, align_corners=True, scale_factor=self.factor, mode="bilinear", recompute_scale_factor = True)
200
+ x = F.interpolate(
201
+ x,
202
+ align_corners=True,
203
+ scale_factor=self.factor,
204
+ mode="bilinear",
205
+ recompute_scale_factor=True,
206
+ )
130
207
  x = self.factor * x
131
208
  elif self.factor > 1:
132
209
  x = self.factor * x
133
- x = F.interpolate(x, align_corners=True, scale_factor=self.factor, mode="bilinear", recompute_scale_factor = True)
134
- return x
210
+ x = F.interpolate(
211
+ x,
212
+ align_corners=True,
213
+ scale_factor=self.factor,
214
+ mode="bilinear",
215
+ recompute_scale_factor=True,
216
+ )
217
+ return x
@@ -1,21 +1,35 @@
1
1
  import torch
2
2
 
3
- from konfai.network import network, blocks
4
- from konfai.utils.config import config
3
+ from konfai.network import blocks, network
4
+
5
5
 
6
6
  class ConvBlock(torch.nn.Module):
7
-
8
- def __init__(self, in_channels : int, out_channels : int) -> None:
7
+
8
+ def __init__(self, in_channels: int, out_channels: int) -> None:
9
9
  super().__init__()
10
- self.Conv_0 = torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True)
10
+ self.Conv_0 = torch.nn.Conv3d(
11
+ in_channels=in_channels,
12
+ out_channels=out_channels,
13
+ kernel_size=3,
14
+ stride=1,
15
+ padding=1,
16
+ bias=True,
17
+ )
11
18
  self.Norm_0 = torch.nn.InstanceNorm3d(num_features=out_channels, affine=True)
12
19
  self.Activation_0 = torch.nn.LeakyReLU(negative_slope=0.01)
13
- self.Conv_1 = torch.nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1, bias=True)
20
+ self.Conv_1 = torch.nn.Conv3d(
21
+ in_channels=out_channels,
22
+ out_channels=out_channels,
23
+ kernel_size=3,
24
+ stride=1,
25
+ padding=1,
26
+ bias=True,
27
+ )
14
28
  self.Norm_1 = torch.nn.InstanceNorm3d(num_features=out_channels, affine=True)
15
29
  self.Activation_1 = torch.nn.LeakyReLU(negative_slope=0.01)
16
-
17
- def forward(self, input: torch.Tensor) -> torch.Tensor:
18
- output = self.Conv_0(input)
30
+
31
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
32
+ output = self.Conv_0(tensor)
19
33
  output = self.Norm_0(output)
20
34
  output = self.Activation_0(output)
21
35
  output = self.Conv_1(output)
@@ -23,34 +37,54 @@ class ConvBlock(torch.nn.Module):
23
37
  output = self.Activation_1(output)
24
38
  return output
25
39
 
26
- class UnetCPP_1_Layers(torch.nn.Module):
40
+
41
+ class UnetCPP1Layers(torch.nn.Module):
27
42
 
28
43
  def __init__(self) -> None:
29
44
  super().__init__()
30
45
  self.DownConvBlock_0 = ConvBlock(in_channels=1, out_channels=32)
31
-
32
- def forward(self, input: torch.Tensor) -> torch.Tensor:
33
- return self.DownConvBlock_0(input)
46
+
47
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
48
+ return self.DownConvBlock_0(tensor)
49
+
34
50
 
35
51
  class Adaptation(torch.nn.Module):
36
52
 
37
53
  def __init__(self) -> None:
38
54
  super().__init__()
39
- self.Encoder_1 = UnetCPP_1_Layers()
55
+ self.Encoder_1 = UnetCPP1Layers()
40
56
  self.ToFeatures = blocks.ToFeatures(3)
41
57
  self.FCT_1 = torch.nn.Linear(32, 32, bias=True)
42
58
 
43
- def forward(self, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor) -> list[torch.Tensor]:
59
+ def forward(
60
+ self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
61
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
44
62
  self.Encoder_1.requires_grad_(False)
45
63
  self.FCT_1.requires_grad_(True)
46
- return self.FCT_1(self.ToFeatures(self.Encoder_1(A))), self.FCT_1(self.ToFeatures(self.Encoder_1(B))), self.FCT_1(self.ToFeatures(self.Encoder_1(C)))
64
+ return (
65
+ self.FCT_1(self.ToFeatures(self.Encoder_1(a))),
66
+ self.FCT_1(self.ToFeatures(self.Encoder_1(b))),
67
+ self.FCT_1(self.ToFeatures(self.Encoder_1(c))),
68
+ )
69
+
47
70
 
48
71
  class Representation(network.Network):
49
72
 
50
- def __init__( self,
51
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
52
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
53
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
54
- dim : int = 3):
55
- super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=dim, init_type="kaiming")
56
- self.add_module("Model", Adaptation(), in_branch=[0,1,2])
73
+ def __init__(
74
+ self,
75
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
76
+ schedulers: dict[str, network.LRSchedulersLoader] = {
77
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
78
+ },
79
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
80
+ dim: int = 3,
81
+ ):
82
+ super().__init__(
83
+ in_channels=1,
84
+ optimizer=optimizer,
85
+ schedulers=schedulers,
86
+ outputs_criterions=outputs_criterions,
87
+ dim=dim,
88
+ init_type="kaiming",
89
+ )
90
+ self.add_module("Model", Adaptation(), in_branch=[0, 1, 2])