konfai 1.1.6__py3-none-any.whl → 1.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -37,10 +37,9 @@ class Discriminator(network.Network):
37
37
  self.add_module("Layers", Discriminator.DiscriminatorNLayers(channels, strides, dim))
38
38
  self.add_module("Head", Discriminator.DiscriminatorHead(channels[-1], dim))
39
39
 
40
- @config("Discriminator")
41
40
  def __init__(self,
42
41
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
43
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
42
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
44
43
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
45
44
  channels: list[int] = [1, 16, 32, 64, 64],
46
45
  strides: list[int] = [2,2,2,1],
@@ -169,10 +168,9 @@ class Discriminator_ADA(network.Network):
169
168
  self.add_module("Layers", Discriminator_ADA.DiscriminatorNLayers(channels, strides, 100, dim), in_branch=[0, "te"])
170
169
  self.add_module("Head", Discriminator_ADA.DiscriminatorHead(channels[-1], dim))
171
170
 
172
- @config("Discriminator_ADA")
173
171
  def __init__(self,
174
172
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
175
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
173
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
176
174
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
177
175
  channels: list[int] = [1, 16, 32, 64, 64],
178
176
  strides: list[int] = [2,2,2,1],
@@ -261,10 +259,9 @@ class Discriminator_ADA(network.Network):
261
259
  self.add_module("AutoEncoder", GeneratorV1.GeneratorAutoEncoder(ngf, dim))
262
260
  self.add_module("Head", GeneratorV1.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
263
261
 
264
- @config("GeneratorV1")
265
262
  def __init__(self,
266
263
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
267
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
264
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
268
265
  patch : ModelPatch = ModelPatch(),
269
266
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
270
267
  dim : int = 3) -> None:
@@ -349,10 +346,9 @@ class GeneratorV1(network.Network):
349
346
  self.add_module("AutoEncoder", GeneratorV1.GeneratorAutoEncoder(ngf, dim))
350
347
  self.add_module("Head", GeneratorV1.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
351
348
 
352
- @config("GeneratorV1")
353
349
  def __init__(self,
354
350
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
355
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
351
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
356
352
  patch : ModelPatch = ModelPatch(),
357
353
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
358
354
  dim : int = 3) -> None:
@@ -383,10 +379,9 @@ class GeneratorV2(network.Network):
383
379
  self.add_module("UNetBlock_0", NestedUNet.NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
384
380
  self.add_module("Head", GeneratorV2.NestedUNetHead(channels[:2], dim=dim), in_branch=["X_0_{}".format(len(channels)-2)])
385
381
 
386
- @config("GeneratorV2")
387
382
  def __init__( self,
388
383
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
389
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
384
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
390
385
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
391
386
  patch : Union[ModelPatch, None] = None,
392
387
  channels: list[int]=[1, 64, 128, 256, 512, 1024],
@@ -424,10 +419,9 @@ class GeneratorV3(network.Network):
424
419
  self.add_module("UNetBlock_0", UNet.UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=1, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
425
420
  self.add_module("Head", GeneratorV3.NestedUNetHead(channels[:2], dim=dim), in_branch=["X_0_{}".format(len(channels)-2)])
426
421
 
427
- @config("GeneratorV3")
428
422
  def __init__( self,
429
423
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
430
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
424
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
431
425
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
432
426
  patch : Union[ModelPatch, None] = None,
433
427
  channels: list[int]=[1, 64, 128, 256, 512, 1024],
@@ -443,7 +437,6 @@ class GeneratorV3(network.Network):
443
437
 
444
438
  class DiffusionGan(network.Network):
445
439
 
446
- @config("DiffusionGan")
447
440
  def __init__(self, generator : GeneratorV1 = GeneratorV1(), discriminator : Discriminator_ADA = Discriminator_ADA()) -> None:
448
441
  super().__init__()
449
442
  self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
@@ -454,7 +447,6 @@ class DiffusionGan(network.Network):
454
447
 
455
448
  class DiffusionGanV2(network.Network):
456
449
 
457
- @config("DiffusionGan")
458
450
  def __init__(self, generator : GeneratorV2 = GeneratorV2(), discriminator : Discriminator = Discriminator()) -> None:
459
451
  super().__init__()
460
452
  self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
@@ -466,10 +458,9 @@ class DiffusionGanV2(network.Network):
466
458
 
467
459
  class CycleGanDiscriminator(network.Network):
468
460
 
469
- @config("CycleGanDiscriminator")
470
461
  def __init__(self,
471
462
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
472
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
463
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
473
464
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
474
465
  patch : Union[ModelPatch, None] = None,
475
466
  channels: list[int] = [1, 16, 32, 64, 64],
@@ -485,10 +476,9 @@ class CycleGanDiscriminator(network.Network):
485
476
 
486
477
  class CycleGanGeneratorV1(network.Network):
487
478
 
488
- @config("CycleGanGeneratorV1")
489
479
  def __init__(self,
490
480
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
491
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
481
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
492
482
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
493
483
  patch : Union[ModelPatch, None] = None,
494
484
  dim : int = 3) -> None:
@@ -498,10 +488,9 @@ class CycleGanGeneratorV1(network.Network):
498
488
 
499
489
  class CycleGanGeneratorV2(network.Network):
500
490
 
501
- @config("CycleGanGeneratorV2")
502
491
  def __init__(self,
503
492
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
504
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
493
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
505
494
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
506
495
  patch : Union[ModelPatch, None] = None,
507
496
  channels: list[int]=[1, 64, 128, 256, 512, 1024],
@@ -518,10 +507,9 @@ class CycleGanGeneratorV2(network.Network):
518
507
 
519
508
  class CycleGanGeneratorV3(network.Network):
520
509
 
521
- @config("CycleGanGeneratorV3")
522
510
  def __init__(self,
523
511
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
524
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
512
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
525
513
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
526
514
  patch : Union[ModelPatch, None] = None,
527
515
  channels: list[int]=[1, 64, 128, 256, 512, 1024],
@@ -538,7 +526,6 @@ class CycleGanGeneratorV3(network.Network):
538
526
 
539
527
  class DiffusionCycleGan(network.Network):
540
528
 
541
- @config("DiffusionCycleGan")
542
529
  def __init__(self, generators : CycleGanGeneratorV3 = CycleGanGeneratorV3(), discriminators : CycleGanDiscriminator = CycleGanDiscriminator()) -> None:
543
530
  super().__init__()
544
531
  self.add_module("Generator", generators, in_branch=[0, 1], out_branch=["pB", "pA"])
@@ -23,10 +23,9 @@ class Discriminator(network.Network):
23
23
  self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(tuple([1]*dim)))
24
24
  self.add_module("Flatten", torch.nn.Flatten(1))
25
25
 
26
- @config("Discriminator")
27
26
  def __init__(self,
28
27
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
29
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
28
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
30
29
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
31
30
  nb_batch_per_step: int = 64,
32
31
  dim : int = 3) -> None:
@@ -101,11 +100,10 @@ class Generator(network.Network):
101
100
  self.add_module("Encoder", Generator.GeneratorEncoder(channels, dim))
102
101
  self.add_module("NResBlock", Generator.GeneratorNResnetBlock(channels=channels[-1], nb_conv=6, dim=dim))
103
102
  self.add_module("Decoder", Generator.GeneratorDecoder(channels, dim))
104
-
105
- @config("Generator")
103
+
106
104
  def __init__(self,
107
105
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
108
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
106
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
109
107
  patch : ModelPatch = ModelPatch(),
110
108
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
111
109
  nb_batch_per_step: int = 64,
@@ -121,7 +119,6 @@ class Generator(network.Network):
121
119
 
122
120
  class Gan(network.Network):
123
121
 
124
- @config("Gan")
125
122
  def __init__(self, generator : Generator = Generator(), discriminator : Discriminator = Discriminator()) -> None:
126
123
  super().__init__()
127
124
  self.add_module("Discriminator_B", discriminator, in_branch=[1], out_branch=[-1], requires_grad=True)
@@ -25,7 +25,6 @@ class VAE(network.Network):
25
25
  self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
26
26
  self.add_module("Tanh", torch.nn.Tanh())
27
27
 
28
- @config("VAE")
29
28
  def __init__(self,
30
29
  optimizer: network.OptimizerLoader = network.OptimizerLoader(),
31
30
  schedulers: network.LRSchedulersLoader = network.LRSchedulersLoader(),
@@ -60,7 +59,6 @@ class LinearVAE(network.Network):
60
59
  self.add_module("Linear", torch.nn.Linear(in_features, out_features))
61
60
  self.add_module("Tanh", torch.nn.Tanh())
62
61
 
63
- @config("LinearVAE")
64
62
  def __init__(self,
65
63
  optimizer: network.OptimizerLoader = network.OptimizerLoader(),
66
64
  schedulers: network.LRSchedulersLoader = network.LRSchedulersLoader(),
@@ -8,10 +8,9 @@ from konfai.models.segmentation import UNet
8
8
 
9
9
  class VoxelMorph(network.Network):
10
10
 
11
- @config("VoxelMorph")
12
11
  def __init__( self,
13
12
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
14
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
13
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
15
14
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
16
15
  dim : int = 3,
17
16
  channels : list[int] = [4, 16,32,32,32],
@@ -47,10 +47,9 @@ class Adaptation(torch.nn.Module):
47
47
 
48
48
  class Representation(network.Network):
49
49
 
50
- @config("Representation")
51
50
  def __init__( self,
52
51
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
53
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
52
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
54
53
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
55
54
  dim : int = 3):
56
55
  super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=dim, init_type="kaiming")
@@ -34,7 +34,7 @@ class NestedUNet(network.Network):
34
34
 
35
35
  def __init__( self,
36
36
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
37
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
37
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
38
38
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
39
39
  patch : Union[ModelPatch, None] = None,
40
40
  dim : int = 3,
@@ -81,15 +81,17 @@ class UNetpp(network.Network):
81
81
  if len(encoder_channels) > 2:
82
82
  self.add_module("UNetBlock_{}".format(i+1), UNetpp.UNetPPBlock(encoder_channels[1:], decoder_channels[1:], encoders[1:], upSampleMode, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(encoder_channels)-2)])
83
83
  for j in range(len(encoder_channels)-2):
84
- self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=encoder_channels[2], out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
84
+ in_channels = decoder_channels[3] if j == len(encoder_channels)-3 and len(encoder_channels) > 3 else encoder_channels[2]
85
+ out_channel = decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1]
86
+ self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=in_channels, out_channels=out_channel, upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
85
87
  self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
86
- self.add_module("X_{}_{}".format(i, j+1), blocks.ConvBlock(in_channels=encoder_channels[1]*j+((encoder_channels[1]+encoder_channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else (decoder_channels[1]+decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1]+encoder_channels[1])), out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
87
-
88
+ self.add_module("X_{}_{}".format(i, j+1), blocks.ConvBlock(in_channels=encoder_channels[1]*(j+1)+(in_channels if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else out_channel), out_channels=out_channel, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
89
+
88
90
  class UNetPPHead(network.ModuleArgsDict):
89
91
 
90
92
  def __init__(self, in_channels: int, out_channels: int, nb_class: int, dim: int) -> None:
91
93
  super().__init__()
92
- self.add_module("Upsample", blocks.upSample(in_channels=in_channels, out_channels=out_channels, upSampleMode=blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim))
94
+ self.add_module("Upsample", blocks.upSample(in_channels=in_channels, out_channels=out_channels, upSampleMode=blocks.UpSampleMode.UPSAMPLE, dim=dim))
93
95
  self.add_module("ConvBlock", blocks.ConvBlock(in_channels=in_channels, out_channels=out_channels, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim))
94
96
  self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = out_channels, out_channels = nb_class, kernel_size = 3, stride = 1, padding = 1))
95
97
  if nb_class > 1:
@@ -100,7 +102,7 @@ class UNetpp(network.Network):
100
102
 
101
103
  def __init__( self,
102
104
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
103
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
105
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
104
106
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
105
107
  patch : Union[ModelPatch, None] = None,
106
108
  encoder_channels: list[int] = [1,64,64,128,256,512],
@@ -108,6 +110,5 @@ class UNetpp(network.Network):
108
110
  layers: list[int] = [3,4,6,3],
109
111
  dim : int = 2) -> None:
110
112
  super().__init__(in_channels = encoder_channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
111
- self.add_module("Block_0", UNetpp.UNetPPBlock(encoder_channels, decoder_channels[::-1], UNetpp.resNetEncoder(encoder_channels, layers, dim), blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(encoder_channels)-2)])
112
- self.add_module("Head", UNetpp.UNetPPHead(in_channels=decoder_channels[-3], out_channels=decoder_channels[-2], nb_class=decoder_channels[-1], dim=dim), in_branch=["X_0_{}".format(len(encoder_channels)-2)], out_branch=[-1])
113
-
113
+ self.add_module("Block_0", UNetpp.UNetPPBlock(encoder_channels, decoder_channels[::-1], UNetpp.resNetEncoder(encoder_channels, layers, dim), blocks.UpSampleMode.CONV_TRANSPOSE, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(encoder_channels)-2)])
114
+ self.add_module("Head", UNetpp.UNetPPHead(in_channels=decoder_channels[-3], out_channels=decoder_channels[-2], nb_class=decoder_channels[-1], dim=dim), in_branch=["X_0_{}".format(len(encoder_channels)-2)], out_branch=[-1])
@@ -39,7 +39,7 @@ class UNet(network.Network):
39
39
 
40
40
  def __init__( self,
41
41
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
42
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
42
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
43
43
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
44
44
  patch : Union[ModelPatch, None] = None,
45
45
  dim : int = 3,
konfai/network/blocks.py CHANGED
@@ -36,11 +36,7 @@ def getNorm(normMode: Enum, channels : int, dim: int) -> Union[torch.nn.Module,
36
36
 
37
37
  class UpSampleMode(Enum):
38
38
  CONV_TRANSPOSE = 0,
39
- UPSAMPLE_NEAREST = 1,
40
- UPSAMPLE_LINEAR = 2,
41
- UPSAMPLE_BILINEAR = 3,
42
- UPSAMPLE_BICUBIC = 4,
43
- UPSAMPLE_TRILINEAR = 5
39
+ UPSAMPLE = 1,
44
40
 
45
41
  class DownSampleMode(Enum):
46
42
  MAXPOOL = 0,
@@ -128,7 +124,13 @@ def upSample(in_channels: int, out_channels: int, upSampleMode: UpSampleMode, di
128
124
  if upSampleMode == UpSampleMode.CONV_TRANSPOSE:
129
125
  return getTorchModule("ConvTranspose", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, stride = stride, padding = 0)
130
126
  else:
131
- return torch.nn.Upsample(scale_factor=2, mode=upSampleMode.name.replace("UPSAMPLE_", "").lower(), align_corners=False)
127
+ if dim == 3:
128
+ upsampleMethod = "trilinear"
129
+ if dim == 2:
130
+ upsampleMethod = "bilinear"
131
+ if dim == 1:
132
+ upsampleMethod = "linear"
133
+ return torch.nn.Upsample(scale_factor=2, mode=upsampleMethod.lower(), align_corners=False)
132
134
 
133
135
  class Unsqueeze(torch.nn.Module):
134
136
 
konfai/network/network.py CHANGED
@@ -41,57 +41,41 @@ class OptimizerLoader():
41
41
  def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
42
42
  return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
43
43
 
44
- class SchedulerStep():
45
-
46
- @config(None)
47
- def __init__(self, nb_step : int = 0) -> None:
48
- self.nb_step = nb_step
49
44
 
50
45
  class LRSchedulersLoader():
51
46
 
52
- @config("Schedulers")
53
- def __init__(self, params: dict[str, SchedulerStep] = {"default:ReduceLROnPlateau" : SchedulerStep(0)}) -> None:
54
- self.params = params
55
-
56
- def getschedulers(self, key: str, optimizer: torch.optim.Optimizer) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
57
- schedulers : dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
58
- for nameTmp, step in self.params.items():
59
- if nameTmp:
60
- ok = False
61
- for m in ["torch.optim.lr_scheduler", "konfai.metric.schedulers"]:
62
- module, name = _getModule(nameTmp, m)
63
- if hasattr(importlib.import_module(module), name):
64
- schedulers[config("{}.Model.{}.Schedulers.{}".format(KONFAI_ROOT(), key, name))(getattr(importlib.import_module(module), name))(optimizer, config = None)] = step.nb_step
65
- ok = True
66
- if not ok:
67
- raise TrainerError("Unknown scheduler {}, tried importing from: 'torch.optim.lr_scheduler' and 'konfai.metric.schedulers', but no valid match was found. Check your YAML config or scheduler name spelling.".format(nameTmp))
68
- return schedulers
69
-
70
- class SchedulersLoader():
47
+ @config()
48
+ def __init__(self, nb_step : int = 0) -> None:
49
+ self.nb_step = nb_step
50
+
51
+ def getschedulers(self, key: str, scheduler_classname: str, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
52
+ for m in ["torch.optim.lr_scheduler", "konfai.metric.schedulers"]:
53
+ module, name = _getModule(scheduler_classname, m)
54
+ if hasattr(importlib.import_module(module), name):
55
+ return config("{}.Model.{}.schedulers.{}".format(KONFAI_ROOT(), key, scheduler_classname))(getattr(importlib.import_module(module), name))(optimizer, config = None)
56
+ raise TrainerError("Unknown scheduler {}, tried importing from: 'torch.optim.lr_scheduler' and 'konfai.metric.schedulers', but no valid match was found. Check your YAML config or scheduler name spelling.".format(scheduler_classname))
57
+
58
+ class LossSchedulersLoader():
71
59
 
72
- @config("Schedulers")
73
- def __init__(self, params: dict[str, SchedulerStep] = {"default:Constant" : SchedulerStep(0)}) -> None:
74
- self.params = params
75
-
76
- def getschedulers(self, key: str) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
77
- schedulers : dict[Scheduler, int] = {}
78
- for name, step in self.params.items():
79
- if name:
80
- schedulers[getattr(importlib.import_module("konfai.metric.schedulers"), name)(config = None, DL_args = key)] = step.nb_step
81
- return schedulers
60
+ @config()
61
+ def __init__(self, nb_step : int = 0) -> None:
62
+ self.nb_step = nb_step
63
+
64
+ def getschedulers(self, key: str, scheduler_classname: str) -> torch.optim.lr_scheduler._LRScheduler:
65
+ return config("{}.{}".format(key, scheduler_classname))(getattr(importlib.import_module('konfai.metric.schedulers'), scheduler_classname))(config = None)
82
66
 
83
67
  class CriterionsAttr():
84
68
 
85
69
  @config()
86
- def __init__(self, l: SchedulersLoader = SchedulersLoader(), isLoss: bool = True, group: int = 0, stepStart: int = 0, stepStop: Union[int, None] = None, accumulation: bool = False) -> None:
87
- self.l = l
70
+ def __init__(self, schedulers: dict[str, LossSchedulersLoader] = {"default:Constant": LossSchedulersLoader(0)}, isLoss: bool = True, group: int = 0, stepStart: int = 0, stepStop: Union[int, None] = None, accumulation: bool = False) -> None:
71
+ self.schedulersLoader = schedulers
88
72
  self.isTorchCriterion = True
89
73
  self.isLoss = isLoss
90
74
  self.stepStart = stepStart
91
75
  self.stepStop = stepStop
92
76
  self.group = group
93
77
  self.accumulation = accumulation
94
- self.sheduler = None
78
+ self.schedulers: dict[Scheduler, int] = {}
95
79
 
96
80
  class CriterionsLoader():
97
81
 
@@ -104,7 +88,8 @@ class CriterionsLoader():
104
88
  for module_classpath, criterionsAttr in self.criterionsLoader.items():
105
89
  module, name = _getModule(module_classpath, "konfai.metric.measure")
106
90
  criterionsAttr.isTorchCriterion = module.startswith("torch")
107
- criterionsAttr.sheduler = criterionsAttr.l.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))
91
+ for scheduler_classname, schedulers in criterionsAttr.schedulersLoader.items():
92
+ criterionsAttr.schedulers[schedulers.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}.schedulers".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath), scheduler_classname)] = schedulers.nb_step
108
93
  criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
109
94
  return criterions
110
95
 
@@ -202,7 +187,7 @@ class Measure():
202
187
 
203
188
  for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
204
189
  if it >= criterionsAttr.stepStart and (criterionsAttr.stepStop is None or it <= criterionsAttr.stepStop):
205
- scheduler = self.update_scheduler(criterionsAttr.sheduler, it)
190
+ scheduler = self.update_scheduler(criterionsAttr.schedulers, it)
206
191
  self._loss[criterionsAttr.group]["{}:{}:{}".format(output_group, target_group, criterion.__class__.__name__)].add(scheduler.get_value(), criterion(output, *target))
207
192
  if training and len(np.unique([len(l) for l in self._loss[criterionsAttr.group].values() if l.accumulation and l.isLoss])) == 1:
208
193
  if criterionsAttr.isLoss:
@@ -542,7 +527,7 @@ class Network(ModuleArgsDict, ABC):
542
527
  def __init__( self,
543
528
  in_channels : int = 1,
544
529
  optimizer: Union[OptimizerLoader, None] = None,
545
- schedulers: Union[LRSchedulersLoader, None] = None,
530
+ schedulers: Union[dict[str, LRSchedulersLoader], None] = None,
546
531
  outputsCriterions: Union[dict[str, TargetCriterionsLoader], None] = None,
547
532
  patch : Union[ModelPatch, None] = None,
548
533
  nb_batch_per_step : int = 1,
@@ -556,7 +541,7 @@ class Network(ModuleArgsDict, ABC):
556
541
  self.optimizer : Union[torch.optim.Optimizer, None] = None
557
542
 
558
543
  self.LRSchedulersLoader = schedulers
559
- self.schedulers : Union[dict[torch.optim.lr_scheduler._LRScheduler, int], None] = None
544
+ self.schedulers : Union[dict[torch.optim.lr_scheduler._LRScheduler, int], None] = {}
560
545
 
561
546
  self.outputsCriterionsLoader = outputsCriterions
562
547
  self.measure : Union[Measure, None] = None
@@ -669,8 +654,8 @@ class Network(ModuleArgsDict, ABC):
669
654
  self._it = int(state_dict["{}_it".format(self.getName())])
670
655
  if "{}_nb_lr_update".format(self.getName()) in state_dict:
671
656
  self._nb_lr_update = int(state_dict["{}_nb_lr_update".format(self.getName())])
672
- if self.schedulers:
673
- for scheduler in self.schedulers:
657
+ for scheduler in self.schedulers:
658
+ if scheduler.last_epoch == -1:
674
659
  scheduler.last_epoch = self._nb_lr_update
675
660
  self.initialized()
676
661
 
@@ -744,7 +729,8 @@ class Network(ModuleArgsDict, ABC):
744
729
  self.optimizer.zero_grad()
745
730
 
746
731
  if self.LRSchedulersLoader and self.optimizer:
747
- self.schedulers = self.LRSchedulersLoader.getschedulers(key, self.optimizer)
732
+ for schedulers_classname, schedulers in self.LRSchedulersLoader.items():
733
+ self.schedulers[schedulers.getschedulers(key, schedulers_classname, self.optimizer)] = schedulers.nb_step
748
734
 
749
735
  def initialized(self):
750
736
  pass
@@ -892,6 +878,7 @@ class Network(ModuleArgsDict, ABC):
892
878
  model._requires_grad(list(self.measure.outputsCriterions.keys()))
893
879
  for loss in self.measure.getLoss():
894
880
  self.scaler.scale(loss / self.nb_batch_per_step).backward()
881
+
895
882
  if self._it % self.nb_batch_per_step == 0:
896
883
  self.scaler.step(self.optimizer)
897
884
  self.scaler.update()
@@ -903,11 +890,10 @@ class Network(ModuleArgsDict, ABC):
903
890
  self._nb_lr_update+=1
904
891
  step = 0
905
892
  scheduler = None
906
- if self.schedulers:
907
- for scheduler, value in self.schedulers.items():
908
- if value is None or (self._nb_lr_update >= step and self._nb_lr_update < step+value):
909
- break
910
- step += value
893
+ for scheduler, value in self.schedulers.items():
894
+ if value is None or (self._nb_lr_update >= step and self._nb_lr_update < step+value):
895
+ break
896
+ step += value
911
897
  if scheduler:
912
898
  if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
913
899
  if self.measure: