konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/metric/measure.py CHANGED
@@ -1,334 +1,434 @@
1
- from abc import ABC
1
+ import copy
2
2
  import importlib
3
- import numpy as np
4
- import torch
5
-
6
- from torchvision.models import inception_v3, Inception_V3_Weights
7
- from torchvision.transforms import functional as F
8
- from scipy import linalg
9
-
10
- import torch.nn.functional as F
11
3
  import os
12
-
13
- from typing import Callable, Union
4
+ from abc import ABC, abstractmethod
5
+ from collections.abc import Callable
14
6
  from functools import partial
15
- import torch.nn.functional as F
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F # noqa: N812
11
+ from scipy import linalg
16
12
  from skimage.metrics import structural_similarity
17
- import copy
18
- from abc import abstractmethod
13
+ from torchvision.models import Inception_V3_Weights, inception_v3
14
+ from tqdm import tqdm
19
15
 
20
- from konfai.utils.config import config
21
- from konfai.utils.utils import _getModule, download_url
22
16
  from konfai.data.patching import ModelPatch
23
17
  from konfai.network.blocks import LatentDistribution
24
18
  from konfai.network.network import ModelLoader, Network
25
- from tqdm import tqdm
19
+ from konfai.utils.config import config
20
+ from konfai.utils.utils import download_url, get_module
21
+
22
+ models_register = {}
26
23
 
27
- modelsRegister = {}
28
24
 
29
25
  class Criterion(torch.nn.Module, ABC):
30
26
 
31
27
  def __init__(self) -> None:
32
28
  super().__init__()
33
29
 
34
- def init(self, model : torch.nn.Module, output_group : str, target_group : str) -> str:
30
+ def init(self, model: torch.nn.Module, output_group: str, target_group: str) -> str:
35
31
  return output_group
36
32
 
37
33
  @abstractmethod
38
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
34
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
39
35
  pass
40
36
 
37
+
41
38
  class MaskedLoss(Criterion):
42
39
 
43
- def __init__(self, loss : Callable[[torch.Tensor, torch.Tensor], torch.Tensor], mode_image_masked: bool) -> None:
40
+ def __init__(
41
+ self,
42
+ loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
43
+ mode_image_masked: bool,
44
+ ) -> None:
44
45
  super().__init__()
45
46
  self.loss = loss
46
47
  self.mode_image_masked = mode_image_masked
47
48
 
48
- def getMask(self, targets: list[torch.Tensor]) -> torch.Tensor:
49
+ def get_mask(self, targets: list[torch.Tensor]) -> torch.Tensor | None:
49
50
  result = None
50
51
  if len(targets) > 0:
51
52
  result = targets[0]
52
53
  for mask in targets[1:]:
53
- result = result*mask
54
+ result = result * mask
54
55
  return result
55
56
 
56
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
57
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> tuple[torch.Tensor, float]:
57
58
  loss = torch.tensor(0, dtype=torch.float32).to(output.device)
58
- mask = self.getMask(targets[1:])
59
- for batch in range(output.shape[0]):
60
- if mask is not None:
59
+ true_loss = 0
60
+ true_nb = 0
61
+ mask = self.get_mask(list(targets[1:]))
62
+ if mask is not None:
63
+ for batch in range(output.shape[0]):
61
64
  if self.mode_image_masked:
62
- for i in torch.unique(mask[batch]):
63
- if i != 0:
64
- loss += self.loss(output[batch, ...]*torch.where(mask == i, 1, 0), targets[0][batch, ...]*torch.where(mask == i, 1, 0))
65
+ loss_b = self.loss(
66
+ output[batch, ...] * torch.where(mask == 1, 1, 0),
67
+ targets[0][batch, ...] * torch.where(mask == 1, 1, 0),
68
+ )
65
69
  else:
66
- for i in torch.unique(mask[batch]):
67
- if i != 0:
68
- index = mask[batch, ...] == i
69
- loss += self.loss(torch.masked_select(output[batch, ...], index), torch.masked_select(targets[0][batch, ...], index))
70
- else:
71
- loss += self.loss(output[batch, ...], targets[0][batch, ...])
72
- return loss/output.shape[0]
73
-
70
+ index = mask[batch, ...] == 1
71
+ loss_b = self.loss(
72
+ torch.masked_select(output[batch, ...], index),
73
+ torch.masked_select(targets[0][batch, ...], index),
74
+ )
75
+
76
+ loss += loss_b
77
+ if torch.any(mask[batch] == 1):
78
+ true_loss += loss_b.item()
79
+ true_nb += 1
80
+ else:
81
+ loss_tmp = self.loss(output, targets[0])
82
+ loss += loss_tmp
83
+ true_loss += loss_tmp.item()
84
+ true_nb += 1
85
+ return loss / output.shape[0], np.nan if true_nb == 0 else true_loss / true_nb
86
+
87
+
74
88
  class MSE(MaskedLoss):
75
89
 
