konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +509 -290
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +384 -277
  6. konfai/evaluator.py +309 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +341 -222
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +781 -423
  23. konfai/predictor.py +645 -240
  24. konfai/trainer.py +527 -216
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +495 -249
  29. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
  30. konfai-1.1.9.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.7.dist-info/RECORD +0 -39
  33. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,10 @@
1
1
  from functools import partial
2
+
2
3
  import torch
3
4
 
4
- from konfai.network import network, blocks
5
- from konfai.utils.config import config
6
5
  from konfai.data.patching import ModelPatch
6
+ from konfai.network import blocks, network
7
+
7
8
 
8
9
  class Discriminator(network.Network):
9
10
 
@@ -11,121 +12,244 @@ class Discriminator(network.Network):
11
12
 
12
13
  def __init__(self, channels: list[int], strides: list[int], dim: int) -> None:
13
14
  super().__init__()
14
- blockConfig = partial(blocks.BlockConfig, kernel_size=4, padding=1, bias=False, activation=partial(torch.nn.LeakyReLU, negative_slope = 0.2, inplace=True), normMode=blocks.NormMode.SYNCBATCH)
15
+ block_config = partial(
16
+ blocks.BlockConfig,
17
+ kernel_size=4,
18
+ padding=1,
19
+ bias=False,
20
+ activation=partial(torch.nn.LeakyReLU, negative_slope=0.2, inplace=True),
21
+ norm_mode=blocks.NormMode.SYNCBATCH,
22
+ )
15
23
  for i, (in_channels, out_channels, stride) in enumerate(zip(channels, channels[1:], strides)):
16
- self.add_module("Layer_{}".format(i), blocks.ConvBlock(in_channels, out_channels, 1, blockConfig(stride=stride), dim))
17
-
24
+ self.add_module(
25
+ f"Layer_{i}",
26
+ blocks.ConvBlock(in_channels, out_channels, [block_config(stride=stride)], dim),
27
+ )
28
+
18
29
  class DiscriminatorHead(network.ModuleArgsDict):
19
30
 
20
31
  def __init__(self, channels: int, dim: int) -> None:
21
32
  super().__init__()
22
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels=channels, out_channels=1, kernel_size=4, stride=1, padding=1))
23
- self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(tuple([1]*dim)))
33
+ self.add_module(
34
+ "Conv",
35
+ blocks.get_torch_module("Conv", dim)(
36
+ in_channels=channels,
37
+ out_channels=1,
38
+ kernel_size=4,
39
+ stride=1,
40
+ padding=1,
41
+ ),
42
+ )
43
+ self.add_module(
44
+ "AdaptiveAvgPool",
45
+ blocks.get_torch_module("AdaptiveAvgPool", dim)(tuple([1] * dim)),
46
+ )
24
47
  self.add_module("Flatten", torch.nn.Flatten(1))
25
-
26
- def __init__(self,
27
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
28
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
29
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
30
- nb_batch_per_step: int = 64,
31
- dim : int = 3) -> None:
32
- super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=dim, nb_batch_per_step=nb_batch_per_step)
48
+
49
+ def __init__(
50
+ self,
51
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
52
+ schedulers: dict[str, network.LRSchedulersLoader] = {
53
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
54
+ },
55
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
56
+ nb_batch_per_step: int = 64,
57
+ dim: int = 3,
58
+ ) -> None:
59
+ super().__init__(
60
+ in_channels=1,
61
+ optimizer=optimizer,
62
+ schedulers=schedulers,
63
+ outputs_criterions=outputs_criterions,
64
+ dim=dim,
65
+ nb_batch_per_step=nb_batch_per_step,
66
+ )
33
67
  channels = [1, 16, 32, 64, 64]
34
- strides = [2,2,2,1]
68
+ strides = [2, 2, 2, 1]
35
69
  self.add_module("Layers", Discriminator.DiscriminatorNLayers(channels, strides, dim))
36
70
  self.add_module("Head", Discriminator.DiscriminatorHead(channels[-1], dim))
37
71
 
72
+
38
73
  class Generator(network.Network):
39
74
 
40
75
  class GeneratorStem(network.ModuleArgsDict):
41
76
 
42
77
  def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
43
78
  super().__init__()
44
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, nb_conv=1, blockConfig=blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH"), dim=dim))
79
+ self.add_module(
80
+ "ConvBlock",
81
+ blocks.ConvBlock(
82
+ in_channels,
83
+ out_channels,
84
+ block_configs=[blocks.BlockConfig(bias=False, activation="ReLU", norm_mode="SYNCBATCH")],
85
+ dim=dim,
86
+ ),
87
+ )
45
88
 
