quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +32 -0
- quadra/configs/callbacks/default.yaml +37 -0
- quadra/configs/callbacks/default_anomalib.yaml +67 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +46 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +331 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1264 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +579 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +524 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.1.13.dist-info/LICENSE +201 -0
- quadra-2.1.13.dist-info/METADATA +386 -0
- quadra-2.1.13.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
- quadra-2.1.13.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,1433 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import itertools
|
|
5
|
+
import json
|
|
6
|
+
import math
|
|
7
|
+
import os
|
|
8
|
+
import random
|
|
9
|
+
import shutil
|
|
10
|
+
import warnings
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from copy import deepcopy
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from functools import partial
|
|
15
|
+
from multiprocessing import Pool
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
import cv2
|
|
19
|
+
import h5py
|
|
20
|
+
import numpy as np
|
|
21
|
+
from scipy import ndimage
|
|
22
|
+
from skimage.measure import label, regionprops # pylint: disable=no-name-in-module
|
|
23
|
+
from skimage.util import view_as_windows
|
|
24
|
+
from skmultilearn.model_selection import iterative_train_test_split
|
|
25
|
+
from tqdm import tqdm
|
|
26
|
+
from tripy import earclip
|
|
27
|
+
|
|
28
|
+
from quadra.utils import utils
|
|
29
|
+
|
|
30
|
+
log = utils.get_logger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class PatchDatasetFileFormat:
|
|
35
|
+
"""Model representing the content of the patch dataset split_files field in the info.json file."""
|
|
36
|
+
|
|
37
|
+
image_path: str
|
|
38
|
+
mask_path: str | None = None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class PatchDatasetInfo:
|
|
43
|
+
"""Model representing the content of the patch dataset info.json file."""
|
|
44
|
+
|
|
45
|
+
patch_size: tuple[int, int] | None
|
|
46
|
+
patch_number: tuple[int, int] | None
|
|
47
|
+
annotated_good: list[int] | None
|
|
48
|
+
overlap: float
|
|
49
|
+
train_files: list[PatchDatasetFileFormat]
|
|
50
|
+
val_files: list[PatchDatasetFileFormat]
|
|
51
|
+
test_files: list[PatchDatasetFileFormat]
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def _map_files(files: list[Any]):
|
|
55
|
+
"""Convert a list of dict to a list of PatchDatasetFileFormat."""
|
|
56
|
+
mapped_files = []
|
|
57
|
+
for file in files:
|
|
58
|
+
current_file = file
|
|
59
|
+
if isinstance(file, dict):
|
|
60
|
+
current_file = PatchDatasetFileFormat(**current_file)
|
|
61
|
+
mapped_files.append(current_file)
|
|
62
|
+
|
|
63
|
+
return mapped_files
|
|
64
|
+
|
|
65
|
+
def __post_init__(self):
|
|
66
|
+
self.train_files = self._map_files(self.train_files)
|
|
67
|
+
self.val_files = self._map_files(self.val_files)
|
|
68
|
+
self.test_files = self._map_files(self.test_files)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_image_mask_association(
|
|
72
|
+
data_folder: str,
|
|
73
|
+
mask_folder: str | None = None,
|
|
74
|
+
mask_extension: str = "",
|
|
75
|
+
warning_on_missing_mask: bool = True,
|
|
76
|
+
) -> list[dict]:
|
|
77
|
+
"""Function used to match images and mask from a folder or sub-folders.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
data_folder: root data folder containing images or images and masks
|
|
81
|
+
mask_folder: Optional root directory used to search only the masks
|
|
82
|
+
mask_extension: extension used to identify the mask file, it's mandatory if mask_folder is not specified
|
|
83
|
+
warning_on_missing_mask: if set to True a warning will be raised if a mask is missing, disable if you know
|
|
84
|
+
that many images do not have a mask.
|
|
85
|
+
warning_on_missing_mask: if set to True a warning will be raised if a mask is missing, disable if you know
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
List of dict like:
|
|
89
|
+
[
|
|
90
|
+
{
|
|
91
|
+
'base_name': '161927.tiff',
|
|
92
|
+
'path': 'test_dataset_patch/images/161927.tiff',
|
|
93
|
+
'mask': 'test_dataset_patch/masks/161927_mask.tiff'
|
|
94
|
+
}, ...
|
|
95
|
+
]
|
|
96
|
+
"""
|
|
97
|
+
# get all the images from the data folder
|
|
98
|
+
data_images = glob.glob(os.path.join(data_folder, "**", "*"), recursive=True)
|
|
99
|
+
|
|
100
|
+
basenames = [os.path.splitext(os.path.basename(image))[0] for image in data_images]
|
|
101
|
+
|
|
102
|
+
if len(set(basenames)) != len(basenames):
|
|
103
|
+
raise ValueError("Found multiple images with the same name and different extension, this is not supported.")
|
|
104
|
+
|
|
105
|
+
log.info("Found: %d images in %s", len(data_images), data_folder)
|
|
106
|
+
# divide images and mask if in the same folder
|
|
107
|
+
# if mask folder is specified search mask in that folder
|
|
108
|
+
if mask_folder:
|
|
109
|
+
masks_images = []
|
|
110
|
+
for basename in basenames:
|
|
111
|
+
mask_path = os.path.join(mask_folder, f"{basename}{mask_extension}.*")
|
|
112
|
+
mask_path_list = glob.glob(mask_path)
|
|
113
|
+
|
|
114
|
+
if len(mask_path_list) == 1:
|
|
115
|
+
masks_images.append(mask_path_list[0])
|
|
116
|
+
elif warning_on_missing_mask:
|
|
117
|
+
log.warning("Mask for %s not found", basename)
|
|
118
|
+
else:
|
|
119
|
+
if mask_extension == "":
|
|
120
|
+
raise ValueError("If no mask folder is provided, mask extension is mandatory it cannot be empty.")
|
|
121
|
+
|
|
122
|
+
masks_images = [image for image in data_images if mask_extension in image]
|
|
123
|
+
data_images = [image for image in data_images if mask_extension not in image]
|
|
124
|
+
|
|
125
|
+
# build support dictionary
|
|
126
|
+
unique_images = [{"base_name": os.path.basename(image), "path": image, "mask": None} for image in data_images]
|
|
127
|
+
|
|
128
|
+
images_stem = [os.path.splitext(str(image["base_name"]))[0] + mask_extension for image in unique_images]
|
|
129
|
+
masks_stem = [os.path.splitext(os.path.basename(mask))[0] for mask in masks_images]
|
|
130
|
+
|
|
131
|
+
# search corrispondency between file or folders
|
|
132
|
+
for i, image_stem in enumerate(images_stem):
|
|
133
|
+
if image_stem in masks_stem:
|
|
134
|
+
unique_images[i]["mask"] = masks_images[masks_stem.index(image_stem)]
|
|
135
|
+
|
|
136
|
+
log.info("Unique images with mask: %d", len([uni for uni in unique_images if uni.get("mask") is not None]))
|
|
137
|
+
log.info("Unique images with no mask: %d", len([uni for uni in unique_images if uni.get("mask") is None]))
|
|
138
|
+
|
|
139
|
+
return unique_images
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def compute_patch_info(
|
|
143
|
+
img_h: int,
|
|
144
|
+
img_w: int,
|
|
145
|
+
patch_num_h: int,
|
|
146
|
+
patch_num_w: int,
|
|
147
|
+
overlap: float = 0.0,
|
|
148
|
+
) -> tuple[tuple[int, int], tuple[int, int]]:
|
|
149
|
+
"""Compute the patch size and step size given the number of patches and the overlap.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
img_h: height of the image
|
|
153
|
+
img_w: width of the image
|
|
154
|
+
patch_num_h: number of vertical patches
|
|
155
|
+
patch_num_w: number of horizontal patches
|
|
156
|
+
overlap: percentage of overlap between patches.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Tuple containing:
|
|
160
|
+
patch_size: [size_y, size_x] Dimension of the patch
|
|
161
|
+
step_size: [step_y, step_x] Step size
|
|
162
|
+
"""
|
|
163
|
+
patch_size_h = np.ceil(img_h / (1 + (patch_num_h - 1) - (patch_num_h - 1) * overlap)).astype(int)
|
|
164
|
+
step_h = patch_size_h - np.ceil(overlap * patch_size_h).astype(int)
|
|
165
|
+
|
|
166
|
+
patch_size_w = np.ceil(img_w / (1 + (patch_num_w - 1) - (patch_num_w - 1) * overlap)).astype(int)
|
|
167
|
+
step_w = patch_size_w - np.ceil(overlap * patch_size_w).astype(int)
|
|
168
|
+
|
|
169
|
+
# We want a combination of patch size and step that if the image is not divisible by the number of patches
|
|
170
|
+
# will try to fit the maximum number of patches in the image + ONLY 1 extra patch that will be taken from the end
|
|
171
|
+
# of the image.
|
|
172
|
+
|
|
173
|
+
counter = 0
|
|
174
|
+
original_patch_size_h = patch_size_h
|
|
175
|
+
original_patch_size_w = patch_size_w
|
|
176
|
+
original_step_h = step_h
|
|
177
|
+
original_step_w = step_w
|
|
178
|
+
|
|
179
|
+
while (patch_num_h - 1) * step_h + patch_size_h < img_h or (patch_num_h - 2) * step_h + patch_size_h > img_h:
|
|
180
|
+
counter += 1
|
|
181
|
+
if (patch_num_h - 1) * (step_h + 1) + patch_size_h < img_h:
|
|
182
|
+
step_h += 1
|
|
183
|
+
else:
|
|
184
|
+
patch_size_h += 1
|
|
185
|
+
|
|
186
|
+
if counter == 100:
|
|
187
|
+
# We probably entered an infinite loop, restart with smaller step size
|
|
188
|
+
step_h = original_step_h - 1
|
|
189
|
+
patch_size_h = original_patch_size_h
|
|
190
|
+
counter = 0
|
|
191
|
+
|
|
192
|
+
counter = 0
|
|
193
|
+
while (patch_num_w - 1) * step_w + patch_size_w < img_w or (patch_num_w - 2) * step_w + patch_size_w > img_w:
|
|
194
|
+
counter += 1
|
|
195
|
+
if (patch_num_w - 1) * (step_w + 1) + patch_size_w < img_w:
|
|
196
|
+
step_w += 1
|
|
197
|
+
else:
|
|
198
|
+
patch_size_w += 1
|
|
199
|
+
|
|
200
|
+
if counter == 100:
|
|
201
|
+
# We probably entered an infinite loop, restart with smaller step size
|
|
202
|
+
step_w = original_step_w - 1
|
|
203
|
+
patch_size_w = original_patch_size_w
|
|
204
|
+
counter = 0
|
|
205
|
+
|
|
206
|
+
return (patch_size_h, patch_size_w), (step_h, step_w)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def compute_patch_info_from_patch_dim(
|
|
210
|
+
img_h: int,
|
|
211
|
+
img_w: int,
|
|
212
|
+
patch_height: int,
|
|
213
|
+
patch_width: int,
|
|
214
|
+
overlap: float = 0.0,
|
|
215
|
+
) -> tuple[tuple[int, int], tuple[int, int]]:
|
|
216
|
+
"""Compute patch info given the patch dimension
|
|
217
|
+
Args:
|
|
218
|
+
img_h: height of the image
|
|
219
|
+
img_w: width of the image
|
|
220
|
+
patch_height: patch height
|
|
221
|
+
patch_width: patch width
|
|
222
|
+
overlap: overlap percentage [0, 1].
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
Tuple of number of patches, step
|
|
226
|
+
|
|
227
|
+
"""
|
|
228
|
+
assert 1 >= overlap >= 0, f"Invalid overlap. Must be between [0, 1], received {overlap}"
|
|
229
|
+
step_h = patch_height - int(overlap * patch_height)
|
|
230
|
+
step_w = patch_width - int(overlap * patch_width)
|
|
231
|
+
|
|
232
|
+
patch_num_h = np.ceil(((img_h - patch_height) / step_h) + 1).astype(int)
|
|
233
|
+
patch_num_w = np.ceil(((img_w - patch_width) / step_w) + 1).astype(int)
|
|
234
|
+
|
|
235
|
+
# Handle the case where the last patch does not cover the full image, I need to do this rather than np.ceil
|
|
236
|
+
# because I don't want to add a new patch if the last one exceeds already the image!
|
|
237
|
+
if ((patch_num_h - 1) * step_h) + patch_height < img_h:
|
|
238
|
+
patch_num_h += 1
|
|
239
|
+
if ((patch_num_w - 1) * step_w) + patch_width < img_w:
|
|
240
|
+
patch_num_w += 1
|
|
241
|
+
|
|
242
|
+
return (patch_num_h, patch_num_w), (step_h, step_w)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def from_rgb_to_idx(img: np.ndarray, class_to_color: dict, class_to_idx: dict) -> np.ndarray:
|
|
246
|
+
"""Args:
|
|
247
|
+
img: Rgb mask in which each different color is associated with a class
|
|
248
|
+
class_to_color: Dict "key": [R, G, B]
|
|
249
|
+
class_to_idx: Dict "key": class_idx.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Grayscale mask in which each class is associated with a specific index
|
|
253
|
+
"""
|
|
254
|
+
img = img.astype(int)
|
|
255
|
+
# Use negative values to avoid strange behaviour in the remote eventuality
|
|
256
|
+
# of someone using a color like [1, 255, 255]
|
|
257
|
+
for classe, color in class_to_color.items():
|
|
258
|
+
img[np.all(img == color, axis=-1).astype(bool), 0] = -class_to_idx[classe]
|
|
259
|
+
|
|
260
|
+
img = np.abs(img[:, :, 0])
|
|
261
|
+
|
|
262
|
+
return img.astype(np.uint8)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def __save_patch_dataset(
|
|
266
|
+
image_patches: np.ndarray,
|
|
267
|
+
labelled_patches: np.ndarray | None = None,
|
|
268
|
+
mask_patches: np.ndarray | None = None,
|
|
269
|
+
labelled_mask: np.ndarray | None = None,
|
|
270
|
+
output_folder: str = "extraction_data",
|
|
271
|
+
image_name: str = "example",
|
|
272
|
+
area_threshold: float = 0.45,
|
|
273
|
+
area_defect_threshold: float = 0.2,
|
|
274
|
+
mask_extension: str = "_mask",
|
|
275
|
+
save_mask: bool = False,
|
|
276
|
+
mask_output_folder: str | None = None,
|
|
277
|
+
class_to_idx: dict | None = None,
|
|
278
|
+
) -> None:
|
|
279
|
+
"""Given a view_as_window computed patches, masks and labelled mask, save all the images in subdirectory
|
|
280
|
+
divided by name and position in the grid, ambiguous patches i.e. the one that contains defects but with not enough
|
|
281
|
+
to go above defined thresholds are marked as #DISCARD# and should be discarded in training.
|
|
282
|
+
Patches of images without ground truth are saved inside the None folder.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
image_patches: [n, m, patch_w, patch_h, channel] numpy array of the image patches
|
|
286
|
+
mask_patches: [n, m, patch_w, patch_h] numpy array of mask patches
|
|
287
|
+
labelled_patches: [n, m, patch_w, patch_h] numpy array of labelled mask patch
|
|
288
|
+
labelled_mask: numpy array in which each defect in the image is labelled using connected components
|
|
289
|
+
class_to_idx: Dictionary with the mapping {"class" -> class in mask}, it must cover all indices
|
|
290
|
+
contained in the masks
|
|
291
|
+
save_mask: flag to save or ignore mask
|
|
292
|
+
output_folder: folder where to save data
|
|
293
|
+
mask_extension: postfix of the saved mask based on the image name
|
|
294
|
+
mask_output_folder: Optional folder in which to save the masks
|
|
295
|
+
image_name: name to use in order to save the data
|
|
296
|
+
area_threshold: minimum percentage of defected patch area present in the mask to classify the patch as defect
|
|
297
|
+
area_defect_threshold: minimum percentage of single defect present in the patch to classify the patch as defect
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
None
|
|
301
|
+
"""
|
|
302
|
+
if class_to_idx is not None:
|
|
303
|
+
log.debug("Classes from dict: %s", class_to_idx)
|
|
304
|
+
index_to_class = {v: k for k, v in class_to_idx.items()}
|
|
305
|
+
log.debug("Inverse class: %s", index_to_class)
|
|
306
|
+
reference_classes = index_to_class
|
|
307
|
+
|
|
308
|
+
if mask_patches is not None:
|
|
309
|
+
classes_in_mask = set(np.unique(mask_patches))
|
|
310
|
+
missing_classes = set(classes_in_mask).difference(class_to_idx.values())
|
|
311
|
+
|
|
312
|
+
assert len(missing_classes) == 0, f"Found index in mask that has no corresponding class {missing_classes}"
|
|
313
|
+
elif mask_patches is not None:
|
|
314
|
+
reference_classes = {k: str(v) for k, v in enumerate(list(np.unique(mask_patches)))}
|
|
315
|
+
else:
|
|
316
|
+
raise ValueError("If no `class_to_idx` is provided, `mask_patches` must be provided")
|
|
317
|
+
|
|
318
|
+
log.debug("Classes from mask: %s", reference_classes)
|
|
319
|
+
class_to_idx = {v: k for k, v in reference_classes.items()}
|
|
320
|
+
log.debug("Final reference classes: %s", reference_classes)
|
|
321
|
+
|
|
322
|
+
# create subdirectory for the saving data
|
|
323
|
+
for cl in reference_classes.values():
|
|
324
|
+
os.makedirs(os.path.join(output_folder, str(cl)), exist_ok=True)
|
|
325
|
+
|
|
326
|
+
if mask_output_folder is not None:
|
|
327
|
+
os.makedirs(os.path.join(output_folder, mask_output_folder, str(cl)), exist_ok=True)
|
|
328
|
+
|
|
329
|
+
if mask_output_folder is None:
|
|
330
|
+
mask_output_folder = output_folder
|
|
331
|
+
else:
|
|
332
|
+
mask_output_folder = os.path.join(output_folder, mask_output_folder)
|
|
333
|
+
|
|
334
|
+
log.debug("Mask out: %s", mask_output_folder)
|
|
335
|
+
|
|
336
|
+
if mask_patches is None:
|
|
337
|
+
os.makedirs(os.path.join(output_folder, str(None)), exist_ok=True)
|
|
338
|
+
# for [i, j] in patches location
|
|
339
|
+
for row_index in range(image_patches.shape[0]):
|
|
340
|
+
for col_index in range(image_patches.shape[1]):
|
|
341
|
+
# default class it's the one in index 0
|
|
342
|
+
output_class = reference_classes.get(0)
|
|
343
|
+
image = image_patches[row_index, col_index]
|
|
344
|
+
|
|
345
|
+
discard_in_training = True
|
|
346
|
+
if mask_patches is not None and labelled_patches is not None:
|
|
347
|
+
discard_in_training = False
|
|
348
|
+
max_defected_area = 0
|
|
349
|
+
mask = mask_patches[row_index, col_index]
|
|
350
|
+
patch_area_th = mask.shape[0] * mask.shape[1] * area_threshold
|
|
351
|
+
|
|
352
|
+
if mask.sum() > 0:
|
|
353
|
+
discard_in_training = True
|
|
354
|
+
for k, v in class_to_idx.items():
|
|
355
|
+
if v == 0:
|
|
356
|
+
continue
|
|
357
|
+
|
|
358
|
+
mask_patch = mask == int(v)
|
|
359
|
+
defected_area = mask_patch.sum()
|
|
360
|
+
|
|
361
|
+
if defected_area > 0:
|
|
362
|
+
# If enough defected area is inside the patch
|
|
363
|
+
if defected_area > patch_area_th:
|
|
364
|
+
if defected_area > max_defected_area:
|
|
365
|
+
output_class = k
|
|
366
|
+
max_defected_area = defected_area
|
|
367
|
+
discard_in_training = False
|
|
368
|
+
else:
|
|
369
|
+
all_defects_in_patch = mask_patch * labelled_patches[row_index, col_index]
|
|
370
|
+
|
|
371
|
+
# For each different defect inside the area check
|
|
372
|
+
# if enough part of it is contained in the patch
|
|
373
|
+
for defect_id in np.unique(all_defects_in_patch):
|
|
374
|
+
if defect_id == 0:
|
|
375
|
+
continue
|
|
376
|
+
|
|
377
|
+
defect_area_in_patch = (all_defects_in_patch == defect_id).sum()
|
|
378
|
+
defect_area_th = (labelled_mask == defect_id).sum() * area_defect_threshold
|
|
379
|
+
|
|
380
|
+
if defect_area_in_patch > defect_area_th:
|
|
381
|
+
output_class = k
|
|
382
|
+
if defect_area_in_patch > max_defected_area:
|
|
383
|
+
max_defected_area = defect_area_in_patch
|
|
384
|
+
discard_in_training = False
|
|
385
|
+
else:
|
|
386
|
+
discard_in_training = False
|
|
387
|
+
|
|
388
|
+
if save_mask:
|
|
389
|
+
mask_name = f"{image_name}_{row_index * image_patches.shape[1] + col_index}{mask_extension}.png"
|
|
390
|
+
|
|
391
|
+
if discard_in_training:
|
|
392
|
+
mask_name = "#DISCARD#" + mask_name
|
|
393
|
+
cv2.imwrite(
|
|
394
|
+
os.path.join(
|
|
395
|
+
mask_output_folder,
|
|
396
|
+
output_class, # type: ignore[arg-type]
|
|
397
|
+
mask_name,
|
|
398
|
+
),
|
|
399
|
+
mask.astype(np.uint8),
|
|
400
|
+
)
|
|
401
|
+
else:
|
|
402
|
+
output_class = "None"
|
|
403
|
+
|
|
404
|
+
patch_name = f"{image_name}_{row_index * image_patches.shape[1] + col_index}.png"
|
|
405
|
+
if discard_in_training:
|
|
406
|
+
patch_name = "#DISCARD#" + patch_name
|
|
407
|
+
|
|
408
|
+
cv2.imwrite(
|
|
409
|
+
os.path.join(
|
|
410
|
+
output_folder,
|
|
411
|
+
output_class, # type: ignore[arg-type]
|
|
412
|
+
patch_name,
|
|
413
|
+
),
|
|
414
|
+
image,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def generate_patch_dataset(
|
|
419
|
+
data_dictionary: list[dict],
|
|
420
|
+
class_to_idx: dict,
|
|
421
|
+
val_size: float = 0.3,
|
|
422
|
+
test_size: float = 0.0,
|
|
423
|
+
seed: int = 42,
|
|
424
|
+
patch_number: tuple[int, int] | None = None,
|
|
425
|
+
patch_size: tuple[int, int] | None = None,
|
|
426
|
+
overlap: float = 0.0,
|
|
427
|
+
output_folder: str = "extraction_data",
|
|
428
|
+
save_original_images_and_masks: bool = True,
|
|
429
|
+
area_threshold: float = 0.45,
|
|
430
|
+
area_defect_threshold: float = 0.2,
|
|
431
|
+
mask_extension: str = "_mask",
|
|
432
|
+
mask_output_folder: str | None = None,
|
|
433
|
+
save_mask: bool = False,
|
|
434
|
+
clear_output_folder: bool = False,
|
|
435
|
+
mask_preprocessing: Callable | None = None,
|
|
436
|
+
train_filename: str = "dataset.txt",
|
|
437
|
+
repeat_good_images: int = 1,
|
|
438
|
+
balance_defects: bool = True,
|
|
439
|
+
annotated_good: list[str] | None = None,
|
|
440
|
+
num_workers: int = 1,
|
|
441
|
+
) -> dict | None:
|
|
442
|
+
"""Giving a data_dictionary as:
|
|
443
|
+
>>> {
|
|
444
|
+
>>> 'base_name': '163931_1_5.jpg',
|
|
445
|
+
>>> 'path': 'extraction_data/1/163931_1_5.jpg',
|
|
446
|
+
>>> 'mask': 'extraction_data/1/163931_1_5_mask.jpg'
|
|
447
|
+
>>>}
|
|
448
|
+
This function will generate patches datasets based on the defined split number, one for training, one for validation
|
|
449
|
+
and one for testing respectively under output_folder/train, output_folder/val and output_folder/test, the training
|
|
450
|
+
dataset will contain h5 files and a txt file resulting from a call to the
|
|
451
|
+
generate_classification_patch_train_dataset, while the test dataset will contain patches saved on disk divided
|
|
452
|
+
in subfolders per class, patch extraction is done in a sliding window fashion.
|
|
453
|
+
Original images and masks (preprocessed if mask_preprocessing is present) will also be saved under
|
|
454
|
+
output_folder/original/images and output_folder/original/masks.
|
|
455
|
+
If patch number is specified the patch size will be calculated accordingly, if the image is not divisible by the
|
|
456
|
+
patch number two possible behaviours can occur:
|
|
457
|
+
- if the patch reconstruction is smaller than the original image a new patch will be generated containing the
|
|
458
|
+
pixels from the edge of the image (E.g the new patch will contain the last patch_size pixels of the original
|
|
459
|
+
image)
|
|
460
|
+
- if the patch reconstruction is bigger than the original image the last patch will contain the pixels from the
|
|
461
|
+
edge of the image same as above, but without adding a new patch to the count.
|
|
462
|
+
|
|
463
|
+
Args:
|
|
464
|
+
data_dictionary: Dictionary as above
|
|
465
|
+
val_size: percentage of the dictionary entries to be used for validation
|
|
466
|
+
test_size: percentage of the dictionary entries to be used for testing
|
|
467
|
+
seed: seed for rng based operations
|
|
468
|
+
clear_output_folder: flag used to delete all the data in subfolder
|
|
469
|
+
class_to_idx: Dictionary {"defect": value in mask.. }
|
|
470
|
+
output_folder: root_folder where to extract the data
|
|
471
|
+
save_original_images_and_masks: If True, images and masks will be copied inside output_folder/original/
|
|
472
|
+
area_threshold: Minimum percentage of defected patch area present in the mask to classify the patch as defect
|
|
473
|
+
area_defect_threshold: Minimum percentage of single defect present in the patch to classify the patch as defect
|
|
474
|
+
mask_extension: Extension used to assign image to mask
|
|
475
|
+
mask_output_folder: Optional folder in which to save the masks
|
|
476
|
+
save_mask: Flag to save the mask
|
|
477
|
+
patch_number: Optional number of patches for each side, required if patch_size is None
|
|
478
|
+
patch_size: Optional dimension of the patch, required if patch_number is None
|
|
479
|
+
overlap: Overlap of the patches [0, 1]
|
|
480
|
+
mask_preprocessing: Optional function applied to masks, this can be useful for example to convert an image in
|
|
481
|
+
range [0-255] to the required [0-1]
|
|
482
|
+
train_filename: Name of the file containing mapping between h5 files and labels for training
|
|
483
|
+
repeat_good_images: Number of repetition for images with emtpy or None mask
|
|
484
|
+
balance_defects: If true add one good entry for each defect extracted
|
|
485
|
+
annotated_good: List of labels that are annotated but considered as good
|
|
486
|
+
num_workers: Number of workers used for the h5 creation
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
None if data_dictionary is empty, otherwise return a dictionary containing informations about the dataset
|
|
490
|
+
|
|
491
|
+
"""
|
|
492
|
+
if len(data_dictionary) == 0:
|
|
493
|
+
warnings.warn("Input data dictionary is empty!", UserWarning, stacklevel=2)
|
|
494
|
+
return None
|
|
495
|
+
|
|
496
|
+
if val_size < 0 or test_size < 0 or (val_size + test_size) > 1:
|
|
497
|
+
raise ValueError("Validation and Test size must be greater or equal than zero and sum up to maximum 1")
|
|
498
|
+
if clear_output_folder and os.path.exists(output_folder):
|
|
499
|
+
shutil.rmtree(output_folder)
|
|
500
|
+
os.makedirs(output_folder, exist_ok=True)
|
|
501
|
+
os.makedirs(os.path.join(output_folder, "original"), exist_ok=True)
|
|
502
|
+
if save_original_images_and_masks:
|
|
503
|
+
log.info("Moving original images and masks to dataset folder...")
|
|
504
|
+
os.makedirs(os.path.join(output_folder, "original", "images"), exist_ok=True)
|
|
505
|
+
os.makedirs(os.path.join(output_folder, "original", "masks"), exist_ok=True)
|
|
506
|
+
|
|
507
|
+
for i, item in enumerate(data_dictionary):
|
|
508
|
+
img_new_path = os.path.join("original", "images", item["base_name"])
|
|
509
|
+
shutil.copy(item["path"], os.path.join(output_folder, img_new_path))
|
|
510
|
+
data_dictionary[i]["path"] = img_new_path
|
|
511
|
+
|
|
512
|
+
if item["mask"] is not None:
|
|
513
|
+
mask = cv2.imread(item["mask"])
|
|
514
|
+
if mask_preprocessing is not None:
|
|
515
|
+
mask = mask_preprocessing(mask).astype(np.uint8)
|
|
516
|
+
mask_new_path = os.path.join("original", "masks", os.path.splitext(item["base_name"])[0] + ".png")
|
|
517
|
+
cv2.imwrite(os.path.join(output_folder, mask_new_path), mask)
|
|
518
|
+
data_dictionary[i]["mask"] = mask_new_path
|
|
519
|
+
|
|
520
|
+
shuffled_indices = np.random.default_rng(seed).permutation(len(data_dictionary))
|
|
521
|
+
data_dictionary = [data_dictionary[i] for i in shuffled_indices]
|
|
522
|
+
log.info("Performing multilabel stratification...")
|
|
523
|
+
train_data_dictionary, val_data_dictionary, test_data_dictionary = multilabel_stratification(
|
|
524
|
+
output_folder=output_folder,
|
|
525
|
+
data_dictionary=data_dictionary,
|
|
526
|
+
num_classes=len(class_to_idx.values()),
|
|
527
|
+
val_size=val_size,
|
|
528
|
+
test_size=test_size,
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
log.info("Train set size: %d", len(train_data_dictionary))
|
|
532
|
+
log.info("Validation set size: %d", len(val_data_dictionary))
|
|
533
|
+
log.info("Test set size: %d", len(test_data_dictionary))
|
|
534
|
+
|
|
535
|
+
idx_to_class = {v: k for (k, v) in class_to_idx.items()}
|
|
536
|
+
|
|
537
|
+
os.makedirs(output_folder, exist_ok=True)
|
|
538
|
+
|
|
539
|
+
dataset_info = {
|
|
540
|
+
"patch_size": patch_size,
|
|
541
|
+
"patch_number": patch_number,
|
|
542
|
+
"overlap": overlap,
|
|
543
|
+
"annotated_good": annotated_good,
|
|
544
|
+
"train_files": [{"image_path": x["path"], "mask_path": x["mask"]} for x in train_data_dictionary],
|
|
545
|
+
"val_files": [{"image_path": x["path"], "mask_path": x["mask"]} for x in val_data_dictionary],
|
|
546
|
+
"test_files": [{"image_path": x["path"], "mask_path": x["mask"]} for x in test_data_dictionary],
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
with open(os.path.join(output_folder, "info.json"), "w") as f:
|
|
550
|
+
json.dump(dataset_info, f)
|
|
551
|
+
|
|
552
|
+
if len(train_data_dictionary) > 0:
|
|
553
|
+
log.info("Generating train set")
|
|
554
|
+
generate_patch_sampling_dataset(
|
|
555
|
+
data_dictionary=train_data_dictionary,
|
|
556
|
+
patch_number=patch_number,
|
|
557
|
+
patch_size=patch_size,
|
|
558
|
+
overlap=overlap,
|
|
559
|
+
idx_to_class=idx_to_class,
|
|
560
|
+
balance_defects=balance_defects,
|
|
561
|
+
repeat_good_images=repeat_good_images,
|
|
562
|
+
output_folder=output_folder,
|
|
563
|
+
subfolder_name="train",
|
|
564
|
+
train_filename=train_filename,
|
|
565
|
+
annotated_good=annotated_good if annotated_good is None else [class_to_idx[x] for x in annotated_good],
|
|
566
|
+
num_workers=num_workers,
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
for phase, split_dict in zip(["val", "test"], [val_data_dictionary, test_data_dictionary]):
|
|
570
|
+
if len(split_dict) > 0:
|
|
571
|
+
log.info("Generating %s set", phase)
|
|
572
|
+
generate_patch_sliding_window_dataset(
|
|
573
|
+
data_dictionary=split_dict,
|
|
574
|
+
patch_number=patch_number,
|
|
575
|
+
patch_size=patch_size,
|
|
576
|
+
overlap=overlap,
|
|
577
|
+
output_folder=output_folder,
|
|
578
|
+
subfolder_name=phase,
|
|
579
|
+
area_threshold=area_threshold,
|
|
580
|
+
area_defect_threshold=area_defect_threshold,
|
|
581
|
+
mask_extension=mask_extension,
|
|
582
|
+
mask_output_folder=mask_output_folder,
|
|
583
|
+
save_mask=save_mask,
|
|
584
|
+
class_to_idx=class_to_idx,
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
log.info("All done! Datasets saved to %s", output_folder)
|
|
588
|
+
|
|
589
|
+
return dataset_info
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def multilabel_stratification(
|
|
593
|
+
output_folder: str,
|
|
594
|
+
data_dictionary: list[dict],
|
|
595
|
+
num_classes: int,
|
|
596
|
+
val_size: float,
|
|
597
|
+
test_size: float,
|
|
598
|
+
) -> tuple[list[dict], list[dict], list[dict]]:
|
|
599
|
+
"""Split data dictionary using multilabel based stratification, place every sample with None
|
|
600
|
+
mask inside the test set,for all the others read the labels contained in the masks
|
|
601
|
+
to create one-hot encoded labels.
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
output_folder: root folder of the dataset
|
|
605
|
+
data_dictionary: Data dictionary as described in generate patch dataset
|
|
606
|
+
num_classes: Number of classes contained in the dataset, required for one hot encoding
|
|
607
|
+
val_size: Percentage of data to be used for validation
|
|
608
|
+
test_size: Percentage of data to be used for test
|
|
609
|
+
Returns:
|
|
610
|
+
Three data dictionaries, one for training, one for validation and one for test
|
|
611
|
+
|
|
612
|
+
"""
|
|
613
|
+
if val_size + test_size == 0:
|
|
614
|
+
return data_dictionary, [], []
|
|
615
|
+
if val_size == 1:
|
|
616
|
+
return [], data_dictionary, []
|
|
617
|
+
if test_size == 1:
|
|
618
|
+
return [], [], data_dictionary
|
|
619
|
+
|
|
620
|
+
test_data_dictionary = list(filter(lambda q: q["mask"] is None, data_dictionary))
|
|
621
|
+
log.info("Number of images with no mask inserted in test_data_dictionary: %d", len(test_data_dictionary))
|
|
622
|
+
empty_test_size = len(test_data_dictionary) / len(data_dictionary)
|
|
623
|
+
data_dictionary = list(filter(lambda q: q["mask"] is not None, data_dictionary))
|
|
624
|
+
|
|
625
|
+
if len(data_dictionary) == 0:
|
|
626
|
+
# All the item in the data dictionary have None mask, put everything in test
|
|
627
|
+
warnings.warn(
|
|
628
|
+
"All the images have None mask and the test size is not equal to 1! Put everything in test",
|
|
629
|
+
UserWarning,
|
|
630
|
+
stacklevel=2,
|
|
631
|
+
)
|
|
632
|
+
return [], [], test_data_dictionary
|
|
633
|
+
|
|
634
|
+
x = []
|
|
635
|
+
y = None
|
|
636
|
+
for item in data_dictionary:
|
|
637
|
+
one_hot = np.zeros([1, num_classes], dtype=np.int16)
|
|
638
|
+
if item["mask"] is None:
|
|
639
|
+
continue
|
|
640
|
+
# this works even if item["mask"] is already an absolute path
|
|
641
|
+
mask = cv2.imread(os.path.join(output_folder, item["mask"]), 0)
|
|
642
|
+
|
|
643
|
+
labels = np.unique(mask)
|
|
644
|
+
|
|
645
|
+
one_hot[:, labels] = 1
|
|
646
|
+
x.append(item["base_name"])
|
|
647
|
+
if y is None:
|
|
648
|
+
y = one_hot
|
|
649
|
+
else:
|
|
650
|
+
y = np.concatenate([y, one_hot])
|
|
651
|
+
|
|
652
|
+
x_test: list[Any] | np.ndarray
|
|
653
|
+
|
|
654
|
+
if empty_test_size > test_size:
|
|
655
|
+
warnings.warn(
|
|
656
|
+
(
|
|
657
|
+
"The percentage of images with None label is greater than the test_size, the newest test_size is"
|
|
658
|
+
f" {empty_test_size}!"
|
|
659
|
+
),
|
|
660
|
+
UserWarning,
|
|
661
|
+
stacklevel=2,
|
|
662
|
+
)
|
|
663
|
+
x_train, _, x_val, _ = iterative_train_test_split(np.expand_dims(np.array(x), 1), y, val_size)
|
|
664
|
+
x_test = [q["base_name"] for q in test_data_dictionary]
|
|
665
|
+
else:
|
|
666
|
+
test_size -= empty_test_size
|
|
667
|
+
x_train, _, x_remaining, y_remaining = iterative_train_test_split(
|
|
668
|
+
np.expand_dims(np.array(x), 1), y, val_size + test_size
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
if x_remaining.shape[0] == 1:
|
|
672
|
+
if test_size == 0:
|
|
673
|
+
x_val = x_remaining
|
|
674
|
+
x_test = np.array([])
|
|
675
|
+
elif val_size == 0:
|
|
676
|
+
x_test = x_remaining
|
|
677
|
+
x_val = np.array([])
|
|
678
|
+
else:
|
|
679
|
+
log.warning("Not enough data to create the test split, only a validation set of size 1 will be created")
|
|
680
|
+
x_val = x_remaining
|
|
681
|
+
x_test = np.array([])
|
|
682
|
+
else:
|
|
683
|
+
x_val, _, x_test, _ = iterative_train_test_split(
|
|
684
|
+
x_remaining, y_remaining, test_size / (val_size + test_size)
|
|
685
|
+
)
|
|
686
|
+
# Here x_test should be always a numpy array, but mypy does not recognize it
|
|
687
|
+
x_test = [q[0] for q in x_test.tolist()] # type: ignore[union-attr]
|
|
688
|
+
x_test.extend([q["base_name"] for q in test_data_dictionary])
|
|
689
|
+
|
|
690
|
+
train_data_dictionary = list(filter(lambda q: q["base_name"] in x_train, data_dictionary))
|
|
691
|
+
val_data_dictionary = list(filter(lambda q: q["base_name"] in x_val, data_dictionary))
|
|
692
|
+
test_data_dictionary = list(filter(lambda q: q["base_name"] in x_test, data_dictionary + test_data_dictionary))
|
|
693
|
+
|
|
694
|
+
return train_data_dictionary, val_data_dictionary, test_data_dictionary
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def generate_patch_sliding_window_dataset(
|
|
698
|
+
data_dictionary: list[dict],
|
|
699
|
+
subfolder_name: str,
|
|
700
|
+
patch_number: tuple[int, int] | None = None,
|
|
701
|
+
patch_size: tuple[int, int] | None = None,
|
|
702
|
+
overlap: float = 0.0,
|
|
703
|
+
output_folder: str = "extraction_data",
|
|
704
|
+
area_threshold: float = 0.45,
|
|
705
|
+
area_defect_threshold: float = 0.2,
|
|
706
|
+
mask_extension: str = "_mask",
|
|
707
|
+
mask_output_folder: str | None = None,
|
|
708
|
+
save_mask: bool = False,
|
|
709
|
+
class_to_idx: dict | None = None,
|
|
710
|
+
) -> None:
|
|
711
|
+
"""Giving a data_dictionary as:
|
|
712
|
+
>>> {
|
|
713
|
+
>>> 'base_name': '163931_1_5.jpg',
|
|
714
|
+
>>> 'path': 'extraction_data/1/163931_1_5.jpg',
|
|
715
|
+
>>> 'mask': 'extraction_data/1/163931_1_5_mask.jpg'
|
|
716
|
+
>>>}
|
|
717
|
+
This function will extract the patches and save the file and the mask in subdirectory
|
|
718
|
+
Args:
|
|
719
|
+
data_dictionary: Dictionary as above
|
|
720
|
+
subfolder_name: Name of the subfolder where to save the extracted patches (output_folder/subfolder_name)
|
|
721
|
+
class_to_idx: Dictionary {"defect": value in mask.. }
|
|
722
|
+
output_folder: root_folder where to extract the data
|
|
723
|
+
area_threshold: minimum percentage of defected patch area present in the mask to classify the patch as defect
|
|
724
|
+
area_defect_threshold: minimum percentage of single defect present in the patch to classify the patch as defect
|
|
725
|
+
mask_extension: extension used to assign image to mask
|
|
726
|
+
mask_output_folder: Optional folder in which to save the masks
|
|
727
|
+
save_mask: flag to save the mask
|
|
728
|
+
patch_number: Optional number of patches for each side, required if patch_size is None
|
|
729
|
+
patch_size: Optional dimension of the patch, required if patch_number is None
|
|
730
|
+
overlap: overlap of the patches [0, 1].
|
|
731
|
+
|
|
732
|
+
Returns:
|
|
733
|
+
None.
|
|
734
|
+
|
|
735
|
+
"""
|
|
736
|
+
if save_mask and len(mask_extension) == 0 and mask_output_folder is None:
|
|
737
|
+
raise InvalidParameterCombinationException(
|
|
738
|
+
"If mask output folder is not set you must specify a mask extension in order to save masks!"
|
|
739
|
+
)
|
|
740
|
+
|
|
741
|
+
if patch_number is None and patch_size is None:
|
|
742
|
+
raise InvalidParameterCombinationException("One between patch number or patch size must be specified!")
|
|
743
|
+
|
|
744
|
+
for data in tqdm(data_dictionary):
|
|
745
|
+
base_id = data.get("base_name")
|
|
746
|
+
base_path = data.get("path")
|
|
747
|
+
base_mask = data.get("mask")
|
|
748
|
+
|
|
749
|
+
assert base_id is not None, "Cannot find base id in data_dictionary"
|
|
750
|
+
assert base_path is not None, "Cannot find image in data_dictionary"
|
|
751
|
+
|
|
752
|
+
image = cv2.imread(os.path.join(output_folder, base_path))
|
|
753
|
+
h = image.shape[0]
|
|
754
|
+
w = image.shape[1]
|
|
755
|
+
|
|
756
|
+
log.debug("Processing %s with shape %s", base_id, image.shape)
|
|
757
|
+
mask = mask_patches = None
|
|
758
|
+
labelled_mask = labelled_patches = None
|
|
759
|
+
|
|
760
|
+
if base_mask is not None:
|
|
761
|
+
mask = cv2.imread(os.path.join(output_folder, base_mask), 0)
|
|
762
|
+
labelled_mask = label(mask)
|
|
763
|
+
|
|
764
|
+
if patch_size is not None:
|
|
765
|
+
[patch_height, patch_width] = patch_size
|
|
766
|
+
[patch_num_h, patch_num_w], step = compute_patch_info_from_patch_dim(
|
|
767
|
+
h, w, patch_height, patch_width, overlap
|
|
768
|
+
)
|
|
769
|
+
elif patch_number is not None:
|
|
770
|
+
[patch_height, patch_width], step = compute_patch_info(h, w, patch_number[0], patch_number[1], overlap)
|
|
771
|
+
[patch_num_h, patch_num_w] = patch_number
|
|
772
|
+
else:
|
|
773
|
+
# mypy does not recognize that this is unreachable
|
|
774
|
+
raise InvalidParameterCombinationException("One between patch number or patch size must be specified!")
|
|
775
|
+
|
|
776
|
+
log.debug(
|
|
777
|
+
"Extracting %s patches with size %s, step %s", [patch_num_h, patch_num_w], [patch_height, patch_width], step
|
|
778
|
+
)
|
|
779
|
+
image_patches = extract_patches(image, (patch_num_h, patch_num_w), (patch_height, patch_width), step, overlap)
|
|
780
|
+
|
|
781
|
+
if mask is not None:
|
|
782
|
+
if labelled_mask is None:
|
|
783
|
+
raise ValueError("Labelled mask cannot be None!")
|
|
784
|
+
mask_patches = extract_patches(mask, (patch_num_h, patch_num_w), (patch_height, patch_width), step, overlap)
|
|
785
|
+
labelled_patches = extract_patches(
|
|
786
|
+
labelled_mask, (patch_num_h, patch_num_w), (patch_height, patch_width), step, overlap
|
|
787
|
+
)
|
|
788
|
+
assert image_patches.shape[:-1] == mask_patches.shape, "Image patches and mask patches mismatch!"
|
|
789
|
+
|
|
790
|
+
log.debug("Image patches shape: %s", image_patches.shape)
|
|
791
|
+
__save_patch_dataset(
|
|
792
|
+
image_patches=image_patches,
|
|
793
|
+
mask_patches=mask_patches,
|
|
794
|
+
labelled_patches=labelled_patches,
|
|
795
|
+
labelled_mask=labelled_mask,
|
|
796
|
+
image_name=os.path.splitext(base_id)[0],
|
|
797
|
+
output_folder=os.path.join(output_folder, subfolder_name),
|
|
798
|
+
area_threshold=area_threshold,
|
|
799
|
+
area_defect_threshold=area_defect_threshold,
|
|
800
|
+
mask_extension=mask_extension,
|
|
801
|
+
save_mask=save_mask,
|
|
802
|
+
mask_output_folder=mask_output_folder,
|
|
803
|
+
class_to_idx=class_to_idx,
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
def extract_patches(
|
|
808
|
+
image: np.ndarray,
|
|
809
|
+
patch_number: tuple[int, ...],
|
|
810
|
+
patch_size: tuple[int, ...],
|
|
811
|
+
step: tuple[int, ...],
|
|
812
|
+
overlap: float,
|
|
813
|
+
) -> np.ndarray:
|
|
814
|
+
"""From an image extract N x M Patch[h, w] if the image is not perfectly divided by the number of patches of given
|
|
815
|
+
dimension the last patch will contain a replica of the original image taken in range [-img_h:, :] or [:, -img_w:].
|
|
816
|
+
|
|
817
|
+
Args:
|
|
818
|
+
image: Numpy array of the image
|
|
819
|
+
patch_number: number of patches to be extracted
|
|
820
|
+
patch_size: dimension of the patch
|
|
821
|
+
step: step of the patch extraction
|
|
822
|
+
overlap: horizontal and vertical patch overlapping in range [0, 1]
|
|
823
|
+
|
|
824
|
+
Returns:
|
|
825
|
+
Patches [N, M, 1, image_w, image_h, image_c]
|
|
826
|
+
|
|
827
|
+
"""
|
|
828
|
+
assert 1.0 >= overlap >= 0.0, f"Overlap must be between 0 and 1. Received {overlap}"
|
|
829
|
+
(patch_num_h, patch_num_w) = patch_number
|
|
830
|
+
(patch_height, patch_width) = patch_size
|
|
831
|
+
|
|
832
|
+
pad_h = (patch_num_h - 1) * step[0] + patch_size[0] - image.shape[0]
|
|
833
|
+
pad_w = (patch_num_w - 1) * step[1] + patch_size[1] - image.shape[1]
|
|
834
|
+
# if the image has 3 channel change dimension
|
|
835
|
+
if len(image.shape) == 3:
|
|
836
|
+
patch_size = (patch_size[0], patch_size[1], image.shape[2])
|
|
837
|
+
step = (step[0], step[1], image.shape[2])
|
|
838
|
+
|
|
839
|
+
# If this is not true there's some strange case I didn't take into account
|
|
840
|
+
if pad_h < 0 or pad_w < 0:
|
|
841
|
+
raise ValueError("Something went wrong with the patch extraction, expected positive padding values")
|
|
842
|
+
|
|
843
|
+
if pad_h > 0 or pad_w > 0:
|
|
844
|
+
# We work with copies as view_as_windows returns a view of the original image
|
|
845
|
+
crop_img = deepcopy(image)
|
|
846
|
+
|
|
847
|
+
if pad_h:
|
|
848
|
+
crop_img = crop_img[0 : (patch_num_h - 2) * step[0] + patch_height, :]
|
|
849
|
+
|
|
850
|
+
if pad_w:
|
|
851
|
+
crop_img = crop_img[:, 0 : (patch_num_w - 2) * step[1] + patch_width]
|
|
852
|
+
|
|
853
|
+
# Extract safe patches inside the image
|
|
854
|
+
patches = view_as_windows(crop_img, patch_size, step=step)
|
|
855
|
+
else:
|
|
856
|
+
patches = view_as_windows(image, patch_size, step=step)
|
|
857
|
+
|
|
858
|
+
extra_patches_h = None
|
|
859
|
+
extra_patches_w = None
|
|
860
|
+
|
|
861
|
+
if pad_h > 0:
|
|
862
|
+
# Append extra patches taken from the edge of the image
|
|
863
|
+
extra_patches_h = view_as_windows(image[-patch_height:, :], patch_size, step=step)
|
|
864
|
+
|
|
865
|
+
if pad_w > 0:
|
|
866
|
+
extra_patches_w = view_as_windows(image[:, -patch_width:], patch_size, step=step)
|
|
867
|
+
|
|
868
|
+
if extra_patches_h is not None:
|
|
869
|
+
# Add an extra column and set is content to the bottom right patch area of the original image if both
|
|
870
|
+
# dimension requires extra patches
|
|
871
|
+
if extra_patches_h.ndim == 6:
|
|
872
|
+
# RGB
|
|
873
|
+
extra_patches_h = np.concatenate(
|
|
874
|
+
[
|
|
875
|
+
extra_patches_h,
|
|
876
|
+
(np.zeros([1, 1, 1, patch_size[0], patch_size[1], extra_patches_h.shape[-1]], dtype=np.uint8)),
|
|
877
|
+
],
|
|
878
|
+
axis=1,
|
|
879
|
+
)
|
|
880
|
+
else:
|
|
881
|
+
extra_patches_h = np.concatenate(
|
|
882
|
+
[extra_patches_h, (np.zeros([1, 1, patch_size[0], patch_size[1]], dtype=np.uint8))], axis=1
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
if extra_patches_h is None:
|
|
886
|
+
# Required by mypy as it cannot infer that extra_patch_h cannot be None
|
|
887
|
+
raise ValueError("Extra patch h cannot be None!")
|
|
888
|
+
|
|
889
|
+
extra_patches_h[:, -1, :] = image[-patch_height:, -patch_width:]
|
|
890
|
+
|
|
891
|
+
if patches.ndim == 6:
|
|
892
|
+
# With RGB images there's an extra dimension, axis 2 is important don't use plain squeeze or it breaks if
|
|
893
|
+
# the number of patches is set to 1!
|
|
894
|
+
patches = patches.squeeze(axis=2)
|
|
895
|
+
|
|
896
|
+
if extra_patches_w is not None:
|
|
897
|
+
if extra_patches_w.ndim == 6:
|
|
898
|
+
# RGB
|
|
899
|
+
patches = np.concatenate([patches, extra_patches_w.squeeze(2)], axis=1)
|
|
900
|
+
else:
|
|
901
|
+
patches = np.concatenate([patches, extra_patches_w], axis=1)
|
|
902
|
+
|
|
903
|
+
if extra_patches_h is not None:
|
|
904
|
+
if extra_patches_h.ndim == 6:
|
|
905
|
+
# RGB
|
|
906
|
+
patches = np.concatenate([patches, extra_patches_h.squeeze(2)], axis=0)
|
|
907
|
+
else:
|
|
908
|
+
patches = np.concatenate([patches, extra_patches_h], axis=0)
|
|
909
|
+
|
|
910
|
+
# If this is not true there's some strange case I didn't take into account
|
|
911
|
+
assert (
|
|
912
|
+
patches.shape[0] == patch_num_h and patches.shape[1] == patch_num_w
|
|
913
|
+
), f"Patch shape {patches.shape} does not match the expected shape {patch_number}"
|
|
914
|
+
|
|
915
|
+
return patches
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
def generate_patch_sampling_dataset(
|
|
919
|
+
data_dictionary: list[dict[Any, Any]],
|
|
920
|
+
output_folder: str,
|
|
921
|
+
idx_to_class: dict,
|
|
922
|
+
overlap: float,
|
|
923
|
+
repeat_good_images: int = 1,
|
|
924
|
+
balance_defects: bool = True,
|
|
925
|
+
patch_number: tuple[int, int] | None = None,
|
|
926
|
+
patch_size: tuple[int, int] | None = None,
|
|
927
|
+
subfolder_name: str = "train",
|
|
928
|
+
train_filename: str = "dataset.txt",
|
|
929
|
+
annotated_good: list[int] | None = None,
|
|
930
|
+
num_workers: int = 1,
|
|
931
|
+
) -> None:
|
|
932
|
+
"""Generate a dataset of patches.
|
|
933
|
+
|
|
934
|
+
Args:
|
|
935
|
+
data_dictionary: Dictionary containing image and mask mapping
|
|
936
|
+
output_folder: root folder
|
|
937
|
+
idx_to_class: Dict mapping an index to the corresponding class name
|
|
938
|
+
repeat_good_images: Number of repetition for images with emtpy or None mask
|
|
939
|
+
balance_defects: If true add one good entry for each defect extracted
|
|
940
|
+
patch_number: Optional number of patches for each side, required if patch_size is None
|
|
941
|
+
patch_size: Optional dimension of the patch, required if patch_number is None
|
|
942
|
+
overlap: Percentage of overlap between patches
|
|
943
|
+
subfolder_name: name of the subfolder where to store h5 files for defected images and dataset txt
|
|
944
|
+
train_filename: Name of the file in which to store the mappings between h5 files and labels
|
|
945
|
+
annotated_good: List of class indices that are considered good other than the background
|
|
946
|
+
num_workers: Number of processes used to create h5 files.
|
|
947
|
+
|
|
948
|
+
Returns:
|
|
949
|
+
Create a txt file containing tuples path,label where path is a pointer to the generated h5 file and label is the
|
|
950
|
+
corresponding label
|
|
951
|
+
|
|
952
|
+
Each generated h5 file contains five fields:
|
|
953
|
+
img_path: Pointer to the location of the original image
|
|
954
|
+
mask_path: Optional pointer to the mask file, is missing if the mask is completely empty or is
|
|
955
|
+
not present
|
|
956
|
+
patch_size: dimension of the patches on the interested image
|
|
957
|
+
triangles: List of triangles that covers the defect
|
|
958
|
+
triangles_weights: Which weight should be given to each triangle for sampling
|
|
959
|
+
|
|
960
|
+
"""
|
|
961
|
+
if patch_number is None and patch_size is None:
|
|
962
|
+
raise InvalidParameterCombinationException("One between patch number or patch size must be specified!")
|
|
963
|
+
|
|
964
|
+
sampling_dataset_folder = os.path.join(output_folder, subfolder_name)
|
|
965
|
+
|
|
966
|
+
os.makedirs(sampling_dataset_folder, exist_ok=True)
|
|
967
|
+
labelled_masks_path = os.path.join(output_folder, "original", "labelled_masks")
|
|
968
|
+
os.makedirs(labelled_masks_path, exist_ok=True)
|
|
969
|
+
|
|
970
|
+
with open(os.path.join(sampling_dataset_folder, train_filename), "w") as output_file:
|
|
971
|
+
if num_workers < 1:
|
|
972
|
+
raise InvalidNumWorkersNumberException("Workers must be >= 1")
|
|
973
|
+
|
|
974
|
+
if num_workers > 1:
|
|
975
|
+
log.info("Executing generate_patch_sampling_dataset w/ more than 1 worker!")
|
|
976
|
+
|
|
977
|
+
split_data_dictionary = np.array_split(np.asarray(data_dictionary), num_workers)
|
|
978
|
+
|
|
979
|
+
with Pool(num_workers) as pool:
|
|
980
|
+
res_list = pool.map(
|
|
981
|
+
partial(
|
|
982
|
+
create_h5,
|
|
983
|
+
patch_size=patch_size,
|
|
984
|
+
patch_number=patch_number,
|
|
985
|
+
idx_to_class=idx_to_class,
|
|
986
|
+
overlap=overlap,
|
|
987
|
+
repeat_good_images=repeat_good_images,
|
|
988
|
+
balance_defects=balance_defects,
|
|
989
|
+
annotated_good=annotated_good,
|
|
990
|
+
output_folder=output_folder,
|
|
991
|
+
labelled_masks_path=labelled_masks_path,
|
|
992
|
+
sampling_dataset_folder=sampling_dataset_folder,
|
|
993
|
+
),
|
|
994
|
+
split_data_dictionary,
|
|
995
|
+
)
|
|
996
|
+
|
|
997
|
+
res = list(itertools.chain(*res_list))
|
|
998
|
+
else:
|
|
999
|
+
res = create_h5(
|
|
1000
|
+
data_dictionary=data_dictionary,
|
|
1001
|
+
patch_size=patch_size,
|
|
1002
|
+
patch_number=patch_number,
|
|
1003
|
+
idx_to_class=idx_to_class,
|
|
1004
|
+
overlap=overlap,
|
|
1005
|
+
repeat_good_images=repeat_good_images,
|
|
1006
|
+
balance_defects=balance_defects,
|
|
1007
|
+
annotated_good=annotated_good,
|
|
1008
|
+
output_folder=output_folder,
|
|
1009
|
+
labelled_masks_path=labelled_masks_path,
|
|
1010
|
+
sampling_dataset_folder=sampling_dataset_folder,
|
|
1011
|
+
)
|
|
1012
|
+
|
|
1013
|
+
for line in res:
|
|
1014
|
+
output_file.write(line)
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
def create_h5(
|
|
1018
|
+
data_dictionary: list[dict[Any, Any]],
|
|
1019
|
+
idx_to_class: dict,
|
|
1020
|
+
overlap: float,
|
|
1021
|
+
repeat_good_images: int,
|
|
1022
|
+
balance_defects: bool,
|
|
1023
|
+
output_folder: str,
|
|
1024
|
+
labelled_masks_path: str,
|
|
1025
|
+
sampling_dataset_folder: str,
|
|
1026
|
+
annotated_good: list[int] | None = None,
|
|
1027
|
+
patch_size: tuple[int, int] | None = None,
|
|
1028
|
+
patch_number: tuple[int, int] | None = None,
|
|
1029
|
+
) -> list[str]:
|
|
1030
|
+
"""Create h5 files for each image in the dataset.
|
|
1031
|
+
|
|
1032
|
+
Args:
|
|
1033
|
+
data_dictionary: Dictionary containing image and mask mapping
|
|
1034
|
+
idx_to_class: Dict mapping an index to the corresponding class name
|
|
1035
|
+
overlap: Percentage of overlap between patches
|
|
1036
|
+
repeat_good_images: Number of repetition for images with emtpy or None mask
|
|
1037
|
+
balance_defects: If true add one good entry for each defect extracted
|
|
1038
|
+
output_folder: root folder
|
|
1039
|
+
overlap: Percentage of overlap between patches
|
|
1040
|
+
annotated_good: List of class indices that are considered good other than the background
|
|
1041
|
+
labelled_masks_path: paths of labelled masks
|
|
1042
|
+
sampling_dataset_folder: folder of the dataset
|
|
1043
|
+
patch_size: Dimension of the patch, required if patch_number is None
|
|
1044
|
+
patch_number: Number of patches for each side, required if patch_size is None.
|
|
1045
|
+
|
|
1046
|
+
Returns:
|
|
1047
|
+
output_list: List of h5 files' names
|
|
1048
|
+
|
|
1049
|
+
"""
|
|
1050
|
+
if patch_number is None and patch_size is None:
|
|
1051
|
+
raise InvalidParameterCombinationException("One between patch number or patch size must be specified!")
|
|
1052
|
+
|
|
1053
|
+
output_list = []
|
|
1054
|
+
for item in tqdm(data_dictionary):
|
|
1055
|
+
log.debug("Processing %s", item["base_name"])
|
|
1056
|
+
# this works even if item["path"] is already an absolute path
|
|
1057
|
+
img = cv2.imread(os.path.join(output_folder, item["path"]))
|
|
1058
|
+
|
|
1059
|
+
h = img.shape[0]
|
|
1060
|
+
w = img.shape[1]
|
|
1061
|
+
|
|
1062
|
+
if item["mask"] is None:
|
|
1063
|
+
mask = np.zeros([h, w])
|
|
1064
|
+
else:
|
|
1065
|
+
# this works even if item["mask"] is already an absolute path
|
|
1066
|
+
mask = cv2.imread(os.path.join(output_folder, item["mask"]), 0) # type: ignore[assignment]
|
|
1067
|
+
|
|
1068
|
+
if patch_size is not None:
|
|
1069
|
+
patch_height = patch_size[1]
|
|
1070
|
+
patch_width = patch_size[0]
|
|
1071
|
+
else:
|
|
1072
|
+
# Mypy complains because patch_number is Optional, but we already checked that it is not None.
|
|
1073
|
+
[patch_height, patch_width], _ = compute_patch_info(
|
|
1074
|
+
h,
|
|
1075
|
+
w,
|
|
1076
|
+
patch_number[0], # type: ignore[index]
|
|
1077
|
+
patch_number[1], # type: ignore[index]
|
|
1078
|
+
overlap,
|
|
1079
|
+
)
|
|
1080
|
+
|
|
1081
|
+
h5_file_name_good = os.path.join(sampling_dataset_folder, f"{os.path.splitext(item['base_name'])[0]}_good.h5")
|
|
1082
|
+
|
|
1083
|
+
disable_good = False
|
|
1084
|
+
|
|
1085
|
+
with h5py.File(h5_file_name_good, "w") as f:
|
|
1086
|
+
f.create_dataset("img_path", data=item["path"])
|
|
1087
|
+
f.create_dataset("patch_size", data=np.array([patch_height, patch_width]))
|
|
1088
|
+
|
|
1089
|
+
target = idx_to_class[0]
|
|
1090
|
+
|
|
1091
|
+
if mask.sum() == 0:
|
|
1092
|
+
f.create_dataset("triangles", data=np.array([], dtype=np.uint8), dtype=np.uint8)
|
|
1093
|
+
f.create_dataset("triangles_weights", data=np.array([], dtype=np.uint8), dtype=np.uint8)
|
|
1094
|
+
|
|
1095
|
+
for _ in range(repeat_good_images):
|
|
1096
|
+
output_list.append(f"{os.path.basename(h5_file_name_good)},{target}\n")
|
|
1097
|
+
|
|
1098
|
+
continue
|
|
1099
|
+
|
|
1100
|
+
binary_mask = (mask > 0).astype(np.uint8)
|
|
1101
|
+
|
|
1102
|
+
# Dilate the defects and take the background
|
|
1103
|
+
binary_mask = np.logical_not(cv2.dilate(binary_mask, np.ones([patch_height, patch_width]))).astype(np.uint8)
|
|
1104
|
+
|
|
1105
|
+
temp_binary_mask = deepcopy(binary_mask)
|
|
1106
|
+
# Remove the edges of the image as they are unsafe for sampling without padding
|
|
1107
|
+
temp_binary_mask[0 : patch_height // 2, :] = 0
|
|
1108
|
+
temp_binary_mask[:, 0 : patch_width // 2] = 0
|
|
1109
|
+
temp_binary_mask[-patch_height // 2 :, :] = 0
|
|
1110
|
+
temp_binary_mask[:, -patch_width // 2 :] = 0
|
|
1111
|
+
|
|
1112
|
+
if temp_binary_mask.sum() != 0:
|
|
1113
|
+
# If the mask without the edges is not empty use it, otherwise use the original mask as it is not
|
|
1114
|
+
# possible to sample a patch that will not exceed the edges, this must be taken care by the patch
|
|
1115
|
+
# sampler used during training
|
|
1116
|
+
binary_mask = temp_binary_mask
|
|
1117
|
+
|
|
1118
|
+
# In the case of hx1 or 1xw number of patches we must make sure that the sampling row or the sampling
|
|
1119
|
+
# column is empty, if it isn't remove it from the possible sampling area
|
|
1120
|
+
if patch_height == img.shape[0]:
|
|
1121
|
+
must_clear_indices = np.where(binary_mask.sum(axis=0) != img.shape[0])[0]
|
|
1122
|
+
binary_mask[:, must_clear_indices] = 0
|
|
1123
|
+
|
|
1124
|
+
if patch_width == img.shape[1]:
|
|
1125
|
+
must_clear_indices = np.where(binary_mask.sum(axis=1) != img.shape[1])[0]
|
|
1126
|
+
binary_mask[must_clear_indices, :] = 0
|
|
1127
|
+
|
|
1128
|
+
# If there's no space for sampling good patches skip it
|
|
1129
|
+
if binary_mask.sum() == 0:
|
|
1130
|
+
disable_good = True
|
|
1131
|
+
else:
|
|
1132
|
+
triangles, weights = triangulate_region(binary_mask)
|
|
1133
|
+
if triangles is None:
|
|
1134
|
+
disable_good = True
|
|
1135
|
+
else:
|
|
1136
|
+
log.debug(
|
|
1137
|
+
"Saving %s triangles for %s with label %s",
|
|
1138
|
+
triangles.shape[0],
|
|
1139
|
+
os.path.basename(h5_file_name_good),
|
|
1140
|
+
target,
|
|
1141
|
+
)
|
|
1142
|
+
|
|
1143
|
+
f.create_dataset("mask_path", data=item["mask"])
|
|
1144
|
+
# Points from extracted triangles should be sufficiently far from all the defects allowing to sample
|
|
1145
|
+
# good patches almost all the time
|
|
1146
|
+
f.create_dataset("triangles", data=triangles, dtype=np.int32)
|
|
1147
|
+
f.create_dataset("triangles_weights", data=weights, dtype=np.float64)
|
|
1148
|
+
|
|
1149
|
+
# Avoid saving the good h5 file here because otherwise I'll have one more good compared to the
|
|
1150
|
+
# number of defects
|
|
1151
|
+
if not balance_defects:
|
|
1152
|
+
output_list.append(f"{os.path.basename(h5_file_name_good)},{target}\n")
|
|
1153
|
+
|
|
1154
|
+
if disable_good:
|
|
1155
|
+
os.remove(h5_file_name_good)
|
|
1156
|
+
|
|
1157
|
+
labelled_mask = label(mask)
|
|
1158
|
+
cv2.imwrite(os.path.join(labelled_masks_path, f"{os.path.splitext(item['base_name'])[0]}.png"), labelled_mask)
|
|
1159
|
+
|
|
1160
|
+
real_defects_mask = None
|
|
1161
|
+
|
|
1162
|
+
if annotated_good is not None:
|
|
1163
|
+
# Remove true defected area from the good labeled mask
|
|
1164
|
+
# If we want this to be even more restrictive we could also include the background as we don't know for sure
|
|
1165
|
+
# it will not contain any defects
|
|
1166
|
+
real_defects_mask = (~np.isin(mask, [0] + annotated_good)).astype(np.uint8)
|
|
1167
|
+
real_defects_mask = cv2.dilate(real_defects_mask, np.ones([patch_height, patch_width])).astype(bool)
|
|
1168
|
+
|
|
1169
|
+
for i in np.unique(labelled_mask):
|
|
1170
|
+
if i == 0:
|
|
1171
|
+
continue
|
|
1172
|
+
|
|
1173
|
+
current_mask = (labelled_mask == i).astype(np.uint8)
|
|
1174
|
+
target_idx = (mask * current_mask).max()
|
|
1175
|
+
|
|
1176
|
+
# When we have good annotations we want to avoid sampling patches containing true defects, to do so we
|
|
1177
|
+
# reduce the extraction area based on the area covered by the other defects
|
|
1178
|
+
if annotated_good is not None and real_defects_mask is not None and target_idx in annotated_good:
|
|
1179
|
+
# a - b = a & ~b
|
|
1180
|
+
# pylint: disable=invalid-unary-operand-type
|
|
1181
|
+
current_mask = np.bitwise_and(current_mask.astype(bool), ~real_defects_mask).astype(np.uint8)
|
|
1182
|
+
else:
|
|
1183
|
+
# When dealing with small defects the number of points that will be sampled will be limited and patches
|
|
1184
|
+
# will mostly be centered around the defect, to overcome this issue enlarge defect bounding box by 50%
|
|
1185
|
+
# of the difference between the patch_size and the defect bb size, we don't do this on good labels to
|
|
1186
|
+
# avoid invalidating the reduction applied before.
|
|
1187
|
+
props = regionprops(current_mask)[0]
|
|
1188
|
+
bbox_size = [props.bbox[2] - props.bbox[0], props.bbox[3] - props.bbox[1]]
|
|
1189
|
+
diff_bbox = np.array([max(0, patch_height - bbox_size[0]), max(0, patch_width - bbox_size[1])])
|
|
1190
|
+
|
|
1191
|
+
if diff_bbox[0] != 0:
|
|
1192
|
+
current_mask = cv2.dilate(current_mask, np.ones([diff_bbox[0] // 2, 1]))
|
|
1193
|
+
if diff_bbox[1] != 0:
|
|
1194
|
+
current_mask = cv2.dilate(current_mask, np.ones([1, diff_bbox[1] // 2]))
|
|
1195
|
+
|
|
1196
|
+
if current_mask.sum() == 0:
|
|
1197
|
+
# If it's not possible to sample a labelled good patch basically
|
|
1198
|
+
continue
|
|
1199
|
+
|
|
1200
|
+
temp_current_mask = deepcopy(current_mask)
|
|
1201
|
+
# Remove the edges of the image as they are unsafe for sampling without padding
|
|
1202
|
+
temp_current_mask[0 : patch_height // 2, :] = 0
|
|
1203
|
+
temp_current_mask[:, 0 : patch_width // 2] = 0
|
|
1204
|
+
temp_current_mask[-patch_height // 2 :, :] = 0
|
|
1205
|
+
temp_current_mask[:, -patch_width // 2 :] = 0
|
|
1206
|
+
|
|
1207
|
+
if temp_current_mask.sum() != 0:
|
|
1208
|
+
# If the mask without the edges is not empty use it, otherwise use the original mask as it is not
|
|
1209
|
+
# possible to sample a patch that will not exceed the edges, this must be taken care by the patch
|
|
1210
|
+
# sampler used during training
|
|
1211
|
+
current_mask = temp_current_mask
|
|
1212
|
+
|
|
1213
|
+
triangles, weights = triangulate_region(current_mask)
|
|
1214
|
+
|
|
1215
|
+
if triangles is not None:
|
|
1216
|
+
h5_file_name = os.path.join(sampling_dataset_folder, f"{os.path.splitext(item['base_name'])[0]}_{i}.h5")
|
|
1217
|
+
|
|
1218
|
+
target = idx_to_class[target_idx]
|
|
1219
|
+
|
|
1220
|
+
log.debug(
|
|
1221
|
+
"Saving %s triangles for %s with label %s",
|
|
1222
|
+
triangles.shape[0],
|
|
1223
|
+
os.path.basename(h5_file_name),
|
|
1224
|
+
target,
|
|
1225
|
+
)
|
|
1226
|
+
|
|
1227
|
+
with h5py.File(h5_file_name, "w") as f:
|
|
1228
|
+
f.create_dataset("img_path", data=item["path"])
|
|
1229
|
+
f.create_dataset("mask_path", data=item["mask"])
|
|
1230
|
+
f.create_dataset("patch_size", data=np.array([patch_height, patch_width]))
|
|
1231
|
+
f.create_dataset("triangles", data=triangles, dtype=np.int32)
|
|
1232
|
+
f.create_dataset("triangles_weights", data=weights, dtype=np.float64)
|
|
1233
|
+
f.create_dataset("labelled_index", data=i, dtype=np.int32)
|
|
1234
|
+
|
|
1235
|
+
if annotated_good is not None and target_idx in annotated_good:
|
|
1236
|
+
# I treat annotate good images exactly the same as I would treat background
|
|
1237
|
+
for _ in range(repeat_good_images):
|
|
1238
|
+
output_list.append(f"{os.path.basename(h5_file_name)},{target}\n")
|
|
1239
|
+
else:
|
|
1240
|
+
output_list.append(f"{os.path.basename(h5_file_name)},{target}\n")
|
|
1241
|
+
|
|
1242
|
+
if balance_defects:
|
|
1243
|
+
if not disable_good:
|
|
1244
|
+
output_list.append(f"{os.path.basename(h5_file_name_good)},{idx_to_class[0]}\n")
|
|
1245
|
+
else:
|
|
1246
|
+
log.debug(
|
|
1247
|
+
"Unable to add a good defect for %s, since there's no way to sample good patches",
|
|
1248
|
+
h5_file_name,
|
|
1249
|
+
)
|
|
1250
|
+
return output_list
|
|
1251
|
+
|
|
1252
|
+
|
|
1253
|
+
def triangle_area(triangle: np.ndarray) -> float:
|
|
1254
|
+
"""Compute the area of a triangle defined by 3 points.
|
|
1255
|
+
|
|
1256
|
+
Args:
|
|
1257
|
+
triangle: Array of shape 3x2 containing the coordinates of a triangle.
|
|
1258
|
+
|
|
1259
|
+
Returns:
|
|
1260
|
+
The area of the triangle
|
|
1261
|
+
|
|
1262
|
+
"""
|
|
1263
|
+
[y1, x1], [y2, x2], [y3, x3] = triangle
|
|
1264
|
+
return abs(0.5 * (((x2 - x1) * (y3 - y1)) - ((x3 - x1) * (y2 - y1))))
|
|
1265
|
+
|
|
1266
|
+
|
|
1267
|
+
def triangulate_region(mask: ndimage) -> tuple[np.ndarray | None, np.ndarray | None]:
|
|
1268
|
+
"""Extract from a binary image containing a single roi (with or without holes) a list of triangles
|
|
1269
|
+
(and their normalized area) that completely subdivide an approximated polygon defined around mask contours,
|
|
1270
|
+
the output can be used to easily sample uniformly points that are almost guarantee to lie inside the roi.
|
|
1271
|
+
|
|
1272
|
+
Args:
|
|
1273
|
+
mask: Binary image defining a region of interest
|
|
1274
|
+
|
|
1275
|
+
Returns:
|
|
1276
|
+
Tuple containing:
|
|
1277
|
+
triangles: a numpy array containing a list of list of vertices (y, x) of the triangles defined over a
|
|
1278
|
+
polygon that contains the entire region
|
|
1279
|
+
weights: areas of each triangle rescaled (area_i / sum(areas))
|
|
1280
|
+
|
|
1281
|
+
"""
|
|
1282
|
+
polygon_points, hier = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_TC89_L1)
|
|
1283
|
+
|
|
1284
|
+
if not np.all(hier[:, :, 3] == -1): # there are holes
|
|
1285
|
+
holes = ndimage.binary_fill_holes(mask).astype(np.uint8)
|
|
1286
|
+
holes -= mask
|
|
1287
|
+
holes = (holes > 0).astype(np.uint8)
|
|
1288
|
+
if holes.sum() > 0: # there are holes
|
|
1289
|
+
for hole in regionprops(label(holes)):
|
|
1290
|
+
y_hole_center = int(hole.centroid[0])
|
|
1291
|
+
mask[y_hole_center] = 0
|
|
1292
|
+
|
|
1293
|
+
polygon_points, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_TC89_L1)
|
|
1294
|
+
|
|
1295
|
+
final_approx = []
|
|
1296
|
+
|
|
1297
|
+
# Extract a simpler approximation of the contour
|
|
1298
|
+
for cnt in polygon_points:
|
|
1299
|
+
epsilon = 0.01 * cv2.arcLength(cnt, True)
|
|
1300
|
+
approx = cv2.approxPolyDP(cnt, epsilon, True)
|
|
1301
|
+
final_approx.append(approx)
|
|
1302
|
+
|
|
1303
|
+
triangles = None
|
|
1304
|
+
|
|
1305
|
+
for approx in final_approx:
|
|
1306
|
+
contours_tripy = [x[0] for x in approx]
|
|
1307
|
+
current_triangles = earclip(contours_tripy)
|
|
1308
|
+
|
|
1309
|
+
if len(current_triangles) == 0:
|
|
1310
|
+
# This can only happen is a defect is like one pixel wide...
|
|
1311
|
+
continue
|
|
1312
|
+
|
|
1313
|
+
current_triangles = np.array([list(x) for x in current_triangles])
|
|
1314
|
+
|
|
1315
|
+
triangles = current_triangles if triangles is None else np.concatenate([triangles, current_triangles])
|
|
1316
|
+
|
|
1317
|
+
if triangles is None:
|
|
1318
|
+
return None, None
|
|
1319
|
+
|
|
1320
|
+
# Swap x and y to match cv2
|
|
1321
|
+
triangles = triangles[..., ::-1]
|
|
1322
|
+
|
|
1323
|
+
weights = np.array([triangle_area(x) for x in triangles])
|
|
1324
|
+
weights = weights / weights.sum()
|
|
1325
|
+
|
|
1326
|
+
return triangles, weights
|
|
1327
|
+
|
|
1328
|
+
|
|
1329
|
+
class InvalidParameterCombinationException(Exception):
|
|
1330
|
+
"""Exception raised when an invalid combination of parameters is passed to a function."""
|
|
1331
|
+
|
|
1332
|
+
|
|
1333
|
+
class InvalidNumWorkersNumberException(Exception):
|
|
1334
|
+
"""Exception raised when an invalid number of workers is passed to a function."""
|
|
1335
|
+
|
|
1336
|
+
|
|
1337
|
+
def load_train_file(
|
|
1338
|
+
train_file_path: str,
|
|
1339
|
+
include_filter: list[str] | None = None,
|
|
1340
|
+
exclude_filter: list[str] | None = None,
|
|
1341
|
+
class_to_skip: list | None = None,
|
|
1342
|
+
) -> tuple[list[str], list[str]]:
|
|
1343
|
+
"""Load a train file and return a list of samples and a list of targets. It is expected that train files will be in
|
|
1344
|
+
the same location as the train_file_path.
|
|
1345
|
+
|
|
1346
|
+
Args:
|
|
1347
|
+
train_file_path: Training file location
|
|
1348
|
+
include_filter: Include only samples that contain one of the element of this list
|
|
1349
|
+
exclude_filter: Exclude all samples that contain one of the element of this list
|
|
1350
|
+
class_to_skip: if not None, exlude all the samples with labels present in this list.
|
|
1351
|
+
|
|
1352
|
+
Returns:
|
|
1353
|
+
List of samples and list of targets
|
|
1354
|
+
|
|
1355
|
+
"""
|
|
1356
|
+
samples = []
|
|
1357
|
+
targets = []
|
|
1358
|
+
|
|
1359
|
+
with open(train_file_path) as f:
|
|
1360
|
+
lines = f.read().splitlines()
|
|
1361
|
+
for line in lines:
|
|
1362
|
+
sample, target = line.split(",")
|
|
1363
|
+
if class_to_skip is not None and target in class_to_skip:
|
|
1364
|
+
continue
|
|
1365
|
+
samples.append(sample)
|
|
1366
|
+
targets.append(target)
|
|
1367
|
+
|
|
1368
|
+
include_filter = [] if include_filter is None else include_filter
|
|
1369
|
+
exclude_filter = [] if exclude_filter is None else exclude_filter
|
|
1370
|
+
|
|
1371
|
+
valid_samples_indices = [
|
|
1372
|
+
i
|
|
1373
|
+
for (i, x) in enumerate(samples)
|
|
1374
|
+
if (len(include_filter) == 0 or any(f in x for f in include_filter))
|
|
1375
|
+
and (len(exclude_filter) == 0 or not any(f in x for f in exclude_filter))
|
|
1376
|
+
]
|
|
1377
|
+
|
|
1378
|
+
samples = [samples[i] for i in valid_samples_indices]
|
|
1379
|
+
targets = [targets[i] for i in valid_samples_indices]
|
|
1380
|
+
|
|
1381
|
+
train_folder = os.path.dirname(train_file_path)
|
|
1382
|
+
samples = [os.path.join(train_folder, x) for x in samples]
|
|
1383
|
+
|
|
1384
|
+
return samples, targets
|
|
1385
|
+
|
|
1386
|
+
|
|
1387
|
+
def compute_safe_patch_range(sampled_point: int, patch_size: int, image_size: int) -> tuple[int, int]:
|
|
1388
|
+
"""Computes the safe patch size for the given image size.
|
|
1389
|
+
|
|
1390
|
+
Args:
|
|
1391
|
+
sampled_point: the sampled point
|
|
1392
|
+
patch_size: the size of the patch
|
|
1393
|
+
image_size: the size of the image.
|
|
1394
|
+
|
|
1395
|
+
Returns:
|
|
1396
|
+
Tuple containing the safe patch range [left, right] such that
|
|
1397
|
+
[sampled_point - left : sampled_point + right] will be within the image size.
|
|
1398
|
+
"""
|
|
1399
|
+
left = patch_size // 2
|
|
1400
|
+
right = patch_size // 2
|
|
1401
|
+
|
|
1402
|
+
if sampled_point + right > image_size:
|
|
1403
|
+
right = image_size - sampled_point
|
|
1404
|
+
left = patch_size - right
|
|
1405
|
+
|
|
1406
|
+
if sampled_point - left < 0:
|
|
1407
|
+
left = sampled_point
|
|
1408
|
+
right = patch_size - left
|
|
1409
|
+
|
|
1410
|
+
return left, right
|
|
1411
|
+
|
|
1412
|
+
|
|
1413
|
+
def trisample(triangle: np.ndarray) -> tuple[int, int]:
|
|
1414
|
+
"""Sample a point uniformly in a triangle.
|
|
1415
|
+
|
|
1416
|
+
Args:
|
|
1417
|
+
triangle: Array of shape 3x2 containing the coordinates of a triangle.
|
|
1418
|
+
|
|
1419
|
+
Returns:
|
|
1420
|
+
Sample point uniformly in the triangle
|
|
1421
|
+
|
|
1422
|
+
"""
|
|
1423
|
+
[y1, x1], [y2, x2], [y3, x3] = triangle
|
|
1424
|
+
|
|
1425
|
+
r1 = random.random()
|
|
1426
|
+
r2 = random.random()
|
|
1427
|
+
|
|
1428
|
+
s1 = math.sqrt(r1)
|
|
1429
|
+
|
|
1430
|
+
x = x1 * (1.0 - s1) + x2 * (1.0 - r2) * s1 + x3 * r2 * s1
|
|
1431
|
+
y = y1 * (1.0 - s1) + y2 * (1.0 - r2) * s1 + y3 * r2 * s1
|
|
1432
|
+
|
|
1433
|
+
return int(y), int(x)
|