konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +509 -290
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +384 -277
  6. konfai/evaluator.py +309 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +341 -222
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +781 -423
  23. konfai/predictor.py +645 -240
  24. konfai/trainer.py +527 -216
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +495 -249
  29. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
  30. konfai-1.1.9.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.7.dist-info/RECORD +0 -39
  33. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
@@ -1,83 +1,152 @@
1
1
  from functools import partial
2
- from typing import Union
3
- import torch
2
+ from typing import cast
3
+
4
4
  import numpy as np
5
+ import torch
5
6
 
6
- from konfai.network import network, blocks
7
- from konfai.utils.config import config
8
- from konfai.data.patching import ModelPatch, Attribute
9
7
  from konfai.data import augmentation
10
- from konfai.models.segmentation import UNet, NestedUNet
8
+ from konfai.data.patching import Attribute, ModelPatch
11
9
  from konfai.models.generation.ddpm import DDPM
10
+ from konfai.models.segmentation import NestedUNet, UNet
11
+ from konfai.network import blocks, network
12
+
12
13
 
13
14
  class Discriminator(network.Network):
14
-
15
+
15
16
  class DiscriminatorNLayers(network.ModuleArgsDict):
16
17
 
17
18
  def __init__(self, channels: list[int], strides: list[int], dim: int) -> None:
18
19
  super().__init__()
19
- blockConfig = partial(blocks.BlockConfig, kernel_size=4, padding=1, bias=False, activation=partial(torch.nn.LeakyReLU, negative_slope = 0.2, inplace=True), normMode=blocks.NormMode.SYNCBATCH)
20
+ block_config = partial(
21
+ blocks.BlockConfig,
22
+ kernel_size=4,
23
+ padding=1,
24
+ bias=False,
25
+ activation=partial(torch.nn.LeakyReLU, negative_slope=0.2, inplace=True),
26
+ norm_mode=blocks.NormMode.SYNCBATCH,
27
+ )
20
28
  for i, (in_channels, out_channels, stride) in enumerate(zip(channels, channels[1:], strides)):
21
- self.add_module("Layer_{}".format(i), blocks.ConvBlock(in_channels, out_channels, [blockConfig(stride=stride)], dim))
22
-
29
+ self.add_module(
30
+ f"Layer_{i}",
31
+ blocks.ConvBlock(in_channels, out_channels, [block_config(stride=stride)], dim),
32
+ )
33
+
23
34
  class DiscriminatorHead(network.ModuleArgsDict):
24
35
 
25
36
  def __init__(self, channels: int, dim: int) -> None:
26
37
  super().__init__()
27
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels=channels, out_channels=1, kernel_size=4, stride=1, padding=1))
28
- #self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(tuple([1]*dim)))
29
- #self.add_module("Flatten", torch.nn.Flatten(1))
38
+ self.add_module(
39
+ "Conv",
40
+ blocks.get_torch_module("Conv", dim)(
41
+ in_channels=channels,
42
+ out_channels=1,
43
+ kernel_size=4,
44
+ stride=1,
45
+ padding=1,
46
+ ),
47
+ )
48
+ # self.add_module("AdaptiveAvgPool", blocks.get_torch_module("AdaptiveAvgPool", dim)(tuple([1]*dim)))
49
+ # self.add_module("Flatten", torch.nn.Flatten(1))
30
50
 
31
51
  class DiscriminatorBlock(network.ModuleArgsDict):
32
52
 
33
- def __init__(self, channels: list[int] = [1, 16, 32, 64, 64],
34
- strides: list[int] = [2,2,2,1],
35
- dim : int = 3) -> None:
53
+ def __init__(
54
+ self,
55
+ channels: list[int] = [1, 16, 32, 64, 64],
56
+ strides: list[int] = [2, 2, 2, 1],
57
+ dim: int = 3,
58
+ ) -> None:
36
59
  super().__init__()
37
60
  self.add_module("Layers", Discriminator.DiscriminatorNLayers(channels, strides, dim))
38
61
  self.add_module("Head", Discriminator.DiscriminatorHead(channels[-1], dim))
39
62
 
40
- def __init__(self,
41
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
42
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
43
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
44
- channels: list[int] = [1, 16, 32, 64, 64],
45
- strides: list[int] = [2,2,2,1],
46
- nb_batch_per_step: int = 1,
47
- dim : int = 3) -> None:
48
- super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, nb_batch_per_step=nb_batch_per_step, dim=dim, init_type="kaiming")
49
- self.add_module("DiscriminatorModel", Discriminator.DiscriminatorBlock(channels, strides, dim))
50
-
51
-
52
- class Discriminator_ADA(network.Network):
53
-
54
- class DDPM_TE(torch.nn.Module):
63
+ def __init__(
64
+ self,
65
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
66
+ schedulers: dict[str, network.LRSchedulersLoader] = {
67
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
68
+ },
69
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
70
+ channels: list[int] = [1, 16, 32, 64, 64],
71
+ strides: list[int] = [2, 2, 2, 1],
72
+ nb_batch_per_step: int = 1,
73
+ dim: int = 3,
74
+ ) -> None:
75
+ super().__init__(
76
+ in_channels=1,
77
+ optimizer=optimizer,
78
+ schedulers=schedulers,
79
+ outputs_criterions=outputs_criterions,
80
+ nb_batch_per_step=nb_batch_per_step,
81
+ dim=dim,
82
+ init_type="kaiming",
83
+ )
84
+ self.add_module(
85
+ "DiscriminatorModel",
86
+ Discriminator.DiscriminatorBlock(channels, strides, dim),
87
+ )
88
+
89
+
90
+ class DiscriminatorADA(network.Network):
91
+
92
+ class DDPMTE(torch.nn.Module):
55
93
 
56
94
  def __init__(self, in_channels: int, out_channels: int) -> None:
57
95
  super().__init__()
58
96
  self.linear_0 = torch.nn.Linear(in_channels, out_channels)
59
97
  self.siLU = torch.nn.SiLU()
60
98
  self.linear_1 = torch.nn.Linear(out_channels, out_channels)