46
89
  class GeneratorHead(network.ModuleArgsDict):
47
90
 
48
91
  def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
49
92
  super().__init__()
50
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels, in_channels, nb_conv=1, blockConfig=blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH"), dim=dim))
51
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels, out_channels, kernel_size=1, bias=False))
93
+ self.add_module(
94
+ "ConvBlock",
95
+ blocks.ConvBlock(
96
+ in_channels,
97
+ in_channels,
98
+ block_configs=[blocks.BlockConfig(bias=False, activation="ReLU", norm_mode="SYNCBATCH")],
99
+ dim=dim,
100
+ ),
101
+ )
102
+ self.add_module(
103
+ "Conv",
104
+ blocks.get_torch_module("Conv", dim)(in_channels, out_channels, kernel_size=1, bias=False),
105
+ )
52
106
  self.add_module("Tanh", torch.nn.Tanh())
53
107
 
54
108
  class GeneratorDownSample(network.ModuleArgsDict):
55
109
 
56
110
  def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
57
111
  super().__init__()
58
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, nb_conv=1, blockConfig=blocks.BlockConfig(stride=2, bias=False, activation="ReLU", normMode="SYNCBATCH"), dim=dim))
59
-
112
+ self.add_module(
113
+ "ConvBlock",
114
+ blocks.ConvBlock(
115
+ in_channels,
116
+ out_channels,
117
+ block_configs=[blocks.BlockConfig(stride=2, bias=False, activation="ReLU", norm_mode="SYNCBATCH")],
118
+ dim=dim,
119
+ ),
120
+ )
121
+
60
122
  class GeneratorUpSample(network.ModuleArgsDict):
61
123
 
62
124
  def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
63
125
  super().__init__()
64
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, nb_conv=1, blockConfig=blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH"), dim=dim))
65
- self.add_module("Upsample", torch.nn.Upsample(scale_factor=2, mode="bilinear" if dim < 3 else "trilinear"))
66
-
126
+ self.add_module(
127
+ "ConvBlock",
128
+ blocks.ConvBlock(
129
+ in_channels,
130
+ out_channels,
131
+ block_configs=[blocks.BlockConfig(bias=False, activation="ReLU", norm_mode="SYNCBATCH")],
132
+ dim=dim,
133
+ ),
134
+ )
135
+ self.add_module(
136
+ "Upsample",
137
+ torch.nn.Upsample(scale_factor=2, mode="bilinear" if dim < 3 else "trilinear"),
138
+ )
139
+
67
140
  class GeneratorEncoder(network.ModuleArgsDict):
68
141
  def __init__(self, channels: list[int], dim: int) -> None:
69
142
  super().__init__()
70
143
  for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
71
- self.add_module("DownSample_{}".format(i), Generator.GeneratorDownSample(in_channels=in_channels, out_channels=out_channels, dim=dim))
72
-
144
+ self.add_module(
145
+ f"DownSample_{i}",
146
+ Generator.GeneratorDownSample(in_channels=in_channels, out_channels=out_channels, dim=dim),
147
+ )
148
+
73
149
  class GeneratorResnetBlock(network.ModuleArgsDict):
74
150
 
75
- def __init__(self, channels : int, dim : int):
151
+ def __init__(self, channels: int, dim: int):
76
152
  super().__init__()
