quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.nn.functional import cosine_similarity
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def cosine_align_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
6
|
+
"""Computes mean of cosine distance based on similarity mean(1 - cosine_similarity).
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
x: feature n1
|
|
10
|
+
y: feature n2.
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
cosine align loss
|
|
14
|
+
"""
|
|
15
|
+
cos = 1 - cosine_similarity(x, y, dim=1)
|
|
16
|
+
return torch.mean(cos)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# Source: https://arxiv.org/pdf/2005.10242.pdf
|
|
20
|
+
def align_loss(x: torch.Tensor, y: torch.Tensor, alpha: int = 2) -> torch.Tensor:
|
|
21
|
+
"""Mean(l2^alpha).
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
x: feature n1
|
|
25
|
+
y: feature n2
|
|
26
|
+
alpha: pow of the norm loss.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Align loss
|
|
30
|
+
"""
|
|
31
|
+
norm = torch.norm(x - y, p=2, dim=1)
|
|
32
|
+
return torch.mean(torch.pow(norm, alpha))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def uniform_loss(x: torch.Tensor, t: float = 2.0) -> torch.Tensor:
|
|
36
|
+
"""log(mean(exp(-t*dist_p2))).
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
x: feature tensor
|
|
40
|
+
t: temperature of the dist_p2.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Uniform loss
|
|
44
|
+
"""
|
|
45
|
+
return torch.log(torch.mean(torch.exp(torch.pow(torch.pdist(x, p=2), 2) * -t)))
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def idmm_loss(
|
|
6
|
+
p1: torch.Tensor,
|
|
7
|
+
y1: torch.Tensor,
|
|
8
|
+
smoothing: float = 0.1,
|
|
9
|
+
) -> torch.Tensor:
|
|
10
|
+
"""IDMM loss described in https://arxiv.org/abs/2201.10728.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
p1: Prediction labels for `z1`
|
|
14
|
+
y1: Instance labels for `z1`
|
|
15
|
+
smoothing: smoothing factor used for label smoothing.
|
|
16
|
+
Defaults to 0.1.
|
|
17
|
+
|
|
18
|
+
Returns:
|
|
19
|
+
IDMM loss
|
|
20
|
+
"""
|
|
21
|
+
loss = F.cross_entropy(p1, y1, label_smoothing=smoothing)
|
|
22
|
+
return loss
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class IDMMLoss(torch.nn.Module):
|
|
26
|
+
"""IDMM loss described in https://arxiv.org/abs/2201.10728."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, smoothing: float = 0.1):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.smoothing = smoothing
|
|
31
|
+
|
|
32
|
+
def forward(
|
|
33
|
+
self,
|
|
34
|
+
p1: torch.Tensor,
|
|
35
|
+
y1: torch.Tensor,
|
|
36
|
+
) -> torch.Tensor:
|
|
37
|
+
"""IDMM loss described in https://arxiv.org/abs/2201.10728.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
p1: Prediction labels for `z1`
|
|
41
|
+
y1: Instance labels for `z1`
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
IDMM loss
|
|
45
|
+
"""
|
|
46
|
+
return idmm_loss(
|
|
47
|
+
p1,
|
|
48
|
+
y1,
|
|
49
|
+
self.smoothing,
|
|
50
|
+
)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
from quadra.utils.utils import AllGatherSyncFunction
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def simclr_loss(
|
|
10
|
+
features1: torch.Tensor,
|
|
11
|
+
features2: torch.Tensor,
|
|
12
|
+
temperature: float = 1.0,
|
|
13
|
+
) -> torch.Tensor:
|
|
14
|
+
"""SimCLR loss described in https://arxiv.org/pdf/2002.05709.pdf.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
temperature: optional temperature
|
|
18
|
+
features1: First augmented features (i.e. T(features))
|
|
19
|
+
features2: Second augmented features (i.e. T'(features))
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
SimCLR loss
|
|
23
|
+
"""
|
|
24
|
+
features1 = F.normalize(features1, dim=-1)
|
|
25
|
+
features2 = F.normalize(features2, dim=-1)
|
|
26
|
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
27
|
+
features1_dist = AllGatherSyncFunction.apply(features1)
|
|
28
|
+
features2_dist = AllGatherSyncFunction.apply(features2)
|
|
29
|
+
else:
|
|
30
|
+
features1_dist = features1
|
|
31
|
+
features2_dist = features2
|
|
32
|
+
features = torch.cat([features1, features2], dim=0) # [2B, d]
|
|
33
|
+
features_dist = torch.cat([features1_dist, features2_dist], dim=0) # [2B * DIST_SIZE, d]
|
|
34
|
+
|
|
35
|
+
# Similarity matrix
|
|
36
|
+
sim = torch.exp(torch.div(torch.mm(features, features_dist.t()), temperature)) # [2B, 2B * DIST_SIZE]
|
|
37
|
+
|
|
38
|
+
# Negatives
|
|
39
|
+
neg = sim.sum(dim=-1)
|
|
40
|
+
|
|
41
|
+
# From each row, subtract e^(1/temp) to remove similarity measure for zi * zi, since
|
|
42
|
+
# (zi^T * zi) / ||zi||^2 = 1
|
|
43
|
+
row_sub = torch.full_like(neg, math.e ** (1 / temperature), device=neg.device)
|
|
44
|
+
neg = torch.clamp(neg - row_sub, min=1e-6) # clamp for numerical stability
|
|
45
|
+
|
|
46
|
+
# Positive similarity, pos becomes [2 * batch_size]
|
|
47
|
+
pos = torch.exp(torch.div(torch.sum(features1 * features2, dim=-1), temperature))
|
|
48
|
+
pos = torch.cat([pos, pos], dim=0)
|
|
49
|
+
|
|
50
|
+
loss = -torch.log(pos / (neg + 1e-6)).mean()
|
|
51
|
+
return loss
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class SimCLRLoss(torch.nn.Module):
|
|
55
|
+
"""SIMCLRloss module.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
temperature: temperature of SIM loss.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(self, temperature: float = 1.0):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.temperature = temperature
|
|
64
|
+
|
|
65
|
+
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
|
66
|
+
"""Forward pass of the loss."""
|
|
67
|
+
return simclr_loss(x1, x2, temperature=self.temperature)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def simsiam_loss(
|
|
6
|
+
p1: torch.Tensor,
|
|
7
|
+
p2: torch.Tensor,
|
|
8
|
+
z1: torch.Tensor,
|
|
9
|
+
z2: torch.Tensor,
|
|
10
|
+
) -> torch.Tensor:
|
|
11
|
+
"""SimSIAM loss described in https://arxiv.org/abs/2011.10566.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
p1: First `predicted` features (i.e. h(f(T(x1))))
|
|
15
|
+
p2: Second `predicted` features (i.e. h(f(T'(x2))))
|
|
16
|
+
z1: First 'projected features (i.e. f(T(x1)))
|
|
17
|
+
z2: Second 'projected features (i.e. f(T(x2)))
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
SimSIAM loss
|
|
21
|
+
"""
|
|
22
|
+
return -(F.cosine_similarity(p1, z2).mean() + F.cosine_similarity(p2, z1).mean()) * 0.5
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SimSIAMLoss(torch.nn.Module):
|
|
26
|
+
"""SimSIAM loss module."""
|
|
27
|
+
|
|
28
|
+
def forward(self, p1: torch.Tensor, p2: torch.Tensor, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
|
|
29
|
+
"""Compute the SimSIAM loss."""
|
|
30
|
+
return simsiam_loss(p1, p2, z1, z2)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def vicreg_loss(
|
|
5
|
+
z1: torch.Tensor,
|
|
6
|
+
z2: torch.Tensor,
|
|
7
|
+
lambd: float,
|
|
8
|
+
mu: float,
|
|
9
|
+
nu: float = 1,
|
|
10
|
+
gamma: float = 1,
|
|
11
|
+
) -> torch.Tensor:
|
|
12
|
+
"""VICReg loss described in https://arxiv.org/abs/2105.04906.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
z1: First `augmented` normalized features (i.e. f(T(x))). The normalization can be obtained with
|
|
16
|
+
z1_norm = (z1 - z1.mean(0)) / z1.std(0)
|
|
17
|
+
z2: Second `augmented` normalized features (i.e. f(T(x))). The normalization can be obtained with
|
|
18
|
+
z2_norm = (z2 - z2.mean(0)) / z2.std(0)
|
|
19
|
+
lambd: lambda multiplier for redundancy term.
|
|
20
|
+
mu: mu multiplier for similarity term.
|
|
21
|
+
nu: nu multiplier for variance term. Default: 1
|
|
22
|
+
gamma: gamma multiplier for covariance term. Default: 1
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
VICReg loss
|
|
26
|
+
"""
|
|
27
|
+
# Variance loss
|
|
28
|
+
std_z1 = torch.sqrt(z1.var(dim=0) + 0.0001)
|
|
29
|
+
std_z2 = torch.sqrt(z2.var(dim=0) + 0.0001)
|
|
30
|
+
v_z1 = torch.nn.functional.relu(gamma - std_z1).mean()
|
|
31
|
+
v_z2 = torch.nn.functional.relu(gamma - std_z2).mean()
|
|
32
|
+
var_loss = v_z1 + v_z2
|
|
33
|
+
|
|
34
|
+
# Similarity loss
|
|
35
|
+
sim_loss = torch.nn.functional.mse_loss(z1, z2)
|
|
36
|
+
|
|
37
|
+
# Covariance loss
|
|
38
|
+
n = z1.size(0)
|
|
39
|
+
d = z1.size(1)
|
|
40
|
+
z1 = z1 - z1.mean(dim=0)
|
|
41
|
+
z2 = z2 - z2.mean(dim=0)
|
|
42
|
+
cov_z1 = (z1.T @ z1) / (n - 1)
|
|
43
|
+
cov_z2 = (z2.T @ z2) / (n - 1)
|
|
44
|
+
off_diagonal_cov_z1 = cov_z1.flatten()[:-1].view(d - 1, d + 1)[:, 1:].flatten()
|
|
45
|
+
off_diagonal_cov_z2 = cov_z2.flatten()[:-1].view(d - 1, d + 1)[:, 1:].flatten()
|
|
46
|
+
cov_loss = off_diagonal_cov_z1.pow_(2).sum() / d + off_diagonal_cov_z2.pow_(2).sum() / d
|
|
47
|
+
|
|
48
|
+
return lambd * sim_loss + mu * var_loss + nu * cov_loss
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class VICRegLoss(torch.nn.Module):
|
|
52
|
+
"""VIC regression loss module.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
lambd: lambda multiplier for redundancy term.
|
|
56
|
+
mu: mu multiplier for similarity term.
|
|
57
|
+
nu: nu multiplier for variance term. Default: 1.
|
|
58
|
+
gamma: gamma multiplier for covariance term. Default: 1.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
lambd: float,
|
|
64
|
+
mu: float,
|
|
65
|
+
nu: float = 1,
|
|
66
|
+
gamma: float = 1,
|
|
67
|
+
):
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.lambd = lambd
|
|
70
|
+
self.mu = mu
|
|
71
|
+
self.nu = nu
|
|
72
|
+
self.gamma = gamma
|
|
73
|
+
|
|
74
|
+
def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
|
|
75
|
+
"""Computes VICReg loss."""
|
|
76
|
+
return vicreg_loss(z1, z2, self.lambd, self.mu, self.nu, self.gamma)
|
quadra/main.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import time
|
|
2
|
+
|
|
3
|
+
import hydra
|
|
4
|
+
import matplotlib
|
|
5
|
+
from omegaconf import DictConfig
|
|
6
|
+
from pytorch_lightning import seed_everything
|
|
7
|
+
|
|
8
|
+
from quadra.tasks.base import Task
|
|
9
|
+
from quadra.utils.resolver import register_resolvers
|
|
10
|
+
from quadra.utils.utils import get_logger, load_envs, setup_opencv
|
|
11
|
+
from quadra.utils.validator import validate_config
|
|
12
|
+
|
|
13
|
+
load_envs()
|
|
14
|
+
register_resolvers()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
matplotlib.use("Agg")
|
|
18
|
+
log = get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@hydra.main(config_path="configs/", config_name="config.yaml", version_base="1.3.0")
|
|
22
|
+
def main(config: DictConfig):
|
|
23
|
+
"""Main entry function for any of the tasks."""
|
|
24
|
+
if config.validate:
|
|
25
|
+
start = time.time()
|
|
26
|
+
validate_config(config)
|
|
27
|
+
stop = time.time()
|
|
28
|
+
log.info("Config validation took %f seconds", stop - start)
|
|
29
|
+
|
|
30
|
+
from quadra.utils import utils # pylint: disable=import-outside-toplevel
|
|
31
|
+
|
|
32
|
+
utils.extras(config)
|
|
33
|
+
|
|
34
|
+
# Prints the resolved configuration to the console
|
|
35
|
+
if config.get("print_config"):
|
|
36
|
+
utils.print_config(config, resolve=True)
|
|
37
|
+
|
|
38
|
+
# Set seed for random number generators in pytorch, numpy and python.random
|
|
39
|
+
seed_everything(config.core.seed, workers=True)
|
|
40
|
+
setup_opencv()
|
|
41
|
+
|
|
42
|
+
# Run specified task using the configuration composition
|
|
43
|
+
task: Task = hydra.utils.instantiate(config.task, config, _recursive_=False)
|
|
44
|
+
task.execute()
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
if __name__ == "__main__":
|
|
48
|
+
# pylint: disable=no-value-for-parameter
|
|
49
|
+
main()
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import cast
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from scipy.optimize import linear_sum_assignment
|
|
8
|
+
from skimage.measure import label, regionprops # pylint: disable=no-name-in-module
|
|
9
|
+
|
|
10
|
+
from quadra.utils.evaluation import dice
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _pad_to_shape(a: np.ndarray, shape: tuple, constant_values: int = 0) -> np.ndarray:
|
|
14
|
+
"""Pad lower - right with 0s
|
|
15
|
+
Args:
|
|
16
|
+
a: numpy array to pad
|
|
17
|
+
shape: shape of the resulting np.array
|
|
18
|
+
constant_values: value to pad.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
Padded array
|
|
22
|
+
"""
|
|
23
|
+
y_, x_ = shape
|
|
24
|
+
y, x = a.shape
|
|
25
|
+
y_pad = y_ - y
|
|
26
|
+
x_pad = x_ - x
|
|
27
|
+
return np.pad(
|
|
28
|
+
a,
|
|
29
|
+
(
|
|
30
|
+
(0, y_pad),
|
|
31
|
+
(0, x_pad),
|
|
32
|
+
),
|
|
33
|
+
mode="constant",
|
|
34
|
+
constant_values=constant_values,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _get_iou(bboxes1: np.ndarray, bboxes2: np.ndarray, approx_iou: bool = False) -> np.ndarray:
|
|
39
|
+
"""Intersect over union
|
|
40
|
+
Args:
|
|
41
|
+
bboxes1: extracted bounding boxes
|
|
42
|
+
bboxes2: ground truth
|
|
43
|
+
approx_iou: flag to approximate.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Intersect over union array
|
|
47
|
+
"""
|
|
48
|
+
x11, y11, x12, y12 = np.split(bboxes1, 4, axis=1)
|
|
49
|
+
x21, y21, x22, y22 = np.split(bboxes2, 4, axis=1)
|
|
50
|
+
|
|
51
|
+
# determine the (x, y)-coordinates of the intersection rectangle
|
|
52
|
+
xA = np.maximum(x11, x21.T)
|
|
53
|
+
yA = np.maximum(y11, y21.T)
|
|
54
|
+
xB = np.minimum(x12, x22.T)
|
|
55
|
+
yB = np.minimum(y12, y22.T)
|
|
56
|
+
|
|
57
|
+
# compute the area of intersection rectangle
|
|
58
|
+
inter_area = np.maximum((xB - xA), 0) * np.maximum((yB - yA), 0)
|
|
59
|
+
|
|
60
|
+
# compute the area of both the prediction and ground-truth rectangles
|
|
61
|
+
box_a_area = (x12 - x11) * (y12 - y11)
|
|
62
|
+
box_b_area = (x22 - x21) * (y22 - y21)
|
|
63
|
+
|
|
64
|
+
if approx_iou:
|
|
65
|
+
iou = inter_area / box_b_area.T
|
|
66
|
+
else:
|
|
67
|
+
iou = inter_area / (box_a_area + box_b_area.T - inter_area)
|
|
68
|
+
|
|
69
|
+
return iou
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _get_dice_matrix(
|
|
73
|
+
labels_pred: np.ndarray,
|
|
74
|
+
n_labels_pred: int,
|
|
75
|
+
labels_gt: np.ndarray,
|
|
76
|
+
n_labels_gt: int,
|
|
77
|
+
) -> np.ndarray:
|
|
78
|
+
"""Create dice matrix
|
|
79
|
+
Args:
|
|
80
|
+
labels_pred: predicted label
|
|
81
|
+
n_labels_pred: number of label predicted
|
|
82
|
+
labels_gt: ground truth labels
|
|
83
|
+
n_labels_gt: number of gt labels.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Dice matrix
|
|
87
|
+
"""
|
|
88
|
+
m = np.zeros((n_labels_pred, n_labels_gt))
|
|
89
|
+
for i in range(n_labels_pred):
|
|
90
|
+
pred = labels_pred == i + 1
|
|
91
|
+
for j in range(n_labels_gt):
|
|
92
|
+
gt = labels_gt == j + 1
|
|
93
|
+
m[i, j] = dice(
|
|
94
|
+
torch.Tensor(pred).unsqueeze(0).unsqueeze(0),
|
|
95
|
+
torch.Tensor(gt).unsqueeze(0).unsqueeze(0),
|
|
96
|
+
reduction="none",
|
|
97
|
+
)
|
|
98
|
+
return m
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def segmentation_props(
|
|
102
|
+
pred: np.ndarray, mask: np.ndarray
|
|
103
|
+
) -> tuple[float, float, float, float, list[float], float, int, int, int, int]:
|
|
104
|
+
"""Return some information regarding a segmentation task.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
pred (np.ndarray[bool]): Prediction of a segmentation model as
|
|
108
|
+
a binary image.
|
|
109
|
+
mask (np.ndarray[bool]): Ground truth mask as binary image
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
1-Dice(pred, mask) Given a matrix (a_ij) = (1-Dice)(prediction_i, ground_truth_j),
|
|
113
|
+
where prediction_i is the i-th prediction connected component and
|
|
114
|
+
ground_truth_j is the j-th ground truth connected component,
|
|
115
|
+
I compute the LSA (Linear Sum Assignment) to find the optimal 1-to-1 assignment
|
|
116
|
+
between predictions and ground truths that minimize the (1-Dice) score.
|
|
117
|
+
Then, for every unique pair of (predictioni, ground_truthj) we compute Average
|
|
118
|
+
(1-Dice)(predictioni, ground_truthj)
|
|
119
|
+
Average (1-Dice)(predictioni, ground_truthj) between True Positives
|
|
120
|
+
(that is predictions associated to a ground truth), which is gratis from a[i,j],
|
|
121
|
+
where the average is computed w.r.t. the total number of True Positives found.
|
|
122
|
+
Average IoU(predictioni, ground_truthj) between True Positives
|
|
123
|
+
(that is predictions associated to a ground truth),
|
|
124
|
+
where the average is computed w.r.t. the total number of True Positives found.
|
|
125
|
+
The IoU is computed between the minimum enclosing bounding box
|
|
126
|
+
of a prediction and a ground truth.
|
|
127
|
+
Average area of False Positives
|
|
128
|
+
Histogram of false positives
|
|
129
|
+
Average area of False Negatives
|
|
130
|
+
Number of True Positives (predictions associated to a ground truth)
|
|
131
|
+
Number of False Positives (predictions without a ground truth associated)
|
|
132
|
+
and their avg. area (avg is taken w.r.t. the total number of False Positives found)
|
|
133
|
+
Number of False Negatives (ground truth without a predictions associated)
|
|
134
|
+
and their avg. area (avg is taken w.r.t. the total number of False Negatives found)
|
|
135
|
+
Number of labels in the mask.
|
|
136
|
+
"""
|
|
137
|
+
labels_pred, n_labels_pred = label(pred, connectivity=2, return_num=True, background=0)
|
|
138
|
+
labels_mask, n_labels_mask = label(mask, connectivity=2, return_num=True, background=0)
|
|
139
|
+
|
|
140
|
+
labels_pred = cast(np.ndarray, labels_pred)
|
|
141
|
+
labels_mask = cast(np.ndarray, labels_mask)
|
|
142
|
+
n_labels_pred = cast(int, n_labels_pred)
|
|
143
|
+
n_labels_mask = cast(int, n_labels_mask)
|
|
144
|
+
|
|
145
|
+
props_pred = regionprops(labels_pred)
|
|
146
|
+
props_mask = regionprops(labels_mask)
|
|
147
|
+
pred_bbox = np.array([props_pred[i].bbox for i in range(len(props_pred))])
|
|
148
|
+
mask_bbox = np.array([props_mask[i].bbox for i in range(len(props_mask))])
|
|
149
|
+
|
|
150
|
+
global_dice = float(
|
|
151
|
+
dice(
|
|
152
|
+
torch.Tensor(pred).unsqueeze(0).unsqueeze(0),
|
|
153
|
+
torch.Tensor(mask).unsqueeze(0).unsqueeze(0),
|
|
154
|
+
).item()
|
|
155
|
+
)
|
|
156
|
+
lsa_iou = 0.0
|
|
157
|
+
lsa_dice = 0.0
|
|
158
|
+
tp_num = 0
|
|
159
|
+
fp_num = 0
|
|
160
|
+
fn_num = 0
|
|
161
|
+
fp_area = 0.0
|
|
162
|
+
fn_area = 0.0
|
|
163
|
+
fp_hist: list[float] = []
|
|
164
|
+
if n_labels_pred > 0 and n_labels_mask > 0:
|
|
165
|
+
dice_mat = _get_dice_matrix(labels_pred, n_labels_pred, labels_mask, n_labels_mask)
|
|
166
|
+
# Thresholding over Dice scores
|
|
167
|
+
dice_mat = np.where(dice_mat <= 0.9, dice_mat, 1.0)
|
|
168
|
+
iou_mat = _get_iou(pred_bbox, mask_bbox, approx_iou=False)
|
|
169
|
+
dice_mat_shape = dice_mat.shape
|
|
170
|
+
max_dim = np.max(dice_mat_shape)
|
|
171
|
+
# Add dummy Dices so LSA is unique and i can compute FP and FN
|
|
172
|
+
dice_mat = _pad_to_shape(dice_mat, (max_dim, max_dim), 1)
|
|
173
|
+
lsa = linear_sum_assignment(dice_mat, maximize=False)
|
|
174
|
+
for row, col in zip(lsa[0], lsa[1]):
|
|
175
|
+
# More preds than GTs --> False Positive
|
|
176
|
+
if row < n_labels_pred and col >= n_labels_mask:
|
|
177
|
+
min_row = pred_bbox[row][0]
|
|
178
|
+
min_col = pred_bbox[row][1]
|
|
179
|
+
h = pred_bbox[row][2] - min_row
|
|
180
|
+
w = pred_bbox[row][3] - min_col
|
|
181
|
+
fp_num += 1
|
|
182
|
+
area = pred[min_row : min_row + h, min_col : min_col + w].sum()
|
|
183
|
+
fp_area += area
|
|
184
|
+
fp_hist.append(area)
|
|
185
|
+
continue
|
|
186
|
+
|
|
187
|
+
# More GTs than preds --> False Negative
|
|
188
|
+
if col < n_labels_mask and row >= n_labels_pred:
|
|
189
|
+
min_row = mask_bbox[col][0]
|
|
190
|
+
min_col = mask_bbox[col][1]
|
|
191
|
+
h = mask_bbox[col][2] - min_row
|
|
192
|
+
w = mask_bbox[col][3] - min_col
|
|
193
|
+
fn_num += 1
|
|
194
|
+
fn_area += mask[min_row : min_row + h, min_col : min_col + w].sum()
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
# Real True Positive: a prediction has been assigned to a gt
|
|
198
|
+
# with at least a 1-Dice score of 0.9
|
|
199
|
+
if dice_mat[row, col] <= 0.9:
|
|
200
|
+
tp_num += 1
|
|
201
|
+
lsa_iou += iou_mat[row, col]
|
|
202
|
+
lsa_dice += dice_mat[row, col]
|
|
203
|
+
else:
|
|
204
|
+
# Here we have both a FP and a FN
|
|
205
|
+
min_row = pred_bbox[row][0]
|
|
206
|
+
min_col = pred_bbox[row][1]
|
|
207
|
+
h = pred_bbox[row][2] - min_row
|
|
208
|
+
w = pred_bbox[row][3] - min_col
|
|
209
|
+
fp_num += 1
|
|
210
|
+
area = pred[min_row : min_row + h, min_col : min_col + w].sum()
|
|
211
|
+
fp_area += area
|
|
212
|
+
fp_hist.append(area)
|
|
213
|
+
|
|
214
|
+
min_row = mask_bbox[col][0]
|
|
215
|
+
min_col = mask_bbox[col][1]
|
|
216
|
+
h = mask_bbox[col][2] - min_row
|
|
217
|
+
w = mask_bbox[col][3] - min_col
|
|
218
|
+
fn_num += 1
|
|
219
|
+
fn_area += mask[min_row : min_row + h, min_col : min_col + w].sum()
|
|
220
|
+
elif len(pred_bbox) > 0 and len(mask_bbox) == 0: # No GTs --> FP
|
|
221
|
+
for p_bbox in pred_bbox:
|
|
222
|
+
min_row = p_bbox[0]
|
|
223
|
+
min_col = p_bbox[1]
|
|
224
|
+
h = p_bbox[2] - min_row
|
|
225
|
+
w = p_bbox[3] - min_col
|
|
226
|
+
fp_num += 1
|
|
227
|
+
# print("FP area:", pred[min_row : min_row + h, min_col : min_col + w].sum())
|
|
228
|
+
area = pred[min_row : min_row + h, min_col : min_col + w].sum()
|
|
229
|
+
fp_area += area
|
|
230
|
+
fp_hist.append(area)
|
|
231
|
+
elif len(pred_bbox) == 0 and len(mask_bbox) > 0: # No preds --> FN
|
|
232
|
+
for m_bbox in mask_bbox:
|
|
233
|
+
min_row = m_bbox[0]
|
|
234
|
+
min_col = m_bbox[1]
|
|
235
|
+
h = m_bbox[2] - min_row
|
|
236
|
+
w = m_bbox[3] - min_col
|
|
237
|
+
fn_num += 1
|
|
238
|
+
# print("FN area:", mask[min_row : min_row + h, min_col : min_col + w].sum())
|
|
239
|
+
fn_area += mask[min_row : min_row + h, min_col : min_col + w].sum()
|
|
240
|
+
return (
|
|
241
|
+
global_dice,
|
|
242
|
+
lsa_dice,
|
|
243
|
+
lsa_iou,
|
|
244
|
+
fp_area,
|
|
245
|
+
fp_hist,
|
|
246
|
+
fn_area,
|
|
247
|
+
tp_num,
|
|
248
|
+
fp_num,
|
|
249
|
+
fn_num,
|
|
250
|
+
n_labels_mask,
|
|
251
|
+
)
|
|
File without changes
|