61
-
62
- def forward(self, input: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
63
- return input + self.linear_1(self.siLU(self.linear_0(t))).reshape(input.shape[0], -1, *[1 for _ in range(len(input.shape)-2)])
64
-
99
+
100
+ def forward(self, tensor: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
101
+ return tensor + self.linear_1(self.siLU(self.linear_0(t))).reshape(
102
+ tensor.shape[0], -1, *[1 for _ in range(len(tensor.shape) - 2)]
103
+ )
104
+
65
105
  class DiscriminatorNLayers(network.ModuleArgsDict):
66
106
 
67
- def __init__(self, channels: list[int], strides: list[int], time_embedding_dim: int, dim: int) -> None:
68
- super().__init__()
69
- blockConfig = partial(blocks.BlockConfig, kernel_size=4, padding=1, bias=False, activation=partial(torch.nn.LeakyReLU, negative_slope = 0.2, inplace=True), normMode=blocks.NormMode.SYNCBATCH)
107
+ def __init__(
108
+ self,
109
+ channels: list[int],
110
+ strides: list[int],
111
+ time_embedding_dim: int,
112
+ dim: int,
113
+ ) -> None:
114
+ super().__init__()
115
+ block_config = partial(
116
+ blocks.BlockConfig,
117
+ kernel_size=4,
118
+ padding=1,
119
+ bias=False,
120
+ activation=partial(torch.nn.LeakyReLU, negative_slope=0.2, inplace=True),
121
+ norm_mode=blocks.NormMode.SYNCBATCH,
122
+ )
70
123
  for i, (in_channels, out_channels, stride) in enumerate(zip(channels, channels[1:], strides)):
71
- self.add_module("Te_{}".format(i), Discriminator_ADA.DDPM_TE(time_embedding_dim, in_channels), in_branch=[0, 1])
72
- self.add_module("Layer_{}".format(i), blocks.ConvBlock(in_channels, out_channels, [blockConfig(stride=stride)], dim))
73
-
124
+ self.add_module(
125
+ f"Te_{i}",
126
+ DiscriminatorADA.DDPMTE(time_embedding_dim, in_channels),
127
+ in_branch=[0, 1],
128
+ )
129
+ self.add_module(
130
+ f"Layer_{i}",
131
+ blocks.ConvBlock(in_channels, out_channels, [block_config(stride=stride)], dim),
132
+ )
133
+
74
134
  class DiscriminatorHead(network.ModuleArgsDict):
75
135
 
76
136
  def __init__(self, channels: int, dim: int) -> None:
77
137
  super().__init__()
78
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels=channels, out_channels=1, kernel_size=4, stride=1, padding=1))
79
- #self.add_module("AdaptiveAvgPool", blocks.getTorchModule("AdaptiveAvgPool", dim)(tuple([1]*dim)))
80
- #self.add_module("Flatten", torch.nn.Flatten(1))
138
+ self.add_module(
139
+ "Conv",
140
+ blocks.get_torch_module("Conv", dim)(
141
+ in_channels=channels,
142
+ out_channels=1,
143
+ kernel_size=4,
144
+ stride=1,
145
+ padding=1,
146
+ ),
147
+ )
148
+ # self.add_module("AdaptiveAvgPool", blocks.get_torch_module("AdaptiveAvgPool", dim)(tuple([1]*dim)))
149
+ # self.add_module("Flatten", torch.nn.Flatten(1))
81
150
 
82
151
  class UpdateP(torch.nn.Module):
83
152
 
@@ -90,183 +159,137 @@ class Discriminator_ADA(network.Network):
90
159
  self.ada_kimg = 500
91
160
 
92
161
  self.measure = None
93
- self.names = None
162
+ self.names = []
94
163
  self.p = 0
95
-
96
- def setMeasure(self, measure: network.Measure, names: list[str]):
164
+
165
+ def set_measure(self, measure: network.Measure, names: list[str]):
97
166
  self.measure = measure
98
167
  self.names = names
99
-
100
- def forward(self, input: torch.Tensor) -> torch.Tensor:
168
+
169
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
101
170
  if self.measure is not None and self._it % self.n == 0:
102
- value = sum([v for k, v in self.measure.getLastValues(self.n).items() if k in self.names])
103
- adjust = np.sign(self.ada_target-value) * (self.ada_interval)
171
+ value = sum([v for k, v in self.measure.get_last_values(self.n).items() if k in self.names])
172
+ adjust = np.sign(self.ada_target - value) * (self.ada_interval)
104
173
  self.p += adjust
105
174
  self.p = np.clip(self.p, 0, 1)
106
175
  self._it += 1
107
- return torch.tensor(self.p).to(input.device)
108
-
176
+ return torch.tensor(self.p).to(tensor.device)
177
+
109
178
  class DiscriminatorAugmentation(torch.nn.Module):
110
179
 
111
180
  def __init__(self, dim: int):
112
181
  super().__init__()
113
182
 
114
- self.dataAugmentations : dict[augmentation.DataAugmentation, float] = {}
183
+ self.data_augmentations: dict[augmentation.DataAugmentation, float] = {}
115
184
  pixel_blitting = {
116
- augmentation.Flip([1/3]*3 if dim == 3 else [1/2]*2) : 0,
117
- augmentation.Rotate(a_min=0, a_max=360, is_quarter = True): 0,
118
- augmentation.Translate([(-0.5, 0.5)]* (3 if dim == 3 else 2), is_int=True) : 0
119
- }
120
-
121
- self.dataAugmentations.update(pixel_blitting)
185
+ augmentation.Flip([1 / 3] * 3 if dim == 3 else [1 / 2] * 2): 0,
186
+ augmentation.Rotate(a_min=0, a_max=360, is_quarter=True): 0,
187
+ augmentation.Translate(-5, 5, is_int=True): 0,
188
+ }
189
+
190
+ self.data_augmentations.update(
191
+ {cast(augmentation.DataAugmentation, k): v for k, v in pixel_blitting.items()}
192
+ )
122
193
  geometric = {
123
- augmentation.Scale([0.2]) : 0,
194
+ augmentation.Scale(0.2): 0,
124
195
  augmentation.Rotate(a_min=0, a_max=360): 0,
125
- augmentation.Scale([0.2]*3 if dim == 3 else [0.2]*2) : 0,
196
+ augmentation.Scale(0.2): 0,
126
197
  augmentation.Rotate(a_min=0, a_max=360): 0,
127
- augmentation.Translate([(-0.5, 0.5)]* (3 if dim == 3 else 2)) : 0,
128
- augmentation.Elastix(16, 16) : 0.5
129
- }
130
- self.dataAugmentations.update(geometric)
198
+ augmentation.Translate(-5, 5): 0,
199
+ augmentation.Elastix(16, 16): 0.5,
200
+ }
201
+ self.data_augmentations.update({cast(augmentation.DataAugmentation, k): v for k, v in geometric.items()})
131
202
  color = {
132
- augmentation.Brightness(0.2) : 0,
133
- augmentation.Contrast(0.5) : 0,
134
- augmentation.Saturation(1): 0,
135
- augmentation.HUE(1) : 0,
136
- augmentation.LumaFlip(): 0
203
+ augmentation.Brightness(0.2): 0.0,
204
+ augmentation.Contrast(0.5): 0.0,
205
+ augmentation.Saturation(1): 0.0,
206
+ augmentation.HUE(1): 0.0,
207
+ augmentation.LumaFlip(): 0.0,
137
208
  }
