konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +533 -316
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +408 -275
- konfai/evaluator.py +325 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +360 -244
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +795 -427
- konfai/predictor.py +644 -238
- konfai/trainer.py +509 -222
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +497 -249
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
- konfai-1.2.0.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.8.dist-info/RECORD +0 -39
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/models/generation/gan.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from functools import partial
|
|
2
|
+
|
|
2
3
|
import torch
|
|
3
4
|
|
|
4
|
-
from konfai.network import network, blocks
|
|
5
|
-
from konfai.utils.config import config
|
|
6
5
|
from konfai.data.patching import ModelPatch
|
|
6
|
+
from konfai.network import blocks, network
|
|
7
|
+
|
|
7
8
|
|
|
8
9
|
class Discriminator(network.Network):
|
|
9
10
|
|
|
@@ -11,121 +12,244 @@ class Discriminator(network.Network):
|
|
|
11
12
|
|
|
12
13
|
def __init__(self, channels: list[int], strides: list[int], dim: int) -> None:
|
|
13
14
|
super().__init__()
|
|
14
|
-
|
|
15
|
+
block_config = partial(
|
|
16
|
+
blocks.BlockConfig,
|
|
17
|
+
kernel_size=4,
|
|
18
|
+
padding=1,
|
|
19
|
+
bias=False,
|
|
20
|
+
activation=partial(torch.nn.LeakyReLU, negative_slope=0.2, inplace=True),
|
|
21
|
+
norm_mode=blocks.NormMode.SYNCBATCH,
|
|
22
|
+
)
|
|
15
23
|
for i, (in_channels, out_channels, stride) in enumerate(zip(channels, channels[1:], strides)):
|
|
16
|
-
self.add_module(
|
|
17
|
-
|
|
24
|
+
self.add_module(
|
|
25
|
+
f"Layer_{i}",
|
|
26
|
+
blocks.ConvBlock(in_channels, out_channels, [block_config(stride=stride)], dim),
|
|
27
|
+
)
|
|
28
|
+
|
|
18
29
|
class DiscriminatorHead(network.ModuleArgsDict):
|
|
19
30
|
|
|
20
31
|
def __init__(self, channels: int, dim: int) -> None:
|
|
21
32
|
super().__init__()
|
|
22
|
-
self.add_module(
|
|
23
|
-
|
|
33
|
+
self.add_module(
|
|
34
|
+
"Conv",
|
|
35
|
+
blocks.get_torch_module("Conv", dim)(
|
|
36
|
+
in_channels=channels,
|
|
37
|
+
out_channels=1,
|
|
38
|
+
kernel_size=4,
|
|
39
|
+
stride=1,
|
|
40
|
+
padding=1,
|
|
41
|
+
),
|
|
42
|
+
)
|
|
43
|
+
self.add_module(
|
|
44
|
+
"AdaptiveAvgPool",
|
|
45
|
+
blocks.get_torch_module("AdaptiveAvgPool", dim)(tuple([1] * dim)),
|
|
46
|
+
)
|
|
24
47
|
self.add_module("Flatten", torch.nn.Flatten(1))
|
|
25
|
-
|
|
26
|
-
def __init__(
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
52
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
|
53
|
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
|
54
|
+
},
|
|
55
|
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
|
56
|
+
nb_batch_per_step: int = 64,
|
|
57
|
+
dim: int = 3,
|
|
58
|
+
) -> None:
|
|
59
|
+
super().__init__(
|
|
60
|
+
in_channels=1,
|
|
61
|
+
optimizer=optimizer,
|
|
62
|
+
schedulers=schedulers,
|
|
63
|
+
outputs_criterions=outputs_criterions,
|
|
64
|
+
dim=dim,
|
|
65
|
+
nb_batch_per_step=nb_batch_per_step,
|
|
66
|
+
)
|
|
33
67
|
channels = [1, 16, 32, 64, 64]
|
|
34
|
-
strides = [2,2,2,1]
|
|
68
|
+
strides = [2, 2, 2, 1]
|
|
35
69
|
self.add_module("Layers", Discriminator.DiscriminatorNLayers(channels, strides, dim))
|
|
36
70
|
self.add_module("Head", Discriminator.DiscriminatorHead(channels[-1], dim))
|
|
37
71
|
|
|
72
|
+
|
|
38
73
|
class Generator(network.Network):
|
|
39
74
|
|
|
40
75
|
class GeneratorStem(network.ModuleArgsDict):
|
|
41
76
|
|
|
42
77
|
def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
|
|
43
78
|
super().__init__()
|
|
44
|
-
self.add_module(
|
|
79
|
+
self.add_module(
|
|
80
|
+
"ConvBlock",
|
|
81
|
+
blocks.ConvBlock(
|
|
82
|
+
in_channels,
|
|
83
|
+
out_channels,
|
|
84
|
+
block_configs=[blocks.BlockConfig(bias=False, activation="ReLU", norm_mode="SYNCBATCH")],
|
|
85
|
+
dim=dim,
|
|
86
|
+
),
|
|
87
|
+
)
|
|
45
88
|
|
|
46
89
|
class GeneratorHead(network.ModuleArgsDict):
|
|
47
90
|
|
|
48
91
|
def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
|
|
49
92
|
super().__init__()
|
|
50
|
-
self.add_module(
|
|
51
|
-
|
|
93
|
+
self.add_module(
|
|
94
|
+
"ConvBlock",
|
|
95
|
+
blocks.ConvBlock(
|
|
96
|
+
in_channels,
|
|
97
|
+
in_channels,
|
|
98
|
+
block_configs=[blocks.BlockConfig(bias=False, activation="ReLU", norm_mode="SYNCBATCH")],
|
|
99
|
+
dim=dim,
|
|
100
|
+
),
|
|
101
|
+
)
|
|
102
|
+
self.add_module(
|
|
103
|
+
"Conv",
|
|
104
|
+
blocks.get_torch_module("Conv", dim)(in_channels, out_channels, kernel_size=1, bias=False),
|
|
105
|
+
)
|
|
52
106
|
self.add_module("Tanh", torch.nn.Tanh())
|
|
53
107
|
|
|
54
108
|
class GeneratorDownSample(network.ModuleArgsDict):
|
|
55
109
|
|
|
56
110
|
def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
|
|
57
111
|
super().__init__()
|
|
58
|
-
self.add_module(
|
|
59
|
-
|
|
112
|
+
self.add_module(
|
|
113
|
+
"ConvBlock",
|
|
114
|
+
blocks.ConvBlock(
|
|
115
|
+
in_channels,
|
|
116
|
+
out_channels,
|
|
117
|
+
block_configs=[blocks.BlockConfig(stride=2, bias=False, activation="ReLU", norm_mode="SYNCBATCH")],
|
|
118
|
+
dim=dim,
|
|
119
|
+
),
|
|
120
|
+
)
|
|
121
|
+
|
|
60
122
|
class GeneratorUpSample(network.ModuleArgsDict):
|
|
61
123
|
|
|
62
124
|
def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
|
|
63
125
|
super().__init__()
|
|
64
|
-
self.add_module(
|
|
65
|
-
|
|
66
|
-
|
|
126
|
+
self.add_module(
|
|
127
|
+
"ConvBlock",
|
|
128
|
+
blocks.ConvBlock(
|
|
129
|
+
in_channels,
|
|
130
|
+
out_channels,
|
|
131
|
+
block_configs=[blocks.BlockConfig(bias=False, activation="ReLU", norm_mode="SYNCBATCH")],
|
|
132
|
+
dim=dim,
|
|
133
|
+
),
|
|
134
|
+
)
|
|
135
|
+
self.add_module(
|
|
136
|
+
"Upsample",
|
|
137
|
+
torch.nn.Upsample(scale_factor=2, mode="bilinear" if dim < 3 else "trilinear"),
|
|
138
|
+
)
|
|
139
|
+
|
|
67
140
|
class GeneratorEncoder(network.ModuleArgsDict):
|
|
68
141
|
def __init__(self, channels: list[int], dim: int) -> None:
|
|
69
142
|
super().__init__()
|
|
70
143
|
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
|
|
71
|
-
self.add_module(
|
|
72
|
-
|
|
144
|
+
self.add_module(
|
|
145
|
+
f"DownSample_{i}",
|
|
146
|
+
Generator.GeneratorDownSample(in_channels=in_channels, out_channels=out_channels, dim=dim),
|
|
147
|
+
)
|
|
148
|
+
|
|
73
149
|
class GeneratorResnetBlock(network.ModuleArgsDict):
|
|
74
150
|
|
|
75
|
-
def __init__(self, channels
|
|
151
|
+
def __init__(self, channels: int, dim: int):
|
|
76
152
|
super().__init__()
|
|
77
|
-
self.add_module(
|
|
153
|
+
self.add_module(
|
|
154
|
+
"Conv_0",
|
|
155
|
+
blocks.get_torch_module("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False),
|
|
156
|
+
)
|
|
78
157
|
self.add_module("Norm", torch.nn.LeakyReLU(0.2, inplace=True))
|
|
79
|
-
self.add_module(
|
|
80
|
-
|
|
158
|
+
self.add_module(
|
|
159
|
+
"Conv_1",
|
|
160
|
+
blocks.get_torch_module("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False),
|
|
161
|
+
)
|
|
162
|
+
self.add_module("Residual", blocks.Add(), in_branch=[0, 1])
|
|
81
163
|
|
|
82
164
|
class GeneratorNResnetBlock(network.ModuleArgsDict):
|
|
83
165
|
|
|
84
166
|
def __init__(self, channels: int, nb_conv: int, dim: int) -> None:
|
|
85
167
|
super().__init__()
|
|
86
168
|
for i in range(nb_conv):
|
|
87
|
-
self.add_module(
|
|
169
|
+
self.add_module(
|
|
170
|
+
f"ResnetBlock_{i}",
|
|
171
|
+
Generator.GeneratorResnetBlock(channels=channels, dim=dim),
|
|
172
|
+
)
|
|
88
173
|
|
|
89
174
|
class GeneratorDecoder(network.ModuleArgsDict):
|
|
90
175
|
def __init__(self, channels: list[int], dim: int) -> None:
|
|
91
176
|
super().__init__()
|
|
92
177
|
for i, (in_channels, out_channels) in enumerate(zip(reversed(channels), reversed(channels[:-1]))):
|
|
93
|
-
self.add_module(
|
|
94
|
-
|
|
178
|
+
self.add_module(
|
|
179
|
+
f"UpSample_{i}",
|
|
180
|
+
Generator.GeneratorUpSample(in_channels=in_channels, out_channels=out_channels, dim=dim),
|
|
181
|
+
)
|
|
182
|
+
|
|
95
183
|
class GeneratorAutoEncoder(network.ModuleArgsDict):
|
|
96
184
|
|
|
97
185
|
def __init__(self, ngf: int, dim: int) -> None:
|
|
98
186
|
super().__init__()
|
|
99
|
-
channels = [ngf, ngf*2]
|
|
187
|
+
channels = [ngf, ngf * 2]
|
|
100
188
|
self.add_module("Encoder", Generator.GeneratorEncoder(channels, dim))
|
|
101
|
-
self.add_module(
|
|
189
|
+
self.add_module(
|
|
190
|
+
"NResBlock",
|
|
191
|
+
Generator.GeneratorNResnetBlock(channels=channels[-1], nb_conv=6, dim=dim),
|
|
192
|
+
)
|
|
102
193
|
self.add_module("Decoder", Generator.GeneratorDecoder(channels, dim))
|
|
103
|
-
|
|
104
|
-
def __init__(
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
194
|
+
|
|
195
|
+
def __init__(
|
|
196
|
+
self,
|
|
197
|
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
198
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
|
199
|
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
|
200
|
+
},
|
|
201
|
+
patch: ModelPatch = ModelPatch(),
|
|
202
|
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
|
203
|
+
nb_batch_per_step: int = 64,
|
|
204
|
+
dim: int = 3,
|
|
205
|
+
) -> None:
|
|
206
|
+
super().__init__(
|
|
207
|
+
optimizer=optimizer,
|
|
208
|
+
in_channels=1,
|
|
209
|
+
schedulers=schedulers,
|
|
210
|
+
patch=patch,
|
|
211
|
+
outputs_criterions=outputs_criterions,
|
|
212
|
+
dim=dim,
|
|
213
|
+
nb_batch_per_step=nb_batch_per_step,
|
|
214
|
+
)
|
|
215
|
+
ngf = 32
|
|
113
216
|
self.add_module("Stem", Generator.GeneratorStem(1, ngf, dim))
|
|
114
217
|
self.add_module("AutoEncoder", Generator.GeneratorAutoEncoder(ngf, dim))
|
|
115
218
|
self.add_module("Head", Generator.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
|
|
116
219
|
|
|
117
|
-
def
|
|
220
|
+
def get_name(self):
|
|
118
221
|
return "Generator"
|
|
119
222
|
|
|
223
|
+
|
|
120
224
|
class Gan(network.Network):
|
|
121
225
|
|
|
122
|
-
def __init__(
|
|
226
|
+
def __init__(
|
|
227
|
+
self,
|
|
228
|
+
generator: Generator = Generator(),
|
|
229
|
+
discriminator: Discriminator = Discriminator(),
|
|
230
|
+
) -> None:
|
|
123
231
|
super().__init__()
|
|
124
|
-
self.add_module(
|
|
232
|
+
self.add_module(
|
|
233
|
+
"Discriminator_B",
|
|
234
|
+
discriminator,
|
|
235
|
+
in_branch=[1],
|
|
236
|
+
out_branch=[-1],
|
|
237
|
+
requires_grad=True,
|
|
238
|
+
)
|
|
125
239
|
self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
|
|
126
|
-
|
|
127
|
-
self.add_module("detach", blocks.Detach(), in_branch=["pB"], out_branch=["pB_detach"])
|
|
128
|
-
self.add_module("Discriminator_pB_detach", discriminator, in_branch=["pB_detach"], out_branch=[-1])
|
|
129
|
-
|
|
130
|
-
self.add_module("Discriminator_pB", discriminator, in_branch=["pB"], out_branch=[-1], requires_grad=False)
|
|
131
240
|
|
|
241
|
+
self.add_module("detach", blocks.Detach(), in_branch=["pB"], out_branch=["pB_detach"])
|
|
242
|
+
self.add_module(
|
|
243
|
+
"Discriminator_pB_detach",
|
|
244
|
+
discriminator,
|
|
245
|
+
in_branch=["pB_detach"],
|
|
246
|
+
out_branch=[-1],
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
self.add_module(
|
|
250
|
+
"Discriminator_pB",
|
|
251
|
+
discriminator,
|
|
252
|
+
in_branch=["pB"],
|
|
253
|
+
out_branch=[-1],
|
|
254
|
+
requires_grad=False,
|
|
255
|
+
)
|
konfai/models/generation/vae.py
CHANGED
|
@@ -1,70 +1,170 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from konfai.network import
|
|
4
|
-
|
|
3
|
+
from konfai.network import blocks, network
|
|
4
|
+
|
|
5
5
|
|
|
6
6
|
class VAE(network.Network):
|
|
7
7
|
|
|
8
8
|
class AutoEncoderBlock(network.ModuleArgsDict):
|
|
9
9
|
|
|
10
|
-
def __init__(
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
channels: list[int],
|
|
13
|
+
nb_conv_per_stage: int,
|
|
14
|
+
block_config: blocks.BlockConfig,
|
|
15
|
+
downsample_mode: blocks.DownsampleMode,
|
|
16
|
+
upsample_mode: blocks.UpsampleMode,
|
|
17
|
+
dim: int,
|
|
18
|
+
block: type,
|
|
19
|
+
i: int = 0,
|
|
20
|
+
) -> None:
|
|
11
21
|
super().__init__()
|
|
12
22
|
if i > 0:
|
|
13
|
-
self.add_module(
|
|
14
|
-
|
|
23
|
+
self.add_module(
|
|
24
|
+
downsample_mode.name,
|
|
25
|
+
blocks.downsample(
|
|
26
|
+
in_channels=channels[0],
|
|
27
|
+
out_channels=channels[1],
|
|
28
|
+
downsample_mode=downsample_mode,
|
|
29
|
+
dim=dim,
|
|
30
|
+
),
|
|
31
|
+
)
|
|
32
|
+
self.add_module(
|
|
33
|
+
"DownBlock",
|
|
34
|
+
block(
|
|
35
|
+
in_channels=channels[(1 if downsample_mode == blocks.DownsampleMode.CONV_STRIDE and i > 0 else 0)],
|
|
36
|
+
out_channels=channels[1],
|
|
37
|
+
block_configs=[block_config] * nb_conv_per_stage,
|
|
38
|
+
dim=dim,
|
|
39
|
+
),
|
|
40
|
+
)
|
|
15
41
|
if len(channels) > 2:
|
|
16
|
-
self.add_module(
|
|
17
|
-
|
|
42
|
+
self.add_module(
|
|
43
|
+
f"AutoEncoder_{i + 1}",
|
|
44
|
+
VAE.AutoEncoderBlock(
|
|
45
|
+
channels[1:],
|
|
46
|
+
nb_conv_per_stage,
|
|
47
|
+
block_config,
|
|
48
|
+
downsample_mode,
|
|
49
|
+
upsample_mode,
|
|
50
|
+
dim,
|
|
51
|
+
block,
|
|
52
|
+
i + 1,
|
|
53
|
+
),
|
|
54
|
+
)
|
|
55
|
+
self.add_module(
|
|
56
|
+
"UpBlock",
|
|
57
|
+
block(
|
|
58
|
+
in_channels=(
|
|
59
|
+
channels[2] if upsample_mode != blocks.UpsampleMode.CONV_TRANSPOSE else channels[1]
|
|
60
|
+
),
|
|
61
|
+
out_channels=channels[1],
|
|
62
|
+
block_configs=[block_config] * nb_conv_per_stage,
|
|
63
|
+
dim=dim,
|
|
64
|
+
),
|
|
65
|
+
)
|
|
18
66
|
if i > 0:
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
67
|
+
self.add_module(
|
|
68
|
+
upsample_mode.name,
|
|
69
|
+
blocks.upsample(
|
|
70
|
+
in_channels=channels[1],
|
|
71
|
+
out_channels=channels[0],
|
|
72
|
+
upsample_mode=upsample_mode,
|
|
73
|
+
dim=dim,
|
|
74
|
+
),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
class VAEHead(network.ModuleArgsDict):
|
|
22
78
|
|
|
23
79
|
def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
|
|
24
80
|
super().__init__()
|
|
25
|
-
self.add_module(
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
81
|
+
self.add_module(
|
|
82
|
+
"Conv",
|
|
83
|
+
blocks.get_torch_module("Conv", dim)(
|
|
84
|
+
in_channels=in_channels,
|
|
85
|
+
out_channels=out_channels,
|
|
86
|
+
kernel_size=3,
|
|
87
|
+
stride=1,
|
|
88
|
+
padding=1,
|
|
89
|
+
),
|
|
90
|
+
)
|
|
91
|
+
self.add_module("Tanh", torch.nn.Tanh())
|
|
92
|
+
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
96
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
|
97
|
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
|
98
|
+
},
|
|
99
|
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
|
100
|
+
dim: int = 3,
|
|
101
|
+
channels: list[int] = [1, 64, 128, 256, 512, 1024],
|
|
102
|
+
block_config: blocks.BlockConfig = blocks.BlockConfig(),
|
|
103
|
+
nb_conv_per_stage: int = 2,
|
|
104
|
+
downsample_mode: str = "MAXPOOL",
|
|
105
|
+
upsample_mode: str = "CONV_TRANSPOSE",
|
|
106
|
+
block_type: str = "Conv",
|
|
107
|
+
) -> None:
|
|
108
|
+
|
|
109
|
+
super().__init__(
|
|
110
|
+
in_channels=channels[0],
|
|
111
|
+
init_type="normal",
|
|
112
|
+
optimizer=optimizer,
|
|
113
|
+
schedulers=schedulers,
|
|
114
|
+
outputs_criterions=outputs_criterions,
|
|
115
|
+
dim=dim,
|
|
116
|
+
nb_batch_per_step=1,
|
|
117
|
+
)
|
|
118
|
+
self.add_module(
|
|
119
|
+
"AutoEncoder_0",
|
|
120
|
+
VAE.AutoEncoderBlock(
|
|
121
|
+
channels,
|
|
122
|
+
nb_conv_per_stage,
|
|
123
|
+
block_config,
|
|
124
|
+
downsample_mode=blocks.DownsampleMode[downsample_mode],
|
|
125
|
+
upsample_mode=blocks.UpsampleMode[upsample_mode],
|
|
126
|
+
dim=dim,
|
|
127
|
+
block=blocks.ConvBlock if block_type == "Conv" else blocks.ResBlock,
|
|
128
|
+
),
|
|
129
|
+
)
|
|
130
|
+
self.add_module("Head", VAE.VAEHead(channels[1], channels[0], dim))
|
|
43
131
|
|
|
44
132
|
|
|
45
133
|
class LinearVAE(network.Network):
|
|
46
134
|
|
|
47
|
-
class
|
|
135
|
+
class LinearVAEDenseLayer(network.ModuleArgsDict):
|
|
48
136
|
|
|
49
137
|
def __init__(self, in_features: int, out_features: int) -> None:
|
|
50
138
|
super().__init__()
|
|
51
139
|
self.add_module("Linear", torch.nn.Linear(in_features, out_features))
|
|
52
|
-
#self.add_module("Norm", torch.nn.BatchNorm1d(out_features))
|
|
140
|
+
# self.add_module("Norm", torch.nn.BatchNorm1d(out_features))
|
|
53
141
|
self.add_module("Activation", torch.nn.LeakyReLU())
|
|
54
142
|
|
|
55
|
-
class
|
|
143
|
+
class LinearVAEHead(network.ModuleArgsDict):
|
|
56
144
|
|
|
57
145
|
def __init__(self, in_features: int, out_features: int) -> None:
|
|
58
146
|
super().__init__()
|
|
59
147
|
self.add_module("Linear", torch.nn.Linear(in_features, out_features))
|
|
60
148
|
self.add_module("Tanh", torch.nn.Tanh())
|
|
61
149
|
|
|
62
|
-
def __init__(
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
150
|
+
def __init__(
|
|
151
|
+
self,
|
|
152
|
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
153
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
|
154
|
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
|
155
|
+
},
|
|
156
|
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
|
157
|
+
) -> None:
|
|
158
|
+
super().__init__(
|
|
159
|
+
in_channels=1,
|
|
160
|
+
init_type="normal",
|
|
161
|
+
optimizer=optimizer,
|
|
162
|
+
schedulers=schedulers,
|
|
163
|
+
outputs_criterions=outputs_criterions,
|
|
164
|
+
dim=1,
|
|
165
|
+
nb_batch_per_step=1,
|
|
166
|
+
)
|
|
167
|
+
self.add_module("DenseLayer_0", LinearVAE.LinearVAEDenseLayer(23343, 5))
|
|
168
|
+
# self.add_module("Head", LinearVAE.DenseLayer(100, 28590))
|
|
169
|
+
self.add_module("Head", LinearVAE.LinearVAEHead(5, 23343))
|
|
170
|
+
# self.add_module("DenseLayer_5", LinearVAE.DenseLayer(5000, 28590))
|