quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
quadra/datasets/patch.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import random
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
|
|
7
|
+
import cv2
|
|
8
|
+
import h5py
|
|
9
|
+
import numpy as np
|
|
10
|
+
from torch.utils.data import Dataset
|
|
11
|
+
|
|
12
|
+
from quadra.utils import utils
|
|
13
|
+
from quadra.utils.imaging import keep_aspect_ratio_resize
|
|
14
|
+
from quadra.utils.patch.dataset import compute_safe_patch_range, trisample
|
|
15
|
+
|
|
16
|
+
log = utils.get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PatchSklearnClassificationTrainDataset(Dataset):
|
|
20
|
+
"""Dataset used for patch sampling, it expects samples to be paths to h5 files containing all the required
|
|
21
|
+
information for patch sampling from images.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
data_path: base path to the dataset
|
|
25
|
+
samples: Paths to h5 files
|
|
26
|
+
targets: Labels associated with each sample
|
|
27
|
+
class_to_idx: Mapping between class and corresponding index
|
|
28
|
+
resize: Whether to perform an aspect ratio resize of the patch before the transformations
|
|
29
|
+
transform: Optional function applied to the image
|
|
30
|
+
rgb: if False, image will be converted in grayscale
|
|
31
|
+
channel: 1 or 3. If rgb is True, then channel will be set at 3.
|
|
32
|
+
balance_classes: if True, the dataset will be balanced by duplicating samples of the minority class
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
data_path: str,
|
|
38
|
+
samples: list[str],
|
|
39
|
+
targets: list[str | int],
|
|
40
|
+
class_to_idx: dict | None = None,
|
|
41
|
+
resize: int | None = None,
|
|
42
|
+
transform: Callable | None = None,
|
|
43
|
+
rgb: bool = True,
|
|
44
|
+
channel: int = 3,
|
|
45
|
+
balance_classes: bool = False,
|
|
46
|
+
):
|
|
47
|
+
super().__init__()
|
|
48
|
+
|
|
49
|
+
# Keep-Aspect-Ratio resize
|
|
50
|
+
self.resize = resize
|
|
51
|
+
self.data_path = data_path
|
|
52
|
+
|
|
53
|
+
if balance_classes:
|
|
54
|
+
samples_array = np.array(samples)
|
|
55
|
+
targets_array = np.array(targets)
|
|
56
|
+
samples_to_use: list[str] = []
|
|
57
|
+
targets_to_use: list[str | int] = []
|
|
58
|
+
|
|
59
|
+
cls, counts = np.unique(targets_array, return_counts=True)
|
|
60
|
+
max_count = np.max(counts)
|
|
61
|
+
for cl, count in zip(cls, counts):
|
|
62
|
+
idx_to_pick = list(np.where(targets_array == cl)[0])
|
|
63
|
+
|
|
64
|
+
if count < max_count:
|
|
65
|
+
idx_to_pick += random.choices(idx_to_pick, k=max_count - count)
|
|
66
|
+
|
|
67
|
+
samples_to_use.extend(samples_array[idx_to_pick])
|
|
68
|
+
targets_to_use.extend(targets_array[idx_to_pick])
|
|
69
|
+
else:
|
|
70
|
+
samples_to_use = samples
|
|
71
|
+
targets_to_use = targets
|
|
72
|
+
|
|
73
|
+
# Data
|
|
74
|
+
self.x = np.array(samples_to_use)
|
|
75
|
+
self.y = np.array(targets_to_use)
|
|
76
|
+
|
|
77
|
+
if class_to_idx is None:
|
|
78
|
+
unique_targets = np.unique(targets_to_use)
|
|
79
|
+
class_to_idx = {c: i for i, c in enumerate(unique_targets)}
|
|
80
|
+
|
|
81
|
+
self.class_to_idx = class_to_idx
|
|
82
|
+
self.idx_to_class = {v: k for k, v in class_to_idx.items()}
|
|
83
|
+
|
|
84
|
+
self.samples = [
|
|
85
|
+
(path, self.class_to_idx[self.y[i]] if self.y[i] is not None else None) for i, path in enumerate(self.x)
|
|
86
|
+
]
|
|
87
|
+
|
|
88
|
+
self.rgb = rgb
|
|
89
|
+
self.channel = 3 if rgb else channel
|
|
90
|
+
|
|
91
|
+
self.transform = transform
|
|
92
|
+
|
|
93
|
+
def __getitem__(self, idx) -> tuple[np.ndarray, np.ndarray]:
|
|
94
|
+
path, y = self.samples[idx]
|
|
95
|
+
|
|
96
|
+
h5_file = h5py.File(path)
|
|
97
|
+
x = cv2.imread(os.path.join(self.data_path, h5_file["img_path"][()].decode("utf-8")))
|
|
98
|
+
|
|
99
|
+
weights = h5_file["triangles_weights"][()]
|
|
100
|
+
patch_size = h5_file["patch_size"][()]
|
|
101
|
+
|
|
102
|
+
if weights.shape[0] == 0: # pylint: disable=no-member
|
|
103
|
+
# If the image is completely good sample a point anywhere
|
|
104
|
+
patch_y = np.random.randint(0, x.shape[0] + 1)
|
|
105
|
+
patch_x = np.random.randint(0, x.shape[1] + 1)
|
|
106
|
+
else:
|
|
107
|
+
random_triangle = np.random.choice(weights.shape[0], p=weights)
|
|
108
|
+
[patch_y, patch_x] = trisample(h5_file["triangles"][random_triangle])
|
|
109
|
+
|
|
110
|
+
h5_file.close()
|
|
111
|
+
|
|
112
|
+
# If the patch is outside the image reduce the exceeding area by taking more patch from the inner area
|
|
113
|
+
[y_left, y_right] = compute_safe_patch_range(patch_y, patch_size[0], x.shape[0])
|
|
114
|
+
[x_left, x_right] = compute_safe_patch_range(patch_x, patch_size[1], x.shape[1])
|
|
115
|
+
|
|
116
|
+
x = x[patch_y - y_left : patch_y + y_right, patch_x - x_left : patch_x + x_right]
|
|
117
|
+
|
|
118
|
+
if self.rgb:
|
|
119
|
+
x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
|
|
120
|
+
else:
|
|
121
|
+
x = cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)
|
|
122
|
+
x = cv2.cvtColor(x, cv2.COLOR_GRAY2RGB)
|
|
123
|
+
|
|
124
|
+
if self.channel == 1:
|
|
125
|
+
x = x[:, :, 0]
|
|
126
|
+
|
|
127
|
+
# Resize keeping aspect ratio
|
|
128
|
+
if self.resize:
|
|
129
|
+
x = keep_aspect_ratio_resize(x, self.resize)
|
|
130
|
+
|
|
131
|
+
if self.transform:
|
|
132
|
+
aug = self.transform(image=x)
|
|
133
|
+
x = aug["image"]
|
|
134
|
+
|
|
135
|
+
return x, y
|
|
136
|
+
|
|
137
|
+
def __len__(self):
|
|
138
|
+
return len(self.samples)
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import albumentations
|
|
8
|
+
import cv2
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from quadra.utils.deprecation import deprecated
|
|
13
|
+
from quadra.utils.imaging import keep_aspect_ratio_resize
|
|
14
|
+
from quadra.utils.segmentation import smooth_mask
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
# DEPRECATED -> we can use SegmentationDatasetMulticlass also for one class segmentation
|
|
18
|
+
@deprecated("Use SegmentationDatasetMulticlass instead")
|
|
19
|
+
class SegmentationDataset(torch.utils.data.Dataset):
|
|
20
|
+
"""Custom SegmentationDataset class for loading images and masks.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
image_paths: List of paths to images.
|
|
24
|
+
mask_paths: List of paths to masks.
|
|
25
|
+
batch_size: Batch size.
|
|
26
|
+
object_masks: List of paths to object masks.
|
|
27
|
+
resize: Resize image to this size.
|
|
28
|
+
mask_preprocess: Preprocess mask.
|
|
29
|
+
labels: List of labels.
|
|
30
|
+
transform: Transformations to apply to images and masks.
|
|
31
|
+
mask_smoothing: Smooth mask.
|
|
32
|
+
defect_transform: Transformations to apply to images and masks for defects.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
image_paths: list[str],
|
|
38
|
+
mask_paths: list[str],
|
|
39
|
+
batch_size: int | None = None,
|
|
40
|
+
object_masks: list[np.ndarray | Any] | None = None,
|
|
41
|
+
resize: int = 224,
|
|
42
|
+
mask_preprocess: Callable | None = None,
|
|
43
|
+
labels: list[str] | None = None,
|
|
44
|
+
transform: albumentations.Compose | None = None,
|
|
45
|
+
mask_smoothing: bool = False,
|
|
46
|
+
defect_transform: albumentations.Compose | None = None,
|
|
47
|
+
):
|
|
48
|
+
self.transform = transform
|
|
49
|
+
self.defect_transform = defect_transform
|
|
50
|
+
self.image_paths = image_paths
|
|
51
|
+
self.mask_paths = mask_paths
|
|
52
|
+
self.labels = labels
|
|
53
|
+
self.mask_preprocess = mask_preprocess
|
|
54
|
+
self.resize = resize
|
|
55
|
+
self.object_masks = object_masks
|
|
56
|
+
self.data_len = len(self.image_paths)
|
|
57
|
+
self.batch_size = None if batch_size is None else max(batch_size, self.data_len)
|
|
58
|
+
self.smooth_mask = mask_smoothing
|
|
59
|
+
|
|
60
|
+
def __getitem__(self, index):
|
|
61
|
+
# This is required to avoid infinite loop when running the dataset outside of a dataloader
|
|
62
|
+
if self.batch_size is not None and self.batch_size == index:
|
|
63
|
+
raise StopIteration
|
|
64
|
+
|
|
65
|
+
if self.batch_size is None and self.data_len == index:
|
|
66
|
+
raise StopIteration
|
|
67
|
+
|
|
68
|
+
index = index % self.data_len
|
|
69
|
+
image_path = self.image_paths[index]
|
|
70
|
+
|
|
71
|
+
image = cv2.imread(str(image_path))
|
|
72
|
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
73
|
+
object_mask_path = self.object_masks[index] if self.object_masks is not None else None
|
|
74
|
+
if object_mask_path is not None:
|
|
75
|
+
object_mask = cv2.imread(str(object_mask_path), 0) if os.path.isfile(object_mask_path) else None
|
|
76
|
+
else:
|
|
77
|
+
object_mask = None
|
|
78
|
+
label = self.labels[index] if self.labels is not None else None
|
|
79
|
+
if (
|
|
80
|
+
self.mask_paths[index] is np.nan
|
|
81
|
+
or self.mask_paths[index] is None
|
|
82
|
+
or not os.path.isfile(self.mask_paths[index])
|
|
83
|
+
):
|
|
84
|
+
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
|
85
|
+
else:
|
|
86
|
+
mask_path = self.mask_paths[index]
|
|
87
|
+
mask = cv2.imread(str(mask_path), 0)
|
|
88
|
+
if self.defect_transform is not None and label == 1 and np.sum(mask) == 0:
|
|
89
|
+
if object_mask is not None:
|
|
90
|
+
object_mask *= 255
|
|
91
|
+
aug = self.defect_transform(image=image, mask=mask, object_mask=object_mask, label=label)
|
|
92
|
+
image = aug["image"]
|
|
93
|
+
mask = aug["mask"]
|
|
94
|
+
label = aug["label"]
|
|
95
|
+
if self.mask_preprocess:
|
|
96
|
+
mask = self.mask_preprocess(mask)
|
|
97
|
+
if object_mask is not None:
|
|
98
|
+
object_mask = self.mask_preprocess(object_mask)
|
|
99
|
+
if self.resize:
|
|
100
|
+
image = keep_aspect_ratio_resize(image, self.resize)
|
|
101
|
+
mask = keep_aspect_ratio_resize(mask, self.resize)
|
|
102
|
+
if object_mask is not None:
|
|
103
|
+
object_mask = keep_aspect_ratio_resize(object_mask, self.resize)
|
|
104
|
+
|
|
105
|
+
if self.transform is not None:
|
|
106
|
+
aug = self.transform(image=image, mask=mask)
|
|
107
|
+
image = aug["image"]
|
|
108
|
+
mask = aug["mask"]
|
|
109
|
+
if isinstance(mask, np.ndarray):
|
|
110
|
+
mask_sum = np.sum(mask)
|
|
111
|
+
elif isinstance(mask, torch.Tensor):
|
|
112
|
+
mask_sum = torch.sum(mask)
|
|
113
|
+
else:
|
|
114
|
+
raise ValueError("Unsupported type for mask")
|
|
115
|
+
if mask_sum > 0 and (label is None or label == 0):
|
|
116
|
+
label = 1
|
|
117
|
+
if mask_sum == 0:
|
|
118
|
+
label = 0
|
|
119
|
+
|
|
120
|
+
if isinstance(image, np.ndarray):
|
|
121
|
+
mask = (mask > 0).astype(np.uint8)
|
|
122
|
+
|
|
123
|
+
if self.smooth_mask:
|
|
124
|
+
mask = smooth_mask(mask)
|
|
125
|
+
mask = np.expand_dims(mask, axis=0)
|
|
126
|
+
else:
|
|
127
|
+
mask = (mask > 0).int()
|
|
128
|
+
if self.smooth_mask:
|
|
129
|
+
mask = torch.from_numpy(smooth_mask(mask.numpy()))
|
|
130
|
+
mask = mask.unsqueeze(0)
|
|
131
|
+
|
|
132
|
+
return image, mask, label
|
|
133
|
+
|
|
134
|
+
def __len__(self):
|
|
135
|
+
if self.batch_size is None:
|
|
136
|
+
return self.data_len
|
|
137
|
+
|
|
138
|
+
return max(self.data_len, self.batch_size)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class SegmentationDatasetMulticlass(torch.utils.data.Dataset):
|
|
142
|
+
"""Custom SegmentationDataset class for loading images and multilabel masks.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
image_paths: List of paths to images.
|
|
146
|
+
mask_paths: List of paths to masks.
|
|
147
|
+
idx_to_class: dict with corrispondence btw mask index and classes: {1: class_1, 2: class_2, ..., N: class_N}
|
|
148
|
+
batch_size: Batch size.
|
|
149
|
+
transform: Transformations to apply to images and masks.
|
|
150
|
+
one_hot: if True return a binary mask (n_classxHxW), otherwise the labelled mask HxW. SMP loss requires the
|
|
151
|
+
second format.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
image_paths: list[str],
|
|
157
|
+
mask_paths: list[str],
|
|
158
|
+
idx_to_class: dict,
|
|
159
|
+
batch_size: int | None = None,
|
|
160
|
+
transform: albumentations.Compose | None = None,
|
|
161
|
+
one_hot: bool = False,
|
|
162
|
+
):
|
|
163
|
+
self.transform = transform
|
|
164
|
+
self.image_paths = image_paths
|
|
165
|
+
self.mask_paths = mask_paths
|
|
166
|
+
self.idx_to_class = idx_to_class
|
|
167
|
+
self.data_len = len(self.image_paths)
|
|
168
|
+
self.batch_size = None if batch_size is None else max(batch_size, self.data_len)
|
|
169
|
+
self.one_hot = one_hot
|
|
170
|
+
|
|
171
|
+
def _preprocess_mask(self, mask: np.ndarray):
|
|
172
|
+
"""Function to preprocess the mask -> needed for albumentations
|
|
173
|
+
Args:
|
|
174
|
+
mask: a numpy array of dimension HxW with values in [0] + self.idx_to_class.
|
|
175
|
+
|
|
176
|
+
Output:
|
|
177
|
+
a binary numpy array with dims len(self.idx_to_class) + 1 x H x W
|
|
178
|
+
"""
|
|
179
|
+
multilayer_mask = np.zeros((len(self.idx_to_class) + 1, *mask.shape[:2]))
|
|
180
|
+
# provide background information for completeness
|
|
181
|
+
# single channel mask does not use it anyway.
|
|
182
|
+
multilayer_mask[0] = (mask == 0).astype(np.uint8)
|
|
183
|
+
for idx in self.idx_to_class:
|
|
184
|
+
multilayer_mask[int(idx)] = (mask == int(idx)).astype(np.uint8)
|
|
185
|
+
|
|
186
|
+
return multilayer_mask
|
|
187
|
+
|
|
188
|
+
def __getitem__(self, index):
|
|
189
|
+
"""Get image and mask."""
|
|
190
|
+
# This is required to avoid infinite loop when running the dataset outside of a dataloader
|
|
191
|
+
if self.batch_size is not None and self.batch_size == index:
|
|
192
|
+
raise StopIteration
|
|
193
|
+
if self.batch_size is None and self.data_len == index:
|
|
194
|
+
raise StopIteration
|
|
195
|
+
|
|
196
|
+
index = index % self.data_len
|
|
197
|
+
image_path = self.image_paths[index]
|
|
198
|
+
|
|
199
|
+
image = cv2.imread(str(image_path))
|
|
200
|
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
201
|
+
|
|
202
|
+
if (
|
|
203
|
+
self.mask_paths[index] is np.nan
|
|
204
|
+
or self.mask_paths[index] is None
|
|
205
|
+
or not os.path.isfile(self.mask_paths[index])
|
|
206
|
+
):
|
|
207
|
+
mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
|
208
|
+
else:
|
|
209
|
+
mask_path = self.mask_paths[index]
|
|
210
|
+
mask = cv2.imread(str(mask_path), 0)
|
|
211
|
+
|
|
212
|
+
# we go back to binary masks avoid transformation errors
|
|
213
|
+
mask = self._preprocess_mask(mask)
|
|
214
|
+
|
|
215
|
+
if self.transform is not None:
|
|
216
|
+
masks = list(mask)
|
|
217
|
+
aug = self.transform(image=image, masks=masks)
|
|
218
|
+
image = aug["image"]
|
|
219
|
+
mask = np.stack(aug["masks"]) # C x H x W
|
|
220
|
+
|
|
221
|
+
# we compute single channel mask again
|
|
222
|
+
# zero is the background
|
|
223
|
+
if not self.one_hot: # one hot is done by smp dice loss
|
|
224
|
+
mask_out = np.zeros(mask.shape[1:])
|
|
225
|
+
for i in range(1, mask.shape[0]):
|
|
226
|
+
mask_out[mask[i] == 1] = i
|
|
227
|
+
# mask_out shape -> HxW
|
|
228
|
+
else:
|
|
229
|
+
mask_out = mask
|
|
230
|
+
# mask_out shape -> CxHxW where C is number of classes (included the background)
|
|
231
|
+
|
|
232
|
+
return image, mask_out.astype(int), 0
|
|
233
|
+
|
|
234
|
+
def __len__(self):
|
|
235
|
+
"""Returns the dataset lenght."""
|
|
236
|
+
if self.batch_size is None:
|
|
237
|
+
return self.data_len
|
|
238
|
+
|
|
239
|
+
return max(self.data_len, self.batch_size)
|
quadra/datasets/ssl.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from collections.abc import Iterable
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
import albumentations as A
|
|
8
|
+
import numpy as np
|
|
9
|
+
from torch.utils.data import Dataset
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AugmentationStrategy(Enum):
|
|
13
|
+
"""Augmentation Strategy for TwoAugmentationDataset."""
|
|
14
|
+
|
|
15
|
+
SAME_IMAGE = 1
|
|
16
|
+
SAME_CLASS = 2
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TwoAugmentationDataset(Dataset):
|
|
20
|
+
"""Two Image Augmentation Dataset for using in self-supervised learning.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
dataset: A torch Dataset object
|
|
24
|
+
transform: albumentation transformations for each image.
|
|
25
|
+
If you use single transformation, it will be applied to both images.
|
|
26
|
+
If you use tuple, it will be applied to first image and second image separately.
|
|
27
|
+
strategy: Defaults to AugmentationStrategy.SAME_IMAGE.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
dataset: Dataset,
|
|
33
|
+
transform: A.Compose | tuple[A.Compose, A.Compose],
|
|
34
|
+
strategy: AugmentationStrategy = AugmentationStrategy.SAME_IMAGE,
|
|
35
|
+
):
|
|
36
|
+
self.dataset = dataset
|
|
37
|
+
self.transform = transform
|
|
38
|
+
self.stategy = strategy
|
|
39
|
+
if isinstance(transform, Iterable) and not isinstance(transform, str) and len(set(transform)) != 2:
|
|
40
|
+
raise ValueError("transform must be an Iterable of length 2")
|
|
41
|
+
|
|
42
|
+
def __getitem__(self, index):
|
|
43
|
+
image1, target = self.dataset[index]
|
|
44
|
+
|
|
45
|
+
if self.stategy == AugmentationStrategy.SAME_IMAGE:
|
|
46
|
+
image2 = image1
|
|
47
|
+
elif self.stategy == AugmentationStrategy.SAME_CLASS:
|
|
48
|
+
positive_pair_idx = random.choice(np.where(self.dataset.y == target)[0])
|
|
49
|
+
image2, _ = self.dataset[positive_pair_idx]
|
|
50
|
+
else:
|
|
51
|
+
raise ValueError("Unknown strategy")
|
|
52
|
+
|
|
53
|
+
if isinstance(self.transform, Iterable):
|
|
54
|
+
image1 = self.transform[0](image=image1)["image"]
|
|
55
|
+
image2 = self.transform[1](image=image2)["image"]
|
|
56
|
+
else:
|
|
57
|
+
image1 = self.transform(image=image1)["image"]
|
|
58
|
+
image2 = self.transform(image=image2)["image"]
|
|
59
|
+
|
|
60
|
+
return [image1, image2], target
|
|
61
|
+
|
|
62
|
+
def __len__(self):
|
|
63
|
+
return len(self.dataset)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class TwoSetAugmentationDataset(Dataset):
|
|
67
|
+
"""Two Set Augmentation Dataset for using in self-supervised learning (DINO).
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
dataset: Base dataset
|
|
71
|
+
global_transforms: Global transformations for each image.
|
|
72
|
+
local_transform: Local transformations for each image.
|
|
73
|
+
num_local_transforms: Number of local transformations to apply. In total you will have
|
|
74
|
+
two + num_local_transforms transformations for each image. First element of the array will always
|
|
75
|
+
return the original image.
|
|
76
|
+
|
|
77
|
+
Example:
|
|
78
|
+
>>> images[0] = global_transform[0](original_image)
|
|
79
|
+
>>> images[1] = global_transform[1](original_image)
|
|
80
|
+
>>> images[2:] = local_transform(s)(original_image)
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
dataset: Dataset,
|
|
86
|
+
global_transforms: tuple[A.Compose, A.Compose],
|
|
87
|
+
local_transform: A.Compose,
|
|
88
|
+
num_local_transforms: int,
|
|
89
|
+
):
|
|
90
|
+
self.dataset = dataset
|
|
91
|
+
self.global_transforms = global_transforms
|
|
92
|
+
self.local_transform = local_transform
|
|
93
|
+
self.num_local_transforms = num_local_transforms
|
|
94
|
+
|
|
95
|
+
if num_local_transforms < 1:
|
|
96
|
+
raise ValueError("num_local_transforms must be greater than 0")
|
|
97
|
+
|
|
98
|
+
def __getitem__(self, index):
|
|
99
|
+
original_image, target = self.dataset[index]
|
|
100
|
+
global_outputs = []
|
|
101
|
+
local_outputs = []
|
|
102
|
+
for global_transform in self.global_transforms:
|
|
103
|
+
global_outputs.append(global_transform(image=original_image)["image"])
|
|
104
|
+
for _ in range(self.num_local_transforms):
|
|
105
|
+
local_outputs.append(self.local_transform(image=original_image)["image"])
|
|
106
|
+
all_outputs = global_outputs + local_outputs
|
|
107
|
+
return all_outputs, target
|
|
108
|
+
|
|
109
|
+
def __len__(self):
|
|
110
|
+
return len(self.dataset)
|
|
File without changes
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class AsymmetricLoss(torch.nn.Module):
|
|
5
|
+
"""Notice - optimized version, minimizes memory allocation and gpu uploading,
|
|
6
|
+
favors inplace operations.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
gamma_neg: gamma for negative samples
|
|
10
|
+
gamma_pos: gamma for positive samples
|
|
11
|
+
m: bias value added to negative samples
|
|
12
|
+
eps: epsilon to avoid division by zero
|
|
13
|
+
disable_torch_grad_focal_loss: if True, disables torch grad for focal loss
|
|
14
|
+
apply_sigmoid: if True, applies sigmoid to input before computing loss
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
gamma_neg: float = 4,
|
|
20
|
+
gamma_pos: float = 0,
|
|
21
|
+
m: float = 0.05,
|
|
22
|
+
eps: float = 1e-8,
|
|
23
|
+
disable_torch_grad_focal_loss: bool = False,
|
|
24
|
+
apply_sigmoid: bool = True,
|
|
25
|
+
):
|
|
26
|
+
super().__init__()
|
|
27
|
+
|
|
28
|
+
self.gamma_neg = gamma_neg
|
|
29
|
+
self.gamma_pos = gamma_pos
|
|
30
|
+
self.m = m
|
|
31
|
+
self.eps = eps
|
|
32
|
+
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
|
|
33
|
+
self.apply_sigmoid = apply_sigmoid
|
|
34
|
+
|
|
35
|
+
# prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
|
|
36
|
+
self.targets: torch.Tensor
|
|
37
|
+
self.anti_targets: torch.Tensor
|
|
38
|
+
self.xs_pos: torch.Tensor
|
|
39
|
+
self.xs_neg: torch.Tensor
|
|
40
|
+
self.asymmetric_w: torch.Tensor
|
|
41
|
+
self.loss: torch.Tensor
|
|
42
|
+
|
|
43
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
44
|
+
"""Compute the asymmetric loss.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
x: input logits (after sigmoid)
|
|
48
|
+
y: targets (multi-label binarized vector)
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
asymettric loss
|
|
52
|
+
"""
|
|
53
|
+
self.targets = y
|
|
54
|
+
self.anti_targets = 1 - y
|
|
55
|
+
|
|
56
|
+
# Calculating Probabilities
|
|
57
|
+
self.xs_pos = x
|
|
58
|
+
if self.apply_sigmoid:
|
|
59
|
+
self.xs_pos = torch.sigmoid(self.xs_pos)
|
|
60
|
+
self.xs_neg = 1.0 - self.xs_pos
|
|
61
|
+
|
|
62
|
+
# Asymmetric clipping
|
|
63
|
+
if self.m is not None and self.m > 0:
|
|
64
|
+
self.xs_neg.add_(self.m).clamp_(max=1)
|
|
65
|
+
|
|
66
|
+
# Basic CE calculation
|
|
67
|
+
self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
|
|
68
|
+
self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
|
|
69
|
+
|
|
70
|
+
# Asymmetric Focusing
|
|
71
|
+
if self.gamma_neg > 0 or self.gamma_pos > 0:
|
|
72
|
+
if self.disable_torch_grad_focal_loss:
|
|
73
|
+
torch.set_grad_enabled(False)
|
|
74
|
+
self.xs_pos = self.xs_pos * self.targets
|
|
75
|
+
self.xs_neg = self.xs_neg * self.anti_targets
|
|
76
|
+
self.asymmetric_w = torch.pow(
|
|
77
|
+
1 - self.xs_pos - self.xs_neg, self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets
|
|
78
|
+
)
|
|
79
|
+
if self.disable_torch_grad_focal_loss:
|
|
80
|
+
torch.set_grad_enabled(True)
|
|
81
|
+
self.loss *= self.asymmetric_w
|
|
82
|
+
|
|
83
|
+
return -self.loss.sum()
|