138
- self.dataAugmentations.update(color)
139
-
140
- corruptions = {
141
- augmentation.Noise(1) : 1,
142
- augmentation.CutOUT(0.5, 0.5, -1) : 0.3
209
+ self.data_augmentations.update({cast(augmentation.DataAugmentation, k): v for k, v in color.items()})
210
+
211
+ corruptions = {
212
+ augmentation.Noise(1): 1,
213
+ augmentation.CutOUT(0.5, 1, -1): 0.3,
143
214
  }
144
- self.dataAugmentations.update(corruptions)
145
-
146
- def _setP(self, prob: float):
147
- for augmentation, p in self.dataAugmentations.items():
148
- augmentation.load(prob*p)
149
-
150
- def forward(self, input: torch.Tensor, prob: torch.Tensor) -> torch.Tensor:
151
- self._setP(prob.item())
152
- out = input
153
- for augmentation in self.dataAugmentations.keys():
154
- augmentation.state_init(None, [input.shape[2:]]*input.shape[0], [Attribute()]*input.shape[0])
155
- out = augmentation(0, [data for data in out], None)
215
+ self.data_augmentations.update({cast(augmentation.DataAugmentation, k): v for k, v in corruptions.items()})
216
+
217
+ def _set_p(self, prob: float):
218
+ for aug, p in self.data_augmentations.items():
219
+ aug.load(prob * p)
220
+
221
+ def forward(self, tensor: torch.Tensor, prob: torch.Tensor) -> torch.Tensor:
222
+ self._set_p(prob.item())
223
+ out = tensor
224
+ for aug in self.data_augmentations.keys():
225
+ aug.state_init(
226
+ None,
227
+ [tensor.shape[2:]] * tensor.shape[0],
228
+ [Attribute()] * tensor.shape[0],
229
+ )
230
+ out = aug("", 0, list(out))
156
231
  return torch.cat([data.unsqueeze(0) for data in out], 0)
157
232
 
158
-
159
233
  class DiscriminatorBlock(network.ModuleArgsDict):
160
234
 
161
- def __init__(self, channels: list[int] = [1, 16, 32, 64, 64],
162
- strides: list[int] = [2,2,2,1],
163
- dim : int = 3) -> None:
164
- super().__init__()
165
- self.add_module("Prob", Discriminator_ADA.UpdateP(), out_branch=["p"])
166
- self.add_module("Sample", Discriminator_ADA.DiscriminatorAugmentation(dim), in_branch=[0, "p"])
167
- self.add_module("t", DDPM.DDPM_TimeEmbedding(1000, 100), in_branch=[0, "p"], out_branch=["te"])
168
- self.add_module("Layers", Discriminator_ADA.DiscriminatorNLayers(channels, strides, 100, dim), in_branch=[0, "te"])
169
- self.add_module("Head", Discriminator_ADA.DiscriminatorHead(channels[-1], dim))
170
-
171
- def __init__(self,
172
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
173
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
174
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
175
- channels: list[int] = [1, 16, 32, 64, 64],
176
- strides: list[int] = [2,2,2,1],
177
- nb_batch_per_step: int = 1,
178
- dim : int = 3) -> None:
179
- super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, nb_batch_per_step=nb_batch_per_step, dim=dim, init_type="kaiming")
180
- self.add_module("DiscriminatorModel", Discriminator_ADA.DiscriminatorBlock(channels, strides, dim))
235
+ def __init__(
236
+ self,
237
+ channels: list[int] = [1, 16, 32, 64, 64],
238
+ strides: list[int] = [2, 2, 2, 1],
239
+ dim: int = 3,
240
+ ) -> None:
241
+ super().__init__()
242
+ self.add_module("Prob", DiscriminatorADA.UpdateP(), out_branch=["p"])
243
+ self.add_module(
244
+ "Sample",
245
+ DiscriminatorADA.DiscriminatorAugmentation(dim),
246
+ in_branch=[0, "p"],
247
+ )
248
+ self.add_module(
249
+ "t",
250
+ DDPM.DDPMTimeEmbedding(1000, 100),
251
+ in_branch=[0, "p"],
252
+ out_branch=["te"],
253
+ )
254
+ self.add_module(
255
+ "Layers",
256
+ DiscriminatorADA.DiscriminatorNLayers(channels, strides, 100, dim),
257
+ in_branch=[0, "te"],
258
+ )
259
+ self.add_module("Head", DiscriminatorADA.DiscriminatorHead(channels[-1], dim))
260
+
261
+ def __init__(
262
+ self,
263
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
264
+ schedulers: dict[str, network.LRSchedulersLoader] = {
265
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
266
+ },
267
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
268
+ channels: list[int] = [1, 16, 32, 64, 64],
269
+ strides: list[int] = [2, 2, 2, 1],
270
+ nb_batch_per_step: int = 1,
271
+ dim: int = 3,
272
+ ) -> None:
273
+ super().__init__(
274
+ in_channels=1,
275
+ optimizer=optimizer,
276
+ schedulers=schedulers,
277
+ outputs_criterions=outputs_criterions,
278
+ nb_batch_per_step=nb_batch_per_step,
279
+ dim=dim,
280
+ init_type="kaiming",
281
+ )
282
+ self.add_module(
283
+ "DiscriminatorModel",
284
+ DiscriminatorADA.DiscriminatorBlock(channels, strides, dim),
285
+ )
181
286
 
182
287
  def initialized(self):
183
- self["DiscriminatorModel"]["Prob"].setMeasure(self.measure, ["Discriminator_B.DiscriminatorModel.Head.Conv:None:PatchGanLoss"])
184
-
185
- """class GeneratorV1(network.Network):
186
-
187
- class GeneratorStem(network.ModuleArgsDict):
288
+ self["DiscriminatorModel"]["Prob"].set_measure(
289
+ self.measure,
290
+ ["Discriminator_B.DiscriminatorModel.Head.Conv:None:PatchGanLoss"],
291
+ )
188
292
 
189
- def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
190
- super().__init__()
191
- self.add_module("ReflectionPad2d", torch.nn.ReflectionPad2d(3))
192
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, blockConfigs=[blocks.BlockConfig(kernel_size=7, padding=0, bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
193
-
194
- class GeneratorHead(network.ModuleArgsDict):
195
-
196
- def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
197
- super().__init__()
198
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels, in_channels, blockConfigs=[blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
199
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels, out_channels, kernel_size=1, bias=False))
200
- self.add_module("Tanh", torch.nn.Tanh())
201
-
202
- class GeneratorDownSample(network.ModuleArgsDict):
203
-
204
- def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
205
- super().__init__()
206
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, blockConfigs=[blocks.BlockConfig(stride=2, bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
207
-
208
- class GeneratorUpSample(network.ModuleArgsDict):
209
-
210
- def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
211
- super().__init__()
212
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, blockConfigs=[blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
213
- self.add_module("Upsample", torch.nn.Upsample(scale_factor=2, mode="bilinear" if dim < 3 else "trilinear"))
214
-
215
- class GeneratorEncoder(network.ModuleArgsDict):
216
- def __init__(self, channels: list[int], dim: int) -> None:
217
- super().__init__()
218
- for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
219
- self.add_module("DownSample_{}".format(i), GeneratorV1.GeneratorDownSample(in_channels=in_channels, out_channels=out_channels, dim=dim))
220
-
221
- class GeneratorResnetBlock(network.ModuleArgsDict):
222
-
223
- def __init__(self, channels : int, dim : int):
224
- super().__init__()
225
- self.add_module("Conv_0", blocks.getTorchModule("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False))
226
- self.add_module("Norm_0", torch.nn.SyncBatchNorm(channels))
227
- self.add_module("Activation_0", torch.nn.LeakyReLU(0.2, inplace=True))
228
- self.add_module("Conv_1", blocks.getTorchModule("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False))
229
- self.add_module("Norm_1", torch.nn.SyncBatchNorm(channels))
230
- self.add_module("Residual", blocks.Add(), in_branch=[0,1])
231
-
232
- class GeneratorNResnetBlock(network.ModuleArgsDict):
233
-
234
- def __init__(self, channels: int, nb_conv: int, dim: int) -> None:
235
- super().__init__()
236
- for i in range(nb_conv):
237
- self.add_module("ResnetBlock_{}".format(i), GeneratorV1.GeneratorResnetBlock(channels=channels, dim=dim))
238
-
239
- class GeneratorDecoder(network.ModuleArgsDict):
240
- def __init__(self, channels: list[int], dim: int) -> None:
241
- super().__init__()
242
- for i, (in_channels, out_channels) in enumerate(zip(reversed(channels), reversed(channels[:-1]))):
243
- self.add_module("UpSample_{}".format(i), GeneratorV1.GeneratorUpSample(in_channels=in_channels, out_channels=out_channels, dim=dim))
244
-
245
- class GeneratorAutoEncoder(network.ModuleArgsDict):
246
-
247
- def __init__(self, ngf: int, dim: int) -> None:
248
- super().__init__()
249
- channels = [ngf, ngf*2]
250
- self.add_module("Encoder", GeneratorV1.GeneratorEncoder(channels, dim))
251
- self.add_module("NResBlock", GeneratorV1.GeneratorNResnetBlock(channels=channels[-1], nb_conv=6, dim=dim))
252
- self.add_module("Decoder", GeneratorV1.GeneratorDecoder(channels, dim))
253
-
254
- class GeneratorBlock(network.ModuleArgsDict):
255
-
256
- def __init__(self, ngf: int, dim: int) -> None:
257
- super().__init__()
258
- self.add_module("Stem", GeneratorV1.GeneratorStem(3, ngf, dim))
259
- self.add_module("AutoEncoder", GeneratorV1.GeneratorAutoEncoder(ngf, dim))
260
- self.add_module("Head", GeneratorV1.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
261
-
262
- def __init__(self,
263
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
264
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
265
- patch : ModelPatch = ModelPatch(),
266
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
267
- dim : int = 3) -> None:
268
- super().__init__(optimizer=optimizer, in_channels=3, schedulers=schedulers, patch=patch, outputsCriterions=outputsCriterions, dim=dim)
269
- self.add_module("GeneratorModel", GeneratorV1.GeneratorBlock(32, dim))"""
270
293
 
