konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/main.py ADDED
@@ -0,0 +1,43 @@
1
+ import argparse
2
+ import os
3
+ from torch.cuda import device_count
4
+ import torch.multiprocessing as mp
5
+ import submitit
6
+ from konfai.utils.utils import setupAPI, TensorBoard, Log
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
10
+ with setupAPI(parser) as distributedObject:
11
+ with Log(distributedObject.name):
12
+ world_size = device_count()
13
+ if world_size == 0:
14
+ world_size = 1
15
+ distributedObject.setup(world_size)
16
+ with TensorBoard(distributedObject.name):
17
+ mp.spawn(distributedObject, nprocs=world_size)
18
+
19
+
20
+ def cluster():
21
+ parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22
+
23
+ # Cluster manager arguments
24
+ cluster_args = parser.add_argument_group('Cluster manager arguments')
25
+ cluster_args.add_argument('--name', type=str, help='Task name', required=True)
26
+ cluster_args.add_argument('--num-nodes', '--num_nodes', default=1, type=int, help='Number of nodes')
27
+ cluster_args.add_argument('--memory', type=int, default=16, help='Amount of memory per node')
28
+ cluster_args.add_argument('--time-limit', '--time_limit', type=int, default=1440, help='Job time limit in minute')
29
+ cluster_args.add_argument('--resubmit', action='store_true', help='Automatically resubmit job just before timout')
30
+
31
+ with setupAPI(parser) as distributedObject:
32
+ args = parser.parse_args()
33
+ config = vars(args)
34
+ os.environ["DL_API_OVERWRITE"] = "True"
35
+ os.environ["DL_API_CLUSTER"] = "True"
36
+
37
+ n_gpu = len(config["gpu"].split(","))
38
+ distributedObject.setup(n_gpu*int(config["num_nodes"]))
39
+
40
+ executor = submitit.AutoExecutor(folder="./Cluster/")
41
+ executor.update_parameters(name=config["name"], mem_gb=config["memory"], gpus_per_node=n_gpu, tasks_per_node=n_gpu//distributedObject.size, cpus_per_task=config["num_workers"], nodes=config["num_nodes"], timeout_min=config["time_limit"])
42
+ with TensorBoard(distributedObject.name):
43
+ executor.submit(distributedObject)
File without changes
@@ -0,0 +1,488 @@
1
+ from abc import ABC
2
+ import importlib
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ import torch.nn.functional as F
8
+ import os
9
+ import tqdm
10
+
11
+ from typing import Callable, Union
12
+ from functools import partial
13
+ import torch.nn.functional as F
14
+ from skimage.metrics import structural_similarity
15
+ import copy
16
+ from abc import abstractmethod
17
+
18
+ from KonfAI.konfai.utils.config import config
19
+ from KonfAI.konfai.utils.utils import _getModule
20
+ from KonfAI.konfai.data.HDF5 import ModelPatch
21
+ from KonfAI.konfai.network.blocks import LatentDistribution
22
+ from KonfAI.konfai.network.network import ModelLoader, Network
23
+
24
+ modelsRegister = {}
25
+
26
+ class Criterion(torch.nn.Module, ABC):
27
+
28
+ def __init__(self) -> None:
29
+ super().__init__()
30
+
31
+ def init(self, model : torch.nn.Module, output_group : str, target_group : str) -> str:
32
+ return output_group
33
+
34
+ @abstractmethod
35
+ def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
36
+ pass
37
+
38
+ class MaskedLoss(Criterion):
39
+
40
+ def __init__(self, loss : Callable[[torch.Tensor, torch.Tensor], torch.Tensor], mode_image_masked: bool) -> None:
41
+ super().__init__()
42
+ self.loss = loss
43
+ self.mode_image_masked = mode_image_masked
44
+
45
+ def getMask(self, targets: list[torch.Tensor]) -> torch.Tensor:
46
+ result = None
47
+ if len(targets) > 0:
48
+ result = targets[0]
49
+ for mask in targets[1:]:
50
+ result = result*mask
51
+ return result
52
+
53
+ def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
54
+ loss = torch.tensor(0, dtype=torch.float32).to(output.device)
55
+ mask = self.getMask(targets[1:])
56
+ for batch in range(output.shape[0]):
57
+ if mask is not None:
58
+ if self.mode_image_masked:
59
+ for i in torch.unique(mask[batch]):
60
+ if i != 0:
61
+ loss += self.loss(output[batch, ...]*torch.where(mask == i, 1, 0), targets[0][batch, ...]*torch.where(mask == i, 1, 0))
62
+ else:
63
+ for i in torch.unique(mask[batch]):
64
+ if i != 0:
65
+ index = mask[batch, ...] == i
66
+ loss += self.loss(torch.masked_select(output[batch, ...], index), torch.masked_select(targets[0][batch, ...], index))
67
+ else:
68
+ loss += self.loss(output[batch, ...], targets[0][batch, ...])
69
+ return loss/output.shape[0]
70
+
71
+ class MSE(MaskedLoss):
72
+
73
+ def _loss(reduction: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
74
+ return torch.nn.MSELoss(reduction=reduction)(x, y)
75
+
76
+ def __init__(self, reduction: str = "mean") -> None:
77
+ super().__init__(partial(MSE._loss, reduction), False)
78
+
79
+ class MAE(MaskedLoss):
80
+
81
+ def _loss(reduction: str, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
82
+ return torch.nn.L1Loss(reduction=reduction)(x, y)
83
+
84
+ def __init__(self, reduction: str = "mean") -> None:
85
+ super().__init__(partial(MAE._loss, reduction), False)
86
+
87
+ class PSNR(MaskedLoss):
88
+
89
+ def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
90
+ mse = torch.mean((x - y).pow(2))
91
+ psnr = 10 * torch.log10(dynamic_range**2 / mse)
92
+ return psnr
93
+
94
+ def __init__(self, dynamic_range: Union[float, None] = None) -> None:
95
+ dynamic_range = dynamic_range if dynamic_range else 1024+3000
96
+ super().__init__(partial(PSNR._loss, dynamic_range), False)
97
+
98
+ class SSIM(MaskedLoss):
99
+
100
+ def _loss(dynamic_range: float, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
101
+ return structural_similarity(x[0][0].detach().cpu().numpy(), y[0][0].cpu().numpy(), data_range=dynamic_range, gradient=False, full=False)
102
+
103
+ def __init__(self, dynamic_range: Union[float, None] = None) -> None:
104
+ dynamic_range = dynamic_range if dynamic_range else 1024+3000
105
+ super().__init__(partial(SSIM._loss, dynamic_range), True)
106
+
107
+
108
+ class LPIPS(MaskedLoss):
109
+
110
+ def normalize(input: torch.Tensor) -> torch.Tensor:
111
+ return (input-torch.min(input))/(torch.max(input)-torch.min(input))*2-1
112
+
113
+ def preprocessing(input: torch.Tensor) -> torch.Tensor:
114
+ return input.repeat((1,3,1,1))
115
+
116
+ def _loss(loss_fn_alex, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
117
+ datasetPatch = ModelPatch([1, 64, 64])
118
+ datasetPatch.load(x.shape[2:])
119
+
120
+ patchIterator = datasetPatch.disassemble(LPIPS.normalize(x), LPIPS.normalize(y))
121
+ loss = 0
122
+ with tqdm.tqdm(iterable = enumerate(patchIterator), leave=False, total=datasetPatch.getSize(0)) as batch_iter:
123
+ for i, patch_input in batch_iter:
124
+ real, fake = LPIPS.preprocessing(patch_input[0]), LPIPS.preprocessing(patch_input[1])
125
+ loss += loss_fn_alex(real, fake).item()
126
+ return loss/datasetPatch.getSize(0)
127
+
128
+ def __init__(self, model: str = "alex") -> None:
129
+ import lpips
130
+ super().__init__(partial(LPIPS._loss, lpips.LPIPS(net=model)), True)
131
+
132
+ class Dice(Criterion):
133
+
134
+ def __init__(self, smooth : float = 1e-6) -> None:
135
+ super().__init__()
136
+ self.smooth = smooth
137
+
138
+ def flatten(self, tensor : torch.Tensor) -> torch.Tensor:
139
+ return tensor.permute((1, 0) + tuple(range(2, tensor.dim()))).contiguous().view(tensor.size(1), -1)
140
+
141
+ def dice_per_channel(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
142
+ input = self.flatten(input)
143
+ target = self.flatten(target)
144
+ return (2.*(input * target).sum() + self.smooth)/(input.sum() + target.sum() + self.smooth)
145
+
146
+ def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
147
+ target = targets[0]
148
+ target = F.one_hot(target.type(torch.int64), num_classes=output.shape[1]).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float().squeeze(2)
149
+ return 1-torch.mean(self.dice_per_channel(output, target))
150
+
151
+ class GradientImages(Criterion):
152
+
153
+ def __init__(self):
154
+ super().__init__()
155
+
156
+ @staticmethod
157
+ def _image_gradient2D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
158
+ dx = image[:, :, 1:, :] - image[:, :, :-1, :]
159
+ dy = image[:, :, :, 1:] - image[:, :, :, :-1]
160
+ return dx, dy
161
+
162
+ @staticmethod
163
+ def _image_gradient3D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
164
+ dx = image[:, :, 1:, :, :] - image[:, :, :-1, :, :]
165
+ dy = image[:, :, :, 1:, :] - image[:, :, :, :-1, :]
166
+ dz = image[:, :, :, :, 1:] - image[:, :, :, :, :-1]
167
+ return dx, dy, dz
168
+
169
+ def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
170
+ target_0 = targets[0]
171
+ if len(output.shape) == 5:
172
+ dx, dy, dz = GradientImages._image_gradient3D(output)
173
+ if target_0 is not None:
174
+ dx_tmp, dy_tmp, dz_tmp = GradientImages._image_gradient3D(target_0)
175
+ dx -= dx_tmp
176
+ dy -= dy_tmp
177
+ dz -= dz_tmp
178
+ return dx.norm() + dy.norm() + dz.norm()
179
+ else:
180
+ dx, dy = GradientImages._image_gradient2D(output)
181
+ if target_0 is not None:
182
+ dx_tmp, dy_tmp = GradientImages._image_gradient2D(target_0)
183
+ dx -= dx_tmp
184
+ dy -= dy_tmp
185
+ return dx.norm() + dy.norm()
186
+
187
+ class BCE(Criterion):
188
+
189
+ def __init__(self, target : float = 0) -> None:
190
+ super().__init__()
191
+ self.loss = torch.nn.BCEWithLogitsLoss()
192
+ self.register_buffer('target', torch.tensor(target).type(torch.float32))
193
+
194
+ def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
195
+ target = self._buffers["target"]
196
+ return self.loss(output, target.to(output.device).expand_as(output))
197
+
198
+ class PatchGanLoss(Criterion):
199
+
200
+ def __init__(self, target : float = 0) -> None:
201
+ super().__init__()
202
+ self.loss = torch.nn.MSELoss()
203
+ self.register_buffer('target', torch.tensor(target).type(torch.float32))
204
+
205
+ def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
206
+ target = self._buffers["target"]
207
+ return self.loss(output, (torch.ones_like(output)*target).to(output.device))
208
+
209
+ class WGP(Criterion):
210
+
211
+ def __init__(self) -> None:
212
+ super().__init__()
213
+
214
+ def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
215
+ return torch.mean((output - 1)**2)
216
+
217
+ class Gram(Criterion):
218
+
219
+ def computeGram(input : torch.Tensor):
220
+ (b, ch, w) = input.size()
221
+ with torch.amp.autocast('cuda', enabled=False):
222
+ return input.bmm(input.transpose(1, 2)).div(ch*w)
223
+
224
+ def __init__(self) -> None:
225
+ super().__init__()
226
+ self.loss = torch.nn.L1Loss(reduction='sum')
227
+
228
+ def forward(self, output: torch.Tensor, target : torch.Tensor) -> torch.Tensor:
229
+ if len(output.shape) > 3:
230
+ output = output.view(output.shape[0], output.shape[1], int(np.prod(output.shape[2:])))
231
+ if len(target.shape) > 3:
232
+ target = target.view(target.shape[0], target.shape[1], int(np.prod(target.shape[2:])))
233
+ return self.loss(Gram.computeGram(output), Gram.computeGram(target))
234
+
235
+ class PerceptualLoss(Criterion):
236
+
237
+ class Module():
238
+
239
+ @config(None)
240
+ def __init__(self, losses: dict[str, float] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
241
+ self.losses = losses
242
+ self.DL_args = os.environ['DEEP_LEARNING_API_CONFIG_PATH'] if "DEEP_LEARNING_API_CONFIG_PATH" in os.environ else ""
243
+
244
+ def getLoss(self) -> dict[torch.nn.Module, float]:
245
+ result: dict[torch.nn.Module, float] = {}
246
+ for loss, l in self.losses.items():
247
+ module, name = _getModule(loss, "measure")
248
+ result[config(self.DL_args)(getattr(importlib.import_module(module), name))(config=None)] = l
249
+ return result
250
+
251
+ def __init__(self, modelLoader : ModelLoader = ModelLoader(), path_model : str = "name", modules : dict[str, Module] = {"UNetBlock_0.DownConvBlock.Activation_1": Module({"Gram": 1, "torch_nn_L1Loss": 1})}, shape: list[int] = [128, 128, 128]) -> None:
252
+ super().__init__()
253
+ self.path_model = path_model
254
+ if self.path_model not in modelsRegister:
255
+ self.model = modelLoader.getModel(train=False, DL_args=os.environ['DEEP_LEARNING_API_CONFIG_PATH'].split("PerceptualLoss")[0]+"PerceptualLoss.Model", DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain", "outputsCriterions", "drop_p"])
256
+ if path_model.startswith("https"):
257
+ state_dict = torch.hub.load_state_dict_from_url(path_model)
258
+ state_dict = {"Model": {self.model.getName() : state_dict["model"]}}
259
+ else:
260
+ state_dict = torch.load(path_model, weights_only=True)
261
+ self.model.load(state_dict)
262
+ modelsRegister[self.path_model] = self.model
263
+ else:
264
+ self.model = modelsRegister[self.path_model]
265
+
266
+ self.shape = shape
267
+ self.mode = "trilinear" if len(shape) == 3 else "bilinear"
268
+ self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
269
+ for name, losses in modules.items():
270
+ self.modules_loss[name.replace(":", ".")] = losses.getLoss()
271
+
272
+ self.model.eval()
273
+ self.model.requires_grad_(False)
274
+ self.models: dict[int, torch.nn.Module] = {}
275
+
276
+ def preprocessing(self, input: torch.Tensor) -> torch.Tensor:
277
+ #if not all([input.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]):
278
+ # input = F.interpolate(input, mode=self.mode, size=tuple(self.shape), align_corners=False).type(torch.float32)
279
+ #if input.shape[1] != self.model.in_channels:
280
+ # input = input.repeat(tuple([1,self.model.in_channels] + [1 for _ in range(len(self.shape))]))
281
+ #input = (input - torch.min(input))/(torch.max(input)-torch.min(input))
282
+ #input = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input)
283
+ return input
284
+
285
+ def _compute(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
286
+ loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
287
+ output = self.preprocessing(output)
288
+ targets = [self.preprocessing(target) for target in targets]
289
+ for zipped_output in zip([output], *[[target] for target in targets]):
290
+ output = zipped_output[0]
291
+ targets = zipped_output[1:]
292
+
293
+ for zipped_layers in list(zip(self.models[output.device.index].get_layers([output], set(self.modules_loss.keys()).copy()), *[self.models[output.device.index].get_layers([target], set(self.modules_loss.keys()).copy()) for target in targets])):
294
+ output_layer = zipped_layers[0][1].view(zipped_layers[0][1].shape[0], zipped_layers[0][1].shape[1], int(np.prod(zipped_layers[0][1].shape[2:])))
295
+ for (loss_function, l), target_layer in zip(self.modules_loss[zipped_layers[0][0]].items(), zipped_layers[1:]):
296
+ target_layer = target_layer[1].view(target_layer[1].shape[0], target_layer[1].shape[1], int(np.prod(target_layer[1].shape[2:])))
297
+ loss = loss+l*loss_function(output_layer.float(), target_layer.float())/output_layer.shape[0]
298
+ return loss
299
+
300
+ def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
301
+ if output.device.index not in self.models:
302
+ del os.environ["device"]
303
+ self.models[output.device.index] = Network.to(copy.deepcopy(self.model).eval(), output.device.index).eval()
304
+ loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
305
+ if len(output.shape) == 5 and len(self.shape) == 2:
306
+ for i in range(output.shape[2]):
307
+ loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])/output.shape[2]
308
+ else:
309
+ loss = self._compute(output, targets)
310
+ return loss.to(output)
311
+
312
+ class KLDivergence(Criterion):
313
+
314
+ def __init__(self, shape: list[int], dim : int = 100, mu : float = 0, std : float = 1) -> None:
315
+ super().__init__()
316
+ self.latentDim = dim
317
+ self.mu = torch.Tensor([mu])
318
+ self.std = torch.Tensor([std])
319
+ self.modelDim = 3
320
+ self.shape = shape
321
+ self.loss = torch.nn.KLDivLoss()
322
+
323
+ def init(self, model : Network, output_group : str, target_group : str) -> str:
324
+ super().init(model, output_group, target_group)
325
+ model._compute_channels_trace(model, model.in_channels, None, None)
326
+
327
+ last_module = model
328
+ for name in output_group.split(".")[:-1]:
329
+ last_module = last_module[name]
330
+
331
+ modules = last_module._modules.copy()
332
+ last_module._modules.clear()
333
+
334
+ for name, value in modules.items():
335
+ last_module._modules[name] = value
336
+ if name == output_group.split(".")[-1]:
337
+ last_module.add_module("LatentDistribution", LatentDistribution(shape = self.shape, latentDim=self.latentDim))
338
+ return ".".join(output_group.split(".")[:-1])+".LatentDistribution.Concat"
339
+
340
+ def forward(self, output: torch.Tensor, targets : list[torch.Tensor]) -> torch.Tensor:
341
+ mu = output[:, 0, :]
342
+ log_std = output[:, 1, :]
343
+ return torch.mean(-0.5 * torch.sum(1 + log_std - mu**2 - torch.exp(log_std), dim = 1), dim = 0)
344
+
345
+ """
346
+ def forward(self, input : torch.Tensor) -> torch.Tensor:
347
+ mu = input[:, 0, :]
348
+ log_std = input[:, 1, :]
349
+
350
+ z = input[:, 2, :]
351
+
352
+ q = torch.distributions.Normal(mu, log_std)
353
+
354
+ target_mu = torch.ones((self.latentDim)).to(input.device)*self.mu.to(input.device)
355
+ target_std = torch.ones((self.latentDim)).to(input.device)*self.std.to(input.device)
356
+
357
+ p = torch.distributions.Normal(target_mu, target_std)
358
+
359
+ log_pz = p.log_prob(z)
360
+ log_qzx = q.log_prob(z)
361
+
362
+ kl = (log_pz - log_qzx)
363
+ kl = kl.sum(-1)
364
+ return kl
365
+ """
366
+
367
+ class Accuracy(Criterion):
368
+
369
+ def __init__(self) -> None:
370
+ super().__init__()
371
+ self.n : int = 0
372
+ self.corrects = torch.zeros((1))
373
+
374
+ def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
375
+ target_0 = targets[0]
376
+ self.n += output.shape[0]
377
+ self.corrects += (torch.argmax(torch.softmax(output, dim=1), dim=1) == target_0).sum().float().cpu()
378
+ return self.corrects/self.n
379
+
380
+ class TripletLoss(Criterion):
381
+
382
+ def __init__(self) -> None:
383
+ super().__init__()
384
+ self.triplet_loss = torch.nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
385
+
386
+ def forward(self, output: torch.Tensor) -> torch.Tensor:
387
+ return self.triplet_loss(output[0], output[1], output[2])
388
+
389
+ class L1LossRepresentation(Criterion):
390
+
391
+ def __init__(self) -> None:
392
+ super().__init__()
393
+ self.loss = torch.nn.L1Loss()
394
+
395
+ def _variance(self, features: torch.Tensor) -> torch.Tensor:
396
+ return torch.mean(torch.clamp(1 - torch.var(features, dim=0), min=0))
397
+
398
+ def forward(self, output: torch.Tensor) -> torch.Tensor:
399
+ return self.loss(output[0], output[1])+ self._variance(output[0]) + self._variance(output[1])
400
+
401
+ """import torch
402
+ import torch.nn as nn
403
+ from torchvision.models import inception_v3, Inception_V3_Weights
404
+ from torchvision.transforms import functional as F
405
+ from scipy import linalg
406
+ import numpy as np
407
+
408
+ class FID(Criterion):
409
+
410
+ class InceptionV3(nn.Module):
411
+
412
+ def __init__(self) -> None:
413
+ super().__init__()
414
+ self.model = inception_v3(weights=Inception_V3_Weights.DEFAULT, transform_input=False)
415
+ self.model.fc = nn.Identity()
416
+ self.model.eval()
417
+
418
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
419
+ return self.model(x)
420
+
421
+ def __init__(self) -> None:
422
+ super().__init__()
423
+ self.inception_model = FID.InceptionV3().cuda()
424
+
425
+ def preprocess_images(image: torch.Tensor) -> torch.Tensor:
426
+ return F.normalize(F.resize(image, (299, 299)).repeat((1,3,1,1)), mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).cuda()
427
+
428
+ def get_features(images: torch.Tensor, model: torch.nn.Module) -> np.ndarray:
429
+ with torch.no_grad():
430
+ features = model(images).cpu().numpy()
431
+ return features
432
+
433
+ def calculate_fid(real_features: np.ndarray, generated_features: np.ndarray) -> float:
434
+ # Calculate mean and covariance statistics
435
+ mu1 = np.mean(real_features, axis=0)
436
+ sigma1 = np.cov(real_features, rowvar=False)
437
+ mu2 = np.mean(generated_features, axis=0)
438
+ sigma2 = np.cov(generated_features, rowvar=False)
439
+
440
+ # Calculate FID score
441
+ diff = mu1 - mu2
442
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
443
+ if np.iscomplexobj(covmean):
444
+ covmean = covmean.real
445
+
446
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
447
+
448
+ def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
449
+ real_images = FID.preprocess_images(targets[0].squeeze(0).permute([1, 0, 2, 3]))
450
+ generated_images = FID.preprocess_images(output.squeeze(0).permute([1, 0, 2, 3]))
451
+
452
+ real_features = FID.get_features(real_images, self.inception_model)
453
+ generated_features = FID.get_features(generated_images, self.inception_model)
454
+
455
+ return FID.calculate_fid(real_features, generated_features)
456
+ """
457
+
458
+ """class MutualInformationLoss(torch.nn.Module):
459
+ def __init__(self, num_bins: int = 23, sigma_ratio: float = 0.5, smooth_nr: float = 1e-7, smooth_dr: float = 1e-7) -> None:
460
+ super().__init__()
461
+ bin_centers = torch.linspace(0.0, 1.0, num_bins)
462
+ sigma = torch.mean(bin_centers[1:] - bin_centers[:-1]) * sigma_ratio
463
+ self.num_bins = num_bins
464
+ self.preterm = 1 / (2 * sigma**2)
465
+ self.bin_centers = bin_centers[None, None, ...]
466
+ self.smooth_nr = float(smooth_nr)
467
+ self.smooth_dr = float(smooth_dr)
468
+
469
+ def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
470
+ pred_weight, pred_probability = self.parzen_windowing_gaussian(pred)
471
+ target_weight, target_probability = self.parzen_windowing_gaussian(target)
472
+ return pred_weight, pred_probability, target_weight, target_probability
473
+
474
+
475
+ def parzen_windowing_gaussian(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
476
+ img = torch.clamp(img, 0, 1)
477
+ img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
478
+ weight = torch.exp(-self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2) # (batch, num_sample, num_bin)
479
+ weight = weight / torch.sum(weight, dim=-1, keepdim=True) # (batch, num_sample, num_bin)
480
+ probability = torch.mean(weight, dim=-2, keepdim=True) # (batch, 1, num_bin)
481
+ return weight, probability
482
+
483
+ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
484
+ wa, pa, wb, pb = self.parzen_windowing(pred, target) # (batch, num_sample, num_bin), (batch, 1, num_bin)
485
+ pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
486
+ papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
487
+ mi = torch.sum(pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)) # (batch)
488
+ return torch.mean(mi).neg() # average over the batch and channel ndims"""
@@ -0,0 +1,49 @@
1
+ import torch
2
+ import numpy as np
3
+ from abc import abstractmethod
4
+ from KonfAI.konfai.utils.config import config
5
+
6
+ class Scheduler():
7
+
8
+ @config("Scheduler")
9
+ def __init__(self, start_value: float) -> None:
10
+ self.baseValue = float(start_value)
11
+ self.it = 0
12
+
13
+ def step(self, it: int):
14
+ self.it = it
15
+
16
+ @abstractmethod
17
+ def get_value(self) -> float:
18
+ pass
19
+
20
+ class Constant(Scheduler):
21
+
22
+ @config("Constant")
23
+ def __init__(self, value: float = 1):
24
+ super().__init__(value)
25
+
26
+ def get_value(self) -> float:
27
+ return self.baseValue
28
+
29
+ class CosineAnnealing(Scheduler):
30
+
31
+ @config("CosineAnnealing")
32
+ def __init__(self, start_value: float, eta_min: float = 0.00001, T_max: int = 100):
33
+ super().__init__(start_value)
34
+ self.eta_min = eta_min
35
+ self.T_max = T_max
36
+
37
+ def get_value(self):
38
+ return self.eta_min + (self.baseValue - self.eta_min) *(1 + np.cos(self.it * torch.pi / self.T_max)) / 2
39
+
40
+ class CosineAnnealing(Scheduler):
41
+
42
+ @config("CosineAnnealing")
43
+ def __init__(self, start_value: float, eta_min: float = 0.00001, T_max: int = 100):
44
+ super().__init__(start_value)
45
+ self.eta_min = eta_min
46
+ self.T_max = T_max
47
+
48
+ def get_value(self):
49
+ return self.eta_min + (self.baseValue - self.eta_min) *(1 + np.cos(self.it * torch.pi / self.T_max)) / 2