konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,14 @@
1
+ from collections.abc import Callable
1
2
 
2
-
3
-
4
- from typing import Union, Callable
3
+ import numpy as np
5
4
  import torch
6
5
  import tqdm
7
- import numpy as np
8
6
 
9
- from konfai.network import network, blocks
10
- from konfai.utils.config import config
11
7
  from konfai.data.patching import ModelPatch
12
- from konfai.utils.utils import gpuInfo
13
8
  from konfai.metric.measure import Criterion
9
+ from konfai.network import blocks, network
10
+ from konfai.utils.utils import gpu_info
11
+
14
12
 
15
13
  def cosine_beta_schedule(timesteps, s=0.008):
16
14
  steps = timesteps + 1
@@ -20,198 +18,427 @@ def cosine_beta_schedule(timesteps, s=0.008):
20
18
  betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
21
19
  return torch.clip(betas, 0.0001, 0.9999)
22
20
 
21
+
23
22
  class DDPM(network.Network):
24
-
25
- class DDPM_TE(torch.nn.Module):
23
+
24
+ class DDPMTE(torch.nn.Module):
26
25
 
27
26
  def __init__(self, in_channels: int, out_channels: int) -> None:
28
27
  super().__init__()
29
28
  self.linear_0 = torch.nn.Linear(in_channels, out_channels)
30
29
  self.siLU = torch.nn.SiLU()
31
30
  self.linear_1 = torch.nn.Linear(out_channels, out_channels)
32
-
33
- def forward(self, input: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
34
- return input + self.linear_1(self.siLU(self.linear_0(t))).reshape(input.shape[0], -1, *[1 for _ in range(len(input.shape)-2)])
35
-
36
- class DDPM_UNetBlock(network.ModuleArgsDict):
37
-
38
- def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, time_embedding_dim: int, dim: int, i : int = 0) -> None:
31
+
32
+ def forward(self, tensor: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
33
+ return tensor + self.linear_1(self.siLU(self.linear_0(t))).reshape(
34
+ tensor.shape[0], -1, *[1 for _ in range(len(tensor.shape) - 2)]
35
+ )
36
+
37
+ class DDPMUNetBlock(network.ModuleArgsDict):
38
+
39
+ def __init__(
40
+ self,
41
+ channels: list[int],
42
+ nb_conv_per_stage: int,
43
+ block_config: blocks.BlockConfig,
44
+ downsample_mode: blocks.DownsampleMode,
45
+ upsample_mode: blocks.UpsampleMode,
46
+ attention: bool,
47
+ time_embedding_dim: int,
48
+ dim: int,
49
+ i: int = 0,
50
+ ) -> None:
39
51
  super().__init__()
40
52
  if i > 0:
41
- self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
42
- self.add_module("Te_down", DDPM.DDPM_TE(time_embedding_dim, channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0]), in_branch=[0, 1])
43
- self.add_module("DownConvBlock", blocks.ResBlock(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], nb_conv=nb_conv_per_stage, blockConfig=blockConfig, dim=dim))
53
+ self.add_module(
54
+ downsample_mode.name,
55
+ blocks.downsample(
56
+ in_channels=channels[0],
57
+ out_channels=channels[1],
58
+ downsample_mode=downsample_mode,
59
+ dim=dim,
60
+ ),
61
+ )
62
+ self.add_module(
63
+ "Te_down",
64
+ DDPM.DDPMTE(
65
+ time_embedding_dim,
66
+ channels[(1 if downsample_mode == blocks.DownsampleMode.CONV_STRIDE and i > 0 else 0)],
67
+ ),
68
+ in_branch=[0, 1],
69
+ )
70
+ self.add_module(
71
+ "DownConvBlock",
72
+ blocks.ResBlock(
73
+ in_channels=channels[(1 if downsample_mode == blocks.DownsampleMode.CONV_STRIDE and i > 0 else 0)],
74
+ out_channels=channels[1],
75
+ block_configs=[block_config] * nb_conv_per_stage,
76
+ dim=dim,
77
+ ),
78
+ )
44
79
  if len(channels) > 2:
45
- self.add_module("UNetBlock_{}".format(i+1), DDPM.DDPM_UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, time_embedding_dim, dim, i+1), in_branch=[0,1])
46
- self.add_module("Te_up", DDPM.DDPM_TE(time_embedding_dim, (channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2), in_branch=[0, 1])
47
- self.add_module("UpConvBlock", blocks.ResBlock(in_channels=(channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2, out_channels=channels[1], nb_conv=nb_conv_per_stage, blockConfig=blockConfig, dim=dim))
80
+ self.add_module(
81
+ f"UNetBlock_{i + 1}",
82
+ DDPM.DDPMUNetBlock(
83
+ channels[1:],
84
+ nb_conv_per_stage,
85
+ block_config,
86
+ downsample_mode,
87
+ upsample_mode,
88
+ attention,
89
+ time_embedding_dim,
90
+ dim,
91
+ i + 1,
92
+ ),
93
+ in_branch=[0, 1],
94
+ )
95
+ self.add_module(
96
+ "Te_up",
97
+ DDPM.DDPMTE(
98
+ time_embedding_dim,
99
+ (
100
+ (channels[1] + channels[2])
101
+ if upsample_mode != blocks.UpsampleMode.CONV_TRANSPOSE
102
+ else channels[1] * 2
103
+ ),
104
+ ),
105
+ in_branch=[0, 1],
106
+ )
107
+ self.add_module(
108
+ "UpConvBlock",
109
+ blocks.ResBlock(
110
+ in_channels=(
111
+ (channels[1] + channels[2])
112
+ if upsample_mode != blocks.UpsampleMode.CONV_TRANSPOSE
113
+ else channels[1] * 2
114
+ ),
115
+ out_channels=channels[1],
116
+ block_configs=[block_config] * nb_conv_per_stage,
117
+ dim=dim,
118
+ ),
119
+ )
48
120
  if i > 0:
49
121
  if attention:
50
- self.add_module("Attention", blocks.Attention(F_g=channels[1], F_l=channels[0], F_int=channels[0], dim=dim), in_branch=[2, 0], out_branch=[2])
51
- self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim))
122
+ self.add_module(
123
+ "Attention",
124
+ blocks.Attention(f_g=channels[1], f_l=channels[0], f_int=channels[0], dim=dim),
125
+ in_branch=[2, 0],
126
+ out_branch=[2],
127
+ )
128
+ self.add_module(
129
+ upsample_mode.name,
130
+ blocks.upsample(
131
+ in_channels=channels[1],
132
+ out_channels=channels[0],
133
+ upsample_mode=upsample_mode,
134
+ dim=dim,
135
+ ),
136
+ )
52
137
  self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 2])
53
-
54
- class DDPM_UNetHead(network.ModuleArgsDict):
138
+
139
+ class DDPMUNetHead(network.ModuleArgsDict):
55
140
 
56
141
  def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
57
142
  super().__init__()
58
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
143
+ self.add_module(
144
+ "Conv",
145
+ blocks.get_torch_module("Conv", dim)(
146
+ in_channels=in_channels,
147
+ out_channels=out_channels,
148
+ kernel_size=3,
149
+ stride=1,
150
+ padding=1,
151
+ ),
152
+ )
59
153
 
60
- class DDPM_ForwardProcess(torch.nn.Module):
154
+ class DDPMForwardProcess(torch.nn.Module):
61
155
 
62
- def __init__(self, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02) -> None:
156
+ def __init__(
157
+ self,
158
+ noise_step: int = 1000,
159
+ beta_start: float = 1e-4,
160
+ beta_end: float = 0.02,
161
+ ) -> None:
63
162
  super().__init__()
64
163
  self.betas = torch.linspace(beta_start, beta_end, noise_step)
