quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,98 @@
1
+ from __future__ import annotations
2
+
3
+ import sklearn
4
+ import timm
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from quadra.modules.base import SSLModule
9
+
10
+
11
+ class IDMM(SSLModule):
12
+ """IDMM model.
13
+
14
+ Args:
15
+ model: backbone model
16
+ prediction_mlp: student prediction MLP
17
+ criterion: loss function
18
+ multiview_loss: whether to use the multiview loss as definied in https://arxiv.org/abs/2201.10728.
19
+ Defaults to True.
20
+ mixup_fn: the mixup/cutmix function to be applied to a batch of images.
21
+ Defaults to None.
22
+ classifier: Standard sklearn classifier
23
+ optimizer: optimizer of the training. If None a default Adam is used.
24
+ lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used.
25
+ lr_scheduler_interval: interval at which the lr scheduler is updated.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ model: torch.nn.Module,
31
+ prediction_mlp: torch.nn.Module,
32
+ criterion: torch.nn.Module,
33
+ multiview_loss: bool = True,
34
+ mixup_fn: timm.data.Mixup | None = None,
35
+ classifier: sklearn.base.ClassifierMixin | None = None,
36
+ optimizer: torch.optim.Optimizer | None = None,
37
+ lr_scheduler: object | None = None,
38
+ lr_scheduler_interval: str | None = "epoch",
39
+ ):
40
+ super().__init__(
41
+ model,
42
+ criterion,
43
+ classifier,
44
+ optimizer,
45
+ lr_scheduler,
46
+ lr_scheduler_interval,
47
+ )
48
+ # self.save_hyperparameters()
49
+ self.prediction_mlp = prediction_mlp
50
+ self.mixup_fn = mixup_fn
51
+ self.multiview_loss = multiview_loss
52
+
53
+ def forward(self, x):
54
+ z = self.model(x)
55
+ p = self.prediction_mlp(z)
56
+ return z, p
57
+
58
+ def training_step(self, batch, batch_idx):
59
+ # pylint: disable=unused-argument
60
+ # Compute loss
61
+ if self.multiview_loss:
62
+ im_x, im_y, target = batch
63
+
64
+ # Contrastive loss
65
+ za, _ = self(im_x)
66
+ zb, _ = self(im_y)
67
+ za = F.normalize(za, dim=-1)
68
+ zb = F.normalize(zb, dim=-1)
69
+ s_aa = za.T @ za
70
+ s_ab = za.T @ zb
71
+ contrastive = (
72
+ torch.log(torch.exp(s_aa).sum(-1))
73
+ - torch.diagonal(s_aa)
74
+ + torch.log(torch.exp(s_ab).sum(-1))
75
+ - torch.diagonal(s_ab)
76
+ )
77
+
78
+ # Instance discrimination
79
+ if self.mixup_fn is not None:
80
+ im_x, target = self.mixup_fn(im_x, target)
81
+ _, pred = self(im_x)
82
+ loss = self.criterion(pred, target) + contrastive.mean()
83
+ else:
84
+ im_x, target = batch
85
+ if self.mixup_fn is not None:
86
+ im_x, target = self.mixup_fn(im_x, target)
87
+ pred = self(im_x)
88
+ loss = self.criterion(pred, target)
89
+
90
+ self.log(
91
+ "loss",
92
+ loss,
93
+ on_epoch=True,
94
+ on_step=True,
95
+ logger=True,
96
+ prog_bar=True,
97
+ )
98
+ return loss
@@ -0,0 +1,73 @@
1
+ from __future__ import annotations
2
+
3
+ import sklearn
4
+ import torch
5
+ from torch import nn
6
+
7
+ from quadra.modules.base import SSLModule
8
+
9
+
10
+ class SimCLR(SSLModule):
11
+ """SIMCLR class.
12
+
13
+ Args:
14
+ model: Feature extractor as pytorch `torch.nn.Module`
15
+ projection_mlp: projection head as
16
+ pytorch `torch.nn.Module`
17
+ criterion: SSL loss to be applied
18
+ classifier: Standard sklearn classifier. Defaults to None.
19
+ optimizer: optimizer of the training. If None a default Adam is used. Defaults to None.
20
+ lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.
21
+ lr_scheduler_interval: interval at which the lr scheduler is updated. Defaults to "epoch".
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ model: nn.Module,
27
+ projection_mlp: nn.Module,
28
+ criterion: torch.nn.Module,
29
+ classifier: sklearn.base.ClassifierMixin | None = None,
30
+ optimizer: torch.optim.Optimizer | None = None,
31
+ lr_scheduler: object | None = None,
32
+ lr_scheduler_interval: str | None = "epoch",
33
+ ):
34
+ super().__init__(
35
+ model,
36
+ criterion,
37
+ classifier,
38
+ optimizer,
39
+ lr_scheduler,
40
+ lr_scheduler_interval,
41
+ )
42
+ self.projection_mlp = projection_mlp
43
+
44
+ def forward(self, x):
45
+ x = self.model(x)
46
+ x = self.projection_mlp(x)
47
+ return x
48
+
49
+ def training_step(
50
+ self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int
51
+ ) -> torch.Tensor:
52
+ """Args:
53
+ batch: The batch of data
54
+ batch_idx: The index of the batch.
55
+
56
+ Returns:
57
+ The computed loss
58
+ """
59
+ # pylint: disable=unused-argument
60
+ (im_x, im_y), _ = batch
61
+ emb_x = self(im_x)
62
+ emb_y = self(im_y)
63
+ loss = self.criterion(emb_x, emb_y)
64
+
65
+ self.log(
66
+ "loss",
67
+ loss,
68
+ on_epoch=True,
69
+ on_step=True,
70
+ logger=True,
71
+ prog_bar=True,
72
+ )
73
+ return loss
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations
2
+
3
+ import sklearn
4
+ import torch
5
+
6
+ from quadra.modules.base import SSLModule
7
+
8
+
9
+ class SimSIAM(SSLModule):
10
+ """SimSIAM model.
11
+
12
+ Args:
13
+ model: Feature extractor as pytorch `torch.nn.Module`
14
+ projection_mlp: optional projection head as pytorch `torch.nn.Module`
15
+ prediction_mlp: optional predicition head as pytorch `torch.nn.Module`
16
+ criterion: loss to be applied.
17
+ classifier: Standard sklearn classifier.
18
+ optimizer: optimizer of the training. If None a default Adam is used.
19
+ lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used.
20
+ lr_scheduler_interval: interval at which the lr scheduler is updated.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ model: torch.nn.Module,
26
+ projection_mlp: torch.nn.Module,
27
+ prediction_mlp: torch.nn.Module,
28
+ criterion: torch.nn.Module,
29
+ classifier: sklearn.base.ClassifierMixin | None = None,
30
+ optimizer: torch.optim.Optimizer | None = None,
31
+ lr_scheduler: object | None = None,
32
+ lr_scheduler_interval: str | None = "epoch",
33
+ ):
34
+ super().__init__(
35
+ model,
36
+ criterion,
37
+ classifier,
38
+ optimizer,
39
+ lr_scheduler,
40
+ lr_scheduler_interval,
41
+ )
42
+ # self.save_hyperparameters()
43
+ self.projection_mlp = projection_mlp
44
+ self.prediction_mlp = prediction_mlp
45
+
46
+ def forward(self, x):
47
+ x = self.model(x)
48
+ z = self.projection_mlp(x)
49
+ p = self.prediction_mlp(z)
50
+ return p, z.detach()
51
+
52
+ def training_step(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int):
53
+ # pylint: disable=unused-argument
54
+ # Compute loss
55
+ (im_x, im_y), _ = batch
56
+ p1, z1 = self(im_x)
57
+ p2, z2 = self(im_y)
58
+ loss = self.criterion(p1, p2, z1, z2)
59
+
60
+ self.log(
61
+ "loss",
62
+ loss,
63
+ on_epoch=True,
64
+ on_step=True,
65
+ logger=True,
66
+ prog_bar=True,
67
+ )
68
+ return loss
@@ -0,0 +1,67 @@
1
+ from __future__ import annotations
2
+
3
+ import sklearn
4
+ import torch
5
+ from torch import nn, optim
6
+
7
+ from quadra.modules.base import SSLModule
8
+
9
+
10
+ class VICReg(SSLModule):
11
+ """VICReg model.
12
+
13
+ Args:
14
+ model: Network Module used for extract features
15
+ projection_mlp: Module to project extracted features
16
+ criterion: SSL loss to be applied.
17
+ classifier: Standard sklearn classifier. Defaults to None.
18
+ optimizer: optimizer of the training. If None a default Adam is used. Defaults to None.
19
+ lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.
20
+ lr_scheduler_interval: interval at which the lr scheduler is updated. Defaults to "epoch".
21
+
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ model: nn.Module,
27
+ projection_mlp: nn.Module,
28
+ criterion: nn.Module,
29
+ classifier: sklearn.base.ClassifierMixin | None = None,
30
+ optimizer: optim.Optimizer | None = None,
31
+ lr_scheduler: object | None = None,
32
+ lr_scheduler_interval: str | None = "epoch",
33
+ ):
34
+ super().__init__(
35
+ model,
36
+ criterion,
37
+ classifier,
38
+ optimizer,
39
+ lr_scheduler,
40
+ lr_scheduler_interval,
41
+ )
42
+ # self.save_hyperparameters()
43
+ self.projection_mlp = projection_mlp
44
+ self.criterion = criterion
45
+
46
+ def forward(self, x):
47
+ x = self.model(x)
48
+ z = self.projection_mlp(x)
49
+ return z
50
+
51
+ def training_step(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int):
52
+ # pylint: disable=unused-argument
53
+ # Compute loss
54
+ (im_x, im_y), _ = batch
55
+ z1 = self(im_x)
56
+ z2 = self(im_y)
57
+ loss = self.criterion(z1, z2)
58
+
59
+ self.log(
60
+ "loss",
61
+ loss,
62
+ on_epoch=True,
63
+ on_step=True,
64
+ logger=True,
65
+ prog_bar=True,
66
+ )
67
+ return loss
@@ -0,0 +1,4 @@
1
+ from .lars import LARS
2
+ from .sam import SAM
3
+
4
+ __all__ = ["LARS", "SAM"]
@@ -0,0 +1,153 @@
1
+ """References:
2
+ - https://arxiv.org/pdf/1708.03888.pdf
3
+ - https://github.com/pytorch/pytorch/blob/1.6/torch/optim/sgd.py.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from collections.abc import Callable
9
+
10
+ import torch
11
+ from torch.nn import Parameter
12
+ from torch.optim.optimizer import Optimizer, _RequiredParameter, required
13
+
14
+
15
+ class LARS(Optimizer):
16
+ r"""Extends SGD in PyTorch with LARS scaling from the paper
17
+ `Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>`_.
18
+
19
+ Args:
20
+ params: iterable of parameters to optimize or dicts defining
21
+ parameter groups
22
+ lr: learning rate
23
+ momentum: momentum factor (default: 0)
24
+ weight_decay: weight decay (L2 penalty) (default: 0)
25
+ dampening: dampening for momentum (default: 0)
26
+ nesterov: enables Nesterov momentum (default: False)
27
+ trust_coefficient: trust coefficient for computing LR (default: 0.001)
28
+ eps: eps for division denominator (default: 1e-8).
29
+
30
+ Example:
31
+ >>> model = torch.nn.Linear(10, 1)
32
+ >>> input = torch.Tensor(10)
33
+ >>> target = torch.Tensor([1.])
34
+ >>> loss_fn = lambda input, target: (input - target) ** 2
35
+ >>> #
36
+ >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
37
+ >>> optimizer.zero_grad()
38
+ >>> loss_fn(model(input), target).backward()
39
+ >>> optimizer.step()
40
+
41
+ .. note::
42
+ The application of momentum in the SGD part is modified according to
43
+ the PyTorch standards. LARS scaling fits into the equation in the
44
+ following fashion.
45
+
46
+ .. math::
47
+ \begin{aligned}
48
+ g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\
49
+ v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\
50
+ p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
51
+ \\end{aligned}
52
+
53
+ where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the
54
+ parameters, gradient, velocity, momentum, and weight decay respectively.
55
+ The :math:`lars_lr` is defined by Eq. 6 in the paper.
56
+ The Nesterov version is analogously modified.
57
+
58
+ .. warning::
59
+ Parameters with weight decay set to 0 will automatically be excluded from
60
+ layer-wise LR scaling. This is to ensure consistency with papers like SimCLR
61
+ and BYOL.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ params: list[Parameter],
67
+ lr: _RequiredParameter = required,
68
+ momentum: float = 0,
69
+ dampening: float = 0,
70
+ weight_decay: float = 0,
71
+ nesterov: bool = False,
72
+ trust_coefficient: float = 0.001,
73
+ eps: float = 1e-8,
74
+ ):
75
+ if lr is not required and lr < 0.0: # type: ignore[operator]
76
+ raise ValueError(f"Invalid learning rate: {lr}")
77
+ if momentum < 0.0:
78
+ raise ValueError(f"Invalid momentum value: {momentum}")
79
+ if weight_decay < 0.0:
80
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
81
+
82
+ defaults = {
83
+ "lr": lr,
84
+ "momentum": momentum,
85
+ "dampening": dampening,
86
+ "weight_decay": weight_decay,
87
+ "nesterov": nesterov,
88
+ }
89
+ if nesterov and (momentum <= 0 or dampening != 0):
90
+ raise ValueError("Nesterov momentum requires a momentum and zero dampening")
91
+
92
+ self.eps = eps
93
+ self.trust_coefficient = trust_coefficient
94
+
95
+ super().__init__(params, defaults)
96
+
97
+ def __setstate__(self, state):
98
+ super().__setstate__(state)
99
+
100
+ for group in self.param_groups:
101
+ group.setdefault("nesterov", False)
102
+
103
+ @torch.no_grad()
104
+ def step(self, closure: Callable | None = None):
105
+ """Performs a single optimization step.
106
+
107
+ Args:
108
+ closure: A closure that reevaluates the model and returns the loss. Defaults to None.
109
+ """
110
+ loss = None
111
+ if closure is not None:
112
+ with torch.enable_grad():
113
+ loss = closure()
114
+
115
+ # exclude scaling for params with 0 weight decay
116
+ for group in self.param_groups:
117
+ weight_decay = group["weight_decay"]
118
+ momentum = group["momentum"]
119
+ dampening = group["dampening"]
120
+ nesterov = group["nesterov"]
121
+
122
+ for p in group["params"]:
123
+ if p.grad is None:
124
+ continue
125
+
126
+ d_p = p.grad
127
+ p_norm = torch.norm(p.data)
128
+ g_norm = torch.norm(p.grad.data)
129
+
130
+ # lars scaling + weight decay part
131
+ if weight_decay != 0 and p_norm != 0 and g_norm != 0:
132
+ lars_lr = p_norm / (g_norm + p_norm * weight_decay + self.eps)
133
+ lars_lr *= self.trust_coefficient
134
+
135
+ d_p = d_p.add(p, alpha=weight_decay)
136
+ d_p *= lars_lr
137
+
138
+ # sgd part
139
+ if momentum != 0:
140
+ param_state = self.state[p]
141
+ if "momentum_buffer" not in param_state:
142
+ buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
143
+ else:
144
+ buf = param_state["momentum_buffer"]
145
+ buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
146
+ if nesterov:
147
+ d_p = d_p.add(buf, alpha=momentum)
148
+ else:
149
+ d_p = buf
150
+
151
+ p.add_(d_p, alpha=-group["lr"])
152
+
153
+ return loss
@@ -0,0 +1,127 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any
5
+
6
+ import torch
7
+ from torch.nn import Parameter
8
+
9
+
10
+ class SAM(torch.optim.Optimizer):
11
+ """PyTorch implementation of Sharpness-Aware-Minization paper: https://arxiv.org/abs/2010.01412
12
+ and https://arxiv.org/abs/2102.11600.
13
+ Taken from: https://github.com/davda54/sam.
14
+
15
+ Args:
16
+ params: model parameters.
17
+ base_optimizer: optimizer to use.
18
+ rho: Postive float value used to scale the gradients.
19
+ adaptive: Boolean flag indicating whether to use adaptive step update.
20
+ **kwargs: Additional parameters for the base optimizer.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ params: list[Parameter],
26
+ base_optimizer: torch.optim.Optimizer,
27
+ rho: float = 0.05,
28
+ adaptive: bool = True,
29
+ **kwargs: Any,
30
+ ):
31
+ assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
32
+
33
+ defaults = {"rho": rho, "adaptive": adaptive, **kwargs}
34
+ super().__init__(params, defaults)
35
+
36
+ if callable(base_optimizer):
37
+ self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
38
+ else:
39
+ self.base_optimizer = base_optimizer
40
+ self.rho = rho
41
+ self.adaptive = adaptive
42
+ self.param_groups = self.base_optimizer.param_groups
43
+
44
+ @torch.no_grad()
45
+ def first_step(self, zero_grad: bool = False) -> None:
46
+ """First step for SAM optimizer.
47
+
48
+ Args:
49
+ zero_grad: Boolean flag indicating whether to zero the gradients.
50
+
51
+ Returns:
52
+ None
53
+ """
54
+ grad_norm = self._grad_norm()
55
+ for group in self.param_groups:
56
+ scale = self.rho / (grad_norm + 1e-12)
57
+
58
+ for p in group["params"]:
59
+ if p.grad is None:
60
+ continue
61
+ e_w = (torch.pow(p, 2) if self.adaptive else 1.0) * p.grad * scale.to(p)
62
+ p.add_(e_w) # climb to the local maximum "w + e(w)"
63
+ self.state[p]["e_w"] = e_w
64
+
65
+ if zero_grad:
66
+ self.zero_grad()
67
+
68
+ @torch.no_grad()
69
+ def second_step(self, zero_grad: bool = False) -> None:
70
+ """Second step for SAM optimizer.
71
+
72
+ Args:
73
+ zero_grad: Boolean flag indicating whether to zero the gradients.
74
+
75
+ Returns:
76
+ None
77
+
78
+ """
79
+ for group in self.param_groups:
80
+ for p in group["params"]:
81
+ if p.grad is None:
82
+ continue
83
+ p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)"
84
+
85
+ self.base_optimizer.step() # do the actual "sharpness-aware" update
86
+
87
+ if zero_grad:
88
+ self.zero_grad()
89
+
90
+ @torch.no_grad()
91
+ def step(self, closure: Callable | None = None) -> None: # type: ignore[override]
92
+ """Step for SAM optimizer.
93
+
94
+ Args:
95
+ closure: The Optional closure for enable grad.
96
+
97
+ Returns:
98
+ None
99
+
100
+ """
101
+ if closure is not None:
102
+ closure = torch.enable_grad()(closure)
103
+
104
+ self.first_step(zero_grad=True)
105
+ if closure is not None:
106
+ closure()
107
+ self.second_step(zero_grad=False)
108
+
109
+ def _grad_norm(self) -> torch.Tensor:
110
+ """Put everything on the same device, in case of model parallelism
111
+ Returns:
112
+ Grad norm.
113
+ """
114
+ # put everything on the same device, in case of model parallelism
115
+ shared_device = self.param_groups[0]["params"][0].device
116
+ norm = torch.norm(
117
+ torch.stack(
118
+ [
119
+ ((torch.abs(p) if self.adaptive else 1.0) * p.grad).norm(p=2).to(shared_device)
120
+ for group in self.param_groups
121
+ for p in group["params"]
122
+ if p.grad is not None
123
+ ]
124
+ ),
125
+ p=2,
126
+ )
127
+ return norm
@@ -0,0 +1,3 @@
1
+ from .warmup import CosineAnnealingWithLinearWarmUp
2
+
3
+ __all__ = ["CosineAnnealingWithLinearWarmUp"]
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ from torch.optim import Optimizer
4
+ from torch.optim.lr_scheduler import _LRScheduler
5
+
6
+
7
+ class LearningRateScheduler(_LRScheduler):
8
+ """Provides inteface of learning rate scheduler.
9
+
10
+ Note:
11
+ Do not use this class directly, use one of the sub classes.
12
+ """
13
+
14
+ def __init__(self, optimizer: Optimizer, init_lr: tuple[float, ...]):
15
+ # pylint: disable=super-init-not-called
16
+ self.optimizer = optimizer
17
+ self.init_lr = init_lr
18
+
19
+ def step(self, *args, **kwargs):
20
+ """Base method, must be implemented by the sub classes."""
21
+ raise NotImplementedError
22
+
23
+ def set_lr(self, lr: tuple[float, ...]):
24
+ """Set the learning rate for the optimizer."""
25
+ if self.optimizer is not None:
26
+ for i, g in enumerate(self.optimizer.param_groups):
27
+ if "fix_lr" in g and g["fix_lr"]:
28
+ if len(lr) == 1:
29
+ lr_to_set = self.init_lr[0]
30
+ else:
31
+ lr_to_set = self.init_lr[i]
32
+ elif len(lr) == 1:
33
+ lr_to_set = lr[0]
34
+ else:
35
+ lr_to_set = lr[i]
36
+ g["lr"] = lr_to_set
37
+
38
+ def get_lr(self):
39
+ """Get the current learning rate if the optimizer is available."""
40
+ if self.optimizer is not None:
41
+ for g in self.optimizer.param_groups:
42
+ return g["lr"]
43
+
44
+ return None