quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,285 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from quadra.utils.models import trunc_normal_
7
+
8
+
9
+ class ProjectionHead(torch.nn.Module):
10
+ """Base class for all projection and prediction heads.
11
+
12
+ Args:
13
+ blocks:
14
+ List of tuples, each denoting one block of the projection head MLP.
15
+ Each tuple reads (linear_layer, batch_norm_layer, non_linearity_layer).
16
+ `batch_norm` layer can be possibly None, the same happens for
17
+ `non_linearity_layer`.
18
+ """
19
+
20
+ def __init__(self, blocks: list[tuple[torch.nn.Module | None, ...]]):
21
+ super().__init__()
22
+
23
+ layers: list[nn.Module] = []
24
+ for linear, batch_norm, non_linearity in blocks:
25
+ if linear:
26
+ layers.append(linear)
27
+ if batch_norm:
28
+ layers.append(batch_norm)
29
+ if non_linearity:
30
+ layers.append(non_linearity)
31
+ self.layers = torch.nn.Sequential(*layers)
32
+
33
+ def forward(self, x: torch.Tensor):
34
+ return self.layers(x)
35
+
36
+
37
+ class ExpanderReducer(ProjectionHead):
38
+ """Expander followed by a reducer."""
39
+
40
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
41
+ super().__init__(
42
+ [
43
+ (
44
+ torch.nn.Linear(input_dim, hidden_dim, bias=False),
45
+ torch.nn.BatchNorm1d(hidden_dim),
46
+ torch.nn.ReLU(inplace=True),
47
+ ),
48
+ (
49
+ torch.nn.Linear(hidden_dim, output_dim, bias=False),
50
+ torch.nn.BatchNorm1d(output_dim, affine=False),
51
+ torch.nn.ReLU(inplace=True),
52
+ ),
53
+ ]
54
+ )
55
+
56
+
57
+ class BarlowTwinsProjectionHead(ProjectionHead):
58
+ """Projection head used for Barlow Twins.
59
+ "The projector network has three linear layers, each with 8192 output
60
+ units. The first two layers of the projector are followed by a batch
61
+ normalization layer and rectified linear units." https://arxiv.org/abs/2103.03230.
62
+ """
63
+
64
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
65
+ super().__init__(
66
+ [
67
+ (
68
+ torch.nn.Linear(input_dim, hidden_dim, bias=False),
69
+ torch.nn.BatchNorm1d(hidden_dim),
70
+ torch.nn.ReLU(inplace=True),
71
+ ),
72
+ (
73
+ torch.nn.Linear(hidden_dim, hidden_dim, bias=False),
74
+ torch.nn.BatchNorm1d(hidden_dim),
75
+ torch.nn.ReLU(inplace=True),
76
+ ),
77
+ (torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
78
+ ]
79
+ )
80
+
81
+
82
+ class SimCLRProjectionHead(ProjectionHead):
83
+ """Projection head used for SimCLR.
84
+ "We use a MLP with one hidden layer to obtain zi = g(h) = W_2 * σ(W_1 * h)
85
+ where σ is a ReLU non-linearity." https://arxiv.org/abs/2002.05709.
86
+ """
87
+
88
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
89
+ super().__init__(
90
+ [
91
+ (
92
+ torch.nn.Linear(input_dim, hidden_dim),
93
+ None,
94
+ torch.nn.ReLU(inplace=True),
95
+ ),
96
+ (torch.nn.Linear(hidden_dim, output_dim), None, None),
97
+ ]
98
+ )
99
+
100
+
101
+ class SimCLRPredictionHead(ProjectionHead):
102
+ """Prediction head used for SimCLR.
103
+ "We set g(h) = W(2)σ(W(1)h), with the same input and output dimensionality (i.e. 2048)."
104
+ https://arxiv.org/abs/2002.05709.
105
+ """
106
+
107
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
108
+ super().__init__(
109
+ [
110
+ (
111
+ torch.nn.Linear(input_dim, hidden_dim, bias=False),
112
+ torch.nn.BatchNorm1d(hidden_dim),
113
+ torch.nn.ReLU(inplace=True),
114
+ ),
115
+ (torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
116
+ ]
117
+ )
118
+
119
+
120
+ class SimSiamProjectionHead(ProjectionHead):
121
+ """Projection head used for SimSiam.
122
+ "The projection MLP (in f) has BN applied to each fully-connected (fc)
123
+ layer, including its output fc. Its output fc has no ReLU. The hidden fc is
124
+ 2048-d. This MLP has 3 layers." https://arxiv.org/abs/2011.10566.
125
+ """
126
+
127
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
128
+ super().__init__(
129
+ [
130
+ (
131
+ torch.nn.Linear(input_dim, hidden_dim, bias=False),
132
+ torch.nn.BatchNorm1d(hidden_dim),
133
+ torch.nn.ReLU(inplace=True),
134
+ ),
135
+ (
136
+ torch.nn.Linear(hidden_dim, hidden_dim, bias=False),
137
+ torch.nn.BatchNorm1d(hidden_dim, affine=False),
138
+ torch.nn.ReLU(inplace=True),
139
+ ),
140
+ (
141
+ torch.nn.Linear(hidden_dim, output_dim, bias=False),
142
+ torch.nn.BatchNorm1d(output_dim, affine=False),
143
+ None,
144
+ ),
145
+ ]
146
+ )
147
+
148
+
149
+ class SimSiamPredictionHead(ProjectionHead):
150
+ """Prediction head used for SimSiam.
151
+ "The prediction MLP (h) has BN applied to its hidden fc layers. Its output
152
+ fc does not have BN (...) or ReLU. This MLP has 2 layers." https://arxiv.org/abs/2011.10566.
153
+ """
154
+
155
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
156
+ super().__init__(
157
+ [
158
+ (
159
+ torch.nn.Linear(input_dim, hidden_dim, bias=False),
160
+ torch.nn.BatchNorm1d(hidden_dim),
161
+ torch.nn.ReLU(inplace=True),
162
+ ),
163
+ (torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
164
+ ]
165
+ )
166
+
167
+
168
+ class BYOLPredictionHead(ProjectionHead):
169
+ """Prediction head used for BYOL."""
170
+
171
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
172
+ super().__init__(
173
+ [
174
+ (
175
+ torch.nn.Linear(input_dim, hidden_dim, bias=False),
176
+ torch.nn.BatchNorm1d(hidden_dim),
177
+ torch.nn.ReLU(inplace=True),
178
+ ),
179
+ (torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
180
+ ]
181
+ )
182
+
183
+
184
+ class BYOLProjectionHead(ProjectionHead):
185
+ """Projection head used for BYOL."""
186
+
187
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
188
+ super().__init__(
189
+ [
190
+ (
191
+ torch.nn.Linear(input_dim, hidden_dim, bias=False),
192
+ torch.nn.BatchNorm1d(hidden_dim),
193
+ torch.nn.ReLU(inplace=True),
194
+ ),
195
+ (torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
196
+ ]
197
+ )
198
+
199
+
200
+ class DinoProjectionHead(nn.Module):
201
+ """Projection head used for Dino. This projection head does not have
202
+ a batch norm layer.
203
+
204
+ Args:
205
+ input_dim: Input dimension for MLP head.
206
+ output_dim: Output dimension (projection dimension) for MLP head.
207
+ hidden_dim: Hidden dimension. Defaults to 512.
208
+ bottleneck_dim: Bottleneck dimension. Defaults to 256.
209
+ num_layers: Number of hidden layers used in MLP. Defaults to 3.
210
+ norm_last_layer: Decides applying normalization before last layer.
211
+ Defaults to False.
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ input_dim: int,
217
+ output_dim: int,
218
+ hidden_dim: int,
219
+ use_bn: bool = False,
220
+ norm_last_layer: bool = True,
221
+ num_layers: int = 3,
222
+ bottleneck_dim: int = 256,
223
+ ):
224
+ super().__init__()
225
+ num_layers = max(num_layers, 1)
226
+ self.mlp: nn.Linear | nn.Sequential
227
+ if num_layers == 1:
228
+ self.mlp = nn.Linear(input_dim, bottleneck_dim)
229
+ else:
230
+ layers: list[nn.Module] = [nn.Linear(input_dim, hidden_dim)]
231
+ if use_bn:
232
+ layers.append(nn.BatchNorm1d(hidden_dim))
233
+ layers.append(nn.GELU())
234
+ for _ in range(num_layers - 2):
235
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
236
+ if use_bn:
237
+ layers.append(nn.BatchNorm1d(hidden_dim))
238
+ layers.append(nn.GELU())
239
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
240
+ self.mlp = nn.Sequential(*layers)
241
+ self.apply(self._init_weights)
242
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, output_dim, bias=False))
243
+ self.last_layer.weight_g.data.fill_(1)
244
+ if norm_last_layer:
245
+ self.last_layer.weight_g.requires_grad = False
246
+
247
+ def _init_weights(self, m):
248
+ """Initialize the weights of the projection head."""
249
+ if isinstance(m, nn.Linear):
250
+ trunc_normal_(m.weight, std=0.02)
251
+ if isinstance(m, nn.Linear) and m.bias is not None:
252
+ nn.init.constant_(m.bias, 0)
253
+
254
+ def forward(self, x):
255
+ x = self.mlp(x)
256
+ x = nn.functional.normalize(x, dim=-1, p=2)
257
+ x = self.last_layer(x)
258
+ return x
259
+
260
+
261
+ class MultiCropModel(nn.Module):
262
+ """MultiCrop model for DINO augmentation.
263
+
264
+ It takes 2 global crops and N (possible) local crops as a single tensor.
265
+
266
+ Args:
267
+ backbone: Backbone model.
268
+ head: Head model.
269
+ """
270
+
271
+ def __init__(self, backbone: nn.Module, head: nn.Module):
272
+ super().__init__()
273
+ self.backbone = backbone
274
+ self.head = head
275
+
276
+ def forward(self, x):
277
+ n_crops = len(x)
278
+ # (n_samples * n_crops, 3, size, size)
279
+ concatenated = torch.cat(x, dim=0)
280
+ # (n_samples * n_crops, in_dim)
281
+ embedding = self.backbone(concatenated)
282
+ logits = self.head(embedding) # (n_samples * n_crops, out_dim)
283
+ chunks = logits.chunk(n_crops) # n_crops * (n_samples, out_dim)
284
+
285
+ return chunks
@@ -0,0 +1,186 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any
5
+
6
+ import sklearn
7
+ import torch
8
+ from pytorch_lightning.core.optimizer import LightningOptimizer
9
+ from torch import nn
10
+ from torch.optim import Optimizer
11
+
12
+ from quadra.modules.ssl import BYOL
13
+ from quadra.utils.models import clip_gradients
14
+ from quadra.utils.utils import get_logger
15
+
16
+ log = get_logger(__name__)
17
+
18
+
19
+ class Dino(BYOL):
20
+ """DINO pytorch-lightning module.
21
+
22
+ Args:
23
+ student : student model
24
+ teacher : teacher model
25
+ student_projection_mlp : student projection MLP
26
+ teacher_projection_mlp : teacher projection MLP
27
+ criterion : loss function
28
+ freeze_last_layer : number of layers to freeze in the student model. Default: 1
29
+ classifier: Standard sklearn classifier
30
+ optimizer: optimizer of the training. If None a default Adam is used.
31
+ lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used.
32
+ lr_scheduler_interval: interval at which the lr scheduler is updated.
33
+ teacher_momentum: momentum of the teacher parameters
34
+ teacher_momentum_cosine_decay: whether to use cosine decay for the teacher momentum
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ student: nn.Module,
40
+ teacher: nn.Module,
41
+ student_projection_mlp: nn.Module,
42
+ teacher_projection_mlp: nn.Module,
43
+ criterion: nn.Module,
44
+ freeze_last_layer: int = 1,
45
+ classifier: sklearn.base.ClassifierMixin | None = None,
46
+ optimizer: Optimizer | None = None,
47
+ lr_scheduler: object | None = None,
48
+ lr_scheduler_interval: str | None = "epoch",
49
+ teacher_momentum: float = 0.9995,
50
+ teacher_momentum_cosine_decay: bool | None = True,
51
+ ):
52
+ super().__init__(
53
+ student=student,
54
+ teacher=teacher,
55
+ student_projection_mlp=student_projection_mlp,
56
+ student_prediction_mlp=nn.Identity(),
57
+ teacher_projection_mlp=teacher_projection_mlp,
58
+ criterion=criterion,
59
+ teacher_momentum=teacher_momentum,
60
+ teacher_momentum_cosine_decay=teacher_momentum_cosine_decay,
61
+ classifier=classifier,
62
+ optimizer=optimizer,
63
+ lr_scheduler=lr_scheduler,
64
+ lr_scheduler_interval=lr_scheduler_interval,
65
+ )
66
+ self.freeze_last_layer = freeze_last_layer
67
+
68
+ def initialize_teacher(self):
69
+ """Initialize teacher from the state dict of the student one,
70
+ checking also that student model requires greadient correctly.
71
+ """
72
+ self.teacher_projection_mlp.load_state_dict(self.student_projection_mlp.state_dict())
73
+ for p in self.teacher_projection_mlp.parameters():
74
+ p.requires_grad = False
75
+
76
+ self.teacher.load_state_dict(self.model.state_dict())
77
+ for p in self.teacher.parameters():
78
+ p.requires_grad = False
79
+
80
+ all_frozen = True
81
+ for p in self.model.parameters():
82
+ all_frozen = all_frozen and (not p.requires_grad)
83
+
84
+ if all_frozen:
85
+ log.warning(
86
+ "All parameters of the student model are frozen, the model will not be trained, automatically"
87
+ " unfreezing all the layers"
88
+ )
89
+
90
+ for p in self.model.parameters():
91
+ p.requires_grad = True
92
+
93
+ for name, p in self.student_projection_mlp.named_parameters():
94
+ if name != "last_layer.weight_g":
95
+ assert p.requires_grad is True
96
+
97
+ self.teacher_initialized = True
98
+
99
+ def student_multicrop_forward(self, x: list[torch.Tensor]) -> torch.Tensor:
100
+ """Student forward on the multicrop imges.
101
+
102
+ Args:
103
+ x: List of torch.Tensor containing multicropped augmented images
104
+
105
+ Returns:
106
+ torch.Tensor: a tensor of shape NxBxD, where N is the number crops
107
+ corresponding to the length of the input list `x`, B is the batch size
108
+ and D is the output dimension
109
+ """
110
+ n_crops = len(x)
111
+ concatenated = torch.cat(x, dim=0) # (n_samples * n_crops, C, H, W)
112
+ embedding = self.model(concatenated) # (n_samples * n_crops, in_dim)
113
+ logits = self.student_projection_mlp(embedding) # (n_samples * n_crops, out_dim)
114
+ chunks = logits.chunk(n_crops) # n_crops * (n_samples, out_dim)
115
+ return chunks
116
+
117
+ def teacher_multicrop_forward(self, x: list[torch.Tensor]) -> torch.Tensor:
118
+ """Teacher forward on the multicrop imges.
119
+
120
+ Args:
121
+ x: List of torch.Tensor containing multicropped augmented images
122
+
123
+ Returns:
124
+ torch.Tensor: a tensor of shape NxBxD, where N is the number crops
125
+ corresponding to the length of the input list `x`, B is the batch size
126
+ and D is the output dimension
127
+ """
128
+ n_crops = len(x)
129
+ concatenated = torch.cat(x, dim=0) # (n_samples * n_crops, C, H, W)
130
+ embedding = self.teacher(concatenated) # (n_samples * n_crops, in_dim)
131
+ logits = self.teacher_projection_mlp(embedding) # (n_samples * n_crops, out_dim)
132
+ chunks = logits.chunk(n_crops) # n_crops * (n_samples, out_dim)
133
+ return chunks
134
+
135
+ def cancel_gradients_last_layer(self, epoch: int, freeze_last_layer: int):
136
+ """Zero out the gradient of the last layer, as specified in the paper.
137
+
138
+ Args:
139
+ epoch: current epoch
140
+ freeze_last_layer: maximum freeze epoch: if `epoch` >= `freeze_last_layer`
141
+ then the gradient of the last layer will not be freezed
142
+ """
143
+ if epoch >= freeze_last_layer:
144
+ return
145
+ for n, p in self.student_projection_mlp.named_parameters():
146
+ if "last_layer" in n:
147
+ p.grad = None
148
+
149
+ def training_step(self, batch: tuple[list[torch.Tensor], torch.Tensor], *args: Any) -> torch.Tensor:
150
+ images, _ = batch
151
+ with torch.no_grad():
152
+ teacher_output = self.teacher_multicrop_forward(images[:2])
153
+
154
+ student_output = self.student_multicrop_forward(images)
155
+ loss = self.criterion(self.current_epoch, student_output, teacher_output)
156
+
157
+ self.log(name="loss", value=loss, on_step=True, on_epoch=True, prog_bar=True)
158
+ return loss
159
+
160
+ def configure_gradient_clipping(
161
+ self,
162
+ optimizer: Optimizer,
163
+ gradient_clip_val: int | float | None = None,
164
+ gradient_clip_algorithm: str | None = None,
165
+ ):
166
+ """Configure gradient clipping for the optimizer."""
167
+ if gradient_clip_algorithm is not None and gradient_clip_val is not None:
168
+ clip_gradients(self.model, gradient_clip_val)
169
+ clip_gradients(self.student_projection_mlp, gradient_clip_val)
170
+ self.cancel_gradients_last_layer(self.current_epoch, self.freeze_last_layer)
171
+
172
+ def optimizer_step(
173
+ self,
174
+ epoch: int,
175
+ batch_idx: int,
176
+ optimizer: Optimizer | LightningOptimizer,
177
+ optimizer_closure: Callable[[], Any] | None = None,
178
+ ) -> None:
179
+ """Override optimizer step to update the teacher parameters."""
180
+ super().optimizer_step(
181
+ epoch,
182
+ batch_idx,
183
+ optimizer,
184
+ optimizer_closure=optimizer_closure,
185
+ )
186
+ self.update_teacher()
@@ -0,0 +1,206 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from enum import Enum
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn, optim
9
+
10
+ from quadra.losses.ssl import hyperspherical as loss
11
+ from quadra.modules.base import BaseLightningModule
12
+
13
+
14
+ class AlignLoss(Enum):
15
+ """Align loss enum."""
16
+
17
+ L2 = 1
18
+ COSINE = 2
19
+
20
+
21
+ class TLHyperspherical(BaseLightningModule):
22
+ """Hyperspherical model: maps features extracted from a pretrained backbone into
23
+ an hypersphere.
24
+
25
+ Args:
26
+ model: Feature extractor as pytorch `torch.nn.Module`
27
+ optimizer: optimizer of the training.
28
+ If None a default Adam is used.
29
+ lr_scheduler: lr scheduler.
30
+ If None a default ReduceLROnPlateau is used.
31
+ align_weight: Weight for the align loss component for the
32
+ hyperspherical loss.
33
+ Defaults to 1.
34
+ unifo_weight: Weight for the uniform loss component for the
35
+ hyperspherical loss.
36
+ Defaults to 1.
37
+ classifier_weight: Weight for the classifier loss component for the
38
+ hyperspherical loss.
39
+ Defaults to 1.
40
+ align_loss_type: Which type of align loss to use.
41
+ Defaults to AlignLoss.L2.
42
+ classifier_loss: Whether to compute a classifier loss to 'enhance'
43
+ the hyperpsherical loss with the classification loss.
44
+ It True, model.classifier must be defined
45
+ Defaults to False.
46
+ num_classes: Number of classes for a classification problem.
47
+ Defaults to None.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ model: nn.Module,
53
+ optimizer: optim.Optimizer | None = None,
54
+ lr_scheduler: object | None = None,
55
+ align_weight: float = 1,
56
+ unifo_weight: float = 1,
57
+ classifier_weight: float = 1,
58
+ align_loss_type: AlignLoss = AlignLoss.L2,
59
+ classifier_loss: bool = False,
60
+ num_classes: int | None = None,
61
+ ):
62
+ super().__init__(model, optimizer, lr_scheduler)
63
+ self.align_loss_fun: (
64
+ Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]
65
+ | Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
66
+ )
67
+ self.align_weight = align_weight
68
+ self.unifo_weight = unifo_weight
69
+ self.classifier_weight = classifier_weight
70
+ self.align_loss_type = align_loss_type
71
+ if align_loss_type == AlignLoss.L2:
72
+ self.align_loss_fun = loss.align_loss
73
+ elif align_loss_type == AlignLoss.COSINE:
74
+ self.align_loss_fun = loss.cosine_align_loss
75
+ else:
76
+ raise ValueError("The align loss must be one of 'AlignLoss.L2' (L2 distance) or AlignLoss.COSINE")
77
+
78
+ if classifier_loss and model.classifier is None:
79
+ raise AssertionError("Classifier is not defined")
80
+
81
+ self.classifier_loss = classifier_loss
82
+ self.num_classes = num_classes
83
+
84
+ def forward(self, x):
85
+ return self.model(x)
86
+
87
+ def training_step(self, batch, batch_idx):
88
+ # pylint: disable=unused-argument
89
+ im_x, im_y, target = batch
90
+ emb_x, emb_y = self(torch.cat([im_x, im_y])).chunk(2)
91
+
92
+ align_loss = 0.0
93
+ if self.align_weight > 0:
94
+ align_loss = self.align_loss_fun(emb_x, emb_y)
95
+
96
+ unifo_loss = 0.0
97
+ if self.unifo_weight > 0:
98
+ unifo_loss = (loss.uniform_loss(emb_x) + loss.uniform_loss(emb_y)) / 2
99
+
100
+ classifier_loss = 0.0
101
+ if self.classifier_loss:
102
+ pred = self.model.classifier(emb_x)
103
+ classifier_loss = F.cross_entropy(pred, target)
104
+
105
+ total_loss = (
106
+ self.align_weight * align_loss + self.unifo_weight * unifo_loss + self.classifier_weight * classifier_loss
107
+ )
108
+
109
+ self.log(
110
+ "t_loss",
111
+ total_loss,
112
+ on_epoch=True,
113
+ logger=True,
114
+ prog_bar=False,
115
+ )
116
+ self.log(
117
+ "t_align",
118
+ align_loss,
119
+ on_epoch=True,
120
+ on_step=False,
121
+ logger=True,
122
+ prog_bar=False,
123
+ )
124
+ self.log(
125
+ "t_classifier",
126
+ classifier_loss,
127
+ on_epoch=True,
128
+ on_step=False,
129
+ logger=True,
130
+ prog_bar=True,
131
+ )
132
+ self.log(
133
+ "t_unif",
134
+ unifo_loss,
135
+ on_epoch=True,
136
+ on_step=False,
137
+ logger=True,
138
+ prog_bar=False,
139
+ )
140
+ return {"loss": total_loss}
141
+
142
+ def train_epoch_end(self, outputs):
143
+ avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
144
+
145
+ return {"loss": avg_loss}
146
+
147
+ def validation_step(self, batch, batch_idx):
148
+ # pylint: disable=unused-argument
149
+ im_x, im_y, target = batch
150
+ emb_x, emb_y = self(torch.cat([im_x, im_y])).chunk(2)
151
+
152
+ align_loss = 0.0
153
+ if self.align_weight > 0:
154
+ align_loss = self.align_loss_fun(emb_x, emb_y)
155
+
156
+ unifo_loss = 0.0
157
+ if self.unifo_weight > 0:
158
+ unifo_loss = (loss.uniform_loss(emb_x) + loss.uniform_loss(emb_y)) / 2
159
+
160
+ classifier_loss = 0.0
161
+ if self.classifier_loss:
162
+ pred = self.model.classifier(emb_x)
163
+ classifier_loss = F.cross_entropy(pred, target)
164
+
165
+ total_loss = (
166
+ self.align_weight * align_loss + self.unifo_weight * unifo_loss + self.classifier_weight * classifier_loss
167
+ )
168
+
169
+ self.log(
170
+ "val_loss",
171
+ total_loss,
172
+ on_epoch=True,
173
+ on_step=False,
174
+ logger=True,
175
+ prog_bar=False,
176
+ )
177
+ self.log(
178
+ "v_classifier",
179
+ classifier_loss,
180
+ on_epoch=True,
181
+ on_step=False,
182
+ logger=True,
183
+ prog_bar=True,
184
+ )
185
+ self.log(
186
+ "v_align",
187
+ align_loss,
188
+ on_epoch=True,
189
+ on_step=False,
190
+ logger=True,
191
+ prog_bar=False,
192
+ )
193
+ self.log(
194
+ "v_unif",
195
+ unifo_loss,
196
+ on_epoch=True,
197
+ on_step=False,
198
+ logger=True,
199
+ prog_bar=False,
200
+ )
201
+ return {"val_loss": total_loss}
202
+
203
+ def on_validation_epoch_end(self, outputs):
204
+ avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
205
+
206
+ return {"val_loss": avg_loss}