quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,422 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import os
5
+ import random
6
+ from collections.abc import Callable, Iterable
7
+ from typing import Any
8
+
9
+ import albumentations
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torch
13
+ from albumentations.augmentations.transforms import Normalize
14
+ from albumentations.core.composition import TransformsSeqType
15
+ from albumentations.core.transforms_interface import NoOp
16
+ from albumentations.pytorch.transforms import ToTensorV2
17
+ from matplotlib.colors import ListedColormap
18
+ from matplotlib.lines import Line2D
19
+ from matplotlib.pyplot import get_cmap
20
+ from mpl_toolkits.axes_grid1 import ImageGrid
21
+ from omegaconf import DictConfig, ListConfig
22
+ from pytorch_grad_cam.utils.image import show_cam_on_image
23
+
24
+ from quadra.utils import utils
25
+
26
+ log = utils.get_logger(__name__)
27
+
28
+
29
+ class UnNormalize:
30
+ """Unnormalize a tensor image with mean and standard deviation."""
31
+
32
+ def __init__(self, mean, std):
33
+ self.mean = mean
34
+ self.std = std
35
+
36
+ def __call__(self, tensor: torch.Tensor, make_copy=True) -> torch.Tensor:
37
+ """Call function to unnormalize a tensor image with mean and standard deviation.
38
+
39
+ Args:
40
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
41
+ make_copy (bool): whether to apply normalization to a copied tensor
42
+ Returns:
43
+ Tensor: Normalized image.
44
+ """
45
+ if make_copy:
46
+ new_t = tensor.detach().clone()
47
+ else:
48
+ new_t = tensor
49
+ for t, m, s in zip(new_t, self.mean, self.std):
50
+ t.mul_(s).add_(m)
51
+ # The normalize code -> t.sub_(m).div_(s)
52
+ return new_t
53
+
54
+
55
+ def create_grid_figure(
56
+ images: Iterable[Iterable[np.ndarray]],
57
+ nrows: int,
58
+ ncols: int,
59
+ file_path: str,
60
+ bounds: list[tuple[float, float]],
61
+ row_names: Iterable[str] | None = None,
62
+ fig_size: tuple[int, int] = (12, 8),
63
+ ):
64
+ """Create a grid figure with images.
65
+
66
+ Args:
67
+ images: List of images to plot.
68
+ nrows: Number of rows in the grid.
69
+ ncols: Number of columns in the grid.
70
+ file_path: Path to save the figure.
71
+ row_names: Row names. Defaults to None.
72
+ fig_size: Figure size. Defaults to (12, 8).
73
+ bounds: Bounds for the images. Defaults to None.
74
+ """
75
+ default_plt_backend = plt.get_backend()
76
+ plt.switch_backend("Agg")
77
+ _, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=fig_size, squeeze=False)
78
+ for i, row in enumerate(images):
79
+ for j, image in enumerate(row):
80
+ image_to_plot = image[0] if len(image.shape) == 3 and image.shape[0] == 1 else image
81
+ ax[i][j].imshow(image_to_plot, vmin=bounds[i][0], vmax=bounds[i][1])
82
+ ax[i][j].get_xaxis().set_ticks([])
83
+ ax[i][j].get_yaxis().set_ticks([])
84
+ if row_names is not None:
85
+ for ax, name in zip(ax[:, 0], row_names): # noqa: B020
86
+ ax.set_ylabel(name, rotation=90)
87
+
88
+ plt.tight_layout()
89
+ plt.savefig(file_path, bbox_inches="tight", dpi=300, facecolor="white", transparent=False)
90
+ plt.close()
91
+ plt.switch_backend(default_plt_backend)
92
+
93
+
94
+ def create_visualization_dataset(dataset: torch.utils.data.Dataset):
95
+ """Create a visualization dataset by updating transforms."""
96
+
97
+ def convert_transforms(transforms: Any):
98
+ """Handle different types of transforms."""
99
+ if isinstance(transforms, albumentations.BaseCompose):
100
+ transforms.transforms = convert_transforms(transforms.transforms)
101
+ if isinstance(transforms, (list, ListConfig, TransformsSeqType)):
102
+ transforms = [convert_transforms(t) for t in transforms]
103
+ if isinstance(transforms, (dict, DictConfig)):
104
+ for tname, t in transforms.items():
105
+ transforms[tname] = convert_transforms(t)
106
+ if isinstance(transforms, (Normalize, ToTensorV2)):
107
+ return NoOp(p=1)
108
+ return transforms
109
+
110
+ new_dataset = copy.deepcopy(dataset)
111
+ # TODO: Create dataset class that has a transform attribut, we can then use isinstance
112
+ if isinstance(dataset, torch.utils.data.Dataset):
113
+ transform = copy.deepcopy(dataset.transform) # type: ignore[attr-defined]
114
+ if transform is not None:
115
+ new_transforms = convert_transforms(transform)
116
+ new_dataset.transform = new_transforms # type: ignore[attr-defined]
117
+ else:
118
+ raise ValueError(f"The dataset transform {type(transform)} is not supported")
119
+ else:
120
+ raise ValueError(f"The dataset type {dataset} is not supported")
121
+ return new_dataset
122
+
123
+
124
+ def show_mask_on_image(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
125
+ """Show a mask on an image.
126
+
127
+ Args:
128
+ image (np.ndarray): The image.
129
+ mask (np.ndarray): The mask.
130
+
131
+ Returns:
132
+ np.ndarray: The image with the mask.
133
+ """
134
+ image = image.astype(np.float32) / 255
135
+ mask = mask.astype(np.float32) / 255
136
+ out = mask + image
137
+ out = out / np.max(out)
138
+ return (255 * out).astype(np.uint8)
139
+
140
+
141
+ def reconstruct_multiclass_mask(
142
+ mask: np.ndarray,
143
+ image_shape: tuple[int, ...],
144
+ color_map: ListedColormap,
145
+ ignore_class: int | None = None,
146
+ ground_truth_mask: np.ndarray | None = None,
147
+ ) -> np.ndarray:
148
+ """Reconstruct a multiclass mask from a single channel mask.
149
+
150
+ Args:
151
+ mask (np.ndarray): A single channel mask.
152
+ image_shape (Tuple[int, ...]): The shape of the image.
153
+ color_map (ListedColormap): The color map to use.
154
+ ignore_class (Optional[int], optional): The class to ignore. Defaults to None.
155
+ ground_truth_mask (Optional[np.ndarray], optional): The ground truth mask. Defaults to None.
156
+
157
+ Returns:
158
+ mask: np.ndarray
159
+ """
160
+ output_mask = np.zeros(image_shape)
161
+ for c in np.unique(mask):
162
+ if ignore_class is not None and c == ignore_class:
163
+ continue
164
+
165
+ output_mask[mask == c] = color_map[str(c)]
166
+
167
+ if ignore_class is not None and ground_truth_mask is not None:
168
+ output_mask[ground_truth_mask == ignore_class] = [0, 0, 0]
169
+
170
+ return output_mask
171
+
172
+
173
+ def plot_multiclass_prediction(
174
+ image: np.ndarray,
175
+ prediction_image: np.ndarray,
176
+ ground_truth_image: np.ndarray,
177
+ class_to_idx: dict[str, int],
178
+ plot_original: bool = True,
179
+ ignore_class: int | None = 0,
180
+ image_height: int = 10,
181
+ save_path: str | None = None,
182
+ color_map: str = "tab20",
183
+ ) -> None:
184
+ """Function used to plot the image predicted.
185
+
186
+ Args:
187
+ image: The image to plot
188
+ prediction_image: The prediction image
189
+ ground_truth_image: The ground truth image
190
+ class_to_idx: The class to idx mapping
191
+ plot_original: Whether to plot the original image
192
+ ignore_class: The class to ignore
193
+ image_height: The height of the output figure
194
+ save_path: The path to save the figure
195
+ color_map: The color map to use. Defaults to "tab20".
196
+ """
197
+ image = image[0 : prediction_image.shape[0], 0 : prediction_image.shape[1], :]
198
+ class_idxs = list(class_to_idx.values())
199
+ cm = get_cmap(color_map)
200
+ cmap = {str(c): tuple(int(i * 255) for i in cm(c / len(class_idxs))[:-1]) for c in class_idxs}
201
+ output_images = []
202
+ titles = []
203
+ if plot_original:
204
+ output_images.append(image)
205
+ titles.append("Original Image")
206
+
207
+ ground_truth_mask = reconstruct_multiclass_mask(ground_truth_image, image.shape, cmap, ignore_class=ignore_class)
208
+ output_images.append(ground_truth_mask)
209
+ titles.append("Ground Truth Mask")
210
+
211
+ prediction_mask = reconstruct_multiclass_mask(
212
+ prediction_image,
213
+ image.shape,
214
+ cmap,
215
+ ignore_class=ignore_class,
216
+ )
217
+ output_images.append(prediction_mask)
218
+ titles.append("Prediction Mask")
219
+ if ignore_class is not None:
220
+ prediction_mask = reconstruct_multiclass_mask(
221
+ prediction_image, image.shape, cmap, ignore_class=ignore_class, ground_truth_mask=ground_truth_image
222
+ )
223
+ prediction_title = f"Prediction Mask \n (Ignoring Ground Truth Class: {ignore_class})"
224
+ output_images.append(prediction_mask)
225
+ titles.append(prediction_title)
226
+
227
+ _, axs = plt.subplots(
228
+ ncols=len(output_images),
229
+ nrows=1,
230
+ figsize=(len(output_images) * image_height, image_height),
231
+ squeeze=False,
232
+ facecolor="white",
233
+ )
234
+ for i, output_image in output_images:
235
+ axs[0, i].imshow(show_mask_on_image(image, output_image))
236
+ axs[0, i].set_title(titles[i])
237
+ axs[0, i].axis("off")
238
+ custom_lines = [Line2D([0], [0], color=tuple(i / 255.0 for i in cmap[str(c)]), lw=4) for c in class_idxs]
239
+ custom_labels = list(class_to_idx.keys())
240
+ axs[0, -1].legend(custom_lines, custom_labels, loc="center left", bbox_to_anchor=(1.01, 0.81), borderaxespad=0)
241
+ if save_path is not None:
242
+ plt.savefig(save_path, bbox_inches="tight")
243
+ plt.close()
244
+
245
+
246
+ def plot_classification_results(
247
+ test_dataset: torch.utils.data.Dataset,
248
+ pred_labels: np.ndarray,
249
+ test_labels: np.ndarray,
250
+ class_name: str,
251
+ original_folder: str,
252
+ gradcam_folder: str | None = None,
253
+ grayscale_cams: np.ndarray | None = None,
254
+ unorm: Callable[[torch.Tensor], torch.Tensor] | None = None,
255
+ idx_to_class: dict | None = None,
256
+ what: str | None = None,
257
+ real_class_to_plot: int | None = None,
258
+ pred_class_to_plot: int | None = None,
259
+ rows: int | None = 1,
260
+ cols: int = 4,
261
+ figsize: tuple[int, int] = (20, 20),
262
+ gradcam: bool = False,
263
+ ) -> None:
264
+ """Plot and save images extracted from classification. If gradcam is True, same images
265
+ with a gradcam heatmap (layered on original image) will also be saved.
266
+
267
+ Args:
268
+ test_dataset: Test dataset
269
+ pred_labels: Predicted labels
270
+ test_labels: Test labels
271
+ class_name: Name of the examples' class
272
+ original_folder: Folder where original examples will be saved
273
+ gradcam_folder: Folder in which gradcam examples will be saved
274
+ grayscale_cams: Grayscale gradcams (ordered as pred_labels and test_labels)
275
+ unorm: Albumentations function to unormalize image
276
+ idx_to_class: Dictionary of class conversion
277
+ what: Can be "dis" or "conc", used if real_class_to_plot or pred_class_to_plot are None
278
+ real_class_to_plot: Real class to plot.
279
+ pred_class_to_plot: Pred class to plot.
280
+ rows: How many rows in the plot there will be.
281
+ cols: How many cols in the plot there will be.
282
+ figsize: The figure size.
283
+ gradcam: Whether to save also the gradcam version of the examples
284
+
285
+ """
286
+ to_plot = True
287
+ if gradcam:
288
+ if grayscale_cams is None:
289
+ raise ValueError("gradcam is True but grayscale_cams is None")
290
+ if gradcam_folder is None:
291
+ raise ValueError("gradcam is True but gradcam_folder is None")
292
+
293
+ if real_class_to_plot is not None:
294
+ sample_idx = np.where(test_labels == real_class_to_plot)[0]
295
+ if gradcam and grayscale_cams is not None:
296
+ grayscale_cams = grayscale_cams[test_labels == real_class_to_plot]
297
+ pred_labels = pred_labels[test_labels == real_class_to_plot]
298
+ test_labels = test_labels[test_labels == real_class_to_plot]
299
+
300
+ if pred_class_to_plot is not None:
301
+ sample_idx = np.where(pred_labels == pred_class_to_plot)[0]
302
+ if gradcam and grayscale_cams is not None:
303
+ grayscale_cams = grayscale_cams[pred_labels == pred_class_to_plot]
304
+ test_labels = test_labels[pred_labels == pred_class_to_plot]
305
+ pred_labels = pred_labels[pred_labels == pred_class_to_plot]
306
+
307
+ if pred_class_to_plot is None and real_class_to_plot is None:
308
+ raise ValueError("'real_class_to_plot' and 'pred_class_to_plot' must not be both None")
309
+
310
+ if what is not None:
311
+ if what == "dis":
312
+ cordant = pred_labels != test_labels
313
+ elif what == "con":
314
+ cordant = pred_labels == test_labels
315
+ else:
316
+ raise AssertionError(f"{what} not a valid plot type. Must be con or dis")
317
+
318
+ sample_idx = np.array(sample_idx)[cordant]
319
+ pred_labels = np.array(pred_labels)[cordant]
320
+ test_labels = np.array(test_labels)[cordant]
321
+ if gradcam:
322
+ grayscale_cams = np.array(grayscale_cams)[cordant]
323
+
324
+ # randomize
325
+ idx_random = random.sample(range(len(sample_idx)), len(sample_idx))
326
+
327
+ sample_idx = sample_idx[idx_random]
328
+ pred_labels = pred_labels[idx_random]
329
+ test_labels = test_labels[idx_random]
330
+ if gradcam and grayscale_cams is not None:
331
+ grayscale_cams = grayscale_cams[idx_random]
332
+
333
+ cordant_chunks = list(_chunks(sample_idx, cols))
334
+
335
+ if len(sample_idx) == 0:
336
+ to_plot = False
337
+ print("Nothing to plot")
338
+ else:
339
+ if rows is None or rows == 0:
340
+ total_rows = len(cordant_chunks)
341
+ else:
342
+ total_rows = len(cordant_chunks[:rows])
343
+ if gradcam:
344
+ modality_list = ["original", "gradcam"]
345
+ else:
346
+ modality_list = ["original"]
347
+ for modality in modality_list:
348
+ fig = plt.figure(figsize=figsize)
349
+ grid = ImageGrid(
350
+ fig,
351
+ 111, # similar to subplot(111)
352
+ nrows_ncols=(total_rows, cols),
353
+ axes_pad=(0.2, 0.5),
354
+ )
355
+ for i, ax in enumerate(grid):
356
+ if idx_to_class is not None:
357
+ try:
358
+ pred_label = idx_to_class[pred_labels[i]]
359
+ except Exception:
360
+ pred_label = pred_labels[i]
361
+ try:
362
+ test_label = idx_to_class[test_labels[i]]
363
+ except Exception:
364
+ test_label = test_labels[i]
365
+
366
+ ax.axis("off")
367
+ ax.set_title(f"True: {str(test_label)}\nPred {str(pred_label)}")
368
+ image, _ = test_dataset[sample_idx[i]]
369
+
370
+ if unorm is not None:
371
+ image = np.array(unorm(image))
372
+ if modality == "gradcam" and grayscale_cams is not None:
373
+ grayscale_cam = grayscale_cams[i]
374
+ rgb_cam = show_cam_on_image(
375
+ np.transpose(image, (1, 2, 0)), grayscale_cam, use_rgb=True, image_weight=0.7
376
+ )
377
+
378
+ ax.imshow(rgb_cam, cmap="gray")
379
+ if i == len(pred_labels) - 1:
380
+ break
381
+ else:
382
+ if isinstance(image, torch.Tensor):
383
+ image = image.cpu().numpy()
384
+
385
+ if image.max() <= 1:
386
+ image = image * 255
387
+ image = image.astype(int)
388
+
389
+ if len(image.shape) == 3:
390
+ if image.shape[0] == 1:
391
+ image = image[0]
392
+ elif image.shape[0] == 3:
393
+ image = image.transpose((1, 2, 0))
394
+ ax.imshow(image, cmap="gray")
395
+ if i == len(pred_labels) - 1:
396
+ break
397
+
398
+ for item in grid:
399
+ item.axis("off")
400
+
401
+ if to_plot:
402
+ save_folder: str = ""
403
+ if modality == "gradcam" and gradcam_folder is not None:
404
+ save_folder = gradcam_folder
405
+ elif modality == "original":
406
+ save_folder = original_folder
407
+ else:
408
+ log.warning("modality %s has no corresponding folder", modality)
409
+ return
410
+
411
+ plt.savefig(
412
+ os.path.join(save_folder, f"{what}cordant_{class_name}_" + modality + ".png"),
413
+ bbox_inches="tight",
414
+ pad_inches=0,
415
+ )
416
+ plt.close()
417
+
418
+
419
+ def _chunks(lst, n):
420
+ """Yield successive n-sized chunks from lst."""
421
+ for i in range(0, len(lst), n):
422
+ yield lst[i : i + n]