konfai 1.1.6__py3-none-any.whl → 1.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/data/augmentation.py +0 -1
- konfai/data/data_manager.py +25 -39
- konfai/data/patching.py +7 -19
- konfai/data/transform.py +1 -2
- konfai/metric/measure.py +85 -12
- konfai/metric/schedulers.py +2 -16
- konfai/models/classification/convNeXt.py +1 -2
- konfai/models/classification/resnet.py +1 -2
- konfai/models/generation/cStyleGan.py +1 -2
- konfai/models/generation/ddpm.py +1 -2
- konfai/models/generation/diffusionGan.py +10 -23
- konfai/models/generation/gan.py +3 -6
- konfai/models/generation/vae.py +0 -2
- konfai/models/registration/registration.py +1 -2
- konfai/models/representation/representation.py +1 -2
- konfai/models/segmentation/NestedUNet.py +10 -9
- konfai/models/segmentation/UNet.py +1 -1
- konfai/network/blocks.py +8 -6
- konfai/network/network.py +35 -49
- konfai/predictor.py +39 -59
- konfai/utils/utils.py +40 -2
- {konfai-1.1.6.dist-info → konfai-1.1.8.dist-info}/METADATA +1 -1
- konfai-1.1.8.dist-info/RECORD +39 -0
- konfai-1.1.6.dist-info/RECORD +0 -39
- {konfai-1.1.6.dist-info → konfai-1.1.8.dist-info}/WHEEL +0 -0
- {konfai-1.1.6.dist-info → konfai-1.1.8.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.6.dist-info → konfai-1.1.8.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.6.dist-info → konfai-1.1.8.dist-info}/top_level.txt +0 -0
|
@@ -37,10 +37,9 @@ class Discriminator(network.Network):
|
|
|
37
37
|
self.add_module("Layers", Discriminator.DiscriminatorNLayers(channels, strides, dim))
|
|
38
38
|
self.add_module("Head", Discriminator.DiscriminatorHead(channels[-1], dim))
|
|
39
39
|
|
|
40
|
-
@config("Discriminator")
|
|
41
40
|
def __init__(self,
|
|
42
41
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
43
|
-
schedulers
|
|
42
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
44
43
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
45
44
|
channels: list[int] = [1, 16, 32, 64, 64],
|
|
46
45
|
strides: list[int] = [2,2,2,1],
|
|
@@ -169,10 +168,9 @@ class Discriminator_ADA(network.Network):
|
|
|
169
168
|
self.add_module("Layers", Discriminator_ADA.DiscriminatorNLayers(channels, strides, 100, dim), in_branch=[0, "te"])
|
|
170
169
|
self.add_module("Head", Discriminator_ADA.DiscriminatorHead(channels[-1], dim))
|
|
171
170
|
|
|
172
|
-
@config("Discriminator_ADA")
|
|
173
171
|
def __init__(self,
|
|
174
172
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
175
|
-
schedulers
|
|
173
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
176
174
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
177
175
|
channels: list[int] = [1, 16, 32, 64, 64],
|
|
178
176
|
strides: list[int] = [2,2,2,1],
|
|
@@ -261,10 +259,9 @@ class Discriminator_ADA(network.Network):
|
|
|
261
259
|
self.add_module("AutoEncoder", GeneratorV1.GeneratorAutoEncoder(ngf, dim))
|
|
262
260
|
self.add_module("Head", GeneratorV1.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
|
|
263
261
|
|
|
264
|
-
@config("GeneratorV1")
|
|
265
262
|
def __init__(self,
|
|
266
263
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
267
|
-
schedulers
|
|
264
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
268
265
|
patch : ModelPatch = ModelPatch(),
|
|
269
266
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
270
267
|
dim : int = 3) -> None:
|
|
@@ -349,10 +346,9 @@ class GeneratorV1(network.Network):
|
|
|
349
346
|
self.add_module("AutoEncoder", GeneratorV1.GeneratorAutoEncoder(ngf, dim))
|
|
350
347
|
self.add_module("Head", GeneratorV1.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
|
|
351
348
|
|
|
352
|
-
@config("GeneratorV1")
|
|
353
349
|
def __init__(self,
|
|
354
350
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
355
|
-
schedulers
|
|
351
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
356
352
|
patch : ModelPatch = ModelPatch(),
|
|
357
353
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
358
354
|
dim : int = 3) -> None:
|
|
@@ -383,10 +379,9 @@ class GeneratorV2(network.Network):
|
|
|
383
379
|
self.add_module("UNetBlock_0", NestedUNet.NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
|
|
384
380
|
self.add_module("Head", GeneratorV2.NestedUNetHead(channels[:2], dim=dim), in_branch=["X_0_{}".format(len(channels)-2)])
|
|
385
381
|
|
|
386
|
-
@config("GeneratorV2")
|
|
387
382
|
def __init__( self,
|
|
388
383
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
389
|
-
schedulers
|
|
384
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
390
385
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
391
386
|
patch : Union[ModelPatch, None] = None,
|
|
392
387
|
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
@@ -424,10 +419,9 @@ class GeneratorV3(network.Network):
|
|
|
424
419
|
self.add_module("UNetBlock_0", UNet.UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=1, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
|
|
425
420
|
self.add_module("Head", GeneratorV3.NestedUNetHead(channels[:2], dim=dim), in_branch=["X_0_{}".format(len(channels)-2)])
|
|
426
421
|
|
|
427
|
-
@config("GeneratorV3")
|
|
428
422
|
def __init__( self,
|
|
429
423
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
430
|
-
schedulers
|
|
424
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
431
425
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
432
426
|
patch : Union[ModelPatch, None] = None,
|
|
433
427
|
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
@@ -443,7 +437,6 @@ class GeneratorV3(network.Network):
|
|
|
443
437
|
|
|
444
438
|
class DiffusionGan(network.Network):
|
|
445
439
|
|
|
446
|
-
@config("DiffusionGan")
|
|
447
440
|
def __init__(self, generator : GeneratorV1 = GeneratorV1(), discriminator : Discriminator_ADA = Discriminator_ADA()) -> None:
|
|
448
441
|
super().__init__()
|
|
449
442
|
self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
|
|
@@ -454,7 +447,6 @@ class DiffusionGan(network.Network):
|
|
|
454
447
|
|
|
455
448
|
class DiffusionGanV2(network.Network):
|
|
456
449
|
|
|
457
|
-
@config("DiffusionGan")
|
|
458
450
|
def __init__(self, generator : GeneratorV2 = GeneratorV2(), discriminator : Discriminator = Discriminator()) -> None:
|
|
459
451
|
super().__init__()
|
|
460
452
|
self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
|
|
@@ -466,10 +458,9 @@ class DiffusionGanV2(network.Network):
|
|
|
466
458
|
|
|
467
459
|
class CycleGanDiscriminator(network.Network):
|
|
468
460
|
|
|
469
|
-
@config("CycleGanDiscriminator")
|
|
470
461
|
def __init__(self,
|
|
471
462
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
472
|
-
schedulers
|
|
463
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
473
464
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
474
465
|
patch : Union[ModelPatch, None] = None,
|
|
475
466
|
channels: list[int] = [1, 16, 32, 64, 64],
|
|
@@ -485,10 +476,9 @@ class CycleGanDiscriminator(network.Network):
|
|
|
485
476
|
|
|
486
477
|
class CycleGanGeneratorV1(network.Network):
|
|
487
478
|
|
|
488
|
-
@config("CycleGanGeneratorV1")
|
|
489
479
|
def __init__(self,
|
|
490
480
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
491
|
-
schedulers
|
|
481
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
492
482
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
493
483
|
patch : Union[ModelPatch, None] = None,
|
|
494
484
|
dim : int = 3) -> None:
|
|
@@ -498,10 +488,9 @@ class CycleGanGeneratorV1(network.Network):
|
|
|
498
488
|
|
|
499
489
|
class CycleGanGeneratorV2(network.Network):
|
|
500
490
|
|
|
501
|
-
@config("CycleGanGeneratorV2")
|
|
502
491
|
def __init__(self,
|
|
503
492
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
504
|
-
schedulers
|
|
493
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
505
494
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
506
495
|
patch : Union[ModelPatch, None] = None,
|
|
507
496
|
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
@@ -518,10 +507,9 @@ class CycleGanGeneratorV2(network.Network):
|
|
|
518
507
|
|
|
519
508
|
class CycleGanGeneratorV3(network.Network):
|
|
520
509
|
|
|
521
|
-
@config("CycleGanGeneratorV3")
|
|
522
510
|
def __init__(self,
|
|
523
511
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
524
|
-
schedulers
|
|
512
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
525
513
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
526
514
|
patch : Union[ModelPatch, None] = None,
|
|
527
515
|
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
@@ -538,7 +526,6 @@ class CycleGanGeneratorV3(network.Network):
|
|
|
538
526
|
|
|
539
527
|
class DiffusionCycleGan(network.Network):
|
|
540
528
|
|
|
541
|
-
@config("DiffusionCycleGan")
|
|
542
529
|
def __init__(self, generators : CycleGanGeneratorV3 = CycleGanGeneratorV3(), discriminators : CycleGanDiscriminator = CycleGanDiscriminator()) -> None:
|
|
543
530
|
super().__init__()
|
|
544
531
|
self.add_module("Generator", generators, in_branch=[0, 1], out_branch=["pB", "pA"])
|
konfai/models/generation/gan.py
CHANGED
|
@@ -23,10 +23,9 @@ class Discriminator(network.Network):
|
|
|
23
23
|
self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(tuple([1]*dim)))
|
|
24
24
|
self.add_module("Flatten", torch.nn.Flatten(1))
|
|
25
25
|
|
|
26
|
-
@config("Discriminator")
|
|
27
26
|
def __init__(self,
|
|
28
27
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
29
|
-
schedulers
|
|
28
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
30
29
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
31
30
|
nb_batch_per_step: int = 64,
|
|
32
31
|
dim : int = 3) -> None:
|
|
@@ -101,11 +100,10 @@ class Generator(network.Network):
|
|
|
101
100
|
self.add_module("Encoder", Generator.GeneratorEncoder(channels, dim))
|
|
102
101
|
self.add_module("NResBlock", Generator.GeneratorNResnetBlock(channels=channels[-1], nb_conv=6, dim=dim))
|
|
103
102
|
self.add_module("Decoder", Generator.GeneratorDecoder(channels, dim))
|
|
104
|
-
|
|
105
|
-
@config("Generator")
|
|
103
|
+
|
|
106
104
|
def __init__(self,
|
|
107
105
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
108
|
-
schedulers
|
|
106
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
109
107
|
patch : ModelPatch = ModelPatch(),
|
|
110
108
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
111
109
|
nb_batch_per_step: int = 64,
|
|
@@ -121,7 +119,6 @@ class Generator(network.Network):
|
|
|
121
119
|
|
|
122
120
|
class Gan(network.Network):
|
|
123
121
|
|
|
124
|
-
@config("Gan")
|
|
125
122
|
def __init__(self, generator : Generator = Generator(), discriminator : Discriminator = Discriminator()) -> None:
|
|
126
123
|
super().__init__()
|
|
127
124
|
self.add_module("Discriminator_B", discriminator, in_branch=[1], out_branch=[-1], requires_grad=True)
|
konfai/models/generation/vae.py
CHANGED
|
@@ -25,7 +25,6 @@ class VAE(network.Network):
|
|
|
25
25
|
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
|
|
26
26
|
self.add_module("Tanh", torch.nn.Tanh())
|
|
27
27
|
|
|
28
|
-
@config("VAE")
|
|
29
28
|
def __init__(self,
|
|
30
29
|
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
31
30
|
schedulers: network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
@@ -60,7 +59,6 @@ class LinearVAE(network.Network):
|
|
|
60
59
|
self.add_module("Linear", torch.nn.Linear(in_features, out_features))
|
|
61
60
|
self.add_module("Tanh", torch.nn.Tanh())
|
|
62
61
|
|
|
63
|
-
@config("LinearVAE")
|
|
64
62
|
def __init__(self,
|
|
65
63
|
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
66
64
|
schedulers: network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
@@ -8,10 +8,9 @@ from konfai.models.segmentation import UNet
|
|
|
8
8
|
|
|
9
9
|
class VoxelMorph(network.Network):
|
|
10
10
|
|
|
11
|
-
@config("VoxelMorph")
|
|
12
11
|
def __init__( self,
|
|
13
12
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
14
|
-
schedulers
|
|
13
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
15
14
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
16
15
|
dim : int = 3,
|
|
17
16
|
channels : list[int] = [4, 16,32,32,32],
|
|
@@ -47,10 +47,9 @@ class Adaptation(torch.nn.Module):
|
|
|
47
47
|
|
|
48
48
|
class Representation(network.Network):
|
|
49
49
|
|
|
50
|
-
@config("Representation")
|
|
51
50
|
def __init__( self,
|
|
52
51
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
53
|
-
schedulers
|
|
52
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
54
53
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
55
54
|
dim : int = 3):
|
|
56
55
|
super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=dim, init_type="kaiming")
|
|
@@ -34,7 +34,7 @@ class NestedUNet(network.Network):
|
|
|
34
34
|
|
|
35
35
|
def __init__( self,
|
|
36
36
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
37
|
-
schedulers
|
|
37
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
38
38
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
39
39
|
patch : Union[ModelPatch, None] = None,
|
|
40
40
|
dim : int = 3,
|
|
@@ -81,15 +81,17 @@ class UNetpp(network.Network):
|
|
|
81
81
|
if len(encoder_channels) > 2:
|
|
82
82
|
self.add_module("UNetBlock_{}".format(i+1), UNetpp.UNetPPBlock(encoder_channels[1:], decoder_channels[1:], encoders[1:], upSampleMode, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(encoder_channels)-2)])
|
|
83
83
|
for j in range(len(encoder_channels)-2):
|
|
84
|
-
|
|
84
|
+
in_channels = decoder_channels[3] if j == len(encoder_channels)-3 and len(encoder_channels) > 3 else encoder_channels[2]
|
|
85
|
+
out_channel = decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1]
|
|
86
|
+
self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=in_channels, out_channels=out_channel, upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
|
|
85
87
|
self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
86
|
-
self.add_module("X_{}_{}".format(i, j+1), blocks.ConvBlock(in_channels=encoder_channels[1]*j+
|
|
87
|
-
|
|
88
|
+
self.add_module("X_{}_{}".format(i, j+1), blocks.ConvBlock(in_channels=encoder_channels[1]*(j+1)+(in_channels if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else out_channel), out_channels=out_channel, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
|
|
89
|
+
|
|
88
90
|
class UNetPPHead(network.ModuleArgsDict):
|
|
89
91
|
|
|
90
92
|
def __init__(self, in_channels: int, out_channels: int, nb_class: int, dim: int) -> None:
|
|
91
93
|
super().__init__()
|
|
92
|
-
self.add_module("Upsample", blocks.upSample(in_channels=in_channels, out_channels=out_channels, upSampleMode=blocks.UpSampleMode.
|
|
94
|
+
self.add_module("Upsample", blocks.upSample(in_channels=in_channels, out_channels=out_channels, upSampleMode=blocks.UpSampleMode.UPSAMPLE, dim=dim))
|
|
93
95
|
self.add_module("ConvBlock", blocks.ConvBlock(in_channels=in_channels, out_channels=out_channels, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim))
|
|
94
96
|
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = out_channels, out_channels = nb_class, kernel_size = 3, stride = 1, padding = 1))
|
|
95
97
|
if nb_class > 1:
|
|
@@ -100,7 +102,7 @@ class UNetpp(network.Network):
|
|
|
100
102
|
|
|
101
103
|
def __init__( self,
|
|
102
104
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
103
|
-
schedulers
|
|
105
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
104
106
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
105
107
|
patch : Union[ModelPatch, None] = None,
|
|
106
108
|
encoder_channels: list[int] = [1,64,64,128,256,512],
|
|
@@ -108,6 +110,5 @@ class UNetpp(network.Network):
|
|
|
108
110
|
layers: list[int] = [3,4,6,3],
|
|
109
111
|
dim : int = 2) -> None:
|
|
110
112
|
super().__init__(in_channels = encoder_channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
|
|
111
|
-
self.add_module("Block_0", UNetpp.UNetPPBlock(encoder_channels, decoder_channels[::-1], UNetpp.resNetEncoder(encoder_channels, layers, dim), blocks.UpSampleMode.
|
|
112
|
-
self.add_module("Head", UNetpp.UNetPPHead(in_channels=decoder_channels[-3], out_channels=decoder_channels[-2], nb_class=decoder_channels[-1], dim=dim), in_branch=["X_0_{}".format(len(encoder_channels)-2)], out_branch=[-1])
|
|
113
|
-
|
|
113
|
+
self.add_module("Block_0", UNetpp.UNetPPBlock(encoder_channels, decoder_channels[::-1], UNetpp.resNetEncoder(encoder_channels, layers, dim), blocks.UpSampleMode.CONV_TRANSPOSE, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(encoder_channels)-2)])
|
|
114
|
+
self.add_module("Head", UNetpp.UNetPPHead(in_channels=decoder_channels[-3], out_channels=decoder_channels[-2], nb_class=decoder_channels[-1], dim=dim), in_branch=["X_0_{}".format(len(encoder_channels)-2)], out_branch=[-1])
|
|
@@ -39,7 +39,7 @@ class UNet(network.Network):
|
|
|
39
39
|
|
|
40
40
|
def __init__( self,
|
|
41
41
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
42
|
-
schedulers
|
|
42
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
43
43
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
44
44
|
patch : Union[ModelPatch, None] = None,
|
|
45
45
|
dim : int = 3,
|
konfai/network/blocks.py
CHANGED
|
@@ -36,11 +36,7 @@ def getNorm(normMode: Enum, channels : int, dim: int) -> Union[torch.nn.Module,
|
|
|
36
36
|
|
|
37
37
|
class UpSampleMode(Enum):
|
|
38
38
|
CONV_TRANSPOSE = 0,
|
|
39
|
-
|
|
40
|
-
UPSAMPLE_LINEAR = 2,
|
|
41
|
-
UPSAMPLE_BILINEAR = 3,
|
|
42
|
-
UPSAMPLE_BICUBIC = 4,
|
|
43
|
-
UPSAMPLE_TRILINEAR = 5
|
|
39
|
+
UPSAMPLE = 1,
|
|
44
40
|
|
|
45
41
|
class DownSampleMode(Enum):
|
|
46
42
|
MAXPOOL = 0,
|
|
@@ -128,7 +124,13 @@ def upSample(in_channels: int, out_channels: int, upSampleMode: UpSampleMode, di
|
|
|
128
124
|
if upSampleMode == UpSampleMode.CONV_TRANSPOSE:
|
|
129
125
|
return getTorchModule("ConvTranspose", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, stride = stride, padding = 0)
|
|
130
126
|
else:
|
|
131
|
-
|
|
127
|
+
if dim == 3:
|
|
128
|
+
upsampleMethod = "trilinear"
|
|
129
|
+
if dim == 2:
|
|
130
|
+
upsampleMethod = "bilinear"
|
|
131
|
+
if dim == 1:
|
|
132
|
+
upsampleMethod = "linear"
|
|
133
|
+
return torch.nn.Upsample(scale_factor=2, mode=upsampleMethod.lower(), align_corners=False)
|
|
132
134
|
|
|
133
135
|
class Unsqueeze(torch.nn.Module):
|
|
134
136
|
|
konfai/network/network.py
CHANGED
|
@@ -41,57 +41,41 @@ class OptimizerLoader():
|
|
|
41
41
|
def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
|
|
42
42
|
return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
|
|
43
43
|
|
|
44
|
-
class SchedulerStep():
|
|
45
|
-
|
|
46
|
-
@config(None)
|
|
47
|
-
def __init__(self, nb_step : int = 0) -> None:
|
|
48
|
-
self.nb_step = nb_step
|
|
49
44
|
|
|
50
45
|
class LRSchedulersLoader():
|
|
51
46
|
|
|
52
|
-
@config(
|
|
53
|
-
def __init__(self,
|
|
54
|
-
self.
|
|
55
|
-
|
|
56
|
-
def getschedulers(self, key: str, optimizer: torch.optim.Optimizer) ->
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
if
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
schedulers[config("{}.Model.{}.Schedulers.{}".format(KONFAI_ROOT(), key, name))(getattr(importlib.import_module(module), name))(optimizer, config = None)] = step.nb_step
|
|
65
|
-
ok = True
|
|
66
|
-
if not ok:
|
|
67
|
-
raise TrainerError("Unknown scheduler {}, tried importing from: 'torch.optim.lr_scheduler' and 'konfai.metric.schedulers', but no valid match was found. Check your YAML config or scheduler name spelling.".format(nameTmp))
|
|
68
|
-
return schedulers
|
|
69
|
-
|
|
70
|
-
class SchedulersLoader():
|
|
47
|
+
@config()
|
|
48
|
+
def __init__(self, nb_step : int = 0) -> None:
|
|
49
|
+
self.nb_step = nb_step
|
|
50
|
+
|
|
51
|
+
def getschedulers(self, key: str, scheduler_classname: str, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
|
|
52
|
+
for m in ["torch.optim.lr_scheduler", "konfai.metric.schedulers"]:
|
|
53
|
+
module, name = _getModule(scheduler_classname, m)
|
|
54
|
+
if hasattr(importlib.import_module(module), name):
|
|
55
|
+
return config("{}.Model.{}.schedulers.{}".format(KONFAI_ROOT(), key, scheduler_classname))(getattr(importlib.import_module(module), name))(optimizer, config = None)
|
|
56
|
+
raise TrainerError("Unknown scheduler {}, tried importing from: 'torch.optim.lr_scheduler' and 'konfai.metric.schedulers', but no valid match was found. Check your YAML config or scheduler name spelling.".format(scheduler_classname))
|
|
57
|
+
|
|
58
|
+
class LossSchedulersLoader():
|
|
71
59
|
|
|
72
|
-
@config(
|
|
73
|
-
def __init__(self,
|
|
74
|
-
self.
|
|
75
|
-
|
|
76
|
-
def getschedulers(self, key: str) ->
|
|
77
|
-
|
|
78
|
-
for name, step in self.params.items():
|
|
79
|
-
if name:
|
|
80
|
-
schedulers[getattr(importlib.import_module("konfai.metric.schedulers"), name)(config = None, DL_args = key)] = step.nb_step
|
|
81
|
-
return schedulers
|
|
60
|
+
@config()
|
|
61
|
+
def __init__(self, nb_step : int = 0) -> None:
|
|
62
|
+
self.nb_step = nb_step
|
|
63
|
+
|
|
64
|
+
def getschedulers(self, key: str, scheduler_classname: str) -> torch.optim.lr_scheduler._LRScheduler:
|
|
65
|
+
return config("{}.{}".format(key, scheduler_classname))(getattr(importlib.import_module('konfai.metric.schedulers'), scheduler_classname))(config = None)
|
|
82
66
|
|
|
83
67
|
class CriterionsAttr():
|
|
84
68
|
|
|
85
69
|
@config()
|
|
86
|
-
def __init__(self,
|
|
87
|
-
self.
|
|
70
|
+
def __init__(self, schedulers: dict[str, LossSchedulersLoader] = {"default:Constant": LossSchedulersLoader(0)}, isLoss: bool = True, group: int = 0, stepStart: int = 0, stepStop: Union[int, None] = None, accumulation: bool = False) -> None:
|
|
71
|
+
self.schedulersLoader = schedulers
|
|
88
72
|
self.isTorchCriterion = True
|
|
89
73
|
self.isLoss = isLoss
|
|
90
74
|
self.stepStart = stepStart
|
|
91
75
|
self.stepStop = stepStop
|
|
92
76
|
self.group = group
|
|
93
77
|
self.accumulation = accumulation
|
|
94
|
-
self.
|
|
78
|
+
self.schedulers: dict[Scheduler, int] = {}
|
|
95
79
|
|
|
96
80
|
class CriterionsLoader():
|
|
97
81
|
|
|
@@ -104,7 +88,8 @@ class CriterionsLoader():
|
|
|
104
88
|
for module_classpath, criterionsAttr in self.criterionsLoader.items():
|
|
105
89
|
module, name = _getModule(module_classpath, "konfai.metric.measure")
|
|
106
90
|
criterionsAttr.isTorchCriterion = module.startswith("torch")
|
|
107
|
-
|
|
91
|
+
for scheduler_classname, schedulers in criterionsAttr.schedulersLoader.items():
|
|
92
|
+
criterionsAttr.schedulers[schedulers.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}.schedulers".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath), scheduler_classname)] = schedulers.nb_step
|
|
108
93
|
criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
|
|
109
94
|
return criterions
|
|
110
95
|
|
|
@@ -202,7 +187,7 @@ class Measure():
|
|
|
202
187
|
|
|
203
188
|
for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
|
|
204
189
|
if it >= criterionsAttr.stepStart and (criterionsAttr.stepStop is None or it <= criterionsAttr.stepStop):
|
|
205
|
-
scheduler = self.update_scheduler(criterionsAttr.
|
|
190
|
+
scheduler = self.update_scheduler(criterionsAttr.schedulers, it)
|
|
206
191
|
self._loss[criterionsAttr.group]["{}:{}:{}".format(output_group, target_group, criterion.__class__.__name__)].add(scheduler.get_value(), criterion(output, *target))
|
|
207
192
|
if training and len(np.unique([len(l) for l in self._loss[criterionsAttr.group].values() if l.accumulation and l.isLoss])) == 1:
|
|
208
193
|
if criterionsAttr.isLoss:
|
|
@@ -542,7 +527,7 @@ class Network(ModuleArgsDict, ABC):
|
|
|
542
527
|
def __init__( self,
|
|
543
528
|
in_channels : int = 1,
|
|
544
529
|
optimizer: Union[OptimizerLoader, None] = None,
|
|
545
|
-
schedulers: Union[LRSchedulersLoader, None] = None,
|
|
530
|
+
schedulers: Union[dict[str, LRSchedulersLoader], None] = None,
|
|
546
531
|
outputsCriterions: Union[dict[str, TargetCriterionsLoader], None] = None,
|
|
547
532
|
patch : Union[ModelPatch, None] = None,
|
|
548
533
|
nb_batch_per_step : int = 1,
|
|
@@ -556,7 +541,7 @@ class Network(ModuleArgsDict, ABC):
|
|
|
556
541
|
self.optimizer : Union[torch.optim.Optimizer, None] = None
|
|
557
542
|
|
|
558
543
|
self.LRSchedulersLoader = schedulers
|
|
559
|
-
self.schedulers : Union[dict[torch.optim.lr_scheduler._LRScheduler, int], None] =
|
|
544
|
+
self.schedulers : Union[dict[torch.optim.lr_scheduler._LRScheduler, int], None] = {}
|
|
560
545
|
|
|
561
546
|
self.outputsCriterionsLoader = outputsCriterions
|
|
562
547
|
self.measure : Union[Measure, None] = None
|
|
@@ -669,8 +654,8 @@ class Network(ModuleArgsDict, ABC):
|
|
|
669
654
|
self._it = int(state_dict["{}_it".format(self.getName())])
|
|
670
655
|
if "{}_nb_lr_update".format(self.getName()) in state_dict:
|
|
671
656
|
self._nb_lr_update = int(state_dict["{}_nb_lr_update".format(self.getName())])
|
|
672
|
-
|
|
673
|
-
|
|
657
|
+
for scheduler in self.schedulers:
|
|
658
|
+
if scheduler.last_epoch == -1:
|
|
674
659
|
scheduler.last_epoch = self._nb_lr_update
|
|
675
660
|
self.initialized()
|
|
676
661
|
|
|
@@ -744,7 +729,8 @@ class Network(ModuleArgsDict, ABC):
|
|
|
744
729
|
self.optimizer.zero_grad()
|
|
745
730
|
|
|
746
731
|
if self.LRSchedulersLoader and self.optimizer:
|
|
747
|
-
|
|
732
|
+
for schedulers_classname, schedulers in self.LRSchedulersLoader.items():
|
|
733
|
+
self.schedulers[schedulers.getschedulers(key, schedulers_classname, self.optimizer)] = schedulers.nb_step
|
|
748
734
|
|
|
749
735
|
def initialized(self):
|
|
750
736
|
pass
|
|
@@ -892,6 +878,7 @@ class Network(ModuleArgsDict, ABC):
|
|
|
892
878
|
model._requires_grad(list(self.measure.outputsCriterions.keys()))
|
|
893
879
|
for loss in self.measure.getLoss():
|
|
894
880
|
self.scaler.scale(loss / self.nb_batch_per_step).backward()
|
|
881
|
+
|
|
895
882
|
if self._it % self.nb_batch_per_step == 0:
|
|
896
883
|
self.scaler.step(self.optimizer)
|
|
897
884
|
self.scaler.update()
|
|
@@ -903,11 +890,10 @@ class Network(ModuleArgsDict, ABC):
|
|
|
903
890
|
self._nb_lr_update+=1
|
|
904
891
|
step = 0
|
|
905
892
|
scheduler = None
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
step += value
|
|
893
|
+
for scheduler, value in self.schedulers.items():
|
|
894
|
+
if value is None or (self._nb_lr_update >= step and self._nb_lr_update < step+value):
|
|
895
|
+
break
|
|
896
|
+
step += value
|
|
911
897
|
if scheduler:
|
|
912
898
|
if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
|
|
913
899
|
if self.measure:
|