65
- self.betas = DDPM.DDPM_ForwardProcess.enforce_zero_terminal_snr(self.betas)
164
+ self.betas = DDPM.DDPMForwardProcess.enforce_zero_terminal_snr(self.betas)
66
165
  self.alphas = 1 - self.betas
67
166
  self.alpha_hat = torch.concat((torch.ones(1), torch.cumprod(self.alphas, dim=0)))
68
167
 
168
+ @staticmethod
69
169
  def enforce_zero_terminal_snr(betas: torch.Tensor):
70
170
  alphas = 1 - betas
71
171
  alphas_bar = alphas.cumprod(0)
72
172
  alphas_bar_sqrt = alphas_bar.sqrt()
73
173
  alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
74
- alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
75
- alphas_bar_sqrt -= alphas_bar_sqrt_T
76
- alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
77
- alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
78
- alphas_bar = alphas_bar_sqrt ** 2
174
+ alphas_bar_sqrt_t = alphas_bar_sqrt[-1].clone()
175
+ alphas_bar_sqrt -= alphas_bar_sqrt_t
176
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_t)
177
+ alphas_bar = alphas_bar_sqrt**2
79
178
  alphas = alphas_bar[1:] / alphas_bar[:-1]
80
179
  alphas = torch.cat([alphas_bar[0:1], alphas])
81
180
  betas = 1 - alphas
82
181
  return betas
83
182
 
84
- def forward(self, input: torch.Tensor, t: torch.Tensor, eta: torch.Tensor) -> torch.Tensor:
85
- alpha_hat_t = self.alpha_hat[t.cpu()].to(input.device).reshape(input.shape[0], *[1 for _ in range(len(input.shape)-1)])
86
- result = alpha_hat_t.sqrt() * input + (1 - alpha_hat_t).sqrt() * eta
183
+ def forward(self, tensor: torch.Tensor, t: torch.Tensor, eta: torch.Tensor) -> torch.Tensor:
184
+ alpha_hat_t = (
185
+ self.alpha_hat[t.cpu()]
186
+ .to(tensor.device)
187
+ .reshape(tensor.shape[0], *[1 for _ in range(len(tensor.shape) - 1)])
188
+ )
189
+ result = alpha_hat_t.sqrt() * tensor + (1 - alpha_hat_t).sqrt() * eta
87
190
  return result
88
-
89
- class DDPM_SampleT(torch.nn.Module):
191
+
192
+ class DDPMSampleT(torch.nn.Module):
90
193
 
91
194
  def __init__(self, noise_step: int) -> None:
92
195
  super().__init__()
93
196
  self.noise_step = noise_step
94
-
95
- def forward(self, input: torch.Tensor) -> torch.Tensor:
96
- return torch.randint(0, self.noise_step, (input.shape[0],)).to(input.device)
97
-
98
- class DDPM_TimeEmbedding(torch.nn.Module):
99
-
197
+
198
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
199
+ return torch.randint(0, self.noise_step, (tensor.shape[0],)).to(tensor.device)
200
+
201
+ class DDPMTimeEmbedding(torch.nn.Module):
202
+
203
+ @staticmethod
100
204
  def sinusoidal_embedding(noise_step: int, time_embedding_dim: int):
101
205
  noise_step += 1
102
206
  embedding = torch.zeros(noise_step, time_embedding_dim)
103
207
  wk = torch.tensor([1 / 10_000 ** (2 * j / time_embedding_dim) for j in range(time_embedding_dim)])
104
208
  wk = wk.reshape((1, time_embedding_dim))
105
209
  t = torch.arange(noise_step).reshape((noise_step, 1))
106
- embedding[:,::2] = torch.sin(t * wk[:,::2])
107
- embedding[:,1::2] = torch.cos(t * wk[:,::2])
210
+ embedding[:, ::2] = torch.sin(t * wk[:, ::2])
211
+ embedding[:, 1::2] = torch.cos(t * wk[:, ::2])
108
212
  return embedding
