konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +509 -290
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +384 -277
  6. konfai/evaluator.py +309 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +341 -222
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +781 -423
  23. konfai/predictor.py +645 -240
  24. konfai/trainer.py +527 -216
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +495 -249
  29. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
  30. konfai-1.1.9.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.7.dist-info/RECORD +0 -39
  33. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/metric/measure.py CHANGED
@@ -1,334 +1,428 @@
1
- from abc import ABC
1
+ import copy
2
2
  import importlib
3
- import numpy as np
4
- import torch
5
-
6
- from torchvision.models import inception_v3, Inception_V3_Weights
7
- from torchvision.transforms import functional as F
8
- from scipy import linalg
9
-
10
- import torch.nn.functional as F
11
3
  import os
12
-
13
- from typing import Callable, Union
4
+ from abc import ABC, abstractmethod
5
+ from collections.abc import Callable
14
6
  from functools import partial
15
- import torch.nn.functional as F
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F # noqa: N812
11
+ from scipy import linalg
16
12
  from skimage.metrics import structural_similarity
17
- import copy
18
- from abc import abstractmethod
13
+ from torchvision.models import Inception_V3_Weights, inception_v3
14
+ from tqdm import tqdm
19
15
 
20
- from konfai.utils.config import config
21
- from konfai.utils.utils import _getModule, download_url
22
16
  from konfai.data.patching import ModelPatch
23
17
  from konfai.network.blocks import LatentDistribution
24
18
  from konfai.network.network import ModelLoader, Network
25
- from tqdm import tqdm
19
+ from konfai.utils.config import config
20
+ from konfai.utils.utils import download_url, get_module
21
+
22
+ models_register = {}
26
23
 
27
- modelsRegister = {}
28
24
 
29
25
  class Criterion(torch.nn.Module, ABC):
30
26
 
31
27
  def __init__(self) -> None:
32
28
  super().__init__()
33
29
 
34
- def init(self, model : torch.nn.Module, output_group : str, target_group : str) -> str:
30
+ def init(self, model: torch.nn.Module, output_group: str, target_group: str) -> str:
35
31
  return output_group
36
32
 
37
33
  @abstractmethod
38
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
34
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
39
35
  pass
40
36
 
37
+
41
38
  class MaskedLoss(Criterion):
42
39
 
43
- def __init__(self, loss : Callable[[torch.Tensor, torch.Tensor], torch.Tensor], mode_image_masked: bool) -> None:
40
+ def __init__(
41
+ self,
42
+ loss: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
43
+ mode_image_masked: bool,
44
+ ) -> None:
44
45
  super().__init__()
45
46
  self.loss = loss
46
47
  self.mode_image_masked = mode_image_masked
47
48
 
48
- def getMask(self, targets: list[torch.Tensor]) -> torch.Tensor:
49
+ def get_mask(self, *targets: torch.Tensor) -> torch.Tensor:
49
50
  result = None
50
51
  if len(targets) > 0:
51
52
  result = targets[0]
52
53
  for mask in targets[1:]:
53
- result = result*mask
54
+ result = result * mask
54
55
  return result
55
56
 
56
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
57
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
57
58
  loss = torch.tensor(0, dtype=torch.float32).to(output.device)
58
- mask = self.getMask(targets[1:])
59
+ mask = self.get_mask(targets[1:])
59
60
  for batch in range(output.shape[0]):
60
61
  if mask is not None:
61
62
  if self.mode_image_masked:
62
63
  for i in torch.unique(mask[batch]):
63
64
  if i != 0:
64
- loss += self.loss(output[batch, ...]*torch.where(mask == i, 1, 0), targets[0][batch, ...]*torch.where(mask == i, 1, 0))
65
+ loss += self.loss(
66
+ output[batch, ...] * torch.where(mask == i, 1, 0),
67
+ targets[0][batch, ...] * torch.where(mask == i, 1, 0),
68
+ )
65
69
  else:
66
70
  for i in torch.unique(mask[batch]):
67
71
  if i != 0:
68
72
  index = mask[batch, ...] == i
69
- loss += self.loss(torch.masked_select(output[batch, ...], index), torch.masked_select(targets[0][batch, ...], index))
73
+ loss += self.loss(
74
+ torch.masked_select(output[batch, ...], index),
75
+ torch.masked_select(targets[0][batch, ...], index),
76
+ )
70
77
  else:
71
78
  loss += self.loss(output[batch, ...], targets[0][batch, ...])
72
- return loss/output.shape[0]
73
-
79
+ return loss / output.shape[0]
80
+
81
+
74
82
  class MSE(MaskedLoss):
75
83
 
