quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
quadra/utils/models.py ADDED
@@ -0,0 +1,524 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import warnings
5
+ from collections.abc import Callable
6
+ from typing import Union, cast
7
+
8
+ import numpy as np
9
+ import timm
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import tqdm
13
+ from pytorch_grad_cam import GradCAM
14
+ from scipy import ndimage
15
+ from sklearn.linear_model._base import ClassifierMixin
16
+ from timm.models.layers import DropPath
17
+ from timm.models.vision_transformer import Mlp
18
+ from torch import nn
19
+
20
+ from quadra.models.evaluation import (
21
+ BaseEvaluationModel,
22
+ ONNXEvaluationModel,
23
+ TorchEvaluationModel,
24
+ TorchscriptEvaluationModel,
25
+ )
26
+ from quadra.utils import utils
27
+ from quadra.utils.vit_explainability import VitAttentionGradRollout
28
+
29
+ log = utils.get_logger(__name__)
30
+
31
+
32
+ def net_hat(input_size: int, output_size: int) -> torch.nn.Sequential:
33
+ """Create a linear layer with input and output neurons.
34
+
35
+ Args:
36
+ input_size: Number of input neurons
37
+ output_size: Number of output neurons.
38
+
39
+ Returns:
40
+ A sequential containing a single Linear layer taking input neurons and producing output neurons
41
+
42
+ """
43
+ return torch.nn.Sequential(torch.nn.Linear(input_size, output_size))
44
+
45
+
46
+ def create_net_hat(dims: list[int], act_fun: Callable = torch.nn.ReLU, dropout_p: float = 0) -> torch.nn.Sequential:
47
+ """Create a sequence of linear layers with activation functions and dropout.
48
+
49
+ Args:
50
+ dims: Dimension of hidden layers and output
51
+ act_fun: activation function to use between layers, default ReLU
52
+ dropout_p: Dropout probability. Defaults to 0.
53
+
54
+ Returns:
55
+ Sequence of linear layers of dimension specified by the input, each linear layer is followed
56
+ by an activation function and optionally a dropout layer with the input probability
57
+ """
58
+ components: list[nn.Module] = []
59
+ for i, _ in enumerate(dims[:-2]):
60
+ if dropout_p > 0:
61
+ components.append(torch.nn.Dropout(dropout_p))
62
+ components.append(net_hat(dims[i], dims[i + 1]))
63
+ components.append(act_fun())
64
+ components.append(net_hat(dims[-2], dims[-1]))
65
+ components.append(L2Norm())
66
+ return torch.nn.Sequential(*components)
67
+
68
+
69
+ class L2Norm(torch.nn.Module):
70
+ """Compute L2 Norm."""
71
+
72
+ def forward(self, x: torch.Tensor):
73
+ return x / torch.norm(x, p=2, dim=1, keepdim=True)
74
+
75
+
76
+ def init_weights(m):
77
+ """Basic weight initialization."""
78
+ classname = m.__class__.__name__
79
+ if classname.find("Conv2d") != -1 or classname.find("ConvTranspose2d") != -1:
80
+ nn.init.kaiming_uniform_(m.weight)
81
+ nn.init.zeros_(m.bias)
82
+ elif classname.find("BatchNorm") != -1:
83
+ nn.init.normal_(m.weight, 1.0, 0.02)
84
+ nn.init.zeros_(m.bias)
85
+ elif classname.find("Linear") != -1:
86
+ nn.init.xavier_normal_(m.weight)
87
+ m.bias.data.fill_(0)
88
+
89
+
90
+ def get_feature(
91
+ feature_extractor: torch.nn.Module | BaseEvaluationModel,
92
+ dl: torch.utils.data.DataLoader,
93
+ iteration_over_training: int = 1,
94
+ gradcam: bool = False,
95
+ classifier: ClassifierMixin | None = None,
96
+ input_shape: tuple[int, int, int] | None = None,
97
+ limit_batches: int | None = None,
98
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
99
+ """Given a dataloader and a PyTorch model, extract features with the model and return features and labels.
100
+
101
+ Args:
102
+ dl: PyTorch dataloader
103
+ feature_extractor: Pretrained PyTorch backbone
104
+ iteration_over_training: Extract feature iteration_over_training times for each image
105
+ (best if used with augmentation)
106
+ gradcam: Whether to compute gradcams. Notice that it will slow the function
107
+ classifier: Scikit-learn classifier
108
+ input_shape: [H,W,C], backbone input shape, needed by classifier's pytorch wrapper
109
+ limit_batches: Limit the number of batches to be processed
110
+
111
+ Returns:
112
+ Tuple containing:
113
+ features: Model features
114
+ labels: input_labels
115
+ grayscale_cams: Gradcam output maps, None if gradcam arg is False
116
+ """
117
+ if isinstance(feature_extractor, (TorchEvaluationModel, TorchscriptEvaluationModel)):
118
+ # If we are working with torch based evaluation models we need to extract the model
119
+ feature_extractor = feature_extractor.model
120
+ elif isinstance(feature_extractor, ONNXEvaluationModel):
121
+ gradcam = False
122
+
123
+ feature_extractor.eval()
124
+
125
+ # Setup gradcam
126
+ if gradcam:
127
+ if not hasattr(feature_extractor, "features_extractor"):
128
+ gradcam = False
129
+ elif isinstance(feature_extractor.features_extractor, timm.models.resnet.ResNet):
130
+ target_layers = [feature_extractor.features_extractor.layer4[-1]]
131
+ cam = GradCAM(
132
+ model=feature_extractor,
133
+ target_layers=target_layers,
134
+ use_cuda=torch.cuda.is_available(),
135
+ )
136
+ for p in feature_extractor.features_extractor.layer4[-1].parameters():
137
+ p.requires_grad = True
138
+ elif is_vision_transformer(feature_extractor.features_extractor):
139
+ grad_rollout = VitAttentionGradRollout(
140
+ feature_extractor.features_extractor,
141
+ classifier=classifier,
142
+ example_input=None if input_shape is None else torch.randn(1, *input_shape),
143
+ )
144
+ else:
145
+ gradcam = False
146
+
147
+ if not gradcam:
148
+ log.warning("Gradcam not implemented for this backbone, it will not be computed")
149
+
150
+ # Extract features from data
151
+
152
+ for iteration in range(iteration_over_training):
153
+ for i, b in enumerate(tqdm.tqdm(dl)):
154
+ x1, y1 = b
155
+
156
+ if hasattr(feature_extractor, "parameters"):
157
+ # Move input to the correct device and dtype
158
+ parameter = next(feature_extractor.parameters())
159
+ x1 = x1.to(parameter.device).to(parameter.dtype)
160
+ elif isinstance(feature_extractor, BaseEvaluationModel):
161
+ x1 = x1.to(feature_extractor.device).to(feature_extractor.model_dtype)
162
+
163
+ if gradcam:
164
+ y_hat = cast(
165
+ Union[list[torch.Tensor], tuple[torch.Tensor], torch.Tensor], feature_extractor(x1).detach()
166
+ )
167
+ # mypy can't detect that gradcam is true only if we have a features_extractor
168
+ if is_vision_transformer(feature_extractor.features_extractor): # type: ignore[union-attr]
169
+ grayscale_cam_low_res = grad_rollout(
170
+ input_tensor=x1, targets_list=y1
171
+ ) # TODO: We are using labels (y1) but it would be better to use preds
172
+ orig_shape = grayscale_cam_low_res.shape
173
+ new_shape = (orig_shape[0], x1.shape[2], x1.shape[3])
174
+ zoom_factors = tuple(np.array(new_shape) / np.array(orig_shape))
175
+ grayscale_cam = ndimage.zoom(grayscale_cam_low_res, zoom_factors, order=1)
176
+ else:
177
+ grayscale_cam = cam(input_tensor=x1, targets=None)
178
+ feature_extractor.zero_grad(set_to_none=True) # type: ignore[union-attr]
179
+ else:
180
+ with torch.no_grad():
181
+ y_hat = cast(Union[list[torch.Tensor], tuple[torch.Tensor], torch.Tensor], feature_extractor(x1))
182
+ grayscale_cams = None
183
+
184
+ if isinstance(y_hat, (list, tuple)):
185
+ y_hat = y_hat[0].cpu()
186
+ else:
187
+ y_hat = y_hat.cpu()
188
+
189
+ if torch.cuda.is_available():
190
+ torch.cuda.empty_cache()
191
+
192
+ if i == 0 and iteration == 0:
193
+ features = torch.cat([y_hat], dim=0)
194
+ labels = np.concatenate([y1])
195
+ if gradcam:
196
+ grayscale_cams = grayscale_cam
197
+ else:
198
+ features = torch.cat([features, y_hat], dim=0)
199
+ labels = np.concatenate([labels, y1], axis=0)
200
+ if gradcam:
201
+ grayscale_cams = np.concatenate([grayscale_cams, grayscale_cam], axis=0)
202
+
203
+ if limit_batches is not None and (i + 1) >= limit_batches:
204
+ break
205
+
206
+ return features.detach().numpy(), labels, grayscale_cams
207
+
208
+
209
+ def is_vision_transformer(model: torch.nn.Module) -> bool:
210
+ """Verify if pytorch module is a Vision Transformer.
211
+ This check is primarily needed for gradcam computation in classification tasks.
212
+
213
+ Args:
214
+ model: Model
215
+ """
216
+ return type(model).__name__ == "VisionTransformer"
217
+
218
+
219
+ def _no_grad_trunc_normal_(tensor: torch.Tensor, mean: float, std: float, a: float, b: float):
220
+ """Cut & paste from PyTorch official master until it's in a few official releases - RW
221
+ Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf.
222
+
223
+ Args:
224
+ tensor: an n-dimensional `torch.Tensor`
225
+ mean: the mean of the normal distribution
226
+ std: the standard deviation of the normal distribution
227
+ a: the minimum cutoff
228
+ b: the maximum cutoff
229
+ """
230
+
231
+ def norm_cdf(x: float):
232
+ """Computes standard normal cumulative distribution function."""
233
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
234
+
235
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
236
+ warnings.warn(
237
+ (
238
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
239
+ "The distribution of values may be incorrect."
240
+ ),
241
+ stacklevel=2,
242
+ )
243
+
244
+ with torch.no_grad():
245
+ # Values are generated by using a truncated uniform distribution and
246
+ # then using the inverse CDF for the normal distribution.
247
+ # Get upper and lower cdf values
248
+ l = norm_cdf((a - mean) / std)
249
+ u = norm_cdf((b - mean) / std)
250
+
251
+ # Uniformly fill tensor with values from [l, u], then translate to
252
+ # [2l-1, 2u-1].
253
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
254
+
255
+ # Use inverse cdf transform for normal distribution to get truncated
256
+ # standard normal
257
+ tensor.erfinv_()
258
+
259
+ # Transform to proper mean, std
260
+ tensor.mul_(std * math.sqrt(2.0))
261
+ tensor.add_(mean)
262
+
263
+ # Clamp to ensure it's in the proper range
264
+ tensor.clamp_(min=a, max=b)
265
+ return tensor
266
+
267
+
268
+ def trunc_normal_(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0):
269
+ """Call `_no_grad_trunc_normal_` with `torch.no_grad()`.
270
+
271
+ Args:
272
+ tensor: an n-dimensional `torch.Tensor`
273
+ mean: the mean of the normal distribution
274
+ std: the standard deviation of the normal distribution
275
+ a: the minimum cutoff
276
+ b: the maximum cutoff
277
+ """
278
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
279
+
280
+
281
+ def clip_gradients(model: nn.Module, clip: float) -> list[float]:
282
+ """Args:
283
+ model: The model
284
+ clip: The clip value.
285
+
286
+ Returns:
287
+ The norms of the gradients
288
+ """
289
+ norms = []
290
+ for _, p in model.named_parameters():
291
+ if p.grad is not None:
292
+ param_norm = p.grad.data.norm(2)
293
+ norms.append(param_norm.item())
294
+ clip_coef = clip / (param_norm + 1e-6)
295
+ if clip_coef < 1:
296
+ p.grad.data.mul_(clip_coef)
297
+ return norms
298
+
299
+
300
+ # TODO: do not use this implementation for new models
301
+
302
+
303
+ class AttentionExtractor(torch.nn.Module):
304
+ """General attention extractor.
305
+
306
+ Args:
307
+ model: Backbone model which contains the attention layer.
308
+ attention_layer_name: Attention layer for extracting attention maps.
309
+ Defaults to "attn_drop".
310
+ attention_layer_name: Attention layer for extracting attention maps.
311
+ """
312
+
313
+ def __init__(self, model: torch.nn.Module, attention_layer_name: str = "attn_drop"):
314
+ super().__init__()
315
+ self.model = model
316
+ modules = [module for module_name, module in self.model.named_modules() if attention_layer_name in module_name]
317
+ if modules:
318
+ modules[-1].register_forward_hook(self.get_attention)
319
+ self.attentions = torch.zeros((1, 0))
320
+
321
+ def clear(self):
322
+ """Clear the grabbed attentions."""
323
+ self.attentions = torch.zeros((1, 0))
324
+
325
+ def get_attention(self, module: nn.Module, input_tensor: torch.Tensor, output: torch.Tensor): # pylint: disable=unused-argument
326
+ """Method to be registered to grab attentions."""
327
+ self.attentions = output.detach().clone().cpu()
328
+
329
+ @staticmethod
330
+ def process_attention_maps(attentions: torch.Tensor, img_width: int, img_height: int) -> torch.Tensor:
331
+ """Preprocess attentions maps to be visualized.
332
+
333
+ Args:
334
+ attentions: grabbed attentions
335
+ img_width: image width
336
+ img_height: image height
337
+
338
+ Returns:
339
+ torch.Tensor: preprocessed attentions, with the shape equal to the one of the image from
340
+ which attentions has been computed
341
+ """
342
+ if len(attentions.shape) == 4:
343
+ # vit
344
+ # batch, heads, N, N (class atention layer)
345
+ attentions = attentions[:, :, 0, 1:] # batch, heads, height-1
346
+
347
+ else:
348
+ # xcit
349
+ # batch, heads, N
350
+ attentions = attentions[:, :, 1:] # batch, heads, dim-1
351
+ nh = attentions.shape[1]
352
+ patch_size = int(math.sqrt(img_width * img_height / attentions.shape[-1]))
353
+ w_featmap = img_width // patch_size
354
+ h_featmap = img_height // patch_size
355
+
356
+ # we keep only the output patch attention we dont want cls
357
+ attentions = attentions.reshape(attentions.shape[0], nh, w_featmap, h_featmap)
358
+ attentions = F.interpolate(attentions, scale_factor=patch_size, mode="nearest")
359
+ return attentions
360
+
361
+ def forward(self, t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
362
+ self.clear()
363
+ out = self.model(t)
364
+ return (out, self.attentions) # torch.jit.trace does not complain
365
+
366
+
367
+ # TODO: do not use this implementation for new models
368
+
369
+
370
+ class PositionalEncoding1D(torch.nn.Module):
371
+ """Standard sine-cosine positional encoding from https://arxiv.org/abs/2010.11929.
372
+
373
+ Args:
374
+ d_model: Embedding dimension
375
+ temperature: Temperature for the positional encoding. Defaults to 10000.0.
376
+ dropout: Dropout rate. Defaults to 0.0.
377
+ max_len: Maximum length of the sequence. Defaults to 5000.
378
+ """
379
+
380
+ def __init__(self, d_model: int, temperature: float = 10000.0, dropout: float = 0.0, max_len: int = 5000):
381
+ super().__init__()
382
+ self.dropout: torch.nn.Dropout | torch.nn.Identity
383
+ if dropout > 0:
384
+ self.dropout = torch.nn.Dropout(p=dropout)
385
+ else:
386
+ self.dropout = torch.nn.Identity()
387
+
388
+ position = torch.arange(max_len).unsqueeze(1)
389
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(temperature) / d_model))
390
+ self.pe = torch.zeros(max_len, 1, d_model)
391
+ self.pe[:, 0, 0::2] = torch.sin(position * div_term)
392
+ self.pe[:, 0, 1::2] = torch.cos(position * div_term)
393
+ self.pe = self.pe.permute(1, 0, 2)
394
+ self.pe = torch.nn.Parameter(self.pe)
395
+ self.pe.requires_grad = False
396
+
397
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
398
+ """Forward pass of the positional encoding.
399
+
400
+ Args:
401
+ x: torch tensor [batch_size, seq_len, embedding_dim].
402
+ """
403
+ x = x + self.pe[:, : x.size(1), :]
404
+ return self.dropout(x)
405
+
406
+
407
+ class LSABlock(torch.nn.Module):
408
+ """Local Self Attention Block from https://arxiv.org/abs/2112.13492.
409
+
410
+ Args:
411
+ dim: embedding dimension
412
+ num_heads: number of attention heads
413
+ mlp_ratio: ratio of mlp hidden dim to embedding dim
414
+ qkv_bias: enable bias for qkv if True
415
+ drop: dropout rate
416
+ attn_drop: attention dropout rate
417
+ drop_path: stochastic depth rate
418
+ act_layer: activation layer
419
+ norm_layer:: normalization layer
420
+ mask_diagonal: whether to mask Q^T x K diagonal with -infinity so not to
421
+ count self relationship between tokens. Defaults to True
422
+ learnable_temperature: whether to use a learnable temperature as specified in
423
+ https://arxiv.org/abs/2112.13492. Defaults to True.
424
+ """
425
+
426
+ def __init__(
427
+ self,
428
+ dim: int,
429
+ num_heads: int,
430
+ mlp_ratio: float = 4.0,
431
+ qkv_bias: bool = False,
432
+ drop: float = 0.0,
433
+ attn_drop: float = 0.0,
434
+ drop_path: float = 0.0,
435
+ act_layer: type[nn.Module] = torch.nn.GELU,
436
+ norm_layer: type[torch.nn.LayerNorm] = torch.nn.LayerNorm,
437
+ mask_diagonal: bool = True,
438
+ learnable_temperature: bool = True,
439
+ ):
440
+ super().__init__()
441
+ self.norm1 = norm_layer(dim)
442
+ self.attn = LocalSelfAttention(
443
+ dim,
444
+ num_heads=num_heads,
445
+ qkv_bias=qkv_bias,
446
+ attn_drop=attn_drop,
447
+ proj_drop=drop,
448
+ mask_diagonal=mask_diagonal,
449
+ learnable_temperature=learnable_temperature,
450
+ )
451
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
452
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else torch.nn.Identity()
453
+ self.norm2 = norm_layer(dim)
454
+ mlp_hidden_dim = int(dim * mlp_ratio)
455
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
456
+
457
+ def forward(self, x):
458
+ x = x + self.drop_path(self.attn(self.norm1(x)))
459
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
460
+ return x
461
+
462
+
463
+ class LocalSelfAttention(torch.nn.Module):
464
+ """Local Self Attention from https://arxiv.org/abs/2112.13492.
465
+
466
+ Args:
467
+ dim: embedding dimension.
468
+ num_heads: number of attention heads.
469
+ qkv_bias: enable bias for qkv if True.
470
+ attn_drop: attention dropout rate.
471
+ proj_drop: projection dropout rate.
472
+ mask_diagonal: whether to mask Q^T x K diagonal with -infinity
473
+ so not to count self relationship between tokens. Defaults to True.
474
+ learnable_temperature: whether to use a learnable temperature as specified in
475
+ https://arxiv.org/abs/2112.13492. Defaults to True.
476
+ """
477
+
478
+ def __init__(
479
+ self,
480
+ dim: int,
481
+ num_heads: int = 8,
482
+ qkv_bias: bool = False,
483
+ attn_drop: float = 0.0,
484
+ proj_drop: float = 0.0,
485
+ mask_diagonal: bool = True,
486
+ learnable_temperature: bool = True,
487
+ ):
488
+ super().__init__()
489
+ self.num_heads = num_heads
490
+ head_dim = dim // num_heads
491
+ self.mask_diagonal = mask_diagonal
492
+ if learnable_temperature:
493
+ self.register_parameter("scale", torch.nn.Parameter(torch.tensor(head_dim**-0.5, requires_grad=True)))
494
+ else:
495
+ self.scale = head_dim**-0.5
496
+
497
+ self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias)
498
+ self.attn_drop = torch.nn.Dropout(attn_drop)
499
+ self.proj = torch.nn.Linear(dim, dim)
500
+ self.proj_drop = torch.nn.Dropout(proj_drop)
501
+
502
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
503
+ """Computes the local self attention.
504
+
505
+ Args:
506
+ x: input tensor
507
+
508
+ Returns:
509
+ Output of the local self attention.
510
+ """
511
+ B, N, C = x.shape
512
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
513
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
514
+
515
+ attn = (q @ k.transpose(-2, -1)) * self.scale
516
+ if self.mask_diagonal:
517
+ attn[torch.eye(N, device=attn.device, dtype=torch.bool).repeat(B, self.num_heads, 1, 1)] = -float("inf")
518
+ attn = attn.softmax(dim=-1)
519
+ attn = self.attn_drop(attn)
520
+
521
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
522
+ x = self.proj(x)
523
+ x = self.proj_drop(x)
524
+ return x
@@ -0,0 +1,15 @@
1
+ from .dataset import generate_patch_dataset, get_image_mask_association
2
+ from .metrics import compute_patch_metrics, reconstruct_patch
3
+ from .model import RleEncoder, save_classification_result
4
+ from .visualization import plot_patch_reconstruction, plot_patch_results
5
+
6
+ __all__ = [
7
+ "generate_patch_dataset",
8
+ "reconstruct_patch",
9
+ "save_classification_result",
10
+ "plot_patch_reconstruction",
11
+ "plot_patch_results",
12
+ "get_image_mask_association",
13
+ "compute_patch_metrics",
14
+ "RleEncoder",
15
+ ]