271
294
  class GeneratorV1(network.Network):
272
295
 
@@ -274,68 +297,131 @@ class GeneratorV1(network.Network):
274
297
 
275
298
  def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
276
299
  super().__init__()
277
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, blockConfigs=[blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
300
+ self.add_module(
301
+ "ConvBlock",
302
+ blocks.ConvBlock(
303
+ in_channels,
304
+ out_channels,
305
+ block_configs=[blocks.BlockConfig(bias=False, activation="ReLU", norm_mode="SYNCBATCH")],
306
+ dim=dim,
307
+ ),
308
+ )
278
309
 
279
310
  class GeneratorHead(network.ModuleArgsDict):
280
311
 
281
312
  def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
282
313
  super().__init__()
283
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels, in_channels, blockConfigs=[blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
284
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels, out_channels, kernel_size=1, bias=False))
314
+ self.add_module(
315
+ "ConvBlock",
316
+ blocks.ConvBlock(
317
+ in_channels,
318
+ in_channels,
319
+ block_configs=[blocks.BlockConfig(bias=False, activation="ReLU", norm_mode="SYNCBATCH")],
320
+ dim=dim,
321
+ ),
322
+ )
323
+ self.add_module(
324
+ "Conv",
325
+ blocks.get_torch_module("Conv", dim)(in_channels, out_channels, kernel_size=1, bias=False),
326
+ )
285
327
  self.add_module("Tanh", torch.nn.Tanh())
286
328
 
287
329
  class GeneratorDownSample(network.ModuleArgsDict):
288
330
 
289
331
  def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
290
332
  super().__init__()
291
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, blockConfigs=[blocks.BlockConfig(stride=2, bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
292
-
333
+ self.add_module(
334
+ "ConvBlock",
335
+ blocks.ConvBlock(
336
+ in_channels,
337
+ out_channels,
338
+ block_configs=[
339
+ blocks.BlockConfig(
340
+ stride=2,
341
+ bias=False,
342
+ activation="ReLU",
343
+ norm_mode="SYNCBATCH",
344
+ )
345
+ ],
346
+ dim=dim,
347
+ ),
348
+ )
349
+
293
350
  class GeneratorUpSample(network.ModuleArgsDict):
294
351
 
295
352
  def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
296
353
  super().__init__()
297
- self.add_module("ConvBlock", blocks.ConvBlock(in_channels, out_channels, blockConfigs=[blocks.BlockConfig(bias=False, activation="ReLU", normMode="SYNCBATCH")], dim=dim))
298
- self.add_module("Upsample", torch.nn.Upsample(scale_factor=2, mode="bilinear" if dim < 3 else "trilinear"))
299
-
354
+ self.add_module(
355
+ "ConvBlock",
356
+ blocks.ConvBlock(
357
+ in_channels,
358
+ out_channels,
359
+ block_configs=[blocks.BlockConfig(bias=False, activation="ReLU", norm_mode="SYNCBATCH")],
360
+ dim=dim,
361
+ ),
362
+ )
363
+ self.add_module(
364
+ "Upsample",
365
+ torch.nn.Upsample(scale_factor=2, mode="bilinear" if dim < 3 else "trilinear"),
366
+ )
367
+
300
368
  class GeneratorEncoder(network.ModuleArgsDict):
301
369
  def __init__(self, channels: list[int], dim: int) -> None:
302
370
  super().__init__()
303
371
  for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:])):
304
- self.add_module("DownSample_{}".format(i), GeneratorV1.GeneratorDownSample(in_channels=in_channels, out_channels=out_channels, dim=dim))
305
-
372
+ self.add_module(
373
+ f"DownSample_{i}",
374
+ GeneratorV1.GeneratorDownSample(in_channels=in_channels, out_channels=out_channels, dim=dim),
375
+ )
376
+
306
377
  class GeneratorResnetBlock(network.ModuleArgsDict):
307
378
 
308
- def __init__(self, channels : int, dim : int):
379
+ def __init__(self, channels: int, dim: int):
309
380
  super().__init__()