77
- self.add_module("Conv_0", blocks.getTorchModule("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False))
153
+ self.add_module(
154
+ "Conv_0",
155
+ blocks.get_torch_module("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False),
156
+ )
78
157
  self.add_module("Norm", torch.nn.LeakyReLU(0.2, inplace=True))
79
- self.add_module("Conv_1", blocks.getTorchModule("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False))
80
- self.add_module("Residual", blocks.Add(), in_branch=[0,1])
158
+ self.add_module(
159
+ "Conv_1",
160
+ blocks.get_torch_module("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False),
161
+ )
162
+ self.add_module("Residual", blocks.Add(), in_branch=[0, 1])
81
163
 
82
164
  class GeneratorNResnetBlock(network.ModuleArgsDict):
83
165
 
84
166
  def __init__(self, channels: int, nb_conv: int, dim: int) -> None:
85
167
  super().__init__()
86
168
  for i in range(nb_conv):
87
- self.add_module("ResnetBlock_{}".format(i), Generator.GeneratorResnetBlock(channels=channels, dim=dim))
169
+ self.add_module(
170
+ f"ResnetBlock_{i}",
171
+ Generator.GeneratorResnetBlock(channels=channels, dim=dim),
172
+ )
88
173
 
89
174
  class GeneratorDecoder(network.ModuleArgsDict):
90
175
  def __init__(self, channels: list[int], dim: int) -> None:
91
176
  super().__init__()
92
177
  for i, (in_channels, out_channels) in enumerate(zip(reversed(channels), reversed(channels[:-1]))):
93
- self.add_module("UpSample_{}".format(i), Generator.GeneratorUpSample(in_channels=in_channels, out_channels=out_channels, dim=dim))
94
-
178
+ self.add_module(
179
+ f"UpSample_{i}",
180
+ Generator.GeneratorUpSample(in_channels=in_channels, out_channels=out_channels, dim=dim),
181
+ )
182
+
95
183
  class GeneratorAutoEncoder(network.ModuleArgsDict):
96
184
 
97
185
  def __init__(self, ngf: int, dim: int) -> None:
98
186
  super().__init__()
99
- channels = [ngf, ngf*2]
187
+ channels = [ngf, ngf * 2]
100
188
  self.add_module("Encoder", Generator.GeneratorEncoder(channels, dim))
101
- self.add_module("NResBlock", Generator.GeneratorNResnetBlock(channels=channels[-1], nb_conv=6, dim=dim))
189
+ self.add_module(
190
+ "NResBlock",
191
+ Generator.GeneratorNResnetBlock(channels=channels[-1], nb_conv=6, dim=dim),
192
+ )
102
193
  self.add_module("Decoder", Generator.GeneratorDecoder(channels, dim))
103
-
104
- def __init__(self,
105
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
106
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
107
- patch : ModelPatch = ModelPatch(),
108
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
109
- nb_batch_per_step: int = 64,
110
- dim : int = 3) -> None:
111
- super().__init__(optimizer=optimizer, in_channels=1, schedulers=schedulers, patch=patch, outputsCriterions=outputsCriterions, dim=dim, nb_batch_per_step=nb_batch_per_step)
112
- ngf=32
194
+
195
+ def __init__(
196
+ self,
197
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
198
+ schedulers: dict[str, network.LRSchedulersLoader] = {
199
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
200
+ },
201
+ patch: ModelPatch = ModelPatch(),
202
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
203
+ nb_batch_per_step: int = 64,
204
+ dim: int = 3,
205
+ ) -> None:
206
+ super().__init__(
207
+ optimizer=optimizer,
208
+ in_channels=1,
209
+ schedulers=schedulers,
210
+ patch=patch,
211
+ outputs_criterions=outputs_criterions,
212
+ dim=dim,
213
+ nb_batch_per_step=nb_batch_per_step,
214
+ )
215
+ ngf = 32
113
216
  self.add_module("Stem", Generator.GeneratorStem(1, ngf, dim))
114
217
  self.add_module("AutoEncoder", Generator.GeneratorAutoEncoder(ngf, dim))
115
218
  self.add_module("Head", Generator.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
116
219
 
117
- def getName(self):
220
+ def get_name(self):
118
221
  return "Generator"
119
222
 
223
+
120
224
  class Gan(network.Network):
121
225
 
122
- def __init__(self, generator : Generator = Generator(), discriminator : Discriminator = Discriminator()) -> None:
226
+ def __init__(
227
+ self,
228
+ generator: Generator = Generator(),
229
+ discriminator: Discriminator = Discriminator(),
230
+ ) -> None:
123
231
  super().__init__()
124
- self.add_module("Discriminator_B", discriminator, in_branch=[1], out_branch=[-1], requires_grad=True)
232
+ self.add_module(
233
+ "Discriminator_B",
234
+ discriminator,
235
+ in_branch=[1],
236
+ out_branch=[-1],
237
+ requires_grad=True,
238
+ )
125
239
  self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
126
-
127
- self.add_module("detach", blocks.Detach(), in_branch=["pB"], out_branch=["pB_detach"])
128
- self.add_module("Discriminator_pB_detach", discriminator, in_branch=["pB_detach"], out_branch=[-1])
129
-
130
- self.add_module("Discriminator_pB", discriminator, in_branch=["pB"], out_branch=[-1], requires_grad=False)
131
240
 
241
+ self.add_module("detach", blocks.Detach(), in_branch=["pB"], out_branch=["pB_detach"])
242
+ self.add_module(
243
+ "Discriminator_pB_detach",
244
+ discriminator,
245
+ in_branch=["pB_detach"],
246
+ out_branch=[-1],
247
+ )
248
+
249
+ self.add_module(
250
+ "Discriminator_pB",
251
+ discriminator,
252
+ in_branch=["pB"],
253
+ out_branch=[-1],
254
+ requires_grad=False,
255
+ )
@@ -1,70 +1,170 @@
1
1
  import torch
2
2
 
3
- from konfai.network import network, blocks
4
- from konfai.utils.config import config
3
+ from konfai.network import blocks, network
4
+
5
5
 
6
6
  class VAE(network.Network):
7
7
 
8
8
  class AutoEncoderBlock(network.ModuleArgsDict):
9
9
 
10
- def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, dim: int, block: type, i : int = 0) -> None:
10
+ def __init__(
11
+ self,
12
+ channels: list[int],
13
+ nb_conv_per_stage: int,
14
+ block_config: blocks.BlockConfig,
15
+ downsample_mode: blocks.DownsampleMode,
16
+ upsample_mode: blocks.UpsampleMode,
17
+ dim: int,
18
+ block: type,
19
+ i: int = 0,
20
+ ) -> None:
11
21
  super().__init__()
12
22
  if i > 0:
13
- self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
14
- self.add_module("DownBlock", block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim))
23
+ self.add_module(
24
+ downsample_mode.name,
25
+ blocks.downsample(
26
+ in_channels=channels[0],
27
+ out_channels=channels[1],
28
+ downsample_mode=downsample_mode,
29
+ dim=dim,
30
+ ),
31
+ )
32
+ self.add_module(
33
+ "DownBlock",
34
+ block(
35
+ in_channels=channels[(1 if downsample_mode == blocks.DownsampleMode.CONV_STRIDE and i > 0 else 0)],
36
+ out_channels=channels[1],
37
+ block_configs=[block_config] * nb_conv_per_stage,
38
+ dim=dim,
39
+ ),
40
+ )
15
41
  if len(channels) > 2:
16
- self.add_module("AutoEncoder_{}".format(i+1), VAE.AutoEncoderBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, dim, block, i+1))
17
- self.add_module("UpBlock", block(in_channels=channels[2] if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim))
42
+ self.add_module(
43
+ f"AutoEncoder_{i + 1}",
44
+ VAE.AutoEncoderBlock(
45
+ channels[1:],
46
+ nb_conv_per_stage,
47
+ block_config,
48
+ downsample_mode,
49
+ upsample_mode,
50
+ dim,
51
+ block,
52
+ i + 1,
53
+ ),
54
+ )
55
+ self.add_module(
56
+ "UpBlock",
57
+ block(
58
+ in_channels=(
59
+ channels[2] if upsample_mode != blocks.UpsampleMode.CONV_TRANSPOSE else channels[1]
60
+ ),
61
+ out_channels=channels[1],
62
+ block_configs=[block_config] * nb_conv_per_stage,
63
+ dim=dim,
64
+ ),
65
+ )
18
66
  if i > 0:
19
- self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim))
20
-
21
- class VAE_Head(network.ModuleArgsDict):
67
+ self.add_module(
68
+ upsample_mode.name,
69
+ blocks.upsample(
70
+ in_channels=channels[1],
71
+ out_channels=channels[0],
72
+ upsample_mode=upsample_mode,
73
+ dim=dim,
74
+ ),
75
+ )
76
+
77
+ class VAEHead(network.ModuleArgsDict):
22
78
 
23
79
  def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
24
80
  super().__init__()
25
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
26
- self.add_module("Tanh", torch.nn.Tanh())
27
-
28
- def __init__(self,
29
- optimizer: network.OptimizerLoader = network.OptimizerLoader(),
30
- schedulers: network.LRSchedulersLoader = network.LRSchedulersLoader(),
31
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
32
- dim : int = 3,
33
- channels: list[int]=[1, 64, 128, 256, 512, 1024],
34
- blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
35
- nb_conv_per_stage: int = 2,
36
- downSampleMode: str = "MAXPOOL",
37
- upSampleMode: str = "CONV_TRANSPOSE",
38
- blockType: str = "Conv") -> None:
39
-
40
- super().__init__(in_channels = channels[0], init_type="normal", optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=dim, nb_batch_per_step=1)
41
- self.add_module("AutoEncoder_0", VAE.AutoEncoderBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], dim=dim, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock))
42
- self.add_module("Head", VAE.VAE_Head(channels[1], channels[0], dim))
81
+ self.add_module(
82
+ "Conv",
83
+ blocks.get_torch_module("Conv", dim)(
84
+ in_channels=in_channels,
85
+ out_channels=out_channels,
86
+ kernel_size=3,
87
+ stride=1,
88
+ padding=1,
89
+ ),
90
+ )
91
+ self.add_module("Tanh", torch.nn.Tanh())
92
+
93
+ def __init__(
94
+ self,
95
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
96
+ schedulers: dict[str, network.LRSchedulersLoader] = {
97
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
98
+ },
99
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
100
+ dim: int = 3,
101
+ channels: list[int] = [1, 64, 128, 256, 512, 1024],
102
+ block_config: blocks.BlockConfig = blocks.BlockConfig(),
103
+ nb_conv_per_stage: int = 2,
104
+ downsample_mode: str = "MAXPOOL",
105
+ upsample_mode: str = "CONV_TRANSPOSE",
106
+ block_type: str = "Conv",
107
+ ) -> None:
108
+
109
+ super().__init__(
110
+ in_channels=channels[0],
111
+ init_type="normal",
112
+ optimizer=optimizer,
113
+ schedulers=schedulers,
114
+ outputs_criterions=outputs_criterions,
115
+ dim=dim,
116
+ nb_batch_per_step=1,
117
+ )
118
+ self.add_module(
119
+ "AutoEncoder_0",
120
+ VAE.AutoEncoderBlock(
121
+ channels,
122
+ nb_conv_per_stage,
123
+ block_config,
124
+ downsample_mode=blocks.DownsampleMode[downsample_mode],
125
+ upsample_mode=blocks.UpsampleMode[upsample_mode],
126
+ dim=dim,
127
+ block=blocks.ConvBlock if block_type == "Conv" else blocks.ResBlock,
128
+ ),
129
+ )
130
+ self.add_module("Head", VAE.VAEHead(channels[1], channels[0], dim))
43
131
 
