konfai 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/data/augmentation.py +1 -2
- konfai/data/data_manager.py +26 -40
- konfai/data/patching.py +7 -19
- konfai/data/transform.py +23 -21
- konfai/evaluator.py +1 -1
- konfai/main.py +0 -2
- konfai/metric/measure.py +86 -17
- konfai/metric/schedulers.py +7 -12
- konfai/models/classification/convNeXt.py +1 -2
- konfai/models/classification/resnet.py +1 -2
- konfai/models/generation/cStyleGan.py +1 -2
- konfai/models/generation/ddpm.py +1 -2
- konfai/models/generation/diffusionGan.py +10 -23
- konfai/models/generation/gan.py +3 -6
- konfai/models/generation/vae.py +0 -2
- konfai/models/registration/registration.py +1 -2
- konfai/models/representation/representation.py +1 -2
- konfai/models/segmentation/NestedUNet.py +61 -62
- konfai/models/segmentation/UNet.py +1 -2
- konfai/network/blocks.py +8 -6
- konfai/network/network.py +54 -47
- konfai/predictor.py +55 -76
- konfai/trainer.py +5 -4
- konfai/utils/utils.py +41 -3
- {konfai-1.1.5.dist-info → konfai-1.1.7.dist-info}/METADATA +1 -1
- konfai-1.1.7.dist-info/RECORD +39 -0
- konfai-1.1.5.dist-info/RECORD +0 -39
- {konfai-1.1.5.dist-info → konfai-1.1.7.dist-info}/WHEEL +0 -0
- {konfai-1.1.5.dist-info → konfai-1.1.7.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.5.dist-info → konfai-1.1.7.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.5.dist-info → konfai-1.1.7.dist-info}/top_level.txt +0 -0
|
@@ -118,10 +118,9 @@ class Generator(network.Network):
|
|
|
118
118
|
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
|
|
119
119
|
self.add_module("Tanh", torch.nn.Tanh())
|
|
120
120
|
|
|
121
|
-
@config("Generator")
|
|
122
121
|
def __init__(self,
|
|
123
122
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
124
|
-
schedulers
|
|
123
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
125
124
|
patch : ModelPatch = ModelPatch(),
|
|
126
125
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
127
126
|
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
konfai/models/generation/ddpm.py
CHANGED
|
@@ -175,10 +175,9 @@ class DDPM(network.Network):
|
|
|
175
175
|
def forward(self, input: torch.Tensor, eta: torch.Tensor, eta_hat: torch.Tensor, t: int) -> torch.Tensor:
|
|
176
176
|
return torch.concat((self.v(input, eta, t), self.v(input, eta_hat, t)), dim=1)
|
|
177
177
|
|
|
178
|
-
@config("DDPM")
|
|
179
178
|
def __init__( self,
|
|
180
179
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
181
|
-
schedulers
|
|
180
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
182
181
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
183
182
|
patch : Union[ModelPatch, None] = None,
|
|
184
183
|
train_noise_step: int = 1000,
|
|
@@ -37,10 +37,9 @@ class Discriminator(network.Network):
|
|
|
37
37
|
self.add_module("Layers", Discriminator.DiscriminatorNLayers(channels, strides, dim))
|
|
38
38
|
self.add_module("Head", Discriminator.DiscriminatorHead(channels[-1], dim))
|
|
39
39
|
|
|
40
|
-
@config("Discriminator")
|
|
41
40
|
def __init__(self,
|
|
42
41
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
43
|
-
schedulers
|
|
42
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
44
43
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
45
44
|
channels: list[int] = [1, 16, 32, 64, 64],
|
|
46
45
|
strides: list[int] = [2,2,2,1],
|
|
@@ -169,10 +168,9 @@ class Discriminator_ADA(network.Network):
|
|
|
169
168
|
self.add_module("Layers", Discriminator_ADA.DiscriminatorNLayers(channels, strides, 100, dim), in_branch=[0, "te"])
|
|
170
169
|
self.add_module("Head", Discriminator_ADA.DiscriminatorHead(channels[-1], dim))
|
|
171
170
|
|
|
172
|
-
@config("Discriminator_ADA")
|
|
173
171
|
def __init__(self,
|
|
174
172
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
175
|
-
schedulers
|
|
173
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
176
174
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
177
175
|
channels: list[int] = [1, 16, 32, 64, 64],
|
|
178
176
|
strides: list[int] = [2,2,2,1],
|
|
@@ -261,10 +259,9 @@ class Discriminator_ADA(network.Network):
|
|
|
261
259
|
self.add_module("AutoEncoder", GeneratorV1.GeneratorAutoEncoder(ngf, dim))
|
|
262
260
|
self.add_module("Head", GeneratorV1.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
|
|
263
261
|
|
|
264
|
-
@config("GeneratorV1")
|
|
265
262
|
def __init__(self,
|
|
266
263
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
267
|
-
schedulers
|
|
264
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
268
265
|
patch : ModelPatch = ModelPatch(),
|
|
269
266
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
270
267
|
dim : int = 3) -> None:
|
|
@@ -349,10 +346,9 @@ class GeneratorV1(network.Network):
|
|
|
349
346
|
self.add_module("AutoEncoder", GeneratorV1.GeneratorAutoEncoder(ngf, dim))
|
|
350
347
|
self.add_module("Head", GeneratorV1.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
|
|
351
348
|
|
|
352
|
-
@config("GeneratorV1")
|
|
353
349
|
def __init__(self,
|
|
354
350
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
355
|
-
schedulers
|
|
351
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
356
352
|
patch : ModelPatch = ModelPatch(),
|
|
357
353
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
358
354
|
dim : int = 3) -> None:
|
|
@@ -383,10 +379,9 @@ class GeneratorV2(network.Network):
|
|
|
383
379
|
self.add_module("UNetBlock_0", NestedUNet.NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
|
|
384
380
|
self.add_module("Head", GeneratorV2.NestedUNetHead(channels[:2], dim=dim), in_branch=["X_0_{}".format(len(channels)-2)])
|
|
385
381
|
|
|
386
|
-
@config("GeneratorV2")
|
|
387
382
|
def __init__( self,
|
|
388
383
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
389
|
-
schedulers
|
|
384
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
390
385
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
391
386
|
patch : Union[ModelPatch, None] = None,
|
|
392
387
|
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
@@ -424,10 +419,9 @@ class GeneratorV3(network.Network):
|
|
|
424
419
|
self.add_module("UNetBlock_0", UNet.UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=1, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
|
|
425
420
|
self.add_module("Head", GeneratorV3.NestedUNetHead(channels[:2], dim=dim), in_branch=["X_0_{}".format(len(channels)-2)])
|
|
426
421
|
|
|
427
|
-
@config("GeneratorV3")
|
|
428
422
|
def __init__( self,
|
|
429
423
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
430
|
-
schedulers
|
|
424
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
431
425
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
432
426
|
patch : Union[ModelPatch, None] = None,
|
|
433
427
|
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
@@ -443,7 +437,6 @@ class GeneratorV3(network.Network):
|
|
|
443
437
|
|
|
444
438
|
class DiffusionGan(network.Network):
|
|
445
439
|
|
|
446
|
-
@config("DiffusionGan")
|
|
447
440
|
def __init__(self, generator : GeneratorV1 = GeneratorV1(), discriminator : Discriminator_ADA = Discriminator_ADA()) -> None:
|
|
448
441
|
super().__init__()
|
|
449
442
|
self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
|
|
@@ -454,7 +447,6 @@ class DiffusionGan(network.Network):
|
|
|
454
447
|
|
|
455
448
|
class DiffusionGanV2(network.Network):
|
|
456
449
|
|
|
457
|
-
@config("DiffusionGan")
|
|
458
450
|
def __init__(self, generator : GeneratorV2 = GeneratorV2(), discriminator : Discriminator = Discriminator()) -> None:
|
|
459
451
|
super().__init__()
|
|
460
452
|
self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
|
|
@@ -466,10 +458,9 @@ class DiffusionGanV2(network.Network):
|
|
|
466
458
|
|
|
467
459
|
class CycleGanDiscriminator(network.Network):
|
|
468
460
|
|
|
469
|
-
@config("CycleGanDiscriminator")
|
|
470
461
|
def __init__(self,
|
|
471
462
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
472
|
-
schedulers
|
|
463
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
473
464
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
474
465
|
patch : Union[ModelPatch, None] = None,
|
|
475
466
|
channels: list[int] = [1, 16, 32, 64, 64],
|
|
@@ -485,10 +476,9 @@ class CycleGanDiscriminator(network.Network):
|
|
|
485
476
|
|
|
486
477
|
class CycleGanGeneratorV1(network.Network):
|
|
487
478
|
|
|
488
|
-
@config("CycleGanGeneratorV1")
|
|
489
479
|
def __init__(self,
|
|
490
480
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
491
|
-
schedulers
|
|
481
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
492
482
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
493
483
|
patch : Union[ModelPatch, None] = None,
|
|
494
484
|
dim : int = 3) -> None:
|
|
@@ -498,10 +488,9 @@ class CycleGanGeneratorV1(network.Network):
|
|
|
498
488
|
|
|
499
489
|
class CycleGanGeneratorV2(network.Network):
|
|
500
490
|
|
|
501
|
-
@config("CycleGanGeneratorV2")
|
|
502
491
|
def __init__(self,
|
|
503
492
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
504
|
-
schedulers
|
|
493
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
505
494
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
506
495
|
patch : Union[ModelPatch, None] = None,
|
|
507
496
|
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
@@ -518,10 +507,9 @@ class CycleGanGeneratorV2(network.Network):
|
|
|
518
507
|
|
|
519
508
|
class CycleGanGeneratorV3(network.Network):
|
|
520
509
|
|
|
521
|
-
@config("CycleGanGeneratorV3")
|
|
522
510
|
def __init__(self,
|
|
523
511
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
524
|
-
schedulers
|
|
512
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
525
513
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
526
514
|
patch : Union[ModelPatch, None] = None,
|
|
527
515
|
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
@@ -538,7 +526,6 @@ class CycleGanGeneratorV3(network.Network):
|
|
|
538
526
|
|
|
539
527
|
class DiffusionCycleGan(network.Network):
|
|
540
528
|
|
|
541
|
-
@config("DiffusionCycleGan")
|
|
542
529
|
def __init__(self, generators : CycleGanGeneratorV3 = CycleGanGeneratorV3(), discriminators : CycleGanDiscriminator = CycleGanDiscriminator()) -> None:
|
|
543
530
|
super().__init__()
|
|
544
531
|
self.add_module("Generator", generators, in_branch=[0, 1], out_branch=["pB", "pA"])
|
konfai/models/generation/gan.py
CHANGED
|
@@ -23,10 +23,9 @@ class Discriminator(network.Network):
|
|
|
23
23
|
self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(tuple([1]*dim)))
|
|
24
24
|
self.add_module("Flatten", torch.nn.Flatten(1))
|
|
25
25
|
|
|
26
|
-
@config("Discriminator")
|
|
27
26
|
def __init__(self,
|
|
28
27
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
29
|
-
schedulers
|
|
28
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
30
29
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
31
30
|
nb_batch_per_step: int = 64,
|
|
32
31
|
dim : int = 3) -> None:
|
|
@@ -101,11 +100,10 @@ class Generator(network.Network):
|
|
|
101
100
|
self.add_module("Encoder", Generator.GeneratorEncoder(channels, dim))
|
|
102
101
|
self.add_module("NResBlock", Generator.GeneratorNResnetBlock(channels=channels[-1], nb_conv=6, dim=dim))
|
|
103
102
|
self.add_module("Decoder", Generator.GeneratorDecoder(channels, dim))
|
|
104
|
-
|
|
105
|
-
@config("Generator")
|
|
103
|
+
|
|
106
104
|
def __init__(self,
|
|
107
105
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
108
|
-
schedulers
|
|
106
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
109
107
|
patch : ModelPatch = ModelPatch(),
|
|
110
108
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
111
109
|
nb_batch_per_step: int = 64,
|
|
@@ -121,7 +119,6 @@ class Generator(network.Network):
|
|
|
121
119
|
|
|
122
120
|
class Gan(network.Network):
|
|
123
121
|
|
|
124
|
-
@config("Gan")
|
|
125
122
|
def __init__(self, generator : Generator = Generator(), discriminator : Discriminator = Discriminator()) -> None:
|
|
126
123
|
super().__init__()
|
|
127
124
|
self.add_module("Discriminator_B", discriminator, in_branch=[1], out_branch=[-1], requires_grad=True)
|
konfai/models/generation/vae.py
CHANGED
|
@@ -25,7 +25,6 @@ class VAE(network.Network):
|
|
|
25
25
|
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
|
|
26
26
|
self.add_module("Tanh", torch.nn.Tanh())
|
|
27
27
|
|
|
28
|
-
@config("VAE")
|
|
29
28
|
def __init__(self,
|
|
30
29
|
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
31
30
|
schedulers: network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
@@ -60,7 +59,6 @@ class LinearVAE(network.Network):
|
|
|
60
59
|
self.add_module("Linear", torch.nn.Linear(in_features, out_features))
|
|
61
60
|
self.add_module("Tanh", torch.nn.Tanh())
|
|
62
61
|
|
|
63
|
-
@config("LinearVAE")
|
|
64
62
|
def __init__(self,
|
|
65
63
|
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
66
64
|
schedulers: network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
@@ -8,10 +8,9 @@ from konfai.models.segmentation import UNet
|
|
|
8
8
|
|
|
9
9
|
class VoxelMorph(network.Network):
|
|
10
10
|
|
|
11
|
-
@config("VoxelMorph")
|
|
12
11
|
def __init__( self,
|
|
13
12
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
14
|
-
schedulers
|
|
13
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
15
14
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
16
15
|
dim : int = 3,
|
|
17
16
|
channels : list[int] = [4, 16,32,32,32],
|
|
@@ -47,10 +47,9 @@ class Adaptation(torch.nn.Module):
|
|
|
47
47
|
|
|
48
48
|
class Representation(network.Network):
|
|
49
49
|
|
|
50
|
-
@config("Representation")
|
|
51
50
|
def __init__( self,
|
|
52
51
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
53
|
-
schedulers
|
|
52
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
54
53
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
55
54
|
dim : int = 3):
|
|
56
55
|
super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=dim, init_type="kaiming")
|
|
@@ -3,23 +3,23 @@ from typing import Union
|
|
|
3
3
|
import torch
|
|
4
4
|
from konfai.data.patching import ModelPatch
|
|
5
5
|
|
|
6
|
-
|
|
7
|
-
class NestedUNetBlock(network.ModuleArgsDict):
|
|
6
|
+
class NestedUNet(network.Network):
|
|
8
7
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
self.add_module("
|
|
17
|
-
|
|
18
|
-
self.add_module("
|
|
19
|
-
|
|
20
|
-
|
|
8
|
+
class NestedUNetBlock(network.ModuleArgsDict):
|
|
9
|
+
|
|
10
|
+
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, dim: int, i : int = 0) -> None:
|
|
11
|
+
super().__init__()
|
|
12
|
+
if i > 0:
|
|
13
|
+
self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
|
|
14
|
+
|
|
15
|
+
self.add_module("X_{}_{}".format(i, 0), block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), out_branch=["X_{}_{}".format(i, 0)])
|
|
16
|
+
if len(channels) > 2:
|
|
17
|
+
self.add_module("UNetBlock_{}".format(i+1), NestedUNet.NestedUNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(channels)-2)])
|
|
18
|
+
for j in range(len(channels)-2):
|
|
19
|
+
self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=channels[2], out_channels=channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
|
|
20
|
+
self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
21
|
+
self.add_module("X_{}_{}".format(i, j+1), block(in_channels=(channels[1]*(j+1)+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*(j+2), out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
21
22
|
|
|
22
|
-
class NestedUNet(network.Network):
|
|
23
23
|
|
|
24
24
|
class NestedUNetHead(network.ModuleArgsDict):
|
|
25
25
|
|
|
@@ -34,7 +34,7 @@ class NestedUNet(network.Network):
|
|
|
34
34
|
|
|
35
35
|
def __init__( self,
|
|
36
36
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
37
|
-
schedulers
|
|
37
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
38
38
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
39
39
|
patch : Union[ModelPatch, None] = None,
|
|
40
40
|
dim : int = 3,
|
|
@@ -49,61 +49,60 @@ class NestedUNet(network.Network):
|
|
|
49
49
|
activation: str = "Softmax") -> None:
|
|
50
50
|
super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
|
|
51
51
|
|
|
52
|
-
self.add_module("UNetBlock_0", NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
|
|
52
|
+
self.add_module("UNetBlock_0", NestedUNet.NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
|
|
53
53
|
for j in range(len(channels)-2):
|
|
54
54
|
self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, activation=activation, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
|
|
55
55
|
|
|
56
56
|
|
|
57
|
+
class UNetpp(network.Network):
|
|
57
58
|
|
|
58
|
-
|
|
59
|
-
class ResNetEncoderLayer(network.ModuleArgsDict):
|
|
60
|
-
|
|
61
|
-
def __init__(self, in_channel: int, out_channel: int, nb_block: int, dim: int, downSampleMode : blocks.DownSampleMode):
|
|
62
|
-
super().__init__()
|
|
63
|
-
for i in range(nb_block):
|
|
64
|
-
if downSampleMode == blocks.DownSampleMode.MAXPOOL and i == 0:
|
|
65
|
-
self.add_module("DownSample", blocks.getTorchModule("MaxPool", dim)(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
|
|
66
|
-
self.add_module("ResBlock_{}".format(i), blocks.ResBlock(in_channel, out_channel, [blocks.BlockConfig(3, 2 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i == 0 else 1, 1, False, "ReLU;True", blocks.NormMode.BATCH), blocks.BlockConfig(3, 1, 1, False, None, blocks.NormMode.BATCH)], dim=dim))
|
|
67
|
-
in_channel = out_channel
|
|
68
|
-
|
|
69
|
-
def resNetEncoder(channels: list[int], layers: list[int], dim: int) -> list[torch.nn.Module]:
|
|
70
|
-
modules = []
|
|
71
|
-
modules.append(blocks.ConvBlock(channels[0], channels[1], [blocks.BlockConfig(7, 2, 3, False, "ReLU", blocks.NormMode.BATCH)], dim=dim))
|
|
72
|
-
for i, (in_channel, out_channel, layer) in enumerate(zip(channels[1:], channels[2:], layers)):
|
|
73
|
-
modules.append(ResNetEncoderLayer(in_channel, out_channel, layer, dim, blocks.DownSampleMode.MAXPOOL if i==0 else blocks.DownSampleMode.CONV_STRIDE))
|
|
74
|
-
return modules
|
|
75
|
-
|
|
76
|
-
class NestedUNetBlock(network.ModuleArgsDict):
|
|
77
|
-
|
|
78
|
-
def __init__(self, encoder_channels: list[int], decoder_channels: list[int], encoders: list[torch.nn.Module], upSampleMode: blocks.UpSampleMode, dim: int, i : int = 0) -> None:
|
|
79
|
-
super().__init__()
|
|
80
|
-
self.add_module("X_{}_{}".format(i, 0), encoders[0], out_branch=["X_{}_{}".format(i, 0)])
|
|
81
|
-
if len(encoder_channels) > 2:
|
|
82
|
-
self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(encoder_channels[1:], decoder_channels[1:], encoders[1:], upSampleMode, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(encoder_channels)-2)])
|
|
83
|
-
for j in range(len(encoder_channels)-2):
|
|
84
|
-
self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=encoder_channels[2], out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
|
|
85
|
-
self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
86
|
-
self.add_module("X_{}_{}".format(i, j+1), blocks.ConvBlock(in_channels=encoder_channels[1]*j+((encoder_channels[1]+encoder_channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else (decoder_channels[1]+decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1]+encoder_channels[1])), out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
59
|
+
class ResNetEncoderLayer(network.ModuleArgsDict):
|
|
87
60
|
|
|
88
|
-
|
|
61
|
+
def __init__(self, in_channel: int, out_channel: int, nb_block: int, dim: int, downSampleMode : blocks.DownSampleMode):
|
|
62
|
+
super().__init__()
|
|
63
|
+
for i in range(nb_block):
|
|
64
|
+
if downSampleMode == blocks.DownSampleMode.MAXPOOL and i == 0:
|
|
65
|
+
self.add_module("DownSample", blocks.getTorchModule("MaxPool", dim)(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
|
|
66
|
+
self.add_module("ResBlock_{}".format(i), blocks.ResBlock(in_channel, out_channel, [blocks.BlockConfig(3, 2 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i == 0 else 1, 1, False, "ReLU;True", blocks.NormMode.BATCH), blocks.BlockConfig(3, 1, 1, False, None, blocks.NormMode.BATCH)], dim=dim))
|
|
67
|
+
in_channel = out_channel
|
|
68
|
+
|
|
69
|
+
def resNetEncoder(channels: list[int], layers: list[int], dim: int) -> list[torch.nn.Module]:
|
|
70
|
+
modules = []
|
|
71
|
+
modules.append(blocks.ConvBlock(channels[0], channels[1], [blocks.BlockConfig(7, 2, 3, False, "ReLU", blocks.NormMode.BATCH)], dim=dim))
|
|
72
|
+
for i, (in_channel, out_channel, layer) in enumerate(zip(channels[1:], channels[2:], layers)):
|
|
73
|
+
modules.append(UNetpp.ResNetEncoderLayer(in_channel, out_channel, layer, dim, blocks.DownSampleMode.MAXPOOL if i==0 else blocks.DownSampleMode.CONV_STRIDE))
|
|
74
|
+
return modules
|
|
89
75
|
|
|
90
|
-
|
|
91
|
-
super().__init__()
|
|
92
|
-
self.add_module("Upsample", blocks.upSample(in_channels=in_channels, out_channels=out_channels, upSampleMode=blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim))
|
|
93
|
-
self.add_module("ConvBlock", blocks.ConvBlock(in_channels=in_channels, out_channels=out_channels, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim))
|
|
94
|
-
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = out_channels, out_channels = nb_class, kernel_size = 3, stride = 1, padding = 1))
|
|
95
|
-
if nb_class > 1:
|
|
96
|
-
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
97
|
-
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
98
|
-
else:
|
|
99
|
-
self.add_module("Tanh", torch.nn.Tanh())
|
|
100
|
-
|
|
101
|
-
class UNetpp(network.Network):
|
|
76
|
+
class UNetPPBlock(network.ModuleArgsDict):
|
|
102
77
|
|
|
78
|
+
def __init__(self, encoder_channels: list[int], decoder_channels: list[int], encoders: list[torch.nn.Module], upSampleMode: blocks.UpSampleMode, dim: int, i : int = 0) -> None:
|
|
79
|
+
super().__init__()
|
|
80
|
+
self.add_module("X_{}_{}".format(i, 0), encoders[0], out_branch=["X_{}_{}".format(i, 0)])
|
|
81
|
+
if len(encoder_channels) > 2:
|
|
82
|
+
self.add_module("UNetBlock_{}".format(i+1), UNetpp.UNetPPBlock(encoder_channels[1:], decoder_channels[1:], encoders[1:], upSampleMode, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(encoder_channels)-2)])
|
|
83
|
+
for j in range(len(encoder_channels)-2):
|
|
84
|
+
in_channels = decoder_channels[3] if j == len(encoder_channels)-3 and len(encoder_channels) > 3 else encoder_channels[2]
|
|
85
|
+
out_channel = decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1]
|
|
86
|
+
self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=in_channels, out_channels=out_channel, upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
|
|
87
|
+
self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
88
|
+
self.add_module("X_{}_{}".format(i, j+1), blocks.ConvBlock(in_channels=encoder_channels[1]*(j+1)+(in_channels if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else out_channel), out_channels=out_channel, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
89
|
+
|
|
90
|
+
class UNetPPHead(network.ModuleArgsDict):
|
|
103
91
|
|
|
92
|
+
def __init__(self, in_channels: int, out_channels: int, nb_class: int, dim: int) -> None:
|
|
93
|
+
super().__init__()
|
|
94
|
+
self.add_module("Upsample", blocks.upSample(in_channels=in_channels, out_channels=out_channels, upSampleMode=blocks.UpSampleMode.UPSAMPLE, dim=dim))
|
|
95
|
+
self.add_module("ConvBlock", blocks.ConvBlock(in_channels=in_channels, out_channels=out_channels, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim))
|
|
96
|
+
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = out_channels, out_channels = nb_class, kernel_size = 3, stride = 1, padding = 1))
|
|
97
|
+
if nb_class > 1:
|
|
98
|
+
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
99
|
+
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
100
|
+
else:
|
|
101
|
+
self.add_module("Tanh", torch.nn.Tanh())
|
|
102
|
+
|
|
104
103
|
def __init__( self,
|
|
105
104
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
106
|
-
schedulers
|
|
105
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
107
106
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
108
107
|
patch : Union[ModelPatch, None] = None,
|
|
109
108
|
encoder_channels: list[int] = [1,64,64,128,256,512],
|
|
@@ -111,5 +110,5 @@ class UNetpp(network.Network):
|
|
|
111
110
|
layers: list[int] = [3,4,6,3],
|
|
112
111
|
dim : int = 2) -> None:
|
|
113
112
|
super().__init__(in_channels = encoder_channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
|
|
114
|
-
self.add_module("Block_0",
|
|
115
|
-
self.add_module("Head",
|
|
113
|
+
self.add_module("Block_0", UNetpp.UNetPPBlock(encoder_channels, decoder_channels[::-1], UNetpp.resNetEncoder(encoder_channels, layers, dim), blocks.UpSampleMode.CONV_TRANSPOSE, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(encoder_channels)-2)])
|
|
114
|
+
self.add_module("Head", UNetpp.UNetPPHead(in_channels=decoder_channels[-3], out_channels=decoder_channels[-2], nb_class=decoder_channels[-1], dim=dim), in_branch=["X_0_{}".format(len(encoder_channels)-2)], out_branch=[-1])
|
|
@@ -12,7 +12,6 @@ class UNetHead(network.ModuleArgsDict):
|
|
|
12
12
|
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
|
13
13
|
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
14
14
|
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
15
|
-
|
|
16
15
|
|
|
17
16
|
class UNetBlock(network.ModuleArgsDict):
|
|
18
17
|
|
|
@@ -40,7 +39,7 @@ class UNet(network.Network):
|
|
|
40
39
|
|
|
41
40
|
def __init__( self,
|
|
42
41
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
43
|
-
schedulers
|
|
42
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
44
43
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
45
44
|
patch : Union[ModelPatch, None] = None,
|
|
46
45
|
dim : int = 3,
|
konfai/network/blocks.py
CHANGED
|
@@ -36,11 +36,7 @@ def getNorm(normMode: Enum, channels : int, dim: int) -> Union[torch.nn.Module,
|
|
|
36
36
|
|
|
37
37
|
class UpSampleMode(Enum):
|
|
38
38
|
CONV_TRANSPOSE = 0,
|
|
39
|
-
|
|
40
|
-
UPSAMPLE_LINEAR = 2,
|
|
41
|
-
UPSAMPLE_BILINEAR = 3,
|
|
42
|
-
UPSAMPLE_BICUBIC = 4,
|
|
43
|
-
UPSAMPLE_TRILINEAR = 5
|
|
39
|
+
UPSAMPLE = 1,
|
|
44
40
|
|
|
45
41
|
class DownSampleMode(Enum):
|
|
46
42
|
MAXPOOL = 0,
|
|
@@ -128,7 +124,13 @@ def upSample(in_channels: int, out_channels: int, upSampleMode: UpSampleMode, di
|
|
|
128
124
|
if upSampleMode == UpSampleMode.CONV_TRANSPOSE:
|
|
129
125
|
return getTorchModule("ConvTranspose", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, stride = stride, padding = 0)
|
|
130
126
|
else:
|
|
131
|
-
|
|
127
|
+
if dim == 3:
|
|
128
|
+
upsampleMethod = "trilinear"
|
|
129
|
+
if dim == 2:
|
|
130
|
+
upsampleMethod = "bilinear"
|
|
131
|
+
if dim == 1:
|
|
132
|
+
upsampleMethod = "linear"
|
|
133
|
+
return torch.nn.Upsample(scale_factor=2, mode=upsampleMethod.lower(), align_corners=False)
|
|
132
134
|
|
|
133
135
|
class Unsqueeze(torch.nn.Module):
|
|
134
136
|
|