konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +509 -290
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +384 -277
- konfai/evaluator.py +309 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +341 -222
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +781 -423
- konfai/predictor.py +645 -240
- konfai/trainer.py +527 -216
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +495 -249
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
- konfai-1.1.9.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.7.dist-info/RECORD +0 -39
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/metric/measure.py
CHANGED
|
@@ -1,334 +1,428 @@
|
|
|
1
|
-
|
|
1
|
+
import copy
|
|
2
2
|
import importlib
|
|
3
|
-
import numpy as np
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from torchvision.models import inception_v3, Inception_V3_Weights
|
|
7
|
-
from torchvision.transforms import functional as F
|
|
8
|
-
from scipy import linalg
|
|
9
|
-
|
|
10
|
-
import torch.nn.functional as F
|
|
11
3
|
import os
|
|
12
|
-
|
|
13
|
-
from
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from collections.abc import Callable
|
|
14
6
|
from functools import partial
|
|
15
|
-
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F # noqa: N812
|
|
11
|
+
from scipy import linalg
|
|
16
12
|
from skimage.metrics import structural_similarity
|
|
17
|
-
import
|
|
18
|
-
from
|
|
13
|
+
from torchvision.models import Inception_V3_Weights, inception_v3
|
|
14
|
+
from tqdm import tqdm
|
|
19
15
|
|
|
20
|
-
from konfai.utils.config import config
|
|
21
|
-
from konfai.utils.utils import _getModule, download_url
|
|
22
16
|
from konfai.data.patching import ModelPatch
|
|
23
17
|
from konfai.network.blocks import LatentDistribution
|
|
24
18
|
from konfai.network.network import ModelLoader, Network
|
|
25
|
-
from
|
|
19
|
+
from konfai.utils.config import config
|
|
20
|
+
from konfai.utils.utils import download_url, get_module
|
|
21
|
+
|
|
22
|
+
models_register = {}
|
|
26
23
|
|
|
27
|
-
modelsRegister = {}
|
|
28
24
|
|
|
29
25
|
class Criterion(torch.nn.Module, ABC):
|
|
30
26
|
|
|
31
27
|
def __init__(self) -> None:
|
|
32
28
|
super().__init__()
|
|
33
29
|
|
|
34
|
-
def init(self, model
|
|
30
|
+
def init(self, model: torch.nn.Module, output_group: str, target_group: str) -> str:
|
|
35
31
|
return output_group
|
|
36
32
|
|
|
37
33
|
@abstractmethod
|
|
38
|
-
def forward(self, output: torch.Tensor, *targets
|
|
34
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
39
35
|
pass
|
|
40
36
|
|
|
37
|
+
|
|
41
38
|
class MaskedLoss(Criterion):
|
|
42
39
|
|
|
43
|
-
def __init__(
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
43
|
+
mode_image_masked: bool,
|
|
44
|
+
) -> None:
|
|
44
45
|
super().__init__()
|
|
45
46
|
self.loss = loss
|
|
46
47
|
self.mode_image_masked = mode_image_masked
|
|
47
48
|
|
|
48
|
-
def
|
|
49
|
+
def get_mask(self, *targets: torch.Tensor) -> torch.Tensor:
|
|
49
50
|
result = None
|
|
50
51
|
if len(targets) > 0:
|
|
51
52
|
result = targets[0]
|
|
52
53
|
for mask in targets[1:]:
|
|
53
|
-
result = result*mask
|
|
54
|
+
result = result * mask
|
|
54
55
|
return result
|
|
55
56
|
|
|
56
|
-
def forward(self, output: torch.Tensor, *targets
|
|
57
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
57
58
|
loss = torch.tensor(0, dtype=torch.float32).to(output.device)
|
|
58
|
-
mask = self.
|
|
59
|
+
mask = self.get_mask(targets[1:])
|
|
59
60
|
for batch in range(output.shape[0]):
|
|
60
61
|
if mask is not None:
|
|
61
62
|
if self.mode_image_masked:
|
|
62
63
|
for i in torch.unique(mask[batch]):
|
|
63
64
|
if i != 0:
|
|
64
|
-
loss += self.loss(
|
|
65
|
+
loss += self.loss(
|
|
66
|
+
output[batch, ...] * torch.where(mask == i, 1, 0),
|
|
67
|
+
targets[0][batch, ...] * torch.where(mask == i, 1, 0),
|
|
68
|
+
)
|
|
65
69
|
else:
|
|
66
70
|
for i in torch.unique(mask[batch]):
|
|
67
71
|
if i != 0:
|
|
68
72
|
index = mask[batch, ...] == i
|
|
69
|
-
loss += self.loss(
|
|
73
|
+
loss += self.loss(
|
|
74
|
+
torch.masked_select(output[batch, ...], index),
|
|
75
|
+
torch.masked_select(targets[0][batch, ...], index),
|
|
76
|
+
)
|
|
70
77
|
else:
|
|
71
78
|
loss += self.loss(output[batch, ...], targets[0][batch, ...])
|
|
72
|
-
return loss/output.shape[0]
|
|
73
|
-
|
|
79
|
+
return loss / output.shape[0]
|
|
80
|
+
|
|
81
|
+
|
|
74
82
|
class MSE(MaskedLoss):
|
|
75
83
|
|
|
84
|
+
@staticmethod
|
|
76
85
|
def _loss(reduction: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
77
86
|
return torch.nn.MSELoss(reduction=reduction)(x, y)
|
|
78
87
|
|
|
79
88
|
def __init__(self, reduction: str = "mean") -> None:
|
|
80
89
|
super().__init__(partial(MSE._loss, reduction), False)
|
|
81
90
|
|
|
91
|
+
|
|
82
92
|
class MAE(MaskedLoss):
|
|
83
93
|
|
|
94
|
+
@staticmethod
|
|
84
95
|
def _loss(reduction: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
85
96
|
return torch.nn.L1Loss(reduction=reduction)(x, y)
|
|
86
|
-
|
|
97
|
+
|
|
87
98
|
def __init__(self, reduction: str = "mean") -> None:
|
|
88
99
|
super().__init__(partial(MAE._loss, reduction), False)
|
|
89
100
|
|
|
101
|
+
|
|
90
102
|
class PSNR(MaskedLoss):
|
|
91
103
|
|
|
92
|
-
|
|
104
|
+
@staticmethod
|
|
105
|
+
def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
93
106
|
mse = torch.mean((x - y).pow(2))
|
|
94
|
-
psnr = 10 * torch.log10(dynamic_range**2 / mse)
|
|
95
|
-
return psnr
|
|
107
|
+
psnr = 10 * torch.log10(dynamic_range**2 / mse)
|
|
108
|
+
return psnr
|
|
96
109
|
|
|
97
|
-
def __init__(self, dynamic_range:
|
|
98
|
-
dynamic_range = dynamic_range if dynamic_range else 1024+3071
|
|
110
|
+
def __init__(self, dynamic_range: float | None = None) -> None:
|
|
111
|
+
dynamic_range = dynamic_range if dynamic_range else 1024 + 3071
|
|
99
112
|
super().__init__(partial(PSNR._loss, dynamic_range), False)
|
|
100
|
-
|
|
113
|
+
|
|
114
|
+
|
|
101
115
|
class SSIM(MaskedLoss):
|
|
102
|
-
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
103
118
|
def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
104
|
-
return structural_similarity(
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
119
|
+
return structural_similarity(
|
|
120
|
+
x[0][0].detach().cpu().numpy(),
|
|
121
|
+
y[0][0].cpu().numpy(),
|
|
122
|
+
data_range=dynamic_range,
|
|
123
|
+
gradient=False,
|
|
124
|
+
full=False,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def __init__(self, dynamic_range: float | None = None) -> None:
|
|
128
|
+
dynamic_range = dynamic_range if dynamic_range else 1024 + 3000
|
|
108
129
|
super().__init__(partial(SSIM._loss, dynamic_range), True)
|
|
109
130
|
|
|
110
131
|
|
|
111
132
|
class LPIPS(MaskedLoss):
|
|
112
133
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
134
|
+
@staticmethod
|
|
135
|
+
def normalize(tensor: torch.Tensor) -> torch.Tensor:
|
|
136
|
+
return (tensor - torch.min(tensor)) / (torch.max(tensor) - torch.min(tensor)) * 2 - 1
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def preprocessing(tensor: torch.Tensor) -> torch.Tensor:
|
|
140
|
+
return tensor.repeat((1, 3, 1, 1))
|
|
141
|
+
|
|
142
|
+
@staticmethod
|
|
119
143
|
def _loss(loss_fn_alex, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
120
|
-
|
|
121
|
-
|
|
144
|
+
dataset_patch = ModelPatch([1, 64, 64])
|
|
145
|
+
dataset_patch.load(x.shape[2:])
|
|
122
146
|
|
|
123
|
-
|
|
147
|
+
patch_iterator = dataset_patch.disassemble(LPIPS.normalize(x), LPIPS.normalize(y))
|
|
124
148
|
loss = 0
|
|
125
|
-
with tqdm(
|
|
126
|
-
|
|
149
|
+
with tqdm(
|
|
150
|
+
iterable=enumerate(patch_iterator),
|
|
151
|
+
leave=False,
|
|
152
|
+
total=dataset_patch.get_size(0),
|
|
153
|
+
) as batch_iter:
|
|
154
|
+
for _, patch_input in batch_iter:
|
|
127
155
|
real, fake = LPIPS.preprocessing(patch_input[0]), LPIPS.preprocessing(patch_input[1])
|
|
128
156
|
loss += loss_fn_alex(real, fake).item()
|
|
129
|
-
return loss/
|
|
157
|
+
return loss / dataset_patch.get_size(0)
|
|
130
158
|
|
|
131
159
|
def __init__(self, model: str = "alex") -> None:
|
|
132
160
|
import lpips
|
|
161
|
+
|
|
133
162
|
super().__init__(partial(LPIPS._loss, lpips.LPIPS(net=model)), True)
|
|
134
163
|
|
|
164
|
+
|
|
135
165
|
class Dice(Criterion):
|
|
136
|
-
|
|
137
|
-
def __init__(self, smooth
|
|
166
|
+
|
|
167
|
+
def __init__(self, smooth: float = 1e-6) -> None:
|
|
138
168
|
super().__init__()
|
|
139
169
|
self.smooth = smooth
|
|
140
|
-
|
|
141
|
-
def flatten(self, tensor
|
|
170
|
+
|
|
171
|
+
def flatten(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
142
172
|
return tensor.permute((1, 0) + tuple(range(2, tensor.dim()))).contiguous().view(tensor.size(1), -1)
|
|
143
173
|
|
|
144
|
-
def dice_per_channel(self,
|
|
145
|
-
|
|
174
|
+
def dice_per_channel(self, tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
175
|
+
tensor = self.flatten(tensor)
|
|
146
176
|
target = self.flatten(target)
|
|
147
|
-
return (2
|
|
177
|
+
return (2.0 * (tensor * target).sum() + self.smooth) / (tensor.sum() + target.sum() + self.smooth)
|
|
148
178
|
|
|
149
|
-
def forward(self, output: torch.Tensor, *targets
|
|
150
|
-
target = F.interpolate(
|
|
151
|
-
targets[0],
|
|
152
|
-
output.shape[2:],
|
|
153
|
-
mode="nearest")
|
|
179
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
180
|
+
target = F.interpolate(targets[0], output.shape[2:], mode="nearest")
|
|
154
181
|
if output.shape[1] == 1:
|
|
155
|
-
output =
|
|
156
|
-
|
|
157
|
-
|
|
182
|
+
output = (
|
|
183
|
+
F.one_hot(
|
|
184
|
+
output.type(torch.int64),
|
|
185
|
+
num_classes=int(torch.max(output).item() + 1),
|
|
186
|
+
)
|
|
187
|
+
.permute(0, len(target.shape), *[i + 1 for i in range(len(target.shape) - 1)])
|
|
188
|
+
.float()
|
|
189
|
+
)
|
|
190
|
+
target = (
|
|
191
|
+
F.one_hot(target.type(torch.int64), num_classes=output.shape[1])
|
|
192
|
+
.permute(0, len(target.shape), *[i + 1 for i in range(len(target.shape) - 1)])
|
|
193
|
+
.float()
|
|
194
|
+
.squeeze(2)
|
|
195
|
+
)
|
|
196
|
+
return 1 - torch.mean(self.dice_per_channel(output, target))
|
|
197
|
+
|
|
158
198
|
|
|
159
199
|
class GradientImages(Criterion):
|
|
160
200
|
|
|
161
201
|
def __init__(self):
|
|
162
202
|
super().__init__()
|
|
163
|
-
|
|
203
|
+
|
|
164
204
|
@staticmethod
|
|
165
|
-
def
|
|
205
|
+
def _image_gradient_2d(image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
166
206
|
dx = image[:, :, 1:, :] - image[:, :, :-1, :]
|
|
167
207
|
dy = image[:, :, :, 1:] - image[:, :, :, :-1]
|
|
168
208
|
return dx, dy
|
|
169
209
|
|
|
170
210
|
@staticmethod
|
|
171
|
-
def
|
|
211
|
+
def _image_gradient_3d(
|
|
212
|
+
image: torch.Tensor,
|
|
213
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
172
214
|
dx = image[:, :, 1:, :, :] - image[:, :, :-1, :, :]
|
|
173
215
|
dy = image[:, :, :, 1:, :] - image[:, :, :, :-1, :]
|
|
174
216
|
dz = image[:, :, :, :, 1:] - image[:, :, :, :, :-1]
|
|
175
217
|
return dx, dy, dz
|
|
176
|
-
|
|
177
|
-
def forward(self, output: torch.Tensor, *targets
|
|
218
|
+
|
|
219
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
178
220
|
target_0 = targets[0]
|
|
179
221
|
if len(output.shape) == 5:
|
|
180
|
-
dx, dy, dz = GradientImages.
|
|
222
|
+
dx, dy, dz = GradientImages._image_gradient_3d(output)
|
|
181
223
|
if target_0 is not None:
|
|
182
|
-
dx_tmp, dy_tmp, dz_tmp = GradientImages.
|
|
224
|
+
dx_tmp, dy_tmp, dz_tmp = GradientImages._image_gradient_3d(target_0)
|
|
183
225
|
dx -= dx_tmp
|
|
184
226
|
dy -= dy_tmp
|
|
185
227
|
dz -= dz_tmp
|
|
186
228
|
return dx.norm() + dy.norm() + dz.norm()
|
|
187
229
|
else:
|
|
188
|
-
dx, dy = GradientImages.
|
|
230
|
+
dx, dy = GradientImages._image_gradient_2d(output)
|
|
189
231
|
if target_0 is not None:
|
|
190
|
-
dx_tmp, dy_tmp = GradientImages.
|
|
232
|
+
dx_tmp, dy_tmp = GradientImages._image_gradient_2d(target_0)
|
|
191
233
|
dx -= dx_tmp
|
|
192
234
|
dy -= dy_tmp
|
|
193
235
|
return dx.norm() + dy.norm()
|
|
194
|
-
|
|
236
|
+
|
|
237
|
+
|
|
195
238
|
class BCE(Criterion):
|
|
196
239
|
|
|
197
|
-
def __init__(self, target
|
|
240
|
+
def __init__(self, target: float = 0) -> None:
|
|
198
241
|
super().__init__()
|
|
199
242
|
self.loss = torch.nn.BCEWithLogitsLoss()
|
|
200
|
-
self.register_buffer(
|
|
243
|
+
self.register_buffer("target", torch.tensor(target).type(torch.float32))
|
|
201
244
|
|
|
202
|
-
def forward(self, output: torch.Tensor, *targets
|
|
245
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
203
246
|
target = self._buffers["target"]
|
|
204
247
|
return self.loss(output, target.to(output.device).expand_as(output))
|
|
205
248
|
|
|
249
|
+
|
|
206
250
|
class PatchGanLoss(Criterion):
|
|
207
251
|
|
|
208
|
-
def __init__(self, target
|
|
252
|
+
def __init__(self, target: float = 0) -> None:
|
|
209
253
|
super().__init__()
|
|
210
254
|
self.loss = torch.nn.MSELoss()
|
|
211
|
-
self.register_buffer(
|
|
255
|
+
self.register_buffer("target", torch.tensor(target).type(torch.float32))
|
|
212
256
|
|
|
213
|
-
def forward(self, output: torch.Tensor, *targets
|
|
257
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
214
258
|
target = self._buffers["target"]
|
|
215
|
-
return self.loss(output, (torch.ones_like(output)*target).to(output.device))
|
|
259
|
+
return self.loss(output, (torch.ones_like(output) * target).to(output.device))
|
|
260
|
+
|
|
216
261
|
|
|
217
262
|
class WGP(Criterion):
|
|
218
263
|
|
|
219
264
|
def __init__(self) -> None:
|
|
220
265
|
super().__init__()
|
|
221
|
-
|
|
222
|
-
def forward(self, output: torch.Tensor, *targets
|
|
223
|
-
return torch.mean((output - 1)**2)
|
|
266
|
+
|
|
267
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
268
|
+
return torch.mean((output - 1) ** 2)
|
|
269
|
+
|
|
224
270
|
|
|
225
271
|
class Gram(Criterion):
|
|
226
272
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
273
|
+
@staticmethod
|
|
274
|
+
def compute_gram(tensor: torch.Tensor):
|
|
275
|
+
(b, ch, w) = tensor.size()
|
|
276
|
+
with torch.amp.autocast("cuda", enabled=False):
|
|
277
|
+
return tensor.bmm(tensor.transpose(1, 2)).div(ch * w)
|
|
231
278
|
|
|
232
279
|
def __init__(self) -> None:
|
|
233
280
|
super().__init__()
|
|
234
|
-
self.loss = torch.nn.L1Loss(reduction=
|
|
281
|
+
self.loss = torch.nn.L1Loss(reduction="sum")
|
|
235
282
|
|
|
236
|
-
def forward(self, output: torch.Tensor,
|
|
283
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
284
|
+
target = targets[0]
|
|
237
285
|
if len(output.shape) > 3:
|
|
238
286
|
output = output.view(output.shape[0], output.shape[1], int(np.prod(output.shape[2:])))
|
|
239
287
|
if len(target.shape) > 3:
|
|
240
288
|
target = target.view(target.shape[0], target.shape[1], int(np.prod(target.shape[2:])))
|
|
241
|
-
return self.loss(Gram.
|
|
289
|
+
return self.loss(Gram.compute_gram(output), Gram.compute_gram(target))
|
|
290
|
+
|
|
242
291
|
|
|
243
292
|
class PerceptualLoss(Criterion):
|
|
244
|
-
|
|
245
|
-
class Module
|
|
246
|
-
|
|
293
|
+
|
|
294
|
+
class Module:
|
|
295
|
+
|
|
247
296
|
@config()
|
|
248
297
|
def __init__(self, losses: dict[str, float] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
|
|
249
298
|
self.losses = losses
|
|
250
|
-
self.
|
|
299
|
+
self.konfai_args = os.environ["KONFAI_CONFIG_PATH"] if "KONFAI_CONFIG_PATH" in os.environ else ""
|
|
251
300
|
|
|
252
|
-
def
|
|
301
|
+
def get_loss(self) -> dict[torch.nn.Module, float]:
|
|
253
302
|
result: dict[torch.nn.Module, float] = {}
|
|
254
|
-
for loss,
|
|
255
|
-
module, name =
|
|
256
|
-
result[config(self.
|
|
303
|
+
for loss, loss_value in self.losses.items():
|
|
304
|
+
module, name = get_module(loss, "konfai.metric.measure")
|
|
305
|
+
result[config(self.konfai_args)(getattr(importlib.import_module(module), name))(config=None)] = (
|
|
306
|
+
loss_value
|
|
307
|
+
)
|
|
257
308
|
return result
|
|
258
|
-
|
|
259
|
-
def __init__(
|
|
309
|
+
|
|
310
|
+
def __init__(
|
|
311
|
+
self,
|
|
312
|
+
model_loader: ModelLoader = ModelLoader(),
|
|
313
|
+
path_model: str = "name",
|
|
314
|
+
modules: dict[str, Module] = {
|
|
315
|
+
"UNetBlock_0.DownConvBlock.Activation_1": Module({"Gram": 1, "torch_nn_L1Loss": 1})
|
|
316
|
+
},
|
|
317
|
+
shape: list[int] = [128, 128, 128],
|
|
318
|
+
) -> None:
|
|
260
319
|
super().__init__()
|
|
261
320
|
self.path_model = path_model
|
|
262
|
-
if self.path_model not in
|
|
263
|
-
self.model =
|
|
321
|
+
if self.path_model not in models_register:
|
|
322
|
+
self.model = model_loader.get_model(
|
|
323
|
+
train=False,
|
|
324
|
+
konfai_args=os.environ["KONFAI_CONFIG_PATH"].split("PerceptualLoss")[0] + "PerceptualLoss.Model",
|
|
325
|
+
konfai_without=[
|
|
326
|
+
"optimizer",
|
|
327
|
+
"schedulers",
|
|
328
|
+
"nb_batch_per_step",
|
|
329
|
+
"init_type",
|
|
330
|
+
"init_gain",
|
|
331
|
+
"outputs_criterions",
|
|
332
|
+
"drop_p",
|
|
333
|
+
],
|
|
334
|
+
)
|
|
264
335
|
if path_model.startswith("https"):
|
|
265
336
|
state_dict = torch.hub.load_state_dict_from_url(path_model)
|
|
266
|
-
state_dict = {"Model": {self.model.
|
|
337
|
+
state_dict = {"Model": {self.model.get_name(): state_dict["model"]}}
|
|
267
338
|
else:
|
|
268
339
|
state_dict = torch.load(path_model, weights_only=True)
|
|
269
340
|
self.model.load(state_dict)
|
|
270
|
-
|
|
341
|
+
models_register[self.path_model] = self.model
|
|
271
342
|
else:
|
|
272
|
-
self.model =
|
|
343
|
+
self.model = models_register[self.path_model]
|
|
273
344
|
|
|
274
345
|
self.shape = shape
|
|
275
|
-
self.mode = "trilinear" if
|
|
346
|
+
self.mode = "trilinear" if len(shape) == 3 else "bilinear"
|
|
276
347
|
self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
|
|
277
348
|
for name, losses in modules.items():
|
|
278
|
-
self.modules_loss[name.replace(":", ".")] = losses.
|
|
349
|
+
self.modules_loss[name.replace(":", ".")] = losses.get_loss()
|
|
279
350
|
|
|
280
351
|
self.model.eval()
|
|
281
352
|
self.model.requires_grad_(False)
|
|
282
353
|
self.models: dict[int, torch.nn.Module] = {}
|
|
283
354
|
|
|
284
|
-
def preprocessing(self,
|
|
285
|
-
#if not all([
|
|
286
|
-
#
|
|
287
|
-
#
|
|
288
|
-
#
|
|
289
|
-
#
|
|
290
|
-
#
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
355
|
+
def preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
356
|
+
# if not all([tensor.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]):
|
|
357
|
+
# tensor = F.interpolate(tensor, mode=self.mode,
|
|
358
|
+
# size=tuple(self.shape), align_corners=False).type(torch.float32)
|
|
359
|
+
# if tensor.shape[1] != self.model.in_channels:
|
|
360
|
+
# tensor = tensor.repeat(tuple([1,self.model.in_channels] + [1 for _ in range(len(self.shape))]))
|
|
361
|
+
# tensor = (tensor - torch.min(tensor))/(torch.max(tensor)-torch.min(tensor))
|
|
362
|
+
# tensor = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(tensor)
|
|
363
|
+
return tensor
|
|
364
|
+
|
|
365
|
+
def _compute(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
366
|
+
loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
|
|
367
|
+
output_preprocessing = self.preprocessing(output)
|
|
368
|
+
targets_preprocessing = [self.preprocessing(target) for target in targets]
|
|
369
|
+
for zipped_output in zip([output_preprocessing], *[[target] for target in targets_preprocessing]):
|
|
298
370
|
output = zipped_output[0]
|
|
299
371
|
targets = zipped_output[1:]
|
|
300
372
|
|
|
301
|
-
for zipped_layers in list(
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
373
|
+
for zipped_layers in list(
|
|
374
|
+
zip(
|
|
375
|
+
self.models[output.device.index].get_layers([output], set(self.modules_loss.keys()).copy()),
|
|
376
|
+
*[
|
|
377
|
+
self.models[output.device.index].get_layers([target], set(self.modules_loss.keys()).copy())
|
|
378
|
+
for target in targets
|
|
379
|
+
],
|
|
380
|
+
)
|
|
381
|
+
):
|
|
382
|
+
output_layer = zipped_layers[0][1].view(
|
|
383
|
+
zipped_layers[0][1].shape[0],
|
|
384
|
+
zipped_layers[0][1].shape[1],
|
|
385
|
+
int(np.prod(zipped_layers[0][1].shape[2:])),
|
|
386
|
+
)
|
|
387
|
+
for (loss_function, loss_value), target_layer in zip(
|
|
388
|
+
self.modules_loss[zipped_layers[0][0]].items(), zipped_layers[1:]
|
|
389
|
+
):
|
|
390
|
+
target_layer = target_layer[1].view(
|
|
391
|
+
target_layer[1].shape[0],
|
|
392
|
+
target_layer[1].shape[1],
|
|
393
|
+
int(np.prod(target_layer[1].shape[2:])),
|
|
394
|
+
)
|
|
395
|
+
loss = (
|
|
396
|
+
loss
|
|
397
|
+
+ loss_value * loss_function(output_layer.float(), target_layer.float()) / output_layer.shape[0]
|
|
398
|
+
)
|
|
306
399
|
return loss
|
|
307
|
-
|
|
308
|
-
def forward(self, output: torch.Tensor, *targets
|
|
400
|
+
|
|
401
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
309
402
|
if output.device.index not in self.models:
|
|
310
403
|
del os.environ["device"]
|
|
311
404
|
self.models[output.device.index] = Network.to(copy.deepcopy(self.model).eval(), output.device.index).eval()
|
|
312
|
-
loss = torch.zeros((1), requires_grad
|
|
405
|
+
loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
|
|
313
406
|
if len(output.shape) == 5 and len(self.shape) == 2:
|
|
314
407
|
for i in range(output.shape[2]):
|
|
315
|
-
loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])/output.shape[2]
|
|
408
|
+
loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets]) / output.shape[2]
|
|
316
409
|
else:
|
|
317
410
|
loss = self._compute(output, targets)
|
|
318
411
|
return loss.to(output)
|
|
319
412
|
|
|
413
|
+
|
|
320
414
|
class KLDivergence(Criterion):
|
|
321
|
-
|
|
322
|
-
def __init__(self, shape: list[int], dim
|
|
415
|
+
|
|
416
|
+
def __init__(self, shape: list[int], dim: int = 100, mu: float = 0, std: float = 1) -> None:
|
|
323
417
|
super().__init__()
|
|
324
|
-
self.
|
|
418
|
+
self.latent_dim = dim
|
|
325
419
|
self.mu = torch.Tensor([mu])
|
|
326
420
|
self.std = torch.Tensor([std])
|
|
327
421
|
self.modelDim = 3
|
|
328
422
|
self.shape = shape
|
|
329
423
|
self.loss = torch.nn.KLDivLoss()
|
|
330
|
-
|
|
331
|
-
def init(self, model
|
|
424
|
+
|
|
425
|
+
def init(self, model: Network, output_group: str, target_group: str) -> str:
|
|
332
426
|
super().init(model, output_group, target_group)
|
|
333
427
|
model._compute_channels_trace(model, model.in_channels, None, None)
|
|
334
428
|
|
|
@@ -338,52 +432,35 @@ class KLDivergence(Criterion):
|
|
|
338
432
|
|
|
339
433
|
modules = last_module._modules.copy()
|
|
340
434
|
last_module._modules.clear()
|
|
341
|
-
|
|
435
|
+
|
|
342
436
|
for name, value in modules.items():
|
|
343
437
|
last_module._modules[name] = value
|
|
344
438
|
if name == output_group.split(".")[-1]:
|
|
345
|
-
last_module.add_module(
|
|
346
|
-
|
|
439
|
+
last_module.add_module(
|
|
440
|
+
"LatentDistribution",
|
|
441
|
+
LatentDistribution(shape=self.shape, latent_dim=self.latent_dim),
|
|
442
|
+
)
|
|
443
|
+
return ".".join(output_group.split(".")[:-1]) + ".LatentDistribution.Concat"
|
|
347
444
|
|
|
348
|
-
def forward(self, output: torch.Tensor, targets
|
|
445
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
349
446
|
mu = output[:, 0, :]
|
|
350
447
|
log_std = output[:, 1, :]
|
|
351
|
-
return torch.mean(-0.5 * torch.sum(1 + log_std - mu**2 - torch.exp(log_std), dim
|
|
352
|
-
|
|
353
|
-
"""
|
|
354
|
-
def forward(self, input : torch.Tensor) -> torch.Tensor:
|
|
355
|
-
mu = input[:, 0, :]
|
|
356
|
-
log_std = input[:, 1, :]
|
|
357
|
-
|
|
358
|
-
z = input[:, 2, :]
|
|
448
|
+
return torch.mean(-0.5 * torch.sum(1 + log_std - mu**2 - torch.exp(log_std), dim=1), dim=0)
|
|
359
449
|
|
|
360
|
-
q = torch.distributions.Normal(mu, log_std)
|
|
361
450
|
|
|
362
|
-
target_mu = torch.ones((self.latentDim)).to(input.device)*self.mu.to(input.device)
|
|
363
|
-
target_std = torch.ones((self.latentDim)).to(input.device)*self.std.to(input.device)
|
|
364
|
-
|
|
365
|
-
p = torch.distributions.Normal(target_mu, target_std)
|
|
366
|
-
|
|
367
|
-
log_pz = p.log_prob(z)
|
|
368
|
-
log_qzx = q.log_prob(z)
|
|
369
|
-
|
|
370
|
-
kl = (log_pz - log_qzx)
|
|
371
|
-
kl = kl.sum(-1)
|
|
372
|
-
return kl
|
|
373
|
-
"""
|
|
374
|
-
|
|
375
451
|
class Accuracy(Criterion):
|
|
376
452
|
|
|
377
453
|
def __init__(self) -> None:
|
|
378
454
|
super().__init__()
|
|
379
|
-
self.n
|
|
380
|
-
self.corrects = torch.zeros(
|
|
455
|
+
self.n: int = 0
|
|
456
|
+
self.corrects = torch.zeros(1)
|
|
381
457
|
|
|
382
|
-
def forward(self, output: torch.Tensor, *targets
|
|
458
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
383
459
|
target_0 = targets[0]
|
|
384
460
|
self.n += output.shape[0]
|
|
385
461
|
self.corrects += (torch.argmax(torch.softmax(output, dim=1), dim=1) == target_0).sum().float().cpu()
|
|
386
|
-
return self.corrects/self.n
|
|
462
|
+
return self.corrects / self.n
|
|
463
|
+
|
|
387
464
|
|
|
388
465
|
class TripletLoss(Criterion):
|
|
389
466
|
|
|
@@ -391,9 +468,10 @@ class TripletLoss(Criterion):
|
|
|
391
468
|
super().__init__()
|
|
392
469
|
self.triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
|
393
470
|
|
|
394
|
-
def forward(self, output: torch.Tensor) -> torch.Tensor:
|
|
471
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
395
472
|
return self.triplet_loss(output[0], output[1], output[2])
|
|
396
473
|
|
|
474
|
+
|
|
397
475
|
class L1LossRepresentation(Criterion):
|
|
398
476
|
|
|
399
477
|
def __init__(self) -> None:
|
|
@@ -403,28 +481,31 @@ class L1LossRepresentation(Criterion):
|
|
|
403
481
|
def _variance(self, features: torch.Tensor) -> torch.Tensor:
|
|
404
482
|
return torch.mean(torch.clamp(1 - torch.var(features, dim=0), min=0))
|
|
405
483
|
|
|
406
|
-
def forward(self, output: torch.Tensor) -> torch.Tensor:
|
|
407
|
-
return self.loss(output[0], output[1])+ self._variance(output[0]) + self._variance(output[1])
|
|
484
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
485
|
+
return self.loss(output[0], output[1]) + self._variance(output[0]) + self._variance(output[1])
|
|
486
|
+
|
|
408
487
|
|
|
409
488
|
class FocalLoss(Criterion):
|
|
410
489
|
|
|
411
|
-
def __init__(
|
|
490
|
+
def __init__(
|
|
491
|
+
self,
|
|
492
|
+
gamma: float = 2.0,
|
|
493
|
+
alpha: list[float] = [0.5, 2.0, 0.5, 0.5, 1],
|
|
494
|
+
reduction: str = "mean",
|
|
495
|
+
):
|
|
412
496
|
super().__init__()
|
|
413
497
|
raw_alpha = torch.tensor(alpha, dtype=torch.float32)
|
|
414
498
|
self.alpha = raw_alpha / raw_alpha.sum() * len(raw_alpha)
|
|
415
499
|
self.gamma = gamma
|
|
416
500
|
self.reduction = reduction
|
|
417
501
|
|
|
418
|
-
def forward(self, output: torch.Tensor, *targets:
|
|
419
|
-
target = F.interpolate(
|
|
420
|
-
|
|
421
|
-
output.shape[2:],
|
|
422
|
-
mode="nearest").long()
|
|
423
|
-
|
|
502
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
503
|
+
target = F.interpolate(targets[0], output.shape[2:], mode="nearest").long()
|
|
504
|
+
|
|
424
505
|
logpt = F.log_softmax(output, dim=1)
|
|
425
506
|
pt = torch.exp(logpt)
|
|
426
507
|
|
|
427
|
-
logpt = logpt.gather(1, target)
|
|
508
|
+
logpt = logpt.gather(1, target)
|
|
428
509
|
pt = pt.gather(1, target)
|
|
429
510
|
|
|
430
511
|
at = self.alpha.to(target.device)[target].unsqueeze(1)
|
|
@@ -435,7 +516,7 @@ class FocalLoss(Criterion):
|
|
|
435
516
|
elif self.reduction == "sum":
|
|
436
517
|
return loss.sum()
|
|
437
518
|
return loss
|
|
438
|
-
|
|
519
|
+
|
|
439
520
|
|
|
440
521
|
class FID(Criterion):
|
|
441
522
|
|
|
@@ -449,19 +530,26 @@ class FID(Criterion):
|
|
|
449
530
|
|
|
450
531
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
451
532
|
return self.model(x)
|
|
452
|
-
|
|
533
|
+
|
|
453
534
|
def __init__(self) -> None:
|
|
454
535
|
super().__init__()
|
|
455
536
|
self.inception_model = FID.InceptionV3().cuda()
|
|
456
|
-
|
|
537
|
+
|
|
538
|
+
@staticmethod
|
|
457
539
|
def preprocess_images(image: torch.Tensor) -> torch.Tensor:
|
|
458
|
-
return F.normalize(
|
|
540
|
+
return F.normalize(
|
|
541
|
+
F.resize(image, (299, 299)).repeat((1, 3, 1, 1)),
|
|
542
|
+
mean=[0.485, 0.456, 0.406],
|
|
543
|
+
std=[0.229, 0.224, 0.225],
|
|
544
|
+
).cuda()
|
|
459
545
|
|
|
546
|
+
@staticmethod
|
|
460
547
|
def get_features(images: torch.Tensor, model: torch.nn.Module) -> np.ndarray:
|
|
461
548
|
with torch.no_grad():
|
|
462
549
|
features = model(images).cpu().numpy()
|
|
463
550
|
return features
|
|
464
551
|
|
|
552
|
+
@staticmethod
|
|
465
553
|
def calculate_fid(real_features: np.ndarray, generated_features: np.ndarray) -> float:
|
|
466
554
|
mu1 = np.mean(real_features, axis=0)
|
|
467
555
|
sigma1 = np.cov(real_features, rowvar=False)
|
|
@@ -475,7 +563,7 @@ class FID(Criterion):
|
|
|
475
563
|
|
|
476
564
|
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
|
|
477
565
|
|
|
478
|
-
def forward(self, output: torch.Tensor, *targets
|
|
566
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
479
567
|
real_images = FID.preprocess_images(targets[0].squeeze(0).permute([1, 0, 2, 3]))
|
|
480
568
|
generated_images = FID.preprocess_images(output.squeeze(0).permute([1, 0, 2, 3]))
|
|
481
569
|
|
|
@@ -484,8 +572,15 @@ class FID(Criterion):
|
|
|
484
572
|
|
|
485
573
|
return FID.calculate_fid(real_features, generated_features)
|
|
486
574
|
|
|
575
|
+
|
|
487
576
|
class MutualInformationLoss(torch.nn.Module):
|
|
488
|
-
def __init__(
|
|
577
|
+
def __init__(
|
|
578
|
+
self,
|
|
579
|
+
num_bins: int = 23,
|
|
580
|
+
sigma_ratio: float = 0.5,
|
|
581
|
+
smooth_nr: float = 1e-7,
|
|
582
|
+
smooth_dr: float = 1e-7,
|
|
583
|
+
) -> None:
|
|
489
584
|
super().__init__()
|
|
490
585
|
bin_centers = torch.linspace(0.0, 1.0, num_bins)
|
|
491
586
|
sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio
|
|
@@ -495,7 +590,9 @@ class MutualInformationLoss(torch.nn.Module):
|
|
|
495
590
|
self.smooth_nr = float(smooth_nr)
|
|
496
591
|
self.smooth_dr = float(smooth_dr)
|
|
497
592
|
|
|
498
|
-
def parzen_windowing(
|
|
593
|
+
def parzen_windowing(
|
|
594
|
+
self, pred: torch.Tensor, target: torch.Tensor
|
|
595
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
499
596
|
pred_weight, pred_probability = self.parzen_windowing_gaussian(pred)
|
|
500
597
|
target_weight, target_probability = self.parzen_windowing_gaussian(target)
|
|
501
598
|
return pred_weight, pred_probability, target_weight, target_probability
|
|
@@ -503,35 +600,51 @@ class MutualInformationLoss(torch.nn.Module):
|
|
|
503
600
|
def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
504
601
|
img = torch.clamp(img, 0, 1)
|
|
505
602
|
img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
|
|
506
|
-
weight = torch.exp(
|
|
603
|
+
weight = torch.exp(
|
|
604
|
+
-self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2
|
|
605
|
+
) # (batch, num_sample, num_bin)
|
|
507
606
|
weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin)
|
|
508
607
|
probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin)
|
|
509
608
|
return weight, probability
|
|
510
609
|
|
|
511
|
-
def forward(self, pred: torch.Tensor,
|
|
512
|
-
wa, pa, wb, pb = self.parzen_windowing(pred,
|
|
610
|
+
def forward(self, pred: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
611
|
+
wa, pa, wb, pb = self.parzen_windowing(pred, targets[0]) # (batch, num_sample, num_bin), (batch, 1, num_bin)
|
|
513
612
|
pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
|
|
514
613
|
papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
|
|
515
|
-
mi = torch.sum(
|
|
614
|
+
mi = torch.sum(
|
|
615
|
+
pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr),
|
|
616
|
+
dim=(1, 2),
|
|
617
|
+
) # (batch)
|
|
516
618
|
return torch.mean(mi).neg() # average over the batch and channel ndims
|
|
517
619
|
|
|
518
|
-
class FOCUS(Criterion): #Feature-Oriented Comparison for Unpaired Synthesis
|
|
519
620
|
|
|
520
|
-
|
|
621
|
+
class IMPACTSynth(Criterion): # Feature-Oriented Comparison for Unpaired Synthesis
|
|
622
|
+
|
|
623
|
+
def __init__(
|
|
624
|
+
self,
|
|
625
|
+
model_name: str,
|
|
626
|
+
shape: list[int] = [0, 0],
|
|
627
|
+
in_channels: int = 1,
|
|
628
|
+
losses: dict[str, list[float]] = {"Gram": [1], "torch_nn_L1Loss": [1]},
|
|
629
|
+
) -> None:
|
|
521
630
|
super().__init__()
|
|
522
631
|
if model_name is None:
|
|
523
632
|
return
|
|
524
|
-
self.shape = shape
|
|
525
633
|
self.in_channels = in_channels
|
|
526
634
|
self.losses: dict[torch.nn.Module, list[float]] = {}
|
|
527
635
|
for loss, weights in losses.items():
|
|
528
|
-
module, name =
|
|
529
|
-
self.losses[
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
self.
|
|
534
|
-
|
|
636
|
+
module, name = get_module(loss, "konfai.metric.measure")
|
|
637
|
+
self.losses[
|
|
638
|
+
config(os.environ["KONFAI_CONFIG_PATH"])(getattr(importlib.import_module(module), name))(config=None)
|
|
639
|
+
] = weights
|
|
640
|
+
|
|
641
|
+
self.model_path = download_url(
|
|
642
|
+
model_name,
|
|
643
|
+
"https://huggingface.co/VBoussot/impact-torchscript-models/resolve/main/",
|
|
644
|
+
)
|
|
645
|
+
self.model: torch.nn.Module = torch.jit.load(self.model_path) # nosec B614
|
|
646
|
+
self.dim = len(shape)
|
|
647
|
+
self.shape = shape if all(s > 0 for s in shape) else None
|
|
535
648
|
self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
|
|
536
649
|
|
|
537
650
|
try:
|
|
@@ -540,7 +653,9 @@ class FOCUS(Criterion): #Feature-Oriented Comparison for Unpaired Synthesis
|
|
|
540
653
|
if not isinstance(out, (list, tuple)):
|
|
541
654
|
raise TypeError(f"Expected model output to be a list or tuple, but got {type(out)}.")
|
|
542
655
|
if len(self.weight) != len(out):
|
|
543
|
-
raise ValueError(
|
|
656
|
+
raise ValueError(
|
|
657
|
+
f"Mismatch between number of weights ({len(self.weight)}) and model outputs ({len(out)})."
|
|
658
|
+
)
|
|
544
659
|
except Exception as e:
|
|
545
660
|
msg = (
|
|
546
661
|
f"[Model Sanity Check Failed]\n"
|
|
@@ -550,35 +665,39 @@ class FOCUS(Criterion): #Feature-Oriented Comparison for Unpaired Synthesis
|
|
|
550
665
|
)
|
|
551
666
|
raise RuntimeError(msg) from e
|
|
552
667
|
self.model = None
|
|
553
|
-
|
|
554
|
-
def preprocessing(self,
|
|
555
|
-
if self.shape is not None and not all(
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
668
|
+
|
|
669
|
+
def preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
670
|
+
if self.shape is not None and not all(
|
|
671
|
+
tensor.shape[-i - 1] == size for i, size in enumerate(reversed(self.shape[2:]))
|
|
672
|
+
):
|
|
673
|
+
tensor = torch.nn.functional.interpolate(
|
|
674
|
+
tensor, mode=self.mode, size=tuple(self.shape), align_corners=False
|
|
675
|
+
).type(torch.float32)
|
|
676
|
+
if tensor.shape[1] != self.in_channels:
|
|
677
|
+
tensor = tensor.repeat(tuple([1, 3] + [1 for _ in range(self.dim)]))
|
|
678
|
+
return tensor
|
|
679
|
+
|
|
561
680
|
def _compute(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
|
|
562
|
-
loss = torch.zeros((1), requires_grad
|
|
681
|
+
loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
|
|
563
682
|
output = self.preprocessing(output)
|
|
564
683
|
targets = [self.preprocessing(target) for target in targets]
|
|
565
684
|
self.model.to(output.device)
|
|
566
685
|
for zipped_output in zip(self.model(output), *[self.model(target) for target in targets]):
|
|
567
686
|
output_features = zipped_output[0]
|
|
568
|
-
targets_features = zipped_output[1:]
|
|
687
|
+
targets_features = zipped_output[1:]
|
|
569
688
|
for target_features, (loss_function, weights) in zip(targets_features, self.losses.items()):
|
|
570
|
-
for output_feature, target_feature, weight in zip(output_features
|
|
571
|
-
loss +=
|
|
689
|
+
for output_feature, target_feature, weight in zip(output_features, target_features, weights):
|
|
690
|
+
loss += weight * loss_function(output_feature, target_feature)
|
|
572
691
|
return loss
|
|
573
|
-
|
|
574
|
-
def forward(self, output: torch.Tensor, *targets
|
|
692
|
+
|
|
693
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
575
694
|
if self.model is None:
|
|
576
|
-
self.model = torch.jit.load(self.model_path)
|
|
577
|
-
loss = torch.zeros((1), requires_grad
|
|
695
|
+
self.model = torch.jit.load(self.model_path) # nosec B614
|
|
696
|
+
loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
|
|
578
697
|
if len(output.shape) == 5 and self.dim == 2:
|
|
579
698
|
for i in range(output.shape[2]):
|
|
580
699
|
loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])
|
|
581
|
-
loss/=output.shape[2]
|
|
700
|
+
loss /= output.shape[2]
|
|
582
701
|
else:
|
|
583
|
-
loss = self._compute(output, targets)
|
|
584
|
-
return loss.to(output)
|
|
702
|
+
loss = self._compute(output, list(targets))
|
|
703
|
+
return loss.to(output)
|