44
132
 
45
133
  class LinearVAE(network.Network):
46
134
 
47
- class LinearVAE_DenseLayer(network.ModuleArgsDict):
135
+ class LinearVAEDenseLayer(network.ModuleArgsDict):
48
136
 
49
137
  def __init__(self, in_features: int, out_features: int) -> None:
50
138
  super().__init__()
51
139
  self.add_module("Linear", torch.nn.Linear(in_features, out_features))
52
- #self.add_module("Norm", torch.nn.BatchNorm1d(out_features))
140
+ # self.add_module("Norm", torch.nn.BatchNorm1d(out_features))
53
141
  self.add_module("Activation", torch.nn.LeakyReLU())
54
142
 
55
- class LinearVAE_Head(network.ModuleArgsDict):
143
+ class LinearVAEHead(network.ModuleArgsDict):
56
144
 
57
145
  def __init__(self, in_features: int, out_features: int) -> None:
58
146
  super().__init__()
59
147
  self.add_module("Linear", torch.nn.Linear(in_features, out_features))
60
148
  self.add_module("Tanh", torch.nn.Tanh())
61
149
 
62
- def __init__(self,
63
- optimizer: network.OptimizerLoader = network.OptimizerLoader(),
64
- schedulers: network.LRSchedulersLoader = network.LRSchedulersLoader(),
65
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},) -> None:
66
- super().__init__(in_channels = 1, init_type="normal", optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, dim=1, nb_batch_per_step=1)
67
- self.add_module("DenseLayer_0", LinearVAE.LinearVAE_DenseLayer(23343, 5))
68
- #self.add_module("Head", LinearVAE.DenseLayer(100, 28590))
69
- self.add_module("Head", LinearVAE.LinearVAE_Head(5, 23343))
70
- #self.add_module("DenseLayer_5", LinearVAE.DenseLayer(5000, 28590))
150
+ def __init__(
151
+ self,
152
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
153
+ schedulers: dict[str, network.LRSchedulersLoader] = {
154
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
155
+ },
156
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
157
+ ) -> None:
158
+ super().__init__(
159
+ in_channels=1,
160
+ init_type="normal",
161
+ optimizer=optimizer,
162
+ schedulers=schedulers,
163
+ outputs_criterions=outputs_criterions,
164
+ dim=1,
165
+ nb_batch_per_step=1,
166
+ )
167
+ self.add_module("DenseLayer_0", LinearVAE.LinearVAEDenseLayer(23343, 5))
168
+ # self.add_module("Head", LinearVAE.DenseLayer(100, 28590))
169
+ self.add_module("Head", LinearVAE.LinearVAEHead(5, 23343))
170
+ # self.add_module("DenseLayer_5", LinearVAE.DenseLayer(5000, 28590))