quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
quadra/utils/models.py ADDED
@@ -0,0 +1,523 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import warnings
5
+ from collections.abc import Callable
6
+ from typing import Union, cast
7
+
8
+ import numpy as np
9
+ import timm
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import tqdm
13
+ from pytorch_grad_cam import GradCAM
14
+ from scipy import ndimage
15
+ from sklearn.linear_model._base import ClassifierMixin
16
+ from timm.models.layers import DropPath
17
+ from timm.models.vision_transformer import Mlp
18
+ from torch import nn
19
+
20
+ from quadra.models.evaluation import (
21
+ BaseEvaluationModel,
22
+ ONNXEvaluationModel,
23
+ TorchEvaluationModel,
24
+ TorchscriptEvaluationModel,
25
+ )
26
+ from quadra.utils import utils
27
+ from quadra.utils.vit_explainability import VitAttentionGradRollout
28
+
29
+ log = utils.get_logger(__name__)
30
+
31
+
32
+ def net_hat(input_size: int, output_size: int) -> torch.nn.Sequential:
33
+ """Create a linear layer with input and output neurons.
34
+
35
+ Args:
36
+ input_size: Number of input neurons
37
+ output_size: Number of output neurons.
38
+
39
+ Returns:
40
+ A sequential containing a single Linear layer taking input neurons and producing output neurons
41
+
42
+ """
43
+ return torch.nn.Sequential(torch.nn.Linear(input_size, output_size))
44
+
45
+
46
+ def create_net_hat(dims: list[int], act_fun: Callable = torch.nn.ReLU, dropout_p: float = 0) -> torch.nn.Sequential:
47
+ """Create a sequence of linear layers with activation functions and dropout.
48
+
49
+ Args:
50
+ dims: Dimension of hidden layers and output
51
+ act_fun: activation function to use between layers, default ReLU
52
+ dropout_p: Dropout probability. Defaults to 0.
53
+
54
+ Returns:
55
+ Sequence of linear layers of dimension specified by the input, each linear layer is followed
56
+ by an activation function and optionally a dropout layer with the input probability
57
+ """
58
+ components: list[nn.Module] = []
59
+ for i, _ in enumerate(dims[:-2]):
60
+ if dropout_p > 0:
61
+ components.append(torch.nn.Dropout(dropout_p))
62
+ components.append(net_hat(dims[i], dims[i + 1]))
63
+ components.append(act_fun())
64
+ components.append(net_hat(dims[-2], dims[-1]))
65
+ components.append(L2Norm())
66
+ return torch.nn.Sequential(*components)
67
+
68
+
69
+ class L2Norm(torch.nn.Module):
70
+ """Compute L2 Norm."""
71
+
72
+ def forward(self, x: torch.Tensor):
73
+ return x / torch.norm(x, p=2, dim=1, keepdim=True)
74
+
75
+
76
+ def init_weights(m):
77
+ """Basic weight initialization."""
78
+ classname = m.__class__.__name__
79
+ if classname.find("Conv2d") != -1 or classname.find("ConvTranspose2d") != -1:
80
+ nn.init.kaiming_uniform_(m.weight)
81
+ nn.init.zeros_(m.bias)
82
+ elif classname.find("BatchNorm") != -1:
83
+ nn.init.normal_(m.weight, 1.0, 0.02)
84
+ nn.init.zeros_(m.bias)
85
+ elif classname.find("Linear") != -1:
86
+ nn.init.xavier_normal_(m.weight)
87
+ m.bias.data.fill_(0)
88
+
89
+
90
+ def get_feature(
91
+ feature_extractor: torch.nn.Module | BaseEvaluationModel,
92
+ dl: torch.utils.data.DataLoader,
93
+ iteration_over_training: int = 1,
94
+ gradcam: bool = False,
95
+ classifier: ClassifierMixin | None = None,
96
+ input_shape: tuple[int, int, int] | None = None,
97
+ limit_batches: int | None = None,
98
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
99
+ """Given a dataloader and a PyTorch model, extract features with the model and return features and labels.
100
+
101
+ Args:
102
+ dl: PyTorch dataloader
103
+ feature_extractor: Pretrained PyTorch backbone
104
+ iteration_over_training: Extract feature iteration_over_training times for each image
105
+ (best if used with augmentation)
106
+ gradcam: Whether to compute gradcams. Notice that it will slow the function
107
+ classifier: Scikit-learn classifier
108
+ input_shape: [H,W,C], backbone input shape, needed by classifier's pytorch wrapper
109
+ limit_batches: Limit the number of batches to be processed
110
+
111
+ Returns:
112
+ Tuple containing:
113
+ features: Model features
114
+ labels: input_labels
115
+ grayscale_cams: Gradcam output maps, None if gradcam arg is False
116
+ """
117
+ if isinstance(feature_extractor, (TorchEvaluationModel, TorchscriptEvaluationModel)):
118
+ # If we are working with torch based evaluation models we need to extract the model
119
+ feature_extractor = feature_extractor.model
120
+ elif isinstance(feature_extractor, ONNXEvaluationModel):
121
+ gradcam = False
122
+
123
+ feature_extractor.eval()
124
+
125
+ # Setup gradcam
126
+ if gradcam:
127
+ if not hasattr(feature_extractor, "features_extractor"):
128
+ gradcam = False
129
+ elif isinstance(feature_extractor.features_extractor, timm.models.resnet.ResNet):
130
+ target_layers = [feature_extractor.features_extractor.layer4[-1]]
131
+ cam = GradCAM(
132
+ model=feature_extractor,
133
+ target_layers=target_layers,
134
+ )
135
+ for p in feature_extractor.features_extractor.layer4[-1].parameters():
136
+ p.requires_grad = True
137
+ elif is_vision_transformer(feature_extractor.features_extractor):
138
+ grad_rollout = VitAttentionGradRollout(
139
+ feature_extractor.features_extractor,
140
+ classifier=classifier,
141
+ example_input=None if input_shape is None else torch.randn(1, *input_shape),
142
+ )
143
+ else:
144
+ gradcam = False
145
+
146
+ if not gradcam:
147
+ log.warning("Gradcam not implemented for this backbone, it will not be computed")
148
+
149
+ # Extract features from data
150
+
151
+ for iteration in range(iteration_over_training):
152
+ for i, b in enumerate(tqdm.tqdm(dl)):
153
+ x1, y1 = b
154
+
155
+ if hasattr(feature_extractor, "parameters"):
156
+ # Move input to the correct device and dtype
157
+ parameter = next(feature_extractor.parameters())
158
+ x1 = x1.to(parameter.device).to(parameter.dtype)
159
+ elif isinstance(feature_extractor, BaseEvaluationModel):
160
+ x1 = x1.to(feature_extractor.device).to(feature_extractor.model_dtype)
161
+
162
+ if gradcam:
163
+ y_hat = cast(
164
+ Union[list[torch.Tensor], tuple[torch.Tensor], torch.Tensor], feature_extractor(x1).detach()
165
+ )
166
+ # mypy can't detect that gradcam is true only if we have a features_extractor
167
+ if is_vision_transformer(feature_extractor.features_extractor): # type: ignore[union-attr]
168
+ grayscale_cam_low_res = grad_rollout(
169
+ input_tensor=x1, targets_list=y1
170
+ ) # TODO: We are using labels (y1) but it would be better to use preds
171
+ orig_shape = grayscale_cam_low_res.shape
172
+ new_shape = (orig_shape[0], x1.shape[2], x1.shape[3])
173
+ zoom_factors = tuple(np.array(new_shape) / np.array(orig_shape))
174
+ grayscale_cam = ndimage.zoom(grayscale_cam_low_res, zoom_factors, order=1)
175
+ else:
176
+ grayscale_cam = cam(input_tensor=x1, targets=None)
177
+ feature_extractor.zero_grad(set_to_none=True) # type: ignore[union-attr]
178
+ else:
179
+ with torch.no_grad():
180
+ y_hat = cast(Union[list[torch.Tensor], tuple[torch.Tensor], torch.Tensor], feature_extractor(x1))
181
+ grayscale_cams = None
182
+
183
+ if isinstance(y_hat, (list, tuple)):
184
+ y_hat = y_hat[0].cpu()
185
+ else:
186
+ y_hat = y_hat.cpu()
187
+
188
+ if torch.cuda.is_available():
189
+ torch.cuda.empty_cache()
190
+
191
+ if i == 0 and iteration == 0:
192
+ features = torch.cat([y_hat], dim=0)
193
+ labels = np.concatenate([y1])
194
+ if gradcam:
195
+ grayscale_cams = grayscale_cam
196
+ else:
197
+ features = torch.cat([features, y_hat], dim=0)
198
+ labels = np.concatenate([labels, y1], axis=0)
199
+ if gradcam:
200
+ grayscale_cams = np.concatenate([grayscale_cams, grayscale_cam], axis=0)
201
+
202
+ if limit_batches is not None and (i + 1) >= limit_batches:
203
+ break
204
+
205
+ return features.detach().numpy(), labels, grayscale_cams
206
+
207
+
208
+ def is_vision_transformer(model: torch.nn.Module) -> bool:
209
+ """Verify if pytorch module is a Vision Transformer.
210
+ This check is primarily needed for gradcam computation in classification tasks.
211
+
212
+ Args:
213
+ model: Model
214
+ """
215
+ return type(model).__name__ == "VisionTransformer"
216
+
217
+
218
+ def _no_grad_trunc_normal_(tensor: torch.Tensor, mean: float, std: float, a: float, b: float):
219
+ """Cut & paste from PyTorch official master until it's in a few official releases - RW
220
+ Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf.
221
+
222
+ Args:
223
+ tensor: an n-dimensional `torch.Tensor`
224
+ mean: the mean of the normal distribution
225
+ std: the standard deviation of the normal distribution
226
+ a: the minimum cutoff
227
+ b: the maximum cutoff
228
+ """
229
+
230
+ def norm_cdf(x: float):
231
+ """Computes standard normal cumulative distribution function."""
232
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
233
+
234
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
235
+ warnings.warn(
236
+ (
237
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
238
+ "The distribution of values may be incorrect."
239
+ ),
240
+ stacklevel=2,
241
+ )
242
+
243
+ with torch.no_grad():
244
+ # Values are generated by using a truncated uniform distribution and
245
+ # then using the inverse CDF for the normal distribution.
246
+ # Get upper and lower cdf values
247
+ l = norm_cdf((a - mean) / std)
248
+ u = norm_cdf((b - mean) / std)
249
+
250
+ # Uniformly fill tensor with values from [l, u], then translate to
251
+ # [2l-1, 2u-1].
252
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
253
+
254
+ # Use inverse cdf transform for normal distribution to get truncated
255
+ # standard normal
256
+ tensor.erfinv_()
257
+
258
+ # Transform to proper mean, std
259
+ tensor.mul_(std * math.sqrt(2.0))
260
+ tensor.add_(mean)
261
+
262
+ # Clamp to ensure it's in the proper range
263
+ tensor.clamp_(min=a, max=b)
264
+ return tensor
265
+
266
+
267
+ def trunc_normal_(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0):
268
+ """Call `_no_grad_trunc_normal_` with `torch.no_grad()`.
269
+
270
+ Args:
271
+ tensor: an n-dimensional `torch.Tensor`
272
+ mean: the mean of the normal distribution
273
+ std: the standard deviation of the normal distribution
274
+ a: the minimum cutoff
275
+ b: the maximum cutoff
276
+ """
277
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
278
+
279
+
280
+ def clip_gradients(model: nn.Module, clip: float) -> list[float]:
281
+ """Args:
282
+ model: The model
283
+ clip: The clip value.
284
+
285
+ Returns:
286
+ The norms of the gradients
287
+ """
288
+ norms = []
289
+ for _, p in model.named_parameters():
290
+ if p.grad is not None:
291
+ param_norm = p.grad.data.norm(2)
292
+ norms.append(param_norm.item())
293
+ clip_coef = clip / (param_norm + 1e-6)
294
+ if clip_coef < 1:
295
+ p.grad.data.mul_(clip_coef)
296
+ return norms
297
+
298
+
299
+ # TODO: do not use this implementation for new models
300
+
301
+
302
+ class AttentionExtractor(torch.nn.Module):
303
+ """General attention extractor.
304
+
305
+ Args:
306
+ model: Backbone model which contains the attention layer.
307
+ attention_layer_name: Attention layer for extracting attention maps.
308
+ Defaults to "attn_drop".
309
+ attention_layer_name: Attention layer for extracting attention maps.
310
+ """
311
+
312
+ def __init__(self, model: torch.nn.Module, attention_layer_name: str = "attn_drop"):
313
+ super().__init__()
314
+ self.model = model
315
+ modules = [module for module_name, module in self.model.named_modules() if attention_layer_name in module_name]
316
+ if modules:
317
+ modules[-1].register_forward_hook(self.get_attention)
318
+ self.attentions = torch.zeros((1, 0))
319
+
320
+ def clear(self):
321
+ """Clear the grabbed attentions."""
322
+ self.attentions = torch.zeros((1, 0))
323
+
324
+ def get_attention(self, module: nn.Module, input_tensor: torch.Tensor, output: torch.Tensor): # pylint: disable=unused-argument
325
+ """Method to be registered to grab attentions."""
326
+ self.attentions = output.detach().clone().cpu()
327
+
328
+ @staticmethod
329
+ def process_attention_maps(attentions: torch.Tensor, img_width: int, img_height: int) -> torch.Tensor:
330
+ """Preprocess attentions maps to be visualized.
331
+
332
+ Args:
333
+ attentions: grabbed attentions
334
+ img_width: image width
335
+ img_height: image height
336
+
337
+ Returns:
338
+ torch.Tensor: preprocessed attentions, with the shape equal to the one of the image from
339
+ which attentions has been computed
340
+ """
341
+ if len(attentions.shape) == 4:
342
+ # vit
343
+ # batch, heads, N, N (class atention layer)
344
+ attentions = attentions[:, :, 0, 1:] # batch, heads, height-1
345
+
346
+ else:
347
+ # xcit
348
+ # batch, heads, N
349
+ attentions = attentions[:, :, 1:] # batch, heads, dim-1
350
+ nh = attentions.shape[1]
351
+ patch_size = int(math.sqrt(img_width * img_height / attentions.shape[-1]))
352
+ w_featmap = img_width // patch_size
353
+ h_featmap = img_height // patch_size
354
+
355
+ # we keep only the output patch attention we dont want cls
356
+ attentions = attentions.reshape(attentions.shape[0], nh, w_featmap, h_featmap)
357
+ attentions = F.interpolate(attentions, scale_factor=patch_size, mode="nearest")
358
+ return attentions
359
+
360
+ def forward(self, t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
361
+ self.clear()
362
+ out = self.model(t)
363
+ return (out, self.attentions) # torch.jit.trace does not complain
364
+
365
+
366
+ # TODO: do not use this implementation for new models
367
+
368
+
369
+ class PositionalEncoding1D(torch.nn.Module):
370
+ """Standard sine-cosine positional encoding from https://arxiv.org/abs/2010.11929.
371
+
372
+ Args:
373
+ d_model: Embedding dimension
374
+ temperature: Temperature for the positional encoding. Defaults to 10000.0.
375
+ dropout: Dropout rate. Defaults to 0.0.
376
+ max_len: Maximum length of the sequence. Defaults to 5000.
377
+ """
378
+
379
+ def __init__(self, d_model: int, temperature: float = 10000.0, dropout: float = 0.0, max_len: int = 5000):
380
+ super().__init__()
381
+ self.dropout: torch.nn.Dropout | torch.nn.Identity
382
+ if dropout > 0:
383
+ self.dropout = torch.nn.Dropout(p=dropout)
384
+ else:
385
+ self.dropout = torch.nn.Identity()
386
+
387
+ position = torch.arange(max_len).unsqueeze(1)
388
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(temperature) / d_model))
389
+ self.pe = torch.zeros(max_len, 1, d_model)
390
+ self.pe[:, 0, 0::2] = torch.sin(position * div_term)
391
+ self.pe[:, 0, 1::2] = torch.cos(position * div_term)
392
+ self.pe = self.pe.permute(1, 0, 2)
393
+ self.pe = torch.nn.Parameter(self.pe)
394
+ self.pe.requires_grad = False
395
+
396
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
397
+ """Forward pass of the positional encoding.
398
+
399
+ Args:
400
+ x: torch tensor [batch_size, seq_len, embedding_dim].
401
+ """
402
+ x = x + self.pe[:, : x.size(1), :]
403
+ return self.dropout(x)
404
+
405
+
406
+ class LSABlock(torch.nn.Module):
407
+ """Local Self Attention Block from https://arxiv.org/abs/2112.13492.
408
+
409
+ Args:
410
+ dim: embedding dimension
411
+ num_heads: number of attention heads
412
+ mlp_ratio: ratio of mlp hidden dim to embedding dim
413
+ qkv_bias: enable bias for qkv if True
414
+ drop: dropout rate
415
+ attn_drop: attention dropout rate
416
+ drop_path: stochastic depth rate
417
+ act_layer: activation layer
418
+ norm_layer:: normalization layer
419
+ mask_diagonal: whether to mask Q^T x K diagonal with -infinity so not to
420
+ count self relationship between tokens. Defaults to True
421
+ learnable_temperature: whether to use a learnable temperature as specified in
422
+ https://arxiv.org/abs/2112.13492. Defaults to True.
423
+ """
424
+
425
+ def __init__(
426
+ self,
427
+ dim: int,
428
+ num_heads: int,
429
+ mlp_ratio: float = 4.0,
430
+ qkv_bias: bool = False,
431
+ drop: float = 0.0,
432
+ attn_drop: float = 0.0,
433
+ drop_path: float = 0.0,
434
+ act_layer: type[nn.Module] = torch.nn.GELU,
435
+ norm_layer: type[torch.nn.LayerNorm] = torch.nn.LayerNorm,
436
+ mask_diagonal: bool = True,
437
+ learnable_temperature: bool = True,
438
+ ):
439
+ super().__init__()
440
+ self.norm1 = norm_layer(dim)
441
+ self.attn = LocalSelfAttention(
442
+ dim,
443
+ num_heads=num_heads,
444
+ qkv_bias=qkv_bias,
445
+ attn_drop=attn_drop,
446
+ proj_drop=drop,
447
+ mask_diagonal=mask_diagonal,
448
+ learnable_temperature=learnable_temperature,
449
+ )
450
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
451
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else torch.nn.Identity()
452
+ self.norm2 = norm_layer(dim)
453
+ mlp_hidden_dim = int(dim * mlp_ratio)
454
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
455
+
456
+ def forward(self, x):
457
+ x = x + self.drop_path(self.attn(self.norm1(x)))
458
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
459
+ return x
460
+
461
+
462
+ class LocalSelfAttention(torch.nn.Module):
463
+ """Local Self Attention from https://arxiv.org/abs/2112.13492.
464
+
465
+ Args:
466
+ dim: embedding dimension.
467
+ num_heads: number of attention heads.
468
+ qkv_bias: enable bias for qkv if True.
469
+ attn_drop: attention dropout rate.
470
+ proj_drop: projection dropout rate.
471
+ mask_diagonal: whether to mask Q^T x K diagonal with -infinity
472
+ so not to count self relationship between tokens. Defaults to True.
473
+ learnable_temperature: whether to use a learnable temperature as specified in
474
+ https://arxiv.org/abs/2112.13492. Defaults to True.
475
+ """
476
+
477
+ def __init__(
478
+ self,
479
+ dim: int,
480
+ num_heads: int = 8,
481
+ qkv_bias: bool = False,
482
+ attn_drop: float = 0.0,
483
+ proj_drop: float = 0.0,
484
+ mask_diagonal: bool = True,
485
+ learnable_temperature: bool = True,
486
+ ):
487
+ super().__init__()
488
+ self.num_heads = num_heads
489
+ head_dim = dim // num_heads
490
+ self.mask_diagonal = mask_diagonal
491
+ if learnable_temperature:
492
+ self.register_parameter("scale", torch.nn.Parameter(torch.tensor(head_dim**-0.5, requires_grad=True)))
493
+ else:
494
+ self.scale = head_dim**-0.5
495
+
496
+ self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias)
497
+ self.attn_drop = torch.nn.Dropout(attn_drop)
498
+ self.proj = torch.nn.Linear(dim, dim)
499
+ self.proj_drop = torch.nn.Dropout(proj_drop)
500
+
501
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
502
+ """Computes the local self attention.
503
+
504
+ Args:
505
+ x: input tensor
506
+
507
+ Returns:
508
+ Output of the local self attention.
509
+ """
510
+ B, N, C = x.shape
511
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
512
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
513
+
514
+ attn = (q @ k.transpose(-2, -1)) * self.scale
515
+ if self.mask_diagonal:
516
+ attn[torch.eye(N, device=attn.device, dtype=torch.bool).repeat(B, self.num_heads, 1, 1)] = -float("inf")
517
+ attn = attn.softmax(dim=-1)
518
+ attn = self.attn_drop(attn)
519
+
520
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
521
+ x = self.proj(x)
522
+ x = self.proj_drop(x)
523
+ return x
@@ -0,0 +1,15 @@
1
+ from .dataset import generate_patch_dataset, get_image_mask_association
2
+ from .metrics import compute_patch_metrics, reconstruct_patch
3
+ from .model import RleEncoder, save_classification_result
4
+ from .visualization import plot_patch_reconstruction, plot_patch_results
5
+
6
+ __all__ = [
7
+ "generate_patch_dataset",
8
+ "reconstruct_patch",
9
+ "save_classification_result",
10
+ "plot_patch_reconstruction",
11
+ "plot_patch_results",
12
+ "get_image_mask_association",
13
+ "compute_patch_metrics",
14
+ "RleEncoder",
15
+ ]