310
- self.add_module("Conv_0", blocks.getTorchModule("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False))
381
+ self.add_module(
382
+ "Conv_0",
383
+ blocks.get_torch_module("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False),
384
+ )
311
385
  self.add_module("Norm_0", torch.nn.SyncBatchNorm(channels))
312
386
  self.add_module("Activation_0", torch.nn.LeakyReLU(0.2, inplace=True))
313
- #self.add_module("Norm", torch.nn.LeakyReLU(0.2, inplace=True))
314
-
315
- self.add_module("Conv_1", blocks.getTorchModule("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False))
387
+ # self.add_module("Norm", torch.nn.LeakyReLU(0.2, inplace=True))
388
+
389
+ self.add_module(
390
+ "Conv_1",
391
+ blocks.get_torch_module("Conv", dim)(channels, channels, kernel_size=3, padding=1, bias=False),
392
+ )
316
393
  self.add_module("Norm_1", torch.nn.SyncBatchNorm(channels))
317
- self.add_module("Residual", blocks.Add(), in_branch=[0,1])
394
+ self.add_module("Residual", blocks.Add(), in_branch=[0, 1])
318
395
 
319
396
  class GeneratorNResnetBlock(network.ModuleArgsDict):
320
397
 
321
398
  def __init__(self, channels: int, nb_conv: int, dim: int) -> None:
322
399
  super().__init__()
323
400
  for i in range(nb_conv):
324
- self.add_module("ResnetBlock_{}".format(i), GeneratorV1.GeneratorResnetBlock(channels=channels, dim=dim))
401
+ self.add_module(
402
+ f"ResnetBlock_{i}",
403
+ GeneratorV1.GeneratorResnetBlock(channels=channels, dim=dim),
404
+ )
325
405
 
326
406
  class GeneratorDecoder(network.ModuleArgsDict):
327
407
  def __init__(self, channels: list[int], dim: int) -> None:
328
408
  super().__init__()
329
409
  for i, (in_channels, out_channels) in enumerate(zip(reversed(channels), reversed(channels[:-1]))):
330
- self.add_module("UpSample_{}".format(i), GeneratorV1.GeneratorUpSample(in_channels=in_channels, out_channels=out_channels, dim=dim))
331
-
410
+ self.add_module(
411
+ f"UpSample_{i}",
412
+ GeneratorV1.GeneratorUpSample(in_channels=in_channels, out_channels=out_channels, dim=dim),
413
+ )
414
+
332
415
  class GeneratorAutoEncoder(network.ModuleArgsDict):
333
416
 
334
417
  def __init__(self, ngf: int, dim: int) -> None:
335
418
  super().__init__()
336
- channels = [ngf, ngf*2]
419
+ channels = [ngf, ngf * 2]
337
420
  self.add_module("Encoder", GeneratorV1.GeneratorEncoder(channels, dim))
338
- self.add_module("NResBlock", GeneratorV1.GeneratorNResnetBlock(channels=channels[-1], nb_conv=6, dim=dim))
421
+ self.add_module(
422
+ "NResBlock",
423
+ GeneratorV1.GeneratorNResnetBlock(channels=channels[-1], nb_conv=6, dim=dim),
424
+ )
339
425
  self.add_module("Decoder", GeneratorV1.GeneratorDecoder(channels, dim))
340
426
 
341
427
  class GeneratorBlock(network.ModuleArgsDict):
@@ -344,56 +430,123 @@ class GeneratorV1(network.Network):
344
430
  super().__init__()
345
431
  self.add_module("Stem", GeneratorV1.GeneratorStem(3, ngf, dim))
346
432
  self.add_module("AutoEncoder", GeneratorV1.GeneratorAutoEncoder(ngf, dim))
347
- self.add_module("Head", GeneratorV1.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim))
348
-
349
- def __init__(self,
350
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
351
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
352
- patch : ModelPatch = ModelPatch(),
353
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
354
- dim : int = 3) -> None:
355
- super().__init__(optimizer=optimizer, in_channels=3, schedulers=schedulers, patch=patch, outputsCriterions=outputsCriterions, dim=dim)
433
+ self.add_module(
434
+ "Head",
435
+ GeneratorV1.GeneratorHead(in_channels=ngf, out_channels=1, dim=dim),
436
+ )
437
+
438
+ def __init__(
439
+ self,
440
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
441
+ schedulers: dict[str, network.LRSchedulersLoader] = {
442
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
443
+ },
444
+ patch: ModelPatch = ModelPatch(),
445
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
446
+ dim: int = 3,
447
+ ) -> None:
448
+ super().__init__(
449
+ optimizer=optimizer,
450
+ in_channels=3,
451
+ schedulers=schedulers,
452
+ patch=patch,
453
+ outputs_criterions=outputs_criterions,
454
+ dim=dim,
455
+ )
356
456
  self.add_module("GeneratorModel", GeneratorV1.GeneratorBlock(32, dim))
357
-
457
+
458
+
358
459
  class GeneratorV2(network.Network):
359
460
 
360
461
  class NestedUNetHead(network.ModuleArgsDict):
361
462
 
362
463
  def __init__(self, in_channels: list[int], dim: int) -> None:
363
464
  super().__init__()
364
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels[1], out_channels = 1, kernel_size = 1, stride = 1, padding = 0))
465
+ self.add_module(
466
+ "Conv",
467
+ blocks.get_torch_module("Conv", dim)(
468
+ in_channels=in_channels[1],
469
+ out_channels=1,
470
+ kernel_size=1,
471
+ stride=1,
472
+ padding=0,
473
+ ),
474
+ )
365
475
  self.add_module("Tanh", torch.nn.Tanh())
366
-
476
+
367
477
  class GeneratorBlock(network.ModuleArgsDict):
368
478
 
369
- def __init__(self,
370
- channels: list[int],
371
- blockConfig: blocks.BlockConfig,
372
- nb_conv_per_stage: int,
373
- downSampleMode: str,
374
- upSampleMode: str,
375
- attention : bool,
376
- blockType: str,
377
- dim : int,) -> None:
378
- super().__init__()
379
- self.add_module("UNetBlock_0", NestedUNet.NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
380
- self.add_module("Head", GeneratorV2.NestedUNetHead(channels[:2], dim=dim), in_branch=["X_0_{}".format(len(channels)-2)])
381
-
382
- def __init__( self,
383
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
384
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
385
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
386
- patch : Union[ModelPatch, None] = None,
387
- channels: list[int]=[1, 64, 128, 256, 512, 1024],
388
- blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
389
- nb_conv_per_stage: int = 2,
390
- downSampleMode: str = "MAXPOOL",
391
- upSampleMode: str = "CONV_TRANSPOSE",
392
- attention : bool = False,
393
- blockType: str = "Conv",
394
- dim : int = 3) -> None:
395
- super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
396
- self.add_module("GeneratorModel", GeneratorV2.GeneratorBlock(channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, blockType, dim))
479
+ def __init__(
480
+ self,
481
+ channels: list[int],
482
+ block_config: blocks.BlockConfig,
483
+ nb_conv_per_stage: int,
484
+ downsample_mode: str,
485
+ upsample_mode: str,
486
+ attention: bool,
487
+ block_type: str,
488
+ dim: int,
489
+ ) -> None:
490
+ super().__init__()
491
+ self.add_module(
492
+ "UNetBlock_0",
493
+ NestedUNet.NestedUNet.NestedUNetBlock(
494
+ channels,
495
+ nb_conv_per_stage,
496
+ block_config,
497
+ downsample_mode=blocks.DownsampleMode[downsample_mode],
498
+ upsample_mode=blocks.UpsampleMode[upsample_mode],
499
+ attention=attention,
500
+ block=blocks.ConvBlock if block_type == "Conv" else blocks.ResBlock,
501
+ dim=dim,
502
+ ),
503
+ out_branch=[f"X_0_{j + 1}" for j in range(len(channels) - 2)],
504
+ )
505
+ self.add_module(
506
+ "Head",
507
+ GeneratorV2.NestedUNetHead(channels[:2], dim=dim),
508
+ in_branch=[f"X_0_{len(channels) - 2}"],
509
+ )
510
+
511
+ def __init__(
512
+ self,
513
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
514
+ schedulers: dict[str, network.LRSchedulersLoader] = {
515
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
516
+ },
517
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
518
+ patch: ModelPatch | None = None,
519
+ channels: list[int] = [1, 64, 128, 256, 512, 1024],
520
+ block_config: blocks.BlockConfig = blocks.BlockConfig(),
521
+ nb_conv_per_stage: int = 2,
522
+ downsample_mode: str = "MAXPOOL",
523
+ upsample_mode: str = "CONV_TRANSPOSE",
524
+ attention: bool = False,
525
+ block_type: str = "Conv",
526
+ dim: int = 3,
527
+ ) -> None:
528
+ super().__init__(
529
+ in_channels=channels[0],
530
+ optimizer=optimizer,
531
+ schedulers=schedulers,
532
+ outputs_criterions=outputs_criterions,
533
+ patch=patch,
534
+ dim=dim,
535
+ )
536
+ self.add_module(
537
+ "GeneratorModel",
538
+ GeneratorV2.GeneratorBlock(
539
+ channels,
540
+ block_config,
541
+ nb_conv_per_stage,
542
+ downsample_mode,
543
+ upsample_mode,
544
+ attention,
545
+ block_type,
546
+ dim,
547
+ ),
548
+ )
549
+
397
550
 
398
551
  class GeneratorV3(network.Network):
399
552
 
@@ -401,144 +554,390 @@ class GeneratorV3(network.Network):
401
554
 
402
555
  def __init__(self, in_channels: list[int], dim: int) -> None:
403
556
  super().__init__()
404
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels[1], out_channels = 1, kernel_size = 1, stride = 1, padding = 0))
557
+ self.add_module(
558
+ "Conv",
559
+ blocks.get_torch_module("Conv", dim)(
560
+ in_channels=in_channels[1],
561
+ out_channels=1,
562
+ kernel_size=1,
563
+ stride=1,
564
+ padding=0,
565
+ ),
566
+ )
405
567
  self.add_module("Tanh", torch.nn.Tanh())
406
-
568
+
407
569
  class GeneratorBlock(network.ModuleArgsDict):
408
570
 
409
- def __init__(self,
410
- channels: list[int],
411
- blockConfig: blocks.BlockConfig,
412
- nb_conv_per_stage: int,
413
- downSampleMode: str,
414
- upSampleMode: str,
415
- attention : bool,
416
- blockType: str,
417
- dim : int,) -> None:
418
- super().__init__()
419
- self.add_module("UNetBlock_0", UNet.UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=1, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
420
- self.add_module("Head", GeneratorV3.NestedUNetHead(channels[:2], dim=dim), in_branch=["X_0_{}".format(len(channels)-2)])
421
-
422
- def __init__( self,
423
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
424
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
425
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
426
- patch : Union[ModelPatch, None] = None,
427
- channels: list[int]=[1, 64, 128, 256, 512, 1024],
428
- blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
429
- nb_conv_per_stage: int = 2,
430
- downSampleMode: str = "MAXPOOL",
431
- upSampleMode: str = "CONV_TRANSPOSE",
432
- attention : bool = False,
433
- blockType: str = "Conv",
434
- dim : int = 3) -> None:
435
- super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
436
- self.add_module("GeneratorModel", GeneratorV3.GeneratorBlock(channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, blockType, dim), out_branch=["pB"])
571
+ def __init__(
572
+ self,
573
+ channels: list[int],
574
+ block_config: blocks.BlockConfig,
575
+ nb_conv_per_stage: int,
576
+ downsample_mode: str,
577
+ upsample_mode: str,
578
+ attention: bool,
579
+ block_type: str,
580
+ dim: int,
581
+ ) -> None:
582
+ super().__init__()
583
+ self.add_module(
584
+ "UNetBlock_0",
585
+ UNet.UNetBlock(
586
+ channels,
587
+ nb_conv_per_stage,
588
+ block_config,
589
+ downsample_mode=blocks.DownsampleMode[downsample_mode],
590
+ upsample_mode=blocks.UpsampleMode[upsample_mode],
591
+ attention=attention,
592
+ block=blocks.ConvBlock if block_type == "Conv" else blocks.ResBlock,
593
+ nb_class=1,
594
+ dim=dim,
595
+ ),
596
+ out_branch=[f"X_0_{j + 1}" for j in range(len(channels) - 2)],
597
+ )
598
+ self.add_module(
599
+ "Head",
600
+ GeneratorV3.NestedUNetHead(channels[:2], dim=dim),
601
+ in_branch=[f"X_0_{len(channels) - 2}"],
602
+ )
603
+
604
+ def __init__(
605
+ self,
606
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
607
+ schedulers: dict[str, network.LRSchedulersLoader] = {
608
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
609
+ },
610
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
611
+ patch: ModelPatch | None = None,
612
+ channels: list[int] = [1, 64, 128, 256, 512, 1024],
613
+ block_config: blocks.BlockConfig = blocks.BlockConfig(),
614
+ nb_conv_per_stage: int = 2,
615
+ downsample_mode: str = "MAXPOOL",
616
+ upsample_mode: str = "CONV_TRANSPOSE",
617
+ attention: bool = False,
618
+ block_type: str = "Conv",
619
+ dim: int = 3,
620
+ ) -> None:
621
+ super().__init__(
622
+ in_channels=channels[0],
623
+ optimizer=optimizer,
624
+ schedulers=schedulers,
625
+ outputs_criterions=outputs_criterions,
626
+ patch=patch,
627
+ dim=dim,
628
+ )
629
+ self.add_module(
630
+ "GeneratorModel",
631
+ GeneratorV3.GeneratorBlock(
632
+ channels,
633
+ block_config,
634
+ nb_conv_per_stage,
635
+ downsample_mode,
636
+ upsample_mode,
637
+ attention,
638
+ block_type,
639
+ dim,
640
+ ),
641
+ out_branch=["pB"],
642
+ )
643
+
437
644
 
438
645
  class DiffusionGan(network.Network):
439
646
 
440
- def __init__(self, generator : GeneratorV1 = GeneratorV1(), discriminator : Discriminator_ADA = Discriminator_ADA()) -> None:
647
+ def __init__(
648
+ self,
649
+ generator: GeneratorV1 = GeneratorV1(),
650
+ discriminator: DiscriminatorADA = DiscriminatorADA(),
651
+ ) -> None:
441
652
  super().__init__()
442
653
  self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
443
- self.add_module("Discriminator_B", discriminator, in_branch=[1], out_branch=[-1], requires_grad=True)
654
+ self.add_module(
655
+ "Discriminator_B",
656
+ discriminator,
657
+ in_branch=[1],
658
+ out_branch=[-1],
659
+ requires_grad=True,
660
+ )
444
661
  self.add_module("detach", blocks.Detach(), in_branch=["pB"], out_branch=["pB_detach"])
445
- self.add_module("Discriminator_pB_detach", discriminator, in_branch=["pB_detach"], out_branch=[-1])
446
- self.add_module("Discriminator_pB", discriminator, in_branch=["pB"], out_branch=[-1], requires_grad=False)
662
+ self.add_module(
663
+ "Discriminator_pB_detach",
664
+ discriminator,
665
+ in_branch=["pB_detach"],
666
+ out_branch=[-1],
667
+ )
668
+ self.add_module(
669
+ "Discriminator_pB",
670
+ discriminator,
671
+ in_branch=["pB"],
672
+ out_branch=[-1],
673
+ requires_grad=False,
674
+ )
675
+
447
676
 
448
677
  class DiffusionGanV2(network.Network):
449
678
 
450
- def __init__(self, generator : GeneratorV2 = GeneratorV2(), discriminator : Discriminator = Discriminator()) -> None:
679
+ def __init__(
680
+ self,
681
+ generator: GeneratorV2 = GeneratorV2(),
682
+ discriminator: Discriminator = Discriminator(),
683
+ ) -> None:
451
684
  super().__init__()
452
685
  self.add_module("Generator_A_to_B", generator, in_branch=[0], out_branch=["pB"])
453
- self.add_module("Discriminator_B", discriminator, in_branch=[1], out_branch=[-1], requires_grad=True)
686
+ self.add_module(
687
+ "Discriminator_B",
688
+ discriminator,
689
+ in_branch=[1],
690
+ out_branch=[-1],
691
+ requires_grad=True,
692
+ )
454
693
  self.add_module("detach", blocks.Detach(), in_branch=["pB"], out_branch=["pB_detach"])
455
- self.add_module("Discriminator_pB_detach", discriminator, in_branch=["pB_detach"], out_branch=[-1])
456
- self.add_module("Discriminator_pB", discriminator, in_branch=["pB"], out_branch=[-1], requires_grad=False)
694
+ self.add_module(
695
+ "Discriminator_pB_detach",
696
+ discriminator,
697
+ in_branch=["pB_detach"],
698
+ out_branch=[-1],
699
+ )
700
+ self.add_module(
701
+ "Discriminator_pB",
702
+ discriminator,
703
+ in_branch=["pB"],
704
+ out_branch=[-1],
705
+ requires_grad=False,
706
+ )
457
707
 
458
708
 
459
709
  class CycleGanDiscriminator(network.Network):
460
710
 
461
- def __init__(self,
462
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
463
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
464
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
465
- patch : Union[ModelPatch, None] = None,
466
- channels: list[int] = [1, 16, 32, 64, 64],
467
- strides: list[int] = [2,2,2,1],
468
- dim : int = 3) -> None:
469
- super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim=dim)
470
- self.add_module("Discriminator_A", Discriminator.DiscriminatorBlock(channels, strides, dim), in_branch=[0], out_branch=[0])
471
- self.add_module("Discriminator_B", Discriminator.DiscriminatorBlock(channels, strides, dim), in_branch=[1], out_branch=[1])
472
-
711
+ def __init__(
712
+ self,
713
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
714
+ schedulers: dict[str, network.LRSchedulersLoader] = {
715
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
716
+ },
717
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
718
+ patch: ModelPatch | None = None,
719
+ channels: list[int] = [1, 16, 32, 64, 64],
720
+ strides: list[int] = [2, 2, 2, 1],
721
+ dim: int = 3,
722
+ ) -> None:
723
+ super().__init__(
724
+ in_channels=1,
725
+ optimizer=optimizer,
726
+ schedulers=schedulers,
727
+ outputs_criterions=outputs_criterions,
728
+ patch=patch,
729
+ dim=dim,
730
+ )
731
+ self.add_module(
732
+ "Discriminator_A",
733
+ Discriminator.DiscriminatorBlock(channels, strides, dim),
734
+ in_branch=[0],
735
+ out_branch=[0],
736
+ )
737
+ self.add_module(
738
+ "Discriminator_B",
739
+ Discriminator.DiscriminatorBlock(channels, strides, dim),
740
+ in_branch=[1],
741
+ out_branch=[1],
742
+ )
743
+
473
744
  def initialized(self):
474
- self["Discriminator_A"]["Sample"].setMeasure(self.measure, ["Discriminator.Discriminator_A.Head.Flatten:None:PatchGanLoss"])
475
- self["Discriminator_B"]["Sample"].setMeasure(self.measure, ["Discriminator.Discriminator_B.Head.Flatten:None:PatchGanLoss"])
745
+ self["Discriminator_A"]["Sample"].set_measure(
746
+ self.measure,
747
+ ["Discriminator.Discriminator_A.Head.Flatten:None:PatchGanLoss"],
748
+ )
749
+ self["Discriminator_B"]["Sample"].set_measure(
750
+ self.measure,
751
+ ["Discriminator.Discriminator_B.Head.Flatten:None:PatchGanLoss"],
752
+ )
753
+
476
754
 
477
755
  class CycleGanGeneratorV1(network.Network):
478
756
 
479
- def __init__(self,
480
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
481
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
482
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
483
- patch : Union[ModelPatch, None] = None,
484
- dim : int = 3) -> None:
485
- super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim=dim)
486
- self.add_module("Generator_A_to_B", GeneratorV1.GeneratorBlock(32, dim), in_branch=[0], out_branch=["pB"])
487
- self.add_module("Generator_B_to_A", GeneratorV1.GeneratorBlock(32, dim), in_branch=[1], out_branch=["pA"])
757
+ def __init__(
758
+ self,
759
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
760
+ schedulers: dict[str, network.LRSchedulersLoader] = {
761
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
762
+ },
763
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
764
+ patch: ModelPatch | None = None,
765
+ dim: int = 3,
766
+ ) -> None:
767
+ super().__init__(
768
+ in_channels=1,
769
+ optimizer=optimizer,
770
+ schedulers=schedulers,
771
+ outputs_criterions=outputs_criterions,
772
+ patch=patch,
773
+ dim=dim,
774
+ )
775
+ self.add_module(
776
+ "Generator_A_to_B",
777
+ GeneratorV1.GeneratorBlock(32, dim),
778
+ in_branch=[0],
779
+ out_branch=["pB"],
780
+ )
781
+ self.add_module(
782
+ "Generator_B_to_A",
783
+ GeneratorV1.GeneratorBlock(32, dim),
784
+ in_branch=[1],
785
+ out_branch=["pA"],
786
+ )
787
+
488
788
 
