konfai 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +16 -0
- konfai/data/HDF5.py +326 -0
- konfai/data/__init__.py +0 -0
- konfai/data/augmentation.py +597 -0
- konfai/data/dataset.py +470 -0
- konfai/data/transform.py +536 -0
- konfai/evaluator.py +146 -0
- konfai/main.py +43 -0
- konfai/metric/__init__.py +0 -0
- konfai/metric/measure.py +488 -0
- konfai/metric/schedulers.py +49 -0
- konfai/models/classification/convNeXt.py +175 -0
- konfai/models/classification/resnet.py +116 -0
- konfai/models/generation/cStyleGan.py +137 -0
- konfai/models/generation/ddpm.py +218 -0
- konfai/models/generation/diffusionGan.py +557 -0
- konfai/models/generation/gan.py +134 -0
- konfai/models/generation/vae.py +72 -0
- konfai/models/registration/registration.py +136 -0
- konfai/models/representation/representation.py +57 -0
- konfai/models/segmentation/NestedUNet.py +53 -0
- konfai/models/segmentation/UNet.py +58 -0
- konfai/network/__init__.py +0 -0
- konfai/network/blocks.py +348 -0
- konfai/network/network.py +950 -0
- konfai/predictor.py +366 -0
- konfai/trainer.py +330 -0
- konfai/utils/ITK.py +269 -0
- konfai/utils/Registration.py +199 -0
- konfai/utils/__init__.py +0 -0
- konfai/utils/config.py +218 -0
- konfai/utils/dataset.py +764 -0
- konfai/utils/utils.py +493 -0
- konfai-1.0.0.dist-info/METADATA +68 -0
- konfai-1.0.0.dist-info/RECORD +39 -0
- konfai-1.0.0.dist-info/WHEEL +5 -0
- konfai-1.0.0.dist-info/entry_points.txt +3 -0
- konfai-1.0.0.dist-info/licenses/LICENSE +201 -0
- konfai-1.0.0.dist-info/top_level.txt +1 -0
konfai/main.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
from torch.cuda import device_count
|
|
4
|
+
import torch.multiprocessing as mp
|
|
5
|
+
import submitit
|
|
6
|
+
from konfai.utils.utils import setupAPI, TensorBoard, Log
|
|
7
|
+
|
|
8
|
+
def main():
|
|
9
|
+
parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
10
|
+
with setupAPI(parser) as distributedObject:
|
|
11
|
+
with Log(distributedObject.name):
|
|
12
|
+
world_size = device_count()
|
|
13
|
+
if world_size == 0:
|
|
14
|
+
world_size = 1
|
|
15
|
+
distributedObject.setup(world_size)
|
|
16
|
+
with TensorBoard(distributedObject.name):
|
|
17
|
+
mp.spawn(distributedObject, nprocs=world_size)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def cluster():
|
|
21
|
+
parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
22
|
+
|
|
23
|
+
# Cluster manager arguments
|
|
24
|
+
cluster_args = parser.add_argument_group('Cluster manager arguments')
|
|
25
|
+
cluster_args.add_argument('--name', type=str, help='Task name', required=True)
|
|
26
|
+
cluster_args.add_argument('--num-nodes', '--num_nodes', default=1, type=int, help='Number of nodes')
|
|
27
|
+
cluster_args.add_argument('--memory', type=int, default=16, help='Amount of memory per node')
|
|
28
|
+
cluster_args.add_argument('--time-limit', '--time_limit', type=int, default=1440, help='Job time limit in minute')
|
|
29
|
+
cluster_args.add_argument('--resubmit', action='store_true', help='Automatically resubmit job just before timout')
|
|
30
|
+
|
|
31
|
+
with setupAPI(parser) as distributedObject:
|
|
32
|
+
args = parser.parse_args()
|
|
33
|
+
config = vars(args)
|
|
34
|
+
os.environ["DL_API_OVERWRITE"] = "True"
|
|
35
|
+
os.environ["DL_API_CLUSTER"] = "True"
|
|
36
|
+
|
|
37
|
+
n_gpu = len(config["gpu"].split(","))
|
|
38
|
+
distributedObject.setup(n_gpu*int(config["num_nodes"]))
|
|
39
|
+
|
|
40
|
+
executor = submitit.AutoExecutor(folder="./Cluster/")
|
|
41
|
+
executor.update_parameters(name=config["name"], mem_gb=config["memory"], gpus_per_node=n_gpu, tasks_per_node=n_gpu//distributedObject.size, cpus_per_task=config["num_workers"], nodes=config["num_nodes"], timeout_min=config["time_limit"])
|
|
42
|
+
with TensorBoard(distributedObject.name):
|
|
43
|
+
executor.submit(distributedObject)
|
|
File without changes
|
konfai/metric/measure.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
import importlib
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
import os
|
|
9
|
+
import tqdm
|
|
10
|
+
|
|
11
|
+
from typing import Callable, Union
|
|
12
|
+
from functools import partial
|
|
13
|
+
import torch.nn.functional as F
|
|
14
|
+
from skimage.metrics import structural_similarity
|
|
15
|
+
import copy
|
|
16
|
+
from abc import abstractmethod
|
|
17
|
+
|
|
18
|
+
from KonfAI.konfai.utils.config import config
|
|
19
|
+
from KonfAI.konfai.utils.utils import _getModule
|
|
20
|
+
from KonfAI.konfai.data.HDF5 import ModelPatch
|
|
21
|
+
from KonfAI.konfai.network.blocks import LatentDistribution
|
|
22
|
+
from KonfAI.konfai.network.network import ModelLoader, Network
|
|
23
|
+
|
|
24
|
+
modelsRegister = {}
|
|
25
|
+
|
|
26
|
+
class Criterion(torch.nn.Module, ABC):
|
|
27
|
+
|
|
28
|
+
def __init__(self) -> None:
|
|
29
|
+
super().__init__()
|
|
30
|
+
|
|
31
|
+
def init(self, model : torch.nn.Module, output_group : str, target_group : str) -> str:
|
|
32
|
+
return output_group
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
class MaskedLoss(Criterion):
|
|
39
|
+
|
|
40
|
+
def __init__(self, loss : Callable[[torch.Tensor, torch.Tensor], torch.Tensor], mode_image_masked: bool) -> None:
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.loss = loss
|
|
43
|
+
self.mode_image_masked = mode_image_masked
|
|
44
|
+
|
|
45
|
+
def getMask(self, targets: list[torch.Tensor]) -> torch.Tensor:
|
|
46
|
+
result = None
|
|
47
|
+
if len(targets) > 0:
|
|
48
|
+
result = targets[0]
|
|
49
|
+
for mask in targets[1:]:
|
|
50
|
+
result = result*mask
|
|
51
|
+
return result
|
|
52
|
+
|
|
53
|
+
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
54
|
+
loss = torch.tensor(0, dtype=torch.float32).to(output.device)
|
|
55
|
+
mask = self.getMask(targets[1:])
|
|
56
|
+
for batch in range(output.shape[0]):
|
|
57
|
+
if mask is not None:
|
|
58
|
+
if self.mode_image_masked:
|
|
59
|
+
for i in torch.unique(mask[batch]):
|
|
60
|
+
if i != 0:
|
|
61
|
+
loss += self.loss(output[batch, ...]*torch.where(mask == i, 1, 0), targets[0][batch, ...]*torch.where(mask == i, 1, 0))
|
|
62
|
+
else:
|
|
63
|
+
for i in torch.unique(mask[batch]):
|
|
64
|
+
if i != 0:
|
|
65
|
+
index = mask[batch, ...] == i
|
|
66
|
+
loss += self.loss(torch.masked_select(output[batch, ...], index), torch.masked_select(targets[0][batch, ...], index))
|
|
67
|
+
else:
|
|
68
|
+
loss += self.loss(output[batch, ...], targets[0][batch, ...])
|
|
69
|
+
return loss/output.shape[0]
|
|
70
|
+
|
|
71
|
+
class MSE(MaskedLoss):
|
|
72
|
+
|
|
73
|
+
def _loss(reduction: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
74
|
+
return torch.nn.MSELoss(reduction=reduction)(x, y)
|
|
75
|
+
|
|
76
|
+
def __init__(self, reduction: str = "mean") -> None:
|
|
77
|
+
super().__init__(partial(MSE._loss, reduction), False)
|
|
78
|
+
|
|
79
|
+
class MAE(MaskedLoss):
|
|
80
|
+
|
|
81
|
+
def _loss(reduction: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
82
|
+
return torch.nn.L1Loss(reduction=reduction)(x, y)
|
|
83
|
+
|
|
84
|
+
def __init__(self, reduction: str = "mean") -> None:
|
|
85
|
+
super().__init__(partial(MAE._loss, reduction), False)
|
|
86
|
+
|
|
87
|
+
class PSNR(MaskedLoss):
|
|
88
|
+
|
|
89
|
+
def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
90
|
+
mse = torch.mean((x - y).pow(2))
|
|
91
|
+
psnr = 10 * torch.log10(dynamic_range**2 / mse)
|
|
92
|
+
return psnr
|
|
93
|
+
|
|
94
|
+
def __init__(self, dynamic_range: Union[float, None] = None) -> None:
|
|
95
|
+
dynamic_range = dynamic_range if dynamic_range else 1024+3000
|
|
96
|
+
super().__init__(partial(PSNR._loss, dynamic_range), False)
|
|
97
|
+
|
|
98
|
+
class SSIM(MaskedLoss):
|
|
99
|
+
|
|
100
|
+
def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
101
|
+
return structural_similarity(x[0][0].detach().cpu().numpy(), y[0][0].cpu().numpy(), data_range=dynamic_range, gradient=False, full=False)
|
|
102
|
+
|
|
103
|
+
def __init__(self, dynamic_range: Union[float, None] = None) -> None:
|
|
104
|
+
dynamic_range = dynamic_range if dynamic_range else 1024+3000
|
|
105
|
+
super().__init__(partial(SSIM._loss, dynamic_range), True)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class LPIPS(MaskedLoss):
|
|
109
|
+
|
|
110
|
+
def normalize(input: torch.Tensor) -> torch.Tensor:
|
|
111
|
+
return (input-torch.min(input))/(torch.max(input)-torch.min(input))*2-1
|
|
112
|
+
|
|
113
|
+
def preprocessing(input: torch.Tensor) -> torch.Tensor:
|
|
114
|
+
return input.repeat((1,3,1,1))
|
|
115
|
+
|
|
116
|
+
def _loss(loss_fn_alex, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
117
|
+
datasetPatch = ModelPatch([1, 64, 64])
|
|
118
|
+
datasetPatch.load(x.shape[2:])
|
|
119
|
+
|
|
120
|
+
patchIterator = datasetPatch.disassemble(LPIPS.normalize(x), LPIPS.normalize(y))
|
|
121
|
+
loss = 0
|
|
122
|
+
with tqdm.tqdm(iterable = enumerate(patchIterator), leave=False, total=datasetPatch.getSize(0)) as batch_iter:
|
|
123
|
+
for i, patch_input in batch_iter:
|
|
124
|
+
real, fake = LPIPS.preprocessing(patch_input[0]), LPIPS.preprocessing(patch_input[1])
|
|
125
|
+
loss += loss_fn_alex(real, fake).item()
|
|
126
|
+
return loss/datasetPatch.getSize(0)
|
|
127
|
+
|
|
128
|
+
def __init__(self, model: str = "alex") -> None:
|
|
129
|
+
import lpips
|
|
130
|
+
super().__init__(partial(LPIPS._loss, lpips.LPIPS(net=model)), True)
|
|
131
|
+
|
|
132
|
+
class Dice(Criterion):
|
|
133
|
+
|
|
134
|
+
def __init__(self, smooth : float = 1e-6) -> None:
|
|
135
|
+
super().__init__()
|
|
136
|
+
self.smooth = smooth
|
|
137
|
+
|
|
138
|
+
def flatten(self, tensor : torch.Tensor) -> torch.Tensor:
|
|
139
|
+
return tensor.permute((1, 0) + tuple(range(2, tensor.dim()))).contiguous().view(tensor.size(1), -1)
|
|
140
|
+
|
|
141
|
+
def dice_per_channel(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
142
|
+
input = self.flatten(input)
|
|
143
|
+
target = self.flatten(target)
|
|
144
|
+
return (2.*(input * target).sum() + self.smooth)/(input.sum() + target.sum() + self.smooth)
|
|
145
|
+
|
|
146
|
+
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
147
|
+
target = targets[0]
|
|
148
|
+
target = F.one_hot(target.type(torch.int64), num_classes=output.shape[1]).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float().squeeze(2)
|
|
149
|
+
return 1-torch.mean(self.dice_per_channel(output, target))
|
|
150
|
+
|
|
151
|
+
class GradientImages(Criterion):
|
|
152
|
+
|
|
153
|
+
def __init__(self):
|
|
154
|
+
super().__init__()
|
|
155
|
+
|
|
156
|
+
@staticmethod
|
|
157
|
+
def _image_gradient2D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
158
|
+
dx = image[:, :, 1:, :] - image[:, :, :-1, :]
|
|
159
|
+
dy = image[:, :, :, 1:] - image[:, :, :, :-1]
|
|
160
|
+
return dx, dy
|
|
161
|
+
|
|
162
|
+
@staticmethod
|
|
163
|
+
def _image_gradient3D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
164
|
+
dx = image[:, :, 1:, :, :] - image[:, :, :-1, :, :]
|
|
165
|
+
dy = image[:, :, :, 1:, :] - image[:, :, :, :-1, :]
|
|
166
|
+
dz = image[:, :, :, :, 1:] - image[:, :, :, :, :-1]
|
|
167
|
+
return dx, dy, dz
|
|
168
|
+
|
|
169
|
+
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
170
|
+
target_0 = targets[0]
|
|
171
|
+
if len(output.shape) == 5:
|
|
172
|
+
dx, dy, dz = GradientImages._image_gradient3D(output)
|
|
173
|
+
if target_0 is not None:
|
|
174
|
+
dx_tmp, dy_tmp, dz_tmp = GradientImages._image_gradient3D(target_0)
|
|
175
|
+
dx -= dx_tmp
|
|
176
|
+
dy -= dy_tmp
|
|
177
|
+
dz -= dz_tmp
|
|
178
|
+
return dx.norm() + dy.norm() + dz.norm()
|
|
179
|
+
else:
|
|
180
|
+
dx, dy = GradientImages._image_gradient2D(output)
|
|
181
|
+
if target_0 is not None:
|
|
182
|
+
dx_tmp, dy_tmp = GradientImages._image_gradient2D(target_0)
|
|
183
|
+
dx -= dx_tmp
|
|
184
|
+
dy -= dy_tmp
|
|
185
|
+
return dx.norm() + dy.norm()
|
|
186
|
+
|
|
187
|
+
class BCE(Criterion):
|
|
188
|
+
|
|
189
|
+
def __init__(self, target : float = 0) -> None:
|
|
190
|
+
super().__init__()
|
|
191
|
+
self.loss = torch.nn.BCEWithLogitsLoss()
|
|
192
|
+
self.register_buffer('target', torch.tensor(target).type(torch.float32))
|
|
193
|
+
|
|
194
|
+
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
195
|
+
target = self._buffers["target"]
|
|
196
|
+
return self.loss(output, target.to(output.device).expand_as(output))
|
|
197
|
+
|
|
198
|
+
class PatchGanLoss(Criterion):
|
|
199
|
+
|
|
200
|
+
def __init__(self, target : float = 0) -> None:
|
|
201
|
+
super().__init__()
|
|
202
|
+
self.loss = torch.nn.MSELoss()
|
|
203
|
+
self.register_buffer('target', torch.tensor(target).type(torch.float32))
|
|
204
|
+
|
|
205
|
+
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
206
|
+
target = self._buffers["target"]
|
|
207
|
+
return self.loss(output, (torch.ones_like(output)*target).to(output.device))
|
|
208
|
+
|
|
209
|
+
class WGP(Criterion):
|
|
210
|
+
|
|
211
|
+
def __init__(self) -> None:
|
|
212
|
+
super().__init__()
|
|
213
|
+
|
|
214
|
+
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
215
|
+
return torch.mean((output - 1)**2)
|
|
216
|
+
|
|
217
|
+
class Gram(Criterion):
|
|
218
|
+
|
|
219
|
+
def computeGram(input : torch.Tensor):
|
|
220
|
+
(b, ch, w) = input.size()
|
|
221
|
+
with torch.amp.autocast('cuda', enabled=False):
|
|
222
|
+
return input.bmm(input.transpose(1, 2)).div(ch*w)
|
|
223
|
+
|
|
224
|
+
def __init__(self) -> None:
|
|
225
|
+
super().__init__()
|
|
226
|
+
self.loss = torch.nn.L1Loss(reduction='sum')
|
|
227
|
+
|
|
228
|
+
def forward(self, output: torch.Tensor, target : torch.Tensor) -> torch.Tensor:
|
|
229
|
+
if len(output.shape) > 3:
|
|
230
|
+
output = output.view(output.shape[0], output.shape[1], int(np.prod(output.shape[2:])))
|
|
231
|
+
if len(target.shape) > 3:
|
|
232
|
+
target = target.view(target.shape[0], target.shape[1], int(np.prod(target.shape[2:])))
|
|
233
|
+
return self.loss(Gram.computeGram(output), Gram.computeGram(target))
|
|
234
|
+
|
|
235
|
+
class PerceptualLoss(Criterion):
|
|
236
|
+
|
|
237
|
+
class Module():
|
|
238
|
+
|
|
239
|
+
@config(None)
|
|
240
|
+
def __init__(self, losses: dict[str, float] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
|
|
241
|
+
self.losses = losses
|
|
242
|
+
self.DL_args = os.environ['DEEP_LEARNING_API_CONFIG_PATH'] if "DEEP_LEARNING_API_CONFIG_PATH" in os.environ else ""
|
|
243
|
+
|
|
244
|
+
def getLoss(self) -> dict[torch.nn.Module, float]:
|
|
245
|
+
result: dict[torch.nn.Module, float] = {}
|
|
246
|
+
for loss, l in self.losses.items():
|
|
247
|
+
module, name = _getModule(loss, "measure")
|
|
248
|
+
result[config(self.DL_args)(getattr(importlib.import_module(module), name))(config=None)] = l
|
|
249
|
+
return result
|
|
250
|
+
|
|
251
|
+
def __init__(self, modelLoader : ModelLoader = ModelLoader(), path_model : str = "name", modules : dict[str, Module] = {"UNetBlock_0.DownConvBlock.Activation_1": Module({"Gram": 1, "torch_nn_L1Loss": 1})}, shape: list[int] = [128, 128, 128]) -> None:
|
|
252
|
+
super().__init__()
|
|
253
|
+
self.path_model = path_model
|
|
254
|
+
if self.path_model not in modelsRegister:
|
|
255
|
+
self.model = modelLoader.getModel(train=False, DL_args=os.environ['DEEP_LEARNING_API_CONFIG_PATH'].split("PerceptualLoss")[0]+"PerceptualLoss.Model", DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain", "outputsCriterions", "drop_p"])
|
|
256
|
+
if path_model.startswith("https"):
|
|
257
|
+
state_dict = torch.hub.load_state_dict_from_url(path_model)
|
|
258
|
+
state_dict = {"Model": {self.model.getName() : state_dict["model"]}}
|
|
259
|
+
else:
|
|
260
|
+
state_dict = torch.load(path_model, weights_only=True)
|
|
261
|
+
self.model.load(state_dict)
|
|
262
|
+
modelsRegister[self.path_model] = self.model
|
|
263
|
+
else:
|
|
264
|
+
self.model = modelsRegister[self.path_model]
|
|
265
|
+
|
|
266
|
+
self.shape = shape
|
|
267
|
+
self.mode = "trilinear" if len(shape) == 3 else "bilinear"
|
|
268
|
+
self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
|
|
269
|
+
for name, losses in modules.items():
|
|
270
|
+
self.modules_loss[name.replace(":", ".")] = losses.getLoss()
|
|
271
|
+
|
|
272
|
+
self.model.eval()
|
|
273
|
+
self.model.requires_grad_(False)
|
|
274
|
+
self.models: dict[int, torch.nn.Module] = {}
|
|
275
|
+
|
|
276
|
+
def preprocessing(self, input: torch.Tensor) -> torch.Tensor:
|
|
277
|
+
#if not all([input.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]):
|
|
278
|
+
# input = F.interpolate(input, mode=self.mode, size=tuple(self.shape), align_corners=False).type(torch.float32)
|
|
279
|
+
#if input.shape[1] != self.model.in_channels:
|
|
280
|
+
# input = input.repeat(tuple([1,self.model.in_channels] + [1 for _ in range(len(self.shape))]))
|
|
281
|
+
#input = (input - torch.min(input))/(torch.max(input)-torch.min(input))
|
|
282
|
+
#input = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input)
|
|
283
|
+
return input
|
|
284
|
+
|
|
285
|
+
def _compute(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
|
|
286
|
+
loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
|
|
287
|
+
output = self.preprocessing(output)
|
|
288
|
+
targets = [self.preprocessing(target) for target in targets]
|
|
289
|
+
for zipped_output in zip([output], *[[target] for target in targets]):
|
|
290
|
+
output = zipped_output[0]
|
|
291
|
+
targets = zipped_output[1:]
|
|
292
|
+
|
|
293
|
+
for zipped_layers in list(zip(self.models[output.device.index].get_layers([output], set(self.modules_loss.keys()).copy()), *[self.models[output.device.index].get_layers([target], set(self.modules_loss.keys()).copy()) for target in targets])):
|
|
294
|
+
output_layer = zipped_layers[0][1].view(zipped_layers[0][1].shape[0], zipped_layers[0][1].shape[1], int(np.prod(zipped_layers[0][1].shape[2:])))
|
|
295
|
+
for (loss_function, l), target_layer in zip(self.modules_loss[zipped_layers[0][0]].items(), zipped_layers[1:]):
|
|
296
|
+
target_layer = target_layer[1].view(target_layer[1].shape[0], target_layer[1].shape[1], int(np.prod(target_layer[1].shape[2:])))
|
|
297
|
+
loss = loss+l*loss_function(output_layer.float(), target_layer.float())/output_layer.shape[0]
|
|
298
|
+
return loss
|
|
299
|
+
|
|
300
|
+
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
301
|
+
if output.device.index not in self.models:
|
|
302
|
+
del os.environ["device"]
|
|
303
|
+
self.models[output.device.index] = Network.to(copy.deepcopy(self.model).eval(), output.device.index).eval()
|
|
304
|
+
loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
|
|
305
|
+
if len(output.shape) == 5 and len(self.shape) == 2:
|
|
306
|
+
for i in range(output.shape[2]):
|
|
307
|
+
loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])/output.shape[2]
|
|
308
|
+
else:
|
|
309
|
+
loss = self._compute(output, targets)
|
|
310
|
+
return loss.to(output)
|
|
311
|
+
|
|
312
|
+
class KLDivergence(Criterion):
|
|
313
|
+
|
|
314
|
+
def __init__(self, shape: list[int], dim : int = 100, mu : float = 0, std : float = 1) -> None:
|
|
315
|
+
super().__init__()
|
|
316
|
+
self.latentDim = dim
|
|
317
|
+
self.mu = torch.Tensor([mu])
|
|
318
|
+
self.std = torch.Tensor([std])
|
|
319
|
+
self.modelDim = 3
|
|
320
|
+
self.shape = shape
|
|
321
|
+
self.loss = torch.nn.KLDivLoss()
|
|
322
|
+
|
|
323
|
+
def init(self, model : Network, output_group : str, target_group : str) -> str:
|
|
324
|
+
super().init(model, output_group, target_group)
|
|
325
|
+
model._compute_channels_trace(model, model.in_channels, None, None)
|
|
326
|
+
|
|
327
|
+
last_module = model
|
|
328
|
+
for name in output_group.split(".")[:-1]:
|
|
329
|
+
last_module = last_module[name]
|
|
330
|
+
|
|
331
|
+
modules = last_module._modules.copy()
|
|
332
|
+
last_module._modules.clear()
|
|
333
|
+
|
|
334
|
+
for name, value in modules.items():
|
|
335
|
+
last_module._modules[name] = value
|
|
336
|
+
if name == output_group.split(".")[-1]:
|
|
337
|
+
last_module.add_module("LatentDistribution", LatentDistribution(shape = self.shape, latentDim=self.latentDim))
|
|
338
|
+
return ".".join(output_group.split(".")[:-1])+".LatentDistribution.Concat"
|
|
339
|
+
|
|
340
|
+
def forward(self, output: torch.Tensor, targets : list[torch.Tensor]) -> torch.Tensor:
|
|
341
|
+
mu = output[:, 0, :]
|
|
342
|
+
log_std = output[:, 1, :]
|
|
343
|
+
return torch.mean(-0.5 * torch.sum(1 + log_std - mu**2 - torch.exp(log_std), dim = 1), dim = 0)
|
|
344
|
+
|
|
345
|
+
"""
|
|
346
|
+
def forward(self, input : torch.Tensor) -> torch.Tensor:
|
|
347
|
+
mu = input[:, 0, :]
|
|
348
|
+
log_std = input[:, 1, :]
|
|
349
|
+
|
|
350
|
+
z = input[:, 2, :]
|
|
351
|
+
|
|
352
|
+
q = torch.distributions.Normal(mu, log_std)
|
|
353
|
+
|
|
354
|
+
target_mu = torch.ones((self.latentDim)).to(input.device)*self.mu.to(input.device)
|
|
355
|
+
target_std = torch.ones((self.latentDim)).to(input.device)*self.std.to(input.device)
|
|
356
|
+
|
|
357
|
+
p = torch.distributions.Normal(target_mu, target_std)
|
|
358
|
+
|
|
359
|
+
log_pz = p.log_prob(z)
|
|
360
|
+
log_qzx = q.log_prob(z)
|
|
361
|
+
|
|
362
|
+
kl = (log_pz - log_qzx)
|
|
363
|
+
kl = kl.sum(-1)
|
|
364
|
+
return kl
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
class Accuracy(Criterion):
|
|
368
|
+
|
|
369
|
+
def __init__(self) -> None:
|
|
370
|
+
super().__init__()
|
|
371
|
+
self.n : int = 0
|
|
372
|
+
self.corrects = torch.zeros((1))
|
|
373
|
+
|
|
374
|
+
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
375
|
+
target_0 = targets[0]
|
|
376
|
+
self.n += output.shape[0]
|
|
377
|
+
self.corrects += (torch.argmax(torch.softmax(output, dim=1), dim=1) == target_0).sum().float().cpu()
|
|
378
|
+
return self.corrects/self.n
|
|
379
|
+
|
|
380
|
+
class TripletLoss(Criterion):
|
|
381
|
+
|
|
382
|
+
def __init__(self) -> None:
|
|
383
|
+
super().__init__()
|
|
384
|
+
self.triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
|
|
385
|
+
|
|
386
|
+
def forward(self, output: torch.Tensor) -> torch.Tensor:
|
|
387
|
+
return self.triplet_loss(output[0], output[1], output[2])
|
|
388
|
+
|
|
389
|
+
class L1LossRepresentation(Criterion):
|
|
390
|
+
|
|
391
|
+
def __init__(self) -> None:
|
|
392
|
+
super().__init__()
|
|
393
|
+
self.loss = torch.nn.L1Loss()
|
|
394
|
+
|
|
395
|
+
def _variance(self, features: torch.Tensor) -> torch.Tensor:
|
|
396
|
+
return torch.mean(torch.clamp(1 - torch.var(features, dim=0), min=0))
|
|
397
|
+
|
|
398
|
+
def forward(self, output: torch.Tensor) -> torch.Tensor:
|
|
399
|
+
return self.loss(output[0], output[1])+ self._variance(output[0]) + self._variance(output[1])
|
|
400
|
+
|
|
401
|
+
"""import torch
|
|
402
|
+
import torch.nn as nn
|
|
403
|
+
from torchvision.models import inception_v3, Inception_V3_Weights
|
|
404
|
+
from torchvision.transforms import functional as F
|
|
405
|
+
from scipy import linalg
|
|
406
|
+
import numpy as np
|
|
407
|
+
|
|
408
|
+
class FID(Criterion):
|
|
409
|
+
|
|
410
|
+
class InceptionV3(nn.Module):
|
|
411
|
+
|
|
412
|
+
def __init__(self) -> None:
|
|
413
|
+
super().__init__()
|
|
414
|
+
self.model = inception_v3(weights=Inception_V3_Weights.DEFAULT, transform_input=False)
|
|
415
|
+
self.model.fc = nn.Identity()
|
|
416
|
+
self.model.eval()
|
|
417
|
+
|
|
418
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
419
|
+
return self.model(x)
|
|
420
|
+
|
|
421
|
+
def __init__(self) -> None:
|
|
422
|
+
super().__init__()
|
|
423
|
+
self.inception_model = FID.InceptionV3().cuda()
|
|
424
|
+
|
|
425
|
+
def preprocess_images(image: torch.Tensor) -> torch.Tensor:
|
|
426
|
+
return F.normalize(F.resize(image, (299, 299)).repeat((1,3,1,1)), mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).cuda()
|
|
427
|
+
|
|
428
|
+
def get_features(images: torch.Tensor, model: torch.nn.Module) -> np.ndarray:
|
|
429
|
+
with torch.no_grad():
|
|
430
|
+
features = model(images).cpu().numpy()
|
|
431
|
+
return features
|
|
432
|
+
|
|
433
|
+
def calculate_fid(real_features: np.ndarray, generated_features: np.ndarray) -> float:
|
|
434
|
+
# Calculate mean and covariance statistics
|
|
435
|
+
mu1 = np.mean(real_features, axis=0)
|
|
436
|
+
sigma1 = np.cov(real_features, rowvar=False)
|
|
437
|
+
mu2 = np.mean(generated_features, axis=0)
|
|
438
|
+
sigma2 = np.cov(generated_features, rowvar=False)
|
|
439
|
+
|
|
440
|
+
# Calculate FID score
|
|
441
|
+
diff = mu1 - mu2
|
|
442
|
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
|
443
|
+
if np.iscomplexobj(covmean):
|
|
444
|
+
covmean = covmean.real
|
|
445
|
+
|
|
446
|
+
return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
|
|
447
|
+
|
|
448
|
+
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
449
|
+
real_images = FID.preprocess_images(targets[0].squeeze(0).permute([1, 0, 2, 3]))
|
|
450
|
+
generated_images = FID.preprocess_images(output.squeeze(0).permute([1, 0, 2, 3]))
|
|
451
|
+
|
|
452
|
+
real_features = FID.get_features(real_images, self.inception_model)
|
|
453
|
+
generated_features = FID.get_features(generated_images, self.inception_model)
|
|
454
|
+
|
|
455
|
+
return FID.calculate_fid(real_features, generated_features)
|
|
456
|
+
"""
|
|
457
|
+
|
|
458
|
+
"""class MutualInformationLoss(torch.nn.Module):
|
|
459
|
+
def __init__(self, num_bins: int = 23, sigma_ratio: float = 0.5, smooth_nr: float = 1e-7, smooth_dr: float = 1e-7) -> None:
|
|
460
|
+
super().__init__()
|
|
461
|
+
bin_centers = torch.linspace(0.0, 1.0, num_bins)
|
|
462
|
+
sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio
|
|
463
|
+
self.num_bins = num_bins
|
|
464
|
+
self.preterm = 1 / (2 * sigma**2)
|
|
465
|
+
self.bin_centers = bin_centers[None, None, ...]
|
|
466
|
+
self.smooth_nr = float(smooth_nr)
|
|
467
|
+
self.smooth_dr = float(smooth_dr)
|
|
468
|
+
|
|
469
|
+
def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
470
|
+
pred_weight, pred_probability = self.parzen_windowing_gaussian(pred)
|
|
471
|
+
target_weight, target_probability = self.parzen_windowing_gaussian(target)
|
|
472
|
+
return pred_weight, pred_probability, target_weight, target_probability
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def parzen_windowing_gaussian(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
476
|
+
img = torch.clamp(img, 0, 1)
|
|
477
|
+
img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
|
|
478
|
+
weight = torch.exp(-self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2) # (batch, num_sample, num_bin)
|
|
479
|
+
weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin)
|
|
480
|
+
probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin)
|
|
481
|
+
return weight, probability
|
|
482
|
+
|
|
483
|
+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
484
|
+
wa, pa, wb, pb = self.parzen_windowing(pred, target) # (batch, num_sample, num_bin), (batch, 1, num_bin)
|
|
485
|
+
pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
|
|
486
|
+
papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
|
|
487
|
+
mi = torch.sum(pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)) # (batch)
|
|
488
|
+
return torch.mean(mi).neg() # average over the batch and channel ndims"""
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from KonfAI.konfai.utils.config import config
|
|
5
|
+
|
|
6
|
+
class Scheduler():
|
|
7
|
+
|
|
8
|
+
@config("Scheduler")
|
|
9
|
+
def __init__(self, start_value: float) -> None:
|
|
10
|
+
self.baseValue = float(start_value)
|
|
11
|
+
self.it = 0
|
|
12
|
+
|
|
13
|
+
def step(self, it: int):
|
|
14
|
+
self.it = it
|
|
15
|
+
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def get_value(self) -> float:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
class Constant(Scheduler):
|
|
21
|
+
|
|
22
|
+
@config("Constant")
|
|
23
|
+
def __init__(self, value: float = 1):
|
|
24
|
+
super().__init__(value)
|
|
25
|
+
|
|
26
|
+
def get_value(self) -> float:
|
|
27
|
+
return self.baseValue
|
|
28
|
+
|
|
29
|
+
class CosineAnnealing(Scheduler):
|
|
30
|
+
|
|
31
|
+
@config("CosineAnnealing")
|
|
32
|
+
def __init__(self, start_value: float, eta_min: float = 0.00001, T_max: int = 100):
|
|
33
|
+
super().__init__(start_value)
|
|
34
|
+
self.eta_min = eta_min
|
|
35
|
+
self.T_max = T_max
|
|
36
|
+
|
|
37
|
+
def get_value(self):
|
|
38
|
+
return self.eta_min + (self.baseValue - self.eta_min) *(1 + np.cos(self.it * torch.pi / self.T_max)) / 2
|
|
39
|
+
|
|
40
|
+
class CosineAnnealing(Scheduler):
|
|
41
|
+
|
|
42
|
+
@config("CosineAnnealing")
|
|
43
|
+
def __init__(self, start_value: float, eta_min: float = 0.00001, T_max: int = 100):
|
|
44
|
+
super().__init__(start_value)
|
|
45
|
+
self.eta_min = eta_min
|
|
46
|
+
self.T_max = T_max
|
|
47
|
+
|
|
48
|
+
def get_value(self):
|
|
49
|
+
return self.eta_min + (self.baseValue - self.eta_min) *(1 + np.cos(self.it * torch.pi / self.T_max)) / 2
|