quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +32 -0
- quadra/configs/callbacks/default.yaml +37 -0
- quadra/configs/callbacks/default_anomalib.yaml +67 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +46 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +331 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1264 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +579 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +524 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.1.13.dist-info/LICENSE +201 -0
- quadra-2.1.13.dist-info/METADATA +386 -0
- quadra-2.1.13.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
- quadra-2.1.13.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
from quadra.utils.models import trunc_normal_
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ProjectionHead(torch.nn.Module):
|
|
10
|
+
"""Base class for all projection and prediction heads.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
blocks:
|
|
14
|
+
List of tuples, each denoting one block of the projection head MLP.
|
|
15
|
+
Each tuple reads (linear_layer, batch_norm_layer, non_linearity_layer).
|
|
16
|
+
`batch_norm` layer can be possibly None, the same happens for
|
|
17
|
+
`non_linearity_layer`.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, blocks: list[tuple[torch.nn.Module | None, ...]]):
|
|
21
|
+
super().__init__()
|
|
22
|
+
|
|
23
|
+
layers: list[nn.Module] = []
|
|
24
|
+
for linear, batch_norm, non_linearity in blocks:
|
|
25
|
+
if linear:
|
|
26
|
+
layers.append(linear)
|
|
27
|
+
if batch_norm:
|
|
28
|
+
layers.append(batch_norm)
|
|
29
|
+
if non_linearity:
|
|
30
|
+
layers.append(non_linearity)
|
|
31
|
+
self.layers = torch.nn.Sequential(*layers)
|
|
32
|
+
|
|
33
|
+
def forward(self, x: torch.Tensor):
|
|
34
|
+
return self.layers(x)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ExpanderReducer(ProjectionHead):
|
|
38
|
+
"""Expander followed by a reducer."""
|
|
39
|
+
|
|
40
|
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
|
41
|
+
super().__init__(
|
|
42
|
+
[
|
|
43
|
+
(
|
|
44
|
+
torch.nn.Linear(input_dim, hidden_dim, bias=False),
|
|
45
|
+
torch.nn.BatchNorm1d(hidden_dim),
|
|
46
|
+
torch.nn.ReLU(inplace=True),
|
|
47
|
+
),
|
|
48
|
+
(
|
|
49
|
+
torch.nn.Linear(hidden_dim, output_dim, bias=False),
|
|
50
|
+
torch.nn.BatchNorm1d(output_dim, affine=False),
|
|
51
|
+
torch.nn.ReLU(inplace=True),
|
|
52
|
+
),
|
|
53
|
+
]
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class BarlowTwinsProjectionHead(ProjectionHead):
|
|
58
|
+
"""Projection head used for Barlow Twins.
|
|
59
|
+
"The projector network has three linear layers, each with 8192 output
|
|
60
|
+
units. The first two layers of the projector are followed by a batch
|
|
61
|
+
normalization layer and rectified linear units." https://arxiv.org/abs/2103.03230.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
|
65
|
+
super().__init__(
|
|
66
|
+
[
|
|
67
|
+
(
|
|
68
|
+
torch.nn.Linear(input_dim, hidden_dim, bias=False),
|
|
69
|
+
torch.nn.BatchNorm1d(hidden_dim),
|
|
70
|
+
torch.nn.ReLU(inplace=True),
|
|
71
|
+
),
|
|
72
|
+
(
|
|
73
|
+
torch.nn.Linear(hidden_dim, hidden_dim, bias=False),
|
|
74
|
+
torch.nn.BatchNorm1d(hidden_dim),
|
|
75
|
+
torch.nn.ReLU(inplace=True),
|
|
76
|
+
),
|
|
77
|
+
(torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
|
|
78
|
+
]
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class SimCLRProjectionHead(ProjectionHead):
|
|
83
|
+
"""Projection head used for SimCLR.
|
|
84
|
+
"We use a MLP with one hidden layer to obtain zi = g(h) = W_2 * σ(W_1 * h)
|
|
85
|
+
where σ is a ReLU non-linearity." https://arxiv.org/abs/2002.05709.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
|
89
|
+
super().__init__(
|
|
90
|
+
[
|
|
91
|
+
(
|
|
92
|
+
torch.nn.Linear(input_dim, hidden_dim),
|
|
93
|
+
None,
|
|
94
|
+
torch.nn.ReLU(inplace=True),
|
|
95
|
+
),
|
|
96
|
+
(torch.nn.Linear(hidden_dim, output_dim), None, None),
|
|
97
|
+
]
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class SimCLRPredictionHead(ProjectionHead):
|
|
102
|
+
"""Prediction head used for SimCLR.
|
|
103
|
+
"We set g(h) = W(2)σ(W(1)h), with the same input and output dimensionality (i.e. 2048)."
|
|
104
|
+
https://arxiv.org/abs/2002.05709.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
|
108
|
+
super().__init__(
|
|
109
|
+
[
|
|
110
|
+
(
|
|
111
|
+
torch.nn.Linear(input_dim, hidden_dim, bias=False),
|
|
112
|
+
torch.nn.BatchNorm1d(hidden_dim),
|
|
113
|
+
torch.nn.ReLU(inplace=True),
|
|
114
|
+
),
|
|
115
|
+
(torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
|
|
116
|
+
]
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class SimSiamProjectionHead(ProjectionHead):
|
|
121
|
+
"""Projection head used for SimSiam.
|
|
122
|
+
"The projection MLP (in f) has BN applied to each fully-connected (fc)
|
|
123
|
+
layer, including its output fc. Its output fc has no ReLU. The hidden fc is
|
|
124
|
+
2048-d. This MLP has 3 layers." https://arxiv.org/abs/2011.10566.
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
|
128
|
+
super().__init__(
|
|
129
|
+
[
|
|
130
|
+
(
|
|
131
|
+
torch.nn.Linear(input_dim, hidden_dim, bias=False),
|
|
132
|
+
torch.nn.BatchNorm1d(hidden_dim),
|
|
133
|
+
torch.nn.ReLU(inplace=True),
|
|
134
|
+
),
|
|
135
|
+
(
|
|
136
|
+
torch.nn.Linear(hidden_dim, hidden_dim, bias=False),
|
|
137
|
+
torch.nn.BatchNorm1d(hidden_dim, affine=False),
|
|
138
|
+
torch.nn.ReLU(inplace=True),
|
|
139
|
+
),
|
|
140
|
+
(
|
|
141
|
+
torch.nn.Linear(hidden_dim, output_dim, bias=False),
|
|
142
|
+
torch.nn.BatchNorm1d(output_dim, affine=False),
|
|
143
|
+
None,
|
|
144
|
+
),
|
|
145
|
+
]
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class SimSiamPredictionHead(ProjectionHead):
|
|
150
|
+
"""Prediction head used for SimSiam.
|
|
151
|
+
"The prediction MLP (h) has BN applied to its hidden fc layers. Its output
|
|
152
|
+
fc does not have BN (...) or ReLU. This MLP has 2 layers." https://arxiv.org/abs/2011.10566.
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
|
156
|
+
super().__init__(
|
|
157
|
+
[
|
|
158
|
+
(
|
|
159
|
+
torch.nn.Linear(input_dim, hidden_dim, bias=False),
|
|
160
|
+
torch.nn.BatchNorm1d(hidden_dim),
|
|
161
|
+
torch.nn.ReLU(inplace=True),
|
|
162
|
+
),
|
|
163
|
+
(torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
|
|
164
|
+
]
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class BYOLPredictionHead(ProjectionHead):
|
|
169
|
+
"""Prediction head used for BYOL."""
|
|
170
|
+
|
|
171
|
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
|
172
|
+
super().__init__(
|
|
173
|
+
[
|
|
174
|
+
(
|
|
175
|
+
torch.nn.Linear(input_dim, hidden_dim, bias=False),
|
|
176
|
+
torch.nn.BatchNorm1d(hidden_dim),
|
|
177
|
+
torch.nn.ReLU(inplace=True),
|
|
178
|
+
),
|
|
179
|
+
(torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
|
|
180
|
+
]
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class BYOLProjectionHead(ProjectionHead):
|
|
185
|
+
"""Projection head used for BYOL."""
|
|
186
|
+
|
|
187
|
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
|
188
|
+
super().__init__(
|
|
189
|
+
[
|
|
190
|
+
(
|
|
191
|
+
torch.nn.Linear(input_dim, hidden_dim, bias=False),
|
|
192
|
+
torch.nn.BatchNorm1d(hidden_dim),
|
|
193
|
+
torch.nn.ReLU(inplace=True),
|
|
194
|
+
),
|
|
195
|
+
(torch.nn.Linear(hidden_dim, output_dim, bias=False), None, None),
|
|
196
|
+
]
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class DinoProjectionHead(nn.Module):
|
|
201
|
+
"""Projection head used for Dino. This projection head does not have
|
|
202
|
+
a batch norm layer.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
input_dim: Input dimension for MLP head.
|
|
206
|
+
output_dim: Output dimension (projection dimension) for MLP head.
|
|
207
|
+
hidden_dim: Hidden dimension. Defaults to 512.
|
|
208
|
+
bottleneck_dim: Bottleneck dimension. Defaults to 256.
|
|
209
|
+
num_layers: Number of hidden layers used in MLP. Defaults to 3.
|
|
210
|
+
norm_last_layer: Decides applying normalization before last layer.
|
|
211
|
+
Defaults to False.
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
def __init__(
|
|
215
|
+
self,
|
|
216
|
+
input_dim: int,
|
|
217
|
+
output_dim: int,
|
|
218
|
+
hidden_dim: int,
|
|
219
|
+
use_bn: bool = False,
|
|
220
|
+
norm_last_layer: bool = True,
|
|
221
|
+
num_layers: int = 3,
|
|
222
|
+
bottleneck_dim: int = 256,
|
|
223
|
+
):
|
|
224
|
+
super().__init__()
|
|
225
|
+
num_layers = max(num_layers, 1)
|
|
226
|
+
self.mlp: nn.Linear | nn.Sequential
|
|
227
|
+
if num_layers == 1:
|
|
228
|
+
self.mlp = nn.Linear(input_dim, bottleneck_dim)
|
|
229
|
+
else:
|
|
230
|
+
layers: list[nn.Module] = [nn.Linear(input_dim, hidden_dim)]
|
|
231
|
+
if use_bn:
|
|
232
|
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
|
233
|
+
layers.append(nn.GELU())
|
|
234
|
+
for _ in range(num_layers - 2):
|
|
235
|
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
|
236
|
+
if use_bn:
|
|
237
|
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
|
238
|
+
layers.append(nn.GELU())
|
|
239
|
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
|
240
|
+
self.mlp = nn.Sequential(*layers)
|
|
241
|
+
self.apply(self._init_weights)
|
|
242
|
+
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, output_dim, bias=False))
|
|
243
|
+
self.last_layer.weight_g.data.fill_(1)
|
|
244
|
+
if norm_last_layer:
|
|
245
|
+
self.last_layer.weight_g.requires_grad = False
|
|
246
|
+
|
|
247
|
+
def _init_weights(self, m):
|
|
248
|
+
"""Initialize the weights of the projection head."""
|
|
249
|
+
if isinstance(m, nn.Linear):
|
|
250
|
+
trunc_normal_(m.weight, std=0.02)
|
|
251
|
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
252
|
+
nn.init.constant_(m.bias, 0)
|
|
253
|
+
|
|
254
|
+
def forward(self, x):
|
|
255
|
+
x = self.mlp(x)
|
|
256
|
+
x = nn.functional.normalize(x, dim=-1, p=2)
|
|
257
|
+
x = self.last_layer(x)
|
|
258
|
+
return x
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class MultiCropModel(nn.Module):
|
|
262
|
+
"""MultiCrop model for DINO augmentation.
|
|
263
|
+
|
|
264
|
+
It takes 2 global crops and N (possible) local crops as a single tensor.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
backbone: Backbone model.
|
|
268
|
+
head: Head model.
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
def __init__(self, backbone: nn.Module, head: nn.Module):
|
|
272
|
+
super().__init__()
|
|
273
|
+
self.backbone = backbone
|
|
274
|
+
self.head = head
|
|
275
|
+
|
|
276
|
+
def forward(self, x):
|
|
277
|
+
n_crops = len(x)
|
|
278
|
+
# (n_samples * n_crops, 3, size, size)
|
|
279
|
+
concatenated = torch.cat(x, dim=0)
|
|
280
|
+
# (n_samples * n_crops, in_dim)
|
|
281
|
+
embedding = self.backbone(concatenated)
|
|
282
|
+
logits = self.head(embedding) # (n_samples * n_crops, out_dim)
|
|
283
|
+
chunks = logits.chunk(n_crops) # n_crops * (n_samples, out_dim)
|
|
284
|
+
|
|
285
|
+
return chunks
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import sklearn
|
|
7
|
+
import torch
|
|
8
|
+
from pytorch_lightning.core.optimizer import LightningOptimizer
|
|
9
|
+
from torch import nn
|
|
10
|
+
from torch.optim import Optimizer
|
|
11
|
+
|
|
12
|
+
from quadra.modules.ssl import BYOL
|
|
13
|
+
from quadra.utils.models import clip_gradients
|
|
14
|
+
from quadra.utils.utils import get_logger
|
|
15
|
+
|
|
16
|
+
log = get_logger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Dino(BYOL):
|
|
20
|
+
"""DINO pytorch-lightning module.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
student : student model
|
|
24
|
+
teacher : teacher model
|
|
25
|
+
student_projection_mlp : student projection MLP
|
|
26
|
+
teacher_projection_mlp : teacher projection MLP
|
|
27
|
+
criterion : loss function
|
|
28
|
+
freeze_last_layer : number of layers to freeze in the student model. Default: 1
|
|
29
|
+
classifier: Standard sklearn classifier
|
|
30
|
+
optimizer: optimizer of the training. If None a default Adam is used.
|
|
31
|
+
lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used.
|
|
32
|
+
lr_scheduler_interval: interval at which the lr scheduler is updated.
|
|
33
|
+
teacher_momentum: momentum of the teacher parameters
|
|
34
|
+
teacher_momentum_cosine_decay: whether to use cosine decay for the teacher momentum
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
student: nn.Module,
|
|
40
|
+
teacher: nn.Module,
|
|
41
|
+
student_projection_mlp: nn.Module,
|
|
42
|
+
teacher_projection_mlp: nn.Module,
|
|
43
|
+
criterion: nn.Module,
|
|
44
|
+
freeze_last_layer: int = 1,
|
|
45
|
+
classifier: sklearn.base.ClassifierMixin | None = None,
|
|
46
|
+
optimizer: Optimizer | None = None,
|
|
47
|
+
lr_scheduler: object | None = None,
|
|
48
|
+
lr_scheduler_interval: str | None = "epoch",
|
|
49
|
+
teacher_momentum: float = 0.9995,
|
|
50
|
+
teacher_momentum_cosine_decay: bool | None = True,
|
|
51
|
+
):
|
|
52
|
+
super().__init__(
|
|
53
|
+
student=student,
|
|
54
|
+
teacher=teacher,
|
|
55
|
+
student_projection_mlp=student_projection_mlp,
|
|
56
|
+
student_prediction_mlp=nn.Identity(),
|
|
57
|
+
teacher_projection_mlp=teacher_projection_mlp,
|
|
58
|
+
criterion=criterion,
|
|
59
|
+
teacher_momentum=teacher_momentum,
|
|
60
|
+
teacher_momentum_cosine_decay=teacher_momentum_cosine_decay,
|
|
61
|
+
classifier=classifier,
|
|
62
|
+
optimizer=optimizer,
|
|
63
|
+
lr_scheduler=lr_scheduler,
|
|
64
|
+
lr_scheduler_interval=lr_scheduler_interval,
|
|
65
|
+
)
|
|
66
|
+
self.freeze_last_layer = freeze_last_layer
|
|
67
|
+
|
|
68
|
+
def initialize_teacher(self):
|
|
69
|
+
"""Initialize teacher from the state dict of the student one,
|
|
70
|
+
checking also that student model requires greadient correctly.
|
|
71
|
+
"""
|
|
72
|
+
self.teacher_projection_mlp.load_state_dict(self.student_projection_mlp.state_dict())
|
|
73
|
+
for p in self.teacher_projection_mlp.parameters():
|
|
74
|
+
p.requires_grad = False
|
|
75
|
+
|
|
76
|
+
self.teacher.load_state_dict(self.model.state_dict())
|
|
77
|
+
for p in self.teacher.parameters():
|
|
78
|
+
p.requires_grad = False
|
|
79
|
+
|
|
80
|
+
all_frozen = True
|
|
81
|
+
for p in self.model.parameters():
|
|
82
|
+
all_frozen = all_frozen and (not p.requires_grad)
|
|
83
|
+
|
|
84
|
+
if all_frozen:
|
|
85
|
+
log.warning(
|
|
86
|
+
"All parameters of the student model are frozen, the model will not be trained, automatically"
|
|
87
|
+
" unfreezing all the layers"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
for p in self.model.parameters():
|
|
91
|
+
p.requires_grad = True
|
|
92
|
+
|
|
93
|
+
for name, p in self.student_projection_mlp.named_parameters():
|
|
94
|
+
if name != "last_layer.weight_g":
|
|
95
|
+
assert p.requires_grad is True
|
|
96
|
+
|
|
97
|
+
self.teacher_initialized = True
|
|
98
|
+
|
|
99
|
+
def student_multicrop_forward(self, x: list[torch.Tensor]) -> torch.Tensor:
|
|
100
|
+
"""Student forward on the multicrop imges.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
x: List of torch.Tensor containing multicropped augmented images
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
torch.Tensor: a tensor of shape NxBxD, where N is the number crops
|
|
107
|
+
corresponding to the length of the input list `x`, B is the batch size
|
|
108
|
+
and D is the output dimension
|
|
109
|
+
"""
|
|
110
|
+
n_crops = len(x)
|
|
111
|
+
concatenated = torch.cat(x, dim=0) # (n_samples * n_crops, C, H, W)
|
|
112
|
+
embedding = self.model(concatenated) # (n_samples * n_crops, in_dim)
|
|
113
|
+
logits = self.student_projection_mlp(embedding) # (n_samples * n_crops, out_dim)
|
|
114
|
+
chunks = logits.chunk(n_crops) # n_crops * (n_samples, out_dim)
|
|
115
|
+
return chunks
|
|
116
|
+
|
|
117
|
+
def teacher_multicrop_forward(self, x: list[torch.Tensor]) -> torch.Tensor:
|
|
118
|
+
"""Teacher forward on the multicrop imges.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
x: List of torch.Tensor containing multicropped augmented images
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
torch.Tensor: a tensor of shape NxBxD, where N is the number crops
|
|
125
|
+
corresponding to the length of the input list `x`, B is the batch size
|
|
126
|
+
and D is the output dimension
|
|
127
|
+
"""
|
|
128
|
+
n_crops = len(x)
|
|
129
|
+
concatenated = torch.cat(x, dim=0) # (n_samples * n_crops, C, H, W)
|
|
130
|
+
embedding = self.teacher(concatenated) # (n_samples * n_crops, in_dim)
|
|
131
|
+
logits = self.teacher_projection_mlp(embedding) # (n_samples * n_crops, out_dim)
|
|
132
|
+
chunks = logits.chunk(n_crops) # n_crops * (n_samples, out_dim)
|
|
133
|
+
return chunks
|
|
134
|
+
|
|
135
|
+
def cancel_gradients_last_layer(self, epoch: int, freeze_last_layer: int):
|
|
136
|
+
"""Zero out the gradient of the last layer, as specified in the paper.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
epoch: current epoch
|
|
140
|
+
freeze_last_layer: maximum freeze epoch: if `epoch` >= `freeze_last_layer`
|
|
141
|
+
then the gradient of the last layer will not be freezed
|
|
142
|
+
"""
|
|
143
|
+
if epoch >= freeze_last_layer:
|
|
144
|
+
return
|
|
145
|
+
for n, p in self.student_projection_mlp.named_parameters():
|
|
146
|
+
if "last_layer" in n:
|
|
147
|
+
p.grad = None
|
|
148
|
+
|
|
149
|
+
def training_step(self, batch: tuple[list[torch.Tensor], torch.Tensor], *args: Any) -> torch.Tensor:
|
|
150
|
+
images, _ = batch
|
|
151
|
+
with torch.no_grad():
|
|
152
|
+
teacher_output = self.teacher_multicrop_forward(images[:2])
|
|
153
|
+
|
|
154
|
+
student_output = self.student_multicrop_forward(images)
|
|
155
|
+
loss = self.criterion(self.current_epoch, student_output, teacher_output)
|
|
156
|
+
|
|
157
|
+
self.log(name="loss", value=loss, on_step=True, on_epoch=True, prog_bar=True)
|
|
158
|
+
return loss
|
|
159
|
+
|
|
160
|
+
def configure_gradient_clipping(
|
|
161
|
+
self,
|
|
162
|
+
optimizer: Optimizer,
|
|
163
|
+
gradient_clip_val: int | float | None = None,
|
|
164
|
+
gradient_clip_algorithm: str | None = None,
|
|
165
|
+
):
|
|
166
|
+
"""Configure gradient clipping for the optimizer."""
|
|
167
|
+
if gradient_clip_algorithm is not None and gradient_clip_val is not None:
|
|
168
|
+
clip_gradients(self.model, gradient_clip_val)
|
|
169
|
+
clip_gradients(self.student_projection_mlp, gradient_clip_val)
|
|
170
|
+
self.cancel_gradients_last_layer(self.current_epoch, self.freeze_last_layer)
|
|
171
|
+
|
|
172
|
+
def optimizer_step(
|
|
173
|
+
self,
|
|
174
|
+
epoch: int,
|
|
175
|
+
batch_idx: int,
|
|
176
|
+
optimizer: Optimizer | LightningOptimizer,
|
|
177
|
+
optimizer_closure: Callable[[], Any] | None = None,
|
|
178
|
+
) -> None:
|
|
179
|
+
"""Override optimizer step to update the teacher parameters."""
|
|
180
|
+
super().optimizer_step(
|
|
181
|
+
epoch,
|
|
182
|
+
batch_idx,
|
|
183
|
+
optimizer,
|
|
184
|
+
optimizer_closure=optimizer_closure,
|
|
185
|
+
)
|
|
186
|
+
self.update_teacher()
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from enum import Enum
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from torch import nn, optim
|
|
9
|
+
|
|
10
|
+
from quadra.losses.ssl import hyperspherical as loss
|
|
11
|
+
from quadra.modules.base import BaseLightningModule
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AlignLoss(Enum):
|
|
15
|
+
"""Align loss enum."""
|
|
16
|
+
|
|
17
|
+
L2 = 1
|
|
18
|
+
COSINE = 2
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TLHyperspherical(BaseLightningModule):
|
|
22
|
+
"""Hyperspherical model: maps features extracted from a pretrained backbone into
|
|
23
|
+
an hypersphere.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
model: Feature extractor as pytorch `torch.nn.Module`
|
|
27
|
+
optimizer: optimizer of the training.
|
|
28
|
+
If None a default Adam is used.
|
|
29
|
+
lr_scheduler: lr scheduler.
|
|
30
|
+
If None a default ReduceLROnPlateau is used.
|
|
31
|
+
align_weight: Weight for the align loss component for the
|
|
32
|
+
hyperspherical loss.
|
|
33
|
+
Defaults to 1.
|
|
34
|
+
unifo_weight: Weight for the uniform loss component for the
|
|
35
|
+
hyperspherical loss.
|
|
36
|
+
Defaults to 1.
|
|
37
|
+
classifier_weight: Weight for the classifier loss component for the
|
|
38
|
+
hyperspherical loss.
|
|
39
|
+
Defaults to 1.
|
|
40
|
+
align_loss_type: Which type of align loss to use.
|
|
41
|
+
Defaults to AlignLoss.L2.
|
|
42
|
+
classifier_loss: Whether to compute a classifier loss to 'enhance'
|
|
43
|
+
the hyperpsherical loss with the classification loss.
|
|
44
|
+
It True, model.classifier must be defined
|
|
45
|
+
Defaults to False.
|
|
46
|
+
num_classes: Number of classes for a classification problem.
|
|
47
|
+
Defaults to None.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
model: nn.Module,
|
|
53
|
+
optimizer: optim.Optimizer | None = None,
|
|
54
|
+
lr_scheduler: object | None = None,
|
|
55
|
+
align_weight: float = 1,
|
|
56
|
+
unifo_weight: float = 1,
|
|
57
|
+
classifier_weight: float = 1,
|
|
58
|
+
align_loss_type: AlignLoss = AlignLoss.L2,
|
|
59
|
+
classifier_loss: bool = False,
|
|
60
|
+
num_classes: int | None = None,
|
|
61
|
+
):
|
|
62
|
+
super().__init__(model, optimizer, lr_scheduler)
|
|
63
|
+
self.align_loss_fun: (
|
|
64
|
+
Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]
|
|
65
|
+
| Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
|
|
66
|
+
)
|
|
67
|
+
self.align_weight = align_weight
|
|
68
|
+
self.unifo_weight = unifo_weight
|
|
69
|
+
self.classifier_weight = classifier_weight
|
|
70
|
+
self.align_loss_type = align_loss_type
|
|
71
|
+
if align_loss_type == AlignLoss.L2:
|
|
72
|
+
self.align_loss_fun = loss.align_loss
|
|
73
|
+
elif align_loss_type == AlignLoss.COSINE:
|
|
74
|
+
self.align_loss_fun = loss.cosine_align_loss
|
|
75
|
+
else:
|
|
76
|
+
raise ValueError("The align loss must be one of 'AlignLoss.L2' (L2 distance) or AlignLoss.COSINE")
|
|
77
|
+
|
|
78
|
+
if classifier_loss and model.classifier is None:
|
|
79
|
+
raise AssertionError("Classifier is not defined")
|
|
80
|
+
|
|
81
|
+
self.classifier_loss = classifier_loss
|
|
82
|
+
self.num_classes = num_classes
|
|
83
|
+
|
|
84
|
+
def forward(self, x):
|
|
85
|
+
return self.model(x)
|
|
86
|
+
|
|
87
|
+
def training_step(self, batch, batch_idx):
|
|
88
|
+
# pylint: disable=unused-argument
|
|
89
|
+
im_x, im_y, target = batch
|
|
90
|
+
emb_x, emb_y = self(torch.cat([im_x, im_y])).chunk(2)
|
|
91
|
+
|
|
92
|
+
align_loss = 0.0
|
|
93
|
+
if self.align_weight > 0:
|
|
94
|
+
align_loss = self.align_loss_fun(emb_x, emb_y)
|
|
95
|
+
|
|
96
|
+
unifo_loss = 0.0
|
|
97
|
+
if self.unifo_weight > 0:
|
|
98
|
+
unifo_loss = (loss.uniform_loss(emb_x) + loss.uniform_loss(emb_y)) / 2
|
|
99
|
+
|
|
100
|
+
classifier_loss = 0.0
|
|
101
|
+
if self.classifier_loss:
|
|
102
|
+
pred = self.model.classifier(emb_x)
|
|
103
|
+
classifier_loss = F.cross_entropy(pred, target)
|
|
104
|
+
|
|
105
|
+
total_loss = (
|
|
106
|
+
self.align_weight * align_loss + self.unifo_weight * unifo_loss + self.classifier_weight * classifier_loss
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
self.log(
|
|
110
|
+
"t_loss",
|
|
111
|
+
total_loss,
|
|
112
|
+
on_epoch=True,
|
|
113
|
+
logger=True,
|
|
114
|
+
prog_bar=False,
|
|
115
|
+
)
|
|
116
|
+
self.log(
|
|
117
|
+
"t_align",
|
|
118
|
+
align_loss,
|
|
119
|
+
on_epoch=True,
|
|
120
|
+
on_step=False,
|
|
121
|
+
logger=True,
|
|
122
|
+
prog_bar=False,
|
|
123
|
+
)
|
|
124
|
+
self.log(
|
|
125
|
+
"t_classifier",
|
|
126
|
+
classifier_loss,
|
|
127
|
+
on_epoch=True,
|
|
128
|
+
on_step=False,
|
|
129
|
+
logger=True,
|
|
130
|
+
prog_bar=True,
|
|
131
|
+
)
|
|
132
|
+
self.log(
|
|
133
|
+
"t_unif",
|
|
134
|
+
unifo_loss,
|
|
135
|
+
on_epoch=True,
|
|
136
|
+
on_step=False,
|
|
137
|
+
logger=True,
|
|
138
|
+
prog_bar=False,
|
|
139
|
+
)
|
|
140
|
+
return {"loss": total_loss}
|
|
141
|
+
|
|
142
|
+
def train_epoch_end(self, outputs):
|
|
143
|
+
avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
|
|
144
|
+
|
|
145
|
+
return {"loss": avg_loss}
|
|
146
|
+
|
|
147
|
+
def validation_step(self, batch, batch_idx):
|
|
148
|
+
# pylint: disable=unused-argument
|
|
149
|
+
im_x, im_y, target = batch
|
|
150
|
+
emb_x, emb_y = self(torch.cat([im_x, im_y])).chunk(2)
|
|
151
|
+
|
|
152
|
+
align_loss = 0.0
|
|
153
|
+
if self.align_weight > 0:
|
|
154
|
+
align_loss = self.align_loss_fun(emb_x, emb_y)
|
|
155
|
+
|
|
156
|
+
unifo_loss = 0.0
|
|
157
|
+
if self.unifo_weight > 0:
|
|
158
|
+
unifo_loss = (loss.uniform_loss(emb_x) + loss.uniform_loss(emb_y)) / 2
|
|
159
|
+
|
|
160
|
+
classifier_loss = 0.0
|
|
161
|
+
if self.classifier_loss:
|
|
162
|
+
pred = self.model.classifier(emb_x)
|
|
163
|
+
classifier_loss = F.cross_entropy(pred, target)
|
|
164
|
+
|
|
165
|
+
total_loss = (
|
|
166
|
+
self.align_weight * align_loss + self.unifo_weight * unifo_loss + self.classifier_weight * classifier_loss
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
self.log(
|
|
170
|
+
"val_loss",
|
|
171
|
+
total_loss,
|
|
172
|
+
on_epoch=True,
|
|
173
|
+
on_step=False,
|
|
174
|
+
logger=True,
|
|
175
|
+
prog_bar=False,
|
|
176
|
+
)
|
|
177
|
+
self.log(
|
|
178
|
+
"v_classifier",
|
|
179
|
+
classifier_loss,
|
|
180
|
+
on_epoch=True,
|
|
181
|
+
on_step=False,
|
|
182
|
+
logger=True,
|
|
183
|
+
prog_bar=True,
|
|
184
|
+
)
|
|
185
|
+
self.log(
|
|
186
|
+
"v_align",
|
|
187
|
+
align_loss,
|
|
188
|
+
on_epoch=True,
|
|
189
|
+
on_step=False,
|
|
190
|
+
logger=True,
|
|
191
|
+
prog_bar=False,
|
|
192
|
+
)
|
|
193
|
+
self.log(
|
|
194
|
+
"v_unif",
|
|
195
|
+
unifo_loss,
|
|
196
|
+
on_epoch=True,
|
|
197
|
+
on_step=False,
|
|
198
|
+
logger=True,
|
|
199
|
+
prog_bar=False,
|
|
200
|
+
)
|
|
201
|
+
return {"val_loss": total_loss}
|
|
202
|
+
|
|
203
|
+
def on_validation_epoch_end(self, outputs):
|
|
204
|
+
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
|
|
205
|
+
|
|
206
|
+
return {"val_loss": avg_loss}
|