quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +32 -0
- quadra/configs/callbacks/default.yaml +37 -0
- quadra/configs/callbacks/default_anomalib.yaml +67 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +46 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +331 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1264 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +579 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +524 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.1.13.dist-info/LICENSE +201 -0
- quadra-2.1.13.dist-info/METADATA +386 -0
- quadra-2.1.13.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
- quadra-2.1.13.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,618 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import os
|
|
5
|
+
import random
|
|
6
|
+
import re
|
|
7
|
+
from collections.abc import Generator, Sequence
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
import matplotlib.pyplot as plt
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import torch
|
|
14
|
+
from omegaconf import DictConfig
|
|
15
|
+
from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score, classification_report, confusion_matrix
|
|
16
|
+
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
|
|
17
|
+
from torch.utils.data import DataLoader
|
|
18
|
+
|
|
19
|
+
from quadra.models.base import ModelSignatureWrapper
|
|
20
|
+
from quadra.utils import utils
|
|
21
|
+
from quadra.utils.models import get_feature
|
|
22
|
+
from quadra.utils.visualization import UnNormalize, plot_classification_results
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from quadra.datamodules.classification import SklearnClassificationDataModule
|
|
26
|
+
from quadra.datamodules.patch import PatchSklearnClassificationDataModule
|
|
27
|
+
|
|
28
|
+
log = utils.get_logger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_file_condition(
|
|
32
|
+
file_name: str, root: str, exclude_filter: list[str] | None = None, include_filter: list[str] | None = None
|
|
33
|
+
):
|
|
34
|
+
"""Check if a file should be included or excluded based on the filters provided.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
file_name: Name of the file
|
|
38
|
+
root: Root directory of the file
|
|
39
|
+
exclude_filter: List of string filter to be used to exclude images. If None no filter will be applied.
|
|
40
|
+
include_filter: List of string filter to be used to include images. If None no filter will be applied.
|
|
41
|
+
"""
|
|
42
|
+
if exclude_filter is not None:
|
|
43
|
+
if any(fil in file_name for fil in exclude_filter):
|
|
44
|
+
return False
|
|
45
|
+
|
|
46
|
+
if any(fil in root for fil in exclude_filter):
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
if include_filter is not None and (
|
|
50
|
+
not any(fil in file_name for fil in include_filter) and not any(fil in root for fil in include_filter)
|
|
51
|
+
):
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
return True
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def natural_key(string_):
|
|
58
|
+
"""See http://www.codinghorror.com/blog/archives/001018.html."""
|
|
59
|
+
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def find_images_and_targets(
|
|
63
|
+
folder: str,
|
|
64
|
+
types: list | None = None,
|
|
65
|
+
class_to_idx: dict[str, int] | None = None,
|
|
66
|
+
leaf_name_only: bool = True,
|
|
67
|
+
sort: bool = True,
|
|
68
|
+
exclude_filter: list | None = None,
|
|
69
|
+
include_filter: list | None = None,
|
|
70
|
+
label_map: dict[str, Any] | None = None,
|
|
71
|
+
) -> tuple[np.ndarray, np.ndarray, dict]:
|
|
72
|
+
"""Given a folder, extract the absolute path of all the files with a valid extension.
|
|
73
|
+
Then assign a label based on subfolder name.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
folder: path to main folder
|
|
77
|
+
types: valid file extentions
|
|
78
|
+
class_to_idx: dictionary of conversion btw folder name and index.
|
|
79
|
+
Only file whose label is in dictionary key list will be considered. If None all files will
|
|
80
|
+
be considered and a custom conversion is created.
|
|
81
|
+
leaf_name_only: if True use only the leaf folder name as label, otherwise use the full path
|
|
82
|
+
sort: if True sort the images and labels based on the image name
|
|
83
|
+
exclude_filter: list of string filter to be used to exclude images.
|
|
84
|
+
If None no filter will be applied.
|
|
85
|
+
include_filter: list of string filder to be used to include images.
|
|
86
|
+
Only images that satisfied at list one of the filter will be included.
|
|
87
|
+
label_map: dictionary of conversion btw folder name and label.
|
|
88
|
+
"""
|
|
89
|
+
if types is None:
|
|
90
|
+
types = [".png", ".jpg", ".jpeg", ".bmp"]
|
|
91
|
+
labels = []
|
|
92
|
+
filenames = []
|
|
93
|
+
|
|
94
|
+
for root, _, files in os.walk(folder, topdown=False, followlinks=True):
|
|
95
|
+
if root != folder:
|
|
96
|
+
rel_path = os.path.relpath(root, folder)
|
|
97
|
+
else:
|
|
98
|
+
rel_path = ""
|
|
99
|
+
|
|
100
|
+
if leaf_name_only:
|
|
101
|
+
label = os.path.basename(rel_path)
|
|
102
|
+
else:
|
|
103
|
+
aa = rel_path.split(os.path.sep)
|
|
104
|
+
if len(aa) == 2:
|
|
105
|
+
aa = aa[-1:]
|
|
106
|
+
else:
|
|
107
|
+
aa = aa[-2:]
|
|
108
|
+
label = "_".join(aa) # rel_path.replace(os.path.sep, "_")
|
|
109
|
+
# label = rel_path.replace(os.path.sep, "_")
|
|
110
|
+
|
|
111
|
+
for f in files:
|
|
112
|
+
if not get_file_condition(
|
|
113
|
+
file_name=f, root=root, exclude_filter=exclude_filter, include_filter=include_filter
|
|
114
|
+
):
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
if f.startswith(".") or "checkpoint" in f:
|
|
118
|
+
continue
|
|
119
|
+
_, ext = os.path.splitext(f)
|
|
120
|
+
if ext.lower() in types:
|
|
121
|
+
filenames.append(os.path.join(root, f))
|
|
122
|
+
labels.append(label)
|
|
123
|
+
|
|
124
|
+
if label_map is not None:
|
|
125
|
+
labels, _ = group_labels(labels, label_map)
|
|
126
|
+
|
|
127
|
+
if class_to_idx is None:
|
|
128
|
+
# building class index
|
|
129
|
+
unique_labels = set(labels)
|
|
130
|
+
sorted_labels = sorted(unique_labels, key=natural_key)
|
|
131
|
+
class_to_idx = {str(c): idx for idx, c in enumerate(sorted_labels)}
|
|
132
|
+
|
|
133
|
+
images_and_targets = [(f, l) for f, l in zip(filenames, labels) if l in class_to_idx]
|
|
134
|
+
|
|
135
|
+
if sort:
|
|
136
|
+
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
|
|
137
|
+
|
|
138
|
+
return np.array(images_and_targets)[:, 0], np.array(images_and_targets)[:, 1], class_to_idx
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def find_test_image(
|
|
142
|
+
folder: str,
|
|
143
|
+
types: list[str] | None = None,
|
|
144
|
+
exclude_filter: list[str] | None = None,
|
|
145
|
+
include_filter: list[str] | None = None,
|
|
146
|
+
include_none_class: bool = True,
|
|
147
|
+
test_split_file: str | None = None,
|
|
148
|
+
label_map=None,
|
|
149
|
+
) -> tuple[list[str], list[str | None]]:
|
|
150
|
+
"""Given a path extract images and labels with filters, labels are based on the parent folder name of the images
|
|
151
|
+
Args:
|
|
152
|
+
folder: root directory containing the images
|
|
153
|
+
types: only choose images with the extensions specified, if None use default extensions
|
|
154
|
+
exclude_filter: list of string filter to be used to exclude images. If None no filter will be applied.
|
|
155
|
+
include_filter: list of string filter to be used to include images. If None no filter will be applied.
|
|
156
|
+
include_none_class: if set to True convert all 'None' labels to None, otherwise ignore the image
|
|
157
|
+
test_split_file: if defined use the split defined inside the file
|
|
158
|
+
Returns:
|
|
159
|
+
Two lists, one containing the images path and the other one containing the labels. Labels can be None.
|
|
160
|
+
"""
|
|
161
|
+
if types is None:
|
|
162
|
+
types = [".png", ".jpg", ".jpeg", ".bmp"]
|
|
163
|
+
|
|
164
|
+
labels = []
|
|
165
|
+
filenames = []
|
|
166
|
+
|
|
167
|
+
for root, _, files in os.walk(folder, topdown=False, followlinks=True):
|
|
168
|
+
rel_path = os.path.relpath(root, folder) if root != folder else ""
|
|
169
|
+
label: str | None = os.path.basename(rel_path)
|
|
170
|
+
for f in files:
|
|
171
|
+
if not get_file_condition(
|
|
172
|
+
file_name=f, root=root, exclude_filter=exclude_filter, include_filter=include_filter
|
|
173
|
+
):
|
|
174
|
+
continue
|
|
175
|
+
if f.startswith(".") or "checkpoint" in f:
|
|
176
|
+
continue
|
|
177
|
+
_, ext = os.path.splitext(f)
|
|
178
|
+
if ext.lower() in types:
|
|
179
|
+
if label == "None":
|
|
180
|
+
if include_none_class:
|
|
181
|
+
label = None
|
|
182
|
+
else:
|
|
183
|
+
continue
|
|
184
|
+
filenames.append(os.path.join(root, f))
|
|
185
|
+
labels.append(label)
|
|
186
|
+
|
|
187
|
+
if test_split_file is not None:
|
|
188
|
+
if not os.path.isabs(test_split_file):
|
|
189
|
+
log.info(
|
|
190
|
+
"test_split_file is not an absolute path. Trying to using folder argument %s as parent folder", folder
|
|
191
|
+
)
|
|
192
|
+
test_split_file = os.path.join(folder, test_split_file)
|
|
193
|
+
|
|
194
|
+
if not os.path.exists(test_split_file):
|
|
195
|
+
raise FileNotFoundError(f"test_split_file {test_split_file} does not exist")
|
|
196
|
+
|
|
197
|
+
with open(test_split_file) as test_file:
|
|
198
|
+
test_split = test_file.read().splitlines()
|
|
199
|
+
|
|
200
|
+
file_samples = []
|
|
201
|
+
for row in test_split:
|
|
202
|
+
csv_values = row.split(",")
|
|
203
|
+
if len(csv_values) == 1:
|
|
204
|
+
# ensuring backward compatibility with old split file format
|
|
205
|
+
# old_format: sample, new_format: sample,class
|
|
206
|
+
sample_path = os.path.join(folder, csv_values[0])
|
|
207
|
+
else:
|
|
208
|
+
sample_path = os.path.join(folder, ",".join(csv_values[:-1]))
|
|
209
|
+
|
|
210
|
+
file_samples.append(sample_path)
|
|
211
|
+
|
|
212
|
+
test_split = [os.path.join(folder, sample.strip()) for sample in file_samples]
|
|
213
|
+
labels = [t for s, t in zip(filenames, labels) if s in file_samples]
|
|
214
|
+
filenames = [s for s in filenames if s in file_samples]
|
|
215
|
+
log.info("Selected %d images using test_split_file for the test", len(filenames))
|
|
216
|
+
if len(filenames) != len(file_samples):
|
|
217
|
+
log.warning(
|
|
218
|
+
"test_split_file contains %d images but only %d images were found in the folder."
|
|
219
|
+
"This may be due to duplicate lines in the test_split_file.",
|
|
220
|
+
len(file_samples),
|
|
221
|
+
len(filenames),
|
|
222
|
+
)
|
|
223
|
+
else:
|
|
224
|
+
log.info("No test_split_file. Selected all %s images for the test", folder)
|
|
225
|
+
|
|
226
|
+
if label_map is not None:
|
|
227
|
+
labels, _ = group_labels(labels, label_map)
|
|
228
|
+
|
|
229
|
+
return filenames, labels
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def group_labels(labels: Sequence[str | None], class_mapping: dict[str, str | None | list[str]]) -> tuple[list, dict]:
|
|
233
|
+
"""Group labels based on class_mapping.
|
|
234
|
+
|
|
235
|
+
Raises:
|
|
236
|
+
ValueError: if a label is not in class_mapping
|
|
237
|
+
ValueError: if a label is in class_mapping but has no corresponding value
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
List of labels and a dictionary of labels and their corresponding group
|
|
241
|
+
|
|
242
|
+
Example:
|
|
243
|
+
```python
|
|
244
|
+
grouped_labels, class_to_idx = group_labels(labels, class_mapping={"Good": "A", "Bad": None})
|
|
245
|
+
assert grouped_labels.count("Good") == labels.count("A")
|
|
246
|
+
assert len(class_to_idx.keys()) == 2
|
|
247
|
+
|
|
248
|
+
grouped_labels, class_to_idx = group_labels(labels, class_mapping={"Good": "A", "Defect": "B", "Bad": None})
|
|
249
|
+
assert grouped_labels.count("Bad") == labels.count("C") + labels.count("D")
|
|
250
|
+
assert len(class_to_idx.keys()) == 3
|
|
251
|
+
|
|
252
|
+
grouped_labels, class_to_idx = group_labels(labels, class_mapping={"Good": "A", "Bad": ["B", "C", "D"]})
|
|
253
|
+
assert grouped_labels.count("Bad") == labels.count("B") + labels.count("C") + labels.count("D")
|
|
254
|
+
assert len(class_to_idx.keys()) == 2
|
|
255
|
+
```
|
|
256
|
+
"""
|
|
257
|
+
grouped_labels = []
|
|
258
|
+
specified_targets = [k for k in class_mapping if class_mapping[k] is not None]
|
|
259
|
+
non_specified_targets = [k for k in class_mapping if class_mapping[k] is None]
|
|
260
|
+
if len(non_specified_targets) > 1:
|
|
261
|
+
raise ValueError(f"More than one non specified target: {non_specified_targets}")
|
|
262
|
+
for label in labels:
|
|
263
|
+
found = False
|
|
264
|
+
for target in specified_targets:
|
|
265
|
+
if not found:
|
|
266
|
+
current_mapping = class_mapping[target]
|
|
267
|
+
if current_mapping is None:
|
|
268
|
+
continue
|
|
269
|
+
|
|
270
|
+
if any(label in list(related_label) for related_label in current_mapping if related_label is not None):
|
|
271
|
+
grouped_labels.append(target)
|
|
272
|
+
found = True
|
|
273
|
+
if not found:
|
|
274
|
+
if len(non_specified_targets) > 0:
|
|
275
|
+
grouped_labels.append(non_specified_targets[0])
|
|
276
|
+
else:
|
|
277
|
+
raise ValueError(f"No target found for label: {label}")
|
|
278
|
+
class_to_idx = {k: i for i, k in enumerate(class_mapping.keys())}
|
|
279
|
+
return grouped_labels, class_to_idx
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def filter_with_file(list_of_full_paths: list[str], file_path: str, root_path: str) -> tuple[list[str], list[bool]]:
|
|
283
|
+
"""Filter a list of items using a file containing the items to keep. Paths inside file
|
|
284
|
+
should be relative to root_path not absolute to avoid user related issues.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
list_of_full_paths: list of items to filter
|
|
288
|
+
file_path: path to the file containing the items to keep
|
|
289
|
+
root_path: root path of the dataset
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
list of items to keep
|
|
293
|
+
the mask list to apply different lists later.
|
|
294
|
+
"""
|
|
295
|
+
filtered_full_paths = []
|
|
296
|
+
filter_mask = []
|
|
297
|
+
|
|
298
|
+
with open(file_path) as f:
|
|
299
|
+
for relative_path in f.read().splitlines():
|
|
300
|
+
full_path = os.path.join(root_path, relative_path)
|
|
301
|
+
if full_path in list_of_full_paths:
|
|
302
|
+
filtered_full_paths.append(full_path)
|
|
303
|
+
filter_mask.append(True)
|
|
304
|
+
else:
|
|
305
|
+
filter_mask.append(False)
|
|
306
|
+
|
|
307
|
+
return filtered_full_paths, filter_mask
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def get_split(
|
|
311
|
+
image_dir: str,
|
|
312
|
+
exclude_filter: list[str] | None = None,
|
|
313
|
+
include_filter: list[str] | None = None,
|
|
314
|
+
test_size: float = 0.3,
|
|
315
|
+
random_state: int = 42,
|
|
316
|
+
class_to_idx: dict[str, int] | None = None,
|
|
317
|
+
label_map: dict | None = None,
|
|
318
|
+
n_splits: int = 1,
|
|
319
|
+
include_none_class: bool = False,
|
|
320
|
+
limit_training_data: int | None = None,
|
|
321
|
+
train_split_file: str | None = None,
|
|
322
|
+
) -> tuple[np.ndarray, np.ndarray, Generator[list, None, None], dict]:
|
|
323
|
+
"""Given a folder, extract the absolute path of all the files with a valid extension and name
|
|
324
|
+
and split them into train/test.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
image_dir: Path to the folder containing the images
|
|
328
|
+
exclude_filter: List of file name filter to be excluded: If None no filter will be applied
|
|
329
|
+
include_filter: List of file name filter to be included: If None no filter will be applied
|
|
330
|
+
test_size: Percentage of data to be used for test
|
|
331
|
+
random_state: Random state to be used for reproducibility
|
|
332
|
+
class_to_idx: Dictionary of conversion btw folder name and index.
|
|
333
|
+
Only file whose label is in dictionary key list will be considered.
|
|
334
|
+
If None all files will be considered and a custom conversion is created.
|
|
335
|
+
label_map: Dictionary of conversion btw folder name and label.
|
|
336
|
+
n_splits: Number of dataset subdivision (default 1 -> train/test)
|
|
337
|
+
include_none_class: If set to True convert all 'None' labels to None
|
|
338
|
+
limit_training_data: If set to a value, limit the number of training samples to this value
|
|
339
|
+
train_split_file: If set to a path, use the file to split the dataset
|
|
340
|
+
"""
|
|
341
|
+
# TODO: Why is include_none_class not used?
|
|
342
|
+
# pylint: disable=unused-argument
|
|
343
|
+
assert os.path.isdir(image_dir), f"Folder {image_dir} does not exist."
|
|
344
|
+
# Get samples and target
|
|
345
|
+
samples, targets, class_to_idx = find_images_and_targets(
|
|
346
|
+
folder=image_dir,
|
|
347
|
+
exclude_filter=exclude_filter,
|
|
348
|
+
include_filter=include_filter,
|
|
349
|
+
class_to_idx=class_to_idx,
|
|
350
|
+
label_map=label_map,
|
|
351
|
+
# include_none_class=include_none_class,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
cl, counts = np.unique(targets, return_counts=True)
|
|
355
|
+
|
|
356
|
+
for num, _cl in zip(counts, cl):
|
|
357
|
+
if num == 1:
|
|
358
|
+
to_remove = np.where(np.array(targets) == _cl)[0][0]
|
|
359
|
+
samples = np.delete(np.array(samples), to_remove)
|
|
360
|
+
targets = np.delete(np.array(targets), to_remove)
|
|
361
|
+
class_to_idx.pop(_cl)
|
|
362
|
+
|
|
363
|
+
if train_split_file is not None:
|
|
364
|
+
with open(train_split_file) as f:
|
|
365
|
+
train_split = f.read().splitlines()
|
|
366
|
+
|
|
367
|
+
file_samples = []
|
|
368
|
+
for row in train_split:
|
|
369
|
+
csv_values = row.split(",")
|
|
370
|
+
|
|
371
|
+
if len(csv_values) == 1:
|
|
372
|
+
# ensuring backward compatibility with the old split file format
|
|
373
|
+
# old_format: sample, new_format: sample,class
|
|
374
|
+
sample_path = os.path.join(image_dir, csv_values[0])
|
|
375
|
+
else:
|
|
376
|
+
sample_path = os.path.join(image_dir, ",".join(csv_values[:-1]))
|
|
377
|
+
|
|
378
|
+
file_samples.append(sample_path)
|
|
379
|
+
|
|
380
|
+
train_split = [os.path.join(image_dir, sample.strip()) for sample in file_samples]
|
|
381
|
+
targets = np.array([t for s, t in zip(samples, targets) if s in file_samples])
|
|
382
|
+
samples = np.array([s for s in samples if s in file_samples])
|
|
383
|
+
|
|
384
|
+
if limit_training_data is not None:
|
|
385
|
+
idx_to_keep = []
|
|
386
|
+
for cl in np.unique(targets):
|
|
387
|
+
cl_idx = np.where(np.array(targets) == cl)[0].tolist()
|
|
388
|
+
random.seed(random_state)
|
|
389
|
+
random.shuffle(cl_idx)
|
|
390
|
+
idx_to_keep.extend(cl_idx[:limit_training_data])
|
|
391
|
+
|
|
392
|
+
samples = np.asarray([samples[i] for i in idx_to_keep])
|
|
393
|
+
targets = np.asarray([targets[i] for i in idx_to_keep])
|
|
394
|
+
|
|
395
|
+
_, counts = np.unique(targets, return_counts=True)
|
|
396
|
+
|
|
397
|
+
if n_splits == 1:
|
|
398
|
+
split_technique = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
|
|
399
|
+
else:
|
|
400
|
+
split_technique = StratifiedKFold(n_splits=n_splits, random_state=random_state, shuffle=True)
|
|
401
|
+
|
|
402
|
+
split = split_technique.split(samples, targets)
|
|
403
|
+
|
|
404
|
+
return np.array(samples), np.array(targets), split, class_to_idx
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def save_classification_result(
|
|
408
|
+
results: pd.DataFrame,
|
|
409
|
+
output_folder: str,
|
|
410
|
+
test_dataloader: DataLoader,
|
|
411
|
+
config: DictConfig,
|
|
412
|
+
output: DictConfig,
|
|
413
|
+
accuracy: float | None = None,
|
|
414
|
+
confmat: pd.DataFrame | None = None,
|
|
415
|
+
grayscale_cams: np.ndarray | None = None,
|
|
416
|
+
):
|
|
417
|
+
"""Save csv results, confusion matrix and example images.
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
results: Dataframe containing the results
|
|
421
|
+
output_folder: Path to the output folder
|
|
422
|
+
confmat: Confusion matrix in a pandas dataframe, may be None if all test labels are unknown
|
|
423
|
+
accuracy: Accuracy of the model, is None if all test labels are unknown
|
|
424
|
+
test_dataloader: Dataloader used for testing
|
|
425
|
+
config: Configuration file
|
|
426
|
+
output: Output configuration
|
|
427
|
+
grayscale_cams: List of grayscale grad_cam outputs ordered as the results
|
|
428
|
+
"""
|
|
429
|
+
# Save csv
|
|
430
|
+
results.to_csv(os.path.join(output_folder, "test_results.csv"), index_label="index")
|
|
431
|
+
if grayscale_cams is None:
|
|
432
|
+
log.info("Plotting only original examples, set gradcam = true in config file to also plot gradcam examples")
|
|
433
|
+
|
|
434
|
+
save_gradcams = False
|
|
435
|
+
else:
|
|
436
|
+
log.info("Plotting original and gradcam examples")
|
|
437
|
+
save_gradcams = True
|
|
438
|
+
|
|
439
|
+
if confmat is not None and accuracy is not None:
|
|
440
|
+
# Save confusion matrix
|
|
441
|
+
disp = ConfusionMatrixDisplay(
|
|
442
|
+
confusion_matrix=np.array(confmat),
|
|
443
|
+
display_labels=[x.replace("pred:", "") for x in confmat.columns.to_list()],
|
|
444
|
+
)
|
|
445
|
+
disp.plot(include_values=True, cmap=plt.cm.Greens, ax=None, colorbar=False, xticks_rotation=90)
|
|
446
|
+
plt.title(f"Confusion Matrix (Accuracy: {(accuracy * 100):.2f}%)")
|
|
447
|
+
plt.savefig(
|
|
448
|
+
os.path.join(output_folder, "test_confusion_matrix.png"),
|
|
449
|
+
bbox_inches="tight",
|
|
450
|
+
pad_inches=0,
|
|
451
|
+
dpi=300,
|
|
452
|
+
)
|
|
453
|
+
plt.close()
|
|
454
|
+
|
|
455
|
+
if output is not None and output.example:
|
|
456
|
+
log.info("Saving discordant/concordant examples in test folder")
|
|
457
|
+
idx_to_class = test_dataloader.dataset.idx_to_class # type: ignore[attr-defined]
|
|
458
|
+
|
|
459
|
+
# Get misclassified samples
|
|
460
|
+
images_folder = os.path.join(output_folder, "example")
|
|
461
|
+
if not os.path.isdir(images_folder):
|
|
462
|
+
os.makedirs(images_folder)
|
|
463
|
+
original_images_folder = os.path.join(images_folder, "original")
|
|
464
|
+
if not os.path.isdir(original_images_folder):
|
|
465
|
+
os.makedirs(original_images_folder)
|
|
466
|
+
|
|
467
|
+
gradcam_folder = os.path.join(images_folder, "gradcam")
|
|
468
|
+
if save_gradcams and not os.path.isdir(gradcam_folder):
|
|
469
|
+
os.makedirs(gradcam_folder)
|
|
470
|
+
|
|
471
|
+
for v in np.unique([results["real_label"], results["pred_label"]]):
|
|
472
|
+
if np.isnan(v) or v == -1:
|
|
473
|
+
continue
|
|
474
|
+
|
|
475
|
+
k = idx_to_class[v]
|
|
476
|
+
plot_classification_results(
|
|
477
|
+
test_dataloader.dataset,
|
|
478
|
+
unorm=UnNormalize(mean=config.transforms.mean, std=config.transforms.std),
|
|
479
|
+
pred_labels=results["pred_label"].to_numpy(),
|
|
480
|
+
test_labels=results["real_label"].to_numpy(),
|
|
481
|
+
grayscale_cams=grayscale_cams,
|
|
482
|
+
class_name=k,
|
|
483
|
+
original_folder=original_images_folder,
|
|
484
|
+
gradcam_folder=gradcam_folder,
|
|
485
|
+
idx_to_class=idx_to_class,
|
|
486
|
+
pred_class_to_plot=v,
|
|
487
|
+
what="con",
|
|
488
|
+
rows=output.get("rows", 3),
|
|
489
|
+
cols=output.get("cols", 2),
|
|
490
|
+
figsize=output.get("figsize", (20, 20)),
|
|
491
|
+
gradcam=save_gradcams,
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
plot_classification_results(
|
|
495
|
+
test_dataloader.dataset,
|
|
496
|
+
unorm=UnNormalize(mean=config.transforms.mean, std=config.transforms.std),
|
|
497
|
+
pred_labels=results["pred_label"].to_numpy(),
|
|
498
|
+
test_labels=results["real_label"].to_numpy(),
|
|
499
|
+
grayscale_cams=grayscale_cams,
|
|
500
|
+
class_name=k,
|
|
501
|
+
original_folder=original_images_folder,
|
|
502
|
+
gradcam_folder=gradcam_folder,
|
|
503
|
+
idx_to_class=idx_to_class,
|
|
504
|
+
pred_class_to_plot=v,
|
|
505
|
+
what="dis",
|
|
506
|
+
rows=output.get("rows", 3),
|
|
507
|
+
cols=output.get("cols", 2),
|
|
508
|
+
figsize=output.get("figsize", (20, 20)),
|
|
509
|
+
gradcam=save_gradcams,
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
else:
|
|
513
|
+
log.info("Not generating discordant/concordant examples. Check task.output.example in config file")
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def get_results(
|
|
517
|
+
test_labels: np.ndarray | list[int],
|
|
518
|
+
pred_labels: np.ndarray | list[int],
|
|
519
|
+
idx_to_labels: dict | None = None,
|
|
520
|
+
cl_rep_digits: int = 3,
|
|
521
|
+
) -> tuple[str | dict, pd.DataFrame, float]:
|
|
522
|
+
"""Get prediction results from predicted and test labels.
|
|
523
|
+
|
|
524
|
+
Args:
|
|
525
|
+
test_labels : test labels
|
|
526
|
+
pred_labels : predicted labels
|
|
527
|
+
idx_to_labels : dictionary mapping indices to labels
|
|
528
|
+
cl_rep_digits : number of digits to use in the classification report. Default: 3
|
|
529
|
+
|
|
530
|
+
Returns:
|
|
531
|
+
A tuple that contains classification report as dictionary, `cm` is a pd.Dataframe representing
|
|
532
|
+
the Confusion Matrix, acc is the computed accuracy
|
|
533
|
+
"""
|
|
534
|
+
unique_labels = np.unique([test_labels, pred_labels])
|
|
535
|
+
cl_rep = classification_report(
|
|
536
|
+
y_true=test_labels,
|
|
537
|
+
y_pred=pred_labels,
|
|
538
|
+
labels=unique_labels,
|
|
539
|
+
digits=cl_rep_digits,
|
|
540
|
+
zero_division=0,
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
cm = confusion_matrix(y_true=test_labels, y_pred=pred_labels, labels=unique_labels)
|
|
544
|
+
|
|
545
|
+
acc = accuracy_score(y_true=test_labels, y_pred=pred_labels)
|
|
546
|
+
|
|
547
|
+
if idx_to_labels:
|
|
548
|
+
pd_cm = pd.DataFrame(
|
|
549
|
+
cm,
|
|
550
|
+
index=[f"true:{idx_to_labels[x]}" for x in unique_labels],
|
|
551
|
+
columns=[f"pred:{idx_to_labels[x]}" for x in unique_labels],
|
|
552
|
+
)
|
|
553
|
+
else:
|
|
554
|
+
pd_cm = pd.DataFrame(
|
|
555
|
+
cm,
|
|
556
|
+
index=[f"true:{x}" for x in unique_labels],
|
|
557
|
+
columns=[f"pred:{x}" for x in unique_labels],
|
|
558
|
+
)
|
|
559
|
+
return cl_rep, pd_cm, acc
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def automatic_batch_size_computation(
|
|
563
|
+
datamodule: SklearnClassificationDataModule | PatchSklearnClassificationDataModule,
|
|
564
|
+
backbone: ModelSignatureWrapper,
|
|
565
|
+
starting_batch_size: int,
|
|
566
|
+
) -> int:
|
|
567
|
+
"""Find the optimal batch size for feature extraction. This algorithm works from the largest batch size possible
|
|
568
|
+
and divide by 2 until it finds the largest batch size that fits in memory.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
datamodule: Datamodule used for feature extraction
|
|
572
|
+
backbone: Backbone used for feature extraction
|
|
573
|
+
starting_batch_size: Starting batch size to use for the search
|
|
574
|
+
|
|
575
|
+
Returns:
|
|
576
|
+
Optimal batch size
|
|
577
|
+
"""
|
|
578
|
+
log.info("Finding optimal batch size...")
|
|
579
|
+
optimal = False
|
|
580
|
+
batch_size = starting_batch_size
|
|
581
|
+
|
|
582
|
+
while not optimal:
|
|
583
|
+
datamodule.batch_size = batch_size
|
|
584
|
+
base_dataloader = datamodule.train_dataloader()
|
|
585
|
+
|
|
586
|
+
if isinstance(base_dataloader, Sequence):
|
|
587
|
+
base_dataloader = base_dataloader[0]
|
|
588
|
+
|
|
589
|
+
if len(base_dataloader) == 1:
|
|
590
|
+
# If it fits in memory this is the largest batch size possible
|
|
591
|
+
# If it crashes restart with the previous batch size // 2
|
|
592
|
+
datamodule.batch_size = len(base_dataloader.dataset) # type: ignore[arg-type]
|
|
593
|
+
# New restarting batch size is the largest closest power of 2 to the dataset size, it will be divided by 2
|
|
594
|
+
batch_size = 2 ** math.ceil(math.log2(datamodule.batch_size))
|
|
595
|
+
base_dataloader = datamodule.train_dataloader()
|
|
596
|
+
if isinstance(base_dataloader, Sequence):
|
|
597
|
+
base_dataloader = base_dataloader[0]
|
|
598
|
+
optimal = True
|
|
599
|
+
|
|
600
|
+
try:
|
|
601
|
+
log.info("Trying batch size: %d", datamodule.batch_size)
|
|
602
|
+
_ = get_feature(feature_extractor=backbone, dl=base_dataloader, iteration_over_training=1, limit_batches=1)
|
|
603
|
+
except RuntimeError as e:
|
|
604
|
+
if batch_size > 1:
|
|
605
|
+
batch_size = batch_size // 2
|
|
606
|
+
optimal = False
|
|
607
|
+
continue
|
|
608
|
+
|
|
609
|
+
log.error("Unable to run the model with batch size 1")
|
|
610
|
+
raise e
|
|
611
|
+
|
|
612
|
+
log.info("Found optimal batch size: %d", datamodule.batch_size)
|
|
613
|
+
optimal = True
|
|
614
|
+
|
|
615
|
+
if torch.cuda.is_available():
|
|
616
|
+
torch.cuda.empty_cache()
|
|
617
|
+
|
|
618
|
+
return datamodule.batch_size
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
|
|
4
|
+
from quadra.utils.utils import get_logger
|
|
5
|
+
|
|
6
|
+
logger = get_logger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def deprecated(message: str) -> Callable:
|
|
10
|
+
"""Decorator to mark a function as deprecated.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
message: Message to be displayed when the function is called.
|
|
14
|
+
|
|
15
|
+
Returns:
|
|
16
|
+
Decoratored function.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def deprecated_decorator(func_or_class: Callable) -> Callable:
|
|
20
|
+
"""Decorator to mark a function as deprecated."""
|
|
21
|
+
|
|
22
|
+
@functools.wraps(func_or_class)
|
|
23
|
+
def wrapper(*args, **kwargs):
|
|
24
|
+
"""Wrapper function to display a warning message."""
|
|
25
|
+
warning_msg = f"{func_or_class.__name__} is deprecated. {message}"
|
|
26
|
+
logger.warning(warning_msg)
|
|
27
|
+
return func_or_class(*args, **kwargs)
|
|
28
|
+
|
|
29
|
+
return wrapper
|
|
30
|
+
|
|
31
|
+
return deprecated_decorator
|