109
213
 
110
- def __init__(self, noise_step: int=1000, time_embedding_dim: int=100) -> None:
214
+ def __init__(self, noise_step: int = 1000, time_embedding_dim: int = 100) -> None:
111
215
  super().__init__()
112
216
  self.time_embed = torch.nn.Embedding(noise_step, time_embedding_dim)
113
- self.time_embed.weight.data = DDPM.DDPM_TimeEmbedding.sinusoidal_embedding(noise_step, time_embedding_dim)
217
+ self.time_embed.weight.data = DDPM.DDPMTimeEmbedding.sinusoidal_embedding(noise_step, time_embedding_dim)
114
218
  self.time_embed.requires_grad_(False)
115
219
  self.noise_step = noise_step
116
-
117
- def forward(self, input: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
118
- return self.time_embed((p*self.noise_step).long().repeat((input.shape[0])))
119
-
120
- class DDPM_UNet(network.ModuleArgsDict):
121
220
 
122
- def __init__(self, noise_step: int, channels: list[int], blockConfig: blocks.BlockConfig, nb_conv_per_stage: int, downSampleMode: str, upSampleMode: str, attention : bool, time_embedding_dim: int, dim : int) -> None:
221
+ def forward(self, tensor: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
222
+ return self.time_embed((p * self.noise_step).long().repeat(tensor.shape[0]))
223
+
224
+ class DDPMUNet(network.ModuleArgsDict):
225
+
226
+ def __init__(
227
+ self,
228
+ noise_step: int,
229
+ channels: list[int],
230
+ block_config: blocks.BlockConfig,
231
+ nb_conv_per_stage: int,
232
+ downsample_mode: str,
233
+ upsample_mode: str,
234
+ attention: bool,
235
+ time_embedding_dim: int,
236
+ dim: int,
237
+ ) -> None:
123
238
  super().__init__()
124
- self.add_module("t", DDPM.DDPM_TimeEmbedding(noise_step, time_embedding_dim), in_branch=[1], out_branch=["te"])
125
- self.add_module("UNetBlock_0", DDPM.DDPM_UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, time_embedding_dim=time_embedding_dim, dim=dim), in_branch=[0,"te"])
126
- self.add_module("Head", DDPM.DDPM_UNetHead(in_channels=channels[1], out_channels=1, dim=dim))
239
+ self.add_module(
240
+ "t",
241
+ DDPM.DDPMTimeEmbedding(noise_step, time_embedding_dim),
242
+ in_branch=[1],
243
+ out_branch=["te"],
244
+ )
245
+ self.add_module(
246
+ "UNetBlock_0",
247
+ DDPM.DDPMUNetBlock(
248
+ channels,
249
+ nb_conv_per_stage,
250
+ block_config,
251
+ downsample_mode=blocks.DownsampleMode[downsample_mode],
252
+ upsample_mode=blocks.UpsampleMode[upsample_mode],
253
+ attention=attention,
254
+ time_embedding_dim=time_embedding_dim,
255
+ dim=dim,
256
+ ),
257
+ in_branch=[0, "te"],
258
+ )
259
+ self.add_module(
260
+ "Head",
261
+ DDPM.DDPMUNetHead(in_channels=channels[1], out_channels=1, dim=dim),
262
+ )
127
263
 
128
- class DDPM_Inference(torch.nn.Module):
264
+ class DDPMInference(torch.nn.Module):
129
265
 
130
- def __init__(self, model: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], train_noise_step: int, inference_noise_step: int, beta_start: float, beta_end: float) -> None:
266
+ def __init__(
267
+ self,
268
+ model: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
269
+ train_noise_step: int,
270
+ inference_noise_step: int,
271
+ beta_start: float,
272
+ beta_end: float,
273
+ ) -> None:
131
274
  super().__init__()
132
275
  self.model = model
