quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,45 @@
1
+ import torch
2
+ from torch.nn.functional import cosine_similarity
3
+
4
+
5
+ def cosine_align_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
6
+ """Computes mean of cosine distance based on similarity mean(1 - cosine_similarity).
7
+
8
+ Args:
9
+ x: feature n1
10
+ y: feature n2.
11
+
12
+ Returns:
13
+ cosine align loss
14
+ """
15
+ cos = 1 - cosine_similarity(x, y, dim=1)
16
+ return torch.mean(cos)
17
+
18
+
19
+ # Source: https://arxiv.org/pdf/2005.10242.pdf
20
+ def align_loss(x: torch.Tensor, y: torch.Tensor, alpha: int = 2) -> torch.Tensor:
21
+ """Mean(l2^alpha).
22
+
23
+ Args:
24
+ x: feature n1
25
+ y: feature n2
26
+ alpha: pow of the norm loss.
27
+
28
+ Returns:
29
+ Align loss
30
+ """
31
+ norm = torch.norm(x - y, p=2, dim=1)
32
+ return torch.mean(torch.pow(norm, alpha))
33
+
34
+
35
+ def uniform_loss(x: torch.Tensor, t: float = 2.0) -> torch.Tensor:
36
+ """log(mean(exp(-t*dist_p2))).
37
+
38
+ Args:
39
+ x: feature tensor
40
+ t: temperature of the dist_p2.
41
+
42
+ Returns:
43
+ Uniform loss
44
+ """
45
+ return torch.log(torch.mean(torch.exp(torch.pow(torch.pdist(x, p=2), 2) * -t)))
@@ -0,0 +1,50 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def idmm_loss(
6
+ p1: torch.Tensor,
7
+ y1: torch.Tensor,
8
+ smoothing: float = 0.1,
9
+ ) -> torch.Tensor:
10
+ """IDMM loss described in https://arxiv.org/abs/2201.10728.
11
+
12
+ Args:
13
+ p1: Prediction labels for `z1`
14
+ y1: Instance labels for `z1`
15
+ smoothing: smoothing factor used for label smoothing.
16
+ Defaults to 0.1.
17
+
18
+ Returns:
19
+ IDMM loss
20
+ """
21
+ loss = F.cross_entropy(p1, y1, label_smoothing=smoothing)
22
+ return loss
23
+
24
+
25
+ class IDMMLoss(torch.nn.Module):
26
+ """IDMM loss described in https://arxiv.org/abs/2201.10728."""
27
+
28
+ def __init__(self, smoothing: float = 0.1):
29
+ super().__init__()
30
+ self.smoothing = smoothing
31
+
32
+ def forward(
33
+ self,
34
+ p1: torch.Tensor,
35
+ y1: torch.Tensor,
36
+ ) -> torch.Tensor:
37
+ """IDMM loss described in https://arxiv.org/abs/2201.10728.
38
+
39
+ Args:
40
+ p1: Prediction labels for `z1`
41
+ y1: Instance labels for `z1`
42
+
43
+ Returns:
44
+ IDMM loss
45
+ """
46
+ return idmm_loss(
47
+ p1,
48
+ y1,
49
+ self.smoothing,
50
+ )
@@ -0,0 +1,67 @@
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from quadra.utils.utils import AllGatherSyncFunction
7
+
8
+
9
+ def simclr_loss(
10
+ features1: torch.Tensor,
11
+ features2: torch.Tensor,
12
+ temperature: float = 1.0,
13
+ ) -> torch.Tensor:
14
+ """SimCLR loss described in https://arxiv.org/pdf/2002.05709.pdf.
15
+
16
+ Args:
17
+ temperature: optional temperature
18
+ features1: First augmented features (i.e. T(features))
19
+ features2: Second augmented features (i.e. T'(features))
20
+
21
+ Returns:
22
+ SimCLR loss
23
+ """
24
+ features1 = F.normalize(features1, dim=-1)
25
+ features2 = F.normalize(features2, dim=-1)
26
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
27
+ features1_dist = AllGatherSyncFunction.apply(features1)
28
+ features2_dist = AllGatherSyncFunction.apply(features2)
29
+ else:
30
+ features1_dist = features1
31
+ features2_dist = features2
32
+ features = torch.cat([features1, features2], dim=0) # [2B, d]
33
+ features_dist = torch.cat([features1_dist, features2_dist], dim=0) # [2B * DIST_SIZE, d]
34
+
35
+ # Similarity matrix
36
+ sim = torch.exp(torch.div(torch.mm(features, features_dist.t()), temperature)) # [2B, 2B * DIST_SIZE]
37
+
38
+ # Negatives
39
+ neg = sim.sum(dim=-1)
40
+
41
+ # From each row, subtract e^(1/temp) to remove similarity measure for zi * zi, since
42
+ # (zi^T * zi) / ||zi||^2 = 1
43
+ row_sub = torch.full_like(neg, math.e ** (1 / temperature), device=neg.device)
44
+ neg = torch.clamp(neg - row_sub, min=1e-6) # clamp for numerical stability
45
+
46
+ # Positive similarity, pos becomes [2 * batch_size]
47
+ pos = torch.exp(torch.div(torch.sum(features1 * features2, dim=-1), temperature))
48
+ pos = torch.cat([pos, pos], dim=0)
49
+
50
+ loss = -torch.log(pos / (neg + 1e-6)).mean()
51
+ return loss
52
+
53
+
54
+ class SimCLRLoss(torch.nn.Module):
55
+ """SIMCLRloss module.
56
+
57
+ Args:
58
+ temperature: temperature of SIM loss.
59
+ """
60
+
61
+ def __init__(self, temperature: float = 1.0):
62
+ super().__init__()
63
+ self.temperature = temperature
64
+
65
+ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
66
+ """Forward pass of the loss."""
67
+ return simclr_loss(x1, x2, temperature=self.temperature)
@@ -0,0 +1,30 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def simsiam_loss(
6
+ p1: torch.Tensor,
7
+ p2: torch.Tensor,
8
+ z1: torch.Tensor,
9
+ z2: torch.Tensor,
10
+ ) -> torch.Tensor:
11
+ """SimSIAM loss described in https://arxiv.org/abs/2011.10566.
12
+
13
+ Args:
14
+ p1: First `predicted` features (i.e. h(f(T(x1))))
15
+ p2: Second `predicted` features (i.e. h(f(T'(x2))))
16
+ z1: First 'projected features (i.e. f(T(x1)))
17
+ z2: Second 'projected features (i.e. f(T(x2)))
18
+
19
+ Returns:
20
+ SimSIAM loss
21
+ """
22
+ return -(F.cosine_similarity(p1, z2).mean() + F.cosine_similarity(p2, z1).mean()) * 0.5
23
+
24
+
25
+ class SimSIAMLoss(torch.nn.Module):
26
+ """SimSIAM loss module."""
27
+
28
+ def forward(self, p1: torch.Tensor, p2: torch.Tensor, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
29
+ """Compute the SimSIAM loss."""
30
+ return simsiam_loss(p1, p2, z1, z2)
@@ -0,0 +1,76 @@
1
+ import torch
2
+
3
+
4
+ def vicreg_loss(
5
+ z1: torch.Tensor,
6
+ z2: torch.Tensor,
7
+ lambd: float,
8
+ mu: float,
9
+ nu: float = 1,
10
+ gamma: float = 1,
11
+ ) -> torch.Tensor:
12
+ """VICReg loss described in https://arxiv.org/abs/2105.04906.
13
+
14
+ Args:
15
+ z1: First `augmented` normalized features (i.e. f(T(x))). The normalization can be obtained with
16
+ z1_norm = (z1 - z1.mean(0)) / z1.std(0)
17
+ z2: Second `augmented` normalized features (i.e. f(T(x))). The normalization can be obtained with
18
+ z2_norm = (z2 - z2.mean(0)) / z2.std(0)
19
+ lambd: lambda multiplier for redundancy term.
20
+ mu: mu multiplier for similarity term.
21
+ nu: nu multiplier for variance term. Default: 1
22
+ gamma: gamma multiplier for covariance term. Default: 1
23
+
24
+ Returns:
25
+ VICReg loss
26
+ """
27
+ # Variance loss
28
+ std_z1 = torch.sqrt(z1.var(dim=0) + 0.0001)
29
+ std_z2 = torch.sqrt(z2.var(dim=0) + 0.0001)
30
+ v_z1 = torch.nn.functional.relu(gamma - std_z1).mean()
31
+ v_z2 = torch.nn.functional.relu(gamma - std_z2).mean()
32
+ var_loss = v_z1 + v_z2
33
+
34
+ # Similarity loss
35
+ sim_loss = torch.nn.functional.mse_loss(z1, z2)
36
+
37
+ # Covariance loss
38
+ n = z1.size(0)
39
+ d = z1.size(1)
40
+ z1 = z1 - z1.mean(dim=0)
41
+ z2 = z2 - z2.mean(dim=0)
42
+ cov_z1 = (z1.T @ z1) / (n - 1)
43
+ cov_z2 = (z2.T @ z2) / (n - 1)
44
+ off_diagonal_cov_z1 = cov_z1.flatten()[:-1].view(d - 1, d + 1)[:, 1:].flatten()
45
+ off_diagonal_cov_z2 = cov_z2.flatten()[:-1].view(d - 1, d + 1)[:, 1:].flatten()
46
+ cov_loss = off_diagonal_cov_z1.pow_(2).sum() / d + off_diagonal_cov_z2.pow_(2).sum() / d
47
+
48
+ return lambd * sim_loss + mu * var_loss + nu * cov_loss
49
+
50
+
51
+ class VICRegLoss(torch.nn.Module):
52
+ """VIC regression loss module.
53
+
54
+ Args:
55
+ lambd: lambda multiplier for redundancy term.
56
+ mu: mu multiplier for similarity term.
57
+ nu: nu multiplier for variance term. Default: 1.
58
+ gamma: gamma multiplier for covariance term. Default: 1.
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ lambd: float,
64
+ mu: float,
65
+ nu: float = 1,
66
+ gamma: float = 1,
67
+ ):
68
+ super().__init__()
69
+ self.lambd = lambd
70
+ self.mu = mu
71
+ self.nu = nu
72
+ self.gamma = gamma
73
+
74
+ def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
75
+ """Computes VICReg loss."""
76
+ return vicreg_loss(z1, z2, self.lambd, self.mu, self.nu, self.gamma)
quadra/main.py ADDED
@@ -0,0 +1,49 @@
1
+ import time
2
+
3
+ import hydra
4
+ import matplotlib
5
+ from omegaconf import DictConfig
6
+ from pytorch_lightning import seed_everything
7
+
8
+ from quadra.tasks.base import Task
9
+ from quadra.utils.resolver import register_resolvers
10
+ from quadra.utils.utils import get_logger, load_envs, setup_opencv
11
+ from quadra.utils.validator import validate_config
12
+
13
+ load_envs()
14
+ register_resolvers()
15
+
16
+
17
+ matplotlib.use("Agg")
18
+ log = get_logger(__name__)
19
+
20
+
21
+ @hydra.main(config_path="configs/", config_name="config.yaml", version_base="1.3.0")
22
+ def main(config: DictConfig):
23
+ """Main entry function for any of the tasks."""
24
+ if config.validate:
25
+ start = time.time()
26
+ validate_config(config)
27
+ stop = time.time()
28
+ log.info("Config validation took %f seconds", stop - start)
29
+
30
+ from quadra.utils import utils # pylint: disable=import-outside-toplevel
31
+
32
+ utils.extras(config)
33
+
34
+ # Prints the resolved configuration to the console
35
+ if config.get("print_config"):
36
+ utils.print_config(config, resolve=True)
37
+
38
+ # Set seed for random number generators in pytorch, numpy and python.random
39
+ seed_everything(config.core.seed, workers=True)
40
+ setup_opencv()
41
+
42
+ # Run specified task using the configuration composition
43
+ task: Task = hydra.utils.instantiate(config.task, config, _recursive_=False)
44
+ task.execute()
45
+
46
+
47
+ if __name__ == "__main__":
48
+ # pylint: disable=no-value-for-parameter
49
+ main()
@@ -0,0 +1,3 @@
1
+ from .segmentation import segmentation_props
2
+
3
+ __all__ = ["segmentation_props"]
@@ -0,0 +1,251 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import cast
4
+
5
+ import numpy as np
6
+ import torch
7
+ from scipy.optimize import linear_sum_assignment
8
+ from skimage.measure import label, regionprops # pylint: disable=no-name-in-module
9
+
10
+ from quadra.utils.evaluation import dice
11
+
12
+
13
+ def _pad_to_shape(a: np.ndarray, shape: tuple, constant_values: int = 0) -> np.ndarray:
14
+ """Pad lower - right with 0s
15
+ Args:
16
+ a: numpy array to pad
17
+ shape: shape of the resulting np.array
18
+ constant_values: value to pad.
19
+
20
+ Returns:
21
+ Padded array
22
+ """
23
+ y_, x_ = shape
24
+ y, x = a.shape
25
+ y_pad = y_ - y
26
+ x_pad = x_ - x
27
+ return np.pad(
28
+ a,
29
+ (
30
+ (0, y_pad),
31
+ (0, x_pad),
32
+ ),
33
+ mode="constant",
34
+ constant_values=constant_values,
35
+ )
36
+
37
+
38
+ def _get_iou(bboxes1: np.ndarray, bboxes2: np.ndarray, approx_iou: bool = False) -> np.ndarray:
39
+ """Intersect over union
40
+ Args:
41
+ bboxes1: extracted bounding boxes
42
+ bboxes2: ground truth
43
+ approx_iou: flag to approximate.
44
+
45
+ Returns:
46
+ Intersect over union array
47
+ """
48
+ x11, y11, x12, y12 = np.split(bboxes1, 4, axis=1)
49
+ x21, y21, x22, y22 = np.split(bboxes2, 4, axis=1)
50
+
51
+ # determine the (x, y)-coordinates of the intersection rectangle
52
+ xA = np.maximum(x11, x21.T)
53
+ yA = np.maximum(y11, y21.T)
54
+ xB = np.minimum(x12, x22.T)
55
+ yB = np.minimum(y12, y22.T)
56
+
57
+ # compute the area of intersection rectangle
58
+ inter_area = np.maximum((xB - xA), 0) * np.maximum((yB - yA), 0)
59
+
60
+ # compute the area of both the prediction and ground-truth rectangles
61
+ box_a_area = (x12 - x11) * (y12 - y11)
62
+ box_b_area = (x22 - x21) * (y22 - y21)
63
+
64
+ if approx_iou:
65
+ iou = inter_area / box_b_area.T
66
+ else:
67
+ iou = inter_area / (box_a_area + box_b_area.T - inter_area)
68
+
69
+ return iou
70
+
71
+
72
+ def _get_dice_matrix(
73
+ labels_pred: np.ndarray,
74
+ n_labels_pred: int,
75
+ labels_gt: np.ndarray,
76
+ n_labels_gt: int,
77
+ ) -> np.ndarray:
78
+ """Create dice matrix
79
+ Args:
80
+ labels_pred: predicted label
81
+ n_labels_pred: number of label predicted
82
+ labels_gt: ground truth labels
83
+ n_labels_gt: number of gt labels.
84
+
85
+ Returns:
86
+ Dice matrix
87
+ """
88
+ m = np.zeros((n_labels_pred, n_labels_gt))
89
+ for i in range(n_labels_pred):
90
+ pred = labels_pred == i + 1
91
+ for j in range(n_labels_gt):
92
+ gt = labels_gt == j + 1
93
+ m[i, j] = dice(
94
+ torch.Tensor(pred).unsqueeze(0).unsqueeze(0),
95
+ torch.Tensor(gt).unsqueeze(0).unsqueeze(0),
96
+ reduction="none",
97
+ )
98
+ return m
99
+
100
+
101
+ def segmentation_props(
102
+ pred: np.ndarray, mask: np.ndarray
103
+ ) -> tuple[float, float, float, float, list[float], float, int, int, int, int]:
104
+ """Return some information regarding a segmentation task.
105
+
106
+ Args:
107
+ pred (np.ndarray[bool]): Prediction of a segmentation model as
108
+ a binary image.
109
+ mask (np.ndarray[bool]): Ground truth mask as binary image
110
+
111
+ Returns:
112
+ 1-Dice(pred, mask) Given a matrix (a_ij) = (1-Dice)(prediction_i, ground_truth_j),
113
+ where prediction_i is the i-th prediction connected component and
114
+ ground_truth_j is the j-th ground truth connected component,
115
+ I compute the LSA (Linear Sum Assignment) to find the optimal 1-to-1 assignment
116
+ between predictions and ground truths that minimize the (1-Dice) score.
117
+ Then, for every unique pair of (predictioni, ground_truthj) we compute Average
118
+ (1-Dice)(predictioni, ground_truthj)
119
+ Average (1-Dice)(predictioni, ground_truthj) between True Positives
120
+ (that is predictions associated to a ground truth), which is gratis from a[i,j],
121
+ where the average is computed w.r.t. the total number of True Positives found.
122
+ Average IoU(predictioni, ground_truthj) between True Positives
123
+ (that is predictions associated to a ground truth),
124
+ where the average is computed w.r.t. the total number of True Positives found.
125
+ The IoU is computed between the minimum enclosing bounding box
126
+ of a prediction and a ground truth.
127
+ Average area of False Positives
128
+ Histogram of false positives
129
+ Average area of False Negatives
130
+ Number of True Positives (predictions associated to a ground truth)
131
+ Number of False Positives (predictions without a ground truth associated)
132
+ and their avg. area (avg is taken w.r.t. the total number of False Positives found)
133
+ Number of False Negatives (ground truth without a predictions associated)
134
+ and their avg. area (avg is taken w.r.t. the total number of False Negatives found)
135
+ Number of labels in the mask.
136
+ """
137
+ labels_pred, n_labels_pred = label(pred, connectivity=2, return_num=True, background=0)
138
+ labels_mask, n_labels_mask = label(mask, connectivity=2, return_num=True, background=0)
139
+
140
+ labels_pred = cast(np.ndarray, labels_pred)
141
+ labels_mask = cast(np.ndarray, labels_mask)
142
+ n_labels_pred = cast(int, n_labels_pred)
143
+ n_labels_mask = cast(int, n_labels_mask)
144
+
145
+ props_pred = regionprops(labels_pred)
146
+ props_mask = regionprops(labels_mask)
147
+ pred_bbox = np.array([props_pred[i].bbox for i in range(len(props_pred))])
148
+ mask_bbox = np.array([props_mask[i].bbox for i in range(len(props_mask))])
149
+
150
+ global_dice = float(
151
+ dice(
152
+ torch.Tensor(pred).unsqueeze(0).unsqueeze(0),
153
+ torch.Tensor(mask).unsqueeze(0).unsqueeze(0),
154
+ ).item()
155
+ )
156
+ lsa_iou = 0.0
157
+ lsa_dice = 0.0
158
+ tp_num = 0
159
+ fp_num = 0
160
+ fn_num = 0
161
+ fp_area = 0.0
162
+ fn_area = 0.0
163
+ fp_hist: list[float] = []
164
+ if n_labels_pred > 0 and n_labels_mask > 0:
165
+ dice_mat = _get_dice_matrix(labels_pred, n_labels_pred, labels_mask, n_labels_mask)
166
+ # Thresholding over Dice scores
167
+ dice_mat = np.where(dice_mat <= 0.9, dice_mat, 1.0)
168
+ iou_mat = _get_iou(pred_bbox, mask_bbox, approx_iou=False)
169
+ dice_mat_shape = dice_mat.shape
170
+ max_dim = np.max(dice_mat_shape)
171
+ # Add dummy Dices so LSA is unique and i can compute FP and FN
172
+ dice_mat = _pad_to_shape(dice_mat, (max_dim, max_dim), 1)
173
+ lsa = linear_sum_assignment(dice_mat, maximize=False)
174
+ for row, col in zip(lsa[0], lsa[1]):
175
+ # More preds than GTs --> False Positive
176
+ if row < n_labels_pred and col >= n_labels_mask:
177
+ min_row = pred_bbox[row][0]
178
+ min_col = pred_bbox[row][1]
179
+ h = pred_bbox[row][2] - min_row
180
+ w = pred_bbox[row][3] - min_col
181
+ fp_num += 1
182
+ area = pred[min_row : min_row + h, min_col : min_col + w].sum()
183
+ fp_area += area
184
+ fp_hist.append(area)
185
+ continue
186
+
187
+ # More GTs than preds --> False Negative
188
+ if col < n_labels_mask and row >= n_labels_pred:
189
+ min_row = mask_bbox[col][0]
190
+ min_col = mask_bbox[col][1]
191
+ h = mask_bbox[col][2] - min_row
192
+ w = mask_bbox[col][3] - min_col
193
+ fn_num += 1
194
+ fn_area += mask[min_row : min_row + h, min_col : min_col + w].sum()
195
+ continue
196
+
197
+ # Real True Positive: a prediction has been assigned to a gt
198
+ # with at least a 1-Dice score of 0.9
199
+ if dice_mat[row, col] <= 0.9:
200
+ tp_num += 1
201
+ lsa_iou += iou_mat[row, col]
202
+ lsa_dice += dice_mat[row, col]
203
+ else:
204
+ # Here we have both a FP and a FN
205
+ min_row = pred_bbox[row][0]
206
+ min_col = pred_bbox[row][1]
207
+ h = pred_bbox[row][2] - min_row
208
+ w = pred_bbox[row][3] - min_col
209
+ fp_num += 1
210
+ area = pred[min_row : min_row + h, min_col : min_col + w].sum()
211
+ fp_area += area
212
+ fp_hist.append(area)
213
+
214
+ min_row = mask_bbox[col][0]
215
+ min_col = mask_bbox[col][1]
216
+ h = mask_bbox[col][2] - min_row
217
+ w = mask_bbox[col][3] - min_col
218
+ fn_num += 1
219
+ fn_area += mask[min_row : min_row + h, min_col : min_col + w].sum()
220
+ elif len(pred_bbox) > 0 and len(mask_bbox) == 0: # No GTs --> FP
221
+ for p_bbox in pred_bbox:
222
+ min_row = p_bbox[0]
223
+ min_col = p_bbox[1]
224
+ h = p_bbox[2] - min_row
225
+ w = p_bbox[3] - min_col
226
+ fp_num += 1
227
+ # print("FP area:", pred[min_row : min_row + h, min_col : min_col + w].sum())
228
+ area = pred[min_row : min_row + h, min_col : min_col + w].sum()
229
+ fp_area += area
230
+ fp_hist.append(area)
231
+ elif len(pred_bbox) == 0 and len(mask_bbox) > 0: # No preds --> FN
232
+ for m_bbox in mask_bbox:
233
+ min_row = m_bbox[0]
234
+ min_col = m_bbox[1]
235
+ h = m_bbox[2] - min_row
236
+ w = m_bbox[3] - min_col
237
+ fn_num += 1
238
+ # print("FN area:", mask[min_row : min_row + h, min_col : min_col + w].sum())
239
+ fn_area += mask[min_row : min_row + h, min_col : min_col + w].sum()
240
+ return (
241
+ global_dice,
242
+ lsa_dice,
243
+ lsa_iou,
244
+ fp_area,
245
+ fp_hist,
246
+ fn_area,
247
+ tp_num,
248
+ fp_num,
249
+ fn_num,
250
+ n_labels_mask,
251
+ )
File without changes