quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def one_hot(
|
|
11
|
+
labels: torch.Tensor,
|
|
12
|
+
num_classes: int,
|
|
13
|
+
device: torch.device | None = None,
|
|
14
|
+
dtype: torch.dtype | None = None,
|
|
15
|
+
eps: float = 1e-6,
|
|
16
|
+
) -> torch.Tensor:
|
|
17
|
+
r"""Convert an integer label x-D tensor to a one-hot (x+1)-D tensor.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
labels: tensor with labels of shape :math:`(N, *)`, where N is batch size.
|
|
21
|
+
Each value is an integer representing correct classification.
|
|
22
|
+
num_classes: number of classes in labels.
|
|
23
|
+
device: the desired device of returned tensor.
|
|
24
|
+
dtype: the desired data type of returned tensor.
|
|
25
|
+
eps: a value added to the returned tensor.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
the labels in one hot tensor of shape :math:`(N, C, *)`,
|
|
29
|
+
|
|
30
|
+
Examples:
|
|
31
|
+
>>> labels = torch.LongTensor([[[0, 1], [2, 0]]])
|
|
32
|
+
>>> one_hot(labels, num_classes=3)
|
|
33
|
+
tensor([[[[1.0000e+00, 1.0000e-06],
|
|
34
|
+
[1.0000e-06, 1.0000e+00]],
|
|
35
|
+
<BLANKLINE>
|
|
36
|
+
[[1.0000e-06, 1.0000e+00],
|
|
37
|
+
[1.0000e-06, 1.0000e-06]],
|
|
38
|
+
<BLANKLINE>
|
|
39
|
+
[[1.0000e-06, 1.0000e-06],
|
|
40
|
+
[1.0000e+00, 1.0000e-06]]]])
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
if not isinstance(labels, torch.Tensor):
|
|
44
|
+
raise TypeError(f"Input labels type is not a torch.Tensor. Got {type(labels)}")
|
|
45
|
+
|
|
46
|
+
if not labels.dtype == torch.int64:
|
|
47
|
+
raise ValueError(f"labels must be of the same dtype torch.int64. Got: {labels.dtype}")
|
|
48
|
+
|
|
49
|
+
if num_classes < 1:
|
|
50
|
+
raise ValueError(f"The number of classes must be bigger than one. Got: {num_classes}")
|
|
51
|
+
|
|
52
|
+
shape = labels.shape
|
|
53
|
+
one_hot_output = torch.zeros((shape[0], num_classes) + shape[1:], device=device, dtype=dtype)
|
|
54
|
+
|
|
55
|
+
return one_hot_output.scatter_(1, labels.unsqueeze(1), 1.0) + eps
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# based on: # https://github.com/zhezh/focalloss/blob/master/focalloss.py
|
|
59
|
+
def focal_loss(
|
|
60
|
+
input_tensor: torch.Tensor,
|
|
61
|
+
target: torch.Tensor,
|
|
62
|
+
alpha: float,
|
|
63
|
+
gamma: float = 2.0,
|
|
64
|
+
reduction: str = "none",
|
|
65
|
+
eps: float | None = None,
|
|
66
|
+
) -> torch.Tensor:
|
|
67
|
+
r"""Criterion that computes Focal loss.
|
|
68
|
+
|
|
69
|
+
According to :cite:`lin2018focal`, the Focal loss is computed as follows:
|
|
70
|
+
|
|
71
|
+
.. math::
|
|
72
|
+
|
|
73
|
+
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
|
|
74
|
+
|
|
75
|
+
Where:
|
|
76
|
+
- :math:`p_t` is the model's estimated probability for each class.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
input_tensor: Logits tensor with shape :math:`(N, C, *)` where C = number of classes.
|
|
80
|
+
target: Labels tensor with shape :math:`(N, *)` where each value is :math:`0 ≤ targets[i] ≤ C−1`.
|
|
81
|
+
alpha: Weighting factor :math:`\alpha \in [0, 1]`.
|
|
82
|
+
gamma: Focusing parameter :math:`\gamma >= 0`.
|
|
83
|
+
reduction: Specifies the reduction to apply to the
|
|
84
|
+
output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
|
|
85
|
+
will be applied, ``'mean'``: the sum of the output will be divided by
|
|
86
|
+
the number of elements in the output, ``'sum'``: the output will be
|
|
87
|
+
summed.
|
|
88
|
+
eps: Deprecated: scalar to enforce numerical stabiliy. This is no longer used.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
The computed loss.
|
|
92
|
+
|
|
93
|
+
Example:
|
|
94
|
+
>>> N = 5 # num_classes
|
|
95
|
+
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
|
|
96
|
+
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
|
|
97
|
+
>>> output = focal_loss(input, target, alpha=0.5, gamma=2.0, reduction='mean')
|
|
98
|
+
>>> output.backward()
|
|
99
|
+
"""
|
|
100
|
+
if eps is not None and not torch.jit.is_scripting():
|
|
101
|
+
warnings.warn(
|
|
102
|
+
"`focal_loss` has been reworked for improved numerical stability "
|
|
103
|
+
"and the `eps` argument is no longer necessary",
|
|
104
|
+
DeprecationWarning,
|
|
105
|
+
stacklevel=2,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if not isinstance(input_tensor, torch.Tensor):
|
|
109
|
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input_tensor)}")
|
|
110
|
+
|
|
111
|
+
if not len(input_tensor.shape) >= 2:
|
|
112
|
+
raise ValueError(f"Invalid input shape, we expect BxCx*. Got: {input_tensor.shape}")
|
|
113
|
+
|
|
114
|
+
if input_tensor.size(0) != target.size(0):
|
|
115
|
+
raise ValueError(
|
|
116
|
+
f"Expected input batch_size ({input_tensor.size(0)}) to match target batch_size ({target.size(0)})."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
n = input_tensor.size(0)
|
|
120
|
+
out_size = (n,) + input_tensor.size()[2:]
|
|
121
|
+
if target.size()[1:] != input_tensor.size()[2:]:
|
|
122
|
+
raise ValueError(f"Expected target size {out_size}, got {target.size()}")
|
|
123
|
+
|
|
124
|
+
if not input_tensor.device == target.device:
|
|
125
|
+
raise ValueError(f"input and target must be in the same device. Got: {input_tensor.device} and {target.device}")
|
|
126
|
+
|
|
127
|
+
# compute softmax over the classes axis
|
|
128
|
+
input_soft: torch.Tensor = F.softmax(input_tensor, dim=1)
|
|
129
|
+
log_input_soft: torch.Tensor = F.log_softmax(input_tensor, dim=1)
|
|
130
|
+
|
|
131
|
+
# create the labels one hot tensor
|
|
132
|
+
target_one_hot: torch.Tensor = one_hot(
|
|
133
|
+
target, num_classes=input_tensor.shape[1], device=input_tensor.device, dtype=input_tensor.dtype
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# compute the actual focal loss
|
|
137
|
+
weight = torch.pow(-input_soft + 1.0, gamma)
|
|
138
|
+
|
|
139
|
+
focal = -alpha * weight * log_input_soft
|
|
140
|
+
loss_tmp = torch.einsum("bc...,bc...->b...", (target_one_hot, focal))
|
|
141
|
+
|
|
142
|
+
if reduction == "none":
|
|
143
|
+
loss = loss_tmp
|
|
144
|
+
elif reduction == "mean":
|
|
145
|
+
loss = torch.mean(loss_tmp)
|
|
146
|
+
elif reduction == "sum":
|
|
147
|
+
loss = torch.sum(loss_tmp)
|
|
148
|
+
else:
|
|
149
|
+
raise NotImplementedError(f"Invalid reduction mode: {reduction}")
|
|
150
|
+
return loss
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class FocalLoss(nn.Module):
|
|
154
|
+
r"""Criterion that computes Focal loss.
|
|
155
|
+
|
|
156
|
+
According to :cite:`lin2018focal`, the Focal loss is computed as follows:
|
|
157
|
+
|
|
158
|
+
.. math::
|
|
159
|
+
|
|
160
|
+
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
|
|
161
|
+
|
|
162
|
+
Where:
|
|
163
|
+
- :math:`p_t` is the model's estimated probability for each class.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
alpha: Weighting factor :math:`\alpha \in [0, 1]`.
|
|
167
|
+
gamma: Focusing parameter :math:`\gamma >= 0`.
|
|
168
|
+
reduction: Specifies the reduction to apply to the
|
|
169
|
+
output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
|
|
170
|
+
will be applied, ``'mean'``: the sum of the output will be divided by
|
|
171
|
+
the number of elements in the output, ``'sum'``: the output will be
|
|
172
|
+
summed.
|
|
173
|
+
eps: Deprecated: scalar to enforce numerical stability. This is no longer
|
|
174
|
+
used.
|
|
175
|
+
|
|
176
|
+
Shape:
|
|
177
|
+
- Input: :math:`(N, C, *)` where C = number of classes.
|
|
178
|
+
- Target: :math:`(N, *)` where each value is
|
|
179
|
+
:math:`0 ≤ targets[i] ≤ C−1`.
|
|
180
|
+
|
|
181
|
+
Example:
|
|
182
|
+
>>> N = 5 # num_classes
|
|
183
|
+
>>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'}
|
|
184
|
+
>>> criterion = FocalLoss(**kwargs)
|
|
185
|
+
>>> input = torch.randn(1, N, 3, 5, requires_grad=True)
|
|
186
|
+
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
|
|
187
|
+
>>> output = criterion(input, target)
|
|
188
|
+
>>> output.backward()
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
def __init__(self, alpha: float, gamma: float = 2.0, reduction: str = "none", eps: float | None = None) -> None:
|
|
192
|
+
super().__init__()
|
|
193
|
+
self.alpha: float = alpha
|
|
194
|
+
self.gamma: float = gamma
|
|
195
|
+
self.reduction: str = reduction
|
|
196
|
+
self.eps: float | None = eps
|
|
197
|
+
|
|
198
|
+
def forward(self, input_tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
199
|
+
"""Forward call computation."""
|
|
200
|
+
return focal_loss(input_tensor, target, self.alpha, self.gamma, self.reduction, self.eps)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def binary_focal_loss_with_logits(
|
|
204
|
+
input_tensor: torch.Tensor,
|
|
205
|
+
target: torch.Tensor,
|
|
206
|
+
alpha: float = 0.25,
|
|
207
|
+
gamma: float = 2.0,
|
|
208
|
+
reduction: str = "none",
|
|
209
|
+
eps: float | None = None,
|
|
210
|
+
) -> torch.Tensor:
|
|
211
|
+
r"""Function that computes Binary Focal loss.
|
|
212
|
+
|
|
213
|
+
.. math::
|
|
214
|
+
|
|
215
|
+
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
|
|
216
|
+
|
|
217
|
+
where:
|
|
218
|
+
- :math:`p_t` is the model's estimated probability for each class.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
input_tensor: input data tensor of arbitrary shape.
|
|
222
|
+
target: the target tensor with shape matching input.
|
|
223
|
+
alpha: Weighting factor for the rare class :math:`\alpha \in [0, 1]`.
|
|
224
|
+
gamma: Focusing parameter :math:`\gamma >= 0`.
|
|
225
|
+
reduction: Specifies the reduction to apply to the
|
|
226
|
+
output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
|
|
227
|
+
will be applied, ``'mean'``: the sum of the output will be divided by
|
|
228
|
+
the number of elements in the output, ``'sum'``: the output will be
|
|
229
|
+
summed.
|
|
230
|
+
eps: Deprecated: scalar for numerically stability when dividing. This is no longer used.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
the computed loss.
|
|
234
|
+
|
|
235
|
+
Examples:
|
|
236
|
+
>>> kwargs = {"alpha": 0.25, "gamma": 2.0, "reduction": 'mean'}
|
|
237
|
+
>>> logits = torch.tensor([[[6.325]],[[5.26]],[[87.49]]])
|
|
238
|
+
>>> labels = torch.tensor([[[1.]],[[1.]],[[0.]]])
|
|
239
|
+
>>> binary_focal_loss_with_logits(logits, labels, **kwargs)
|
|
240
|
+
tensor(21.8725)
|
|
241
|
+
"""
|
|
242
|
+
if eps is not None and not torch.jit.is_scripting():
|
|
243
|
+
warnings.warn(
|
|
244
|
+
"`binary_focal_loss_with_logits` has been reworked for improved numerical stability "
|
|
245
|
+
"and the `eps` argument is no longer necessary",
|
|
246
|
+
DeprecationWarning,
|
|
247
|
+
stacklevel=2,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
if not isinstance(input_tensor, torch.Tensor):
|
|
251
|
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input_tensor)}")
|
|
252
|
+
|
|
253
|
+
if not len(input_tensor.shape) >= 2:
|
|
254
|
+
raise ValueError(f"Invalid input shape, we expect BxCx*. Got: {input_tensor.shape}")
|
|
255
|
+
|
|
256
|
+
if input_tensor.size(0) != target.size(0):
|
|
257
|
+
raise ValueError(
|
|
258
|
+
f"Expected input batch_size ({input_tensor.size(0)}) to match target batch_size ({target.size(0)})."
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
probs_pos = torch.sigmoid(input_tensor)
|
|
262
|
+
probs_neg = torch.sigmoid(-input_tensor)
|
|
263
|
+
loss_tmp = -alpha * torch.pow(probs_neg, gamma) * target * F.logsigmoid(input_tensor) - (1 - alpha) * torch.pow(
|
|
264
|
+
probs_pos, gamma
|
|
265
|
+
) * (1.0 - target) * F.logsigmoid(-input_tensor)
|
|
266
|
+
|
|
267
|
+
if reduction == "none":
|
|
268
|
+
loss = loss_tmp
|
|
269
|
+
elif reduction == "mean":
|
|
270
|
+
loss = torch.mean(loss_tmp)
|
|
271
|
+
elif reduction == "sum":
|
|
272
|
+
loss = torch.sum(loss_tmp)
|
|
273
|
+
else:
|
|
274
|
+
raise NotImplementedError(f"Invalid reduction mode: {reduction}")
|
|
275
|
+
return loss
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
class BinaryFocalLossWithLogits(nn.Module):
|
|
279
|
+
r"""Criterion that computes Focal loss.
|
|
280
|
+
|
|
281
|
+
According to :cite:`lin2018focal`, the Focal loss is computed as follows:
|
|
282
|
+
|
|
283
|
+
.. math::
|
|
284
|
+
|
|
285
|
+
\text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
|
|
286
|
+
|
|
287
|
+
where:
|
|
288
|
+
- :math:`p_t` is the model's estimated probability for each class.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
alpha: Weighting factor for the rare class :math:`\alpha \in [0, 1]`.
|
|
292
|
+
gamma: Focusing parameter :math:`\gamma >= 0`.
|
|
293
|
+
reduction: Specifies the reduction to apply to the
|
|
294
|
+
output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
|
|
295
|
+
will be applied, ``'mean'``: the sum of the output will be divided by
|
|
296
|
+
the number of elements in the output, ``'sum'``: the output will be
|
|
297
|
+
summed.
|
|
298
|
+
|
|
299
|
+
Shape:
|
|
300
|
+
- Input: :math:`(N, *)`.
|
|
301
|
+
- Target: :math:`(N, *)`.
|
|
302
|
+
|
|
303
|
+
Examples:
|
|
304
|
+
>>> kwargs = {"alpha": 0.25, "gamma": 2.0, "reduction": 'mean'}
|
|
305
|
+
>>> loss = BinaryFocalLossWithLogits(**kwargs)
|
|
306
|
+
>>> input = torch.randn(1, 3, 5, requires_grad=True)
|
|
307
|
+
>>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(2)
|
|
308
|
+
>>> output = loss(input, target)
|
|
309
|
+
>>> output.backward()
|
|
310
|
+
"""
|
|
311
|
+
|
|
312
|
+
def __init__(self, alpha: float, gamma: float = 2.0, reduction: str = "none") -> None:
|
|
313
|
+
super().__init__()
|
|
314
|
+
self.alpha: float = alpha
|
|
315
|
+
self.gamma: float = gamma
|
|
316
|
+
self.reduction: str = reduction
|
|
317
|
+
|
|
318
|
+
def forward(self, input_tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
319
|
+
"""Forward call computation."""
|
|
320
|
+
return binary_focal_loss_with_logits(input_tensor, target, self.alpha, self.gamma, self.reduction)
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch.nn import functional as F
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def euclidean_dist(
|
|
8
|
+
query: torch.Tensor,
|
|
9
|
+
prototypes: torch.Tensor,
|
|
10
|
+
sen: bool = True,
|
|
11
|
+
eps_pos: float = 1.0,
|
|
12
|
+
eps_neg: float = -1e-7,
|
|
13
|
+
eps: float = 1e-7,
|
|
14
|
+
) -> torch.Tensor:
|
|
15
|
+
"""Compute euclidean distance between two tensors.
|
|
16
|
+
SEN dissimilarity from https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123680120.pdf
|
|
17
|
+
Args:
|
|
18
|
+
query: feature of the network
|
|
19
|
+
prototypes: prototypes of the center
|
|
20
|
+
sen: Sen dissimilarity flag
|
|
21
|
+
eps_pos: similarity arg
|
|
22
|
+
eps_neg: similarity arg
|
|
23
|
+
eps: similarity arg.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Euclidian loss
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
# query: (n_classes * n_query) x d
|
|
30
|
+
# prototypes: n_classes x d
|
|
31
|
+
n = query.size(0)
|
|
32
|
+
m = prototypes.size(0)
|
|
33
|
+
d = query.size(1)
|
|
34
|
+
if d != prototypes.size(1):
|
|
35
|
+
raise ValueError("query and prototypes size[1] should be equal")
|
|
36
|
+
|
|
37
|
+
if sen:
|
|
38
|
+
# SEN dissimilarity from https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123680120.pdf
|
|
39
|
+
norm_query = torch.linalg.norm(query, ord=2, dim=1) # (n_classes * n_query) X 1
|
|
40
|
+
norm_prototypes = torch.linalg.norm(prototypes, ord=2, dim=1) # n_classes X 1
|
|
41
|
+
|
|
42
|
+
# We have to compute (||z|| - ||c||)^2 between all query points w.r.t.
|
|
43
|
+
# all support points
|
|
44
|
+
|
|
45
|
+
# Replicate each single query norm value m times
|
|
46
|
+
norm_query = norm_query.view(-1, 1).unsqueeze(1).expand(n, m, 1)
|
|
47
|
+
# Replicate all prototypes norm values n times
|
|
48
|
+
norm_prototypes = norm_prototypes.view(-1, 1).unsqueeze(0).expand(n, m, 1)
|
|
49
|
+
norm_diff = torch.pow(norm_query - norm_prototypes, 2).squeeze(2)
|
|
50
|
+
epsilon = torch.full((n, m), eps_neg).type_as(query)
|
|
51
|
+
if eps_pos != eps_neg:
|
|
52
|
+
# n_query = n // m
|
|
53
|
+
# for i in range(m):
|
|
54
|
+
# epsilon[i * n_query : (i + 1) * n_query, i] = 1.0
|
|
55
|
+
|
|
56
|
+
# Since query points with class i need to have a positive epsilon
|
|
57
|
+
# whenever they refer to support point with class i and since
|
|
58
|
+
# query and support points are ordered, we need to set:
|
|
59
|
+
# the 1st column of the 1st n_query rows to eps_pos
|
|
60
|
+
# the 2nd column of the 2nd n_query rows to eps_pos
|
|
61
|
+
# and so on
|
|
62
|
+
idxs = torch.eye(m, dtype=torch.bool).unsqueeze(1).expand(m, n // m, m).reshape(-1, m)
|
|
63
|
+
epsilon[idxs] = eps_pos
|
|
64
|
+
norm_diff = norm_diff * epsilon
|
|
65
|
+
|
|
66
|
+
# Replicate each single query point value m times
|
|
67
|
+
query = query.unsqueeze(1).expand(n, m, d)
|
|
68
|
+
# Replicate all prototype points values n times
|
|
69
|
+
prototypes = prototypes.unsqueeze(0).expand(n, m, d)
|
|
70
|
+
|
|
71
|
+
norm = torch.pow(query - prototypes, 2).sum(2)
|
|
72
|
+
if sen:
|
|
73
|
+
return torch.sqrt(norm + norm_diff + eps)
|
|
74
|
+
|
|
75
|
+
return norm
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def prototypical_loss(
|
|
79
|
+
coords: torch.Tensor,
|
|
80
|
+
target: torch.Tensor,
|
|
81
|
+
n_support: int,
|
|
82
|
+
prototypes: torch.Tensor | None = None,
|
|
83
|
+
sen: bool = True,
|
|
84
|
+
eps_pos: float = 1.0,
|
|
85
|
+
eps_neg: float = -1e-7,
|
|
86
|
+
):
|
|
87
|
+
"""Prototypical loss implementation.
|
|
88
|
+
|
|
89
|
+
Inspired by https://github.com/jakesnell/prototypical-networks/blob/master/protonets/models/few_shot.py
|
|
90
|
+
Compute the barycentres by averaging the features of n_support
|
|
91
|
+
samples for each class in target, computes then the distances from each
|
|
92
|
+
samples' features to each one of the barycentres, computes the
|
|
93
|
+
log_probability for each n_query samples for each one of the current
|
|
94
|
+
classes, of appartaining to a class c, loss and accuracy are then computed and returned.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
coords: The model output for a batch of samples
|
|
98
|
+
target: Ground truth for the above batch of samples
|
|
99
|
+
n_support: Number of samples to keep in account when computing
|
|
100
|
+
barycentres, for each one of the current classes
|
|
101
|
+
prototypes: if not None, is used for classification
|
|
102
|
+
sen: Sen dissimilarity flag
|
|
103
|
+
eps_pos: Sen positive similarity arg
|
|
104
|
+
eps_neg: Sen negative similarity arg
|
|
105
|
+
"""
|
|
106
|
+
classes = torch.unique(target, sorted=True)
|
|
107
|
+
n_classes = len(classes)
|
|
108
|
+
n_query = len(torch.where(target == classes[0])[0]) - n_support
|
|
109
|
+
|
|
110
|
+
# Check equality between classes and target with broadcasting:
|
|
111
|
+
# class_idxs[i, j] = True iff classes[i] == target[j]
|
|
112
|
+
class_idxs = classes.unsqueeze(1) == target
|
|
113
|
+
if prototypes is None:
|
|
114
|
+
# Get the prototypes as the mean of the support points,
|
|
115
|
+
# ordered by class
|
|
116
|
+
prototypes = torch.stack([coords[idx_list][:n_support] for idx_list in class_idxs]).mean(1) # n_classes X d
|
|
117
|
+
# Get query samples as the points NOT in the support set,
|
|
118
|
+
# where, after .view(-1, d), one has that
|
|
119
|
+
# the 1st n_query points refer to class 1
|
|
120
|
+
# the 2nd n_query points refer to class 2
|
|
121
|
+
# and so on
|
|
122
|
+
query_samples = torch.stack([coords[idx_list][n_support:] for idx_list in class_idxs]).view(
|
|
123
|
+
-1, prototypes.shape[-1]
|
|
124
|
+
) # (n_classes * n_query) X d
|
|
125
|
+
# Get distances, where dists[i, j] is the distance between
|
|
126
|
+
# query point i to support point j
|
|
127
|
+
dists = euclidean_dist(
|
|
128
|
+
query_samples, prototypes, sen=sen, eps_pos=eps_pos, eps_neg=eps_neg
|
|
129
|
+
) # (n_classes * n_query) X n_classes
|
|
130
|
+
log_p_y = F.log_softmax(-dists, dim=1)
|
|
131
|
+
log_p_y = log_p_y.view(n_classes, n_query, -1) # n_classes X n_query X n_classes
|
|
132
|
+
|
|
133
|
+
target_inds = torch.arange(0, n_classes).view(n_classes, 1, 1)
|
|
134
|
+
# One solution is to use type_as(coords[0])
|
|
135
|
+
target_inds = target_inds.type_as(coords)
|
|
136
|
+
target_inds = target_inds.expand(n_classes, n_query, 1).long()
|
|
137
|
+
|
|
138
|
+
# Since we need to backpropagate the log softmax of query points
|
|
139
|
+
# of class i that refers to support of the same class for every i,
|
|
140
|
+
# and since query and support are ordered we select:
|
|
141
|
+
# from the 1st n_query X n_classes the 1st column
|
|
142
|
+
# from the 2nd n_query X n_classes the 2st column
|
|
143
|
+
# and so on
|
|
144
|
+
loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
|
|
145
|
+
_, y_hat = log_p_y.max(2)
|
|
146
|
+
acc_val = y_hat.eq(target_inds.squeeze()).float().mean()
|
|
147
|
+
|
|
148
|
+
return loss_val, acc_val, prototypes
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .barlowtwins import BarlowTwinsLoss
|
|
2
|
+
from .byol import BYOLRegressionLoss
|
|
3
|
+
from .dino import DinoDistillationLoss
|
|
4
|
+
from .idmm import IDMMLoss
|
|
5
|
+
from .simclr import SimCLRLoss
|
|
6
|
+
from .simsiam import SimSIAMLoss
|
|
7
|
+
from .vicreg import VICRegLoss
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"BarlowTwinsLoss",
|
|
11
|
+
"BYOLRegressionLoss",
|
|
12
|
+
"IDMMLoss",
|
|
13
|
+
"SimCLRLoss",
|
|
14
|
+
"SimSIAMLoss",
|
|
15
|
+
"VICRegLoss",
|
|
16
|
+
"DinoDistillationLoss",
|
|
17
|
+
]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def barlowtwins_loss(
|
|
5
|
+
z1: torch.Tensor,
|
|
6
|
+
z2: torch.Tensor,
|
|
7
|
+
lambd: float,
|
|
8
|
+
) -> torch.Tensor:
|
|
9
|
+
"""BarlowTwins loss described in https://arxiv.org/abs/2103.03230.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
z1: First `augmented` normalized features (i.e. f(T(x))).
|
|
13
|
+
The normalization can be obtained with
|
|
14
|
+
z1_norm = (z1 - z1.mean(0)) / z1.std(0)
|
|
15
|
+
z2: Second `augmented` normalized features (i.e. f(T(x))).
|
|
16
|
+
The normalization can be obtained with
|
|
17
|
+
z2_norm = (z2 - z2.mean(0)) / z2.std(0)
|
|
18
|
+
lambd: lambda multiplier for redundancy term.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
BarlowTwins loss
|
|
22
|
+
"""
|
|
23
|
+
z1 = (z1 - z1.mean(0)) / z1.std(0)
|
|
24
|
+
z1 = (z2 - z2.mean(0)) / z2.std(0)
|
|
25
|
+
cov = z1.T @ z2
|
|
26
|
+
cov.div_(z1.size(0))
|
|
27
|
+
n = cov.size(0)
|
|
28
|
+
invariance_term = torch.diagonal(cov).add_(-1).pow_(2).sum()
|
|
29
|
+
off_diag = cov.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
|
|
30
|
+
redundancy_term = off_diag.pow_(2).sum()
|
|
31
|
+
return invariance_term + lambd * redundancy_term
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class BarlowTwinsLoss(torch.nn.Module):
|
|
35
|
+
"""BarlowTwin loss.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
lambd: lambda of the loss.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, lambd: float):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.lambd = lambd
|
|
44
|
+
|
|
45
|
+
def forward(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
|
|
46
|
+
"""Compute the BarlowTwins loss."""
|
|
47
|
+
return barlowtwins_loss(z1, z2, self.lambd)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
from torch import nn
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def byol_regression_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
7
|
+
"""Byol regression loss
|
|
8
|
+
Args:
|
|
9
|
+
x: tensor
|
|
10
|
+
y: tensor.
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
tensor
|
|
14
|
+
"""
|
|
15
|
+
x = F.normalize(x, dim=-1)
|
|
16
|
+
y = F.normalize(y, dim=-1)
|
|
17
|
+
return 2 - 2 * (x * y).sum(dim=1).mean()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class BYOLRegressionLoss(nn.Module):
|
|
21
|
+
"""BYOL regression loss module."""
|
|
22
|
+
|
|
23
|
+
def forward(
|
|
24
|
+
self,
|
|
25
|
+
x: torch.Tensor,
|
|
26
|
+
y: torch.Tensor,
|
|
27
|
+
) -> torch.Tensor:
|
|
28
|
+
"""Compute the BYOL regression loss.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
x: First Tensor
|
|
32
|
+
y: Second Tensor
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
BYOL regression loss
|
|
36
|
+
"""
|
|
37
|
+
return byol_regression_loss(x, y)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
log = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def dino_distillation_loss(
|
|
12
|
+
student_output: torch.Tensor,
|
|
13
|
+
teacher_output: torch.Tensor,
|
|
14
|
+
center_vector: torch.Tensor,
|
|
15
|
+
teacher_temp: float = 0.04,
|
|
16
|
+
student_temp: float = 0.1,
|
|
17
|
+
) -> torch.Tensor:
|
|
18
|
+
"""Compute the DINO distillation loss.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
student_output: tensor of the student output
|
|
22
|
+
teacher_output: tensor of the teacher output
|
|
23
|
+
center_vector: center vector of distribution
|
|
24
|
+
teacher_temp: temperature teacher
|
|
25
|
+
student_temp: temperature student.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
The computed loss
|
|
29
|
+
"""
|
|
30
|
+
student_temp = [s / student_temp for s in student_output]
|
|
31
|
+
teacher_temp = [(t - center_vector) / teacher_temp for t in teacher_output]
|
|
32
|
+
|
|
33
|
+
student_sm = [F.log_softmax(s, dim=-1) for s in student_temp]
|
|
34
|
+
teacher_sm = [F.softmax(t, dim=-1).detach() for t in teacher_temp]
|
|
35
|
+
|
|
36
|
+
total_loss = torch.tensor(0.0, device=student_output[0].device)
|
|
37
|
+
n_loss_terms = torch.tensor(0.0, device=student_output[0].device)
|
|
38
|
+
|
|
39
|
+
for t_ix, t in enumerate(teacher_sm):
|
|
40
|
+
for s_ix, s in enumerate(student_sm):
|
|
41
|
+
if t_ix == s_ix:
|
|
42
|
+
continue
|
|
43
|
+
|
|
44
|
+
loss = torch.sum(-t * s, dim=-1) # (n_samples,)
|
|
45
|
+
total_loss += loss.mean() # scalar
|
|
46
|
+
n_loss_terms += 1
|
|
47
|
+
|
|
48
|
+
total_loss /= n_loss_terms
|
|
49
|
+
return total_loss
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class DinoDistillationLoss(nn.Module):
|
|
53
|
+
"""Dino distillation loss module.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
output_dim: output dim.
|
|
57
|
+
max_epochs: max epochs.
|
|
58
|
+
warmup_teacher_temp: warmup temperature.
|
|
59
|
+
teacher_temp: teacher temperature.
|
|
60
|
+
warmup_teacher_temp_epochs: warmup teacher epocs.
|
|
61
|
+
student_temp: student temperature.
|
|
62
|
+
center_momentum: center momentum.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
output_dim: int,
|
|
68
|
+
max_epochs: int,
|
|
69
|
+
warmup_teacher_temp: float = 0.04,
|
|
70
|
+
teacher_temp: float = 0.07,
|
|
71
|
+
warmup_teacher_temp_epochs: int = 30,
|
|
72
|
+
student_temp: float = 0.1,
|
|
73
|
+
center_momentum: float = 0.9,
|
|
74
|
+
):
|
|
75
|
+
super().__init__()
|
|
76
|
+
self.student_temp = student_temp
|
|
77
|
+
self.center_momentum = center_momentum
|
|
78
|
+
self.center: torch.Tensor
|
|
79
|
+
# we apply a warm up for the teacher temperature because
|
|
80
|
+
# a too high temperature makes the training instable at the beginning
|
|
81
|
+
|
|
82
|
+
if warmup_teacher_temp_epochs >= max_epochs:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"Number of warmup epochs ({warmup_teacher_temp_epochs}) must be smaller than max_epochs ({max_epochs})"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if warmup_teacher_temp_epochs < 30:
|
|
88
|
+
log.warning("Warmup teacher epochs is very small (< 30). This may cause instabilities in the training")
|
|
89
|
+
|
|
90
|
+
self.teacher_temp_schedule = np.concatenate(
|
|
91
|
+
(
|
|
92
|
+
np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
|
|
93
|
+
np.ones(max_epochs - warmup_teacher_temp_epochs) * teacher_temp,
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
self.register_buffer("center", torch.zeros(1, output_dim))
|
|
97
|
+
|
|
98
|
+
def forward(
|
|
99
|
+
self,
|
|
100
|
+
current_epoch: int,
|
|
101
|
+
student_output: torch.Tensor,
|
|
102
|
+
teacher_output: torch.Tensor,
|
|
103
|
+
) -> torch.Tensor:
|
|
104
|
+
"""Runs forward."""
|
|
105
|
+
teacher_temp = self.teacher_temp_schedule[current_epoch]
|
|
106
|
+
loss = dino_distillation_loss(
|
|
107
|
+
student_output,
|
|
108
|
+
teacher_output,
|
|
109
|
+
center_vector=self.center,
|
|
110
|
+
teacher_temp=teacher_temp,
|
|
111
|
+
student_temp=self.student_temp,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
self.update_center(teacher_output)
|
|
115
|
+
return loss
|
|
116
|
+
|
|
117
|
+
@torch.no_grad()
|
|
118
|
+
def update_center(self, teacher_output: torch.Tensor) -> None:
|
|
119
|
+
"""Update center of the distribution of the teacher
|
|
120
|
+
Args:
|
|
121
|
+
teacher_output: teacher output.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
None
|
|
125
|
+
"""
|
|
126
|
+
# TODO: check if this is correct
|
|
127
|
+
# torch.cat expects a list of tensors but teacher_output is a tensor
|
|
128
|
+
batch_center = torch.cat(teacher_output).mean(dim=0, keepdim=True) # type: ignore[call-overload]
|
|
129
|
+
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
|