quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,331 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, cast
4
+
5
+ import numpy as np
6
+ import timm
7
+ import torch
8
+ import torchmetrics
9
+ import torchmetrics.functional as TMF
10
+ from pytorch_grad_cam import GradCAM
11
+ from scipy import ndimage
12
+ from torch import nn, optim
13
+
14
+ from quadra.models.classification import BaseNetworkBuilder
15
+ from quadra.modules.base import BaseLightningModule
16
+ from quadra.utils.models import is_vision_transformer
17
+ from quadra.utils.utils import get_logger
18
+ from quadra.utils.vit_explainability import VitAttentionGradRollout
19
+
20
+ log = get_logger(__name__)
21
+
22
+
23
+ class ClassificationModule(BaseLightningModule):
24
+ """Lightning module for classification tasks.
25
+
26
+ Args:
27
+ model: Feature extractor as PyTorch `torch.nn.Module`
28
+ criterion: the loss to be applied as a PyTorch `torch.nn.Module`.
29
+ optimizer: optimizer of the training. Defaults to None.
30
+ lr_scheduler: Pytorch learning rate scheduler.
31
+ If None a default ReduceLROnPlateau is used.
32
+ Defaults to None.
33
+ lr_scheduler_interval: the learning rate scheduler interval.
34
+ Defaults to "epoch".
35
+ gradcam (bool): Whether to compute gradcam during prediction step
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ model: nn.Module,
41
+ criterion: nn.Module,
42
+ optimizer: None | optim.Optimizer = None,
43
+ lr_scheduler: None | object = None,
44
+ lr_scheduler_interval: str | None = "epoch",
45
+ gradcam: bool = False,
46
+ ):
47
+ super().__init__(model, optimizer, lr_scheduler, lr_scheduler_interval)
48
+
49
+ self.criterion = criterion
50
+ self.gradcam = gradcam
51
+ self.train_acc = torchmetrics.Accuracy()
52
+ self.val_acc = torchmetrics.Accuracy()
53
+ self.test_acc = torchmetrics.Accuracy()
54
+ self.cam: GradCAM | None = None
55
+ self.grad_rollout: VitAttentionGradRollout | None = None
56
+
57
+ if not isinstance(self.model.features_extractor, timm.models.resnet.ResNet) and not is_vision_transformer(
58
+ cast(BaseNetworkBuilder, self.model).features_extractor
59
+ ):
60
+ log.warning(
61
+ "Backbone not compatible with gradcam. Only timm ResNets, timm ViTs and TorchHub dinoViTs supported",
62
+ )
63
+ self.gradcam = False
64
+
65
+ self.original_requires_grads: list[bool] = []
66
+
67
+ def forward(self, x: torch.Tensor):
68
+ return self.model(x)
69
+
70
+ def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int):
71
+ # pylint: disable=unused-argument
72
+ im, target = batch
73
+ outputs = self(im)
74
+ loss = self.criterion(outputs, target)
75
+
76
+ self.log_dict(
77
+ {"train_loss": loss},
78
+ on_epoch=True,
79
+ on_step=True,
80
+ logger=True,
81
+ prog_bar=True,
82
+ )
83
+ self.log_dict(
84
+ {"train_acc": self.train_acc(outputs.argmax(1), target)},
85
+ on_step=False,
86
+ on_epoch=True,
87
+ logger=True,
88
+ prog_bar=True,
89
+ )
90
+ return loss
91
+
92
+ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int):
93
+ # pylint: disable=unused-argument
94
+ im, target = batch
95
+ outputs = self(im)
96
+ loss = self.criterion(outputs, target)
97
+
98
+ self.log_dict(
99
+ {"val_loss": loss},
100
+ on_epoch=True,
101
+ on_step=True,
102
+ logger=True,
103
+ prog_bar=True,
104
+ )
105
+ self.log_dict(
106
+ {"val_acc": self.val_acc(outputs.argmax(1), target)},
107
+ on_step=False,
108
+ on_epoch=True,
109
+ logger=True,
110
+ prog_bar=True,
111
+ )
112
+ return loss
113
+
114
+ def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int):
115
+ # pylint: disable=unused-argument
116
+ im, target = batch
117
+ outputs = self(im)
118
+
119
+ loss = self.criterion(outputs, target)
120
+
121
+ self.log_dict(
122
+ {"test_loss": loss},
123
+ on_epoch=True,
124
+ on_step=True,
125
+ logger=True,
126
+ prog_bar=False,
127
+ )
128
+ self.log_dict(
129
+ {"test_acc": self.test_acc(outputs.argmax(1), target)},
130
+ on_step=False,
131
+ on_epoch=True,
132
+ logger=True,
133
+ prog_bar=False,
134
+ )
135
+
136
+ def prepare_gradcam(self) -> None:
137
+ """Instantiate gradcam handlers."""
138
+ if isinstance(self.model.features_extractor, timm.models.resnet.ResNet):
139
+ target_layers = [cast(BaseNetworkBuilder, self.model).features_extractor.layer4[-1]]
140
+
141
+ # Get model current device
142
+ device = next(self.model.parameters()).device
143
+
144
+ self.cam = GradCAM(
145
+ model=self.model,
146
+ target_layers=target_layers,
147
+ use_cuda=device.type == "cuda",
148
+ )
149
+ # Activating gradients
150
+ for p in self.model.features_extractor.layer4[-1].parameters():
151
+ p.requires_grad = True
152
+ elif is_vision_transformer(cast(BaseNetworkBuilder, self.model).features_extractor):
153
+ self.grad_rollout = VitAttentionGradRollout(self.model)
154
+ else:
155
+ log.warning("Gradcam not implemented for this backbone, it won't be computed")
156
+ self.original_requires_grads.clear()
157
+ self.gradcam = False
158
+
159
+ def on_predict_start(self) -> None:
160
+ """If gradcam, prepares gradcam and saves params requires_grad state."""
161
+ if self.gradcam:
162
+ # Saving params requires_grad state
163
+ for p in self.model.parameters():
164
+ self.original_requires_grads.append(p.requires_grad)
165
+ self.prepare_gradcam()
166
+
167
+ return super().on_predict_start()
168
+
169
+ def on_predict_end(self) -> None:
170
+ """If we computed gradcam, requires_grad values are reset to original value."""
171
+ if self.gradcam:
172
+ # Get back to initial state
173
+ for i, p in enumerate(self.model.parameters()):
174
+ p.requires_grad = self.original_requires_grads[i]
175
+
176
+ # We are using GradCAM package only for resnets at the moment
177
+ if isinstance(self.model.features_extractor, timm.models.resnet.ResNet) and self.cam is not None:
178
+ # Needed to solve jitting bug
179
+ self.cam.activations_and_grads.release()
180
+ elif (
181
+ is_vision_transformer(cast(BaseNetworkBuilder, self.model).features_extractor)
182
+ and self.grad_rollout is not None
183
+ ):
184
+ for handle in self.grad_rollout.f_hook_handles:
185
+ handle.remove()
186
+ for handle in self.grad_rollout.b_hook_handles:
187
+ handle.remove()
188
+
189
+ return super().on_predict_end()
190
+
191
+ # pylint: disable=unused-argument
192
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
193
+ """Prediction step.
194
+
195
+ Args:
196
+ batch: Tuple composed by (image, target)
197
+ batch_idx: Batch index
198
+ dataloader_idx: Dataloader index
199
+ Returns:
200
+ Tuple containing:
201
+ predicted_classes: indexes of predicted classes
202
+ grayscale_cam: gray scale gradcams
203
+ """
204
+ im, _ = batch
205
+ outputs = self(im)
206
+ probs = torch.softmax(outputs, dim=1)
207
+ predicted_classes = torch.max(probs, dim=1).indices.tolist()
208
+ if self.gradcam:
209
+ # inference_mode set to false because gradcam needs gradients
210
+ with torch.inference_mode(False):
211
+ im = im.clone()
212
+
213
+ if isinstance(self.model.features_extractor, timm.models.resnet.ResNet) and self.cam:
214
+ grayscale_cam = self.cam(input_tensor=im, targets=None)
215
+ elif (
216
+ is_vision_transformer(cast(BaseNetworkBuilder, self.model).features_extractor) and self.grad_rollout
217
+ ):
218
+ grayscale_cam_low_res = self.grad_rollout(input_tensor=im, targets_list=predicted_classes)
219
+ orig_shape = grayscale_cam_low_res.shape
220
+ new_shape = (orig_shape[0], im.shape[2], im.shape[3])
221
+ zoom_factors = tuple(np.array(new_shape) / np.array(orig_shape))
222
+ grayscale_cam = ndimage.zoom(grayscale_cam_low_res, zoom_factors, order=1)
223
+ else:
224
+ grayscale_cam = None
225
+ return predicted_classes, grayscale_cam, torch.max(probs, dim=1)[0].tolist()
226
+
227
+
228
+ class MultilabelClassificationModule(BaseLightningModule):
229
+ """SklearnClassification model: train a generic SklearnClassification model for a multilabel
230
+ problem.
231
+
232
+ Args:
233
+ model: Feature extractor as PyTorch `torch.nn.Module`
234
+ criterion: the loss to be applied as a PyTorch `torch.nn.Module`.
235
+ optimizer: optimizer of the training. Defaults to None.
236
+ lr_scheduler: Pytorch learning rate scheduler.
237
+ If None a default ReduceLROnPlateau is used.
238
+ Defaults to None.
239
+ lr_scheduler_interval: the learning rate scheduler interval.
240
+ Defaults to "epoch".
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ model: nn.Sequential,
246
+ criterion: nn.Module,
247
+ optimizer: None | optim.Optimizer = None,
248
+ lr_scheduler: None | object = None,
249
+ lr_scheduler_interval: str | None = "epoch",
250
+ gradcam: bool = False,
251
+ ):
252
+ super().__init__(model, optimizer, lr_scheduler, lr_scheduler_interval)
253
+ self.criterion = criterion
254
+ self.gradcam = gradcam
255
+
256
+ # TODO: can we use gradcam with more backbones?
257
+ if self.gradcam:
258
+ if not isinstance(model[0].features_extractor, timm.models.resnet.ResNet):
259
+ log.warning(
260
+ "Backbone must be compatible with gradcam, at the moment only ResNets supported, disabling gradcam"
261
+ )
262
+ self.gradcam = False
263
+ else:
264
+ target_layers = [model[0].features_extractor.layer4[-1]]
265
+ self.cam = GradCAM(model=model, target_layers=target_layers, use_cuda=torch.cuda.is_available())
266
+
267
+ def forward(self, x):
268
+ return self.model(x)
269
+
270
+ def training_step(self, batch, batch_idx):
271
+ # pylint: disable=unused-argument
272
+ im, target = batch
273
+ outputs = self(im)
274
+ with torch.no_grad():
275
+ outputs_sig = torch.sigmoid(outputs)
276
+ loss = self.criterion(outputs, target)
277
+
278
+ self.log_dict(
279
+ {
280
+ "t_loss": loss,
281
+ "t_map": TMF.label_ranking_average_precision(outputs_sig, target.bool()),
282
+ "t_f1": TMF.f1_score(outputs_sig, target.bool(), average="samples"),
283
+ },
284
+ on_epoch=True,
285
+ on_step=False,
286
+ logger=True,
287
+ prog_bar=True,
288
+ )
289
+ return loss
290
+
291
+ def validation_step(self, batch, batch_idx):
292
+ # pylint: disable=unused-argument
293
+ im, target = batch
294
+ outputs = self(im)
295
+ with torch.no_grad():
296
+ outputs_sig = torch.sigmoid(outputs)
297
+ loss = self.criterion(outputs, target)
298
+
299
+ self.log_dict(
300
+ {
301
+ "val_loss": loss,
302
+ "val_map": TMF.label_ranking_average_precision(outputs_sig, target.bool()),
303
+ "val_f1": TMF.f1_score(outputs_sig, target.bool(), average="samples"),
304
+ },
305
+ on_epoch=True,
306
+ on_step=False,
307
+ logger=True,
308
+ prog_bar=True,
309
+ )
310
+ return loss
311
+
312
+ def test_step(self, batch, batch_idx):
313
+ # pylint: disable=unused-argument
314
+ im, target = batch
315
+ outputs = self(im)
316
+ with torch.no_grad():
317
+ outputs_sig = torch.sigmoid(outputs)
318
+ loss = self.criterion(outputs, target)
319
+
320
+ self.log_dict(
321
+ {
322
+ "test_loss": loss,
323
+ "test_map": TMF.label_ranking_average_precision(outputs_sig, target.bool()),
324
+ "test_f1": TMF.f1_score(outputs_sig, target.bool(), average="samples"),
325
+ },
326
+ on_epoch=True,
327
+ on_step=True,
328
+ logger=True,
329
+ prog_bar=False,
330
+ )
331
+ return loss
@@ -0,0 +1,17 @@
1
+ from .barlowtwins import BarlowTwins
2
+ from .byol import BYOL
3
+ from .dino import Dino
4
+ from .idmm import IDMM
5
+ from .simclr import SimCLR
6
+ from .simsiam import SimSIAM
7
+ from .vicreg import VICReg
8
+
9
+ __all__ = [
10
+ "BarlowTwins",
11
+ "BYOL",
12
+ "Dino",
13
+ "IDMM",
14
+ "SimCLR",
15
+ "SimSIAM",
16
+ "VICReg",
17
+ ]
@@ -0,0 +1,59 @@
1
+ from __future__ import annotations
2
+
3
+ import sklearn
4
+ import torch
5
+ from torch import nn, optim
6
+
7
+ from quadra.modules.base import SSLModule
8
+
9
+
10
+ class BarlowTwins(SSLModule):
11
+ """BarlowTwins model.
12
+
13
+ Args:
14
+ model: Network Module used for extract features
15
+ projection_mlp: Module to project extracted features
16
+ criterion: SSL loss to be applied
17
+ classifier: Standard sklearn classifier. Defaults to None.
18
+ optimizer: optimizer of the training. If None a default Adam is used. Defaults to None.
19
+ lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.
20
+ lr_scheduler_interval: interval at which the lr scheduler is updated. Defaults to "epoch".
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ model: nn.Module,
26
+ projection_mlp: nn.Module,
27
+ criterion: nn.Module,
28
+ classifier: sklearn.base.ClassifierMixin | None = None,
29
+ optimizer: optim.Optimizer | None = None,
30
+ lr_scheduler: object | None = None,
31
+ lr_scheduler_interval: str | None = "epoch",
32
+ ):
33
+ super().__init__(model, criterion, classifier, optimizer, lr_scheduler, lr_scheduler_interval)
34
+ # self.save_hyperparameters()
35
+ self.projection_mlp = projection_mlp
36
+ self.criterion = criterion
37
+
38
+ def forward(self, x):
39
+ x = self.model(x)
40
+ z = self.projection_mlp(x)
41
+ return z
42
+
43
+ def training_step(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int):
44
+ # pylint: disable=unused-argument
45
+ # Compute loss
46
+ (im_x, im_y), _ = batch
47
+ z1 = self(im_x)
48
+ z2 = self(im_y)
49
+ loss = self.criterion(z1, z2)
50
+
51
+ self.log(
52
+ "loss",
53
+ loss,
54
+ on_epoch=True,
55
+ on_step=True,
56
+ logger=True,
57
+ prog_bar=True,
58
+ )
59
+ return loss
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from collections.abc import Callable, Sized
5
+ from typing import Any
6
+
7
+ import sklearn
8
+ import torch
9
+ from pytorch_lightning.core.optimizer import LightningOptimizer
10
+ from torch import nn
11
+ from torch.optim import Optimizer
12
+
13
+ from quadra.modules.base import SSLModule
14
+
15
+
16
+ class BYOL(SSLModule):
17
+ """BYOL module, inspired by https://arxiv.org/abs/2006.07733.
18
+
19
+ Args:
20
+ student : student model.
21
+ teacher : teacher model.
22
+ student_projection_mlp : student projection MLP.
23
+ student_prediction_mlp : student prediction MLP.
24
+ teacher_projection_mlp : teacher projection MLP.
25
+ criterion : loss function.
26
+ classifier: Standard sklearn classifier.
27
+ optimizer: optimizer of the training. If None a default Adam is used.
28
+ lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used.
29
+ lr_scheduler_interval: interval at which the lr scheduler is updated.
30
+ teacher_momentum: momentum of the teacher parameters.
31
+ teacher_momentum_cosine_decay: whether to use cosine decay for the teacher momentum. Default: True
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ student: nn.Module,
37
+ teacher: nn.Module,
38
+ student_projection_mlp: nn.Module,
39
+ student_prediction_mlp: nn.Module,
40
+ teacher_projection_mlp: nn.Module,
41
+ criterion: nn.Module,
42
+ classifier: sklearn.base.ClassifierMixin | None = None,
43
+ optimizer: Optimizer | None = None,
44
+ lr_scheduler: object | None = None,
45
+ lr_scheduler_interval: str | None = "epoch",
46
+ teacher_momentum: float = 0.9995,
47
+ teacher_momentum_cosine_decay: bool | None = True,
48
+ ):
49
+ super().__init__(
50
+ model=student,
51
+ criterion=criterion,
52
+ classifier=classifier,
53
+ optimizer=optimizer,
54
+ lr_scheduler=lr_scheduler,
55
+ lr_scheduler_interval=lr_scheduler_interval,
56
+ )
57
+ # Student model
58
+ self.max_steps: int
59
+ self.student_projection_mlp = student_projection_mlp
60
+ self.student_prediction_mlp = student_prediction_mlp
61
+
62
+ # Teacher model
63
+ self.teacher = teacher
64
+ self.teacher_projection_mlp = teacher_projection_mlp
65
+ self.teacher_initialized = False
66
+ self.teacher_momentum = teacher_momentum
67
+ self.teacher_momentum_cosine_decay = teacher_momentum_cosine_decay
68
+
69
+ self.initialize_teacher()
70
+
71
+ def initialize_teacher(self):
72
+ """Initialize teacher from the state dict of the student one,
73
+ checking also that student model requires greadient correctly.
74
+ """
75
+ self.teacher_projection_mlp.load_state_dict(self.student_projection_mlp.state_dict())
76
+ for p in self.teacher_projection_mlp.parameters():
77
+ p.requires_grad = False
78
+
79
+ self.teacher.load_state_dict(self.model.state_dict())
80
+ for p in self.teacher.parameters():
81
+ p.requires_grad = False
82
+
83
+ for p in self.student_projection_mlp.parameters():
84
+ assert p.requires_grad is True
85
+ for p in self.student_prediction_mlp.parameters():
86
+ assert p.requires_grad is True
87
+
88
+ self.teacher_initialized = True
89
+
90
+ def update_teacher(self):
91
+ """Update teacher given `self.teacher_momentum` by an exponential moving average
92
+ of the student parameters, that is: theta_t * tau + theta_s * (1 - tau), where
93
+ `theta_{s,t}` are the parameters of the student and the teacher model, while `tau` is the
94
+ teacher momentum. If `self.teacher_momentum_cosine_decay` is True, then the teacher
95
+ momentum will follow a cosine scheduling from `self.teacher_momentum` to 1:
96
+ tau = 1 - (1 - tau) * (cos(pi * t / T) + 1) / 2, where `t` is the current step and
97
+ `T` is the max number of steps.
98
+ """
99
+ with torch.no_grad():
100
+ if self.teacher_momentum_cosine_decay:
101
+ teacher_momentum = (
102
+ 1
103
+ - (1 - self.teacher_momentum)
104
+ * (math.cos(math.pi * self.trainer.global_step / self.max_steps) + 1)
105
+ / 2
106
+ )
107
+ else:
108
+ teacher_momentum = self.teacher_momentum
109
+ self.log("teacher_momentum", teacher_momentum, prog_bar=True)
110
+ for student_ps, teacher_ps in zip(
111
+ list(self.model.parameters()) + list(self.student_projection_mlp.parameters()),
112
+ list(self.teacher.parameters()) + list(self.teacher_projection_mlp.parameters()),
113
+ ):
114
+ teacher_ps.data = teacher_ps.data * teacher_momentum + (1 - teacher_momentum) * student_ps.data
115
+
116
+ def on_train_start(self) -> None:
117
+ if isinstance(self.trainer.train_dataloader, Sized) and isinstance(self.trainer.max_epochs, int):
118
+ self.max_steps = len(self.trainer.train_dataloader) * self.trainer.max_epochs
119
+ else:
120
+ raise ValueError("BYOL requires `max_epochs` to be set and `train_dataloader` to be initialized.")
121
+
122
+ def training_step(self, batch: tuple[list[torch.Tensor], torch.Tensor], *args: Any) -> torch.Tensor:
123
+ [image1, image2], _ = batch
124
+
125
+ online_pred_one = self.student_prediction_mlp(self.student_projection_mlp(self.model(image1)))
126
+ online_pred_two = self.student_prediction_mlp(self.student_projection_mlp(self.model(image2)))
127
+
128
+ with torch.no_grad():
129
+ target_proj_one = self.teacher_projection_mlp(self.teacher(image1))
130
+ target_proj_two = self.teacher_projection_mlp(self.teacher(image2))
131
+
132
+ loss_one = self.criterion(online_pred_one, target_proj_two.detach())
133
+ loss_two = self.criterion(online_pred_two, target_proj_one.detach())
134
+ loss = loss_one + loss_two
135
+
136
+ self.log(name="loss", value=loss, on_step=True, on_epoch=True, prog_bar=True)
137
+ return loss
138
+
139
+ def optimizer_step(
140
+ self,
141
+ epoch: int,
142
+ batch_idx: int,
143
+ optimizer: Optimizer | LightningOptimizer,
144
+ optimizer_closure: Callable[[], Any] | None = None,
145
+ ) -> None:
146
+ """Override optimizer step to update the teacher parameters."""
147
+ super().optimizer_step(
148
+ epoch,
149
+ batch_idx,
150
+ optimizer,
151
+ optimizer_closure=optimizer_closure,
152
+ )
153
+ self.update_teacher()
154
+
155
+ def calculate_accuracy(self, batch):
156
+ """Calculate accuracy on the given batch."""
157
+ images, labels = batch
158
+ embedding = self.model(images).detach().cpu().numpy()
159
+ predictions = self.classifier.predict(embedding)
160
+ labels = labels.detach()
161
+ acc = self.val_acc(torch.tensor(predictions, device=self.device), labels)
162
+
163
+ return acc
164
+
165
+ def on_test_epoch_start(self) -> None:
166
+ self.fit_estimator()
167
+
168
+ def test_step(self, batch, *args: list[Any]) -> None:
169
+ """Calculate accuracy on the test set for the given batch."""
170
+ acc = self.calculate_accuracy(batch)
171
+ self.log(name="test_acc", value=acc, on_step=False, on_epoch=True, prog_bar=True)
172
+ return acc