konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +533 -316
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +408 -275
- konfai/evaluator.py +325 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +360 -244
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +795 -427
- konfai/predictor.py +644 -238
- konfai/trainer.py +509 -222
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +497 -249
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
- konfai-1.2.0.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.8.dist-info/RECORD +0 -39
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/metric/measure.py
CHANGED
|
@@ -1,334 +1,434 @@
|
|
|
1
|
-
|
|
1
|
+
import copy
|
|
2
2
|
import importlib
|
|
3
|
-
import numpy as np
|
|
4
|
-
import torch
|
|
5
|
-
|
|
6
|
-
from torchvision.models import inception_v3, Inception_V3_Weights
|
|
7
|
-
from torchvision.transforms import functional as F
|
|
8
|
-
from scipy import linalg
|
|
9
|
-
|
|
10
|
-
import torch.nn.functional as F
|
|
11
3
|
import os
|
|
12
|
-
|
|
13
|
-
from
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from collections.abc import Callable
|
|
14
6
|
from functools import partial
|
|
15
|
-
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F # noqa: N812
|
|
11
|
+
from scipy import linalg
|
|
16
12
|
from skimage.metrics import structural_similarity
|
|
17
|
-
import
|
|
18
|
-
from
|
|
13
|
+
from torchvision.models import Inception_V3_Weights, inception_v3
|
|
14
|
+
from tqdm import tqdm
|
|
19
15
|
|
|
20
|
-
from konfai.utils.config import config
|
|
21
|
-
from konfai.utils.utils import _getModule, download_url
|
|
22
16
|
from konfai.data.patching import ModelPatch
|
|
23
17
|
from konfai.network.blocks import LatentDistribution
|
|
24
18
|
from konfai.network.network import ModelLoader, Network
|
|
25
|
-
from
|
|
19
|
+
from konfai.utils.config import config
|
|
20
|
+
from konfai.utils.utils import download_url, get_module
|
|
21
|
+
|
|
22
|
+
models_register = {}
|
|
26
23
|
|
|
27
|
-
modelsRegister = {}
|
|
28
24
|
|
|
29
25
|
class Criterion(torch.nn.Module, ABC):
|
|
30
26
|
|
|
31
27
|
def __init__(self) -> None:
|
|
32
28
|
super().__init__()
|
|
33
29
|
|
|
34
|
-
def init(self, model
|
|
30
|
+
def init(self, model: torch.nn.Module, output_group: str, target_group: str) -> str:
|
|
35
31
|
return output_group
|
|
36
32
|
|
|
37
33
|
@abstractmethod
|
|
38
|
-
def forward(self, output: torch.Tensor, *targets
|
|
34
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
39
35
|
pass
|
|
40
36
|
|
|
37
|
+
|
|
41
38
|
class MaskedLoss(Criterion):
|
|
42
39
|
|
|
43
|
-
def __init__(
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
43
|
+
mode_image_masked: bool,
|
|
44
|
+
) -> None:
|
|
44
45
|
super().__init__()
|
|
45
46
|
self.loss = loss
|
|
46
47
|
self.mode_image_masked = mode_image_masked
|
|
47
48
|
|
|
48
|
-
def
|
|
49
|
+
def get_mask(self, targets: list[torch.Tensor]) -> torch.Tensor | None:
|
|
49
50
|
result = None
|
|
50
51
|
if len(targets) > 0:
|
|
51
52
|
result = targets[0]
|
|
52
53
|
for mask in targets[1:]:
|
|
53
|
-
result = result*mask
|
|
54
|
+
result = result * mask
|
|
54
55
|
return result
|
|
55
56
|
|
|
56
|
-
def forward(self, output: torch.Tensor, *targets
|
|
57
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> tuple[torch.Tensor, float]:
|
|
57
58
|
loss = torch.tensor(0, dtype=torch.float32).to(output.device)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
59
|
+
true_loss = 0
|
|
60
|
+
true_nb = 0
|
|
61
|
+
mask = self.get_mask(list(targets[1:]))
|
|
62
|
+
if mask is not None:
|
|
63
|
+
for batch in range(output.shape[0]):
|
|
61
64
|
if self.mode_image_masked:
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
+
loss_b = self.loss(
|
|
66
|
+
output[batch, ...] * torch.where(mask == 1, 1, 0),
|
|
67
|
+
targets[0][batch, ...] * torch.where(mask == 1, 1, 0),
|
|
68
|
+
)
|
|
65
69
|
else:
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
70
|
+
index = mask[batch, ...] == 1
|
|
71
|
+
loss_b = self.loss(
|
|
72
|
+
torch.masked_select(output[batch, ...], index),
|
|
73
|
+
torch.masked_select(targets[0][batch, ...], index),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
loss += loss_b
|
|
77
|
+
if torch.any(mask[batch] == 1):
|
|
78
|
+
true_loss += loss_b.item()
|
|
79
|
+
true_nb += 1
|
|
80
|
+
else:
|
|
81
|
+
loss_tmp = self.loss(output, targets[0])
|
|
82
|
+
loss += loss_tmp
|
|
83
|
+
true_loss += loss_tmp.item()
|
|
84
|
+
true_nb += 1
|
|
85
|
+
return loss / output.shape[0], np.nan if true_nb == 0 else true_loss / true_nb
|
|
86
|
+
|
|
87
|
+
|
|
74
88
|
class MSE(MaskedLoss):
|
|
75
89
|
|
|
90
|
+
@staticmethod
|
|
76
91
|
def _loss(reduction: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
77
92
|
return torch.nn.MSELoss(reduction=reduction)(x, y)
|
|
78
93
|
|
|
79
94
|
def __init__(self, reduction: str = "mean") -> None:
|
|
80
95
|
super().__init__(partial(MSE._loss, reduction), False)
|
|
81
96
|
|
|
97
|
+
|
|
82
98
|
class MAE(MaskedLoss):
|
|
83
99
|
|
|
100
|
+
@staticmethod
|
|
84
101
|
def _loss(reduction: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
85
102
|
return torch.nn.L1Loss(reduction=reduction)(x, y)
|
|
86
|
-
|
|
103
|
+
|
|
87
104
|
def __init__(self, reduction: str = "mean") -> None:
|
|
88
105
|
super().__init__(partial(MAE._loss, reduction), False)
|
|
89
106
|
|
|
107
|
+
|
|
90
108
|
class PSNR(MaskedLoss):
|
|
91
109
|
|
|
92
|
-
|
|
110
|
+
@staticmethod
|
|
111
|
+
def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
93
112
|
mse = torch.mean((x - y).pow(2))
|
|
94
|
-
psnr = 10 * torch.log10(dynamic_range**2 / mse)
|
|
95
|
-
return psnr
|
|
113
|
+
psnr = 10 * torch.log10(dynamic_range**2 / mse)
|
|
114
|
+
return psnr
|
|
96
115
|
|
|
97
|
-
def __init__(self, dynamic_range:
|
|
98
|
-
dynamic_range = dynamic_range if dynamic_range else 1024+3071
|
|
116
|
+
def __init__(self, dynamic_range: float | None = None) -> None:
|
|
117
|
+
dynamic_range = dynamic_range if dynamic_range else 1024 + 3071
|
|
99
118
|
super().__init__(partial(PSNR._loss, dynamic_range), False)
|
|
100
|
-
|
|
119
|
+
|
|
120
|
+
|
|
101
121
|
class SSIM(MaskedLoss):
|
|
102
|
-
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
103
124
|
def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
104
|
-
return structural_similarity(
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
125
|
+
return structural_similarity(
|
|
126
|
+
x[0][0].detach().cpu().numpy(),
|
|
127
|
+
y[0][0].cpu().numpy(),
|
|
128
|
+
data_range=dynamic_range,
|
|
129
|
+
gradient=False,
|
|
130
|
+
full=False,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def __init__(self, dynamic_range: float | None = None) -> None:
|
|
134
|
+
dynamic_range = dynamic_range if dynamic_range else 1024 + 3000
|
|
108
135
|
super().__init__(partial(SSIM._loss, dynamic_range), True)
|
|
109
136
|
|
|
110
137
|
|
|
111
138
|
class LPIPS(MaskedLoss):
|
|
112
139
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
140
|
+
@staticmethod
|
|
141
|
+
def normalize(tensor: torch.Tensor) -> torch.Tensor:
|
|
142
|
+
return (tensor - torch.min(tensor)) / (torch.max(tensor) - torch.min(tensor)) * 2 - 1
|
|
143
|
+
|
|
144
|
+
@staticmethod
|
|
145
|
+
def preprocessing(tensor: torch.Tensor) -> torch.Tensor:
|
|
146
|
+
return tensor.repeat((1, 3, 1, 1)).to(0)
|
|
147
|
+
|
|
148
|
+
@staticmethod
|
|
119
149
|
def _loss(loss_fn_alex, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
120
|
-
|
|
121
|
-
|
|
150
|
+
dataset_patch = ModelPatch([1, 64, 64])
|
|
151
|
+
dataset_patch.load(x.shape[2:])
|
|
122
152
|
|
|
123
|
-
|
|
153
|
+
patch_iterator = dataset_patch.disassemble(LPIPS.normalize(x), LPIPS.normalize(y))
|
|
124
154
|
loss = 0
|
|
125
|
-
with tqdm(
|
|
126
|
-
|
|
155
|
+
with tqdm(
|
|
156
|
+
iterable=enumerate(patch_iterator),
|
|
157
|
+
leave=False,
|
|
158
|
+
total=dataset_patch.get_size(0),
|
|
159
|
+
) as batch_iter:
|
|
160
|
+
for _, patch_input in batch_iter:
|
|
127
161
|
real, fake = LPIPS.preprocessing(patch_input[0]), LPIPS.preprocessing(patch_input[1])
|
|
128
162
|
loss += loss_fn_alex(real, fake).item()
|
|
129
|
-
return loss/
|
|
163
|
+
return loss / dataset_patch.get_size(0)
|
|
130
164
|
|
|
131
165
|
def __init__(self, model: str = "alex") -> None:
|
|
132
166
|
import lpips
|
|
133
|
-
|
|
167
|
+
|
|
168
|
+
super().__init__(partial(LPIPS._loss, lpips.LPIPS(net=model).to(0)), True)
|
|
169
|
+
|
|
134
170
|
|
|
135
171
|
class Dice(Criterion):
|
|
136
|
-
|
|
137
|
-
def __init__(self,
|
|
172
|
+
|
|
173
|
+
def __init__(self, labels: list[int] | None = None) -> None:
|
|
138
174
|
super().__init__()
|
|
139
|
-
self.
|
|
140
|
-
|
|
141
|
-
def flatten(self, tensor
|
|
175
|
+
self.labels = labels
|
|
176
|
+
|
|
177
|
+
def flatten(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
142
178
|
return tensor.permute((1, 0) + tuple(range(2, tensor.dim()))).contiguous().view(tensor.size(1), -1)
|
|
143
179
|
|
|
144
|
-
def dice_per_channel(self,
|
|
145
|
-
|
|
180
|
+
def dice_per_channel(self, tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
181
|
+
tensor = self.flatten(tensor)
|
|
146
182
|
target = self.flatten(target)
|
|
147
|
-
return (2
|
|
148
|
-
|
|
149
|
-
def forward(self, output: torch.Tensor, *targets
|
|
150
|
-
target = F.interpolate(
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
183
|
+
return (2.0 * (tensor * target).sum() + 1e-6) / (tensor.sum() + target.sum() + 1e-6)
|
|
184
|
+
|
|
185
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
186
|
+
target = F.interpolate(targets[0], output.shape[2:], mode="nearest")
|
|
187
|
+
result = {}
|
|
188
|
+
loss = torch.tensor(0, dtype=torch.float32).to(output.device)
|
|
189
|
+
labels = self.labels if self.labels is not None else torch.unique(target)
|
|
190
|
+
for label in labels:
|
|
191
|
+
tp = target == label
|
|
192
|
+
if tp.any().item():
|
|
193
|
+
if output.shape[1] > 1:
|
|
194
|
+
pp = output[:, label].unsqueeze(1)
|
|
195
|
+
else:
|
|
196
|
+
pp = output == label
|
|
197
|
+
loss_tmp = self.dice_per_channel(pp.float(), tp.float())
|
|
198
|
+
loss += loss_tmp
|
|
199
|
+
result[label] = loss_tmp.item()
|
|
200
|
+
else:
|
|
201
|
+
result[label] = np.nan
|
|
202
|
+
return 1 - loss / len(labels), result
|
|
203
|
+
|
|
158
204
|
|
|
159
205
|
class GradientImages(Criterion):
|
|
160
206
|
|
|
161
207
|
def __init__(self):
|
|
162
208
|
super().__init__()
|
|
163
|
-
|
|
209
|
+
|
|
164
210
|
@staticmethod
|
|
165
|
-
def
|
|
211
|
+
def _image_gradient_2d(image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
166
212
|
dx = image[:, :, 1:, :] - image[:, :, :-1, :]
|
|
167
213
|
dy = image[:, :, :, 1:] - image[:, :, :, :-1]
|
|
168
214
|
return dx, dy
|
|
169
215
|
|
|
170
216
|
@staticmethod
|
|
171
|
-
def
|
|
217
|
+
def _image_gradient_3d(
|
|
218
|
+
image: torch.Tensor,
|
|
219
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
172
220
|
dx = image[:, :, 1:, :, :] - image[:, :, :-1, :, :]
|
|
173
221
|
dy = image[:, :, :, 1:, :] - image[:, :, :, :-1, :]
|
|
174
222
|
dz = image[:, :, :, :, 1:] - image[:, :, :, :, :-1]
|
|
175
223
|
return dx, dy, dz
|
|
176
|
-
|
|
177
|
-
def forward(self, output: torch.Tensor, *targets
|
|
224
|
+
|
|
225
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
178
226
|
target_0 = targets[0]
|
|
179
227
|
if len(output.shape) == 5:
|
|
180
|
-
dx, dy, dz = GradientImages.
|
|
228
|
+
dx, dy, dz = GradientImages._image_gradient_3d(output)
|
|
181
229
|
if target_0 is not None:
|
|
182
|
-
dx_tmp, dy_tmp, dz_tmp = GradientImages.
|
|
230
|
+
dx_tmp, dy_tmp, dz_tmp = GradientImages._image_gradient_3d(target_0)
|
|
183
231
|
dx -= dx_tmp
|
|
184
232
|
dy -= dy_tmp
|
|
185
233
|
dz -= dz_tmp
|
|
186
234
|
return dx.norm() + dy.norm() + dz.norm()
|
|
187
235
|
else:
|
|
188
|
-
dx, dy = GradientImages.
|
|
236
|
+
dx, dy = GradientImages._image_gradient_2d(output)
|
|
189
237
|
if target_0 is not None:
|
|
190
|
-
dx_tmp, dy_tmp = GradientImages.
|
|
238
|
+
dx_tmp, dy_tmp = GradientImages._image_gradient_2d(target_0)
|
|
191
239
|
dx -= dx_tmp
|
|
192
240
|
dy -= dy_tmp
|
|
193
241
|
return dx.norm() + dy.norm()
|
|
194
|
-
|
|
242
|
+
|
|
243
|
+
|
|
195
244
|
class BCE(Criterion):
|
|
196
245
|
|
|
197
|
-
def __init__(self, target
|
|
246
|
+
def __init__(self, target: float = 0) -> None:
|
|
198
247
|
super().__init__()
|
|
199
248
|
self.loss = torch.nn.BCEWithLogitsLoss()
|
|
200
|
-
self.register_buffer(
|
|
249
|
+
self.register_buffer("target", torch.tensor(target).type(torch.float32))
|
|
201
250
|
|
|
202
|
-
def forward(self, output: torch.Tensor, *targets
|
|
251
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
203
252
|
target = self._buffers["target"]
|
|
204
253
|
return self.loss(output, target.to(output.device).expand_as(output))
|
|
205
254
|
|
|
255
|
+
|
|
206
256
|
class PatchGanLoss(Criterion):
|
|
207
257
|
|
|
208
|
-
def __init__(self, target
|
|
258
|
+
def __init__(self, target: float = 0) -> None:
|
|
209
259
|
super().__init__()
|
|
210
260
|
self.loss = torch.nn.MSELoss()
|
|
211
|
-
self.register_buffer(
|
|
261
|
+
self.register_buffer("target", torch.tensor(target).type(torch.float32))
|
|
212
262
|
|
|
213
|
-
def forward(self, output: torch.Tensor, *targets
|
|
263
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
214
264
|
target = self._buffers["target"]
|
|
215
|
-
return self.loss(output, (torch.ones_like(output)*target).to(output.device))
|
|
265
|
+
return self.loss(output, (torch.ones_like(output) * target).to(output.device))
|
|
266
|
+
|
|
216
267
|
|
|
217
268
|
class WGP(Criterion):
|
|
218
269
|
|
|
219
270
|
def __init__(self) -> None:
|
|
220
271
|
super().__init__()
|
|
221
|
-
|
|
222
|
-
def forward(self, output: torch.Tensor, *targets
|
|
223
|
-
return torch.mean((output - 1)**2)
|
|
272
|
+
|
|
273
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
274
|
+
return torch.mean((output - 1) ** 2)
|
|
275
|
+
|
|
224
276
|
|
|
225
277
|
class Gram(Criterion):
|
|
226
278
|
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
279
|
+
@staticmethod
|
|
280
|
+
def compute_gram(tensor: torch.Tensor):
|
|
281
|
+
(b, ch, w) = tensor.size()
|
|
282
|
+
with torch.amp.autocast("cuda", enabled=False):
|
|
283
|
+
return tensor.bmm(tensor.transpose(1, 2)).div(ch * w)
|
|
231
284
|
|
|
232
285
|
def __init__(self) -> None:
|
|
233
286
|
super().__init__()
|
|
234
|
-
self.loss = torch.nn.L1Loss(reduction=
|
|
287
|
+
self.loss = torch.nn.L1Loss(reduction="sum")
|
|
235
288
|
|
|
236
|
-
def forward(self, output: torch.Tensor,
|
|
289
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
290
|
+
target = targets[0]
|
|
237
291
|
if len(output.shape) > 3:
|
|
238
292
|
output = output.view(output.shape[0], output.shape[1], int(np.prod(output.shape[2:])))
|
|
239
293
|
if len(target.shape) > 3:
|
|
240
294
|
target = target.view(target.shape[0], target.shape[1], int(np.prod(target.shape[2:])))
|
|
241
|
-
return self.loss(Gram.
|
|
295
|
+
return self.loss(Gram.compute_gram(output), Gram.compute_gram(target))
|
|
296
|
+
|
|
242
297
|
|
|
243
298
|
class PerceptualLoss(Criterion):
|
|
244
|
-
|
|
245
|
-
class Module
|
|
246
|
-
|
|
299
|
+
|
|
300
|
+
class Module:
|
|
301
|
+
|
|
247
302
|
@config()
|
|
248
303
|
def __init__(self, losses: dict[str, float] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
|
|
249
304
|
self.losses = losses
|
|
250
|
-
self.
|
|
305
|
+
self.konfai_args = os.environ["KONFAI_CONFIG_PATH"] if "KONFAI_CONFIG_PATH" in os.environ else ""
|
|
251
306
|
|
|
252
|
-
def
|
|
307
|
+
def get_loss(self) -> dict[torch.nn.Module, float]:
|
|
253
308
|
result: dict[torch.nn.Module, float] = {}
|
|
254
|
-
for loss,
|
|
255
|
-
module, name =
|
|
256
|
-
result[config(self.
|
|
309
|
+
for loss, loss_value in self.losses.items():
|
|
310
|
+
module, name = get_module(loss, "konfai.metric.measure")
|
|
311
|
+
result[config(self.konfai_args)(getattr(importlib.import_module(module), name))(config=None)] = (
|
|
312
|
+
loss_value
|
|
313
|
+
)
|
|
257
314
|
return result
|
|
258
|
-
|
|
259
|
-
def __init__(
|
|
315
|
+
|
|
316
|
+
def __init__(
|
|
317
|
+
self,
|
|
318
|
+
model_loader: ModelLoader = ModelLoader(),
|
|
319
|
+
path_model: str = "name",
|
|
320
|
+
modules: dict[str, Module] = {
|
|
321
|
+
"UNetBlock_0.DownConvBlock.Activation_1": Module({"Gram": 1, "torch_nn_L1Loss": 1})
|
|
322
|
+
},
|
|
323
|
+
shape: list[int] = [128, 128, 128],
|
|
324
|
+
) -> None:
|
|
260
325
|
super().__init__()
|
|
261
326
|
self.path_model = path_model
|
|
262
|
-
if self.path_model not in
|
|
263
|
-
self.model =
|
|
327
|
+
if self.path_model not in models_register:
|
|
328
|
+
self.model = model_loader.get_model(
|
|
329
|
+
train=False,
|
|
330
|
+
konfai_args=os.environ["KONFAI_CONFIG_PATH"].split("PerceptualLoss")[0] + "PerceptualLoss.Model",
|
|
331
|
+
konfai_without=[
|
|
332
|
+
"optimizer",
|
|
333
|
+
"schedulers",
|
|
334
|
+
"nb_batch_per_step",
|
|
335
|
+
"init_type",
|
|
336
|
+
"init_gain",
|
|
337
|
+
"outputs_criterions",
|
|
338
|
+
"drop_p",
|
|
339
|
+
],
|
|
340
|
+
)
|
|
264
341
|
if path_model.startswith("https"):
|
|
265
342
|
state_dict = torch.hub.load_state_dict_from_url(path_model)
|
|
266
|
-
state_dict = {"Model": {self.model.
|
|
343
|
+
state_dict = {"Model": {self.model.get_name(): state_dict["model"]}}
|
|
267
344
|
else:
|
|
268
345
|
state_dict = torch.load(path_model, weights_only=True)
|
|
269
346
|
self.model.load(state_dict)
|
|
270
|
-
|
|
347
|
+
models_register[self.path_model] = self.model
|
|
271
348
|
else:
|
|
272
|
-
self.model =
|
|
349
|
+
self.model = models_register[self.path_model]
|
|
273
350
|
|
|
274
351
|
self.shape = shape
|
|
275
|
-
self.mode = "trilinear" if
|
|
352
|
+
self.mode = "trilinear" if len(shape) == 3 else "bilinear"
|
|
276
353
|
self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
|
|
277
354
|
for name, losses in modules.items():
|
|
278
|
-
self.modules_loss[name.replace(":", ".")] = losses.
|
|
355
|
+
self.modules_loss[name.replace(":", ".")] = losses.get_loss()
|
|
279
356
|
|
|
280
357
|
self.model.eval()
|
|
281
358
|
self.model.requires_grad_(False)
|
|
282
359
|
self.models: dict[int, torch.nn.Module] = {}
|
|
283
360
|
|
|
284
|
-
def preprocessing(self,
|
|
285
|
-
#if not all([
|
|
286
|
-
#
|
|
287
|
-
#
|
|
288
|
-
#
|
|
289
|
-
#
|
|
290
|
-
#
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
361
|
+
def preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
362
|
+
# if not all([tensor.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]):
|
|
363
|
+
# tensor = F.interpolate(tensor, mode=self.mode,
|
|
364
|
+
# size=tuple(self.shape), align_corners=False).type(torch.float32)
|
|
365
|
+
# if tensor.shape[1] != self.model.in_channels:
|
|
366
|
+
# tensor = tensor.repeat(tuple([1,self.model.in_channels] + [1 for _ in range(len(self.shape))]))
|
|
367
|
+
# tensor = (tensor - torch.min(tensor))/(torch.max(tensor)-torch.min(tensor))
|
|
368
|
+
# tensor = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(tensor)
|
|
369
|
+
return tensor
|
|
370
|
+
|
|
371
|
+
def _compute(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
372
|
+
loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
|
|
373
|
+
output_preprocessing = self.preprocessing(output)
|
|
374
|
+
targets_preprocessing = [self.preprocessing(target) for target in targets]
|
|
375
|
+
for zipped_output in zip([output_preprocessing], *[[target] for target in targets_preprocessing]):
|
|
298
376
|
output = zipped_output[0]
|
|
299
377
|
targets = zipped_output[1:]
|
|
300
378
|
|
|
301
|
-
for zipped_layers in list(
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
379
|
+
for zipped_layers in list(
|
|
380
|
+
zip(
|
|
381
|
+
self.models[output.device.index].get_layers([output], set(self.modules_loss.keys()).copy()),
|
|
382
|
+
*[
|
|
383
|
+
self.models[output.device.index].get_layers([target], set(self.modules_loss.keys()).copy())
|
|
384
|
+
for target in targets
|
|
385
|
+
],
|
|
386
|
+
)
|
|
387
|
+
):
|
|
388
|
+
output_layer = zipped_layers[0][1].view(
|
|
389
|
+
zipped_layers[0][1].shape[0],
|
|
390
|
+
zipped_layers[0][1].shape[1],
|
|
391
|
+
int(np.prod(zipped_layers[0][1].shape[2:])),
|
|
392
|
+
)
|
|
393
|
+
for (loss_function, loss_value), target_layer in zip(
|
|
394
|
+
self.modules_loss[zipped_layers[0][0]].items(), zipped_layers[1:]
|
|
395
|
+
):
|
|
396
|
+
target_layer = target_layer[1].view(
|
|
397
|
+
target_layer[1].shape[0],
|
|
398
|
+
target_layer[1].shape[1],
|
|
399
|
+
int(np.prod(target_layer[1].shape[2:])),
|
|
400
|
+
)
|
|
401
|
+
loss = (
|
|
402
|
+
loss
|
|
403
|
+
+ loss_value * loss_function(output_layer.float(), target_layer.float()) / output_layer.shape[0]
|
|
404
|
+
)
|
|
306
405
|
return loss
|
|
307
|
-
|
|
308
|
-
def forward(self, output: torch.Tensor, *targets
|
|
406
|
+
|
|
407
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
309
408
|
if output.device.index not in self.models:
|
|
310
409
|
del os.environ["device"]
|
|
311
410
|
self.models[output.device.index] = Network.to(copy.deepcopy(self.model).eval(), output.device.index).eval()
|
|
312
|
-
loss = torch.zeros((1), requires_grad
|
|
411
|
+
loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
|
|
313
412
|
if len(output.shape) == 5 and len(self.shape) == 2:
|
|
314
413
|
for i in range(output.shape[2]):
|
|
315
|
-
loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])/output.shape[2]
|
|
414
|
+
loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets]) / output.shape[2]
|
|
316
415
|
else:
|
|
317
416
|
loss = self._compute(output, targets)
|
|
318
417
|
return loss.to(output)
|
|
319
418
|
|
|
419
|
+
|
|
320
420
|
class KLDivergence(Criterion):
|
|
321
|
-
|
|
322
|
-
def __init__(self, shape: list[int], dim
|
|
421
|
+
|
|
422
|
+
def __init__(self, shape: list[int], dim: int = 100, mu: float = 0, std: float = 1) -> None:
|
|
323
423
|
super().__init__()
|
|
324
|
-
self.
|
|
424
|
+
self.latent_dim = dim
|
|
325
425
|
self.mu = torch.Tensor([mu])
|
|
326
426
|
self.std = torch.Tensor([std])
|
|
327
427
|
self.modelDim = 3
|
|
328
428
|
self.shape = shape
|
|
329
429
|
self.loss = torch.nn.KLDivLoss()
|
|
330
|
-
|
|
331
|
-
def init(self, model
|
|
430
|
+
|
|
431
|
+
def init(self, model: Network, output_group: str, target_group: str) -> str:
|
|
332
432
|
super().init(model, output_group, target_group)
|
|
333
433
|
model._compute_channels_trace(model, model.in_channels, None, None)
|
|
334
434
|
|
|
@@ -338,52 +438,35 @@ class KLDivergence(Criterion):
|
|
|
338
438
|
|
|
339
439
|
modules = last_module._modules.copy()
|
|
340
440
|
last_module._modules.clear()
|
|
341
|
-
|
|
441
|
+
|
|
342
442
|
for name, value in modules.items():
|
|
343
443
|
last_module._modules[name] = value
|
|
344
444
|
if name == output_group.split(".")[-1]:
|
|
345
|
-
last_module.add_module(
|
|
346
|
-
|
|
445
|
+
last_module.add_module(
|
|
446
|
+
"LatentDistribution",
|
|
447
|
+
LatentDistribution(shape=self.shape, latent_dim=self.latent_dim),
|
|
448
|
+
)
|
|
449
|
+
return ".".join(output_group.split(".")[:-1]) + ".LatentDistribution.Concat"
|
|
347
450
|
|
|
348
|
-
def forward(self, output: torch.Tensor, targets
|
|
451
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
349
452
|
mu = output[:, 0, :]
|
|
350
453
|
log_std = output[:, 1, :]
|
|
351
|
-
return torch.mean(-0.5 * torch.sum(1 + log_std - mu**2 - torch.exp(log_std), dim
|
|
352
|
-
|
|
353
|
-
"""
|
|
354
|
-
def forward(self, input : torch.Tensor) -> torch.Tensor:
|
|
355
|
-
mu = input[:, 0, :]
|
|
356
|
-
log_std = input[:, 1, :]
|
|
454
|
+
return torch.mean(-0.5 * torch.sum(1 + log_std - mu**2 - torch.exp(log_std), dim=1), dim=0)
|
|
357
455
|
|
|
358
|
-
z = input[:, 2, :]
|
|
359
456
|
|
|
360
|
-
q = torch.distributions.Normal(mu, log_std)
|
|
361
|
-
|
|
362
|
-
target_mu = torch.ones((self.latentDim)).to(input.device)*self.mu.to(input.device)
|
|
363
|
-
target_std = torch.ones((self.latentDim)).to(input.device)*self.std.to(input.device)
|
|
364
|
-
|
|
365
|
-
p = torch.distributions.Normal(target_mu, target_std)
|
|
366
|
-
|
|
367
|
-
log_pz = p.log_prob(z)
|
|
368
|
-
log_qzx = q.log_prob(z)
|
|
369
|
-
|
|
370
|
-
kl = (log_pz - log_qzx)
|
|
371
|
-
kl = kl.sum(-1)
|
|
372
|
-
return kl
|
|
373
|
-
"""
|
|
374
|
-
|
|
375
457
|
class Accuracy(Criterion):
|
|
376
458
|
|
|
377
459
|
def __init__(self) -> None:
|
|
378
460
|
super().__init__()
|
|
379
|
-
self.n
|
|
380
|
-
self.corrects = torch.zeros(
|
|
461
|
+
self.n: int = 0
|
|
462
|
+
self.corrects = torch.zeros(1)
|
|
381
463
|
|
|
382
|
-
def forward(self, output: torch.Tensor, *targets
|
|
464
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
383
465
|
target_0 = targets[0]
|
|
384
466
|
self.n += output.shape[0]
|
|
385
467
|
self.corrects += (torch.argmax(torch.softmax(output, dim=1), dim=1) == target_0).sum().float().cpu()
|
|
386
|
-
return self.corrects/self.n
|
|
468
|
+
return self.corrects / self.n
|
|
469
|
+
|
|
387
470
|
|
|
388
471
|
class TripletLoss(Criterion):
|
|
389
472
|
|
|
@@ -391,9 +474,10 @@ class TripletLoss(Criterion):
|
|
|
391
474
|
super().__init__()
|
|
392
475
|
self.triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
|
393
476
|
|
|
394
|
-
def forward(self, output: torch.Tensor) -> torch.Tensor:
|
|
477
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
395
478
|
return self.triplet_loss(output[0], output[1], output[2])
|
|
396
479
|
|
|
480
|
+
|
|
397
481
|
class L1LossRepresentation(Criterion):
|
|
398
482
|
|
|
399
483
|
def __init__(self) -> None:
|
|
@@ -403,28 +487,31 @@ class L1LossRepresentation(Criterion):
|
|
|
403
487
|
def _variance(self, features: torch.Tensor) -> torch.Tensor:
|
|
404
488
|
return torch.mean(torch.clamp(1 - torch.var(features, dim=0), min=0))
|
|
405
489
|
|
|
406
|
-
def forward(self, output: torch.Tensor) -> torch.Tensor:
|
|
407
|
-
return self.loss(output[0], output[1])+ self._variance(output[0]) + self._variance(output[1])
|
|
490
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
491
|
+
return self.loss(output[0], output[1]) + self._variance(output[0]) + self._variance(output[1])
|
|
492
|
+
|
|
408
493
|
|
|
409
494
|
class FocalLoss(Criterion):
|
|
410
495
|
|
|
411
|
-
def __init__(
|
|
496
|
+
def __init__(
|
|
497
|
+
self,
|
|
498
|
+
gamma: float = 2.0,
|
|
499
|
+
alpha: list[float] = [0.5, 2.0, 0.5, 0.5, 1],
|
|
500
|
+
reduction: str = "mean",
|
|
501
|
+
):
|
|
412
502
|
super().__init__()
|
|
413
503
|
raw_alpha = torch.tensor(alpha, dtype=torch.float32)
|
|
414
504
|
self.alpha = raw_alpha / raw_alpha.sum() * len(raw_alpha)
|
|
415
505
|
self.gamma = gamma
|
|
416
506
|
self.reduction = reduction
|
|
417
507
|
|
|
418
|
-
def forward(self, output: torch.Tensor, *targets:
|
|
419
|
-
target = F.interpolate(
|
|
420
|
-
|
|
421
|
-
output.shape[2:],
|
|
422
|
-
mode="nearest").long()
|
|
423
|
-
|
|
508
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
509
|
+
target = F.interpolate(targets[0], output.shape[2:], mode="nearest").long()
|
|
510
|
+
|
|
424
511
|
logpt = F.log_softmax(output, dim=1)
|
|
425
512
|
pt = torch.exp(logpt)
|
|
426
513
|
|
|
427
|
-
logpt = logpt.gather(1, target)
|
|
514
|
+
logpt = logpt.gather(1, target)
|
|
428
515
|
pt = pt.gather(1, target)
|
|
429
516
|
|
|
430
517
|
at = self.alpha.to(target.device)[target].unsqueeze(1)
|
|
@@ -435,7 +522,7 @@ class FocalLoss(Criterion):
|
|
|
435
522
|
elif self.reduction == "sum":
|
|
436
523
|
return loss.sum()
|
|
437
524
|
return loss
|
|
438
|
-
|
|
525
|
+
|
|
439
526
|
|
|
440
527
|
class FID(Criterion):
|
|
441
528
|
|
|
@@ -449,19 +536,26 @@ class FID(Criterion):
|
|
|
449
536
|
|
|
450
537
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
451
538
|
return self.model(x)
|
|
452
|
-
|
|
539
|
+
|
|
453
540
|
def __init__(self) -> None:
|
|
454
541
|
super().__init__()
|
|
455
542
|
self.inception_model = FID.InceptionV3().cuda()
|
|
456
|
-
|
|
543
|
+
|
|
544
|
+
@staticmethod
|
|
457
545
|
def preprocess_images(image: torch.Tensor) -> torch.Tensor:
|
|
458
|
-
return F.normalize(
|
|
546
|
+
return F.normalize(
|
|
547
|
+
F.resize(image, (299, 299)).repeat((1, 3, 1, 1)),
|
|
548
|
+
mean=[0.485, 0.456, 0.406],
|
|
549
|
+
std=[0.229, 0.224, 0.225],
|
|
550
|
+
).cuda()
|
|
459
551
|
|
|
552
|
+
@staticmethod
|
|
460
553
|
def get_features(images: torch.Tensor, model: torch.nn.Module) -> np.ndarray:
|
|
461
554
|
with torch.no_grad():
|
|
462
555
|
features = model(images).cpu().numpy()
|
|
463
556
|
return features
|
|
464
557
|
|
|
558
|
+
@staticmethod
|
|
465
559
|
def calculate_fid(real_features: np.ndarray, generated_features: np.ndarray) -> float:
|
|
466
560
|
mu1 = np.mean(real_features, axis=0)
|
|
467
561
|
sigma1 = np.cov(real_features, rowvar=False)
|
|
@@ -475,7 +569,7 @@ class FID(Criterion):
|
|
|
475
569
|
|
|
476
570
|
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
|
|
477
571
|
|
|
478
|
-
def forward(self, output: torch.Tensor, *targets
|
|
572
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
479
573
|
real_images = FID.preprocess_images(targets[0].squeeze(0).permute([1, 0, 2, 3]))
|
|
480
574
|
generated_images = FID.preprocess_images(output.squeeze(0).permute([1, 0, 2, 3]))
|
|
481
575
|
|
|
@@ -484,8 +578,15 @@ class FID(Criterion):
|
|
|
484
578
|
|
|
485
579
|
return FID.calculate_fid(real_features, generated_features)
|
|
486
580
|
|
|
581
|
+
|
|
487
582
|
class MutualInformationLoss(torch.nn.Module):
|
|
488
|
-
def __init__(
|
|
583
|
+
def __init__(
|
|
584
|
+
self,
|
|
585
|
+
num_bins: int = 23,
|
|
586
|
+
sigma_ratio: float = 0.5,
|
|
587
|
+
smooth_nr: float = 1e-7,
|
|
588
|
+
smooth_dr: float = 1e-7,
|
|
589
|
+
) -> None:
|
|
489
590
|
super().__init__()
|
|
490
591
|
bin_centers = torch.linspace(0.0, 1.0, num_bins)
|
|
491
592
|
sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio
|
|
@@ -495,7 +596,9 @@ class MutualInformationLoss(torch.nn.Module):
|
|
|
495
596
|
self.smooth_nr = float(smooth_nr)
|
|
496
597
|
self.smooth_dr = float(smooth_dr)
|
|
497
598
|
|
|
498
|
-
def parzen_windowing(
|
|
599
|
+
def parzen_windowing(
|
|
600
|
+
self, pred: torch.Tensor, target: torch.Tensor
|
|
601
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
499
602
|
pred_weight, pred_probability = self.parzen_windowing_gaussian(pred)
|
|
500
603
|
target_weight, target_probability = self.parzen_windowing_gaussian(target)
|
|
501
604
|
return pred_weight, pred_probability, target_weight, target_probability
|
|
@@ -503,82 +606,95 @@ class MutualInformationLoss(torch.nn.Module):
|
|
|
503
606
|
def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
504
607
|
img = torch.clamp(img, 0, 1)
|
|
505
608
|
img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
|
|
506
|
-
weight = torch.exp(
|
|
609
|
+
weight = torch.exp(
|
|
610
|
+
-self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2
|
|
611
|
+
) # (batch, num_sample, num_bin)
|
|
507
612
|
weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin)
|
|
508
613
|
probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin)
|
|
509
614
|
return weight, probability
|
|
510
615
|
|
|
511
|
-
def forward(self, pred: torch.Tensor,
|
|
512
|
-
wa, pa, wb, pb = self.parzen_windowing(pred,
|
|
616
|
+
def forward(self, pred: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
617
|
+
wa, pa, wb, pb = self.parzen_windowing(pred, targets[0]) # (batch, num_sample, num_bin), (batch, 1, num_bin)
|
|
513
618
|
pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
|
|
514
619
|
papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
|
|
515
|
-
mi = torch.sum(
|
|
620
|
+
mi = torch.sum(
|
|
621
|
+
pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr),
|
|
622
|
+
dim=(1, 2),
|
|
623
|
+
) # (batch)
|
|
516
624
|
return torch.mean(mi).neg() # average over the batch and channel ndims
|
|
517
625
|
|
|
518
|
-
class FOCUS(Criterion): #Feature-Oriented Comparison for Unpaired Synthesis
|
|
519
626
|
|
|
520
|
-
|
|
627
|
+
class IMPACTSynth(Criterion): # Feature-Oriented Comparison for Unpaired Synthesis
|
|
628
|
+
|
|
629
|
+
def __init__(
|
|
630
|
+
self, model_name: str, shape: list[int] = [0, 0], in_channels: int = 1, weights: list[float] = [1]
|
|
631
|
+
) -> None:
|
|
521
632
|
super().__init__()
|
|
522
633
|
if model_name is None:
|
|
523
634
|
return
|
|
524
|
-
self.shape = shape
|
|
525
635
|
self.in_channels = in_channels
|
|
526
|
-
self.
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
self.model: torch.nn.Module = torch.jit.load(self.model_path)
|
|
533
|
-
self.dim = len(
|
|
534
|
-
self.shape =
|
|
636
|
+
self.loss = torch.nn.L1Loss()
|
|
637
|
+
self.weights = weights
|
|
638
|
+
self.model_path = download_url(
|
|
639
|
+
model_name,
|
|
640
|
+
"https://huggingface.co/VBoussot/impact-torchscript-models/resolve/main/",
|
|
641
|
+
)
|
|
642
|
+
self.model: torch.nn.Module = torch.jit.load(self.model_path) # nosec B614
|
|
643
|
+
self.dim = len(shape)
|
|
644
|
+
self.shape = shape if all(s > 0 for s in shape) else None
|
|
535
645
|
self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
|
|
536
646
|
|
|
537
647
|
try:
|
|
538
|
-
dummy_input = torch.zeros((1, self.in_channels, *(self.shape if self.shape else [224] * self.dim)))
|
|
648
|
+
dummy_input = torch.zeros((1, self.in_channels, *(self.shape if self.shape else [224] * self.dim))).to(0)
|
|
539
649
|
out = self.model(dummy_input)
|
|
540
650
|
if not isinstance(out, (list, tuple)):
|
|
541
651
|
raise TypeError(f"Expected model output to be a list or tuple, but got {type(out)}.")
|
|
542
|
-
if len(
|
|
543
|
-
raise ValueError(f"Mismatch between number of weights ({len(
|
|
652
|
+
if len(weights) != len(out):
|
|
653
|
+
raise ValueError(f"Mismatch between number of weights ({len(weights)}) and model outputs ({len(out)}).")
|
|
544
654
|
except Exception as e:
|
|
545
655
|
msg = (
|
|
546
656
|
f"[Model Sanity Check Failed]\n"
|
|
547
657
|
f"Input shape attempted: {dummy_input.shape}\n"
|
|
548
|
-
f"Expected output length: {len(
|
|
658
|
+
f"Expected output length: {len(weights)}\n"
|
|
549
659
|
f"Error: {type(e).__name__}: {e}"
|
|
550
660
|
)
|
|
551
661
|
raise RuntimeError(msg) from e
|
|
552
662
|
self.model = None
|
|
553
|
-
|
|
554
|
-
def preprocessing(self,
|
|
555
|
-
if self.shape is not None and not all(
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
663
|
+
|
|
664
|
+
def preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
665
|
+
if self.shape is not None and not all(
|
|
666
|
+
tensor.shape[-i - 1] == size for i, size in enumerate(reversed(self.shape[2:]))
|
|
667
|
+
):
|
|
668
|
+
tensor = torch.nn.functional.interpolate(
|
|
669
|
+
tensor, mode=self.mode, size=tuple(self.shape), align_corners=False
|
|
670
|
+
).type(torch.float32)
|
|
671
|
+
if tensor.shape[1] != self.in_channels:
|
|
672
|
+
tensor = tensor.repeat(tuple([1, 3] + [1 for _ in range(self.dim)]))
|
|
673
|
+
return tensor
|
|
674
|
+
|
|
561
675
|
def _compute(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
|
|
562
|
-
loss = torch.zeros((1), requires_grad
|
|
676
|
+
loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
|
|
563
677
|
output = self.preprocessing(output)
|
|
564
678
|
targets = [self.preprocessing(target) for target in targets]
|
|
565
679
|
self.model.to(output.device)
|
|
566
|
-
for zipped_output in zip(self.model(output), *[self.model(target) for target in targets]):
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
680
|
+
for zipped_output in zip(self.weights, self.model(output), *[self.model(target) for target in targets]):
|
|
681
|
+
weight = zipped_output[0]
|
|
682
|
+
if weight == 0:
|
|
683
|
+
continue
|
|
684
|
+
output_feature = zipped_output[1]
|
|
685
|
+
targets_features = zipped_output[1:]
|
|
686
|
+
for target_feature in targets_features:
|
|
687
|
+
loss += weight * self.loss(output_feature, target_feature)
|
|
572
688
|
return loss
|
|
573
|
-
|
|
574
|
-
def forward(self, output: torch.Tensor, *targets
|
|
689
|
+
|
|
690
|
+
def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
|
|
575
691
|
if self.model is None:
|
|
576
|
-
self.model = torch.jit.load(self.model_path)
|
|
577
|
-
loss = torch.zeros((1), requires_grad
|
|
692
|
+
self.model = torch.jit.load(self.model_path) # nosec B614
|
|
693
|
+
loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
|
|
578
694
|
if len(output.shape) == 5 and self.dim == 2:
|
|
579
695
|
for i in range(output.shape[2]):
|
|
580
|
-
loss
|
|
581
|
-
loss/=output.shape[2]
|
|
696
|
+
loss += self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])
|
|
697
|
+
loss /= output.shape[2]
|
|
582
698
|
else:
|
|
583
|
-
loss = self._compute(output, targets)
|
|
584
|
-
return loss.to(output)
|
|
699
|
+
loss = self._compute(output, list(targets))
|
|
700
|
+
return loss.to(output)
|