konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +509 -290
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +384 -277
  6. konfai/evaluator.py +309 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +341 -222
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +781 -423
  23. konfai/predictor.py +645 -240
  24. konfai/trainer.py +527 -216
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +495 -249
  29. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
  30. konfai-1.1.9.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.7.dist-info/RECORD +0 -39
  33. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
@@ -1,114 +1,385 @@
1
- from konfai.network import network, blocks
2
- from typing import Union
3
1
  import torch
2
+
4
3
  from konfai.data.patching import ModelPatch
4
+ from konfai.network import blocks, network
5
+
5
6
 
6
7
  class NestedUNet(network.Network):
7
8
 
8
9
  class NestedUNetBlock(network.ModuleArgsDict):
9
10
 
10
- def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, dim: int, i : int = 0) -> None:
11
+ def __init__(
12
+ self,
13
+ channels: list[int],
14
+ nb_conv_per_stage: int,
15
+ block_config: blocks.BlockConfig,
16
+ downsample_mode: blocks.DownsampleMode,
17
+ upsample_mode: blocks.UpsampleMode,
18
+ attention: bool,
19
+ block: type,
20
+ dim: int,
21
+ i: int = 0,
22
+ ) -> None:
11
23
  super().__init__()
12
24
  if i > 0:
13
- self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
14
-
15
- self.add_module("X_{}_{}".format(i, 0), block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), out_branch=["X_{}_{}".format(i, 0)])
16
- if len(channels) > 2:
17
- self.add_module("UNetBlock_{}".format(i+1), NestedUNet.NestedUNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(channels)-2)])
18
- for j in range(len(channels)-2):
19
- self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=channels[2], out_channels=channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
20
- self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
21
- self.add_module("X_{}_{}".format(i, j+1), block(in_channels=(channels[1]*(j+1)+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*(j+2), out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
25
+ self.add_module(
26
+ downsample_mode.name,
27
+ blocks.downsample(
28
+ in_channels=channels[0],
29
+ out_channels=channels[1],
30
+ downsample_mode=downsample_mode,
31
+ dim=dim,
32
+ ),
33
+ )
22
34
 
35
+ self.add_module(
36
+ f"X_{i}_{0}",
37
+ block(
38
+ in_channels=channels[(1 if downsample_mode == blocks.DownsampleMode.CONV_STRIDE and i > 0 else 0)],
39
+ out_channels=channels[1],
40
+ block_configs=[block_config] * nb_conv_per_stage,
41
+ dim=dim,
42
+ ),
43
+ out_branch=[f"X_{i}_{0}"],
44
+ )
45
+ if len(channels) > 2:
46
+ self.add_module(
47
+ f"UNetBlock_{i + 1}",
48
+ NestedUNet.NestedUNetBlock(
49
+ channels[1:],
50
+ nb_conv_per_stage,
51
+ block_config,
52
+ downsample_mode,
53
+ upsample_mode,
54
+ attention,
55
+ block,
56
+ dim,
57
+ i + 1,
58
+ ),
59
+ in_branch=[f"X_{i}_{0}"],
60
+ out_branch=[f"X_{i + 1}_{j}" for j in range(len(channels) - 2)],
61
+ )
62
+ for j in range(len(channels) - 2):
63
+ self.add_module(
64
+ f"X_{i}_{j + 1}_{upsample_mode.name}",
65
+ blocks.upsample(
66
+ in_channels=channels[2],
67
+ out_channels=channels[1],
68
+ upsample_mode=upsample_mode,
69
+ dim=dim,
70
+ ),
71
+ in_branch=[f"X_{i + 1}_{j}"],
72
+ out_branch=[f"X_{i + 1}_{j}"],
73
+ )
74
+ self.add_module(
75
+ f"SkipConnection_{i}_{j + 1}",
76
+ blocks.Concat(),
77
+ in_branch=[f"X_{i + 1}_{j}"] + [f"X_{i}_{r}" for r in range(j + 1)],
78
+ out_branch=[f"X_{i}_{j + 1}"],
79
+ )
80
+ self.add_module(
81
+ f"X_{i}_{j + 1}",
82
+ block(
83
+ in_channels=(
84
+ (channels[1] * (j + 1) + channels[2])
85
+ if upsample_mode != blocks.UpsampleMode.CONV_TRANSPOSE
86
+ else channels[1] * (j + 2)
87
+ ),
88
+ out_channels=channels[1],
89
+ block_configs=[block_config] * nb_conv_per_stage,
90
+ dim=dim,
91
+ ),
92
+ in_branch=[f"X_{i}_{j + 1}"],
93
+ out_branch=[f"X_{i}_{j + 1}"],
94
+ )
23
95
 
24
96
  class NestedUNetHead(network.ModuleArgsDict):
25
97
 
26
98
  def __init__(self, in_channels: int, nb_class: int, activation: str, dim: int) -> None:
27
99
  super().__init__()
28
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
100
+ self.add_module(
101
+ "Conv",
102
+ blocks.get_torch_module("Conv", dim)(
103
+ in_channels=in_channels,
104
+ out_channels=nb_class,
105
+ kernel_size=1,
106
+ stride=1,
107
+ padding=0,
108
+ ),
109
+ )
29
110
  if activation == "Softmax":
30
111
  self.add_module("Softmax", torch.nn.Softmax(dim=1))
31
112
  self.add_module("Argmax", blocks.ArgMax(dim=1))
32
113
  elif activation == "Tanh":
33
114
  self.add_module("Tanh", torch.nn.Tanh())
34
-
35
- def __init__( self,
36
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
37
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
38
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
39
- patch : Union[ModelPatch, None] = None,
40
- dim : int = 3,
41
- channels: list[int]=[1, 64, 128, 256, 512, 1024],
42
- nb_class: int = 2,
43
- blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
44
- nb_conv_per_stage: int = 2,
45
- downSampleMode: str = "MAXPOOL",
46
- upSampleMode: str = "CONV_TRANSPOSE",
47
- attention : bool = False,
48
- blockType: str = "Conv",
49
- activation: str = "Softmax") -> None:
50
- super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
51
-
52
- self.add_module("UNetBlock_0", NestedUNet.NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
53
- for j in range(len(channels)-2):
54
- self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, activation=activation, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
115
+
116
+ def __init__(
117
+ self,
118
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
119
+ schedulers: dict[str, network.LRSchedulersLoader] = {
120
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
121
+ },
122
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
123
+ patch: ModelPatch | None = None,
124
+ dim: int = 3,
125
+ channels: list[int] = [1, 64, 128, 256, 512, 1024],
126
+ nb_class: int = 2,
127
+ block_config: blocks.BlockConfig = blocks.BlockConfig(),
128
+ nb_conv_per_stage: int = 2,
129
+ downsample_mode: str = "MAXPOOL",
130
+ upsample_mode: str = "CONV_TRANSPOSE",
131
+ attention: bool = False,
132
+ block_type: str = "Conv",
133
+ activation: str = "Softmax",
134
+ ) -> None:
135
+ super().__init__(
136
+ in_channels=channels[0],
137
+ optimizer=optimizer,
138
+ schedulers=schedulers,
139
+ outputs_criterions=outputs_criterions,
140
+ patch=patch,
141
+ dim=dim,
142
+ )
143
+
144
+ self.add_module(
145
+ "UNetBlock_0",
146
+ NestedUNet.NestedUNetBlock(
147
+ channels,
148
+ nb_conv_per_stage,
149
+ block_config,
150
+ downsample_mode=blocks.DownsampleMode[downsample_mode],
151
+ upsample_mode=blocks.UpsampleMode[upsample_mode],
152
+ attention=attention,
153
+ block=blocks.ConvBlock if block_type == "Conv" else blocks.ResBlock,
154
+ dim=dim,
155
+ ),
156
+ out_branch=[f"X_0_{j + 1}" for j in range(len(channels) - 2)],
157
+ )
158
+ for j in range(len(channels) - 2):
159
+ self.add_module(
160
+ f"Head_{j}",
161
+ NestedUNet.NestedUNetHead(
162
+ in_channels=channels[1],
163
+ nb_class=nb_class,
164
+ activation=activation,
165
+ dim=dim,
166
+ ),
167
+ in_branch=[f"X_0_{j + 1}"],
168
+ out_branch=[-1],
169
+ )
55
170
 
56
171
 
57
172
  class UNetpp(network.Network):
58
173
 
59
174
  class ResNetEncoderLayer(network.ModuleArgsDict):
60
175
 
61
- def __init__(self, in_channel: int, out_channel: int, nb_block: int, dim: int, downSampleMode : blocks.DownSampleMode):
176
+ def __init__(
177
+ self,
178
+ in_channel: int,
179
+ out_channel: int,
180
+ nb_block: int,
181
+ dim: int,
182
+ downsample_mode: blocks.DownsampleMode,
183
+ ):
62
184
  super().__init__()
63
185
  for i in range(nb_block):
64
- if downSampleMode == blocks.DownSampleMode.MAXPOOL and i == 0:
65
- self.add_module("DownSample", blocks.getTorchModule("MaxPool", dim)(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
66
- self.add_module("ResBlock_{}".format(i), blocks.ResBlock(in_channel, out_channel, [blocks.BlockConfig(3, 2 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i == 0 else 1, 1, False, "ReLU;True", blocks.NormMode.BATCH), blocks.BlockConfig(3, 1, 1, False, None, blocks.NormMode.BATCH)], dim=dim))
186
+ if downsample_mode == blocks.DownsampleMode.MAXPOOL and i == 0:
187
+ self.add_module(
188
+ "DownSample",
189
+ blocks.get_torch_module("MaxPool", dim)(
190
+ kernel_size=3,
191
+ stride=2,
192
+ padding=1,
193
+ dilation=1,
194
+ ceil_mode=False,
195
+ ),
196
+ )
197
+ self.add_module(
198
+ f"ResBlock_{i}",
199
+ blocks.ResBlock(
200
+ in_channel,
201
+ out_channel,
202
+ [
203
+ blocks.BlockConfig(
204
+ 3,
205
+ (2 if downsample_mode == blocks.DownsampleMode.CONV_STRIDE and i == 0 else 1),
206
+ 1,
207
+ False,
208
+ "ReLU;True",
209
+ blocks.NormMode.BATCH,
210
+ ),
211
+ blocks.BlockConfig(3, 1, 1, False, None, blocks.NormMode.BATCH),
212
+ ],
213
+ dim=dim,
214
+ ),
215
+ )
67
216
  in_channel = out_channel
68
-
69
- def resNetEncoder(channels: list[int], layers: list[int], dim: int) -> list[torch.nn.Module]:
217
+
218
+ @staticmethod
219
+ def resnet_encoder(channels: list[int], layers: list[int], dim: int) -> list[torch.nn.Module]:
70
220
  modules = []
71
- modules.append(blocks.ConvBlock(channels[0], channels[1], [blocks.BlockConfig(7, 2, 3, False, "ReLU", blocks.NormMode.BATCH)], dim=dim))
221
+ modules.append(
222
+ blocks.ConvBlock(
223
+ channels[0],
224
+ channels[1],
225
+ [blocks.BlockConfig(7, 2, 3, False, "ReLU", blocks.NormMode.BATCH)],
226
+ dim=dim,
227
+ )
228
+ )
72
229
  for i, (in_channel, out_channel, layer) in enumerate(zip(channels[1:], channels[2:], layers)):
73
- modules.append(UNetpp.ResNetEncoderLayer(in_channel, out_channel, layer, dim, blocks.DownSampleMode.MAXPOOL if i==0 else blocks.DownSampleMode.CONV_STRIDE))
230
+ modules.append(
231
+ UNetpp.ResNetEncoderLayer(
232
+ in_channel,
233
+ out_channel,
234
+ layer,
235
+ dim,
236
+ (blocks.DownsampleMode.MAXPOOL if i == 0 else blocks.DownsampleMode.CONV_STRIDE),
237
+ )
238
+ )
74
239
  return modules
75
240
 
76
241
  class UNetPPBlock(network.ModuleArgsDict):
77
242
 
78
- def __init__(self, encoder_channels: list[int], decoder_channels: list[int], encoders: list[torch.nn.Module], upSampleMode: blocks.UpSampleMode, dim: int, i : int = 0) -> None:
243
+ def __init__(
244
+ self,
245
+ encoder_channels: list[int],
246
+ decoder_channels: list[int],
247
+ encoders: list[torch.nn.Module],
248
+ upsample_mode: blocks.UpsampleMode,
249
+ dim: int,
250
+ i: int = 0,
251
+ ) -> None:
79
252
  super().__init__()
80
- self.add_module("X_{}_{}".format(i, 0), encoders[0], out_branch=["X_{}_{}".format(i, 0)])
253
+ self.add_module(f"X_{i}_{0}", encoders[0], out_branch=[f"X_{i}_{0}"])
81
254
  if len(encoder_channels) > 2:
82
- self.add_module("UNetBlock_{}".format(i+1), UNetpp.UNetPPBlock(encoder_channels[1:], decoder_channels[1:], encoders[1:], upSampleMode, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(encoder_channels)-2)])
83
- for j in range(len(encoder_channels)-2):
84
- in_channels = decoder_channels[3] if j == len(encoder_channels)-3 and len(encoder_channels) > 3 else encoder_channels[2]
85
- out_channel = decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1]
86
- self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=in_channels, out_channels=out_channel, upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
87
- self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
88
- self.add_module("X_{}_{}".format(i, j+1), blocks.ConvBlock(in_channels=encoder_channels[1]*(j+1)+(in_channels if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else out_channel), out_channels=out_channel, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
89
-
255
+ self.add_module(
256
+ f"UNetBlock_{i + 1}",
257
+ UNetpp.UNetPPBlock(
258
+ encoder_channels[1:],
259
+ decoder_channels[1:],
260
+ encoders[1:],
261
+ upsample_mode,
262
+ dim,
263
+ i + 1,
264
+ ),
265
+ in_branch=[f"X_{i}_{0}"],
266
+ out_branch=[f"X_{i + 1}_{j}" for j in range(len(encoder_channels) - 2)],
267
+ )
268
+ for j in range(len(encoder_channels) - 2):
269
+ in_channels = (
270
+ decoder_channels[3]
271
+ if j == len(encoder_channels) - 3 and len(encoder_channels) > 3
272
+ else encoder_channels[2]
273
+ )
274
+ out_channel = decoder_channels[2] if j == len(encoder_channels) - 3 else encoder_channels[1]
275
+ self.add_module(
276
+ f"X_{i}_{j + 1}_{upsample_mode.name}",
277
+ blocks.upsample(
278
+ in_channels=in_channels,
279
+ out_channels=out_channel,
280
+ upsample_mode=upsample_mode,
281
+ dim=dim,
282
+ ),
283
+ in_branch=[f"X_{i + 1}_{j}"],
284
+ out_branch=[f"X_{i + 1}_{j}"],
285
+ )
286
+ self.add_module(
287
+ f"SkipConnection_{i}_{j + 1}",
288
+ blocks.Concat(),
289
+ in_branch=[f"X_{i + 1}_{j}"] + [f"X_{i}_{r}" for r in range(j + 1)],
290
+ out_branch=[f"X_{i}_{j + 1}"],
291
+ )
292
+ self.add_module(
293
+ f"X_{i}_{j + 1}",
294
+ blocks.ConvBlock(
295
+ in_channels=encoder_channels[1] * (j + 1)
296
+ + (in_channels if upsample_mode != blocks.UpsampleMode.CONV_TRANSPOSE else out_channel),
297
+ out_channels=out_channel,
298
+ block_configs=[blocks.BlockConfig(3, 1, 1, False, "ReLU;True", blocks.NormMode.BATCH)] * 2,
299
+ dim=dim,
300
+ ),
301
+ in_branch=[f"X_{i}_{j + 1}"],
302
+ out_branch=[f"X_{i}_{j + 1}"],
303
+ )
304
+
90
305
  class UNetPPHead(network.ModuleArgsDict):
91
306
 
92
307
  def __init__(self, in_channels: int, out_channels: int, nb_class: int, dim: int) -> None:
93
308
  super().__init__()
94
- self.add_module("Upsample", blocks.upSample(in_channels=in_channels, out_channels=out_channels, upSampleMode=blocks.UpSampleMode.UPSAMPLE, dim=dim))
95
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels=in_channels, out_channels=out_channels, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim))
96
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = out_channels, out_channels = nb_class, kernel_size = 3, stride = 1, padding = 1))
309
+ self.add_module(
310
+ "Upsample",
311
+ blocks.upsample(
312
+ in_channels=in_channels,
313
+ out_channels=out_channels,
314
+ upsample_mode=blocks.UpsampleMode.UPSAMPLE,
315
+ dim=dim,
316
+ ),
317
+ )
318
+ self.add_module(
319
+ "ConvBlock",
320
+ blocks.ConvBlock(
321
+ in_channels=in_channels,
322
+ out_channels=out_channels,
323
+ block_configs=[blocks.BlockConfig(3, 1, 1, False, "ReLU;True", blocks.NormMode.BATCH)] * 2,
324
+ dim=dim,
325
+ ),
326
+ )
327
+ self.add_module(
328
+ "Conv",
329
+ blocks.get_torch_module("Conv", dim)(
330
+ in_channels=out_channels,
331
+ out_channels=nb_class,
332
+ kernel_size=3,
333
+ stride=1,
334
+ padding=1,
335
+ ),
336
+ )
97
337
  if nb_class > 1:
98
338
  self.add_module("Softmax", torch.nn.Softmax(dim=1))
99
339
  self.add_module("Argmax", blocks.ArgMax(dim=1))
100
340
  else:
101
341
  self.add_module("Tanh", torch.nn.Tanh())
102
-
103
- def __init__( self,
104
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
105
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
106
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
107
- patch : Union[ModelPatch, None] = None,
108
- encoder_channels: list[int] = [1,64,64,128,256,512],
109
- decoder_channels: list[int] = [256,128,64,32,16,1],
110
- layers: list[int] = [3,4,6,3],
111
- dim : int = 2) -> None:
112
- super().__init__(in_channels = encoder_channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
113
- self.add_module("Block_0", UNetpp.UNetPPBlock(encoder_channels, decoder_channels[::-1], UNetpp.resNetEncoder(encoder_channels, layers, dim), blocks.UpSampleMode.CONV_TRANSPOSE, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(encoder_channels)-2)])
114
- self.add_module("Head", UNetpp.UNetPPHead(in_channels=decoder_channels[-3], out_channels=decoder_channels[-2], nb_class=decoder_channels[-1], dim=dim), in_branch=["X_0_{}".format(len(encoder_channels)-2)], out_branch=[-1])
342
+
343
+ def __init__(
344
+ self,
345
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
346
+ schedulers: dict[str, network.LRSchedulersLoader] = {
347
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
348
+ },
349
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
350
+ patch: ModelPatch | None = None,
351
+ encoder_channels: list[int] = [1, 64, 64, 128, 256, 512],
352
+ decoder_channels: list[int] = [256, 128, 64, 32, 16, 1],
353
+ layers: list[int] = [3, 4, 6, 3],
354
+ dim: int = 2,
355
+ ) -> None:
356
+ super().__init__(
357
+ in_channels=encoder_channels[0],
358
+ optimizer=optimizer,
359
+ schedulers=schedulers,
360
+ outputs_criterions=outputs_criterions,
361
+ patch=patch,
362
+ dim=dim,
363
+ )
364
+ self.add_module(
365
+ "Block_0",
366
+ UNetpp.UNetPPBlock(
367
+ encoder_channels,
368
+ decoder_channels[::-1],
369
+ UNetpp.resnet_encoder(encoder_channels, layers, dim),
370
+ blocks.UpsampleMode.UPSAMPLE,
371
+ dim=dim,
372
+ ),
373
+ out_branch=[f"X_0_{j + 1}" for j in range(len(encoder_channels) - 2)],
374
+ )
375
+ self.add_module(
376
+ "Head",
377
+ UNetpp.UNetPPHead(
378
+ in_channels=decoder_channels[-3],
379
+ out_channels=decoder_channels[-2],
380
+ nb_class=decoder_channels[-1],
381
+ dim=dim,
382
+ ),
383
+ in_branch=[f"X_0_{len(encoder_channels) - 2}"],
384
+ out_branch=[-1],
385
+ )
@@ -1,55 +1,165 @@
1
1
  import torch
2
- from typing import Union
3
2
 
4
- from konfai.network import network, blocks
5
- from konfai.utils.config import config
6
3
  from konfai.data.patching import ModelPatch
4
+ from konfai.network import blocks, network
5
+
7
6
 
8
7
  class UNetHead(network.ModuleArgsDict):
9
8
 
10
9
  def __init__(self, in_channels: int, nb_class: int, dim: int, level: int) -> None:
11
10
  super().__init__()
12
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
11
+ self.add_module(
12
+ "Conv",
13
+ blocks.get_torch_module("Conv", dim)(
14
+ in_channels=in_channels,
15
+ out_channels=nb_class,
16
+ kernel_size=1,
17
+ stride=1,
18
+ padding=0,
19
+ ),
20
+ )
13
21
  self.add_module("Softmax", torch.nn.Softmax(dim=1))
14
22
  self.add_module("Argmax", blocks.ArgMax(dim=1))
15
23
 
24
+
16
25
  class UNetBlock(network.ModuleArgsDict):
17
26
 
18
- def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, nb_class: int, dim: int, i : int = 0) -> None:
27
+ def __init__(
28
+ self,
29
+ channels: list[int],
30
+ nb_conv_per_stage: int,
31
+ block_config: blocks.BlockConfig,
32
+ downsample_mode: blocks.DownsampleMode,
33
+ upsample_mode: blocks.UpsampleMode,
34
+ attention: bool,
35
+ block: type,
36
+ nb_class: int,
37
+ dim: int,
38
+ i: int = 0,
39
+ ) -> None:
19
40
  super().__init__()
20
- blockConfig_stride = blockConfig
41
+ block_config_stride = block_config
21
42
  if i > 0:
22
- if downSampleMode != blocks.DownSampleMode.CONV_STRIDE:
23
- self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
43
+ if downsample_mode != blocks.DownsampleMode.CONV_STRIDE:
44
+ self.add_module(
45
+ downsample_mode.name,
46
+ blocks.downsample(
47
+ in_channels=channels[0],
48
+ out_channels=channels[1],
49
+ downsample_mode=downsample_mode,
50
+ dim=dim,
51
+ ),
52
+ )
24
53
  else:
25
- blockConfig_stride = blocks.BlockConfig(blockConfig.kernel_size, 2, blockConfig.padding, blockConfig.bias, blockConfig.activation, blockConfig.normMode)
26
- self.add_module("DownConvBlock", block(in_channels=channels[0], out_channels=channels[1], blockConfigs=[blockConfig_stride]+[blockConfig]*(nb_conv_per_stage-1), dim=dim))
54
+ block_config_stride = blocks.BlockConfig(
55
+ block_config.kernel_size,
56
+ 2,
57
+ block_config.padding,
58
+ block_config.bias,
59
+ block_config.activation,
60
+ block_config.norm_mode,
61
+ )
62
+ self.add_module(
63
+ "DownConvBlock",
64
+ block(
65
+ in_channels=channels[0],
66
+ out_channels=channels[1],
67
+ block_configs=[block_config_stride] + [block_config] * (nb_conv_per_stage - 1),
68
+ dim=dim,
69
+ ),
70
+ )
27
71
  if len(channels) > 2:
28
- self.add_module("UNetBlock_{}".format(i+1), UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, nb_class, dim, i+1))
29
- self.add_module("UpConvBlock", block(in_channels=(channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2, out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim))
72
+ self.add_module(
73
+ f"UNetBlock_{i + 1}",
74
+ UNetBlock(
75
+ channels[1:],
76
+ nb_conv_per_stage,
77
+ block_config,
78
+ downsample_mode,
79
+ upsample_mode,
80
+ attention,
81
+ block,
82
+ nb_class,
83
+ dim,
84
+ i + 1,
85
+ ),
86
+ )
87
+ self.add_module(
88
+ "UpConvBlock",
89
+ block(
90
+ in_channels=(
91
+ (channels[1] + channels[2])
92
+ if upsample_mode != blocks.UpsampleMode.CONV_TRANSPOSE
93
+ else channels[1] * 2
94
+ ),
95
+ out_channels=channels[1],
96
+ block_configs=[block_config] * nb_conv_per_stage,
97
+ dim=dim,
98
+ ),
99
+ )
30
100
  if nb_class > 0:
31
101
  self.add_module("Head", UNetHead(channels[1], nb_class, dim, i), out_branch=[-1])
32
102
  if i > 0:
33
103
  if attention:
34
- self.add_module("Attention", blocks.Attention(F_g=channels[1], F_l=channels[0], F_int=channels[0], dim=dim), in_branch=[1, 0], out_branch=[1])
35
- self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim, kernel_size=2, stride=2))
104
+ self.add_module(
105
+ "Attention",
106
+ blocks.Attention(f_g=channels[1], f_l=channels[0], f_int=channels[0], dim=dim),
107
+ in_branch=[1, 0],
108
+ out_branch=[1],
109
+ )
110
+ self.add_module(
111
+ upsample_mode.name,
112
+ blocks.upsample(
113
+ in_channels=channels[1],
114
+ out_channels=channels[0],
115
+ upsample_mode=upsample_mode,
116
+ dim=dim,
117
+ kernel_size=2,
118
+ stride=2,
119
+ ),
120
+ )
36
121
  self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
