konfai 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -118,10 +118,9 @@ class Generator(network.Network):
118
118
  self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
119
119
  self.add_module("Tanh", torch.nn.Tanh())
120
120
 
121
- @config("Generator")
122
121
  def __init__(self,
123
122
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
124
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
123
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
125
124
  patch : ModelPatch = ModelPatch(),
126
125
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
127
126
  channels: list[int]=[1, 64, 128, 256, 512, 1024],
@@ -175,10 +175,9 @@ class DDPM(network.Network):
175
175
  def forward(self, input: torch.Tensor, eta: torch.Tensor, eta_hat: torch.Tensor, t: int) -> torch.Tensor:
176
176
  return torch.concat((self.v(input, eta, t), self.v(input, eta_hat, t)), dim=1)
177
177
 
178
- @config("DDPM")
179
178
  def __init__( self,
180
179
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
181
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
180
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
182
181
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
183
182
  patch : Union[ModelPatch, None] = None,
184
183
  train_noise_step: int = 1000,
@@ -37,10 +37,9 @@ class Discriminator(network.Network):
37
37
  self.add_module("Layers", Discriminator.DiscriminatorNLayers(channels, strides, dim))
38
38
  self.add_module("Head", Discriminator.DiscriminatorHead(channels[-1], dim))
39
39
 
40
- @config("Discriminator")
41
40
  def __init__(self,
42
41
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
43
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
42
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
44
43
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
45
44
  channels: list[int] = [1, 16, 32, 64, 64],
46
45
  strides: list[int] = [2,2,2,1],
@@ -169,10 +168,9 @@ class Discriminator_ADA(network.Network):
169
168
  self.add_module("Layers", Discriminator_ADA.DiscriminatorNLayers(channels, strides, 100, dim), in_branch=[0, "te"])
170
169
  self.add_module("Head", Discriminator_ADA.DiscriminatorHead(channels[-1], dim))
171
170
 
172
- @config("Discriminator_ADA")
173
171
  def __init__(self,
174
172
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
175
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
173
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
176
174
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
177
175
  channels: list[int] = [1, 16, 32, 64, 64],
178
176
  strides: list[int] = [2,2,2,1],
@@ -261,10 +259,9 @@ class Discriminator_ADA(network.Network):
261
259
  self.add_module("AutoEncoder", GeneratorV1.GeneratorAutoEncoder(ngf, dim))
262
260
  self.add_module("Head", GeneratorV1.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
263
261
 
264
- @config("GeneratorV1")
265
262
  def __init__(self,
266
263
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
267
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
264
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
268
265
  patch : ModelPatch = ModelPatch(),
269
266
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
270
267
  dim : int = 3) -> None:
@@ -349,10 +346,9 @@ class GeneratorV1(network.Network):
349
346
  self.add_module("AutoEncoder", GeneratorV1.GeneratorAutoEncoder(ngf, dim))
350
347
  self.add_module("Head", GeneratorV1.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
351
348
 
352
- @config("GeneratorV1")
353
349
  def __init__(self,
354
350
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
355
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
351
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
356
352
  patch : ModelPatch = ModelPatch(),
357
353
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
358
354
  dim : int = 3) -> None:
@@ -383,10 +379,9 @@ class GeneratorV2(network.Network):
383
379
  self.add_module("UNetBlock_0", NestedUNet.NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
384
380
  self.add_module("Head", GeneratorV2.NestedUNetHead(channels[:2], dim=dim), in_branch=["X_0_{}".format(len(channels)-2)])
385
381
 
386
- @config("GeneratorV2")
387
382
  def __init__( self,
388
383
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
389
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
384
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
390
385
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
391
386
  patch : Union[ModelPatch, None] = None,
392
387
  channels: list[int]=[1, 64, 128, 256, 512, 1024],
@@ -424,10 +419,9 @@ class GeneratorV3(network.Network):
424
419
  self.add_module("UNetBlock_0", UNet.UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=1, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
425
420
  self.add_module("Head", GeneratorV3.NestedUNetHead(channels[:2], dim=dim), in_branch=["X_0_{}".format(len(channels)-2)])
426
421
 
427
- @config("GeneratorV3")
428
422
  def __init__( self,
429
423
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
430
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
424
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
431
425
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
432
426
  patch : Union[ModelPatch, None] = None,
433
427
  channels: list[int]=[1, 64, 128, 256, 512, 1024],
@@ -443,7 +437,6 @@ class GeneratorV3(network.Network):
443
437
 
444
438
  class DiffusionGan(network.Network):
445
439
 
446
- @config("DiffusionGan")
447
440
  def __init__(self, generator : GeneratorV1 = GeneratorV1(), discriminator : Discriminator_ADA = Discriminator_ADA()) -> None:
448
441
  super().__init__()
449
442
  self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
@@ -454,7 +447,6 @@ class DiffusionGan(network.Network):
454
447
 
455
448
  class DiffusionGanV2(network.Network):
456
449
 
457
- @config("DiffusionGan")
458
450
  def __init__(self, generator : GeneratorV2 = GeneratorV2(), discriminator : Discriminator = Discriminator()) -> None:
459
451
  super().__init__()
460
452
  self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
@@ -466,10 +458,9 @@ class DiffusionGanV2(network.Network):
466
458
 
467
459
  class CycleGanDiscriminator(network.Network):
468
460
 
469
- @config("CycleGanDiscriminator")
470
461
  def __init__(self,
471
462
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
472
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
463
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
473
464
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
474
465
  patch : Union[ModelPatch, None] = None,
475
466
  channels: list[int] = [1, 16, 32, 64, 64],
@@ -485,10 +476,9 @@ class CycleGanDiscriminator(network.Network):
485
476
 
486
477
  class CycleGanGeneratorV1(network.Network):
487
478
 
488
- @config("CycleGanGeneratorV1")
489
479
  def __init__(self,
490
480
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
491
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
481
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
492
482
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
493
483
  patch : Union[ModelPatch, None] = None,
494
484
  dim : int = 3) -> None:
@@ -498,10 +488,9 @@ class CycleGanGeneratorV1(network.Network):
498
488
 
499
489
  class CycleGanGeneratorV2(network.Network):
500
490
 
501
- @config("CycleGanGeneratorV2")
502
491
  def __init__(self,
503
492
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
504
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
493
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
505
494
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
506
495
  patch : Union[ModelPatch, None] = None,
507
496
  channels: list[int]=[1, 64, 128, 256, 512, 1024],
@@ -518,10 +507,9 @@ class CycleGanGeneratorV2(network.Network):
518
507
 
519
508
  class CycleGanGeneratorV3(network.Network):
520
509
 
521
- @config("CycleGanGeneratorV3")
522
510
  def __init__(self,
523
511
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
524
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
512
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
525
513
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
526
514
  patch : Union[ModelPatch, None] = None,
527
515
  channels: list[int]=[1, 64, 128, 256, 512, 1024],
@@ -538,7 +526,6 @@ class CycleGanGeneratorV3(network.Network):
538
526
 
539
527
  class DiffusionCycleGan(network.Network):
540
528
 
541
- @config("DiffusionCycleGan")
542
529
  def __init__(self, generators : CycleGanGeneratorV3 = CycleGanGeneratorV3(), discriminators : CycleGanDiscriminator = CycleGanDiscriminator()) -> None:
543
530
  super().__init__()
544
531
  self.add_module("Generator", generators, in_branch=[0, 1], out_branch=["pB", "pA"])
@@ -23,10 +23,9 @@ class Discriminator(network.Network):
23
23
  self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(tuple([1]*dim)))
24
24
  self.add_module("Flatten", torch.nn.Flatten(1))
25
25
 
26
- @config("Discriminator")
27
26
  def __init__(self,
28
27
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
29
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
28
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
30
29
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
31
30
  nb_batch_per_step: int = 64,
32
31
  dim : int = 3) -> None:
@@ -101,11 +100,10 @@ class Generator(network.Network):
101
100
  self.add_module("Encoder", Generator.GeneratorEncoder(channels, dim))
102
101
  self.add_module("NResBlock", Generator.GeneratorNResnetBlock(channels=channels[-1], nb_conv=6, dim=dim))
103
102
  self.add_module("Decoder", Generator.GeneratorDecoder(channels, dim))
104
-
105
- @config("Generator")
103
+
106
104
  def __init__(self,
107
105
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
108
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
106
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
109
107
  patch : ModelPatch = ModelPatch(),
110
108
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
111
109
  nb_batch_per_step: int = 64,
@@ -121,7 +119,6 @@ class Generator(network.Network):
121
119
 
122
120
  class Gan(network.Network):
123
121
 
124
- @config("Gan")
125
122
  def __init__(self, generator : Generator = Generator(), discriminator : Discriminator = Discriminator()) -> None:
126
123
  super().__init__()
127
124
  self.add_module("Discriminator_B", discriminator, in_branch=[1], out_branch=[-1], requires_grad=True)
@@ -25,7 +25,6 @@ class VAE(network.Network):
25
25
  self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
26
26
  self.add_module("Tanh", torch.nn.Tanh())
27
27
 
28
- @config("VAE")
29
28
  def __init__(self,
30
29
  optimizer: network.OptimizerLoader = network.OptimizerLoader(),
31
30
  schedulers: network.LRSchedulersLoader = network.LRSchedulersLoader(),
@@ -60,7 +59,6 @@ class LinearVAE(network.Network):
60
59
  self.add_module("Linear", torch.nn.Linear(in_features, out_features))
61
60
  self.add_module("Tanh", torch.nn.Tanh())
62
61
 
63
- @config("LinearVAE")
64
62
  def __init__(self,
65
63
  optimizer: network.OptimizerLoader = network.OptimizerLoader(),
66
64
  schedulers: network.LRSchedulersLoader = network.LRSchedulersLoader(),
@@ -8,10 +8,9 @@ from konfai.models.segmentation import UNet
8
8
 
9
9
  class VoxelMorph(network.Network):
10
10
 
11
- @config("VoxelMorph")
12
11
  def __init__( self,
13
12
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
14
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
13
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
15
14
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
16
15
  dim : int = 3,
17
16
  channels : list[int] = [4, 16,32,32,32],
@@ -47,10 +47,9 @@ class Adaptation(torch.nn.Module):
47
47
 
48
48
  class Representation(network.Network):
49
49
 
50
- @config("Representation")
51
50
  def __init__( self,
52
51
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
53
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
52
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
54
53
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
55
54
  dim : int = 3):
56
55
  super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=dim, init_type="kaiming")
@@ -3,23 +3,23 @@ from typing import Union
3
3
  import torch
4
4
  from konfai.data.patching import ModelPatch
5
5
 
6
-
7
- class NestedUNetBlock(network.ModuleArgsDict):
6
+ class NestedUNet(network.Network):
8
7
 
9
- def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, dim: int, i : int = 0) -> None:
10
- super().__init__()
11
- if i > 0:
12
- self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
13
-
14
- self.add_module("X_{}_{}".format(i, 0), block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), out_branch=["X_{}_{}".format(i, 0)])
15
- if len(channels) > 2:
16
- self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(channels)-2)])
17
- for j in range(len(channels)-2):
18
- self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=channels[2], out_channels=channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
19
- self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
20
- self.add_module("X_{}_{}".format(i, j+1), block(in_channels=(channels[1]*(j+1)+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*(j+2), out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
8
+ class NestedUNetBlock(network.ModuleArgsDict):
9
+
10
+ def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, dim: int, i : int = 0) -> None:
11
+ super().__init__()
12
+ if i > 0:
13
+ self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
14
+
15
+ self.add_module("X_{}_{}".format(i, 0), block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), out_branch=["X_{}_{}".format(i, 0)])
16
+ if len(channels) > 2:
17
+ self.add_module("UNetBlock_{}".format(i+1), NestedUNet.NestedUNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(channels)-2)])
18
+ for j in range(len(channels)-2):
19
+ self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=channels[2], out_channels=channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
20
+ self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
21
+ self.add_module("X_{}_{}".format(i, j+1), block(in_channels=(channels[1]*(j+1)+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*(j+2), out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
21
22
 
22
- class NestedUNet(network.Network):
23
23
 
24
24
  class NestedUNetHead(network.ModuleArgsDict):
25
25
 
@@ -34,7 +34,7 @@ class NestedUNet(network.Network):
34
34
 
35
35
  def __init__( self,
36
36
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
37
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
37
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
38
38
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
39
39
  patch : Union[ModelPatch, None] = None,
40
40
  dim : int = 3,
@@ -49,61 +49,60 @@ class NestedUNet(network.Network):
49
49
  activation: str = "Softmax") -> None:
50
50
  super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
51
51
 
52
- self.add_module("UNetBlock_0", NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
52
+ self.add_module("UNetBlock_0", NestedUNet.NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
53
53
  for j in range(len(channels)-2):
54
54
  self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, activation=activation, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
55
55
 
56
56
 
57
+ class UNetpp(network.Network):
57
58
 
58
-
59
- class ResNetEncoderLayer(network.ModuleArgsDict):
60
-
61
- def __init__(self, in_channel: int, out_channel: int, nb_block: int, dim: int, downSampleMode : blocks.DownSampleMode):
62
- super().__init__()
63
- for i in range(nb_block):
64
- if downSampleMode == blocks.DownSampleMode.MAXPOOL and i == 0:
65
- self.add_module("DownSample", blocks.getTorchModule("MaxPool", dim)(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
66
- self.add_module("ResBlock_{}".format(i), blocks.ResBlock(in_channel, out_channel, [blocks.BlockConfig(3, 2 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i == 0 else 1, 1, False, "ReLU;True", blocks.NormMode.BATCH), blocks.BlockConfig(3, 1, 1, False, None, blocks.NormMode.BATCH)], dim=dim))
67
- in_channel = out_channel
68
-
69
- def resNetEncoder(channels: list[int], layers: list[int], dim: int) -> list[torch.nn.Module]:
70
- modules = []
71
- modules.append(blocks.ConvBlock(channels[0], channels[1], [blocks.BlockConfig(7, 2, 3, False, "ReLU", blocks.NormMode.BATCH)], dim=dim))
72
- for i, (in_channel, out_channel, layer) in enumerate(zip(channels[1:], channels[2:], layers)):
73
- modules.append(ResNetEncoderLayer(in_channel, out_channel, layer, dim, blocks.DownSampleMode.MAXPOOL if i==0 else blocks.DownSampleMode.CONV_STRIDE))
74
- return modules
75
-
76
- class NestedUNetBlock(network.ModuleArgsDict):
77
-
78
- def __init__(self, encoder_channels: list[int], decoder_channels: list[int], encoders: list[torch.nn.Module], upSampleMode: blocks.UpSampleMode, dim: int, i : int = 0) -> None:
79
- super().__init__()
80
- self.add_module("X_{}_{}".format(i, 0), encoders[0], out_branch=["X_{}_{}".format(i, 0)])
81
- if len(encoder_channels) > 2:
82
- self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(encoder_channels[1:], decoder_channels[1:], encoders[1:], upSampleMode, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(encoder_channels)-2)])
83
- for j in range(len(encoder_channels)-2):
84
- self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=encoder_channels[2], out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
85
- self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
86
- self.add_module("X_{}_{}".format(i, j+1), blocks.ConvBlock(in_channels=encoder_channels[1]*j+((encoder_channels[1]+encoder_channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else (decoder_channels[1]+decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1]+encoder_channels[1])), out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
59
+ class ResNetEncoderLayer(network.ModuleArgsDict):
87
60
 
88
- class NestedUNetHead(network.ModuleArgsDict):
61
+ def __init__(self, in_channel: int, out_channel: int, nb_block: int, dim: int, downSampleMode : blocks.DownSampleMode):
62
+ super().__init__()
63
+ for i in range(nb_block):
64
+ if downSampleMode == blocks.DownSampleMode.MAXPOOL and i == 0:
65
+ self.add_module("DownSample", blocks.getTorchModule("MaxPool", dim)(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
66
+ self.add_module("ResBlock_{}".format(i), blocks.ResBlock(in_channel, out_channel, [blocks.BlockConfig(3, 2 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i == 0 else 1, 1, False, "ReLU;True", blocks.NormMode.BATCH), blocks.BlockConfig(3, 1, 1, False, None, blocks.NormMode.BATCH)], dim=dim))
67
+ in_channel = out_channel
68
+
69
+ def resNetEncoder(channels: list[int], layers: list[int], dim: int) -> list[torch.nn.Module]:
70
+ modules = []
71
+ modules.append(blocks.ConvBlock(channels[0], channels[1], [blocks.BlockConfig(7, 2, 3, False, "ReLU", blocks.NormMode.BATCH)], dim=dim))
72
+ for i, (in_channel, out_channel, layer) in enumerate(zip(channels[1:], channels[2:], layers)):
73
+ modules.append(UNetpp.ResNetEncoderLayer(in_channel, out_channel, layer, dim, blocks.DownSampleMode.MAXPOOL if i==0 else blocks.DownSampleMode.CONV_STRIDE))
74
+ return modules
89
75
 
90
- def __init__(self, in_channels: int, out_channels: int, nb_class: int, dim: int) -> None:
91
- super().__init__()
92
- self.add_module("Upsample", blocks.upSample(in_channels=in_channels, out_channels=out_channels, upSampleMode=blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim))
93
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels=in_channels, out_channels=out_channels, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim))
94
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = out_channels, out_channels = nb_class, kernel_size = 3, stride = 1, padding = 1))
95
- if nb_class > 1:
96
- self.add_module("Softmax", torch.nn.Softmax(dim=1))
97
- self.add_module("Argmax", blocks.ArgMax(dim=1))
98
- else:
99
- self.add_module("Tanh", torch.nn.Tanh())
100
-
101
- class UNetpp(network.Network):
76
+ class UNetPPBlock(network.ModuleArgsDict):
102
77
 
78
+ def __init__(self, encoder_channels: list[int], decoder_channels: list[int], encoders: list[torch.nn.Module], upSampleMode: blocks.UpSampleMode, dim: int, i : int = 0) -> None:
79
+ super().__init__()
80
+ self.add_module("X_{}_{}".format(i, 0), encoders[0], out_branch=["X_{}_{}".format(i, 0)])
81
+ if len(encoder_channels) > 2:
82
+ self.add_module("UNetBlock_{}".format(i+1), UNetpp.UNetPPBlock(encoder_channels[1:], decoder_channels[1:], encoders[1:], upSampleMode, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(encoder_channels)-2)])
83
+ for j in range(len(encoder_channels)-2):
84
+ in_channels = decoder_channels[3] if j == len(encoder_channels)-3 and len(encoder_channels) > 3 else encoder_channels[2]
85
+ out_channel = decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1]
86
+ self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=in_channels, out_channels=out_channel, upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
87
+ self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
88
+ self.add_module("X_{}_{}".format(i, j+1), blocks.ConvBlock(in_channels=encoder_channels[1]*(j+1)+(in_channels if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else out_channel), out_channels=out_channel, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
89
+
90
+ class UNetPPHead(network.ModuleArgsDict):
103
91
 
92
+ def __init__(self, in_channels: int, out_channels: int, nb_class: int, dim: int) -> None:
93
+ super().__init__()
94
+ self.add_module("Upsample", blocks.upSample(in_channels=in_channels, out_channels=out_channels, upSampleMode=blocks.UpSampleMode.UPSAMPLE, dim=dim))
95
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels=in_channels, out_channels=out_channels, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim))
96
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = out_channels, out_channels = nb_class, kernel_size = 3, stride = 1, padding = 1))
97
+ if nb_class > 1:
98
+ self.add_module("Softmax", torch.nn.Softmax(dim=1))
99
+ self.add_module("Argmax", blocks.ArgMax(dim=1))
100
+ else:
101
+ self.add_module("Tanh", torch.nn.Tanh())
102
+
104
103
  def __init__( self,
105
104
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
106
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
105
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
107
106
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
108
107
  patch : Union[ModelPatch, None] = None,
109
108
  encoder_channels: list[int] = [1,64,64,128,256,512],
@@ -111,5 +110,5 @@ class UNetpp(network.Network):
111
110
  layers: list[int] = [3,4,6,3],
112
111
  dim : int = 2) -> None:
113
112
  super().__init__(in_channels = encoder_channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
114
- self.add_module("Block_0", NestedUNetBlock(encoder_channels, decoder_channels[::-1], resNetEncoder(encoder_channels, layers, dim), blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(encoder_channels)-2)])
115
- self.add_module("Head", NestedUNetHead(in_channels=decoder_channels[-3], out_channels=decoder_channels[-2], nb_class=decoder_channels[-1], dim=dim), in_branch=["X_0_{}".format(len(encoder_channels)-2)], out_branch=[-1])
113
+ self.add_module("Block_0", UNetpp.UNetPPBlock(encoder_channels, decoder_channels[::-1], UNetpp.resNetEncoder(encoder_channels, layers, dim), blocks.UpSampleMode.CONV_TRANSPOSE, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(encoder_channels)-2)])
114
+ self.add_module("Head", UNetpp.UNetPPHead(in_channels=decoder_channels[-3], out_channels=decoder_channels[-2], nb_class=decoder_channels[-1], dim=dim), in_branch=["X_0_{}".format(len(encoder_channels)-2)], out_branch=[-1])
@@ -12,7 +12,6 @@ class UNetHead(network.ModuleArgsDict):
12
12
  self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
13
13
  self.add_module("Softmax", torch.nn.Softmax(dim=1))
14
14
  self.add_module("Argmax", blocks.ArgMax(dim=1))
15
-
16
15
 
17
16
  class UNetBlock(network.ModuleArgsDict):
18
17
 
@@ -40,7 +39,7 @@ class UNet(network.Network):
40
39
 
41
40
  def __init__( self,
42
41
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
43
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
42
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
44
43
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
45
44
  patch : Union[ModelPatch, None] = None,
46
45
  dim : int = 3,
konfai/network/blocks.py CHANGED
@@ -36,11 +36,7 @@ def getNorm(normMode: Enum, channels : int, dim: int) -> Union[torch.nn.Module,
36
36
 
37
37
  class UpSampleMode(Enum):
38
38
  CONV_TRANSPOSE = 0,
39
- UPSAMPLE_NEAREST = 1,
40
- UPSAMPLE_LINEAR = 2,
41
- UPSAMPLE_BILINEAR = 3,
42
- UPSAMPLE_BICUBIC = 4,
43
- UPSAMPLE_TRILINEAR = 5
39
+ UPSAMPLE = 1,
44
40
 
45
41
  class DownSampleMode(Enum):
46
42
  MAXPOOL = 0,
@@ -128,7 +124,13 @@ def upSample(in_channels: int, out_channels: int, upSampleMode: UpSampleMode, di
128
124
  if upSampleMode == UpSampleMode.CONV_TRANSPOSE:
129
125
  return getTorchModule("ConvTranspose", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, stride = stride, padding = 0)
130
126
  else:
131
- return torch.nn.Upsample(scale_factor=2, mode=upSampleMode.name.replace("UPSAMPLE_", "").lower(), align_corners=False)
127
+ if dim == 3:
128
+ upsampleMethod = "trilinear"
129
+ if dim == 2:
130
+ upsampleMethod = "bilinear"
131
+ if dim == 1:
132
+ upsampleMethod = "linear"
133
+ return torch.nn.Upsample(scale_factor=2, mode=upsampleMethod.lower(), align_corners=False)
132
134
 
133
135
  class Unsqueeze(torch.nn.Module):
134
136