90
+ @staticmethod
76
91
  def _loss(reduction: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
77
92
  return torch.nn.MSELoss(reduction=reduction)(x, y)
78
93
 
79
94
  def __init__(self, reduction: str = "mean") -> None:
80
95
  super().__init__(partial(MSE._loss, reduction), False)
81
96
 
97
+
82
98
  class MAE(MaskedLoss):
83
99
 
100
+ @staticmethod
84
101
  def _loss(reduction: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
85
102
  return torch.nn.L1Loss(reduction=reduction)(x, y)
86
-
103
+
87
104
  def __init__(self, reduction: str = "mean") -> None:
88
105
  super().__init__(partial(MAE._loss, reduction), False)
89
106
 
107
+
90
108
  class PSNR(MaskedLoss):
91
109
 
92
- def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
110
+ @staticmethod
111
+ def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
93
112
  mse = torch.mean((x - y).pow(2))
94
- psnr = 10 * torch.log10(dynamic_range**2 / mse)
95
- return psnr
113
+ psnr = 10 * torch.log10(dynamic_range**2 / mse)
114
+ return psnr
96
115
 
97
- def __init__(self, dynamic_range: Union[float, None] = None) -> None:
98
- dynamic_range = dynamic_range if dynamic_range else 1024+3071
116
+ def __init__(self, dynamic_range: float | None = None) -> None:
117
+ dynamic_range = dynamic_range if dynamic_range else 1024 + 3071
99
118
  super().__init__(partial(PSNR._loss, dynamic_range), False)
100
-
119
+
120
+
101
121
  class SSIM(MaskedLoss):
102
-
122
+
123
+ @staticmethod
103
124
  def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
104
- return structural_similarity(x[0][0].detach().cpu().numpy(), y[0][0].cpu().numpy(), data_range=dynamic_range, gradient=False, full=False)
105
-
106
- def __init__(self, dynamic_range: Union[float, None] = None) -> None:
107
- dynamic_range = dynamic_range if dynamic_range else 1024+3000
125
+ return structural_similarity(
126
+ x[0][0].detach().cpu().numpy(),
127
+ y[0][0].cpu().numpy(),
128
+ data_range=dynamic_range,
129
+ gradient=False,
130
+ full=False,
131
+ )
132
+
133
+ def __init__(self, dynamic_range: float | None = None) -> None:
134
+ dynamic_range = dynamic_range if dynamic_range else 1024 + 3000
108
135
  super().__init__(partial(SSIM._loss, dynamic_range), True)
109
136
 
110
137
 
111
138
  class LPIPS(MaskedLoss):
112
139
 
113
- def normalize(input: torch.Tensor) -> torch.Tensor:
114
- return (input-torch.min(input))/(torch.max(input)-torch.min(input))*2-1
115
-
116
- def preprocessing(input: torch.Tensor) -> torch.Tensor:
117
- return input.repeat((1,3,1,1))
118
-
140
+ @staticmethod
141
+ def normalize(tensor: torch.Tensor) -> torch.Tensor:
142
+ return (tensor - torch.min(tensor)) / (torch.max(tensor) - torch.min(tensor)) * 2 - 1
143
+
144
+ @staticmethod
145
+ def preprocessing(tensor: torch.Tensor) -> torch.Tensor:
146
+ return tensor.repeat((1, 3, 1, 1)).to(0)
147
+
148
+ @staticmethod
119
149
  def _loss(loss_fn_alex, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
120
- datasetPatch = ModelPatch([1, 64, 64])
121
- datasetPatch.load(x.shape[2:])
150
+ dataset_patch = ModelPatch([1, 64, 64])
151
+ dataset_patch.load(x.shape[2:])
122
152
 
123
- patchIterator = datasetPatch.disassemble(LPIPS.normalize(x), LPIPS.normalize(y))
153
+ patch_iterator = dataset_patch.disassemble(LPIPS.normalize(x), LPIPS.normalize(y))
124
154
  loss = 0
125
- with tqdm(iterable = enumerate(patchIterator), leave=False, total=datasetPatch.getSize(0)) as batch_iter:
126
- for i, patch_input in batch_iter:
155
+ with tqdm(
156
+ iterable=enumerate(patch_iterator),
157
+ leave=False,
158
+ total=dataset_patch.get_size(0),
159
+ ) as batch_iter:
160
+ for _, patch_input in batch_iter:
127
161
  real, fake = LPIPS.preprocessing(patch_input[0]), LPIPS.preprocessing(patch_input[1])
128
162
  loss += loss_fn_alex(real, fake).item()
129
- return loss/datasetPatch.getSize(0)
163
+ return loss / dataset_patch.get_size(0)
130
164
 
131
165
  def __init__(self, model: str = "alex") -> None:
132
166
  import lpips
133
- super().__init__(partial(LPIPS._loss, lpips.LPIPS(net=model)), True)
167
+
168
+ super().__init__(partial(LPIPS._loss, lpips.LPIPS(net=model).to(0)), True)
169
+
134
170
 
135
171
  class Dice(Criterion):
136
-
137
- def __init__(self, smooth : float = 1e-6) -> None:
172
+
173
+ def __init__(self, labels: list[int] | None = None) -> None:
138
174
  super().__init__()
139
- self.smooth = smooth
140
-
141
- def flatten(self, tensor : torch.Tensor) -> torch.Tensor:
175
+ self.labels = labels
176
+
177
+ def flatten(self, tensor: torch.Tensor) -> torch.Tensor:
142
178
  return tensor.permute((1, 0) + tuple(range(2, tensor.dim()))).contiguous().view(tensor.size(1), -1)
143
179
 
144
- def dice_per_channel(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
145
- input = self.flatten(input)
180
+ def dice_per_channel(self, tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
181
+ tensor = self.flatten(tensor)
146
182
  target = self.flatten(target)
147
- return (2.*(input * target).sum() + self.smooth)/(input.sum() + target.sum() + self.smooth)
148
-
149
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
150
- target = F.interpolate(
151
- targets[0],
152
- output.shape[2:],
153
- mode="nearest")
154
- if output.shape[1] == 1:
155
- output = F.one_hot(output.type(torch.int64), num_classes=int(torch.max(output).item()+1)).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
156
- target = F.one_hot(target.type(torch.int64), num_classes=output.shape[1]).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float().squeeze(2)
157
- return 1-torch.mean(self.dice_per_channel(output, target))
183
+ return (2.0 * (tensor * target).sum() + 1e-6) / (tensor.sum() + target.sum() + 1e-6)
184
+
185
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
186
+ target = F.interpolate(targets[0], output.shape[2:], mode="nearest")
187
+ result = {}
188
+ loss = torch.tensor(0, dtype=torch.float32).to(output.device)
189
+ labels = self.labels if self.labels is not None else torch.unique(target)
190
+ for label in labels:
191
+ tp = target == label
192
+ if tp.any().item():
193
+ if output.shape[1] > 1:
194
+ pp = output[:, label].unsqueeze(1)
195
+ else:
196
+ pp = output == label
197
+ loss_tmp = self.dice_per_channel(pp.float(), tp.float())
198
+ loss += loss_tmp
199
+ result[label] = loss_tmp.item()
200
+ else:
201
+ result[label] = np.nan
202
+ return 1 - loss / len(labels), result
203
+
158
204
 
159
205
  class GradientImages(Criterion):
160
206
 
161
207
  def __init__(self):
162
208
  super().__init__()
163
-
209
+
164
210
  @staticmethod
165
- def _image_gradient2D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
211
+ def _image_gradient_2d(image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
166
212
  dx = image[:, :, 1:, :] - image[:, :, :-1, :]
167
213
  dy = image[:, :, :, 1:] - image[:, :, :, :-1]
168
214
  return dx, dy
169
215
 
170
216
  @staticmethod
171
- def _image_gradient3D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
217
+ def _image_gradient_3d(
218
+ image: torch.Tensor,
219
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
172
220
  dx = image[:, :, 1:, :, :] - image[:, :, :-1, :, :]
173
221
  dy = image[:, :, :, 1:, :] - image[:, :, :, :-1, :]
174
222
  dz = image[:, :, :, :, 1:] - image[:, :, :, :, :-1]
175
223
  return dx, dy, dz
176
-
177
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
224
+
225
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
178
226
  target_0 = targets[0]
179
227
  if len(output.shape) == 5:
180
- dx, dy, dz = GradientImages._image_gradient3D(output)
228
+ dx, dy, dz = GradientImages._image_gradient_3d(output)
181
229
  if target_0 is not None:
182
- dx_tmp, dy_tmp, dz_tmp = GradientImages._image_gradient3D(target_0)
230
+ dx_tmp, dy_tmp, dz_tmp = GradientImages._image_gradient_3d(target_0)
183
231
  dx -= dx_tmp
184
232
  dy -= dy_tmp
185
233
  dz -= dz_tmp
186
234
  return dx.norm() + dy.norm() + dz.norm()
187
235
  else:
188
- dx, dy = GradientImages._image_gradient2D(output)
236
+ dx, dy = GradientImages._image_gradient_2d(output)
189
237
  if target_0 is not None:
190
- dx_tmp, dy_tmp = GradientImages._image_gradient2D(target_0)
238
+ dx_tmp, dy_tmp = GradientImages._image_gradient_2d(target_0)
191
239
  dx -= dx_tmp
192
240
  dy -= dy_tmp
193
241
  return dx.norm() + dy.norm()
194
-
242
+
243
+
195
244
  class BCE(Criterion):
196
245
 
197
- def __init__(self, target : float = 0) -> None:
246
+ def __init__(self, target: float = 0) -> None:
198
247
  super().__init__()
199
248
  self.loss = torch.nn.BCEWithLogitsLoss()
200
- self.register_buffer('target', torch.tensor(target).type(torch.float32))
249
+ self.register_buffer("target", torch.tensor(target).type(torch.float32))
201
250
 
202
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
251
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
203
252
  target = self._buffers["target"]
204
253
  return self.loss(output, target.to(output.device).expand_as(output))
205
254
 
255
+
206
256
  class PatchGanLoss(Criterion):
207
257
 
208
- def __init__(self, target : float = 0) -> None:
258
+ def __init__(self, target: float = 0) -> None:
209
259
  super().__init__()
210
260
  self.loss = torch.nn.MSELoss()
211
- self.register_buffer('target', torch.tensor(target).type(torch.float32))
261
+ self.register_buffer("target", torch.tensor(target).type(torch.float32))
212
262
 
213
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
263
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
214
264
  target = self._buffers["target"]
215
- return self.loss(output, (torch.ones_like(output)*target).to(output.device))
265
+ return self.loss(output, (torch.ones_like(output) * target).to(output.device))
266
+
216
267
 
217
268
  class WGP(Criterion):
218
269
 
219
270
  def __init__(self) -> None:
220
271
  super().__init__()
221
-
222
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
223
- return torch.mean((output - 1)**2)
272
+
273
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
274
+ return torch.mean((output - 1) ** 2)
275
+
224
276
 
225
277
  class Gram(Criterion):
226
278
 
227
- def computeGram(input : torch.Tensor):
228
- (b, ch, w) = input.size()
229
- with torch.amp.autocast('cuda', enabled=False):
230
- return input.bmm(input.transpose(1, 2)).div(ch*w)
279
+ @staticmethod
280
+ def compute_gram(tensor: torch.Tensor):
281
+ (b, ch, w) = tensor.size()
282
+ with torch.amp.autocast("cuda", enabled=False):
283
+ return tensor.bmm(tensor.transpose(1, 2)).div(ch * w)
231
284
 
232
285
  def __init__(self) -> None:
233
286
  super().__init__()
234
- self.loss = torch.nn.L1Loss(reduction='sum')
287
+ self.loss = torch.nn.L1Loss(reduction="sum")
235
288
 
236
- def forward(self, output: torch.Tensor, target : torch.Tensor) -> torch.Tensor:
289
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
290
+ target = targets[0]
237
291
  if len(output.shape) > 3:
238
292
  output = output.view(output.shape[0], output.shape[1], int(np.prod(output.shape[2:])))
239
293
  if len(target.shape) > 3:
240
294
  target = target.view(target.shape[0], target.shape[1], int(np.prod(target.shape[2:])))
241
- return self.loss(Gram.computeGram(output), Gram.computeGram(target))
295
+ return self.loss(Gram.compute_gram(output), Gram.compute_gram(target))
296
+
242
297
 
243
298
  class PerceptualLoss(Criterion):
244
-
245
- class Module():
246
-
299
+
300
+ class Module:
301
+
247
302
  @config()
248
303
  def __init__(self, losses: dict[str, float] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
249
304
  self.losses = losses
250
- self.DL_args = os.environ['KONFAI_CONFIG_PATH'] if "KONFAI_CONFIG_PATH" in os.environ else ""
305
+ self.konfai_args = os.environ["KONFAI_CONFIG_PATH"] if "KONFAI_CONFIG_PATH" in os.environ else ""
251
306
 
252
- def getLoss(self) -> dict[torch.nn.Module, float]:
307
+ def get_loss(self) -> dict[torch.nn.Module, float]:
253
308
  result: dict[torch.nn.Module, float] = {}
254
- for loss, l in self.losses.items():
255
- module, name = _getModule(loss, "konfai.metric.measure")
256
- result[config(self.DL_args)(getattr(importlib.import_module(module), name))(config=None)] = l
309
+ for loss, loss_value in self.losses.items():
310
+ module, name = get_module(loss, "konfai.metric.measure")
311
+ result[config(self.konfai_args)(getattr(importlib.import_module(module), name))(config=None)] = (
312
+ loss_value
313
+ )
257
314
  return result
258
-
259
- def __init__(self, modelLoader : ModelLoader = ModelLoader(), path_model : str = "name", modules : dict[str, Module] = {"UNetBlock_0.DownConvBlock.Activation_1": Module({"Gram": 1, "torch_nn_L1Loss": 1})}, shape: list[int] = [128, 128, 128]) -> None:
315
+
316
+ def __init__(
317
+ self,
318
+ model_loader: ModelLoader = ModelLoader(),
319
+ path_model: str = "name",
320
+ modules: dict[str, Module] = {
321
+ "UNetBlock_0.DownConvBlock.Activation_1": Module({"Gram": 1, "torch_nn_L1Loss": 1})
322
+ },
323
+ shape: list[int] = [128, 128, 128],
324
+ ) -> None:
260
325
  super().__init__()
261
326
  self.path_model = path_model
262
- if self.path_model not in modelsRegister:
263
- self.model = modelLoader.getModel(train=False, DL_args=os.environ['KONFAI_CONFIG_PATH'].split("PerceptualLoss")[0]+"PerceptualLoss.Model", DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain", "outputsCriterions", "drop_p"])
327
+ if self.path_model not in models_register:
328
+ self.model = model_loader.get_model(
329
+ train=False,
330
+ konfai_args=os.environ["KONFAI_CONFIG_PATH"].split("PerceptualLoss")[0] + "PerceptualLoss.Model",
331
+ konfai_without=[
332
+ "optimizer",
333
+ "schedulers",
334
+ "nb_batch_per_step",
335
+ "init_type",
336
+ "init_gain",
337
+ "outputs_criterions",
338
+ "drop_p",
339
+ ],
340
+ )
264
341
  if path_model.startswith("https"):
265
342
  state_dict = torch.hub.load_state_dict_from_url(path_model)
266
- state_dict = {"Model": {self.model.getName() : state_dict["model"]}}
343
+ state_dict = {"Model": {self.model.get_name(): state_dict["model"]}}
267
344
  else:
268
345
  state_dict = torch.load(path_model, weights_only=True)
269
346
  self.model.load(state_dict)
270
- modelsRegister[self.path_model] = self.model
347
+ models_register[self.path_model] = self.model
271
348
  else:
272
- self.model = modelsRegister[self.path_model]
349
+ self.model = models_register[self.path_model]
273
350
 
274
351
  self.shape = shape
275
- self.mode = "trilinear" if len(shape) == 3 else "bilinear"
352
+ self.mode = "trilinear" if len(shape) == 3 else "bilinear"
276
353
  self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
277
354
  for name, losses in modules.items():
278
- self.modules_loss[name.replace(":", ".")] = losses.getLoss()
355
+ self.modules_loss[name.replace(":", ".")] = losses.get_loss()
279
356
 
280
357
  self.model.eval()
281
358
  self.model.requires_grad_(False)
282
359
  self.models: dict[int, torch.nn.Module] = {}
283
360
 
284
- def preprocessing(self, input: torch.Tensor) -> torch.Tensor:
285
- #if not all([input.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]):
286
- # input = F.interpolate(input, mode=self.mode, size=tuple(self.shape), align_corners=False).type(torch.float32)
287
- #if input.shape[1] != self.model.in_channels:
288
- # input = input.repeat(tuple([1,self.model.in_channels] + [1 for _ in range(len(self.shape))]))
289
- #input = (input - torch.min(input))/(torch.max(input)-torch.min(input))
290
- #input = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input)
291
- return input
292
-
293
- def _compute(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
294
- loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
295
- output = self.preprocessing(output)
296
- targets = [self.preprocessing(target) for target in targets]
297
- for zipped_output in zip([output], *[[target] for target in targets]):
361
+ def preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
362
+ # if not all([tensor.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]):
363
+ # tensor = F.interpolate(tensor, mode=self.mode,
364
+ # size=tuple(self.shape), align_corners=False).type(torch.float32)
365
+ # if tensor.shape[1] != self.model.in_channels:
366
+ # tensor = tensor.repeat(tuple([1,self.model.in_channels] + [1 for _ in range(len(self.shape))]))
367
+ # tensor = (tensor - torch.min(tensor))/(torch.max(tensor)-torch.min(tensor))
368
+ # tensor = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(tensor)
369
+ return tensor
370
+
371
+ def _compute(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
372
+ loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
373
+ output_preprocessing = self.preprocessing(output)
374
+ targets_preprocessing = [self.preprocessing(target) for target in targets]
375
+ for zipped_output in zip([output_preprocessing], *[[target] for target in targets_preprocessing]):
298
376
  output = zipped_output[0]
299
377
  targets = zipped_output[1:]
300
378
 
301
- for zipped_layers in list(zip(self.models[output.device.index].get_layers([output], set(self.modules_loss.keys()).copy()), *[self.models[output.device.index].get_layers([target], set(self.modules_loss.keys()).copy()) for target in targets])):
302
- output_layer = zipped_layers[0][1].view(zipped_layers[0][1].shape[0], zipped_layers[0][1].shape[1], int(np.prod(zipped_layers[0][1].shape[2:])))
303
- for (loss_function, l), target_layer in zip(self.modules_loss[zipped_layers[0][0]].items(), zipped_layers[1:]):
304
- target_layer = target_layer[1].view(target_layer[1].shape[0], target_layer[1].shape[1], int(np.prod(target_layer[1].shape[2:])))
305
- loss = loss+l*loss_function(output_layer.float(), target_layer.float())/output_layer.shape[0]
379
+ for zipped_layers in list(
380
+ zip(
381
+ self.models[output.device.index].get_layers([output], set(self.modules_loss.keys()).copy()),
382
+ *[
383
+ self.models[output.device.index].get_layers([target], set(self.modules_loss.keys()).copy())
384
+ for target in targets
385
+ ],
386
+ )
387
+ ):
388
+ output_layer = zipped_layers[0][1].view(
389
+ zipped_layers[0][1].shape[0],
390
+ zipped_layers[0][1].shape[1],
391
+ int(np.prod(zipped_layers[0][1].shape[2:])),
392
+ )
393
+ for (loss_function, loss_value), target_layer in zip(
394
+ self.modules_loss[zipped_layers[0][0]].items(), zipped_layers[1:]
395
+ ):
396
+ target_layer = target_layer[1].view(
397
+ target_layer[1].shape[0],
398
+ target_layer[1].shape[1],
399
+ int(np.prod(target_layer[1].shape[2:])),
400
+ )
401
+ loss = (
402
+ loss
403
+ + loss_value * loss_function(output_layer.float(), target_layer.float()) / output_layer.shape[0]
404
+ )
306
405
  return loss
307
-
308
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
406
+
407
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
309
408
  if output.device.index not in self.models:
310
409
  del os.environ["device"]
311
410
  self.models[output.device.index] = Network.to(copy.deepcopy(self.model).eval(), output.device.index).eval()
312
- loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
411
+ loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
313
412
  if len(output.shape) == 5 and len(self.shape) == 2:
314
413
  for i in range(output.shape[2]):
315
- loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])/output.shape[2]
414
+ loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets]) / output.shape[2]
316
415
  else:
317
416
  loss = self._compute(output, targets)
318
417
  return loss.to(output)
319
418
 
419
+
320
420
  class KLDivergence(Criterion):
321
-
322
- def __init__(self, shape: list[int], dim : int = 100, mu : float = 0, std : float = 1) -> None:
421
+
422
+ def __init__(self, shape: list[int], dim: int = 100, mu: float = 0, std: float = 1) -> None:
323
423
  super().__init__()
324
- self.latentDim = dim
424
+ self.latent_dim = dim
325
425
  self.mu = torch.Tensor([mu])
326
426
  self.std = torch.Tensor([std])
327
427
  self.modelDim = 3
328
428
  self.shape = shape
329
429
  self.loss = torch.nn.KLDivLoss()
330
-
331
- def init(self, model : Network, output_group : str, target_group : str) -> str:
430
+
431
+ def init(self, model: Network, output_group: str, target_group: str) -> str:
332
432
  super().init(model, output_group, target_group)
333
433
  model._compute_channels_trace(model, model.in_channels, None, None)
334
434
 
@@ -338,52 +438,35 @@ class KLDivergence(Criterion):
338
438
 
339
439
  modules = last_module._modules.copy()
340
440
  last_module._modules.clear()
341
-
441
+
342
442
  for name, value in modules.items():
343
443
  last_module._modules[name] = value
344
444
  if name == output_group.split(".")[-1]:
345
- last_module.add_module("LatentDistribution", LatentDistribution(shape = self.shape, latentDim=self.latentDim))
346
- return ".".join(output_group.split(".")[:-1])+".LatentDistribution.Concat"
445
+ last_module.add_module(
446
+ "LatentDistribution",
447
+ LatentDistribution(shape=self.shape, latent_dim=self.latent_dim),
448
+ )
449
+ return ".".join(output_group.split(".")[:-1]) + ".LatentDistribution.Concat"
347
450
 
348
- def forward(self, output: torch.Tensor, targets : list[torch.Tensor]) -> torch.Tensor:
451
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
349
452
  mu = output[:, 0, :]
350
453
  log_std = output[:, 1, :]
351
- return torch.mean(-0.5 * torch.sum(1 + log_std - mu**2 - torch.exp(log_std), dim = 1), dim = 0)
352
-
353
- """
354
- def forward(self, input : torch.Tensor) -> torch.Tensor:
355
- mu = input[:, 0, :]
356
- log_std = input[:, 1, :]
454
+ return torch.mean(-0.5 * torch.sum(1 + log_std - mu**2 - torch.exp(log_std), dim=1), dim=0)
357
455
 
358
- z = input[:, 2, :]
359
456
 
360
- q = torch.distributions.Normal(mu, log_std)
361
-
362
- target_mu = torch.ones((self.latentDim)).to(input.device)*self.mu.to(input.device)
363
- target_std = torch.ones((self.latentDim)).to(input.device)*self.std.to(input.device)
364
-
365
- p = torch.distributions.Normal(target_mu, target_std)
366
-
367
- log_pz = p.log_prob(z)
368
- log_qzx = q.log_prob(z)
369
-
370
- kl = (log_pz - log_qzx)
371
- kl = kl.sum(-1)
372
- return kl
373
- """
374
-
375
457
  class Accuracy(Criterion):
376
458
 
377
459
  def __init__(self) -> None:
378
460
  super().__init__()
379
- self.n : int = 0
380
- self.corrects = torch.zeros((1))
461
+ self.n: int = 0
462
+ self.corrects = torch.zeros(1)
381
463
 
382
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
464
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
383
465
  target_0 = targets[0]
384
466
  self.n += output.shape[0]
385
467
  self.corrects += (torch.argmax(torch.softmax(output, dim=1), dim=1) == target_0).sum().float().cpu()
386
- return self.corrects/self.n
468
+ return self.corrects / self.n
469
+
387
470
 
388
471
  class TripletLoss(Criterion):
389
472
 
@@ -391,9 +474,10 @@ class TripletLoss(Criterion):
391
474
  super().__init__()
392
475
  self.triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
393
476
 
394
- def forward(self, output: torch.Tensor) -> torch.Tensor:
477
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
395
478
  return self.triplet_loss(output[0], output[1], output[2])
396
479
 
480
+
397
481
  class L1LossRepresentation(Criterion):
398
482
 
399
483
  def __init__(self) -> None:
@@ -403,28 +487,31 @@ class L1LossRepresentation(Criterion):
403
487
  def _variance(self, features: torch.Tensor) -> torch.Tensor:
404
488
  return torch.mean(torch.clamp(1 - torch.var(features, dim=0), min=0))
405
489
 
406
- def forward(self, output: torch.Tensor) -> torch.Tensor:
407
- return self.loss(output[0], output[1])+ self._variance(output[0]) + self._variance(output[1])
490
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
491
+ return self.loss(output[0], output[1]) + self._variance(output[0]) + self._variance(output[1])
492
+
408
493
 
409
494
  class FocalLoss(Criterion):
410
495
 
411
- def __init__(self, gamma: float = 2.0, alpha: list[float] = [0.5, 2.0, 0.5, 0.5, 1], reduction: str = "mean"):
496
+ def __init__(
497
+ self,
498
+ gamma: float = 2.0,
499
+ alpha: list[float] = [0.5, 2.0, 0.5, 0.5, 1],
500
+ reduction: str = "mean",
501
+ ):
412
502
  super().__init__()
413
503
  raw_alpha = torch.tensor(alpha, dtype=torch.float32)
414
504
  self.alpha = raw_alpha / raw_alpha.sum() * len(raw_alpha)
415
505
  self.gamma = gamma
416
506
  self.reduction = reduction
417
507
 
418
- def forward(self, output: torch.Tensor, *targets: list[torch.Tensor]) -> torch.Tensor:
419
- target = F.interpolate(
420
- targets[0],
421
- output.shape[2:],
422
- mode="nearest").long()
423
-
508
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
509
+ target = F.interpolate(targets[0], output.shape[2:], mode="nearest").long()
510
+
424
511
  logpt = F.log_softmax(output, dim=1)
425
512
  pt = torch.exp(logpt)
426
513
 
427
- logpt = logpt.gather(1, target)
514
+ logpt = logpt.gather(1, target)
428
515
  pt = pt.gather(1, target)
429
516
 
430
517
  at = self.alpha.to(target.device)[target].unsqueeze(1)
@@ -435,7 +522,7 @@ class FocalLoss(Criterion):
435
522
  elif self.reduction == "sum":
436
523
  return loss.sum()
437
524
  return loss
438
-
525
+
439
526
 
440
527
  class FID(Criterion):
441
528
 
@@ -449,19 +536,26 @@ class FID(Criterion):
449
536
 
450
537
  def forward(self, x: torch.Tensor) -> torch.Tensor:
451
538
  return self.model(x)
452
-
539
+
453
540
  def __init__(self) -> None:
454
541
  super().__init__()
455
542
  self.inception_model = FID.InceptionV3().cuda()
456
-
543
+
544
+ @staticmethod
457
545
  def preprocess_images(image: torch.Tensor) -> torch.Tensor:
458
- return F.normalize(F.resize(image, (299, 299)).repeat((1,3,1,1)), mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).cuda()
546
+ return F.normalize(
547
+ F.resize(image, (299, 299)).repeat((1, 3, 1, 1)),
548
+ mean=[0.485, 0.456, 0.406],
549
+ std=[0.229, 0.224, 0.225],
550
+ ).cuda()
459
551
 
552
+ @staticmethod
460
553
  def get_features(images: torch.Tensor, model: torch.nn.Module) -> np.ndarray:
461
554
  with torch.no_grad():
462
555
  features = model(images).cpu().numpy()
463
556
  return features
464
557
 
558
+ @staticmethod
465
559
  def calculate_fid(real_features: np.ndarray, generated_features: np.ndarray) -> float:
466
560
  mu1 = np.mean(real_features, axis=0)
467
561
  sigma1 = np.cov(real_features, rowvar=False)
@@ -475,7 +569,7 @@ class FID(Criterion):
475
569
 
476
570
  return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
477
571
 
478
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
572
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
479
573
  real_images = FID.preprocess_images(targets[0].squeeze(0).permute([1, 0, 2, 3]))
480
574
  generated_images = FID.preprocess_images(output.squeeze(0).permute([1, 0, 2, 3]))
481
575
 
@@ -484,8 +578,15 @@ class FID(Criterion):
484
578
 
485
579
  return FID.calculate_fid(real_features, generated_features)
486
580
 
581
+
487
582
  class MutualInformationLoss(torch.nn.Module):
488
- def __init__(self, num_bins: int = 23, sigma_ratio: float = 0.5, smooth_nr: float = 1e-7, smooth_dr: float = 1e-7) -> None:
583
+ def __init__(
584
+ self,
585
+ num_bins: int = 23,
586
+ sigma_ratio: float = 0.5,
587
+ smooth_nr: float = 1e-7,
588
+ smooth_dr: float = 1e-7,
589
+ ) -> None:
489
590
  super().__init__()
490
591
  bin_centers = torch.linspace(0.0, 1.0, num_bins)
491
592
  sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio
@@ -495,7 +596,9 @@ class MutualInformationLoss(torch.nn.Module):
495
596
  self.smooth_nr = float(smooth_nr)
496
597
  self.smooth_dr = float(smooth_dr)
497
598
 
498
- def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
599
+ def parzen_windowing(
600
+ self, pred: torch.Tensor, target: torch.Tensor
601
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
499
602
  pred_weight, pred_probability = self.parzen_windowing_gaussian(pred)
500
603
  target_weight, target_probability = self.parzen_windowing_gaussian(target)
501
604
  return pred_weight, pred_probability, target_weight, target_probability
@@ -503,82 +606,95 @@ class MutualInformationLoss(torch.nn.Module):
503
606
  def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
504
607
  img = torch.clamp(img, 0, 1)
505
608
  img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
506
- weight = torch.exp(-self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2) # (batch, num_sample, num_bin)
609
+ weight = torch.exp(
610
+ -self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2
611
+ ) # (batch, num_sample, num_bin)
507
612
  weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin)
508
613
  probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin)
509
614
  return weight, probability
510
615
 
511
- def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
512
- wa, pa, wb, pb = self.parzen_windowing(pred, target) # (batch, num_sample, num_bin), (batch, 1, num_bin)
616
+ def forward(self, pred: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
617
+ wa, pa, wb, pb = self.parzen_windowing(pred, targets[0]) # (batch, num_sample, num_bin), (batch, 1, num_bin)
513
618
  pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
514
619
  papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
515
- mi = torch.sum(pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)) # (batch)
620
+ mi = torch.sum(
621
+ pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr),
622
+ dim=(1, 2),
623
+ ) # (batch)
516
624
  return torch.mean(mi).neg() # average over the batch and channel ndims
517
625
 
518
- class FOCUS(Criterion): #Feature-Oriented Comparison for Unpaired Synthesis
519
626
 
520
- def __init__(self, model_name: str, shape: list[int] = [0,0], in_channels: int = 1, losses: dict[str, list[float]] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
627
+ class IMPACTSynth(Criterion): # Feature-Oriented Comparison for Unpaired Synthesis
628
+
629
+ def __init__(
630
+ self, model_name: str, shape: list[int] = [0, 0], in_channels: int = 1, weights: list[float] = [1]
631
+ ) -> None:
521
632
  super().__init__()
522
633
  if model_name is None:
523
634
  return
524
- self.shape = shape
525
635
  self.in_channels = in_channels
526
- self.losses: dict[torch.nn.Module, list[float]] = {}
527
- for loss, weights in losses.items():
528
- module, name = _getModule(loss, "konfai.metric.measure")
529
- self.losses[config(os.environ['KONFAI_CONFIG_PATH'])(getattr(importlib.import_module(module), name))(config=None)] = weights
530
-
531
- self.model_path = download_url(model_name, "https://huggingface.co/VBoussot/impact-torchscript-models/resolve/main/")
532
- self.model: torch.nn.Module = torch.jit.load(self.model_path)
533
- self.dim = len(self.shape)
534
- self.shape = self.shape if all([s > 0 for s in self.shape]) else None
636
+ self.loss = torch.nn.L1Loss()
637
+ self.weights = weights
638
+ self.model_path = download_url(
639
+ model_name,
640
+ "https://huggingface.co/VBoussot/impact-torchscript-models/resolve/main/",
641
+ )
642
+ self.model: torch.nn.Module = torch.jit.load(self.model_path) # nosec B614
643
+ self.dim = len(shape)
644
+ self.shape = shape if all(s > 0 for s in shape) else None
535
645
  self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
536
646
 
537
647
  try:
538
- dummy_input = torch.zeros((1, self.in_channels, *(self.shape if self.shape else [224] * self.dim)))
648
+ dummy_input = torch.zeros((1, self.in_channels, *(self.shape if self.shape else [224] * self.dim))).to(0)
539
649
  out = self.model(dummy_input)
540
650
  if not isinstance(out, (list, tuple)):
541
651
  raise TypeError(f"Expected model output to be a list or tuple, but got {type(out)}.")
542
- if len(self.weight) != len(out):
543
- raise ValueError(f"Mismatch between number of weights ({len(self.weight)}) and model outputs ({len(out)}).")
652
+ if len(weights) != len(out):
653
+ raise ValueError(f"Mismatch between number of weights ({len(weights)}) and model outputs ({len(out)}).")
544
654
  except Exception as e:
545
655
  msg = (
546
656
  f"[Model Sanity Check Failed]\n"
547
657
  f"Input shape attempted: {dummy_input.shape}\n"
548
- f"Expected output length: {len(self.weight)}\n"
658
+ f"Expected output length: {len(weights)}\n"
549
659
  f"Error: {type(e).__name__}: {e}"
550
660
  )
551
661
  raise RuntimeError(msg) from e
552
662
  self.model = None
553
-
554
- def preprocessing(self, input: torch.Tensor) -> torch.Tensor:
555
- if self.shape is not None and not all([input.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]):
556
- input = torch.nn.functional.interpolate(input, mode=self.mode, size=tuple(self.shape), align_corners=False).type(torch.float32)
557
- if input.shape[1] != self.in_channels:
558
- input = input.repeat(tuple([1,3] + [1 for _ in range(self.dim)]))
559
- return input
560
-
663
+
664
+ def preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
665
+ if self.shape is not None and not all(
666
+ tensor.shape[-i - 1] == size for i, size in enumerate(reversed(self.shape[2:]))
667
+ ):
668
+ tensor = torch.nn.functional.interpolate(
669
+ tensor, mode=self.mode, size=tuple(self.shape), align_corners=False
670
+ ).type(torch.float32)
671
+ if tensor.shape[1] != self.in_channels:
672
+ tensor = tensor.repeat(tuple([1, 3] + [1 for _ in range(self.dim)]))
673
+ return tensor
674
+
561
675
  def _compute(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
562
- loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
676
+ loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
563
677
  output = self.preprocessing(output)
564
678
  targets = [self.preprocessing(target) for target in targets]
565
679
  self.model.to(output.device)
566
- for zipped_output in zip(self.model(output), *[self.model(target) for target in targets]):
567
- output_features = zipped_output[0]
568
- targets_features = zipped_output[1:]
569
- for target_features, (loss_function, weights) in zip(targets_features, self.losses.items()):
570
- for output_feature, target_feature, weight in zip(output_features ,target_features, weights):
571
- loss += weights*loss_function(output_feature, target_feature)
680
+ for zipped_output in zip(self.weights, self.model(output), *[self.model(target) for target in targets]):
681
+ weight = zipped_output[0]
682
+ if weight == 0:
683
+ continue
684
+ output_feature = zipped_output[1]
685
+ targets_features = zipped_output[1:]
686
+ for target_feature in targets_features:
687
+ loss += weight * self.loss(output_feature, target_feature)
572
688
  return loss
573
-
574
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
689
+
690
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
575
691
  if self.model is None:
576
- self.model = torch.jit.load(self.model_path)
577
- loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
692
+ self.model = torch.jit.load(self.model_path) # nosec B614
693
+ loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
578
694
  if len(output.shape) == 5 and self.dim == 2:
579
695
  for i in range(output.shape[2]):
580
- loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])
581
- loss/=output.shape[2]
696
+ loss += self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])
697
+ loss /= output.shape[2]
582
698
  else:
583
- loss = self._compute(output, targets)
584
- return loss.to(output)
699
+ loss = self._compute(output, list(targets))
700
+ return loss.to(output)