quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,474 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from ast import literal_eval
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from functools import wraps
|
|
7
|
+
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
import numpy as np
|
|
10
|
+
import numpy.typing as npt
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import seaborn as sns
|
|
13
|
+
import segmentation_models_pytorch as smp
|
|
14
|
+
import torch
|
|
15
|
+
import yaml
|
|
16
|
+
from segmentation_models_pytorch.losses import DiceLoss
|
|
17
|
+
from segmentation_models_pytorch.losses.constants import BINARY_MODE, MULTICLASS_MODE
|
|
18
|
+
from skimage.measure import label, regionprops # pylint: disable=no-name-in-module
|
|
19
|
+
|
|
20
|
+
from quadra.utils.logger import get_logger
|
|
21
|
+
from quadra.utils.visualization import UnNormalize, create_grid_figure
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException # noqa
|
|
25
|
+
|
|
26
|
+
ONNX_AVAILABLE = True
|
|
27
|
+
except ImportError:
|
|
28
|
+
ONNX_AVAILABLE = False
|
|
29
|
+
|
|
30
|
+
log = get_logger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def dice(
|
|
34
|
+
input_tensor: torch.Tensor,
|
|
35
|
+
target: torch.Tensor,
|
|
36
|
+
smooth: float = 1.0,
|
|
37
|
+
eps: float = 1e-8,
|
|
38
|
+
reduction: str | None = "mean",
|
|
39
|
+
) -> torch.Tensor:
|
|
40
|
+
"""Dice loss computation function.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
input_tensor: input tensor coming from a model
|
|
44
|
+
target: target tensor to compare with
|
|
45
|
+
smooth: smoothing factor
|
|
46
|
+
eps: epsilon to avoid zero division
|
|
47
|
+
reduction: reduction method, one of "mean", "sum", "none"
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
The computed loss
|
|
51
|
+
"""
|
|
52
|
+
bs = input_tensor.size(0)
|
|
53
|
+
iflat = input_tensor.contiguous().view(bs, -1)
|
|
54
|
+
tflat = target.contiguous().view(bs, -1)
|
|
55
|
+
intersection = (iflat * tflat).sum(-1)
|
|
56
|
+
loss = 1 - (2.0 * intersection + smooth) / (iflat.sum(-1) + tflat.sum(-1) + smooth + eps)
|
|
57
|
+
|
|
58
|
+
if reduction == "mean":
|
|
59
|
+
loss = loss.mean()
|
|
60
|
+
return loss
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def score_dice(
|
|
64
|
+
y_pred,
|
|
65
|
+
y_true,
|
|
66
|
+
reduction=None,
|
|
67
|
+
) -> torch.Tensor:
|
|
68
|
+
"""Calculate dice score."""
|
|
69
|
+
return 1 - dice(y_pred, y_true, reduction=reduction)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def score_dice_smp(y_pred: torch.Tensor, y_true: torch.Tensor, mode: str = "binary") -> torch.Tensor:
|
|
73
|
+
"""Compute dice using smp function. Handle both binary and multiclass scenario.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
y_pred: 1xCxHxW one channel for each class
|
|
77
|
+
y_true: 1x1xHxW true mask with value in [0, ..., n_classes]
|
|
78
|
+
mode: "binary" or "multiclass"
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
dice score
|
|
82
|
+
"""
|
|
83
|
+
if mode not in {BINARY_MODE, MULTICLASS_MODE}:
|
|
84
|
+
raise ValueError(f"Mode {mode} not valid.")
|
|
85
|
+
|
|
86
|
+
loss = DiceLoss(mode=mode, from_logits=False)
|
|
87
|
+
|
|
88
|
+
return 1 - loss(y_pred, y_true)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def calculate_mask_based_metrics(
|
|
92
|
+
images: np.ndarray,
|
|
93
|
+
th_masks: torch.Tensor,
|
|
94
|
+
th_preds: torch.Tensor,
|
|
95
|
+
threshold: float = 0.5,
|
|
96
|
+
show_orj_predictions: bool = False,
|
|
97
|
+
metric: Callable = score_dice,
|
|
98
|
+
multilabel: bool = False,
|
|
99
|
+
n_classes: int | None = None,
|
|
100
|
+
) -> tuple[
|
|
101
|
+
dict[str, float],
|
|
102
|
+
dict[str, list[np.ndarray]],
|
|
103
|
+
dict[str, list[np.ndarray]],
|
|
104
|
+
dict[str, list[str | float]],
|
|
105
|
+
]:
|
|
106
|
+
"""Calculate metrics based on masks and predictions.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
images: Images.
|
|
110
|
+
th_masks: masks are tensors.
|
|
111
|
+
th_preds: predictions are tensors.
|
|
112
|
+
threshold: Threshold to apply. Defaults to 0.5.
|
|
113
|
+
show_orj_predictions: Flag to show original predictions. Defaults to False.
|
|
114
|
+
metric: Metric to use comparison. Defaults to `score_dice`.
|
|
115
|
+
multilabel: True if segmentation is multiclass.
|
|
116
|
+
n_classes: Number of classes. If multilabel is False, this should be None.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
dict: Dictionary with metrics.
|
|
120
|
+
"""
|
|
121
|
+
masks = th_masks.cpu().numpy()
|
|
122
|
+
preds = th_preds.squeeze(0).cpu().numpy()
|
|
123
|
+
th_thresh_preds = (th_preds > threshold).float().cpu()
|
|
124
|
+
thresh_preds = th_thresh_preds.squeeze(0).numpy()
|
|
125
|
+
dice_scores = metric(th_thresh_preds, th_masks, reduction=None).numpy()
|
|
126
|
+
result = {}
|
|
127
|
+
if multilabel:
|
|
128
|
+
if n_classes is None:
|
|
129
|
+
raise ValueError("n_classes arg shouldn't be None when multilabel is True")
|
|
130
|
+
preds_multilabel = (
|
|
131
|
+
torch.nn.functional.one_hot(th_preds.to(torch.int64), num_classes=n_classes).squeeze(1).permute(0, 3, 1, 2)
|
|
132
|
+
)
|
|
133
|
+
masks_multilabel = (
|
|
134
|
+
torch.nn.functional.one_hot(th_masks.to(torch.int64), num_classes=n_classes).squeeze(1).permute(0, 3, 1, 2)
|
|
135
|
+
).to(preds_multilabel.device)
|
|
136
|
+
# get_stats multiclass, not considering background channel
|
|
137
|
+
tp, fp, fn, tn = smp.metrics.get_stats(
|
|
138
|
+
preds_multilabel[:, 1:, :, :].long(), masks_multilabel[:, 1:, :, :].long(), mode="multilabel"
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
tp, fp, fn, tn = smp.metrics.get_stats(th_thresh_preds.long(), th_masks.long(), mode="binary")
|
|
142
|
+
per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")
|
|
143
|
+
dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
|
|
144
|
+
result["F1_image"] = round(float(smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro-imagewise").item()), 4)
|
|
145
|
+
result["F1_pixel"] = round(float(smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro").item()), 4)
|
|
146
|
+
result["image_iou"] = round(float(per_image_iou.item()), 4) if not per_image_iou.isnan() else np.nan
|
|
147
|
+
result["dataset_iou"] = round(float(dataset_iou.item()), 4) if not dataset_iou.isnan() else np.nan
|
|
148
|
+
result["TP_pixel"] = tp.sum().item()
|
|
149
|
+
result["FP_pixel"] = fp.sum().item()
|
|
150
|
+
result["FN_pixel"] = fn.sum().item()
|
|
151
|
+
result["TN_pixel"] = tn.sum().item()
|
|
152
|
+
result["TP_image"] = 0
|
|
153
|
+
result["FP_image"] = 0
|
|
154
|
+
result["FN_image"] = 0
|
|
155
|
+
result["TN_image"] = 0
|
|
156
|
+
result["num_good_image"] = 0
|
|
157
|
+
result["num_bad_image"] = 0
|
|
158
|
+
bad_dice, good_dice = [], []
|
|
159
|
+
fg: dict[str, list[np.ndarray]] = {"image": [], "mask": [], "thresh_pred": []}
|
|
160
|
+
fb: dict[str, list[np.ndarray]] = {"image": [], "mask": [], "thresh_pred": []}
|
|
161
|
+
if show_orj_predictions:
|
|
162
|
+
fg["pred"] = []
|
|
163
|
+
fb["pred"] = []
|
|
164
|
+
|
|
165
|
+
area_graph: dict[str, list[str | float]] = {
|
|
166
|
+
"Defect Area Percentage": [],
|
|
167
|
+
"Accuracy": [],
|
|
168
|
+
}
|
|
169
|
+
for idx, (image, pred, mask, thresh_pred, dice_score) in enumerate(
|
|
170
|
+
zip(images, preds, masks, thresh_preds, dice_scores)
|
|
171
|
+
):
|
|
172
|
+
if np.sum(mask) == 0:
|
|
173
|
+
good_dice.append(dice_score)
|
|
174
|
+
else:
|
|
175
|
+
bad_dice.append(dice_score)
|
|
176
|
+
if mask.sum() > 0:
|
|
177
|
+
result["num_bad_image"] += 1
|
|
178
|
+
if thresh_pred.sum() == 0:
|
|
179
|
+
result["FN_image"] += 1
|
|
180
|
+
fg["image"].append(image)
|
|
181
|
+
fg["mask"].append(mask)
|
|
182
|
+
if show_orj_predictions:
|
|
183
|
+
fg["pred"].append(pred)
|
|
184
|
+
fg["thresh_pred"].append(thresh_pred)
|
|
185
|
+
else:
|
|
186
|
+
result["TP_image"] += 1
|
|
187
|
+
rp = regionprops(label(mask[0]))
|
|
188
|
+
for r in rp:
|
|
189
|
+
mask_partial = th_masks[idx, :, r.bbox[0] : r.bbox[2], r.bbox[1] : r.bbox[3]]
|
|
190
|
+
pred_partial = th_thresh_preds[idx, :, r.bbox[0] : r.bbox[2], r.bbox[1] : r.bbox[3]]
|
|
191
|
+
tp, fp, fn, tn = smp.metrics.get_stats(pred_partial.long(), mask_partial.long(), mode="binary")
|
|
192
|
+
area = tp + fn
|
|
193
|
+
area_percentage = area.sum().item() * 100 / (image.shape[0] * image.shape[1])
|
|
194
|
+
defect_acc = smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro")
|
|
195
|
+
area_graph["Accuracy"].append(defect_acc.item() * 100)
|
|
196
|
+
if area_percentage <= 1:
|
|
197
|
+
area_graph["Defect Area Percentage"].append("Very Small <1%")
|
|
198
|
+
elif area_percentage <= 10:
|
|
199
|
+
area_graph["Defect Area Percentage"].append("Small <10%")
|
|
200
|
+
elif area_percentage <= 25:
|
|
201
|
+
area_graph["Defect Area Percentage"].append("Medium <25%")
|
|
202
|
+
else:
|
|
203
|
+
area_graph["Defect Area Percentage"].append("Large >25%")
|
|
204
|
+
|
|
205
|
+
if mask.sum() == 0:
|
|
206
|
+
result["num_good_image"] += 1
|
|
207
|
+
if thresh_pred.sum() > 0:
|
|
208
|
+
result["FP_image"] += 1
|
|
209
|
+
fb["image"].append(image)
|
|
210
|
+
fb["mask"].append(mask)
|
|
211
|
+
if show_orj_predictions:
|
|
212
|
+
fb["pred"].append(pred)
|
|
213
|
+
fb["thresh_pred"].append(thresh_pred)
|
|
214
|
+
else:
|
|
215
|
+
result["TN_image"] += 1
|
|
216
|
+
result["bad_dice_score_mean"] = np.mean(bad_dice) if len(bad_dice) > 0 else "null"
|
|
217
|
+
result["bad_dice_score_std"] = np.std(bad_dice) if len(bad_dice) > 0 else "null"
|
|
218
|
+
result["good_dice_score_mean"] = np.mean(good_dice) if len(good_dice) > 0 else "null"
|
|
219
|
+
result["good_dice_score_std"] = np.std(good_dice) if len(good_dice) > 0 else "null"
|
|
220
|
+
return result, fg, fb, area_graph
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def create_mask_report(
|
|
224
|
+
stage: str,
|
|
225
|
+
output: dict[str, torch.Tensor],
|
|
226
|
+
mean: npt.ArrayLike,
|
|
227
|
+
std: npt.ArrayLike,
|
|
228
|
+
report_path: str,
|
|
229
|
+
nb_samples: int = 6,
|
|
230
|
+
analysis: bool = False,
|
|
231
|
+
apply_sigmoid: bool = True,
|
|
232
|
+
show_all: bool = False,
|
|
233
|
+
threshold: float = 0.5,
|
|
234
|
+
metric: Callable = score_dice,
|
|
235
|
+
show_orj_predictions: bool = False,
|
|
236
|
+
) -> list[str]:
|
|
237
|
+
"""Create report for segmentation experiment
|
|
238
|
+
Args:
|
|
239
|
+
stage: stage name. Train, validation or test
|
|
240
|
+
output: data produced by model
|
|
241
|
+
report_path: experiment path
|
|
242
|
+
mean: mean values
|
|
243
|
+
std: std values
|
|
244
|
+
nb_samples: number of samples
|
|
245
|
+
analysis: if True, analysis will be created
|
|
246
|
+
apply_sigmoid: if True, sigmoid will be applied to predictions
|
|
247
|
+
show_all: if True, all images will be shown
|
|
248
|
+
threshold: threshold for predictions
|
|
249
|
+
metric: metric function
|
|
250
|
+
show_orj_predictions: if True, original predictions will be shown.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
list of paths to created images.
|
|
254
|
+
"""
|
|
255
|
+
if not os.path.exists(report_path):
|
|
256
|
+
os.makedirs(report_path)
|
|
257
|
+
|
|
258
|
+
th_images = output["image"]
|
|
259
|
+
th_masks = output["mask"]
|
|
260
|
+
th_preds = output["mask_pred"]
|
|
261
|
+
th_labels = output["label"]
|
|
262
|
+
n_classes = th_preds.shape[1]
|
|
263
|
+
# TODO: Apply sigmoid is a wrong name now
|
|
264
|
+
if apply_sigmoid:
|
|
265
|
+
if n_classes == 1:
|
|
266
|
+
th_preds = torch.nn.Sigmoid()(th_preds)
|
|
267
|
+
th_thresh_preds = (th_preds > threshold).float()
|
|
268
|
+
else:
|
|
269
|
+
th_preds = torch.nn.Softmax(dim=1)(th_preds)
|
|
270
|
+
th_thresh_preds = torch.argmax(th_preds, dim=1).float().unsqueeze(1)
|
|
271
|
+
# Compute labels from the given masks since by default they are all 0
|
|
272
|
+
th_labels = th_masks.max(dim=2)[0].max(dim=2)[0].squeeze(dim=1)
|
|
273
|
+
show_orj_predictions = False
|
|
274
|
+
|
|
275
|
+
mean = np.asarray(mean)
|
|
276
|
+
std = np.asarray(std)
|
|
277
|
+
unnormalize = UnNormalize(mean, std)
|
|
278
|
+
|
|
279
|
+
images = np.array(
|
|
280
|
+
[(unnormalize(image).cpu().numpy().transpose((1, 2, 0)) * 255).astype(np.uint8) for image in th_images]
|
|
281
|
+
)
|
|
282
|
+
masks = th_masks.cpu().numpy()
|
|
283
|
+
preds = th_preds.squeeze(0).cpu().numpy()
|
|
284
|
+
thresh_preds = th_thresh_preds.squeeze(0).cpu().numpy()
|
|
285
|
+
dice_scores = metric(th_thresh_preds.cpu(), th_masks.cpu(), reduction=None).numpy()
|
|
286
|
+
|
|
287
|
+
labels = th_labels.cpu().numpy()
|
|
288
|
+
binary_labels = labels == 0
|
|
289
|
+
|
|
290
|
+
row_names = ["Input", "Mask", "Pred", f"Pred>{threshold}"]
|
|
291
|
+
bounds = [(0, 255), (0.0, float(n_classes - 1)), (0.0, 1.0), (0.0, float(n_classes - 1))]
|
|
292
|
+
if not show_orj_predictions:
|
|
293
|
+
row_names.pop(2)
|
|
294
|
+
bounds.pop(2)
|
|
295
|
+
|
|
296
|
+
if not show_all:
|
|
297
|
+
sorted_idx = np.argsort(dice_scores)
|
|
298
|
+
else:
|
|
299
|
+
sorted_idx = np.arange(len(dice_scores))
|
|
300
|
+
|
|
301
|
+
binary_labels = binary_labels[sorted_idx]
|
|
302
|
+
|
|
303
|
+
non_zero_score_idx = sorted_idx[~binary_labels]
|
|
304
|
+
zero_score_idx = sorted_idx[binary_labels]
|
|
305
|
+
file_paths = []
|
|
306
|
+
for name, current_score_idx in zip(["good", "bad"], [zero_score_idx, non_zero_score_idx]):
|
|
307
|
+
if len(current_score_idx) == 0:
|
|
308
|
+
continue
|
|
309
|
+
|
|
310
|
+
nb_total_samples = len(current_score_idx)
|
|
311
|
+
nb_selected_samples = nb_total_samples if nb_samples > nb_total_samples else nb_samples
|
|
312
|
+
fig_w = int(nb_selected_samples * 2)
|
|
313
|
+
fig_h = int(len(row_names) * 2)
|
|
314
|
+
if not show_all:
|
|
315
|
+
worst_idx = current_score_idx[:nb_selected_samples].tolist()
|
|
316
|
+
best_idx = current_score_idx[-nb_selected_samples:].tolist()
|
|
317
|
+
random_idx = np.random.choice(current_score_idx, nb_selected_samples, replace=False).tolist()
|
|
318
|
+
|
|
319
|
+
indexes = {"best": best_idx, "worst": worst_idx, "random": random_idx}
|
|
320
|
+
else:
|
|
321
|
+
indexes = {"all": current_score_idx[:nb_selected_samples].tolist()}
|
|
322
|
+
for k, v in indexes.items():
|
|
323
|
+
file_path = os.path.join(report_path, f"{stage}_{name}_{k}_results.png")
|
|
324
|
+
images_to_show = [images[v], masks[v], preds[v], thresh_preds[v]]
|
|
325
|
+
if not show_orj_predictions or n_classes > 1:
|
|
326
|
+
images_to_show.pop(2)
|
|
327
|
+
create_grid_figure(
|
|
328
|
+
images_to_show,
|
|
329
|
+
nrows=len(row_names),
|
|
330
|
+
ncols=nb_selected_samples,
|
|
331
|
+
row_names=row_names,
|
|
332
|
+
file_path=file_path,
|
|
333
|
+
fig_size=(fig_w, fig_h),
|
|
334
|
+
bounds=bounds,
|
|
335
|
+
)
|
|
336
|
+
file_paths.append(file_path)
|
|
337
|
+
if analysis:
|
|
338
|
+
analysis_file_path = os.path.join(report_path, f"{stage}_analysis.yaml")
|
|
339
|
+
result, fg, fb, area_graph = calculate_mask_based_metrics(
|
|
340
|
+
images=images,
|
|
341
|
+
th_masks=th_masks,
|
|
342
|
+
th_preds=th_thresh_preds,
|
|
343
|
+
threshold=threshold,
|
|
344
|
+
show_orj_predictions=show_orj_predictions,
|
|
345
|
+
metric=metric,
|
|
346
|
+
multilabel=bool(n_classes > 1),
|
|
347
|
+
n_classes=n_classes,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
if len(fg["image"]) > 0:
|
|
351
|
+
if len(fg["image"]) > nb_samples:
|
|
352
|
+
for k, v in fg.items():
|
|
353
|
+
fg[k] = v[:nb_samples]
|
|
354
|
+
|
|
355
|
+
fg_file_path = os.path.join(report_path, f"{stage}_fn_results.png")
|
|
356
|
+
fig_w = int(len(fg["image"]) * 2)
|
|
357
|
+
create_grid_figure(
|
|
358
|
+
[fg for _, fg in fg.items()],
|
|
359
|
+
nrows=len(row_names),
|
|
360
|
+
ncols=len(fg["image"]),
|
|
361
|
+
row_names=row_names,
|
|
362
|
+
file_path=fg_file_path,
|
|
363
|
+
fig_size=(fig_w, fig_h),
|
|
364
|
+
bounds=bounds,
|
|
365
|
+
)
|
|
366
|
+
file_paths.append(fg_file_path)
|
|
367
|
+
|
|
368
|
+
if len(fb["image"]) > 0:
|
|
369
|
+
if len(fb["image"]) > nb_samples:
|
|
370
|
+
for k, v in fb.items():
|
|
371
|
+
fb[k] = v[:nb_samples]
|
|
372
|
+
fb_file_path = os.path.join(report_path, f"{stage}_fp_results.png")
|
|
373
|
+
|
|
374
|
+
fig_w = int(len(fb["image"]) * 2)
|
|
375
|
+
create_grid_figure(
|
|
376
|
+
[fb for _, fb in fb.items()],
|
|
377
|
+
nrows=len(row_names),
|
|
378
|
+
ncols=len(fb["image"]),
|
|
379
|
+
row_names=row_names,
|
|
380
|
+
file_path=fb_file_path,
|
|
381
|
+
fig_size=(fig_w, fig_h),
|
|
382
|
+
bounds=bounds,
|
|
383
|
+
)
|
|
384
|
+
file_paths.append(fb_file_path)
|
|
385
|
+
if len(area_graph["Defect Area Percentage"]) > 0:
|
|
386
|
+
fn_area_path = os.path.join(report_path, f"{stage}_acc_area.png")
|
|
387
|
+
fn_area_df = pd.DataFrame(area_graph)
|
|
388
|
+
ax = sns.boxplot(
|
|
389
|
+
x="Defect Area Percentage",
|
|
390
|
+
y="Accuracy",
|
|
391
|
+
data=fn_area_df,
|
|
392
|
+
order=["Very Small <1%", "Small <10%", "Medium <25%", "Large >25%"],
|
|
393
|
+
)
|
|
394
|
+
ax.set_facecolor("white")
|
|
395
|
+
fig = ax.get_figure()
|
|
396
|
+
fig.savefig(fn_area_path)
|
|
397
|
+
plt.close(fig)
|
|
398
|
+
|
|
399
|
+
file_paths.append(fn_area_path)
|
|
400
|
+
with open(analysis_file_path, "w") as file:
|
|
401
|
+
yaml.dump(literal_eval(str(result)), file, default_flow_style=False)
|
|
402
|
+
file_paths.append(analysis_file_path)
|
|
403
|
+
|
|
404
|
+
return file_paths
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def automatic_datamodule_batch_size(batch_size_attribute_name: str = "batch_size"):
|
|
408
|
+
"""Automatically scale the datamodule batch size if the given function goes out of memory.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
batch_size_attribute_name: The name of the attribute to modify in the datamodule
|
|
412
|
+
"""
|
|
413
|
+
|
|
414
|
+
def decorator(func: Callable):
|
|
415
|
+
"""Decorator function."""
|
|
416
|
+
|
|
417
|
+
@wraps(func)
|
|
418
|
+
def wrapper(self, *args, **kwargs):
|
|
419
|
+
"""Wrapper function."""
|
|
420
|
+
is_func_finished = False
|
|
421
|
+
starting_batch_size = None
|
|
422
|
+
automatic_batch_size_completed = False
|
|
423
|
+
|
|
424
|
+
if hasattr(self, "automatic_batch_size_completed"):
|
|
425
|
+
automatic_batch_size_completed = self.automatic_batch_size_completed
|
|
426
|
+
|
|
427
|
+
if hasattr(self, "automatic_batch_size"):
|
|
428
|
+
if not hasattr(self.automatic_batch_size, "disable") or not hasattr(
|
|
429
|
+
self.automatic_batch_size, "starting_batch_size"
|
|
430
|
+
):
|
|
431
|
+
raise ValueError(
|
|
432
|
+
"The automatic_batch_size attribute should have the disable and starting_batch_size attributes"
|
|
433
|
+
)
|
|
434
|
+
starting_batch_size = (
|
|
435
|
+
self.automatic_batch_size.starting_batch_size if not self.automatic_batch_size.disable else None
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
if starting_batch_size is not None and not automatic_batch_size_completed:
|
|
439
|
+
# If we already tried to reduce the batch size, we will start from the last batch size
|
|
440
|
+
log.info("Performing automatic batch size scaling from %d", starting_batch_size)
|
|
441
|
+
setattr(self.datamodule, batch_size_attribute_name, starting_batch_size)
|
|
442
|
+
|
|
443
|
+
while not is_func_finished:
|
|
444
|
+
valid_exceptions = (RuntimeError,)
|
|
445
|
+
|
|
446
|
+
if ONNX_AVAILABLE:
|
|
447
|
+
valid_exceptions += (RuntimeException,)
|
|
448
|
+
|
|
449
|
+
try:
|
|
450
|
+
func(self, *args, **kwargs)
|
|
451
|
+
is_func_finished = True
|
|
452
|
+
self.automatic_batch_size_completed = True
|
|
453
|
+
if torch.cuda.is_available():
|
|
454
|
+
torch.cuda.empty_cache()
|
|
455
|
+
except valid_exceptions as e:
|
|
456
|
+
current_batch_size = getattr(self.datamodule, batch_size_attribute_name)
|
|
457
|
+
setattr(self.datamodule, batch_size_attribute_name, current_batch_size // 2)
|
|
458
|
+
log.warning(
|
|
459
|
+
"The function %s went out of memory, trying to reduce the batch size to %d",
|
|
460
|
+
func.__name__,
|
|
461
|
+
self.datamodule.batch_size,
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
if self.datamodule.batch_size == 0:
|
|
465
|
+
raise RuntimeError(
|
|
466
|
+
f"Unable to run {func.__name__} with batch size 1, the program will exit"
|
|
467
|
+
) from e
|
|
468
|
+
|
|
469
|
+
if torch.cuda.is_available():
|
|
470
|
+
torch.cuda.empty_cache()
|
|
471
|
+
|
|
472
|
+
return wrapper
|
|
473
|
+
|
|
474
|
+
return decorator
|