133
276
  self.train_noise_step = train_noise_step
134
277
  self.inference_noise_step = inference_noise_step
135
- self.forwardProcess = DDPM.DDPM_ForwardProcess(train_noise_step, beta_start, beta_end)
136
-
137
- def forward(self, input: torch.Tensor) -> torch.Tensor:
138
- x = torch.randn_like(input).to(input.device)
278
+ self.forwardProcess = DDPM.DDPMForwardProcess(train_noise_step, beta_start, beta_end)
279
+
280
+ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
281
+ x = torch.randn_like(tensor).to(tensor.device)
139
282
  result = []
140
283
  result.append(x.unsqueeze(1))
141
- description = lambda : "Inference : "+gpuInfo(input.device)
142
284
  offset = self.train_noise_step // self.inference_noise_step
143
- t_list = np.round(np.arange(self.train_noise_step-1, 0, -offset)).astype(int).tolist()
144
- with tqdm.tqdm(iterable = enumerate(t_list), desc = description(), total=len(t_list), leave=False, disable=True) as batch_iter:
145
- for i, t in batch_iter:
285
+ t_list = np.round(np.arange(self.train_noise_step - 1, 0, -offset)).astype(int).tolist()
286
+ with tqdm.tqdm(
287
+ iterable=enumerate(t_list),
288
+ desc="Inference : " + gpu_info(),
289
+ total=len(t_list),
290
+ leave=False,
291
+ disable=True,
292
+ ) as batch_iter:
293
+ for _, t in batch_iter:
146
294
  alpha_t_hat = self.forwardProcess.alpha_hat[t]
147
- alpha_t_hat_prev = self.forwardProcess.alpha_hat[t - offset] if t - offset >= 0 else self.forwardProcess.alpha_hat[0]
148
- eta_theta = self.model(torch.concat((x, input), dim=1), (torch.ones(input.shape[0], 1) * t).to(input.device).long())
295
+ alpha_t_hat_prev = (
296
+ self.forwardProcess.alpha_hat[t - offset]
297
+ if t - offset >= 0
298
+ else self.forwardProcess.alpha_hat[0]
299
+ )
300
+ eta_theta = self.model(
301
+ torch.concat((x, tensor), dim=1),
302
+ (torch.ones(tensor.shape[0], 1) * t).to(tensor.device).long(),
303
+ )
149
304
 
150
- predicted = (x - (1 - alpha_t_hat).sqrt() * eta_theta)/alpha_t_hat.sqrt()*alpha_t_hat_prev.sqrt()
305
+ predicted = (
306
+ (x - (1 - alpha_t_hat).sqrt() * eta_theta) / alpha_t_hat.sqrt() * alpha_t_hat_prev.sqrt()
307
+ )
151
308
 
152
309
  variance = ((1 - alpha_t_hat_prev) / (1 - alpha_t_hat)) * (1 - alpha_t_hat / alpha_t_hat_prev)
153
-
154
- direction = (1-alpha_t_hat_prev-variance).sqrt()*eta_theta
155
-
156
- x = predicted+direction
157
- x += variance.sqrt()*torch.randn_like(input).to(input.device)
158
-
310
+
311
+ direction = (1 - alpha_t_hat_prev - variance).sqrt() * eta_theta
312
+
313
+ x = predicted + direction
314
+ x += variance.sqrt() * torch.randn_like(tensor).to(tensor.device)
315
+
159
316
  result.append(x.unsqueeze(1))
160
-
161
- batch_iter.set_description(description())
317
+
318
+ batch_iter.set_description("Inference : " + gpu_info())
162
319
  result[-1] = torch.clip(result[-1], -1, 1)
163
320
  return torch.concat(result, dim=1)
164
321
 
165
- class DDPM_VLoss(torch.nn.Module):
322
+ class DDPMVLoss(torch.nn.Module):
166
323
 
167
324
  def __init__(self, alpha_hat: torch.Tensor) -> None:
168
325
  super().__init__()
169
326
  self.alpha_hat = alpha_hat
170
327
 
171
- def v(self, input: torch.Tensor, noise: torch.Tensor, t: torch.Tensor):
172
- alpha_hat_t = self.alpha_hat[t.cpu()].to(input.device).reshape(input.shape[0], *[1 for _ in range(len(input.shape)-1)])
173
- return alpha_hat_t.sqrt()*input-(1-alpha_hat_t).sqrt()*noise
174
-
175
- def forward(self, input: torch.Tensor, eta: torch.Tensor, eta_hat: torch.Tensor, t: int) -> torch.Tensor:
176
- return torch.concat((self.v(input, eta, t), self.v(input, eta_hat, t)), dim=1)
177
-
178
- def __init__( self,
179
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
180
- schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
181
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
182
- patch : Union[ModelPatch, None] = None,
183
- train_noise_step: int = 1000,
184
- inference_noise_step: int = 100,
185
- beta_start: float = 1e-4,
186
- beta_end: float = 0.02,
187
- time_embedding_dim: int = 100,
188
- channels: list[int]=[1, 64, 128, 256, 512, 1024],
189
- blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
190
- nb_conv_per_stage: int = 2,
191
- downSampleMode: str = "MAXPOOL",
192
- upSampleMode: str = "CONV_TRANSPOSE",
193
- attention : bool = False,
194
- dim : int = 3) -> None:
195
- super().__init__(in_channels=1, optimizer=optimizer, schedulers=schedulers, outputsCriterions=outputsCriterions, patch=patch, dim=dim)
328
+ def v(self, tensor: torch.Tensor, noise: torch.Tensor, t: torch.Tensor):
329
+ alpha_hat_t = (
330
+ self.alpha_hat[t.cpu()]
331
+ .to(tensor.device)
332
+ .reshape(tensor.shape[0], *[1 for _ in range(len(tensor.shape) - 1)])
333
+ )
334
+ return alpha_hat_t.sqrt() * tensor - (1 - alpha_hat_t).sqrt() * noise
335
+
336
+ def forward(self, tensor: torch.Tensor, eta: torch.Tensor, eta_hat: torch.Tensor, t: int) -> torch.Tensor:
337
+ return torch.concat((self.v(tensor, eta, t), self.v(tensor, eta_hat, t)), dim=1)
338
+
339
+ def __init__(
340
+ self,
341
+ optimizer: network.OptimizerLoader = network.OptimizerLoader(),
342
+ schedulers: dict[str, network.LRSchedulersLoader] = {
343
+ "default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
344
+ },
345
+ outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
346
+ patch: ModelPatch | None = None,
347
+ train_noise_step: int = 1000,
348
+ inference_noise_step: int = 100,
349
+ beta_start: float = 1e-4,
350
+ beta_end: float = 0.02,
351
+ time_embedding_dim: int = 100,
352
+ channels: list[int] = [1, 64, 128, 256, 512, 1024],
353
+ block_config: blocks.BlockConfig = blocks.BlockConfig(),
354
+ nb_conv_per_stage: int = 2,
355
+ downsample_mode: str = "MAXPOOL",
356
+ upsample_mode: str = "CONV_TRANSPOSE",
357
+ attention: bool = False,
358
+ dim: int = 3,
359
+ ) -> None:
360
+ super().__init__(
361
+ in_channels=1,
362
+ optimizer=optimizer,
363
+ schedulers=schedulers,
364
+ outputs_criterions=outputs_criterions,
365
+ patch=patch,
366
+ dim=dim,
367
+ )
196
368
  self.add_module("Noise", blocks.NormalNoise(), out_branch=["eta"], training=True)
