konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +509 -290
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +384 -277
- konfai/evaluator.py +309 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +341 -222
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +781 -423
- konfai/predictor.py +645 -240
- konfai/trainer.py +527 -216
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +495 -249
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
- konfai-1.1.9.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.7.dist-info/RECORD +0 -39
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/models/generation/ddpm.py
CHANGED
|
@@ -1,16 +1,14 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
1
2
|
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
from typing import Union, Callable
|
|
3
|
+
import numpy as np
|
|
5
4
|
import torch
|
|
6
5
|
import tqdm
|
|
7
|
-
import numpy as np
|
|
8
6
|
|
|
9
|
-
from konfai.network import network, blocks
|
|
10
|
-
from konfai.utils.config import config
|
|
11
7
|
from konfai.data.patching import ModelPatch
|
|
12
|
-
from konfai.utils.utils import gpuInfo
|
|
13
8
|
from konfai.metric.measure import Criterion
|
|
9
|
+
from konfai.network import blocks, network
|
|
10
|
+
from konfai.utils.utils import gpu_info
|
|
11
|
+
|
|
14
12
|
|
|
15
13
|
def cosine_beta_schedule(timesteps, s=0.008):
|
|
16
14
|
steps = timesteps + 1
|
|
@@ -20,198 +18,427 @@ def cosine_beta_schedule(timesteps, s=0.008):
|
|
|
20
18
|
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
|
21
19
|
return torch.clip(betas, 0.0001, 0.9999)
|
|
22
20
|
|
|
21
|
+
|
|
23
22
|
class DDPM(network.Network):
|
|
24
|
-
|
|
25
|
-
class
|
|
23
|
+
|
|
24
|
+
class DDPMTE(torch.nn.Module):
|
|
26
25
|
|
|
27
26
|
def __init__(self, in_channels: int, out_channels: int) -> None:
|
|
28
27
|
super().__init__()
|
|
29
28
|
self.linear_0 = torch.nn.Linear(in_channels, out_channels)
|
|
30
29
|
self.siLU = torch.nn.SiLU()
|
|
31
30
|
self.linear_1 = torch.nn.Linear(out_channels, out_channels)
|
|
32
|
-
|
|
33
|
-
def forward(self,
|
|
34
|
-
return
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
31
|
+
|
|
32
|
+
def forward(self, tensor: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
|
33
|
+
return tensor + self.linear_1(self.siLU(self.linear_0(t))).reshape(
|
|
34
|
+
tensor.shape[0], -1, *[1 for _ in range(len(tensor.shape) - 2)]
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
class DDPMUNetBlock(network.ModuleArgsDict):
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
channels: list[int],
|
|
42
|
+
nb_conv_per_stage: int,
|
|
43
|
+
block_config: blocks.BlockConfig,
|
|
44
|
+
downsample_mode: blocks.DownsampleMode,
|
|
45
|
+
upsample_mode: blocks.UpsampleMode,
|
|
46
|
+
attention: bool,
|
|
47
|
+
time_embedding_dim: int,
|
|
48
|
+
dim: int,
|
|
49
|
+
i: int = 0,
|
|
50
|
+
) -> None:
|
|
39
51
|
super().__init__()
|
|
40
52
|
if i > 0:
|
|
41
|
-
self.add_module(
|
|
42
|
-
|
|
43
|
-
|
|
53
|
+
self.add_module(
|
|
54
|
+
downsample_mode.name,
|
|
55
|
+
blocks.downsample(
|
|
56
|
+
in_channels=channels[0],
|
|
57
|
+
out_channels=channels[1],
|
|
58
|
+
downsample_mode=downsample_mode,
|
|
59
|
+
dim=dim,
|
|
60
|
+
),
|
|
61
|
+
)
|
|
62
|
+
self.add_module(
|
|
63
|
+
"Te_down",
|
|
64
|
+
DDPM.DDPMTE(
|
|
65
|
+
time_embedding_dim,
|
|
66
|
+
channels[(1 if downsample_mode == blocks.DownsampleMode.CONV_STRIDE and i > 0 else 0)],
|
|
67
|
+
),
|
|
68
|
+
in_branch=[0, 1],
|
|
69
|
+
)
|
|
70
|
+
self.add_module(
|
|
71
|
+
"DownConvBlock",
|
|
72
|
+
blocks.ResBlock(
|
|
73
|
+
in_channels=channels[(1 if downsample_mode == blocks.DownsampleMode.CONV_STRIDE and i > 0 else 0)],
|
|
74
|
+
out_channels=channels[1],
|
|
75
|
+
block_configs=[block_config] * nb_conv_per_stage,
|
|
76
|
+
dim=dim,
|
|
77
|
+
),
|
|
78
|
+
)
|
|
44
79
|
if len(channels) > 2:
|
|
45
|
-
self.add_module(
|
|
46
|
-
|
|
47
|
-
|
|
80
|
+
self.add_module(
|
|
81
|
+
f"UNetBlock_{i + 1}",
|
|
82
|
+
DDPM.DDPMUNetBlock(
|
|
83
|
+
channels[1:],
|
|
84
|
+
nb_conv_per_stage,
|
|
85
|
+
block_config,
|
|
86
|
+
downsample_mode,
|
|
87
|
+
upsample_mode,
|
|
88
|
+
attention,
|
|
89
|
+
time_embedding_dim,
|
|
90
|
+
dim,
|
|
91
|
+
i + 1,
|
|
92
|
+
),
|
|
93
|
+
in_branch=[0, 1],
|
|
94
|
+
)
|
|
95
|
+
self.add_module(
|
|
96
|
+
"Te_up",
|
|
97
|
+
DDPM.DDPMTE(
|
|
98
|
+
time_embedding_dim,
|
|
99
|
+
(
|
|
100
|
+
(channels[1] + channels[2])
|
|
101
|
+
if upsample_mode != blocks.UpsampleMode.CONV_TRANSPOSE
|
|
102
|
+
else channels[1] * 2
|
|
103
|
+
),
|
|
104
|
+
),
|
|
105
|
+
in_branch=[0, 1],
|
|
106
|
+
)
|
|
107
|
+
self.add_module(
|
|
108
|
+
"UpConvBlock",
|
|
109
|
+
blocks.ResBlock(
|
|
110
|
+
in_channels=(
|
|
111
|
+
(channels[1] + channels[2])
|
|
112
|
+
if upsample_mode != blocks.UpsampleMode.CONV_TRANSPOSE
|
|
113
|
+
else channels[1] * 2
|
|
114
|
+
),
|
|
115
|
+
out_channels=channels[1],
|
|
116
|
+
block_configs=[block_config] * nb_conv_per_stage,
|
|
117
|
+
dim=dim,
|
|
118
|
+
),
|
|
119
|
+
)
|
|
48
120
|
if i > 0:
|
|
49
121
|
if attention:
|
|
50
|
-
self.add_module(
|
|
51
|
-
|
|
122
|
+
self.add_module(
|
|
123
|
+
"Attention",
|
|
124
|
+
blocks.Attention(f_g=channels[1], f_l=channels[0], f_int=channels[0], dim=dim),
|
|
125
|
+
in_branch=[2, 0],
|
|
126
|
+
out_branch=[2],
|
|
127
|
+
)
|
|
128
|
+
self.add_module(
|
|
129
|
+
upsample_mode.name,
|
|
130
|
+
blocks.upsample(
|
|
131
|
+
in_channels=channels[1],
|
|
132
|
+
out_channels=channels[0],
|
|
133
|
+
upsample_mode=upsample_mode,
|
|
134
|
+
dim=dim,
|
|
135
|
+
),
|
|
136
|
+
)
|
|
52
137
|
self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 2])
|
|
53
|
-
|
|
54
|
-
class
|
|
138
|
+
|
|
139
|
+
class DDPMUNetHead(network.ModuleArgsDict):
|
|
55
140
|
|
|
56
141
|
def __init__(self, in_channels: int, out_channels: int, dim: int) -> None:
|
|
57
142
|
super().__init__()
|
|
58
|
-
self.add_module(
|
|
143
|
+
self.add_module(
|
|
144
|
+
"Conv",
|
|
145
|
+
blocks.get_torch_module("Conv", dim)(
|
|
146
|
+
in_channels=in_channels,
|
|
147
|
+
out_channels=out_channels,
|
|
148
|
+
kernel_size=3,
|
|
149
|
+
stride=1,
|
|
150
|
+
padding=1,
|
|
151
|
+
),
|
|
152
|
+
)
|
|
59
153
|
|
|
60
|
-
class
|
|
154
|
+
class DDPMForwardProcess(torch.nn.Module):
|
|
61
155
|
|
|
62
|
-
def __init__(
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
noise_step: int = 1000,
|
|
159
|
+
beta_start: float = 1e-4,
|
|
160
|
+
beta_end: float = 0.02,
|
|
161
|
+
) -> None:
|
|
63
162
|
super().__init__()
|
|
64
163
|
self.betas = torch.linspace(beta_start, beta_end, noise_step)
|
|
65
|
-
self.betas = DDPM.
|
|
164
|
+
self.betas = DDPM.DDPMForwardProcess.enforce_zero_terminal_snr(self.betas)
|
|
66
165
|
self.alphas = 1 - self.betas
|
|
67
166
|
self.alpha_hat = torch.concat((torch.ones(1), torch.cumprod(self.alphas, dim=0)))
|
|
68
167
|
|
|
168
|
+
@staticmethod
|
|
69
169
|
def enforce_zero_terminal_snr(betas: torch.Tensor):
|
|
70
170
|
alphas = 1 - betas
|
|
71
171
|
alphas_bar = alphas.cumprod(0)
|
|
72
172
|
alphas_bar_sqrt = alphas_bar.sqrt()
|
|
73
173
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
|
74
|
-
|
|
75
|
-
alphas_bar_sqrt -=
|
|
76
|
-
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
|
|
77
|
-
|
|
78
|
-
alphas_bar = alphas_bar_sqrt ** 2
|
|
174
|
+
alphas_bar_sqrt_t = alphas_bar_sqrt[-1].clone()
|
|
175
|
+
alphas_bar_sqrt -= alphas_bar_sqrt_t
|
|
176
|
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_t)
|
|
177
|
+
alphas_bar = alphas_bar_sqrt**2
|
|
79
178
|
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
|
80
179
|
alphas = torch.cat([alphas_bar[0:1], alphas])
|
|
81
180
|
betas = 1 - alphas
|
|
82
181
|
return betas
|
|
83
182
|
|
|
84
|
-
def forward(self,
|
|
85
|
-
alpha_hat_t =
|
|
86
|
-
|
|
183
|
+
def forward(self, tensor: torch.Tensor, t: torch.Tensor, eta: torch.Tensor) -> torch.Tensor:
|
|
184
|
+
alpha_hat_t = (
|
|
185
|
+
self.alpha_hat[t.cpu()]
|
|
186
|
+
.to(tensor.device)
|
|
187
|
+
.reshape(tensor.shape[0], *[1 for _ in range(len(tensor.shape) - 1)])
|
|
188
|
+
)
|
|
189
|
+
result = alpha_hat_t.sqrt() * tensor + (1 - alpha_hat_t).sqrt() * eta
|
|
87
190
|
return result
|
|
88
|
-
|
|
89
|
-
class
|
|
191
|
+
|
|
192
|
+
class DDPMSampleT(torch.nn.Module):
|
|
90
193
|
|
|
91
194
|
def __init__(self, noise_step: int) -> None:
|
|
92
195
|
super().__init__()
|
|
93
196
|
self.noise_step = noise_step
|
|
94
|
-
|
|
95
|
-
def forward(self,
|
|
96
|
-
return torch.randint(0, self.noise_step, (
|
|
97
|
-
|
|
98
|
-
class
|
|
99
|
-
|
|
197
|
+
|
|
198
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
199
|
+
return torch.randint(0, self.noise_step, (tensor.shape[0],)).to(tensor.device)
|
|
200
|
+
|
|
201
|
+
class DDPMTimeEmbedding(torch.nn.Module):
|
|
202
|
+
|
|
203
|
+
@staticmethod
|
|
100
204
|
def sinusoidal_embedding(noise_step: int, time_embedding_dim: int):
|
|
101
205
|
noise_step += 1
|
|
102
206
|
embedding = torch.zeros(noise_step, time_embedding_dim)
|
|
103
207
|
wk = torch.tensor([1 / 10_000 ** (2 * j / time_embedding_dim) for j in range(time_embedding_dim)])
|
|
104
208
|
wk = wk.reshape((1, time_embedding_dim))
|
|
105
209
|
t = torch.arange(noise_step).reshape((noise_step, 1))
|
|
106
|
-
embedding[
|
|
107
|
-
embedding[:,1::2] = torch.cos(t * wk[
|
|
210
|
+
embedding[:, ::2] = torch.sin(t * wk[:, ::2])
|
|
211
|
+
embedding[:, 1::2] = torch.cos(t * wk[:, ::2])
|
|
108
212
|
return embedding
|
|
109
213
|
|
|
110
|
-
def __init__(self, noise_step: int=1000, time_embedding_dim: int=100) -> None:
|
|
214
|
+
def __init__(self, noise_step: int = 1000, time_embedding_dim: int = 100) -> None:
|
|
111
215
|
super().__init__()
|
|
112
216
|
self.time_embed = torch.nn.Embedding(noise_step, time_embedding_dim)
|
|
113
|
-
self.time_embed.weight.data = DDPM.
|
|
217
|
+
self.time_embed.weight.data = DDPM.DDPMTimeEmbedding.sinusoidal_embedding(noise_step, time_embedding_dim)
|
|
114
218
|
self.time_embed.requires_grad_(False)
|
|
115
219
|
self.noise_step = noise_step
|
|
116
|
-
|
|
117
|
-
def forward(self, input: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
|
|
118
|
-
return self.time_embed((p*self.noise_step).long().repeat((input.shape[0])))
|
|
119
|
-
|
|
120
|
-
class DDPM_UNet(network.ModuleArgsDict):
|
|
121
220
|
|
|
122
|
-
def
|
|
221
|
+
def forward(self, tensor: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
|
|
222
|
+
return self.time_embed((p * self.noise_step).long().repeat(tensor.shape[0]))
|
|
223
|
+
|
|
224
|
+
class DDPMUNet(network.ModuleArgsDict):
|
|
225
|
+
|
|
226
|
+
def __init__(
|
|
227
|
+
self,
|
|
228
|
+
noise_step: int,
|
|
229
|
+
channels: list[int],
|
|
230
|
+
block_config: blocks.BlockConfig,
|
|
231
|
+
nb_conv_per_stage: int,
|
|
232
|
+
downsample_mode: str,
|
|
233
|
+
upsample_mode: str,
|
|
234
|
+
attention: bool,
|
|
235
|
+
time_embedding_dim: int,
|
|
236
|
+
dim: int,
|
|
237
|
+
) -> None:
|
|
123
238
|
super().__init__()
|
|
124
|
-
self.add_module(
|
|
125
|
-
|
|
126
|
-
|
|
239
|
+
self.add_module(
|
|
240
|
+
"t",
|
|
241
|
+
DDPM.DDPMTimeEmbedding(noise_step, time_embedding_dim),
|
|
242
|
+
in_branch=[1],
|
|
243
|
+
out_branch=["te"],
|
|
244
|
+
)
|
|
245
|
+
self.add_module(
|
|
246
|
+
"UNetBlock_0",
|
|
247
|
+
DDPM.DDPMUNetBlock(
|
|
248
|
+
channels,
|
|
249
|
+
nb_conv_per_stage,
|
|
250
|
+
block_config,
|
|
251
|
+
downsample_mode=blocks.DownsampleMode[downsample_mode],
|
|
252
|
+
upsample_mode=blocks.UpsampleMode[upsample_mode],
|
|
253
|
+
attention=attention,
|
|
254
|
+
time_embedding_dim=time_embedding_dim,
|
|
255
|
+
dim=dim,
|
|
256
|
+
),
|
|
257
|
+
in_branch=[0, "te"],
|
|
258
|
+
)
|
|
259
|
+
self.add_module(
|
|
260
|
+
"Head",
|
|
261
|
+
DDPM.DDPMUNetHead(in_channels=channels[1], out_channels=1, dim=dim),
|
|
262
|
+
)
|
|
127
263
|
|
|
128
|
-
class
|
|
264
|
+
class DDPMInference(torch.nn.Module):
|
|
129
265
|
|
|
130
|
-
def __init__(
|
|
266
|
+
def __init__(
|
|
267
|
+
self,
|
|
268
|
+
model: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
269
|
+
train_noise_step: int,
|
|
270
|
+
inference_noise_step: int,
|
|
271
|
+
beta_start: float,
|
|
272
|
+
beta_end: float,
|
|
273
|
+
) -> None:
|
|
131
274
|
super().__init__()
|
|
132
275
|
self.model = model
|
|
133
276
|
self.train_noise_step = train_noise_step
|
|
134
277
|
self.inference_noise_step = inference_noise_step
|
|
135
|
-
self.forwardProcess = DDPM.
|
|
136
|
-
|
|
137
|
-
def forward(self,
|
|
138
|
-
x = torch.randn_like(
|
|
278
|
+
self.forwardProcess = DDPM.DDPMForwardProcess(train_noise_step, beta_start, beta_end)
|
|
279
|
+
|
|
280
|
+
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
281
|
+
x = torch.randn_like(tensor).to(tensor.device)
|
|
139
282
|
result = []
|
|
140
283
|
result.append(x.unsqueeze(1))
|
|
141
|
-
description = lambda : "Inference : "+gpuInfo(input.device)
|
|
142
284
|
offset = self.train_noise_step // self.inference_noise_step
|
|
143
|
-
t_list = np.round(np.arange(self.train_noise_step-1, 0, -offset)).astype(int).tolist()
|
|
144
|
-
with tqdm.tqdm(
|
|
145
|
-
|
|
285
|
+
t_list = np.round(np.arange(self.train_noise_step - 1, 0, -offset)).astype(int).tolist()
|
|
286
|
+
with tqdm.tqdm(
|
|
287
|
+
iterable=enumerate(t_list),
|
|
288
|
+
desc="Inference : " + gpu_info(),
|
|
289
|
+
total=len(t_list),
|
|
290
|
+
leave=False,
|
|
291
|
+
disable=True,
|
|
292
|
+
) as batch_iter:
|
|
293
|
+
for _, t in batch_iter:
|
|
146
294
|
alpha_t_hat = self.forwardProcess.alpha_hat[t]
|
|
147
|
-
alpha_t_hat_prev =
|
|
148
|
-
|
|
295
|
+
alpha_t_hat_prev = (
|
|
296
|
+
self.forwardProcess.alpha_hat[t - offset]
|
|
297
|
+
if t - offset >= 0
|
|
298
|
+
else self.forwardProcess.alpha_hat[0]
|
|
299
|
+
)
|
|
300
|
+
eta_theta = self.model(
|
|
301
|
+
torch.concat((x, tensor), dim=1),
|
|
302
|
+
(torch.ones(tensor.shape[0], 1) * t).to(tensor.device).long(),
|
|
303
|
+
)
|
|
149
304
|
|
|
150
|
-
predicted = (
|
|
305
|
+
predicted = (
|
|
306
|
+
(x - (1 - alpha_t_hat).sqrt() * eta_theta) / alpha_t_hat.sqrt() * alpha_t_hat_prev.sqrt()
|
|
307
|
+
)
|
|
151
308
|
|
|
152
309
|
variance = ((1 - alpha_t_hat_prev) / (1 - alpha_t_hat)) * (1 - alpha_t_hat / alpha_t_hat_prev)
|
|
153
|
-
|
|
154
|
-
direction = (1-alpha_t_hat_prev-variance).sqrt()*eta_theta
|
|
155
|
-
|
|
156
|
-
x = predicted+direction
|
|
157
|
-
x += variance.sqrt()*torch.randn_like(
|
|
158
|
-
|
|
310
|
+
|
|
311
|
+
direction = (1 - alpha_t_hat_prev - variance).sqrt() * eta_theta
|
|
312
|
+
|
|
313
|
+
x = predicted + direction
|
|
314
|
+
x += variance.sqrt() * torch.randn_like(tensor).to(tensor.device)
|
|
315
|
+
|
|
159
316
|
result.append(x.unsqueeze(1))
|
|
160
|
-
|
|
161
|
-
batch_iter.set_description(
|
|
317
|
+
|
|
318
|
+
batch_iter.set_description("Inference : " + gpu_info())
|
|
162
319
|
result[-1] = torch.clip(result[-1], -1, 1)
|
|
163
320
|
return torch.concat(result, dim=1)
|
|
164
321
|
|
|
165
|
-
class
|
|
322
|
+
class DDPMVLoss(torch.nn.Module):
|
|
166
323
|
|
|
167
324
|
def __init__(self, alpha_hat: torch.Tensor) -> None:
|
|
168
325
|
super().__init__()
|
|
169
326
|
self.alpha_hat = alpha_hat
|
|
170
327
|
|
|
171
|
-
def v(self,
|
|
172
|
-
alpha_hat_t =
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
328
|
+
def v(self, tensor: torch.Tensor, noise: torch.Tensor, t: torch.Tensor):
|
|
329
|
+
alpha_hat_t = (
|
|
330
|
+
self.alpha_hat[t.cpu()]
|
|
331
|
+
.to(tensor.device)
|
|
332
|
+
.reshape(tensor.shape[0], *[1 for _ in range(len(tensor.shape) - 1)])
|
|
333
|
+
)
|
|
334
|
+
return alpha_hat_t.sqrt() * tensor - (1 - alpha_hat_t).sqrt() * noise
|
|
335
|
+
|
|
336
|
+
def forward(self, tensor: torch.Tensor, eta: torch.Tensor, eta_hat: torch.Tensor, t: int) -> torch.Tensor:
|
|
337
|
+
return torch.concat((self.v(tensor, eta, t), self.v(tensor, eta_hat, t)), dim=1)
|
|
338
|
+
|
|
339
|
+
def __init__(
|
|
340
|
+
self,
|
|
341
|
+
optimizer: network.OptimizerLoader = network.OptimizerLoader(),
|
|
342
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {
|
|
343
|
+
"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)
|
|
344
|
+
},
|
|
345
|
+
outputs_criterions: dict[str, network.TargetCriterionsLoader] = {"default": network.TargetCriterionsLoader()},
|
|
346
|
+
patch: ModelPatch | None = None,
|
|
347
|
+
train_noise_step: int = 1000,
|
|
348
|
+
inference_noise_step: int = 100,
|
|
349
|
+
beta_start: float = 1e-4,
|
|
350
|
+
beta_end: float = 0.02,
|
|
351
|
+
time_embedding_dim: int = 100,
|
|
352
|
+
channels: list[int] = [1, 64, 128, 256, 512, 1024],
|
|
353
|
+
block_config: blocks.BlockConfig = blocks.BlockConfig(),
|
|
354
|
+
nb_conv_per_stage: int = 2,
|
|
355
|
+
downsample_mode: str = "MAXPOOL",
|
|
356
|
+
upsample_mode: str = "CONV_TRANSPOSE",
|
|
357
|
+
attention: bool = False,
|
|
358
|
+
dim: int = 3,
|
|
359
|
+
) -> None:
|
|
360
|
+
super().__init__(
|
|
361
|
+
in_channels=1,
|
|
362
|
+
optimizer=optimizer,
|
|
363
|
+
schedulers=schedulers,
|
|
364
|
+
outputs_criterions=outputs_criterions,
|
|
365
|
+
patch=patch,
|
|
366
|
+
dim=dim,
|
|
367
|
+
)
|
|
196
368
|
self.add_module("Noise", blocks.NormalNoise(), out_branch=["eta"], training=True)
|
|
197
|
-
self.add_module(
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
369
|
+
self.add_module(
|
|
370
|
+
"Sample",
|
|
371
|
+
DDPM.DDPMSampleT(train_noise_step),
|
|
372
|
+
out_branch=["t"],
|
|
373
|
+
training=True,
|
|
374
|
+
)
|
|
375
|
+
self.add_module(
|
|
376
|
+
"Forward",
|
|
377
|
+
DDPM.DDPMForwardProcess(train_noise_step, beta_start, beta_end),
|
|
378
|
+
in_branch=[0, "t", "eta"],
|
|
379
|
+
out_branch=["x_t"],
|
|
380
|
+
training=True,
|
|
381
|
+
)
|
|
382
|
+
self.add_module(
|
|
383
|
+
"Concat",
|
|
384
|
+
blocks.Concat(),
|
|
385
|
+
in_branch=["x_t", 1],
|
|
386
|
+
out_branch=["xy_t"],
|
|
387
|
+
training=True,
|
|
388
|
+
)
|
|
389
|
+
self.add_module(
|
|
390
|
+
"UNet",
|
|
391
|
+
DDPM.DDPMUNet(
|
|
392
|
+
train_noise_step,
|
|
393
|
+
channels,
|
|
394
|
+
block_config,
|
|
395
|
+
nb_conv_per_stage,
|
|
396
|
+
downsample_mode,
|
|
397
|
+
upsample_mode,
|
|
398
|
+
attention,
|
|
399
|
+
time_embedding_dim,
|
|
400
|
+
dim,
|
|
401
|
+
),
|
|
402
|
+
in_branch=["xy_t", "t"],
|
|
403
|
+
out_branch=["eta_hat"],
|
|
404
|
+
training=True,
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
self.add_module(
|
|
408
|
+
"Noise_optim",
|
|
409
|
+
DDPM.DDPMVLoss(self["Forward"].alpha_hat),
|
|
410
|
+
in_branch=[0, "eta", "eta_hat", "t"],
|
|
411
|
+
out_branch=["noise"],
|
|
412
|
+
training=True,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
self.add_module(
|
|
416
|
+
"Inference",
|
|
417
|
+
DDPM.DDPMInference(
|
|
418
|
+
self.inference,
|
|
419
|
+
train_noise_step,
|
|
420
|
+
inference_noise_step,
|
|
421
|
+
beta_start,
|
|
422
|
+
beta_end,
|
|
423
|
+
),
|
|
424
|
+
in_branch=[1],
|
|
425
|
+
training=False,
|
|
426
|
+
)
|
|
427
|
+
self.add_module(
|
|
428
|
+
"LastImage",
|
|
429
|
+
blocks.Select([slice(None, None), slice(-1, None)]),
|
|
430
|
+
training=False,
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
def inference(self, tensor: torch.Tensor, t: torch.Tensor):
|
|
434
|
+
return self["UNet"](tensor, t)
|
|
435
|
+
|
|
209
436
|
|
|
210
437
|
class MSE(Criterion):
|
|
211
438
|
|
|
212
439
|
def __init__(self):
|
|
213
440
|
super().__init__()
|
|
214
441
|
self.loss = torch.nn.MSELoss()
|
|
215
|
-
|
|
216
|
-
def forward(self,
|
|
217
|
-
return self.loss(
|
|
442
|
+
|
|
443
|
+
def forward(self, tensor: torch.Tensor, *targets: list[torch.Tensor]) -> torch.Tensor:
|
|
444
|
+
return self.loss(tensor[:, 0, ...], tensor[:, 1, ...])
|