84
+ @staticmethod
76
85
  def _loss(reduction: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
77
86
  return torch.nn.MSELoss(reduction=reduction)(x, y)
78
87
 
79
88
  def __init__(self, reduction: str = "mean") -> None:
80
89
  super().__init__(partial(MSE._loss, reduction), False)
81
90
 
91
+
82
92
  class MAE(MaskedLoss):
83
93
 
94
+ @staticmethod
84
95
  def _loss(reduction: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
85
96
  return torch.nn.L1Loss(reduction=reduction)(x, y)
86
-
97
+
87
98
  def __init__(self, reduction: str = "mean") -> None:
88
99
  super().__init__(partial(MAE._loss, reduction), False)
89
100
 
101
+
90
102
  class PSNR(MaskedLoss):
91
103
 
92
- def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
104
+ @staticmethod
105
+ def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
93
106
  mse = torch.mean((x - y).pow(2))
94
- psnr = 10 * torch.log10(dynamic_range**2 / mse)
95
- return psnr
107
+ psnr = 10 * torch.log10(dynamic_range**2 / mse)
108
+ return psnr
96
109
 
97
- def __init__(self, dynamic_range: Union[float, None] = None) -> None:
98
- dynamic_range = dynamic_range if dynamic_range else 1024+3071
110
+ def __init__(self, dynamic_range: float | None = None) -> None:
111
+ dynamic_range = dynamic_range if dynamic_range else 1024 + 3071
99
112
  super().__init__(partial(PSNR._loss, dynamic_range), False)
100
-
113
+
114
+
101
115
  class SSIM(MaskedLoss):
102
-
116
+
117
+ @staticmethod
103
118
  def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
104
- return structural_similarity(x[0][0].detach().cpu().numpy(), y[0][0].cpu().numpy(), data_range=dynamic_range, gradient=False, full=False)
105
-
106
- def __init__(self, dynamic_range: Union[float, None] = None) -> None:
107
- dynamic_range = dynamic_range if dynamic_range else 1024+3000
119
+ return structural_similarity(
120
+ x[0][0].detach().cpu().numpy(),
121
+ y[0][0].cpu().numpy(),
122
+ data_range=dynamic_range,
123
+ gradient=False,
124
+ full=False,
125
+ )
126
+
127
+ def __init__(self, dynamic_range: float | None = None) -> None:
128
+ dynamic_range = dynamic_range if dynamic_range else 1024 + 3000
108
129
  super().__init__(partial(SSIM._loss, dynamic_range), True)
109
130
 
110
131
 
111
132
  class LPIPS(MaskedLoss):
112
133
 
113
- def normalize(input: torch.Tensor) -> torch.Tensor:
114
- return (input-torch.min(input))/(torch.max(input)-torch.min(input))*2-1
115
-
116
- def preprocessing(input: torch.Tensor) -> torch.Tensor:
117
- return input.repeat((1,3,1,1))
118
-
134
+ @staticmethod
135
+ def normalize(tensor: torch.Tensor) -> torch.Tensor:
136
+ return (tensor - torch.min(tensor)) / (torch.max(tensor) - torch.min(tensor)) * 2 - 1
137
+
138
+ @staticmethod
139
+ def preprocessing(tensor: torch.Tensor) -> torch.Tensor:
140
+ return tensor.repeat((1, 3, 1, 1))
141
+
142
+ @staticmethod
119
143
  def _loss(loss_fn_alex, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
120
- datasetPatch = ModelPatch([1, 64, 64])
121
- datasetPatch.load(x.shape[2:])
144
+ dataset_patch = ModelPatch([1, 64, 64])
145
+ dataset_patch.load(x.shape[2:])
122
146
 
123
- patchIterator = datasetPatch.disassemble(LPIPS.normalize(x), LPIPS.normalize(y))
147
+ patch_iterator = dataset_patch.disassemble(LPIPS.normalize(x), LPIPS.normalize(y))
124
148
  loss = 0
125
- with tqdm(iterable = enumerate(patchIterator), leave=False, total=datasetPatch.getSize(0)) as batch_iter:
126
- for i, patch_input in batch_iter:
149
+ with tqdm(
150
+ iterable=enumerate(patch_iterator),
151
+ leave=False,
152
+ total=dataset_patch.get_size(0),
153
+ ) as batch_iter:
154
+ for _, patch_input in batch_iter:
127
155
  real, fake = LPIPS.preprocessing(patch_input[0]), LPIPS.preprocessing(patch_input[1])
128
156
  loss += loss_fn_alex(real, fake).item()
129
- return loss/datasetPatch.getSize(0)
157
+ return loss / dataset_patch.get_size(0)
130
158
 
131
159
  def __init__(self, model: str = "alex") -> None:
132
160
  import lpips
161
+
133
162
  super().__init__(partial(LPIPS._loss, lpips.LPIPS(net=model)), True)
134
163
 
164
+
135
165
  class Dice(Criterion):
136
-
137
- def __init__(self, smooth : float = 1e-6) -> None:
166
+
167
+ def __init__(self, smooth: float = 1e-6) -> None:
138
168
  super().__init__()
139
169
  self.smooth = smooth
140
-
141
- def flatten(self, tensor : torch.Tensor) -> torch.Tensor:
170
+
171
+ def flatten(self, tensor: torch.Tensor) -> torch.Tensor:
142
172
  return tensor.permute((1, 0) + tuple(range(2, tensor.dim()))).contiguous().view(tensor.size(1), -1)
143
173
 
144
- def dice_per_channel(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
145
- input = self.flatten(input)
174
+ def dice_per_channel(self, tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
175
+ tensor = self.flatten(tensor)
146
176
  target = self.flatten(target)
147
- return (2.*(input * target).sum() + self.smooth)/(input.sum() + target.sum() + self.smooth)
177
+ return (2.0 * (tensor * target).sum() + self.smooth) / (tensor.sum() + target.sum() + self.smooth)
148
178
 
149
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
150
- target = F.interpolate(
151
- targets[0],
152
- output.shape[2:],
153
- mode="nearest")
179
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
180
+ target = F.interpolate(targets[0], output.shape[2:], mode="nearest")
154
181
  if output.shape[1] == 1:
155
- output = F.one_hot(output.type(torch.int64), num_classes=int(torch.max(output).item()+1)).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
156
- target = F.one_hot(target.type(torch.int64), num_classes=output.shape[1]).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float().squeeze(2)
157
- return 1-torch.mean(self.dice_per_channel(output, target))
182
+ output = (
183
+ F.one_hot(
184
+ output.type(torch.int64),
185
+ num_classes=int(torch.max(output).item() + 1),
186
+ )
187
+ .permute(0, len(target.shape), *[i + 1 for i in range(len(target.shape) - 1)])
188
+ .float()
189
+ )
190
+ target = (
191
+ F.one_hot(target.type(torch.int64), num_classes=output.shape[1])
192
+ .permute(0, len(target.shape), *[i + 1 for i in range(len(target.shape) - 1)])
193
+ .float()
194
+ .squeeze(2)
195
+ )
196
+ return 1 - torch.mean(self.dice_per_channel(output, target))
197
+
158
198
 
159
199
  class GradientImages(Criterion):
160
200
 
161
201
  def __init__(self):
162
202
  super().__init__()
163
-
203
+
164
204
  @staticmethod
165
- def _image_gradient2D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
205
+ def _image_gradient_2d(image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
166
206
  dx = image[:, :, 1:, :] - image[:, :, :-1, :]
167
207
  dy = image[:, :, :, 1:] - image[:, :, :, :-1]
168
208
  return dx, dy
169
209
 
170
210
  @staticmethod
171
- def _image_gradient3D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
211
+ def _image_gradient_3d(
212
+ image: torch.Tensor,
213
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
172
214
  dx = image[:, :, 1:, :, :] - image[:, :, :-1, :, :]
173
215
  dy = image[:, :, :, 1:, :] - image[:, :, :, :-1, :]
174
216
  dz = image[:, :, :, :, 1:] - image[:, :, :, :, :-1]
175
217
  return dx, dy, dz
176
-
177
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
218
+
219
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
178
220
  target_0 = targets[0]
179
221
  if len(output.shape) == 5:
180
- dx, dy, dz = GradientImages._image_gradient3D(output)
222
+ dx, dy, dz = GradientImages._image_gradient_3d(output)
181
223
  if target_0 is not None:
182
- dx_tmp, dy_tmp, dz_tmp = GradientImages._image_gradient3D(target_0)
224
+ dx_tmp, dy_tmp, dz_tmp = GradientImages._image_gradient_3d(target_0)
183
225
  dx -= dx_tmp
184
226
  dy -= dy_tmp
185
227
  dz -= dz_tmp
186
228
  return dx.norm() + dy.norm() + dz.norm()
187
229
  else:
188
- dx, dy = GradientImages._image_gradient2D(output)
230
+ dx, dy = GradientImages._image_gradient_2d(output)
189
231
  if target_0 is not None:
190
- dx_tmp, dy_tmp = GradientImages._image_gradient2D(target_0)
232
+ dx_tmp, dy_tmp = GradientImages._image_gradient_2d(target_0)
191
233
  dx -= dx_tmp
192
234
  dy -= dy_tmp
193
235
  return dx.norm() + dy.norm()
194
-
236
+
237
+
195
238
  class BCE(Criterion):
196
239
 
197
- def __init__(self, target : float = 0) -> None:
240
+ def __init__(self, target: float = 0) -> None:
198
241
  super().__init__()
199
242
  self.loss = torch.nn.BCEWithLogitsLoss()
200
- self.register_buffer('target', torch.tensor(target).type(torch.float32))
243
+ self.register_buffer("target", torch.tensor(target).type(torch.float32))
201
244
 
202
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
245
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
203
246
  target = self._buffers["target"]
204
247
  return self.loss(output, target.to(output.device).expand_as(output))
205
248
 
249
+
206
250
  class PatchGanLoss(Criterion):
207
251
 
208
- def __init__(self, target : float = 0) -> None:
252
+ def __init__(self, target: float = 0) -> None:
209
253
  super().__init__()
210
254
  self.loss = torch.nn.MSELoss()
211
- self.register_buffer('target', torch.tensor(target).type(torch.float32))
255
+ self.register_buffer("target", torch.tensor(target).type(torch.float32))
212
256
 
213
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
257
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
214
258
  target = self._buffers["target"]
215
- return self.loss(output, (torch.ones_like(output)*target).to(output.device))
259
+ return self.loss(output, (torch.ones_like(output) * target).to(output.device))
260
+
216
261
 
217
262
  class WGP(Criterion):
218
263
 
219
264
  def __init__(self) -> None:
220
265
  super().__init__()
221
-
222
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
223
- return torch.mean((output - 1)**2)
266
+
267
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
268
+ return torch.mean((output - 1) ** 2)
269
+
224
270
 
225
271
  class Gram(Criterion):
226
272
 
227
- def computeGram(input : torch.Tensor):
228
- (b, ch, w) = input.size()
229
- with torch.amp.autocast('cuda', enabled=False):
230
- return input.bmm(input.transpose(1, 2)).div(ch*w)
273
+ @staticmethod
274
+ def compute_gram(tensor: torch.Tensor):
275
+ (b, ch, w) = tensor.size()
276
+ with torch.amp.autocast("cuda", enabled=False):
277
+ return tensor.bmm(tensor.transpose(1, 2)).div(ch * w)
231
278
 
232
279
  def __init__(self) -> None:
233
280
  super().__init__()
234
- self.loss = torch.nn.L1Loss(reduction='sum')
281
+ self.loss = torch.nn.L1Loss(reduction="sum")
235
282
 
236
- def forward(self, output: torch.Tensor, target : torch.Tensor) -> torch.Tensor:
283
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
284
+ target = targets[0]
237
285
  if len(output.shape) > 3:
238
286
  output = output.view(output.shape[0], output.shape[1], int(np.prod(output.shape[2:])))
239
287
  if len(target.shape) > 3:
240
288
  target = target.view(target.shape[0], target.shape[1], int(np.prod(target.shape[2:])))
241
- return self.loss(Gram.computeGram(output), Gram.computeGram(target))
289
+ return self.loss(Gram.compute_gram(output), Gram.compute_gram(target))
290
+
242
291
 
243
292
  class PerceptualLoss(Criterion):
244
-
245
- class Module():
246
-
293
+
294
+ class Module:
295
+
247
296
  @config()
248
297
  def __init__(self, losses: dict[str, float] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
249
298
  self.losses = losses
250
- self.DL_args = os.environ['KONFAI_CONFIG_PATH'] if "KONFAI_CONFIG_PATH" in os.environ else ""
299
+ self.konfai_args = os.environ["KONFAI_CONFIG_PATH"] if "KONFAI_CONFIG_PATH" in os.environ else ""
251
300
 
252
- def getLoss(self) -> dict[torch.nn.Module, float]:
301
+ def get_loss(self) -> dict[torch.nn.Module, float]:
253
302
  result: dict[torch.nn.Module, float] = {}
254
- for loss, l in self.losses.items():
255
- module, name = _getModule(loss, "konfai.metric.measure")
256
- result[config(self.DL_args)(getattr(importlib.import_module(module), name))(config=None)] = l
303
+ for loss, loss_value in self.losses.items():
304
+ module, name = get_module(loss, "konfai.metric.measure")
305
+ result[config(self.konfai_args)(getattr(importlib.import_module(module), name))(config=None)] = (
306
+ loss_value
307
+ )
257
308
  return result
258
-
259
- def __init__(self, modelLoader : ModelLoader = ModelLoader(), path_model : str = "name", modules : dict[str, Module] = {"UNetBlock_0.DownConvBlock.Activation_1": Module({"Gram": 1, "torch_nn_L1Loss": 1})}, shape: list[int] = [128, 128, 128]) -> None:
309
+
310
+ def __init__(
311
+ self,
312
+ model_loader: ModelLoader = ModelLoader(),
313
+ path_model: str = "name",
314
+ modules: dict[str, Module] = {
315
+ "UNetBlock_0.DownConvBlock.Activation_1": Module({"Gram": 1, "torch_nn_L1Loss": 1})
316
+ },
317
+ shape: list[int] = [128, 128, 128],
318
+ ) -> None:
260
319
  super().__init__()
261
320
  self.path_model = path_model
262
- if self.path_model not in modelsRegister:
263
- self.model = modelLoader.getModel(train=False, DL_args=os.environ['KONFAI_CONFIG_PATH'].split("PerceptualLoss")[0]+"PerceptualLoss.Model", DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain", "outputsCriterions", "drop_p"])
321
+ if self.path_model not in models_register:
322
+ self.model = model_loader.get_model(
323
+ train=False,
324
+ konfai_args=os.environ["KONFAI_CONFIG_PATH"].split("PerceptualLoss")[0] + "PerceptualLoss.Model",
325
+ konfai_without=[
326
+ "optimizer",
327
+ "schedulers",
328
+ "nb_batch_per_step",
329
+ "init_type",
330
+ "init_gain",
331
+ "outputs_criterions",
332
+ "drop_p",
333
+ ],
334
+ )
264
335
  if path_model.startswith("https"):
265
336
  state_dict = torch.hub.load_state_dict_from_url(path_model)
266
- state_dict = {"Model": {self.model.getName() : state_dict["model"]}}
337
+ state_dict = {"Model": {self.model.get_name(): state_dict["model"]}}
267
338
  else:
268
339
  state_dict = torch.load(path_model, weights_only=True)
269
340
  self.model.load(state_dict)
270
- modelsRegister[self.path_model] = self.model
341
+ models_register[self.path_model] = self.model
271
342
  else:
272
- self.model = modelsRegister[self.path_model]
343
+ self.model = models_register[self.path_model]
273
344
 
274
345
  self.shape = shape
275
- self.mode = "trilinear" if len(shape) == 3 else "bilinear"
346
+ self.mode = "trilinear" if len(shape) == 3 else "bilinear"
276
347
  self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
277
348
  for name, losses in modules.items():
278
- self.modules_loss[name.replace(":", ".")] = losses.getLoss()
349
+ self.modules_loss[name.replace(":", ".")] = losses.get_loss()
279
350
 
280
351
  self.model.eval()
281
352
  self.model.requires_grad_(False)
282
353
  self.models: dict[int, torch.nn.Module] = {}
283
354
 
284
- def preprocessing(self, input: torch.Tensor) -> torch.Tensor:
285
- #if not all([input.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]):
286
- # input = F.interpolate(input, mode=self.mode, size=tuple(self.shape), align_corners=False).type(torch.float32)
287
- #if input.shape[1] != self.model.in_channels:
288
- # input = input.repeat(tuple([1,self.model.in_channels] + [1 for _ in range(len(self.shape))]))
289
- #input = (input - torch.min(input))/(torch.max(input)-torch.min(input))
290
- #input = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input)
291
- return input
292
-
293
- def _compute(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
294
- loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
295
- output = self.preprocessing(output)
296
- targets = [self.preprocessing(target) for target in targets]
297
- for zipped_output in zip([output], *[[target] for target in targets]):
355
+ def preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
356
+ # if not all([tensor.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]):
357
+ # tensor = F.interpolate(tensor, mode=self.mode,
358
+ # size=tuple(self.shape), align_corners=False).type(torch.float32)
359
+ # if tensor.shape[1] != self.model.in_channels:
360
+ # tensor = tensor.repeat(tuple([1,self.model.in_channels] + [1 for _ in range(len(self.shape))]))
361
+ # tensor = (tensor - torch.min(tensor))/(torch.max(tensor)-torch.min(tensor))
362
+ # tensor = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(tensor)
363
+ return tensor
364
+
365
+ def _compute(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
366
+ loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
367
+ output_preprocessing = self.preprocessing(output)
368
+ targets_preprocessing = [self.preprocessing(target) for target in targets]
369
+ for zipped_output in zip([output_preprocessing], *[[target] for target in targets_preprocessing]):
298
370
  output = zipped_output[0]
299
371
  targets = zipped_output[1:]
300
372
 
301
- for zipped_layers in list(zip(self.models[output.device.index].get_layers([output], set(self.modules_loss.keys()).copy()), *[self.models[output.device.index].get_layers([target], set(self.modules_loss.keys()).copy()) for target in targets])):
302
- output_layer = zipped_layers[0][1].view(zipped_layers[0][1].shape[0], zipped_layers[0][1].shape[1], int(np.prod(zipped_layers[0][1].shape[2:])))
303
- for (loss_function, l), target_layer in zip(self.modules_loss[zipped_layers[0][0]].items(), zipped_layers[1:]):
304
- target_layer = target_layer[1].view(target_layer[1].shape[0], target_layer[1].shape[1], int(np.prod(target_layer[1].shape[2:])))
305
- loss = loss+l*loss_function(output_layer.float(), target_layer.float())/output_layer.shape[0]
373
+ for zipped_layers in list(
374
+ zip(
375
+ self.models[output.device.index].get_layers([output], set(self.modules_loss.keys()).copy()),
376
+ *[
377
+ self.models[output.device.index].get_layers([target], set(self.modules_loss.keys()).copy())
378
+ for target in targets
379
+ ],
380
+ )
381
+ ):
382
+ output_layer = zipped_layers[0][1].view(
383
+ zipped_layers[0][1].shape[0],
384
+ zipped_layers[0][1].shape[1],
385
+ int(np.prod(zipped_layers[0][1].shape[2:])),
386
+ )
387
+ for (loss_function, loss_value), target_layer in zip(
388
+ self.modules_loss[zipped_layers[0][0]].items(), zipped_layers[1:]
389
+ ):
390
+ target_layer = target_layer[1].view(
391
+ target_layer[1].shape[0],
392
+ target_layer[1].shape[1],
393
+ int(np.prod(target_layer[1].shape[2:])),
394
+ )
395
+ loss = (
396
+ loss
397
+ + loss_value * loss_function(output_layer.float(), target_layer.float()) / output_layer.shape[0]
398
+ )
306
399
  return loss
307
-
308
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
400
+
401
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
309
402
  if output.device.index not in self.models:
310
403
  del os.environ["device"]
311
404
  self.models[output.device.index] = Network.to(copy.deepcopy(self.model).eval(), output.device.index).eval()
312
- loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
405
+ loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
313
406
  if len(output.shape) == 5 and len(self.shape) == 2:
314
407
  for i in range(output.shape[2]):
315
- loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])/output.shape[2]
408
+ loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets]) / output.shape[2]
316
409
  else:
317
410
  loss = self._compute(output, targets)
318
411
  return loss.to(output)
319
412
 
413
+
320
414
  class KLDivergence(Criterion):
321
-
322
- def __init__(self, shape: list[int], dim : int = 100, mu : float = 0, std : float = 1) -> None:
415
+
416
+ def __init__(self, shape: list[int], dim: int = 100, mu: float = 0, std: float = 1) -> None:
323
417
  super().__init__()
324
- self.latentDim = dim
418
+ self.latent_dim = dim
325
419
  self.mu = torch.Tensor([mu])
326
420
  self.std = torch.Tensor([std])
327
421
  self.modelDim = 3
328
422
  self.shape = shape
329
423
  self.loss = torch.nn.KLDivLoss()
330
-
331
- def init(self, model : Network, output_group : str, target_group : str) -> str:
424
+
425
+ def init(self, model: Network, output_group: str, target_group: str) -> str:
332
426
  super().init(model, output_group, target_group)
333
427
  model._compute_channels_trace(model, model.in_channels, None, None)
334
428
 
@@ -338,52 +432,35 @@ class KLDivergence(Criterion):
338
432
 
339
433
  modules = last_module._modules.copy()
340
434
  last_module._modules.clear()
341
-
435
+
342
436
  for name, value in modules.items():
343
437
  last_module._modules[name] = value
344
438
  if name == output_group.split(".")[-1]:
345
- last_module.add_module("LatentDistribution", LatentDistribution(shape = self.shape, latentDim=self.latentDim))
346
- return ".".join(output_group.split(".")[:-1])+".LatentDistribution.Concat"
439
+ last_module.add_module(
440
+ "LatentDistribution",
441
+ LatentDistribution(shape=self.shape, latent_dim=self.latent_dim),
442
+ )
443
+ return ".".join(output_group.split(".")[:-1]) + ".LatentDistribution.Concat"
347
444
 
348
- def forward(self, output: torch.Tensor, targets : list[torch.Tensor]) -> torch.Tensor:
445
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
349
446
  mu = output[:, 0, :]
350
447
  log_std = output[:, 1, :]
351
- return torch.mean(-0.5 * torch.sum(1 + log_std - mu**2 - torch.exp(log_std), dim = 1), dim = 0)
352
-
353
- """
354
- def forward(self, input : torch.Tensor) -> torch.Tensor:
355
- mu = input[:, 0, :]
356
- log_std = input[:, 1, :]
357
-
358
- z = input[:, 2, :]
448
+ return torch.mean(-0.5 * torch.sum(1 + log_std - mu**2 - torch.exp(log_std), dim=1), dim=0)
359
449
 
360
- q = torch.distributions.Normal(mu, log_std)
361
450
 
362
- target_mu = torch.ones((self.latentDim)).to(input.device)*self.mu.to(input.device)
363
- target_std = torch.ones((self.latentDim)).to(input.device)*self.std.to(input.device)
364
-
365
- p = torch.distributions.Normal(target_mu, target_std)
366
-
367
- log_pz = p.log_prob(z)
368
- log_qzx = q.log_prob(z)
369
-
370
- kl = (log_pz - log_qzx)
371
- kl = kl.sum(-1)
372
- return kl
373
- """
374
-
375
451
  class Accuracy(Criterion):
376
452
 
377
453
  def __init__(self) -> None:
378
454
  super().__init__()
379
- self.n : int = 0
380
- self.corrects = torch.zeros((1))
455
+ self.n: int = 0
456
+ self.corrects = torch.zeros(1)
381
457
 
382
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
458
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
383
459
  target_0 = targets[0]
384
460
  self.n += output.shape[0]
385
461
  self.corrects += (torch.argmax(torch.softmax(output, dim=1), dim=1) == target_0).sum().float().cpu()
386
- return self.corrects/self.n
462
+ return self.corrects / self.n
463
+
387
464
 
388
465
  class TripletLoss(Criterion):
389
466
 
@@ -391,9 +468,10 @@ class TripletLoss(Criterion):
391
468
  super().__init__()
392
469
  self.triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
393
470
 
394
- def forward(self, output: torch.Tensor) -> torch.Tensor:
471
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
395
472
  return self.triplet_loss(output[0], output[1], output[2])
396
473
 
474
+
397
475
  class L1LossRepresentation(Criterion):
398
476
 
399
477
  def __init__(self) -> None:
@@ -403,28 +481,31 @@ class L1LossRepresentation(Criterion):
403
481
  def _variance(self, features: torch.Tensor) -> torch.Tensor:
404
482
  return torch.mean(torch.clamp(1 - torch.var(features, dim=0), min=0))
405
483
 
406
- def forward(self, output: torch.Tensor) -> torch.Tensor:
407
- return self.loss(output[0], output[1])+ self._variance(output[0]) + self._variance(output[1])
484
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
485
+ return self.loss(output[0], output[1]) + self._variance(output[0]) + self._variance(output[1])
486
+
408
487
 
409
488
  class FocalLoss(Criterion):
410
489
 
411
- def __init__(self, gamma: float = 2.0, alpha: list[float] = [0.5, 2.0, 0.5, 0.5, 1], reduction: str = "mean"):
490
+ def __init__(
491
+ self,
492
+ gamma: float = 2.0,
493
+ alpha: list[float] = [0.5, 2.0, 0.5, 0.5, 1],
494
+ reduction: str = "mean",
495
+ ):
412
496
  super().__init__()
413
497
  raw_alpha = torch.tensor(alpha, dtype=torch.float32)
414
498
  self.alpha = raw_alpha / raw_alpha.sum() * len(raw_alpha)
415
499
  self.gamma = gamma
416
500
  self.reduction = reduction
417
501
 
418
- def forward(self, output: torch.Tensor, *targets: list[torch.Tensor]) -> torch.Tensor:
419
- target = F.interpolate(
420
- targets[0],
421
- output.shape[2:],
422
- mode="nearest").long()
423
-
502
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
503
+ target = F.interpolate(targets[0], output.shape[2:], mode="nearest").long()
504
+
424
505
  logpt = F.log_softmax(output, dim=1)
425
506
  pt = torch.exp(logpt)
426
507
 
427
- logpt = logpt.gather(1, target)
508
+ logpt = logpt.gather(1, target)
428
509
  pt = pt.gather(1, target)
429
510
 
430
511
  at = self.alpha.to(target.device)[target].unsqueeze(1)
@@ -435,7 +516,7 @@ class FocalLoss(Criterion):
435
516
  elif self.reduction == "sum":
436
517
  return loss.sum()
437
518
  return loss
438
-
519
+
439
520
 
440
521
  class FID(Criterion):
441
522
 
@@ -449,19 +530,26 @@ class FID(Criterion):
449
530
 
450
531
  def forward(self, x: torch.Tensor) -> torch.Tensor:
451
532
  return self.model(x)
452
-
533
+
453
534
  def __init__(self) -> None:
454
535
  super().__init__()
455
536
  self.inception_model = FID.InceptionV3().cuda()
456
-
537
+
538
+ @staticmethod
457
539
  def preprocess_images(image: torch.Tensor) -> torch.Tensor:
458
- return F.normalize(F.resize(image, (299, 299)).repeat((1,3,1,1)), mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).cuda()
540
+ return F.normalize(
541
+ F.resize(image, (299, 299)).repeat((1, 3, 1, 1)),
542
+ mean=[0.485, 0.456, 0.406],
543
+ std=[0.229, 0.224, 0.225],
544
+ ).cuda()
459
545
 
546
+ @staticmethod
460
547
  def get_features(images: torch.Tensor, model: torch.nn.Module) -> np.ndarray:
461
548
  with torch.no_grad():
462
549
  features = model(images).cpu().numpy()
463
550
  return features
464
551
 
552
+ @staticmethod
465
553
  def calculate_fid(real_features: np.ndarray, generated_features: np.ndarray) -> float:
466
554
  mu1 = np.mean(real_features, axis=0)
467
555
  sigma1 = np.cov(real_features, rowvar=False)
@@ -475,7 +563,7 @@ class FID(Criterion):
475
563
 
476
564
  return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
477
565
 
478
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
566
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
479
567
  real_images = FID.preprocess_images(targets[0].squeeze(0).permute([1, 0, 2, 3]))
480
568
  generated_images = FID.preprocess_images(output.squeeze(0).permute([1, 0, 2, 3]))
481
569
 
@@ -484,8 +572,15 @@ class FID(Criterion):
484
572
 
485
573
  return FID.calculate_fid(real_features, generated_features)
486
574
 
575
+
487
576
  class MutualInformationLoss(torch.nn.Module):
488
- def __init__(self, num_bins: int = 23, sigma_ratio: float = 0.5, smooth_nr: float = 1e-7, smooth_dr: float = 1e-7) -> None:
577
+ def __init__(
578
+ self,
579
+ num_bins: int = 23,
580
+ sigma_ratio: float = 0.5,
581
+ smooth_nr: float = 1e-7,
582
+ smooth_dr: float = 1e-7,
583
+ ) -> None:
489
584
  super().__init__()
490
585
  bin_centers = torch.linspace(0.0, 1.0, num_bins)
491
586
  sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio
@@ -495,7 +590,9 @@ class MutualInformationLoss(torch.nn.Module):
495
590
  self.smooth_nr = float(smooth_nr)
496
591
  self.smooth_dr = float(smooth_dr)
497
592
 
498
- def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
593
+ def parzen_windowing(
594
+ self, pred: torch.Tensor, target: torch.Tensor
595
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
499
596
  pred_weight, pred_probability = self.parzen_windowing_gaussian(pred)
500
597
  target_weight, target_probability = self.parzen_windowing_gaussian(target)
501
598
  return pred_weight, pred_probability, target_weight, target_probability
@@ -503,35 +600,51 @@ class MutualInformationLoss(torch.nn.Module):
503
600
  def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
504
601
  img = torch.clamp(img, 0, 1)
505
602
  img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
506
- weight = torch.exp(-self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2) # (batch, num_sample, num_bin)
603
+ weight = torch.exp(
604
+ -self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2
605
+ ) # (batch, num_sample, num_bin)
507
606
  weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin)
508
607
  probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin)
509
608
  return weight, probability
510
609
 
511
- def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
512
- wa, pa, wb, pb = self.parzen_windowing(pred, target) # (batch, num_sample, num_bin), (batch, 1, num_bin)
610
+ def forward(self, pred: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
611
+ wa, pa, wb, pb = self.parzen_windowing(pred, targets[0]) # (batch, num_sample, num_bin), (batch, 1, num_bin)
513
612
  pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
514
613
  papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
515
- mi = torch.sum(pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)) # (batch)
614
+ mi = torch.sum(
615
+ pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr),
616
+ dim=(1, 2),
617
+ ) # (batch)
516
618
  return torch.mean(mi).neg() # average over the batch and channel ndims
517
619
 
518
- class FOCUS(Criterion): #Feature-Oriented Comparison for Unpaired Synthesis
519
620
 
520
- def __init__(self, model_name: str, shape: list[int] = [0,0], in_channels: int = 1, losses: dict[str, list[float]] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
621
+ class IMPACTSynth(Criterion): # Feature-Oriented Comparison for Unpaired Synthesis
622
+
623
+ def __init__(
624
+ self,
625
+ model_name: str,
626
+ shape: list[int] = [0, 0],
627
+ in_channels: int = 1,
628
+ losses: dict[str, list[float]] = {"Gram": [1], "torch_nn_L1Loss": [1]},
629
+ ) -> None:
521
630
  super().__init__()
522
631
  if model_name is None:
523
632
  return
524
- self.shape = shape
525
633
  self.in_channels = in_channels
526
634
  self.losses: dict[torch.nn.Module, list[float]] = {}
527
635
  for loss, weights in losses.items():
528
- module, name = _getModule(loss, "konfai.metric.measure")
529
- self.losses[config(os.environ['KONFAI_CONFIG_PATH'])(getattr(importlib.import_module(module), name))(config=None)] = weights
530
-
531
- self.model_path = download_url(model_name, "https://huggingface.co/VBoussot/impact-torchscript-models/resolve/main/")
532
- self.model: torch.nn.Module = torch.jit.load(self.model_path)
533
- self.dim = len(self.shape)
534
- self.shape = self.shape if all([s > 0 for s in self.shape]) else None
636
+ module, name = get_module(loss, "konfai.metric.measure")
637
+ self.losses[
638
+ config(os.environ["KONFAI_CONFIG_PATH"])(getattr(importlib.import_module(module), name))(config=None)
639
+ ] = weights
640
+
641
+ self.model_path = download_url(
642
+ model_name,
643
+ "https://huggingface.co/VBoussot/impact-torchscript-models/resolve/main/",
644
+ )
645
+ self.model: torch.nn.Module = torch.jit.load(self.model_path) # nosec B614
646
+ self.dim = len(shape)
647
+ self.shape = shape if all(s > 0 for s in shape) else None
535
648
  self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
536
649
 
537
650
  try:
@@ -540,7 +653,9 @@ class FOCUS(Criterion): #Feature-Oriented Comparison for Unpaired Synthesis
540
653
  if not isinstance(out, (list, tuple)):
541
654
  raise TypeError(f"Expected model output to be a list or tuple, but got {type(out)}.")
542
655
  if len(self.weight) != len(out):
543
- raise ValueError(f"Mismatch between number of weights ({len(self.weight)}) and model outputs ({len(out)}).")
656
+ raise ValueError(
657
+ f"Mismatch between number of weights ({len(self.weight)}) and model outputs ({len(out)})."
658
+ )
544
659
  except Exception as e:
545
660
  msg = (
546
661
  f"[Model Sanity Check Failed]\n"
@@ -550,35 +665,39 @@ class FOCUS(Criterion): #Feature-Oriented Comparison for Unpaired Synthesis
550
665
  )
551
666
  raise RuntimeError(msg) from e
552
667
  self.model = None
553
-
554
- def preprocessing(self, input: torch.Tensor) -> torch.Tensor:
555
- if self.shape is not None and not all([input.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]):
556
- input = torch.nn.functional.interpolate(input, mode=self.mode, size=tuple(self.shape), align_corners=False).type(torch.float32)
557
- if input.shape[1] != self.in_channels:
558
- input = input.repeat(tuple([1,3] + [1 for _ in range(self.dim)]))
559
- return input
560
-
668
+
669
+ def preprocessing(self, tensor: torch.Tensor) -> torch.Tensor:
670
+ if self.shape is not None and not all(
671
+ tensor.shape[-i - 1] == size for i, size in enumerate(reversed(self.shape[2:]))
672
+ ):
673
+ tensor = torch.nn.functional.interpolate(
674
+ tensor, mode=self.mode, size=tuple(self.shape), align_corners=False
675
+ ).type(torch.float32)
676
+ if tensor.shape[1] != self.in_channels:
677
+ tensor = tensor.repeat(tuple([1, 3] + [1 for _ in range(self.dim)]))
678
+ return tensor
679
+
561
680
  def _compute(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
562
- loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
681
+ loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
563
682
  output = self.preprocessing(output)
564
683
  targets = [self.preprocessing(target) for target in targets]
565
684
  self.model.to(output.device)
566
685
  for zipped_output in zip(self.model(output), *[self.model(target) for target in targets]):
567
686
  output_features = zipped_output[0]
568
- targets_features = zipped_output[1:]
687
+ targets_features = zipped_output[1:]
569
688
  for target_features, (loss_function, weights) in zip(targets_features, self.losses.items()):
570
- for output_feature, target_feature, weight in zip(output_features ,target_features, weights):
571
- loss += weights*loss_function(output_feature, target_feature)
689
+ for output_feature, target_feature, weight in zip(output_features, target_features, weights):
690
+ loss += weight * loss_function(output_feature, target_feature)
572
691
  return loss
573
-
574
- def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
692
+
693
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
575
694
  if self.model is None:
576
- self.model = torch.jit.load(self.model_path)
577
- loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
695
+ self.model = torch.jit.load(self.model_path) # nosec B614
696
+ loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
578
697
  if len(output.shape) == 5 and self.dim == 2:
579
698
  for i in range(output.shape[2]):
580
699
  loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])
581
- loss/=output.shape[2]
700
+ loss /= output.shape[2]
582
701
  else:
583
- loss = self._compute(output, targets)
584
- return loss.to(output)
702
+ loss = self._compute(output, list(targets))
703
+ return loss.to(output)