489
789
  class CycleGanGeneratorV2(network.Network):
490
790
 
491
- def __init__(self,
492
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
493
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
494
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
495
- patch : Union[ModelPatch, None] = None,
496
- channels: list[int]=[1, 64, 128, 256, 512, 1024],
497
- blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
498
- nb_conv_per_stage: int = 2,
499
- downSampleMode: str = "MAXPOOL",
500
- upSampleMode: str = "CONV_TRANSPOSE",
501
- attention : bool = False,
502
- blockType: str = "Conv",
503
- dim : int = 3) -> None:
504
- super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim=dim)
505
- self.add_module("Generator_A_to_B", GeneratorV2.GeneratorBlock(channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, blockType, dim), in_branch=[0], out_branch=["pB"])
506
- self.add_module("Generator_B_to_A", GeneratorV2.GeneratorBlock(channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, blockType, dim), in_branch=[1], out_branch=["pA"])
791
+ def __init__(
792
+ self,
793
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
794
+ schedulers: dict[str, network.LRSchedulersLoader] = {
795
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
796
+ },
797
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
798
+ patch: ModelPatch | None = None,
799
+ channels: list[int] = [1, 64, 128, 256, 512, 1024],
800
+ block_config: blocks.BlockConfig = blocks.BlockConfig(),
801
+ nb_conv_per_stage: int = 2,
802
+ downsample_mode: str = "MAXPOOL",
803
+ upsample_mode: str = "CONV_TRANSPOSE",
804
+ attention: bool = False,
805
+ block_type: str = "Conv",
806
+ dim: int = 3,
807
+ ) -> None:
808
+ super().__init__(
809
+ in_channels=1,
810
+ optimizer=optimizer,
811
+ schedulers=schedulers,
812
+ outputs_criterions=outputs_criterions,
813
+ patch=patch,
814
+ dim=dim,
815
+ )
816
+ self.add_module(
817
+ "Generator_A_to_B",
818
+ GeneratorV2.GeneratorBlock(
819
+ channels,
820
+ block_config,
821
+ nb_conv_per_stage,
822
+ downsample_mode,
823
+ upsample_mode,
824
+ attention,
825
+ block_type,
826
+ dim,
827
+ ),
828
+ in_branch=[0],
829
+ out_branch=["pB"],
830
+ )
831
+ self.add_module(
832
+ "Generator_B_to_A",
833
+ GeneratorV2.GeneratorBlock(
834
+ channels,
835
+ block_config,
836
+ nb_conv_per_stage,
837
+ downsample_mode,
838
+ upsample_mode,
839
+ attention,
840
+ block_type,
841
+ dim,
842
+ ),
843
+ in_branch=[1],
844
+ out_branch=["pA"],
845
+ )
846
+
507
847
 
