quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,422 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import os
|
|
5
|
+
import random
|
|
6
|
+
from collections.abc import Callable, Iterable
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import albumentations
|
|
10
|
+
import matplotlib.pyplot as plt
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
from albumentations.augmentations.transforms import Normalize
|
|
14
|
+
from albumentations.core.composition import TransformsSeqType
|
|
15
|
+
from albumentations.core.transforms_interface import NoOp
|
|
16
|
+
from albumentations.pytorch.transforms import ToTensorV2
|
|
17
|
+
from matplotlib.colors import ListedColormap
|
|
18
|
+
from matplotlib.lines import Line2D
|
|
19
|
+
from matplotlib.pyplot import get_cmap
|
|
20
|
+
from mpl_toolkits.axes_grid1 import ImageGrid
|
|
21
|
+
from omegaconf import DictConfig, ListConfig
|
|
22
|
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
|
23
|
+
|
|
24
|
+
from quadra.utils import utils
|
|
25
|
+
|
|
26
|
+
log = utils.get_logger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class UnNormalize:
|
|
30
|
+
"""Unnormalize a tensor image with mean and standard deviation."""
|
|
31
|
+
|
|
32
|
+
def __init__(self, mean, std):
|
|
33
|
+
self.mean = mean
|
|
34
|
+
self.std = std
|
|
35
|
+
|
|
36
|
+
def __call__(self, tensor: torch.Tensor, make_copy=True) -> torch.Tensor:
|
|
37
|
+
"""Call function to unnormalize a tensor image with mean and standard deviation.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
|
41
|
+
make_copy (bool): whether to apply normalization to a copied tensor
|
|
42
|
+
Returns:
|
|
43
|
+
Tensor: Normalized image.
|
|
44
|
+
"""
|
|
45
|
+
if make_copy:
|
|
46
|
+
new_t = tensor.detach().clone()
|
|
47
|
+
else:
|
|
48
|
+
new_t = tensor
|
|
49
|
+
for t, m, s in zip(new_t, self.mean, self.std):
|
|
50
|
+
t.mul_(s).add_(m)
|
|
51
|
+
# The normalize code -> t.sub_(m).div_(s)
|
|
52
|
+
return new_t
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def create_grid_figure(
|
|
56
|
+
images: Iterable[Iterable[np.ndarray]],
|
|
57
|
+
nrows: int,
|
|
58
|
+
ncols: int,
|
|
59
|
+
file_path: str,
|
|
60
|
+
bounds: list[tuple[float, float]],
|
|
61
|
+
row_names: Iterable[str] | None = None,
|
|
62
|
+
fig_size: tuple[int, int] = (12, 8),
|
|
63
|
+
):
|
|
64
|
+
"""Create a grid figure with images.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
images: List of images to plot.
|
|
68
|
+
nrows: Number of rows in the grid.
|
|
69
|
+
ncols: Number of columns in the grid.
|
|
70
|
+
file_path: Path to save the figure.
|
|
71
|
+
row_names: Row names. Defaults to None.
|
|
72
|
+
fig_size: Figure size. Defaults to (12, 8).
|
|
73
|
+
bounds: Bounds for the images. Defaults to None.
|
|
74
|
+
"""
|
|
75
|
+
default_plt_backend = plt.get_backend()
|
|
76
|
+
plt.switch_backend("Agg")
|
|
77
|
+
_, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=fig_size, squeeze=False)
|
|
78
|
+
for i, row in enumerate(images):
|
|
79
|
+
for j, image in enumerate(row):
|
|
80
|
+
image_to_plot = image[0] if len(image.shape) == 3 and image.shape[0] == 1 else image
|
|
81
|
+
ax[i][j].imshow(image_to_plot, vmin=bounds[i][0], vmax=bounds[i][1])
|
|
82
|
+
ax[i][j].get_xaxis().set_ticks([])
|
|
83
|
+
ax[i][j].get_yaxis().set_ticks([])
|
|
84
|
+
if row_names is not None:
|
|
85
|
+
for ax, name in zip(ax[:, 0], row_names): # noqa: B020
|
|
86
|
+
ax.set_ylabel(name, rotation=90)
|
|
87
|
+
|
|
88
|
+
plt.tight_layout()
|
|
89
|
+
plt.savefig(file_path, bbox_inches="tight", dpi=300, facecolor="white", transparent=False)
|
|
90
|
+
plt.close()
|
|
91
|
+
plt.switch_backend(default_plt_backend)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def create_visualization_dataset(dataset: torch.utils.data.Dataset):
|
|
95
|
+
"""Create a visualization dataset by updating transforms."""
|
|
96
|
+
|
|
97
|
+
def convert_transforms(transforms: Any):
|
|
98
|
+
"""Handle different types of transforms."""
|
|
99
|
+
if isinstance(transforms, albumentations.BaseCompose):
|
|
100
|
+
transforms.transforms = convert_transforms(transforms.transforms)
|
|
101
|
+
if isinstance(transforms, (list, ListConfig, TransformsSeqType)):
|
|
102
|
+
transforms = [convert_transforms(t) for t in transforms]
|
|
103
|
+
if isinstance(transforms, (dict, DictConfig)):
|
|
104
|
+
for tname, t in transforms.items():
|
|
105
|
+
transforms[tname] = convert_transforms(t)
|
|
106
|
+
if isinstance(transforms, (Normalize, ToTensorV2)):
|
|
107
|
+
return NoOp(p=1)
|
|
108
|
+
return transforms
|
|
109
|
+
|
|
110
|
+
new_dataset = copy.deepcopy(dataset)
|
|
111
|
+
# TODO: Create dataset class that has a transform attribut, we can then use isinstance
|
|
112
|
+
if isinstance(dataset, torch.utils.data.Dataset):
|
|
113
|
+
transform = copy.deepcopy(dataset.transform) # type: ignore[attr-defined]
|
|
114
|
+
if transform is not None:
|
|
115
|
+
new_transforms = convert_transforms(transform)
|
|
116
|
+
new_dataset.transform = new_transforms # type: ignore[attr-defined]
|
|
117
|
+
else:
|
|
118
|
+
raise ValueError(f"The dataset transform {type(transform)} is not supported")
|
|
119
|
+
else:
|
|
120
|
+
raise ValueError(f"The dataset type {dataset} is not supported")
|
|
121
|
+
return new_dataset
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def show_mask_on_image(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
|
125
|
+
"""Show a mask on an image.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
image (np.ndarray): The image.
|
|
129
|
+
mask (np.ndarray): The mask.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
np.ndarray: The image with the mask.
|
|
133
|
+
"""
|
|
134
|
+
image = image.astype(np.float32) / 255
|
|
135
|
+
mask = mask.astype(np.float32) / 255
|
|
136
|
+
out = mask + image
|
|
137
|
+
out = out / np.max(out)
|
|
138
|
+
return (255 * out).astype(np.uint8)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def reconstruct_multiclass_mask(
|
|
142
|
+
mask: np.ndarray,
|
|
143
|
+
image_shape: tuple[int, ...],
|
|
144
|
+
color_map: ListedColormap,
|
|
145
|
+
ignore_class: int | None = None,
|
|
146
|
+
ground_truth_mask: np.ndarray | None = None,
|
|
147
|
+
) -> np.ndarray:
|
|
148
|
+
"""Reconstruct a multiclass mask from a single channel mask.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
mask (np.ndarray): A single channel mask.
|
|
152
|
+
image_shape (Tuple[int, ...]): The shape of the image.
|
|
153
|
+
color_map (ListedColormap): The color map to use.
|
|
154
|
+
ignore_class (Optional[int], optional): The class to ignore. Defaults to None.
|
|
155
|
+
ground_truth_mask (Optional[np.ndarray], optional): The ground truth mask. Defaults to None.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
mask: np.ndarray
|
|
159
|
+
"""
|
|
160
|
+
output_mask = np.zeros(image_shape)
|
|
161
|
+
for c in np.unique(mask):
|
|
162
|
+
if ignore_class is not None and c == ignore_class:
|
|
163
|
+
continue
|
|
164
|
+
|
|
165
|
+
output_mask[mask == c] = color_map[str(c)]
|
|
166
|
+
|
|
167
|
+
if ignore_class is not None and ground_truth_mask is not None:
|
|
168
|
+
output_mask[ground_truth_mask == ignore_class] = [0, 0, 0]
|
|
169
|
+
|
|
170
|
+
return output_mask
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def plot_multiclass_prediction(
|
|
174
|
+
image: np.ndarray,
|
|
175
|
+
prediction_image: np.ndarray,
|
|
176
|
+
ground_truth_image: np.ndarray,
|
|
177
|
+
class_to_idx: dict[str, int],
|
|
178
|
+
plot_original: bool = True,
|
|
179
|
+
ignore_class: int | None = 0,
|
|
180
|
+
image_height: int = 10,
|
|
181
|
+
save_path: str | None = None,
|
|
182
|
+
color_map: str = "tab20",
|
|
183
|
+
) -> None:
|
|
184
|
+
"""Function used to plot the image predicted.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
image: The image to plot
|
|
188
|
+
prediction_image: The prediction image
|
|
189
|
+
ground_truth_image: The ground truth image
|
|
190
|
+
class_to_idx: The class to idx mapping
|
|
191
|
+
plot_original: Whether to plot the original image
|
|
192
|
+
ignore_class: The class to ignore
|
|
193
|
+
image_height: The height of the output figure
|
|
194
|
+
save_path: The path to save the figure
|
|
195
|
+
color_map: The color map to use. Defaults to "tab20".
|
|
196
|
+
"""
|
|
197
|
+
image = image[0 : prediction_image.shape[0], 0 : prediction_image.shape[1], :]
|
|
198
|
+
class_idxs = list(class_to_idx.values())
|
|
199
|
+
cm = get_cmap(color_map)
|
|
200
|
+
cmap = {str(c): tuple(int(i * 255) for i in cm(c / len(class_idxs))[:-1]) for c in class_idxs}
|
|
201
|
+
output_images = []
|
|
202
|
+
titles = []
|
|
203
|
+
if plot_original:
|
|
204
|
+
output_images.append(image)
|
|
205
|
+
titles.append("Original Image")
|
|
206
|
+
|
|
207
|
+
ground_truth_mask = reconstruct_multiclass_mask(ground_truth_image, image.shape, cmap, ignore_class=ignore_class)
|
|
208
|
+
output_images.append(ground_truth_mask)
|
|
209
|
+
titles.append("Ground Truth Mask")
|
|
210
|
+
|
|
211
|
+
prediction_mask = reconstruct_multiclass_mask(
|
|
212
|
+
prediction_image,
|
|
213
|
+
image.shape,
|
|
214
|
+
cmap,
|
|
215
|
+
ignore_class=ignore_class,
|
|
216
|
+
)
|
|
217
|
+
output_images.append(prediction_mask)
|
|
218
|
+
titles.append("Prediction Mask")
|
|
219
|
+
if ignore_class is not None:
|
|
220
|
+
prediction_mask = reconstruct_multiclass_mask(
|
|
221
|
+
prediction_image, image.shape, cmap, ignore_class=ignore_class, ground_truth_mask=ground_truth_image
|
|
222
|
+
)
|
|
223
|
+
prediction_title = f"Prediction Mask \n (Ignoring Ground Truth Class: {ignore_class})"
|
|
224
|
+
output_images.append(prediction_mask)
|
|
225
|
+
titles.append(prediction_title)
|
|
226
|
+
|
|
227
|
+
_, axs = plt.subplots(
|
|
228
|
+
ncols=len(output_images),
|
|
229
|
+
nrows=1,
|
|
230
|
+
figsize=(len(output_images) * image_height, image_height),
|
|
231
|
+
squeeze=False,
|
|
232
|
+
facecolor="white",
|
|
233
|
+
)
|
|
234
|
+
for i, output_image in output_images:
|
|
235
|
+
axs[0, i].imshow(show_mask_on_image(image, output_image))
|
|
236
|
+
axs[0, i].set_title(titles[i])
|
|
237
|
+
axs[0, i].axis("off")
|
|
238
|
+
custom_lines = [Line2D([0], [0], color=tuple(i / 255.0 for i in cmap[str(c)]), lw=4) for c in class_idxs]
|
|
239
|
+
custom_labels = list(class_to_idx.keys())
|
|
240
|
+
axs[0, -1].legend(custom_lines, custom_labels, loc="center left", bbox_to_anchor=(1.01, 0.81), borderaxespad=0)
|
|
241
|
+
if save_path is not None:
|
|
242
|
+
plt.savefig(save_path, bbox_inches="tight")
|
|
243
|
+
plt.close()
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def plot_classification_results(
|
|
247
|
+
test_dataset: torch.utils.data.Dataset,
|
|
248
|
+
pred_labels: np.ndarray,
|
|
249
|
+
test_labels: np.ndarray,
|
|
250
|
+
class_name: str,
|
|
251
|
+
original_folder: str,
|
|
252
|
+
gradcam_folder: str | None = None,
|
|
253
|
+
grayscale_cams: np.ndarray | None = None,
|
|
254
|
+
unorm: Callable[[torch.Tensor], torch.Tensor] | None = None,
|
|
255
|
+
idx_to_class: dict | None = None,
|
|
256
|
+
what: str | None = None,
|
|
257
|
+
real_class_to_plot: int | None = None,
|
|
258
|
+
pred_class_to_plot: int | None = None,
|
|
259
|
+
rows: int | None = 1,
|
|
260
|
+
cols: int = 4,
|
|
261
|
+
figsize: tuple[int, int] = (20, 20),
|
|
262
|
+
gradcam: bool = False,
|
|
263
|
+
) -> None:
|
|
264
|
+
"""Plot and save images extracted from classification. If gradcam is True, same images
|
|
265
|
+
with a gradcam heatmap (layered on original image) will also be saved.
|
|
266
|
+
|
|
267
|
+
Args:
|
|
268
|
+
test_dataset: Test dataset
|
|
269
|
+
pred_labels: Predicted labels
|
|
270
|
+
test_labels: Test labels
|
|
271
|
+
class_name: Name of the examples' class
|
|
272
|
+
original_folder: Folder where original examples will be saved
|
|
273
|
+
gradcam_folder: Folder in which gradcam examples will be saved
|
|
274
|
+
grayscale_cams: Grayscale gradcams (ordered as pred_labels and test_labels)
|
|
275
|
+
unorm: Albumentations function to unormalize image
|
|
276
|
+
idx_to_class: Dictionary of class conversion
|
|
277
|
+
what: Can be "dis" or "conc", used if real_class_to_plot or pred_class_to_plot are None
|
|
278
|
+
real_class_to_plot: Real class to plot.
|
|
279
|
+
pred_class_to_plot: Pred class to plot.
|
|
280
|
+
rows: How many rows in the plot there will be.
|
|
281
|
+
cols: How many cols in the plot there will be.
|
|
282
|
+
figsize: The figure size.
|
|
283
|
+
gradcam: Whether to save also the gradcam version of the examples
|
|
284
|
+
|
|
285
|
+
"""
|
|
286
|
+
to_plot = True
|
|
287
|
+
if gradcam:
|
|
288
|
+
if grayscale_cams is None:
|
|
289
|
+
raise ValueError("gradcam is True but grayscale_cams is None")
|
|
290
|
+
if gradcam_folder is None:
|
|
291
|
+
raise ValueError("gradcam is True but gradcam_folder is None")
|
|
292
|
+
|
|
293
|
+
if real_class_to_plot is not None:
|
|
294
|
+
sample_idx = np.where(test_labels == real_class_to_plot)[0]
|
|
295
|
+
if gradcam and grayscale_cams is not None:
|
|
296
|
+
grayscale_cams = grayscale_cams[test_labels == real_class_to_plot]
|
|
297
|
+
pred_labels = pred_labels[test_labels == real_class_to_plot]
|
|
298
|
+
test_labels = test_labels[test_labels == real_class_to_plot]
|
|
299
|
+
|
|
300
|
+
if pred_class_to_plot is not None:
|
|
301
|
+
sample_idx = np.where(pred_labels == pred_class_to_plot)[0]
|
|
302
|
+
if gradcam and grayscale_cams is not None:
|
|
303
|
+
grayscale_cams = grayscale_cams[pred_labels == pred_class_to_plot]
|
|
304
|
+
test_labels = test_labels[pred_labels == pred_class_to_plot]
|
|
305
|
+
pred_labels = pred_labels[pred_labels == pred_class_to_plot]
|
|
306
|
+
|
|
307
|
+
if pred_class_to_plot is None and real_class_to_plot is None:
|
|
308
|
+
raise ValueError("'real_class_to_plot' and 'pred_class_to_plot' must not be both None")
|
|
309
|
+
|
|
310
|
+
if what is not None:
|
|
311
|
+
if what == "dis":
|
|
312
|
+
cordant = pred_labels != test_labels
|
|
313
|
+
elif what == "con":
|
|
314
|
+
cordant = pred_labels == test_labels
|
|
315
|
+
else:
|
|
316
|
+
raise AssertionError(f"{what} not a valid plot type. Must be con or dis")
|
|
317
|
+
|
|
318
|
+
sample_idx = np.array(sample_idx)[cordant]
|
|
319
|
+
pred_labels = np.array(pred_labels)[cordant]
|
|
320
|
+
test_labels = np.array(test_labels)[cordant]
|
|
321
|
+
if gradcam:
|
|
322
|
+
grayscale_cams = np.array(grayscale_cams)[cordant]
|
|
323
|
+
|
|
324
|
+
# randomize
|
|
325
|
+
idx_random = random.sample(range(len(sample_idx)), len(sample_idx))
|
|
326
|
+
|
|
327
|
+
sample_idx = sample_idx[idx_random]
|
|
328
|
+
pred_labels = pred_labels[idx_random]
|
|
329
|
+
test_labels = test_labels[idx_random]
|
|
330
|
+
if gradcam and grayscale_cams is not None:
|
|
331
|
+
grayscale_cams = grayscale_cams[idx_random]
|
|
332
|
+
|
|
333
|
+
cordant_chunks = list(_chunks(sample_idx, cols))
|
|
334
|
+
|
|
335
|
+
if len(sample_idx) == 0:
|
|
336
|
+
to_plot = False
|
|
337
|
+
print("Nothing to plot")
|
|
338
|
+
else:
|
|
339
|
+
if rows is None or rows == 0:
|
|
340
|
+
total_rows = len(cordant_chunks)
|
|
341
|
+
else:
|
|
342
|
+
total_rows = len(cordant_chunks[:rows])
|
|
343
|
+
if gradcam:
|
|
344
|
+
modality_list = ["original", "gradcam"]
|
|
345
|
+
else:
|
|
346
|
+
modality_list = ["original"]
|
|
347
|
+
for modality in modality_list:
|
|
348
|
+
fig = plt.figure(figsize=figsize)
|
|
349
|
+
grid = ImageGrid(
|
|
350
|
+
fig,
|
|
351
|
+
111, # similar to subplot(111)
|
|
352
|
+
nrows_ncols=(total_rows, cols),
|
|
353
|
+
axes_pad=(0.2, 0.5),
|
|
354
|
+
)
|
|
355
|
+
for i, ax in enumerate(grid):
|
|
356
|
+
if idx_to_class is not None:
|
|
357
|
+
try:
|
|
358
|
+
pred_label = idx_to_class[pred_labels[i]]
|
|
359
|
+
except Exception:
|
|
360
|
+
pred_label = pred_labels[i]
|
|
361
|
+
try:
|
|
362
|
+
test_label = idx_to_class[test_labels[i]]
|
|
363
|
+
except Exception:
|
|
364
|
+
test_label = test_labels[i]
|
|
365
|
+
|
|
366
|
+
ax.axis("off")
|
|
367
|
+
ax.set_title(f"True: {str(test_label)}\nPred {str(pred_label)}")
|
|
368
|
+
image, _ = test_dataset[sample_idx[i]]
|
|
369
|
+
|
|
370
|
+
if unorm is not None:
|
|
371
|
+
image = np.array(unorm(image))
|
|
372
|
+
if modality == "gradcam" and grayscale_cams is not None:
|
|
373
|
+
grayscale_cam = grayscale_cams[i]
|
|
374
|
+
rgb_cam = show_cam_on_image(
|
|
375
|
+
np.transpose(image, (1, 2, 0)), grayscale_cam, use_rgb=True, image_weight=0.7
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
ax.imshow(rgb_cam, cmap="gray")
|
|
379
|
+
if i == len(pred_labels) - 1:
|
|
380
|
+
break
|
|
381
|
+
else:
|
|
382
|
+
if isinstance(image, torch.Tensor):
|
|
383
|
+
image = image.cpu().numpy()
|
|
384
|
+
|
|
385
|
+
if image.max() <= 1:
|
|
386
|
+
image = image * 255
|
|
387
|
+
image = image.astype(int)
|
|
388
|
+
|
|
389
|
+
if len(image.shape) == 3:
|
|
390
|
+
if image.shape[0] == 1:
|
|
391
|
+
image = image[0]
|
|
392
|
+
elif image.shape[0] == 3:
|
|
393
|
+
image = image.transpose((1, 2, 0))
|
|
394
|
+
ax.imshow(image, cmap="gray")
|
|
395
|
+
if i == len(pred_labels) - 1:
|
|
396
|
+
break
|
|
397
|
+
|
|
398
|
+
for item in grid:
|
|
399
|
+
item.axis("off")
|
|
400
|
+
|
|
401
|
+
if to_plot:
|
|
402
|
+
save_folder: str = ""
|
|
403
|
+
if modality == "gradcam" and gradcam_folder is not None:
|
|
404
|
+
save_folder = gradcam_folder
|
|
405
|
+
elif modality == "original":
|
|
406
|
+
save_folder = original_folder
|
|
407
|
+
else:
|
|
408
|
+
log.warning("modality %s has no corresponding folder", modality)
|
|
409
|
+
return
|
|
410
|
+
|
|
411
|
+
plt.savefig(
|
|
412
|
+
os.path.join(save_folder, f"{what}cordant_{class_name}_" + modality + ".png"),
|
|
413
|
+
bbox_inches="tight",
|
|
414
|
+
pad_inches=0,
|
|
415
|
+
)
|
|
416
|
+
plt.close()
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def _chunks(lst, n):
|
|
420
|
+
"""Yield successive n-sized chunks from lst."""
|
|
421
|
+
for i in range(0, len(lst), n):
|
|
422
|
+
yield lst[i : i + n]
|