37
122
 
123
+
38
124
  class UNet(network.Network):
39
125
 
40
- def __init__( self,
41
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
42
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
43
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
44
- patch : Union[ModelPatch, None] = None,
45
- dim : int = 3,
46
- channels: list[int]=[1, 64, 128, 256, 512, 1024],
47
- nb_class: int = 2,
48
- blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
49
- nb_conv_per_stage: int = 2,
50
- downSampleMode: str = "MAXPOOL",
51
- upSampleMode: str = "CONV_TRANSPOSE",
52
- attention : bool = False,
53
- blockType: str = "Conv") -> None:
54
- super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
55
- self.add_module("UNetBlock_0", UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=nb_class, dim=dim))
126
+ def __init__(
127
+ self,
128
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
129
+ schedulers: dict[str, network.LRSchedulersLoader] = {
130
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
131
+ },
132
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
133
+ patch: ModelPatch | None = None,
134
+ dim: int = 3,
135
+ channels: list[int] = [1, 64, 128, 256, 512, 1024],
136
+ nb_class: int = 2,
137
+ block_config: blocks.BlockConfig = blocks.BlockConfig(),
138
+ nb_conv_per_stage: int = 2,
139
+ downsample_mode: str = "MAXPOOL",
140
+ upsample_mode: str = "CONV_TRANSPOSE",
141
+ attention: bool = False,
142
+ block_type: str = "Conv",
143
+ ) -> None:
144
+ super().__init__(
145
+ in_channels=channels[0],
146
+ optimizer=optimizer,
147
+ schedulers=schedulers,
148
+ outputs_criterions=outputs_criterions,
149
+ patch=patch,
150
+ dim=dim,
151
+ )
152
+ self.add_module(
153
+ "UNetBlock_0",
154
+ UNetBlock(
155
+ channels,
156
+ nb_conv_per_stage,
157
+ block_config,
158
+ downsample_mode=blocks.DownsampleMode[downsample_mode],
159
+ upsample_mode=blocks.UpsampleMode[upsample_mode],
160
+ attention=attention,
161
+ block=blocks.ConvBlock if block_type == "Conv" else blocks.ResBlock,
162
+ nb_class=nb_class,
163
+ dim=dim,
164
+ ),
165
+ )