508
848
  class CycleGanGeneratorV3(network.Network):
509
849
 
510
- def __init__(self,
511
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
512
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
513
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
514
- patch : Union[ModelPatch, None] = None,
515
- channels: list[int]=[1, 64, 128, 256, 512, 1024],
516
- blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
517
- nb_conv_per_stage: int = 2,
518
- downSampleMode: str = "MAXPOOL",
519
- upSampleMode: str = "CONV_TRANSPOSE",
520
- attention : bool = False,
521
- blockType: str = "Conv",
522
- dim : int = 3) -> None:
523
- super().__init__(in_channels = 1, optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim=dim)
524
- self.add_module("Generator_A_to_B", GeneratorV3.GeneratorBlock(channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, blockType, dim), in_branch=[0], out_branch=["pB"])
525
- self.add_module("Generator_B_to_A", GeneratorV3.GeneratorBlock(channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, blockType, dim), in_branch=[1], out_branch=["pA"])
850
+ def __init__(
851
+ self,
852
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
853
+ schedulers: dict[str, network.LRSchedulersLoader] = {
854
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
855
+ },
856
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
857
+ patch: ModelPatch | None = None,
858
+ channels: list[int] = [1, 64, 128, 256, 512, 1024],
859
+ block_config: blocks.BlockConfig = blocks.BlockConfig(),
860
+ nb_conv_per_stage: int = 2,
861
+ downsample_mode: str = "MAXPOOL",
862
+ upsample_mode: str = "CONV_TRANSPOSE",
863
+ attention: bool = False,
864
+ block_type: str = "Conv",
865
+ dim: int = 3,
866
+ ) -> None:
867
+ super().__init__(
868
+ in_channels=1,
869
+ optimizer=optimizer,
870
+ schedulers=schedulers,
871
+ outputs_criterions=outputs_criterions,
872
+ patch=patch,
873
+ dim=dim,
874
+ )
875
+ self.add_module(
876
+ "Generator_A_to_B",
877
+ GeneratorV3.GeneratorBlock(
878
+ channels,
879
+ block_config,
880
+ nb_conv_per_stage,
881
+ downsample_mode,
882
+ upsample_mode,
883
+ attention,
884
+ block_type,
885
+ dim,
886
+ ),
887
+ in_branch=[0],
888
+ out_branch=["pB"],
889
+ )
890
+ self.add_module(
891
+ "Generator_B_to_A",
892
+ GeneratorV3.GeneratorBlock(
893
+ channels,
894
+ block_config,
895
+ nb_conv_per_stage,
896
+ downsample_mode,
897
+ upsample_mode,
898
+ attention,
899
+ block_type,
900
+ dim,
901
+ ),
902
+ in_branch=[1],
903
+ out_branch=["pA"],
904
+ )
905
+
526
906
 
