konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -0,0 +1,557 @@
1
+ from functools import partial
2
+ from typing import Union
3
+ import torch
4
+ import numpy as np
5
+
6
+ from KonfAI.konfai.network import network, blocks
7
+ from KonfAI.konfai.utils.config import config
8
+ from KonfAI.konfai.data.HDF5 import ModelPatch, Attribute
9
+ from KonfAI.konfai.data import augmentation
10
+ from KonfAI.konfai.models.segmentation import UNet, NestedUNet
11
+ from KonfAI.konfai.models.generation.ddpm import DDPM
12
+
13
+ class Discriminator(network.Network):
14
+
15
+ class DiscriminatorNLayers(network.ModuleArgsDict):
16
+
17
+ def __init__(self, channels: list[int], strides: list[int], dim: int) -> None:
18
+ super().__init__()
19
+ blockConfig = partial(blocks.BlockConfig, kernel_size=4, padding=1, bias=False, activation=partial(torch.nn.LeakyReLU, negative_slope = 0.2, inplace=True), normMode=blocks.NormMode.SYNCBATCH)
20
+ for i, (in_channels, out_channels, stride) in enumerate(zip(channels, channels[1:], strides)):
21
+ self.add_module("Layer_{}".format(i), blocks.ConvBlock(in_channels, out_channels, [blockConfig(stride=stride)], dim))
22
+
23
+ class DiscriminatorHead(network.ModuleArgsDict):
24
+
25
+ def __init__(self, channels: int, dim: int) -> None:
26
+ super().__init__()
27
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels=channels, out_channels=1, kernel_size=4, stride=1, padding=1))
28
+ #self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(tuple([1]*dim)))
29
+ #self.add_module("Flatten", torch.nn.Flatten(1))
30
+
31
+ class DiscriminatorBlock(network.ModuleArgsDict):
32
+
33
+ def __init__(self, channels: list[int] = [1, 16, 32, 64, 64],
34
+ strides: list[int] = [2,2,2,1],
35
+ dim : int = 3) -> None:
36
+ super().__init__()
37
+ self.add_module("Layers", Discriminator.DiscriminatorNLayers(channels, strides, dim))
38
+ self.add_module("Head", Discriminator.DiscriminatorHead(channels[-1], dim))
39
+
40
+ @config("Discriminator")
41
+ def __init__(self,
42
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
43
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
44
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
45
+ channels: list[int] = [1, 16, 32, 64, 64],
46
+ strides: list[int] = [2,2,2,1],
47
+ nb_batch_per_step: int = 1,
48
+ dim : int = 3) -> None:
49
+ super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, nb_batch_per_step=nb_batch_per_step, dim=dim, init_type="kaiming")
50
+ self.add_module("DiscriminatorModel", Discriminator.DiscriminatorBlock(channels, strides, dim))
51
+
52
+
53
+ class Discriminator_ADA(network.Network):
54
+
55
+ class DDPM_TE(torch.nn.Module):
56
+
57
+ def __init__(self, in_channels: int, out_channels: int) -> None:
58
+ super().__init__()
59
+ self.linear_0 = torch.nn.Linear(in_channels, out_channels)
60
+ self.siLU = torch.nn.SiLU()
61
+ self.linear_1 = torch.nn.Linear(out_channels, out_channels)
62
+
63
+ def forward(self, input: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
64
+ return input + self.linear_1(self.siLU(self.linear_0(t))).reshape(input.shape[0], -1, *[1 for _ in range(len(input.shape)-2)])
65
+
66
+ class DiscriminatorNLayers(network.ModuleArgsDict):
67
+
68
+ def __init__(self, channels: list[int], strides: list[int], time_embedding_dim: int, dim: int) -> None:
69
+ super().__init__()
70
+ blockConfig = partial(blocks.BlockConfig, kernel_size=4, padding=1, bias=False, activation=partial(torch.nn.LeakyReLU, negative_slope = 0.2, inplace=True), normMode=blocks.NormMode.SYNCBATCH)
71
+ for i, (in_channels, out_channels, stride) in enumerate(zip(channels, channels[1:], strides)):
72
+ self.add_module("Te_{}".format(i), Discriminator_ADA.DDPM_TE(time_embedding_dim, in_channels), in_branch=[0, 1])
73
+ self.add_module("Layer_{}".format(i), blocks.ConvBlock(in_channels, out_channels, [blockConfig(stride=stride)], dim))
74
+
75
+ class DiscriminatorHead(network.ModuleArgsDict):
76
+
77
+ def __init__(self, channels: int, dim: int) -> None:
78
+ super().__init__()
79
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels=channels, out_channels=1, kernel_size=4, stride=1, padding=1))
80
+ #self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(tuple([1]*dim)))
81
+ #self.add_module("Flatten", torch.nn.Flatten(1))
82
+
83
+ class UpdateP(torch.nn.Module):
84
+
85
+ def __init__(self):
86
+ super().__init__()
87
+ self._it = 0
88
+ self.n = 4
89
+ self.ada_target = 0.25
90
+ self.ada_interval = 0.001
91
+ self.ada_kimg = 500
92
+
93
+ self.measure = None
94
+ self.names = None
95
+ self.p = 0
96
+
97
+ def setMeasure(self, measure: network.Measure, names: list[str]):
98
+ self.measure = measure
99
+ self.names = names
100
+
101
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
102
+ if self.measure is not None and self._it % self.n == 0:
103
+ value = sum([v for k, v in self.measure.getLastValues(self.n).items() if k in self.names])
104
+ adjust = np.sign(self.ada_target-value) * (self.ada_interval)
105
+ self.p += adjust
106
+ self.p = np.clip(self.p, 0, 1)
107
+ self._it += 1
108
+ return torch.tensor(self.p).to(input.device)
109
+
110
+ class DiscriminatorAugmentation(torch.nn.Module):
111
+
112
+ def __init__(self, dim: int):
113
+ super().__init__()
114
+
115
+ self.dataAugmentations : dict[augmentation.DataAugmentation, float] = {}
116
+ pixel_blitting = {
117
+ augmentation.Flip([1/3]*3 if dim == 3 else [1/2]*2) : 0,
118
+ augmentation.Rotate(a_min=0, a_max=360, is_quarter = True): 0,
119
+ augmentation.Translate([(-0.5, 0.5)]* (3 if dim == 3 else 2), is_int=True) : 0
120
+ }
121
+
122
+ self.dataAugmentations.update(pixel_blitting)
123
+ geometric = {
124
+ augmentation.Scale([0.2]) : 0,
125
+ augmentation.Rotate(a_min=0, a_max=360): 0,
126
+ augmentation.Scale([0.2]*3 if dim == 3 else [0.2]*2) : 0,
127
+ augmentation.Rotate(a_min=0, a_max=360): 0,
128
+ augmentation.Translate([(-0.5, 0.5)]* (3 if dim == 3 else 2)) : 0,
129
+ augmentation.Elastix(16, 16) : 0.5
130
+ }
131
+ self.dataAugmentations.update(geometric)
132
+ color = {
133
+ augmentation.Brightness(0.2) : 0,
134
+ augmentation.Contrast(0.5) : 0,
135
+ augmentation.Saturation(1): 0,
136
+ augmentation.HUE(1) : 0,
137
+ augmentation.LumaFlip(): 0
138
+ }
139
+ self.dataAugmentations.update(color)
140
+
141
+ corruptions = {
142
+ augmentation.Noise(1) : 1,
143
+ augmentation.CutOUT(0.5, 0.5, -1) : 0.3
144
+ }
145
+ self.dataAugmentations.update(corruptions)
146
+
147
+ def _setP(self, prob: float):
148
+ for augmentation, p in self.dataAugmentations.items():
149
+ augmentation.load(prob*p)
150
+
151
+ def forward(self, input: torch.Tensor, prob: torch.Tensor) -> torch.Tensor:
152
+ self._setP(prob.item())
153
+ out = input
154
+ for augmentation in self.dataAugmentations.keys():
155
+ augmentation.state_init(None, [input.shape[2:]]*input.shape[0], [Attribute()]*input.shape[0])
156
+ out = augmentation(0, [data for data in out], None)
157
+ return torch.cat([data.unsqueeze(0) for data in out], 0)
158
+
159
+
160
+ class DiscriminatorBlock(network.ModuleArgsDict):
161
+
162
+ def __init__(self, channels: list[int] = [1, 16, 32, 64, 64],
163
+ strides: list[int] = [2,2,2,1],
164
+ dim : int = 3) -> None:
165
+ super().__init__()
166
+ self.add_module("Prob", Discriminator_ADA.UpdateP(), out_branch=["p"])
167
+ self.add_module("Sample", Discriminator_ADA.DiscriminatorAugmentation(dim), in_branch=[0, "p"])
168
+ self.add_module("t", DDPM.DDPM_TimeEmbedding(1000, 100), in_branch=[0, "p"], out_branch=["te"])
169
+ self.add_module("Layers", Discriminator_ADA.DiscriminatorNLayers(channels, strides, 100, dim), in_branch=[0, "te"])
170
+ self.add_module("Head", Discriminator_ADA.DiscriminatorHead(channels[-1], dim))
171
+
172
+ @config("Discriminator_ADA")
173
+ def __init__(self,
174
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
175
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
176
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
177
+ channels: list[int] = [1, 16, 32, 64, 64],
178
+ strides: list[int] = [2,2,2,1],
179
+ nb_batch_per_step: int = 1,
180
+ dim : int = 3) -> None:
181
+ super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, nb_batch_per_step=nb_batch_per_step, dim=dim, init_type="kaiming")
182
+ self.add_module("DiscriminatorModel", Discriminator_ADA.DiscriminatorBlock(channels, strides, dim))
183
+
184
+ def initialized(self):
185
+ self["DiscriminatorModel"]["Prob"].setMeasure(self.measure, ["Discriminator_B.DiscriminatorModel.Head.Conv:None:PatchGanLoss"])
186
+
187
+ """class GeneratorV1(network.Network):
188
+
189
+ class GeneratorStem(network.ModuleArgsDict):
190
+
191
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
192
+ super().__init__()
193
+ self.add_module("ReflectionPad2d", torch.nn.ReflectionPad2d(3))
194
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, blockConfigs=[blocks.BlockConfig(kernel_size=7, padding=0, bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
195
+
196
+ class GeneratorHead(network.ModuleArgsDict):
197
+
198
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
199
+ super().__init__()
200
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels, in_channels, blockConfigs=[blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
201
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels, out_channels, kernel_size=1, bias=False))
202
+ self.add_module("Tanh", torch.nn.Tanh())
203
+
204
+ class GeneratorDownSample(network.ModuleArgsDict):
205
+
206
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
207
+ super().__init__()
208
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, blockConfigs=[blocks.BlockConfig(stride=2, bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
209
+
210
+ class GeneratorUpSample(network.ModuleArgsDict):
211
+
212
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
213
+ super().__init__()
214
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, blockConfigs=[blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
215
+ self.add_module("Upsample", torch.nn.Upsample(scale_factor=2, mode="bilinear" if dim < 3 else "trilinear"))
216
+
217
+ class GeneratorEncoder(network.ModuleArgsDict):
218
+ def __init__(self, channels: list[int], dim: int) -> None:
219
+ super().__init__()
220
+ for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
221
+ self.add_module("DownSample_{}".format(i), GeneratorV1.GeneratorDownSample(in_channels=in_channels, out_channels=out_channels, dim=dim))
222
+
223
+ class GeneratorResnetBlock(network.ModuleArgsDict):
224
+
225
+ def __init__(self, channels : int, dim : int):
226
+ super().__init__()
227
+ self.add_module("Conv_0", blocks.getTorchModule("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False))
228
+ self.add_module("Norm_0", torch.nn.SyncBatchNorm(channels))
229
+ self.add_module("Activation_0", torch.nn.LeakyReLU(0.2, inplace=True))
230
+ self.add_module("Conv_1", blocks.getTorchModule("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False))
231
+ self.add_module("Norm_1", torch.nn.SyncBatchNorm(channels))
232
+ self.add_module("Residual", blocks.Add(), in_branch=[0,1])
233
+
234
+ class GeneratorNResnetBlock(network.ModuleArgsDict):
235
+
236
+ def __init__(self, channels: int, nb_conv: int, dim: int) -> None:
237
+ super().__init__()
238
+ for i in range(nb_conv):
239
+ self.add_module("ResnetBlock_{}".format(i), GeneratorV1.GeneratorResnetBlock(channels=channels, dim=dim))
240
+
241
+ class GeneratorDecoder(network.ModuleArgsDict):
242
+ def __init__(self, channels: list[int], dim: int) -> None:
243
+ super().__init__()
244
+ for i, (in_channels, out_channels) in enumerate(zip(reversed(channels), reversed(channels[:-1]))):
245
+ self.add_module("UpSample_{}".format(i), GeneratorV1.GeneratorUpSample(in_channels=in_channels, out_channels=out_channels, dim=dim))
246
+
247
+ class GeneratorAutoEncoder(network.ModuleArgsDict):
248
+
249
+ def __init__(self, ngf: int, dim: int) -> None:
250
+ super().__init__()
251
+ channels = [ngf, ngf*2]
252
+ self.add_module("Encoder", GeneratorV1.GeneratorEncoder(channels, dim))
253
+ self.add_module("NResBlock", GeneratorV1.GeneratorNResnetBlock(channels=channels[-1], nb_conv=6, dim=dim))
254
+ self.add_module("Decoder", GeneratorV1.GeneratorDecoder(channels, dim))
255
+
256
+ class GeneratorBlock(network.ModuleArgsDict):
257
+
258
+ def __init__(self, ngf: int, dim: int) -> None:
259
+ super().__init__()
260
+ self.add_module("Stem", GeneratorV1.GeneratorStem(3, ngf, dim))
261
+ self.add_module("AutoEncoder", GeneratorV1.GeneratorAutoEncoder(ngf, dim))
262
+ self.add_module("Head", GeneratorV1.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
263
+
264
+ @config("GeneratorV1")
265
+ def __init__(self,
266
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
267
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
268
+ patch : ModelPatch = ModelPatch(),
269
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
270
+ dim : int = 3) -> None:
271
+ super().__init__(optimizer=optimizer, in_channels=3, schedulers=schedulers, patch=patch, outputsCriterions=outputsCriterions, dim=dim)
272
+ self.add_module("GeneratorModel", GeneratorV1.GeneratorBlock(32, dim))"""
273
+
274
+ class GeneratorV1(network.Network):
275
+
276
+ class GeneratorStem(network.ModuleArgsDict):
277
+
278
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
279
+ super().__init__()
280
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, blockConfigs=[blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
281
+
282
+ class GeneratorHead(network.ModuleArgsDict):
283
+
284
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
285
+ super().__init__()
286
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels, in_channels, blockConfigs=[blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
287
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels, out_channels, kernel_size=1, bias=False))
288
+ self.add_module("Tanh", torch.nn.Tanh())
289
+
290
+ class GeneratorDownSample(network.ModuleArgsDict):
291
+
292
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
293
+ super().__init__()
294
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, blockConfigs=[blocks.BlockConfig(stride=2, bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
295
+
296
+ class GeneratorUpSample(network.ModuleArgsDict):
297
+
298
+ def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
299
+ super().__init__()
300
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, blockConfigs=[blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
301
+ self.add_module("Upsample", torch.nn.Upsample(scale_factor=2, mode="bilinear" if dim < 3 else "trilinear"))
302
+
303
+ class GeneratorEncoder(network.ModuleArgsDict):
304
+ def __init__(self, channels: list[int], dim: int) -> None:
305
+ super().__init__()
306
+ for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
307
+ self.add_module("DownSample_{}".format(i), GeneratorV1.GeneratorDownSample(in_channels=in_channels, out_channels=out_channels, dim=dim))
308
+
309
+ class GeneratorResnetBlock(network.ModuleArgsDict):
310
+
311
+ def __init__(self, channels : int, dim : int):
312
+ super().__init__()
313
+ self.add_module("Conv_0", blocks.getTorchModule("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False))
314
+ self.add_module("Norm_0", torch.nn.SyncBatchNorm(channels))
315
+ self.add_module("Activation_0", torch.nn.LeakyReLU(0.2, inplace=True))
316
+ #self.add_module("Norm", torch.nn.LeakyReLU(0.2, inplace=True))
317
+
318
+ self.add_module("Conv_1", blocks.getTorchModule("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False))
319
+ self.add_module("Norm_1", torch.nn.SyncBatchNorm(channels))
320
+ self.add_module("Residual", blocks.Add(), in_branch=[0,1])
321
+
322
+ class GeneratorNResnetBlock(network.ModuleArgsDict):
323
+
324
+ def __init__(self, channels: int, nb_conv: int, dim: int) -> None:
325
+ super().__init__()
326
+ for i in range(nb_conv):
327
+ self.add_module("ResnetBlock_{}".format(i), GeneratorV1.GeneratorResnetBlock(channels=channels, dim=dim))
328
+
329
+ class GeneratorDecoder(network.ModuleArgsDict):
330
+ def __init__(self, channels: list[int], dim: int) -> None:
331
+ super().__init__()
332
+ for i, (in_channels, out_channels) in enumerate(zip(reversed(channels), reversed(channels[:-1]))):
333
+ self.add_module("UpSample_{}".format(i), GeneratorV1.GeneratorUpSample(in_channels=in_channels, out_channels=out_channels, dim=dim))
334
+
335
+ class GeneratorAutoEncoder(network.ModuleArgsDict):
336
+
337
+ def __init__(self, ngf: int, dim: int) -> None:
338
+ super().__init__()
339
+ channels = [ngf, ngf*2]
340
+ self.add_module("Encoder", GeneratorV1.GeneratorEncoder(channels, dim))
341
+ self.add_module("NResBlock", GeneratorV1.GeneratorNResnetBlock(channels=channels[-1], nb_conv=6, dim=dim))
342
+ self.add_module("Decoder", GeneratorV1.GeneratorDecoder(channels, dim))
343
+
344
+ class GeneratorBlock(network.ModuleArgsDict):
345
+
346
+ def __init__(self, ngf: int, dim: int) -> None:
347
+ super().__init__()
348
+ self.add_module("Stem", GeneratorV1.GeneratorStem(3, ngf, dim))
349
+ self.add_module("AutoEncoder", GeneratorV1.GeneratorAutoEncoder(ngf, dim))
350
+ self.add_module("Head", GeneratorV1.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
351
+
352
+ @config("GeneratorV1")
353
+ def __init__(self,
354
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
355
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
356
+ patch : ModelPatch = ModelPatch(),
357
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
358
+ dim : int = 3) -> None:
359
+ super().__init__(optimizer=optimizer, in_channels=3, schedulers=schedulers, patch=patch, outputsCriterions=outputsCriterions, dim=dim)
360
+ self.add_module("GeneratorModel", GeneratorV1.GeneratorBlock(32, dim))
361
+
362
+ class GeneratorV2(network.Network):
363
+
364
+ class NestedUNetHead(network.ModuleArgsDict):
365
+
366
+ def __init__(self, in_channels: list[int], dim: int) -> None:
367
+ super().__init__()
368
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels[1], out_channels = 1, kernel_size = 1, stride = 1, padding = 0))
369
+ self.add_module("Tanh", torch.nn.Tanh())
370
+
371
+ class GeneratorBlock(network.ModuleArgsDict):
372
+
373
+ def __init__(self,
374
+ channels: list[int],
375
+ blockConfig: blocks.BlockConfig,
376
+ nb_conv_per_stage: int,
377
+ downSampleMode: str,
378
+ upSampleMode: str,
379
+ attention : bool,
380
+ blockType: str,
381
+ dim : int,) -> None:
382
+ super().__init__()
383
+ self.add_module("UNetBlock_0", NestedUNet.NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
384
+ self.add_module("Head", GeneratorV2.NestedUNetHead(channels[:2], dim=dim), in_branch=["X_0_{}".format(len(channels)-2)])
385
+
386
+ @config("GeneratorV2")
387
+ def __init__( self,
388
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
389
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
390
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
391
+ patch : Union[ModelPatch, None] = None,
392
+ channels: list[int]=[1, 64, 128, 256, 512, 1024],
393
+ blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
394
+ nb_conv_per_stage: int = 2,
395
+ downSampleMode: str = "MAXPOOL",
396
+ upSampleMode: str = "CONV_TRANSPOSE",
397
+ attention : bool = False,
398
+ blockType: str = "Conv",
399
+ dim : int = 3) -> None:
400
+ super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
401
+ self.add_module("GeneratorModel", GeneratorV2.GeneratorBlock(channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, blockType, dim))
402
+
403
+ class GeneratorV3(network.Network):
404
+
405
+ class NestedUNetHead(network.ModuleArgsDict):
406
+
407
+ def __init__(self, in_channels: list[int], dim: int) -> None:
408
+ super().__init__()
409
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels[1], out_channels = 1, kernel_size = 1, stride = 1, padding = 0))
410
+ self.add_module("Tanh", torch.nn.Tanh())
411
+
412
+ class GeneratorBlock(network.ModuleArgsDict):
413
+
414
+ def __init__(self,
415
+ channels: list[int],
416
+ blockConfig: blocks.BlockConfig,
417
+ nb_conv_per_stage: int,
418
+ downSampleMode: str,
419
+ upSampleMode: str,
420
+ attention : bool,
421
+ blockType: str,
422
+ dim : int,) -> None:
423
+ super().__init__()
424
+ self.add_module("UNetBlock_0", UNet.UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=1, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
425
+ self.add_module("Head", GeneratorV3.NestedUNetHead(channels[:2], dim=dim), in_branch=["X_0_{}".format(len(channels)-2)])
426
+
427
+ @config("GeneratorV3")
428
+ def __init__( self,
429
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
430
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
431
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
432
+ patch : Union[ModelPatch, None] = None,
433
+ channels: list[int]=[1, 64, 128, 256, 512, 1024],
434
+ blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
435
+ nb_conv_per_stage: int = 2,
436
+ downSampleMode: str = "MAXPOOL",
437
+ upSampleMode: str = "CONV_TRANSPOSE",
438
+ attention : bool = False,
439
+ blockType: str = "Conv",
440
+ dim : int = 3) -> None:
441
+ super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
442
+ self.add_module("GeneratorModel", GeneratorV3.GeneratorBlock(channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, blockType, dim), out_branch=["pB"])
443
+
444
+ class DiffusionGan(network.Network):
445
+
446
+ @config("DiffusionGan")
447
+ def __init__(self, generator : GeneratorV1 = GeneratorV1(), discriminator : Discriminator_ADA = Discriminator_ADA()) -> None:
448
+ super().__init__()
449
+ self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
450
+ self.add_module("Discriminator_B", discriminator, in_branch=[1], out_branch=[-1], requires_grad=True)
451
+ self.add_module("detach", blocks.Detach(), in_branch=["pB"], out_branch=["pB_detach"])
452
+ self.add_module("Discriminator_pB_detach", discriminator, in_branch=["pB_detach"], out_branch=[-1])
453
+ self.add_module("Discriminator_pB", discriminator, in_branch=["pB"], out_branch=[-1], requires_grad=False)
454
+
455
+ class DiffusionGanV2(network.Network):
456
+
457
+ @config("DiffusionGan")
458
+ def __init__(self, generator : GeneratorV2 = GeneratorV2(), discriminator : Discriminator = Discriminator()) -> None:
459
+ super().__init__()
460
+ self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
461
+ self.add_module("Discriminator_B", discriminator, in_branch=[1], out_branch=[-1], requires_grad=True)
462
+ self.add_module("detach", blocks.Detach(), in_branch=["pB"], out_branch=["pB_detach"])
463
+ self.add_module("Discriminator_pB_detach", discriminator, in_branch=["pB_detach"], out_branch=[-1])
464
+ self.add_module("Discriminator_pB", discriminator, in_branch=["pB"], out_branch=[-1], requires_grad=False)
465
+
466
+
467
+ class CycleGanDiscriminator(network.Network):
468
+
469
+ @config("CycleGanDiscriminator")
470
+ def __init__(self,
471
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
472
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
473
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
474
+ patch : Union[ModelPatch, None] = None,
475
+ channels: list[int] = [1, 16, 32, 64, 64],
476
+ strides: list[int] = [2,2,2,1],
477
+ dim : int = 3) -> None:
478
+ super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim=dim)
479
+ self.add_module("Discriminator_A", Discriminator.DiscriminatorBlock(channels, strides, dim), in_branch=[0], out_branch=[0])
480
+ self.add_module("Discriminator_B", Discriminator.DiscriminatorBlock(channels, strides, dim), in_branch=[1], out_branch=[1])
481
+
482
+ def initialized(self):
483
+ self["Discriminator_A"]["Sample"].setMeasure(self.measure, ["Discriminator.Discriminator_A.Head.Flatten:None:PatchGanLoss"])
484
+ self["Discriminator_B"]["Sample"].setMeasure(self.measure, ["Discriminator.Discriminator_B.Head.Flatten:None:PatchGanLoss"])
485
+
486
+ class CycleGanGeneratorV1(network.Network):
487
+
488
+ @config("CycleGanGeneratorV1")
489
+ def __init__(self,
490
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
491
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
492
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
493
+ patch : Union[ModelPatch, None] = None,
494
+ dim : int = 3) -> None:
495
+ super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim=dim)
496
+ self.add_module("Generator_A_to_B", GeneratorV1.GeneratorBlock(32, dim), in_branch=[0], out_branch=["pB"])
497
+ self.add_module("Generator_B_to_A", GeneratorV1.GeneratorBlock(32, dim), in_branch=[1], out_branch=["pA"])
498
+
499
+ class CycleGanGeneratorV2(network.Network):
500
+
501
+ @config("CycleGanGeneratorV2")
502
+ def __init__(self,
503
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
504
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
505
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
506
+ patch : Union[ModelPatch, None] = None,
507
+ channels: list[int]=[1, 64, 128, 256, 512, 1024],
508
+ blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
509
+ nb_conv_per_stage: int = 2,
510
+ downSampleMode: str = "MAXPOOL",
511
+ upSampleMode: str = "CONV_TRANSPOSE",
512
+ attention : bool = False,
513
+ blockType: str = "Conv",
514
+ dim : int = 3) -> None:
515
+ super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim=dim)
516
+ self.add_module("Generator_A_to_B", GeneratorV2.GeneratorBlock(channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, blockType, dim), in_branch=[0], out_branch=["pB"])
517
+ self.add_module("Generator_B_to_A", GeneratorV2.GeneratorBlock(channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, blockType, dim), in_branch=[1], out_branch=["pA"])
518
+
519
+ class CycleGanGeneratorV3(network.Network):
520
+
521
+ @config("CycleGanGeneratorV3")
522
+ def __init__(self,
523
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
524
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
525
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
526
+ patch : Union[ModelPatch, None] = None,
527
+ channels: list[int]=[1, 64, 128, 256, 512, 1024],
528
+ blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
529
+ nb_conv_per_stage: int = 2,
530
+ downSampleMode: str = "MAXPOOL",
531
+ upSampleMode: str = "CONV_TRANSPOSE",
532
+ attention : bool = False,
533
+ blockType: str = "Conv",
534
+ dim : int = 3) -> None:
535
+ super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim=dim)
536
+ self.add_module("Generator_A_to_B", GeneratorV3.GeneratorBlock(channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, blockType, dim), in_branch=[0], out_branch=["pB"])
537
+ self.add_module("Generator_B_to_A", GeneratorV3.GeneratorBlock(channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, blockType, dim), in_branch=[1], out_branch=["pA"])
538
+
539
+ class DiffusionCycleGan(network.Network):
540
+
541
+ @config("DiffusionCycleGan")
542
+ def __init__(self, generators : CycleGanGeneratorV3 = CycleGanGeneratorV3(), discriminators : CycleGanDiscriminator = CycleGanDiscriminator()) -> None:
543
+ super().__init__()
544
+ self.add_module("Generator", generators, in_branch=[0, 1], out_branch=["pB", "pA"])
545
+ self.add_module("Discriminator", discriminators, in_branch=[0, 1], out_branch=[-1], requires_grad=True)
546
+
547
+ self.add_module("Generator_identity", generators, in_branch=[1, 0], out_branch=[-1])
548
+
549
+ self.add_module("Generator_p", generators, in_branch=["pA", "pB"], out_branch=[-1])
550
+
551
+ self.add_module("detach_pA", blocks.Detach(), in_branch=["pA"], out_branch=["pA_detach"])
552
+ self.add_module("detach_pB", blocks.Detach(), in_branch=["pB"], out_branch=["pB_detach"])
553
+
554
+ self.add_module("Discriminator_p_detach", discriminators, in_branch=["pA_detach", "pB_detach"], out_branch=[-1])
555
+ self.add_module("Discriminator_p", discriminators, in_branch=["pA", "pB"], out_branch=[-1], requires_grad=False)
556
+
557
+