quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,449 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import cv2
|
|
7
|
+
import numpy as np
|
|
8
|
+
import pandas as pd
|
|
9
|
+
from scipy import ndimage
|
|
10
|
+
from skimage.measure import label, regionprops # pylint: disable=no-name-in-module
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
from quadra.utils import utils
|
|
14
|
+
from quadra.utils.patch.dataset import PatchDatasetFileFormat, compute_patch_info, compute_patch_info_from_patch_dim
|
|
15
|
+
|
|
16
|
+
log = utils.get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_sorted_patches_by_image(test_results: pd.DataFrame, img_name: str) -> pd.DataFrame:
|
|
20
|
+
"""Gets the patches of a given image sorted by patch number.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
test_results: Pandas dataframe containing test results like the one produced by SklearnClassificationTrainer
|
|
24
|
+
img_name: name of the image used to filter the results.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
test results filtered by image name and sorted by patch number
|
|
28
|
+
"""
|
|
29
|
+
img_patches = test_results[test_results["filename"] == os.path.splitext(img_name)[0]]
|
|
30
|
+
patches_idx = np.array(
|
|
31
|
+
[int(os.path.basename(x).split("_")[-1].replace(".png", "")) for x in img_patches["sample"].tolist()]
|
|
32
|
+
)
|
|
33
|
+
patches_idx = np.argsort(patches_idx).tolist()
|
|
34
|
+
img_patches = img_patches.iloc[patches_idx]
|
|
35
|
+
|
|
36
|
+
return img_patches
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def compute_patch_metrics(
|
|
40
|
+
test_img_info: list[PatchDatasetFileFormat],
|
|
41
|
+
test_results: pd.DataFrame,
|
|
42
|
+
overlap: float,
|
|
43
|
+
idx_to_class: dict,
|
|
44
|
+
patch_num_h: int | None = None,
|
|
45
|
+
patch_num_w: int | None = None,
|
|
46
|
+
patch_w: int | None = None,
|
|
47
|
+
patch_h: int | None = None,
|
|
48
|
+
return_polygon: bool = False,
|
|
49
|
+
patch_reconstruction_method: str = "priority",
|
|
50
|
+
annotated_good: list[int] | None = None,
|
|
51
|
+
) -> tuple[int, int, int, list[dict]]:
|
|
52
|
+
"""Compute the metrics of a patch dataset.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
test_img_info: List of observation paths and mask paths
|
|
56
|
+
test_results: Pandas dataframe containing the results of an SklearnClassificationTrainer utility
|
|
57
|
+
patch_num_h: Number of vertical patches (required if patch_w and patch_h are None)
|
|
58
|
+
patch_num_w: Number of horizontal patches (required if patch_w and patch_h are None)
|
|
59
|
+
patch_h: Patch height (required if patch_num_h and patch_num_w are None)
|
|
60
|
+
patch_w: Patch width (required if patch_num_h and patch_num_w are None)
|
|
61
|
+
overlap: Percentage of overlap between the patches
|
|
62
|
+
idx_to_class: Dict mapping an index to the corresponding class name
|
|
63
|
+
return_polygon: if set to true convert the reconstructed mask into polygons, otherwise return the mask
|
|
64
|
+
patch_reconstruction_method: How to compute the label of overlapping patches, can either be:
|
|
65
|
+
priority: Assign the top priority label (i.e the one with greater index) to overlapping regions
|
|
66
|
+
major_voting: Assign the most present label among the patches label overlapping a pixel
|
|
67
|
+
annotated_good: List of indices of annotations to be treated as good.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Tuple containing:
|
|
71
|
+
false_region_bad: Number of false bad regions detected in the dataset
|
|
72
|
+
false_region_good: Number of missed defects
|
|
73
|
+
true_region_bad: Number of correctly identified defects
|
|
74
|
+
reconstructions: If polygon is true this is a List of dict containing
|
|
75
|
+
{
|
|
76
|
+
"file_path": image_path,
|
|
77
|
+
"mask_path": mask_path,
|
|
78
|
+
"file_name": observation_name,
|
|
79
|
+
"prediction": [{
|
|
80
|
+
"label": predicted_label,
|
|
81
|
+
"points": List of dict coordinates "x" and "y" representing the points of a polygon that
|
|
82
|
+
surrounds an image area covered by patches of label = predicted_label
|
|
83
|
+
}]
|
|
84
|
+
}
|
|
85
|
+
else its a list of dict containing
|
|
86
|
+
{
|
|
87
|
+
"file_path": image_path,
|
|
88
|
+
"mask_path": mask_path,
|
|
89
|
+
"file_name": observation_name,
|
|
90
|
+
"prediction": numpy array containing the reconstructed mask
|
|
91
|
+
}
|
|
92
|
+
"""
|
|
93
|
+
assert patch_reconstruction_method in [
|
|
94
|
+
"priority",
|
|
95
|
+
"major_voting",
|
|
96
|
+
], "Patch reconstruction method not recognized, valid values are priority, major_voting"
|
|
97
|
+
|
|
98
|
+
if (patch_h is not None and patch_w is not None) and (patch_num_h is not None and patch_num_w is not None):
|
|
99
|
+
raise ValueError("Either number of patches or patch size is required for reconstruction")
|
|
100
|
+
|
|
101
|
+
assert (patch_h is not None and patch_w is not None) or (
|
|
102
|
+
patch_num_h is not None and patch_num_w is not None
|
|
103
|
+
), "Either number of patches or patch size is required for reconstruction"
|
|
104
|
+
|
|
105
|
+
if patch_h is not None and patch_w is not None and patch_num_h is not None and patch_num_w is not None:
|
|
106
|
+
warnings.warn(
|
|
107
|
+
"Both number of patches and patch dimension are specified, using number of patches by default",
|
|
108
|
+
UserWarning,
|
|
109
|
+
stacklevel=2,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
log.info("Computing patch metrics!")
|
|
113
|
+
|
|
114
|
+
false_region_bad = 0
|
|
115
|
+
false_region_good = 0
|
|
116
|
+
true_region_bad = 0
|
|
117
|
+
reconstructions = []
|
|
118
|
+
test_results["filename"] = test_results["sample"].apply(
|
|
119
|
+
lambda x: "_".join(os.path.basename(x).replace("#DISCARD#", "").split("_")[0:-1])
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
for info in tqdm(test_img_info):
|
|
123
|
+
img_path = info.image_path
|
|
124
|
+
mask_path = info.mask_path
|
|
125
|
+
|
|
126
|
+
img_json_entry = {
|
|
127
|
+
"image_path": img_path,
|
|
128
|
+
"mask_path": mask_path,
|
|
129
|
+
"file_name": os.path.basename(img_path),
|
|
130
|
+
"prediction": None,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
test_img = cv2.imread(img_path)
|
|
134
|
+
|
|
135
|
+
img_name = os.path.basename(img_path)
|
|
136
|
+
|
|
137
|
+
h = test_img.shape[0]
|
|
138
|
+
w = test_img.shape[1]
|
|
139
|
+
|
|
140
|
+
gt_img = None
|
|
141
|
+
|
|
142
|
+
if mask_path is not None and os.path.exists(mask_path):
|
|
143
|
+
gt_img = cv2.imread(mask_path, 0)
|
|
144
|
+
if test_img.shape[0:2] != gt_img.shape:
|
|
145
|
+
# Ensure that the mask has the same size as the image by padding it with zeros
|
|
146
|
+
log.warning("Found mask with different size than the image, padding it with zeros!")
|
|
147
|
+
gt_img = np.pad(
|
|
148
|
+
gt_img, ((0, test_img.shape[0] - gt_img.shape[0]), (0, test_img.shape[1] - gt_img.shape[1]))
|
|
149
|
+
)
|
|
150
|
+
if patch_num_h is not None and patch_num_w is not None:
|
|
151
|
+
patch_size, step = compute_patch_info(h, w, patch_num_h, patch_num_w, overlap)
|
|
152
|
+
elif patch_h is not None and patch_w is not None:
|
|
153
|
+
[patch_num_h, patch_num_w], step = compute_patch_info_from_patch_dim(h, w, patch_h, patch_w, overlap)
|
|
154
|
+
patch_size = (patch_h, patch_w)
|
|
155
|
+
else:
|
|
156
|
+
raise ValueError(
|
|
157
|
+
"Either number of patches or patch size is required for reconstruction, this should not happen"
|
|
158
|
+
" at this stage"
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
img_patches = get_sorted_patches_by_image(test_results, img_name)
|
|
162
|
+
pred = img_patches["pred_label"].to_numpy().reshape(patch_num_h, patch_num_w)
|
|
163
|
+
|
|
164
|
+
# Treat annotated good predictions as background, this is an optimistic assumption that assumes that the
|
|
165
|
+
# remaining background is good, but it is not always true so maybe on non annotated areas we are missing
|
|
166
|
+
# defects and it would be necessary to handle this in a different way.
|
|
167
|
+
if annotated_good is not None:
|
|
168
|
+
pred[np.isin(pred, annotated_good)] = 0
|
|
169
|
+
if patch_num_h is not None and patch_num_w is not None:
|
|
170
|
+
output_mask, predicted_defect = reconstruct_patch(
|
|
171
|
+
input_img_shape=test_img.shape,
|
|
172
|
+
patch_size=patch_size,
|
|
173
|
+
pred=pred,
|
|
174
|
+
patch_num_h=patch_num_h,
|
|
175
|
+
patch_num_w=patch_num_w,
|
|
176
|
+
idx_to_class=idx_to_class,
|
|
177
|
+
step=step,
|
|
178
|
+
return_polygon=return_polygon,
|
|
179
|
+
method=patch_reconstruction_method,
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
raise ValueError("`patch_num_h` and `patch_num_w` cannot be None at this point")
|
|
183
|
+
|
|
184
|
+
if return_polygon:
|
|
185
|
+
img_json_entry["prediction"] = predicted_defect
|
|
186
|
+
else:
|
|
187
|
+
img_json_entry["prediction"] = output_mask
|
|
188
|
+
|
|
189
|
+
reconstructions.append(img_json_entry)
|
|
190
|
+
if gt_img is not None:
|
|
191
|
+
if annotated_good is not None:
|
|
192
|
+
gt_img[np.isin(gt_img, annotated_good)] = 0
|
|
193
|
+
|
|
194
|
+
gt_img_binary = (gt_img > 0).astype(bool) # type: ignore[operator]
|
|
195
|
+
regions_pred = label(output_mask).astype(np.uint8)
|
|
196
|
+
|
|
197
|
+
for k in range(1, regions_pred.max() + 1):
|
|
198
|
+
region = (regions_pred == k).astype(bool)
|
|
199
|
+
# If there's no overlap with the gt
|
|
200
|
+
if np.sum(np.bitwise_and(region, gt_img_binary)) == 0:
|
|
201
|
+
false_region_bad += 1
|
|
202
|
+
|
|
203
|
+
output_mask = (output_mask > 0).astype(np.uint8)
|
|
204
|
+
gt_img = label(gt_img)
|
|
205
|
+
|
|
206
|
+
for i in range(1, gt_img.max() + 1): # type: ignore[union-attr]
|
|
207
|
+
region = (gt_img == i).astype(bool) # type: ignore[union-attr]
|
|
208
|
+
if np.sum(np.bitwise_and(region, output_mask)) == 0:
|
|
209
|
+
false_region_good += 1
|
|
210
|
+
else:
|
|
211
|
+
true_region_bad += 1
|
|
212
|
+
|
|
213
|
+
return false_region_bad, false_region_good, true_region_bad, reconstructions
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def reconstruct_patch(
|
|
217
|
+
input_img_shape: tuple[int, ...],
|
|
218
|
+
patch_size: tuple[int, int],
|
|
219
|
+
pred: np.ndarray,
|
|
220
|
+
patch_num_h: int,
|
|
221
|
+
patch_num_w: int,
|
|
222
|
+
idx_to_class: dict,
|
|
223
|
+
step: tuple[int, int],
|
|
224
|
+
return_polygon: bool = True,
|
|
225
|
+
method: str = "priority",
|
|
226
|
+
) -> tuple[np.ndarray, list[dict]]:
|
|
227
|
+
"""Reconstructs the prediction image from the patches.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
input_img_shape: The size of the reconstructed image
|
|
231
|
+
patch_size: Array defining the patch size
|
|
232
|
+
pred: Numpy array containing reconstructed prediction (patch_num_h x patch_num_w)
|
|
233
|
+
patch_num_h: Number of vertical patches
|
|
234
|
+
patch_num_w: Number of horizontal patches
|
|
235
|
+
idx_to_class: Dictionary mapping indices to labels
|
|
236
|
+
step: Array defining the step size to be used for reconstruction
|
|
237
|
+
return_polygon: If true compute predicted polygons. Defaults to True.
|
|
238
|
+
method: Reconstruction method to be used. Currently supported: "priority" and "major_voting"
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
(reconstructed_prediction_image, predictions) where predictions is an array of objects
|
|
242
|
+
[{
|
|
243
|
+
"label": Predicted_label,
|
|
244
|
+
"points": List of dict coordinates "x" and "y" representing the points of a polygon that
|
|
245
|
+
surrounds an image area covered by patches of label = predicted_label
|
|
246
|
+
}]
|
|
247
|
+
"""
|
|
248
|
+
if method == "priority":
|
|
249
|
+
return _reconstruct_patch_priority(
|
|
250
|
+
input_img_shape,
|
|
251
|
+
patch_size,
|
|
252
|
+
pred,
|
|
253
|
+
patch_num_h,
|
|
254
|
+
patch_num_w,
|
|
255
|
+
idx_to_class,
|
|
256
|
+
step,
|
|
257
|
+
return_polygon,
|
|
258
|
+
)
|
|
259
|
+
if method == "major_voting":
|
|
260
|
+
return _reconstruct_patch_major_voting(
|
|
261
|
+
input_img_shape,
|
|
262
|
+
patch_size,
|
|
263
|
+
pred,
|
|
264
|
+
patch_num_h,
|
|
265
|
+
patch_num_w,
|
|
266
|
+
idx_to_class,
|
|
267
|
+
step,
|
|
268
|
+
return_polygon,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
raise ValueError(f"Invalid reconstruction method {method}")
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _reconstruct_patch_priority(
|
|
275
|
+
input_img_shape: tuple[int, ...],
|
|
276
|
+
patch_size: tuple[int, int],
|
|
277
|
+
pred: np.ndarray,
|
|
278
|
+
patch_num_h: int,
|
|
279
|
+
patch_num_w: int,
|
|
280
|
+
idx_to_class: dict,
|
|
281
|
+
step: tuple[int, int],
|
|
282
|
+
return_polygon: bool = True,
|
|
283
|
+
) -> tuple[np.ndarray, list[dict]]:
|
|
284
|
+
"""Reconstruct patch polygons using the priority method."""
|
|
285
|
+
final_mask = np.zeros([input_img_shape[0], input_img_shape[1]], dtype=np.uint8)
|
|
286
|
+
predicted_defect = []
|
|
287
|
+
|
|
288
|
+
for i in range(1, pred.max() + 1):
|
|
289
|
+
white_patch = np.full((patch_size[0], patch_size[1]), i, dtype=np.uint8)
|
|
290
|
+
masked_pred = (pred == i).astype(np.uint8)
|
|
291
|
+
|
|
292
|
+
if masked_pred.sum() == 0:
|
|
293
|
+
continue
|
|
294
|
+
|
|
295
|
+
mask_img = np.zeros([input_img_shape[0], input_img_shape[1]], dtype=np.uint8)
|
|
296
|
+
|
|
297
|
+
for h in range(patch_num_h):
|
|
298
|
+
for w in range(patch_num_w):
|
|
299
|
+
if masked_pred[h, w] == 1:
|
|
300
|
+
patch_location_h = step[0] * h
|
|
301
|
+
patch_location_w = step[1] * w
|
|
302
|
+
|
|
303
|
+
# Move replicated patches prediction in the correct position of the original image if needed
|
|
304
|
+
if patch_location_h + patch_size[0] > mask_img.shape[0]:
|
|
305
|
+
patch_location_h = mask_img.shape[0] - patch_size[0]
|
|
306
|
+
|
|
307
|
+
if patch_location_w + patch_size[1] > mask_img.shape[1]:
|
|
308
|
+
patch_location_w = mask_img.shape[1] - patch_size[1]
|
|
309
|
+
|
|
310
|
+
mask_img[
|
|
311
|
+
patch_location_h : patch_location_h + patch_size[0],
|
|
312
|
+
patch_location_w : patch_location_w + patch_size[1],
|
|
313
|
+
] = white_patch
|
|
314
|
+
|
|
315
|
+
mask_img = mask_img[0 : input_img_shape[0], 0 : input_img_shape[1]]
|
|
316
|
+
|
|
317
|
+
# Priority is given by the index of the class, the larger, the more important
|
|
318
|
+
final_mask = np.maximum(mask_img, final_mask)
|
|
319
|
+
|
|
320
|
+
if final_mask.sum() != 0 and return_polygon:
|
|
321
|
+
for lab in np.unique(final_mask):
|
|
322
|
+
if lab == 0:
|
|
323
|
+
continue
|
|
324
|
+
|
|
325
|
+
polygon = from_mask_to_polygon((final_mask == lab).astype(np.uint8))
|
|
326
|
+
|
|
327
|
+
for pol in polygon:
|
|
328
|
+
class_entry = {
|
|
329
|
+
"label": idx_to_class.get(lab),
|
|
330
|
+
"points": pol,
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
predicted_defect.append(class_entry)
|
|
334
|
+
|
|
335
|
+
return final_mask, predicted_defect
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _reconstruct_patch_major_voting(
|
|
339
|
+
input_img_shape: tuple[int, ...],
|
|
340
|
+
patch_size: tuple[int, int],
|
|
341
|
+
pred: np.ndarray,
|
|
342
|
+
patch_num_h: int,
|
|
343
|
+
patch_num_w: int,
|
|
344
|
+
idx_to_class: dict,
|
|
345
|
+
step: tuple[int, int],
|
|
346
|
+
return_polygon: bool = True,
|
|
347
|
+
):
|
|
348
|
+
"""Reconstruct patch polygons using the major voting method."""
|
|
349
|
+
predicted_defect = []
|
|
350
|
+
|
|
351
|
+
final_mask = np.zeros([input_img_shape[0], input_img_shape[1], np.max(pred) + 1], dtype=np.uint8)
|
|
352
|
+
white_patch = np.ones((patch_size[0], patch_size[1]), dtype=np.uint8)
|
|
353
|
+
|
|
354
|
+
for i in range(1, pred.max() + 1):
|
|
355
|
+
masked_pred = (pred == i).astype(np.uint8)
|
|
356
|
+
|
|
357
|
+
if masked_pred.sum() == 0:
|
|
358
|
+
continue
|
|
359
|
+
|
|
360
|
+
mask_img = np.zeros([input_img_shape[0], input_img_shape[1]], dtype=np.uint8)
|
|
361
|
+
|
|
362
|
+
for h in range(patch_num_h):
|
|
363
|
+
for w in range(patch_num_w):
|
|
364
|
+
if masked_pred[h, w] == 1:
|
|
365
|
+
patch_location_h = step[0] * h
|
|
366
|
+
patch_location_w = step[1] * w
|
|
367
|
+
|
|
368
|
+
# Move replicated patches prediction in the correct position of the original image if needed
|
|
369
|
+
if patch_location_h + patch_size[0] > mask_img.shape[0]:
|
|
370
|
+
patch_location_h = mask_img.shape[0] - patch_size[0]
|
|
371
|
+
|
|
372
|
+
if patch_location_w + patch_size[1] > mask_img.shape[1]:
|
|
373
|
+
patch_location_w = mask_img.shape[1] - patch_size[1]
|
|
374
|
+
|
|
375
|
+
mask_img[
|
|
376
|
+
patch_location_h : patch_location_h + patch_size[0],
|
|
377
|
+
patch_location_w : patch_location_w + patch_size[1],
|
|
378
|
+
] += white_patch
|
|
379
|
+
|
|
380
|
+
mask_img = mask_img[0 : input_img_shape[0], 0 : input_img_shape[1]]
|
|
381
|
+
final_mask[:, :, i] = mask_img
|
|
382
|
+
|
|
383
|
+
# Since argmax returns first element on ties and the priority is defined from 0 to n_classes,
|
|
384
|
+
# I needed a way to get the last element on ties, this code achieves that
|
|
385
|
+
final_mask = ((final_mask.shape[-1] - 1) - np.argmax(final_mask[..., ::-1], axis=-1)) * np.invert(
|
|
386
|
+
np.all(final_mask == 0, axis=-1)
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
if final_mask.sum() != 0 and return_polygon:
|
|
390
|
+
for lab in np.unique(final_mask):
|
|
391
|
+
if lab == 0:
|
|
392
|
+
continue
|
|
393
|
+
|
|
394
|
+
polygon = from_mask_to_polygon((final_mask == lab).astype(np.uint8))
|
|
395
|
+
|
|
396
|
+
for pol in polygon:
|
|
397
|
+
class_entry = {
|
|
398
|
+
"label": idx_to_class.get(lab),
|
|
399
|
+
"points": pol,
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
predicted_defect.append(class_entry)
|
|
403
|
+
|
|
404
|
+
return final_mask, predicted_defect
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def from_mask_to_polygon(mask_img: np.ndarray) -> list:
|
|
408
|
+
"""Convert a mask of pattern to a list of polygon vertices.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
mask_img: masked patch reconstruction image
|
|
412
|
+
Returns:
|
|
413
|
+
a list of lists containing the coordinates of the polygons containing each region of the mask:
|
|
414
|
+
[
|
|
415
|
+
[
|
|
416
|
+
{
|
|
417
|
+
"x": 1.1,
|
|
418
|
+
"y": 2.2
|
|
419
|
+
},
|
|
420
|
+
{
|
|
421
|
+
"x": 2.1,
|
|
422
|
+
"y": 3.2
|
|
423
|
+
}
|
|
424
|
+
], ...
|
|
425
|
+
].
|
|
426
|
+
"""
|
|
427
|
+
points_dict = []
|
|
428
|
+
# find vertices of polygon: points -> list of array of dim n_vertex, 1, 2(x,y)
|
|
429
|
+
polygon_points, hier = cv2.findContours(mask_img, cv2.RETR_TREE, cv2.CHAIN_APPROX_TC89_L1)
|
|
430
|
+
|
|
431
|
+
if not hier[:, :, 2:].all(-1).all(): # there are holes
|
|
432
|
+
holes = ndimage.binary_fill_holes(mask_img).astype(int)
|
|
433
|
+
holes -= mask_img
|
|
434
|
+
holes = (holes > 0).astype(np.uint8)
|
|
435
|
+
if holes.sum() > 0: # there are holes
|
|
436
|
+
for hole in regionprops(label(holes)):
|
|
437
|
+
a, _, _, _d = hole.bbox
|
|
438
|
+
mask_img[a] = 0
|
|
439
|
+
|
|
440
|
+
polygon_points, hier = cv2.findContours(mask_img, cv2.RETR_LIST, cv2.CHAIN_APPROX_TC89_L1)
|
|
441
|
+
|
|
442
|
+
for pol in polygon_points:
|
|
443
|
+
# pol: n_vertex, 1, 2
|
|
444
|
+
current_poly = []
|
|
445
|
+
for point in pol:
|
|
446
|
+
current_poly.append({"x": int(point[0, 0]), "y": int(point[0, 1])})
|
|
447
|
+
points_dict.append(current_poly)
|
|
448
|
+
|
|
449
|
+
return points_dict
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from label_studio_converter.brush import mask2rle
|
|
11
|
+
from omegaconf import DictConfig
|
|
12
|
+
from sklearn.metrics import ConfusionMatrixDisplay
|
|
13
|
+
from torch.utils.data import DataLoader
|
|
14
|
+
|
|
15
|
+
from quadra.utils import utils
|
|
16
|
+
from quadra.utils.patch.visualization import plot_patch_reconstruction
|
|
17
|
+
from quadra.utils.visualization import UnNormalize, plot_classification_results
|
|
18
|
+
|
|
19
|
+
log = utils.get_logger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def save_classification_result(
|
|
23
|
+
results: pd.DataFrame,
|
|
24
|
+
output_folder: str,
|
|
25
|
+
confusion_matrix: pd.DataFrame | None,
|
|
26
|
+
accuracy: float,
|
|
27
|
+
test_dataloader: DataLoader,
|
|
28
|
+
reconstructions: list[dict],
|
|
29
|
+
config: DictConfig,
|
|
30
|
+
output: DictConfig,
|
|
31
|
+
ignore_classes: list[int] | None = None,
|
|
32
|
+
):
|
|
33
|
+
"""Save classification results.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
results: Dataframe containing the classification results
|
|
37
|
+
output_folder: Folder where to save the results
|
|
38
|
+
confusion_matrix: Confusion matrix
|
|
39
|
+
accuracy: Accuracy of the model
|
|
40
|
+
test_dataloader: Dataloader used for testing
|
|
41
|
+
reconstructions: List of dictionaries containing polygons or masks
|
|
42
|
+
config: Experiment configuration
|
|
43
|
+
output: Output configuration
|
|
44
|
+
ignore_classes: Eventual classes to ignore during reconstruction plot. Defaults to None.
|
|
45
|
+
"""
|
|
46
|
+
# Save csv
|
|
47
|
+
results.to_csv(os.path.join(output_folder, "test_results.csv"), index_label="index")
|
|
48
|
+
|
|
49
|
+
if confusion_matrix is not None:
|
|
50
|
+
# Save confusion matrix
|
|
51
|
+
disp = ConfusionMatrixDisplay(
|
|
52
|
+
confusion_matrix=np.array(confusion_matrix),
|
|
53
|
+
display_labels=[x.replace("pred:", "") for x in confusion_matrix.columns.to_list()],
|
|
54
|
+
)
|
|
55
|
+
disp.plot(include_values=True, cmap=plt.cm.Greens, ax=None, colorbar=False, xticks_rotation=90)
|
|
56
|
+
plt.title(f"Confusion Matrix (Accuracy: {(accuracy * 100):.2f}%)")
|
|
57
|
+
plt.savefig(
|
|
58
|
+
os.path.join(output_folder, "test_confusion_matrix.png"),
|
|
59
|
+
bbox_inches="tight",
|
|
60
|
+
pad_inches=0,
|
|
61
|
+
dpi=300,
|
|
62
|
+
)
|
|
63
|
+
plt.close()
|
|
64
|
+
|
|
65
|
+
if output.example:
|
|
66
|
+
if not hasattr(test_dataloader.dataset, "idx_to_class"):
|
|
67
|
+
raise ValueError("The provided dataset does not have an attribute 'idx_to_class")
|
|
68
|
+
|
|
69
|
+
idx_to_class = test_dataloader.dataset.idx_to_class
|
|
70
|
+
|
|
71
|
+
# Get misclassified samples
|
|
72
|
+
example_folder = os.path.join(output_folder, "example")
|
|
73
|
+
if not os.path.isdir(example_folder):
|
|
74
|
+
os.makedirs(example_folder)
|
|
75
|
+
|
|
76
|
+
# Skip if no no ground truth is available
|
|
77
|
+
if not all(results["real_label"] == -1):
|
|
78
|
+
for v in np.unique([results["real_label"], results["pred_label"]]):
|
|
79
|
+
if v == -1:
|
|
80
|
+
continue
|
|
81
|
+
|
|
82
|
+
k = idx_to_class[v]
|
|
83
|
+
|
|
84
|
+
if ignore_classes is not None and v in ignore_classes:
|
|
85
|
+
continue
|
|
86
|
+
|
|
87
|
+
plot_classification_results(
|
|
88
|
+
test_dataloader.dataset,
|
|
89
|
+
unorm=UnNormalize(mean=config.transforms.mean, std=config.transforms.std),
|
|
90
|
+
pred_labels=results["pred_label"].to_numpy(),
|
|
91
|
+
test_labels=results["real_label"].to_numpy(),
|
|
92
|
+
class_name=k,
|
|
93
|
+
original_folder=example_folder,
|
|
94
|
+
idx_to_class=idx_to_class,
|
|
95
|
+
pred_class_to_plot=v,
|
|
96
|
+
what="con",
|
|
97
|
+
rows=output.get("rows", 3),
|
|
98
|
+
cols=output.get("cols", 2),
|
|
99
|
+
figsize=output.get("figsize", (20, 20)),
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
plot_classification_results(
|
|
103
|
+
test_dataloader.dataset,
|
|
104
|
+
unorm=UnNormalize(mean=config.transforms.mean, std=config.transforms.std),
|
|
105
|
+
pred_labels=results["pred_label"].to_numpy(),
|
|
106
|
+
test_labels=results["real_label"].to_numpy(),
|
|
107
|
+
class_name=k,
|
|
108
|
+
original_folder=example_folder,
|
|
109
|
+
idx_to_class=idx_to_class,
|
|
110
|
+
pred_class_to_plot=v,
|
|
111
|
+
what="dis",
|
|
112
|
+
rows=output.get("rows", 3),
|
|
113
|
+
cols=output.get("cols", 2),
|
|
114
|
+
figsize=output.get("figsize", (20, 20)),
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
for counter, reconstruction in enumerate(reconstructions):
|
|
118
|
+
is_polygon = True
|
|
119
|
+
if isinstance(reconstruction["prediction"], np.ndarray):
|
|
120
|
+
is_polygon = False
|
|
121
|
+
|
|
122
|
+
if is_polygon:
|
|
123
|
+
if len(reconstruction["prediction"]) == 0:
|
|
124
|
+
continue
|
|
125
|
+
elif reconstruction["prediction"].sum() == 0:
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
if counter > 5:
|
|
129
|
+
break
|
|
130
|
+
|
|
131
|
+
to_plot = plot_patch_reconstruction(
|
|
132
|
+
reconstruction,
|
|
133
|
+
idx_to_class,
|
|
134
|
+
class_to_idx=test_dataloader.dataset.class_to_idx, # type: ignore[attr-defined]
|
|
135
|
+
ignore_classes=ignore_classes,
|
|
136
|
+
is_polygon=is_polygon,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if to_plot:
|
|
140
|
+
output_name = f"reconstruction_{os.path.splitext(os.path.basename(reconstruction['file_name']))[0]}.png"
|
|
141
|
+
plt.savefig(os.path.join(example_folder, output_name), bbox_inches="tight", pad_inches=0)
|
|
142
|
+
|
|
143
|
+
plt.close()
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class RleEncoder(json.JSONEncoder):
|
|
147
|
+
"""Custom encoder to convert numpy arrays to RLE."""
|
|
148
|
+
|
|
149
|
+
def default(self, o: Any):
|
|
150
|
+
"""Customize standard encoder behaviour to convert numpy arrays to RLE."""
|
|
151
|
+
if isinstance(o, np.ndarray):
|
|
152
|
+
return mask2rle(o)
|
|
153
|
+
return json.JSONEncoder.default(self, o)
|