197
- self.add_module("Sample", DDPM.DDPM_SampleT(train_noise_step), out_branch=["t"], training=True)
198
- self.add_module("Forward", DDPM.DDPM_ForwardProcess(train_noise_step, beta_start, beta_end), in_branch=[0, "t", "eta"], out_branch=["x_t"], training=True)
199
- self.add_module("Concat", blocks.Concat(), in_branch=["x_t",1], out_branch=["xy_t"], training=True)
200
- self.add_module("UNet", DDPM.DDPM_UNet(train_noise_step, channels, blockConfig, nb_conv_per_stage, downSampleMode, upSampleMode, attention, time_embedding_dim, dim), in_branch=["xy_t", "t"], out_branch=["eta_hat"], training=True)
201
-
202
- self.add_module("Noise_optim", DDPM.DDPM_VLoss(self["Forward"].alpha_hat), in_branch=[0, "eta", "eta_hat", "t"], out_branch=["noise"], training=True)
203
-
204
- self.add_module("Inference", DDPM.DDPM_Inference(self.inference, train_noise_step, inference_noise_step, beta_start, beta_end), in_branch=[1], training=False)
205
- self.add_module("LastImage", blocks.Select([slice(None, None), slice(-1, None)]), training=False)
206
-
207
- def inference(self, input: torch.Tensor, t: torch.Tensor):
208
- return self["UNet"](input, t)
369
+ self.add_module(
370
+ "Sample",
371
+ DDPM.DDPMSampleT(train_noise_step),
372
+ out_branch=["t"],
373
+ training=True,
374
+ )
375
+ self.add_module(
376
+ "Forward",
377
+ DDPM.DDPMForwardProcess(train_noise_step, beta_start, beta_end),
378
+ in_branch=[0, "t", "eta"],
379
+ out_branch=["x_t"],
380
+ training=True,
381
+ )
382
+ self.add_module(
383
+ "Concat",
384
+ blocks.Concat(),
385
+ in_branch=["x_t", 1],
386
+ out_branch=["xy_t"],
387
+ training=True,
388
+ )
389
+ self.add_module(
390
+ "UNet",
391
+ DDPM.DDPMUNet(
392
+ train_noise_step,
393
+ channels,
394
+ block_config,
395
+ nb_conv_per_stage,
396
+ downsample_mode,
397
+ upsample_mode,
398
+ attention,
399
+ time_embedding_dim,
400
+ dim,
401
+ ),
402
+ in_branch=["xy_t", "t"],
403
+ out_branch=["eta_hat"],
404
+ training=True,
405
+ )
406
+
407
+ self.add_module(
408
+ "Noise_optim",
409
+ DDPM.DDPMVLoss(self["Forward"].alpha_hat),
410
+ in_branch=[0, "eta", "eta_hat", "t"],
411
+ out_branch=["noise"],
412
+ training=True,
413
+ )
414
+
415
+ self.add_module(
416
+ "Inference",
417
+ DDPM.DDPMInference(
418
+ self.inference,
419
+ train_noise_step,
420
+ inference_noise_step,
421
+ beta_start,
422
+ beta_end,
423
+ ),
424
+ in_branch=[1],
425
+ training=False,
426
+ )
427
+ self.add_module(
428
+ "LastImage",
429
+ blocks.Select([slice(None, None), slice(-1, None)]),
430
+ training=False,
431
+ )
432
+
433
+ def inference(self, tensor: torch.Tensor, t: torch.Tensor):
434
+ return self["UNet"](tensor, t)
435
+
209
436
 
210
437
  class MSE(Criterion):
211
438
 
212
439
  def __init__(self):
213
440
  super().__init__()
214
441
  self.loss = torch.nn.MSELoss()
215
-
216
- def forward(self, input: torch.Tensor) -> torch.Tensor:
217
- return self.loss(input[:, 0, ...], input[:, 1, ...])
442
+
443
+ def forward(self, tensor: torch.Tensor, *targets: list[torch.Tensor]) -> torch.Tensor:
444
+ return self.loss(tensor[:, 0, ...], tensor[:, 1, ...])