quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +32 -0
- quadra/configs/callbacks/default.yaml +37 -0
- quadra/configs/callbacks/default_anomalib.yaml +67 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +46 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +331 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1264 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +579 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +524 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.1.13.dist-info/LICENSE +201 -0
- quadra-2.1.13.dist-info/METADATA +386 -0
- quadra-2.1.13.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
- quadra-2.1.13.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
import cv2
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import numpy as np
|
|
8
|
+
from matplotlib.cm import get_cmap
|
|
9
|
+
from matplotlib.colors import Colormap
|
|
10
|
+
from matplotlib.lines import Line2D
|
|
11
|
+
from matplotlib.pyplot import Figure
|
|
12
|
+
|
|
13
|
+
from quadra.utils import utils
|
|
14
|
+
|
|
15
|
+
log = utils.get_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def plot_patch_reconstruction(
|
|
19
|
+
reconstruction: dict,
|
|
20
|
+
idx_to_class: dict[int, str],
|
|
21
|
+
class_to_idx: dict[str, int],
|
|
22
|
+
ignore_classes: list[int] | None = None,
|
|
23
|
+
is_polygon: bool = True,
|
|
24
|
+
) -> Figure:
|
|
25
|
+
"""Helper function for plotting the patch reconstruction.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
reconstruction: Dict following this structure
|
|
29
|
+
{
|
|
30
|
+
"file_path": str,
|
|
31
|
+
"mask_path": str,
|
|
32
|
+
"prediction": {
|
|
33
|
+
"label": str,
|
|
34
|
+
"points": [{"x": int, "y": int}]
|
|
35
|
+
}
|
|
36
|
+
} if is_polygon else
|
|
37
|
+
{
|
|
38
|
+
"file_path": str,
|
|
39
|
+
"mask_path": str,
|
|
40
|
+
"prediction": np.ndarray
|
|
41
|
+
}
|
|
42
|
+
idx_to_class: Dictionary mapping indices to label names
|
|
43
|
+
class_to_idx: Dictionary mapping class names to indices
|
|
44
|
+
ignore_classes: Eventually the classes to not plot
|
|
45
|
+
is_polygon: Boolean indicating if the prediction is a polygon or a mask.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Matplotlib plot showing predicted patch regions and eventually gt
|
|
49
|
+
|
|
50
|
+
"""
|
|
51
|
+
cmap_name = "tab10"
|
|
52
|
+
|
|
53
|
+
# 10 classes + good
|
|
54
|
+
if len(idx_to_class.values()) > 11:
|
|
55
|
+
cmap_name = "tab20"
|
|
56
|
+
|
|
57
|
+
cmap = get_cmap(cmap_name)
|
|
58
|
+
test_img = cv2.imread(reconstruction["image_path"])
|
|
59
|
+
test_img = cv2.cvtColor(test_img, cv2.COLOR_BGR2RGB)
|
|
60
|
+
gt_img = None
|
|
61
|
+
|
|
62
|
+
if reconstruction["mask_path"] is not None and os.path.isfile(reconstruction["mask_path"]):
|
|
63
|
+
gt_img = cv2.imread(reconstruction["mask_path"], 0)
|
|
64
|
+
|
|
65
|
+
out = np.zeros((test_img.shape[0], test_img.shape[1]), dtype=np.uint8)
|
|
66
|
+
|
|
67
|
+
if is_polygon:
|
|
68
|
+
for _, region in enumerate(reconstruction["prediction"]):
|
|
69
|
+
points = [[item["x"], item["y"]] for item in region["points"]]
|
|
70
|
+
c_label = region["label"]
|
|
71
|
+
|
|
72
|
+
out = cv2.drawContours(
|
|
73
|
+
out,
|
|
74
|
+
np.array([points], np.int32),
|
|
75
|
+
-1,
|
|
76
|
+
class_to_idx[c_label],
|
|
77
|
+
thickness=cv2.FILLED,
|
|
78
|
+
) # type: ignore[call-overload]
|
|
79
|
+
else:
|
|
80
|
+
out = reconstruction["prediction"]
|
|
81
|
+
|
|
82
|
+
fig = plot_patch_results(
|
|
83
|
+
image=test_img,
|
|
84
|
+
prediction_image=out,
|
|
85
|
+
ground_truth_image=gt_img,
|
|
86
|
+
plot_original=True,
|
|
87
|
+
ignore_classes=ignore_classes,
|
|
88
|
+
save_path=None,
|
|
89
|
+
class_to_idx=class_to_idx,
|
|
90
|
+
cmap=cmap,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return fig
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def show_mask_on_image(image: np.ndarray, mask: np.ndarray):
|
|
97
|
+
"""Plot mask on top of the original image."""
|
|
98
|
+
image = image.astype(np.float32) / 255
|
|
99
|
+
mask = mask.astype(np.float32) / 255
|
|
100
|
+
out = mask + image.astype(np.float32)
|
|
101
|
+
out = out / np.max(out)
|
|
102
|
+
return np.uint8(255 * out)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def create_rgb_mask(
|
|
106
|
+
mask: np.ndarray,
|
|
107
|
+
color_map: dict,
|
|
108
|
+
ignore_classes: list[int] | None = None,
|
|
109
|
+
ground_truth_mask: np.ndarray | None = None,
|
|
110
|
+
):
|
|
111
|
+
"""Convert index mask to RGB mask."""
|
|
112
|
+
output_mask = np.zeros([mask.shape[0], mask.shape[1], 3])
|
|
113
|
+
for c in np.unique(mask):
|
|
114
|
+
if ignore_classes is not None and c in ignore_classes:
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
output_mask[mask == c] = color_map[str(c)]
|
|
118
|
+
if ignore_classes is not None and ground_truth_mask is not None:
|
|
119
|
+
output_mask[np.isin(ground_truth_mask, ignore_classes)] = [0, 0, 0]
|
|
120
|
+
|
|
121
|
+
return output_mask
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def plot_patch_results(
|
|
125
|
+
image: np.ndarray,
|
|
126
|
+
prediction_image: np.ndarray,
|
|
127
|
+
ground_truth_image: np.ndarray | None,
|
|
128
|
+
class_to_idx: dict[str, int],
|
|
129
|
+
plot_original: bool = True,
|
|
130
|
+
ignore_classes: list[int] | None = None,
|
|
131
|
+
image_height: int = 10,
|
|
132
|
+
save_path: str | None = None,
|
|
133
|
+
cmap: Colormap | None = None,
|
|
134
|
+
) -> Figure:
|
|
135
|
+
"""Function used to plot the image predicted.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
prediction_image: The prediction image
|
|
139
|
+
image: The original image to plot
|
|
140
|
+
ground_truth_image: The ground truth image
|
|
141
|
+
class_to_idx: Dictionary mapping class names to indices
|
|
142
|
+
plot_original: Boolean to plot the original image
|
|
143
|
+
ignore_classes: The classes to ignore, default is 0
|
|
144
|
+
image_height: The height of the output figure
|
|
145
|
+
save_path: The path to save the figure
|
|
146
|
+
cmap: The colormap to use. If None, tab20 is used
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
The matplotlib figure
|
|
150
|
+
"""
|
|
151
|
+
if ignore_classes is None:
|
|
152
|
+
ignore_classes = [0]
|
|
153
|
+
|
|
154
|
+
if cmap is None:
|
|
155
|
+
cmap = get_cmap("tab20")
|
|
156
|
+
|
|
157
|
+
image = image[0 : prediction_image.shape[0], 0 : prediction_image.shape[1], :]
|
|
158
|
+
idx_to_class = {v: k for k, v in class_to_idx.items()}
|
|
159
|
+
|
|
160
|
+
if ignore_classes is not None:
|
|
161
|
+
class_to_idx = {k: v for k, v in class_to_idx.items() if v not in ignore_classes}
|
|
162
|
+
|
|
163
|
+
class_idxs = list(class_to_idx.values())
|
|
164
|
+
|
|
165
|
+
cmap = {str(c): tuple(int(i * 255) for i in cmap(c / len(class_idxs))[:-1]) for c in class_idxs}
|
|
166
|
+
output_images = []
|
|
167
|
+
titles = []
|
|
168
|
+
|
|
169
|
+
if plot_original:
|
|
170
|
+
output_images.append(image)
|
|
171
|
+
titles.append("Original Image")
|
|
172
|
+
|
|
173
|
+
if ground_truth_image is not None:
|
|
174
|
+
ground_truth_image = ground_truth_image[0 : prediction_image.shape[0], 0 : prediction_image.shape[1]]
|
|
175
|
+
ground_truth_mask = create_rgb_mask(ground_truth_image, cmap, ignore_classes=ignore_classes)
|
|
176
|
+
output_images.append(ground_truth_mask)
|
|
177
|
+
titles.append("Ground Truth Mask")
|
|
178
|
+
|
|
179
|
+
prediction_mask = create_rgb_mask(
|
|
180
|
+
prediction_image,
|
|
181
|
+
cmap,
|
|
182
|
+
ignore_classes=ignore_classes,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
output_images.append(prediction_mask)
|
|
186
|
+
titles.append("Prediction Mask")
|
|
187
|
+
if ignore_classes is not None and ground_truth_image is not None:
|
|
188
|
+
prediction_mask = create_rgb_mask(
|
|
189
|
+
prediction_image, cmap, ignore_classes=ignore_classes, ground_truth_mask=ground_truth_image
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
ignored_classes_str = [idx_to_class[c] for c in ignore_classes]
|
|
193
|
+
prediction_title = f"Prediction Mask \n (Ignoring Ground Truth Class: {ignored_classes_str})"
|
|
194
|
+
output_images.append(prediction_mask)
|
|
195
|
+
titles.append(prediction_title)
|
|
196
|
+
|
|
197
|
+
fig, axs = plt.subplots(
|
|
198
|
+
ncols=len(output_images),
|
|
199
|
+
nrows=1,
|
|
200
|
+
figsize=(len(output_images) * image_height, image_height),
|
|
201
|
+
squeeze=False,
|
|
202
|
+
facecolor="white",
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
for i, output_image in enumerate(output_images):
|
|
206
|
+
axs[0, i].imshow(show_mask_on_image(image, output_image))
|
|
207
|
+
axs[0, i].set_title(titles[i])
|
|
208
|
+
axs[0, i].axis("off")
|
|
209
|
+
|
|
210
|
+
custom_lines = [Line2D([0], [0], color=tuple(i / 255.0 for i in cmap[str(c)]), lw=4) for c in class_idxs]
|
|
211
|
+
custom_labels = list(class_to_idx.keys())
|
|
212
|
+
axs[0, -1].legend(custom_lines, custom_labels, loc="center left", bbox_to_anchor=(1.01, 0.81), borderaxespad=0)
|
|
213
|
+
if save_path is not None:
|
|
214
|
+
plt.savefig(save_path, bbox_inches="tight")
|
|
215
|
+
plt.close()
|
|
216
|
+
|
|
217
|
+
return fig
|
quadra/utils/resolver.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from hydra.core.hydra_config import HydraConfig
|
|
6
|
+
from omegaconf import OmegaConf
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def multirun_subdir_beautify(subdir: str) -> str:
|
|
10
|
+
"""Change the subdir name to be more readable and usable, this function will replace / with | to avoid creating
|
|
11
|
+
undesired subdirectories and remove the left part of the equals sign to avoid having too long names.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
subdir: The subdir name.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
The beautified subdir name.
|
|
18
|
+
|
|
19
|
+
Examples:
|
|
20
|
+
>>> multirun_subdir_beautify("experiment=pippo/anomaly/padim,trainer.batch_size=32")
|
|
21
|
+
"pippo|anomaly|padim,32"
|
|
22
|
+
"""
|
|
23
|
+
hydra_cfg = HydraConfig.get()
|
|
24
|
+
if hydra_cfg.mode is None or hydra_cfg.mode.name == "RUN":
|
|
25
|
+
return subdir
|
|
26
|
+
# Remove slashes to avoid creating multiple subdirs
|
|
27
|
+
# TODO: if right side of the equals sign has `,` this will not work.
|
|
28
|
+
subdir_list = subdir.replace("/", "|").split(",")
|
|
29
|
+
subdir = ",".join([x.split("=")[1].replace(" ", "") for x in subdir_list])
|
|
30
|
+
|
|
31
|
+
return subdir
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def as_tuple(*args: Any) -> tuple[Any, ...]:
|
|
35
|
+
"""Resolves a list of arguments to a tuple."""
|
|
36
|
+
return tuple(args)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def register_resolvers() -> None:
|
|
40
|
+
"""Register custom resolver."""
|
|
41
|
+
OmegaConf.register_new_resolver("multirun_subdir_beautify", multirun_subdir_beautify)
|
|
42
|
+
OmegaConf.register_new_resolver("as_tuple", as_tuple)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import skimage
|
|
3
|
+
from skimage.morphology import medial_axis
|
|
4
|
+
|
|
5
|
+
from quadra.utils import utils
|
|
6
|
+
|
|
7
|
+
log = utils.get_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def smooth_mask(mask: np.ndarray) -> np.ndarray:
|
|
11
|
+
"""Smooths for segmentation.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
mask: Input mask
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
Smoothed mask
|
|
18
|
+
"""
|
|
19
|
+
labeled_mask = skimage.measure.label(mask)
|
|
20
|
+
labels = np.arange(0, np.max(labeled_mask) + 1)
|
|
21
|
+
output_mask = np.zeros_like(mask).astype(np.float32)
|
|
22
|
+
for l in labels:
|
|
23
|
+
component_mask = labeled_mask == l
|
|
24
|
+
_, distance = medial_axis(component_mask, return_distance=True)
|
|
25
|
+
component_mask_norm = distance ** (1 / 2.2)
|
|
26
|
+
component_mask_norm = (component_mask_norm - np.min(component_mask_norm)) / (
|
|
27
|
+
np.max(component_mask_norm) - np.min(component_mask_norm)
|
|
28
|
+
)
|
|
29
|
+
output_mask += component_mask_norm
|
|
30
|
+
output_mask = output_mask * mask
|
|
31
|
+
return output_mask
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .dataset import * # noqa: F403
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from .anomaly import AnomalyDatasetArguments, anomaly_dataset, base_anomaly_dataset
|
|
2
|
+
from .classification import (
|
|
3
|
+
ClassificationDatasetArguments,
|
|
4
|
+
ClassificationMultilabelDatasetArguments,
|
|
5
|
+
ClassificationPatchDatasetArguments,
|
|
6
|
+
base_classification_dataset,
|
|
7
|
+
base_multilabel_classification_dataset,
|
|
8
|
+
base_patch_classification_dataset,
|
|
9
|
+
classification_dataset,
|
|
10
|
+
classification_patch_dataset,
|
|
11
|
+
multilabel_classification_dataset,
|
|
12
|
+
)
|
|
13
|
+
from .imagenette import imagenette_dataset
|
|
14
|
+
from .segmentation import (
|
|
15
|
+
SegmentationDatasetArguments,
|
|
16
|
+
base_binary_segmentation_dataset,
|
|
17
|
+
base_multiclass_segmentation_dataset,
|
|
18
|
+
segmentation_dataset,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
__all__ = [
|
|
22
|
+
"anomaly_dataset",
|
|
23
|
+
"classification_dataset",
|
|
24
|
+
"AnomalyDatasetArguments",
|
|
25
|
+
"ClassificationDatasetArguments",
|
|
26
|
+
"ClassificationPatchDatasetArguments",
|
|
27
|
+
"classification_patch_dataset",
|
|
28
|
+
"segmentation_dataset",
|
|
29
|
+
"SegmentationDatasetArguments",
|
|
30
|
+
"multilabel_classification_dataset",
|
|
31
|
+
"ClassificationMultilabelDatasetArguments",
|
|
32
|
+
"base_anomaly_dataset",
|
|
33
|
+
"imagenette_dataset",
|
|
34
|
+
"base_classification_dataset",
|
|
35
|
+
"base_patch_classification_dataset",
|
|
36
|
+
"base_binary_segmentation_dataset",
|
|
37
|
+
"base_multiclass_segmentation_dataset",
|
|
38
|
+
"base_multilabel_classification_dataset",
|
|
39
|
+
]
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import shutil
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import cv2
|
|
9
|
+
import pytest
|
|
10
|
+
|
|
11
|
+
from quadra.utils.tests.helpers import _random_image
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class AnomalyDatasetArguments:
|
|
16
|
+
"""Anomaly dataset arguments.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
train_samples: number of train samples
|
|
20
|
+
val_samples: number of validation samples (good, bad)
|
|
21
|
+
test_samples: number of test samples (good, bad)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
train_samples: int
|
|
25
|
+
val_samples: tuple[int, int]
|
|
26
|
+
test_samples: tuple[int, int]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _build_anomaly_dataset(
|
|
30
|
+
tmp_path: Path, dataset_arguments: AnomalyDatasetArguments
|
|
31
|
+
) -> tuple[str, AnomalyDatasetArguments]:
|
|
32
|
+
"""Generate anomaly dataset in the standard mvtec format.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
tmp_path: path to temporary directory
|
|
36
|
+
dataset_arguments: dataset arguments
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
path to anomaly dataset
|
|
40
|
+
"""
|
|
41
|
+
train_samples = dataset_arguments.train_samples
|
|
42
|
+
val_samples = dataset_arguments.val_samples
|
|
43
|
+
test_samples = dataset_arguments.test_samples
|
|
44
|
+
|
|
45
|
+
anomaly_dataset_path = tmp_path / "anomaly_dataset"
|
|
46
|
+
anomaly_dataset_path.mkdir()
|
|
47
|
+
train_good_path = anomaly_dataset_path / "train" / "good"
|
|
48
|
+
val_good_path = anomaly_dataset_path / "val" / "good"
|
|
49
|
+
val_bad_path = anomaly_dataset_path / "val" / "bad"
|
|
50
|
+
test_good_path = anomaly_dataset_path / "test" / "good"
|
|
51
|
+
test_bad_path = anomaly_dataset_path / "test" / "bad"
|
|
52
|
+
|
|
53
|
+
train_good_path.mkdir(parents=True)
|
|
54
|
+
val_good_path.mkdir(parents=True)
|
|
55
|
+
val_bad_path.mkdir(parents=True)
|
|
56
|
+
test_good_path.mkdir(parents=True)
|
|
57
|
+
test_bad_path.mkdir(parents=True)
|
|
58
|
+
|
|
59
|
+
# Generate train good images
|
|
60
|
+
for i in range(train_samples):
|
|
61
|
+
image = _random_image()
|
|
62
|
+
image_path = train_good_path / f"train_{i}.png"
|
|
63
|
+
cv2.imwrite(str(image_path), image)
|
|
64
|
+
|
|
65
|
+
# Generate val good images
|
|
66
|
+
for i in range(val_samples[0]):
|
|
67
|
+
image = _random_image()
|
|
68
|
+
image_path = val_good_path / f"val_{i}.png"
|
|
69
|
+
cv2.imwrite(str(image_path), image)
|
|
70
|
+
# Generate val bad images
|
|
71
|
+
for i in range(val_samples[1]):
|
|
72
|
+
image = _random_image()
|
|
73
|
+
image_path = val_bad_path / f"val_{i}.png"
|
|
74
|
+
cv2.imwrite(str(image_path), image)
|
|
75
|
+
|
|
76
|
+
# Generate test good images
|
|
77
|
+
for i in range(test_samples[0]):
|
|
78
|
+
image = _random_image()
|
|
79
|
+
image_path = test_good_path / f"test_{i}.png"
|
|
80
|
+
cv2.imwrite(str(image_path), image)
|
|
81
|
+
# Generate test bad images
|
|
82
|
+
for i in range(test_samples[1]):
|
|
83
|
+
image = _random_image()
|
|
84
|
+
image_path = test_bad_path / f"test_{i}.png"
|
|
85
|
+
cv2.imwrite(str(image_path), image)
|
|
86
|
+
|
|
87
|
+
return str(anomaly_dataset_path), dataset_arguments
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@pytest.fixture
|
|
91
|
+
def anomaly_dataset(tmp_path: Path, dataset_arguments: AnomalyDatasetArguments) -> tuple[str, AnomalyDatasetArguments]:
|
|
92
|
+
"""Fixture used to dinamically generate anomaly dataset. By default images are random grayscales with size 10x10.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
tmp_path: path to temporary directory
|
|
96
|
+
dataset_arguments: dataset arguments
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
path to anomaly dataset
|
|
100
|
+
"""
|
|
101
|
+
yield _build_anomaly_dataset(tmp_path, dataset_arguments)
|
|
102
|
+
if tmp_path.exists():
|
|
103
|
+
shutil.rmtree(tmp_path)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@pytest.fixture(
|
|
107
|
+
params=[AnomalyDatasetArguments(**{"train_samples": 10, "val_samples": (1, 1), "test_samples": (1, 1)})]
|
|
108
|
+
)
|
|
109
|
+
def base_anomaly_dataset(tmp_path: Path, request: Any) -> tuple[str, AnomalyDatasetArguments]:
|
|
110
|
+
"""Generate base anomaly dataset with the following parameters:
|
|
111
|
+
- train_samples: 10
|
|
112
|
+
- val_samples: (10, 10)
|
|
113
|
+
- test_samples: (10, 10).
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
tmp_path: Path to temporary directory
|
|
117
|
+
request: Pytest SubRequest object
|
|
118
|
+
|
|
119
|
+
Yields:
|
|
120
|
+
Path to anomaly dataset and dataset arguments
|
|
121
|
+
"""
|
|
122
|
+
yield _build_anomaly_dataset(tmp_path, request.param)
|
|
123
|
+
if tmp_path.exists():
|
|
124
|
+
shutil.rmtree(tmp_path)
|