527
907
  class DiffusionCycleGan(network.Network):
528
908
 
529
- def __init__(self, generators : CycleGanGeneratorV3 = CycleGanGeneratorV3(), discriminators : CycleGanDiscriminator = CycleGanDiscriminator()) -> None:
909
+ def __init__(
910
+ self,
911
+ generators: CycleGanGeneratorV3 = CycleGanGeneratorV3(),
912
+ discriminators: CycleGanDiscriminator = CycleGanDiscriminator(),
913
+ ) -> None:
530
914
  super().__init__()
531
915
  self.add_module("Generator", generators, in_branch=[0, 1], out_branch=["pB", "pA"])
532
- self.add_module("Discriminator", discriminators, in_branch=[0, 1], out_branch=[-1], requires_grad=True)
533
-
916
+ self.add_module(
917
+ "Discriminator",
918
+ discriminators,
919
+ in_branch=[0, 1],
920
+ out_branch=[-1],
921
+ requires_grad=True,
922
+ )
923
+
534
924
  self.add_module("Generator_identity", generators, in_branch=[1, 0], out_branch=[-1])
535
-
925
+
536
926
  self.add_module("Generator_p", generators, in_branch=["pA", "pB"], out_branch=[-1])
537
-
927
+
538
928
  self.add_module("detach_pA", blocks.Detach(), in_branch=["pA"], out_branch=["pA_detach"])
539
929
  self.add_module("detach_pB", blocks.Detach(), in_branch=["pB"], out_branch=["pB_detach"])
540
930
 
541
- self.add_module("Discriminator_p_detach", discriminators, in_branch=["pA_detach", "pB_detach"], out_branch=[-1])
542
- self.add_module("Discriminator_p", discriminators, in_branch=["pA", "pB"], out_branch=[-1], requires_grad=False)
543
-
544
-
931
+ self.add_module(
932
+ "Discriminator_p_detach",
933
+ discriminators,
934
+ in_branch=["pA_detach", "pB_detach"],
935
+ out_branch=[-1],
936
+ )
937
+ self.add_module(
938
+ "Discriminator_p",
939
+ discriminators,
940
+ in_branch=["pA", "pB"],
941
+ out_branch=[-1],
942
+ requires_grad=False,
943
+ )