quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,327 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, cast
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import timm
|
|
7
|
+
import torch
|
|
8
|
+
import torchmetrics
|
|
9
|
+
import torchmetrics.functional as TMF
|
|
10
|
+
from pytorch_grad_cam import GradCAM
|
|
11
|
+
from scipy import ndimage
|
|
12
|
+
from torch import nn, optim
|
|
13
|
+
|
|
14
|
+
from quadra.models.classification import BaseNetworkBuilder
|
|
15
|
+
from quadra.modules.base import BaseLightningModule
|
|
16
|
+
from quadra.utils.models import is_vision_transformer
|
|
17
|
+
from quadra.utils.utils import get_logger
|
|
18
|
+
from quadra.utils.vit_explainability import VitAttentionGradRollout
|
|
19
|
+
|
|
20
|
+
log = get_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ClassificationModule(BaseLightningModule):
|
|
24
|
+
"""Lightning module for classification tasks.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model: Feature extractor as PyTorch `torch.nn.Module`
|
|
28
|
+
criterion: the loss to be applied as a PyTorch `torch.nn.Module`.
|
|
29
|
+
optimizer: optimizer of the training. Defaults to None.
|
|
30
|
+
lr_scheduler: Pytorch learning rate scheduler.
|
|
31
|
+
If None a default ReduceLROnPlateau is used.
|
|
32
|
+
Defaults to None.
|
|
33
|
+
lr_scheduler_interval: the learning rate scheduler interval.
|
|
34
|
+
Defaults to "epoch".
|
|
35
|
+
gradcam (bool): Whether to compute gradcam during prediction step
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
model: nn.Module,
|
|
41
|
+
criterion: nn.Module,
|
|
42
|
+
optimizer: None | optim.Optimizer = None,
|
|
43
|
+
lr_scheduler: None | object = None,
|
|
44
|
+
lr_scheduler_interval: str | None = "epoch",
|
|
45
|
+
gradcam: bool = False,
|
|
46
|
+
):
|
|
47
|
+
super().__init__(model, optimizer, lr_scheduler, lr_scheduler_interval)
|
|
48
|
+
|
|
49
|
+
self.criterion = criterion
|
|
50
|
+
self.gradcam = gradcam
|
|
51
|
+
self.train_acc = torchmetrics.Accuracy()
|
|
52
|
+
self.val_acc = torchmetrics.Accuracy()
|
|
53
|
+
self.test_acc = torchmetrics.Accuracy()
|
|
54
|
+
self.cam: GradCAM | None = None
|
|
55
|
+
self.grad_rollout: VitAttentionGradRollout | None = None
|
|
56
|
+
|
|
57
|
+
if not isinstance(self.model.features_extractor, timm.models.resnet.ResNet) and not is_vision_transformer(
|
|
58
|
+
cast(BaseNetworkBuilder, self.model).features_extractor
|
|
59
|
+
):
|
|
60
|
+
log.warning(
|
|
61
|
+
"Backbone not compatible with gradcam. Only timm ResNets, timm ViTs and TorchHub dinoViTs supported",
|
|
62
|
+
)
|
|
63
|
+
self.gradcam = False
|
|
64
|
+
|
|
65
|
+
self.original_requires_grads: list[bool] = []
|
|
66
|
+
|
|
67
|
+
def forward(self, x: torch.Tensor):
|
|
68
|
+
return self.model(x)
|
|
69
|
+
|
|
70
|
+
def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int):
|
|
71
|
+
# pylint: disable=unused-argument
|
|
72
|
+
im, target = batch
|
|
73
|
+
outputs = self(im)
|
|
74
|
+
loss = self.criterion(outputs, target)
|
|
75
|
+
|
|
76
|
+
self.log_dict(
|
|
77
|
+
{"train_loss": loss},
|
|
78
|
+
on_epoch=True,
|
|
79
|
+
on_step=True,
|
|
80
|
+
logger=True,
|
|
81
|
+
prog_bar=True,
|
|
82
|
+
)
|
|
83
|
+
self.log_dict(
|
|
84
|
+
{"train_acc": self.train_acc(outputs.argmax(1), target)},
|
|
85
|
+
on_step=False,
|
|
86
|
+
on_epoch=True,
|
|
87
|
+
logger=True,
|
|
88
|
+
prog_bar=True,
|
|
89
|
+
)
|
|
90
|
+
return loss
|
|
91
|
+
|
|
92
|
+
def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int):
|
|
93
|
+
# pylint: disable=unused-argument
|
|
94
|
+
im, target = batch
|
|
95
|
+
outputs = self(im)
|
|
96
|
+
loss = self.criterion(outputs, target)
|
|
97
|
+
|
|
98
|
+
self.log_dict(
|
|
99
|
+
{"val_loss": loss},
|
|
100
|
+
on_epoch=True,
|
|
101
|
+
on_step=True,
|
|
102
|
+
logger=True,
|
|
103
|
+
prog_bar=True,
|
|
104
|
+
)
|
|
105
|
+
self.log_dict(
|
|
106
|
+
{"val_acc": self.val_acc(outputs.argmax(1), target)},
|
|
107
|
+
on_step=False,
|
|
108
|
+
on_epoch=True,
|
|
109
|
+
logger=True,
|
|
110
|
+
prog_bar=True,
|
|
111
|
+
)
|
|
112
|
+
return loss
|
|
113
|
+
|
|
114
|
+
def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int):
|
|
115
|
+
# pylint: disable=unused-argument
|
|
116
|
+
im, target = batch
|
|
117
|
+
outputs = self(im)
|
|
118
|
+
|
|
119
|
+
loss = self.criterion(outputs, target)
|
|
120
|
+
|
|
121
|
+
self.log_dict(
|
|
122
|
+
{"test_loss": loss},
|
|
123
|
+
on_epoch=True,
|
|
124
|
+
on_step=True,
|
|
125
|
+
logger=True,
|
|
126
|
+
prog_bar=False,
|
|
127
|
+
)
|
|
128
|
+
self.log_dict(
|
|
129
|
+
{"test_acc": self.test_acc(outputs.argmax(1), target)},
|
|
130
|
+
on_step=False,
|
|
131
|
+
on_epoch=True,
|
|
132
|
+
logger=True,
|
|
133
|
+
prog_bar=False,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def prepare_gradcam(self) -> None:
|
|
137
|
+
"""Instantiate gradcam handlers."""
|
|
138
|
+
if isinstance(self.model.features_extractor, timm.models.resnet.ResNet):
|
|
139
|
+
target_layers = [cast(BaseNetworkBuilder, self.model).features_extractor.layer4[-1]]
|
|
140
|
+
|
|
141
|
+
self.cam = GradCAM(
|
|
142
|
+
model=self.model,
|
|
143
|
+
target_layers=target_layers,
|
|
144
|
+
)
|
|
145
|
+
# Activating gradients
|
|
146
|
+
for p in self.model.features_extractor.layer4[-1].parameters():
|
|
147
|
+
p.requires_grad = True
|
|
148
|
+
elif is_vision_transformer(cast(BaseNetworkBuilder, self.model).features_extractor):
|
|
149
|
+
self.grad_rollout = VitAttentionGradRollout(self.model)
|
|
150
|
+
else:
|
|
151
|
+
log.warning("Gradcam not implemented for this backbone, it won't be computed")
|
|
152
|
+
self.original_requires_grads.clear()
|
|
153
|
+
self.gradcam = False
|
|
154
|
+
|
|
155
|
+
def on_predict_start(self) -> None:
|
|
156
|
+
"""If gradcam, prepares gradcam and saves params requires_grad state."""
|
|
157
|
+
if self.gradcam:
|
|
158
|
+
# Saving params requires_grad state
|
|
159
|
+
for p in self.model.parameters():
|
|
160
|
+
self.original_requires_grads.append(p.requires_grad)
|
|
161
|
+
self.prepare_gradcam()
|
|
162
|
+
|
|
163
|
+
return super().on_predict_start()
|
|
164
|
+
|
|
165
|
+
def on_predict_end(self) -> None:
|
|
166
|
+
"""If we computed gradcam, requires_grad values are reset to original value."""
|
|
167
|
+
if self.gradcam:
|
|
168
|
+
# Get back to initial state
|
|
169
|
+
for i, p in enumerate(self.model.parameters()):
|
|
170
|
+
p.requires_grad = self.original_requires_grads[i]
|
|
171
|
+
|
|
172
|
+
# We are using GradCAM package only for resnets at the moment
|
|
173
|
+
if isinstance(self.model.features_extractor, timm.models.resnet.ResNet) and self.cam is not None:
|
|
174
|
+
# Needed to solve jitting bug
|
|
175
|
+
self.cam.activations_and_grads.release()
|
|
176
|
+
elif (
|
|
177
|
+
is_vision_transformer(cast(BaseNetworkBuilder, self.model).features_extractor)
|
|
178
|
+
and self.grad_rollout is not None
|
|
179
|
+
):
|
|
180
|
+
for handle in self.grad_rollout.f_hook_handles:
|
|
181
|
+
handle.remove()
|
|
182
|
+
for handle in self.grad_rollout.b_hook_handles:
|
|
183
|
+
handle.remove()
|
|
184
|
+
|
|
185
|
+
return super().on_predict_end()
|
|
186
|
+
|
|
187
|
+
# pylint: disable=unused-argument
|
|
188
|
+
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
|
|
189
|
+
"""Prediction step.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
batch: Tuple composed by (image, target)
|
|
193
|
+
batch_idx: Batch index
|
|
194
|
+
dataloader_idx: Dataloader index
|
|
195
|
+
Returns:
|
|
196
|
+
Tuple containing:
|
|
197
|
+
predicted_classes: indexes of predicted classes
|
|
198
|
+
grayscale_cam: gray scale gradcams
|
|
199
|
+
"""
|
|
200
|
+
im, _ = batch
|
|
201
|
+
outputs = self(im)
|
|
202
|
+
probs = torch.softmax(outputs, dim=1)
|
|
203
|
+
predicted_classes = torch.max(probs, dim=1).indices.tolist()
|
|
204
|
+
if self.gradcam:
|
|
205
|
+
# inference_mode set to false because gradcam needs gradients
|
|
206
|
+
with torch.inference_mode(False):
|
|
207
|
+
im = im.clone()
|
|
208
|
+
|
|
209
|
+
if isinstance(self.model.features_extractor, timm.models.resnet.ResNet) and self.cam:
|
|
210
|
+
grayscale_cam = self.cam(input_tensor=im, targets=None)
|
|
211
|
+
elif (
|
|
212
|
+
is_vision_transformer(cast(BaseNetworkBuilder, self.model).features_extractor) and self.grad_rollout
|
|
213
|
+
):
|
|
214
|
+
grayscale_cam_low_res = self.grad_rollout(input_tensor=im, targets_list=predicted_classes)
|
|
215
|
+
orig_shape = grayscale_cam_low_res.shape
|
|
216
|
+
new_shape = (orig_shape[0], im.shape[2], im.shape[3])
|
|
217
|
+
zoom_factors = tuple(np.array(new_shape) / np.array(orig_shape))
|
|
218
|
+
grayscale_cam = ndimage.zoom(grayscale_cam_low_res, zoom_factors, order=1)
|
|
219
|
+
else:
|
|
220
|
+
grayscale_cam = None
|
|
221
|
+
return predicted_classes, grayscale_cam, torch.max(probs, dim=1)[0].tolist()
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class MultilabelClassificationModule(BaseLightningModule):
|
|
225
|
+
"""SklearnClassification model: train a generic SklearnClassification model for a multilabel
|
|
226
|
+
problem.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
model: Feature extractor as PyTorch `torch.nn.Module`
|
|
230
|
+
criterion: the loss to be applied as a PyTorch `torch.nn.Module`.
|
|
231
|
+
optimizer: optimizer of the training. Defaults to None.
|
|
232
|
+
lr_scheduler: Pytorch learning rate scheduler.
|
|
233
|
+
If None a default ReduceLROnPlateau is used.
|
|
234
|
+
Defaults to None.
|
|
235
|
+
lr_scheduler_interval: the learning rate scheduler interval.
|
|
236
|
+
Defaults to "epoch".
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
def __init__(
|
|
240
|
+
self,
|
|
241
|
+
model: nn.Sequential,
|
|
242
|
+
criterion: nn.Module,
|
|
243
|
+
optimizer: None | optim.Optimizer = None,
|
|
244
|
+
lr_scheduler: None | object = None,
|
|
245
|
+
lr_scheduler_interval: str | None = "epoch",
|
|
246
|
+
gradcam: bool = False,
|
|
247
|
+
):
|
|
248
|
+
super().__init__(model, optimizer, lr_scheduler, lr_scheduler_interval)
|
|
249
|
+
self.criterion = criterion
|
|
250
|
+
self.gradcam = gradcam
|
|
251
|
+
|
|
252
|
+
# TODO: can we use gradcam with more backbones?
|
|
253
|
+
if self.gradcam:
|
|
254
|
+
if not isinstance(model[0].features_extractor, timm.models.resnet.ResNet):
|
|
255
|
+
log.warning(
|
|
256
|
+
"Backbone must be compatible with gradcam, at the moment only ResNets supported, disabling gradcam"
|
|
257
|
+
)
|
|
258
|
+
self.gradcam = False
|
|
259
|
+
else:
|
|
260
|
+
target_layers = [model[0].features_extractor.layer4[-1]]
|
|
261
|
+
self.cam = GradCAM(model=model, target_layers=target_layers)
|
|
262
|
+
|
|
263
|
+
def forward(self, x):
|
|
264
|
+
return self.model(x)
|
|
265
|
+
|
|
266
|
+
def training_step(self, batch, batch_idx):
|
|
267
|
+
# pylint: disable=unused-argument
|
|
268
|
+
im, target = batch
|
|
269
|
+
outputs = self(im)
|
|
270
|
+
with torch.no_grad():
|
|
271
|
+
outputs_sig = torch.sigmoid(outputs)
|
|
272
|
+
loss = self.criterion(outputs, target)
|
|
273
|
+
|
|
274
|
+
self.log_dict(
|
|
275
|
+
{
|
|
276
|
+
"t_loss": loss,
|
|
277
|
+
"t_map": TMF.label_ranking_average_precision(outputs_sig, target.bool()),
|
|
278
|
+
"t_f1": TMF.f1_score(outputs_sig, target.bool(), average="samples"),
|
|
279
|
+
},
|
|
280
|
+
on_epoch=True,
|
|
281
|
+
on_step=False,
|
|
282
|
+
logger=True,
|
|
283
|
+
prog_bar=True,
|
|
284
|
+
)
|
|
285
|
+
return loss
|
|
286
|
+
|
|
287
|
+
def validation_step(self, batch, batch_idx):
|
|
288
|
+
# pylint: disable=unused-argument
|
|
289
|
+
im, target = batch
|
|
290
|
+
outputs = self(im)
|
|
291
|
+
with torch.no_grad():
|
|
292
|
+
outputs_sig = torch.sigmoid(outputs)
|
|
293
|
+
loss = self.criterion(outputs, target)
|
|
294
|
+
|
|
295
|
+
self.log_dict(
|
|
296
|
+
{
|
|
297
|
+
"val_loss": loss,
|
|
298
|
+
"val_map": TMF.label_ranking_average_precision(outputs_sig, target.bool()),
|
|
299
|
+
"val_f1": TMF.f1_score(outputs_sig, target.bool(), average="samples"),
|
|
300
|
+
},
|
|
301
|
+
on_epoch=True,
|
|
302
|
+
on_step=False,
|
|
303
|
+
logger=True,
|
|
304
|
+
prog_bar=True,
|
|
305
|
+
)
|
|
306
|
+
return loss
|
|
307
|
+
|
|
308
|
+
def test_step(self, batch, batch_idx):
|
|
309
|
+
# pylint: disable=unused-argument
|
|
310
|
+
im, target = batch
|
|
311
|
+
outputs = self(im)
|
|
312
|
+
with torch.no_grad():
|
|
313
|
+
outputs_sig = torch.sigmoid(outputs)
|
|
314
|
+
loss = self.criterion(outputs, target)
|
|
315
|
+
|
|
316
|
+
self.log_dict(
|
|
317
|
+
{
|
|
318
|
+
"test_loss": loss,
|
|
319
|
+
"test_map": TMF.label_ranking_average_precision(outputs_sig, target.bool()),
|
|
320
|
+
"test_f1": TMF.f1_score(outputs_sig, target.bool(), average="samples"),
|
|
321
|
+
},
|
|
322
|
+
on_epoch=True,
|
|
323
|
+
on_step=True,
|
|
324
|
+
logger=True,
|
|
325
|
+
prog_bar=False,
|
|
326
|
+
)
|
|
327
|
+
return loss
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from .barlowtwins import BarlowTwins
|
|
2
|
+
from .byol import BYOL
|
|
3
|
+
from .dino import Dino
|
|
4
|
+
from .idmm import IDMM
|
|
5
|
+
from .simclr import SimCLR
|
|
6
|
+
from .simsiam import SimSIAM
|
|
7
|
+
from .vicreg import VICReg
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"BarlowTwins",
|
|
11
|
+
"BYOL",
|
|
12
|
+
"Dino",
|
|
13
|
+
"IDMM",
|
|
14
|
+
"SimCLR",
|
|
15
|
+
"SimSIAM",
|
|
16
|
+
"VICReg",
|
|
17
|
+
]
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sklearn
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn, optim
|
|
6
|
+
|
|
7
|
+
from quadra.modules.base import SSLModule
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BarlowTwins(SSLModule):
|
|
11
|
+
"""BarlowTwins model.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
model: Network Module used for extract features
|
|
15
|
+
projection_mlp: Module to project extracted features
|
|
16
|
+
criterion: SSL loss to be applied
|
|
17
|
+
classifier: Standard sklearn classifier. Defaults to None.
|
|
18
|
+
optimizer: optimizer of the training. If None a default Adam is used. Defaults to None.
|
|
19
|
+
lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.
|
|
20
|
+
lr_scheduler_interval: interval at which the lr scheduler is updated. Defaults to "epoch".
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
model: nn.Module,
|
|
26
|
+
projection_mlp: nn.Module,
|
|
27
|
+
criterion: nn.Module,
|
|
28
|
+
classifier: sklearn.base.ClassifierMixin | None = None,
|
|
29
|
+
optimizer: optim.Optimizer | None = None,
|
|
30
|
+
lr_scheduler: object | None = None,
|
|
31
|
+
lr_scheduler_interval: str | None = "epoch",
|
|
32
|
+
):
|
|
33
|
+
super().__init__(model, criterion, classifier, optimizer, lr_scheduler, lr_scheduler_interval)
|
|
34
|
+
# self.save_hyperparameters()
|
|
35
|
+
self.projection_mlp = projection_mlp
|
|
36
|
+
self.criterion = criterion
|
|
37
|
+
|
|
38
|
+
def forward(self, x):
|
|
39
|
+
x = self.model(x)
|
|
40
|
+
z = self.projection_mlp(x)
|
|
41
|
+
return z
|
|
42
|
+
|
|
43
|
+
def training_step(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int):
|
|
44
|
+
# pylint: disable=unused-argument
|
|
45
|
+
# Compute loss
|
|
46
|
+
(im_x, im_y), _ = batch
|
|
47
|
+
z1 = self(im_x)
|
|
48
|
+
z2 = self(im_y)
|
|
49
|
+
loss = self.criterion(z1, z2)
|
|
50
|
+
|
|
51
|
+
self.log(
|
|
52
|
+
"loss",
|
|
53
|
+
loss,
|
|
54
|
+
on_epoch=True,
|
|
55
|
+
on_step=True,
|
|
56
|
+
logger=True,
|
|
57
|
+
prog_bar=True,
|
|
58
|
+
)
|
|
59
|
+
return loss
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from collections.abc import Callable, Sized
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import sklearn
|
|
8
|
+
import torch
|
|
9
|
+
from pytorch_lightning.core.optimizer import LightningOptimizer
|
|
10
|
+
from torch import nn
|
|
11
|
+
from torch.optim import Optimizer
|
|
12
|
+
|
|
13
|
+
from quadra.modules.base import SSLModule
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BYOL(SSLModule):
|
|
17
|
+
"""BYOL module, inspired by https://arxiv.org/abs/2006.07733.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
student : student model.
|
|
21
|
+
teacher : teacher model.
|
|
22
|
+
student_projection_mlp : student projection MLP.
|
|
23
|
+
student_prediction_mlp : student prediction MLP.
|
|
24
|
+
teacher_projection_mlp : teacher projection MLP.
|
|
25
|
+
criterion : loss function.
|
|
26
|
+
classifier: Standard sklearn classifier.
|
|
27
|
+
optimizer: optimizer of the training. If None a default Adam is used.
|
|
28
|
+
lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used.
|
|
29
|
+
lr_scheduler_interval: interval at which the lr scheduler is updated.
|
|
30
|
+
teacher_momentum: momentum of the teacher parameters.
|
|
31
|
+
teacher_momentum_cosine_decay: whether to use cosine decay for the teacher momentum. Default: True
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
student: nn.Module,
|
|
37
|
+
teacher: nn.Module,
|
|
38
|
+
student_projection_mlp: nn.Module,
|
|
39
|
+
student_prediction_mlp: nn.Module,
|
|
40
|
+
teacher_projection_mlp: nn.Module,
|
|
41
|
+
criterion: nn.Module,
|
|
42
|
+
classifier: sklearn.base.ClassifierMixin | None = None,
|
|
43
|
+
optimizer: Optimizer | None = None,
|
|
44
|
+
lr_scheduler: object | None = None,
|
|
45
|
+
lr_scheduler_interval: str | None = "epoch",
|
|
46
|
+
teacher_momentum: float = 0.9995,
|
|
47
|
+
teacher_momentum_cosine_decay: bool | None = True,
|
|
48
|
+
):
|
|
49
|
+
super().__init__(
|
|
50
|
+
model=student,
|
|
51
|
+
criterion=criterion,
|
|
52
|
+
classifier=classifier,
|
|
53
|
+
optimizer=optimizer,
|
|
54
|
+
lr_scheduler=lr_scheduler,
|
|
55
|
+
lr_scheduler_interval=lr_scheduler_interval,
|
|
56
|
+
)
|
|
57
|
+
# Student model
|
|
58
|
+
self.max_steps: int
|
|
59
|
+
self.student_projection_mlp = student_projection_mlp
|
|
60
|
+
self.student_prediction_mlp = student_prediction_mlp
|
|
61
|
+
|
|
62
|
+
# Teacher model
|
|
63
|
+
self.teacher = teacher
|
|
64
|
+
self.teacher_projection_mlp = teacher_projection_mlp
|
|
65
|
+
self.teacher_initialized = False
|
|
66
|
+
self.teacher_momentum = teacher_momentum
|
|
67
|
+
self.teacher_momentum_cosine_decay = teacher_momentum_cosine_decay
|
|
68
|
+
|
|
69
|
+
self.initialize_teacher()
|
|
70
|
+
|
|
71
|
+
def initialize_teacher(self):
|
|
72
|
+
"""Initialize teacher from the state dict of the student one,
|
|
73
|
+
checking also that student model requires greadient correctly.
|
|
74
|
+
"""
|
|
75
|
+
self.teacher_projection_mlp.load_state_dict(self.student_projection_mlp.state_dict())
|
|
76
|
+
for p in self.teacher_projection_mlp.parameters():
|
|
77
|
+
p.requires_grad = False
|
|
78
|
+
|
|
79
|
+
self.teacher.load_state_dict(self.model.state_dict())
|
|
80
|
+
for p in self.teacher.parameters():
|
|
81
|
+
p.requires_grad = False
|
|
82
|
+
|
|
83
|
+
for p in self.student_projection_mlp.parameters():
|
|
84
|
+
assert p.requires_grad is True
|
|
85
|
+
for p in self.student_prediction_mlp.parameters():
|
|
86
|
+
assert p.requires_grad is True
|
|
87
|
+
|
|
88
|
+
self.teacher_initialized = True
|
|
89
|
+
|
|
90
|
+
def update_teacher(self):
|
|
91
|
+
"""Update teacher given `self.teacher_momentum` by an exponential moving average
|
|
92
|
+
of the student parameters, that is: theta_t * tau + theta_s * (1 - tau), where
|
|
93
|
+
`theta_{s,t}` are the parameters of the student and the teacher model, while `tau` is the
|
|
94
|
+
teacher momentum. If `self.teacher_momentum_cosine_decay` is True, then the teacher
|
|
95
|
+
momentum will follow a cosine scheduling from `self.teacher_momentum` to 1:
|
|
96
|
+
tau = 1 - (1 - tau) * (cos(pi * t / T) + 1) / 2, where `t` is the current step and
|
|
97
|
+
`T` is the max number of steps.
|
|
98
|
+
"""
|
|
99
|
+
with torch.no_grad():
|
|
100
|
+
if self.teacher_momentum_cosine_decay:
|
|
101
|
+
teacher_momentum = (
|
|
102
|
+
1
|
|
103
|
+
- (1 - self.teacher_momentum)
|
|
104
|
+
* (math.cos(math.pi * self.trainer.global_step / self.max_steps) + 1)
|
|
105
|
+
/ 2
|
|
106
|
+
)
|
|
107
|
+
else:
|
|
108
|
+
teacher_momentum = self.teacher_momentum
|
|
109
|
+
self.log("teacher_momentum", teacher_momentum, prog_bar=True)
|
|
110
|
+
for student_ps, teacher_ps in zip(
|
|
111
|
+
list(self.model.parameters()) + list(self.student_projection_mlp.parameters()),
|
|
112
|
+
list(self.teacher.parameters()) + list(self.teacher_projection_mlp.parameters()),
|
|
113
|
+
):
|
|
114
|
+
teacher_ps.data = teacher_ps.data * teacher_momentum + (1 - teacher_momentum) * student_ps.data
|
|
115
|
+
|
|
116
|
+
def on_train_start(self) -> None:
|
|
117
|
+
if isinstance(self.trainer.train_dataloader, Sized) and isinstance(self.trainer.max_epochs, int):
|
|
118
|
+
self.max_steps = len(self.trainer.train_dataloader) * self.trainer.max_epochs
|
|
119
|
+
else:
|
|
120
|
+
raise ValueError("BYOL requires `max_epochs` to be set and `train_dataloader` to be initialized.")
|
|
121
|
+
|
|
122
|
+
def training_step(self, batch: tuple[list[torch.Tensor], torch.Tensor], *args: Any) -> torch.Tensor:
|
|
123
|
+
[image1, image2], _ = batch
|
|
124
|
+
|
|
125
|
+
online_pred_one = self.student_prediction_mlp(self.student_projection_mlp(self.model(image1)))
|
|
126
|
+
online_pred_two = self.student_prediction_mlp(self.student_projection_mlp(self.model(image2)))
|
|
127
|
+
|
|
128
|
+
with torch.no_grad():
|
|
129
|
+
target_proj_one = self.teacher_projection_mlp(self.teacher(image1))
|
|
130
|
+
target_proj_two = self.teacher_projection_mlp(self.teacher(image2))
|
|
131
|
+
|
|
132
|
+
loss_one = self.criterion(online_pred_one, target_proj_two.detach())
|
|
133
|
+
loss_two = self.criterion(online_pred_two, target_proj_one.detach())
|
|
134
|
+
loss = loss_one + loss_two
|
|
135
|
+
|
|
136
|
+
self.log(name="loss", value=loss, on_step=True, on_epoch=True, prog_bar=True)
|
|
137
|
+
return loss
|
|
138
|
+
|
|
139
|
+
def optimizer_step(
|
|
140
|
+
self,
|
|
141
|
+
epoch: int,
|
|
142
|
+
batch_idx: int,
|
|
143
|
+
optimizer: Optimizer | LightningOptimizer,
|
|
144
|
+
optimizer_closure: Callable[[], Any] | None = None,
|
|
145
|
+
) -> None:
|
|
146
|
+
"""Override optimizer step to update the teacher parameters."""
|
|
147
|
+
super().optimizer_step(
|
|
148
|
+
epoch,
|
|
149
|
+
batch_idx,
|
|
150
|
+
optimizer,
|
|
151
|
+
optimizer_closure=optimizer_closure,
|
|
152
|
+
)
|
|
153
|
+
self.update_teacher()
|
|
154
|
+
|
|
155
|
+
def calculate_accuracy(self, batch):
|
|
156
|
+
"""Calculate accuracy on the given batch."""
|
|
157
|
+
images, labels = batch
|
|
158
|
+
embedding = self.model(images).detach().cpu().numpy()
|
|
159
|
+
predictions = self.classifier.predict(embedding)
|
|
160
|
+
labels = labels.detach()
|
|
161
|
+
acc = self.val_acc(torch.tensor(predictions, device=self.device), labels)
|
|
162
|
+
|
|
163
|
+
return acc
|
|
164
|
+
|
|
165
|
+
def on_test_epoch_start(self) -> None:
|
|
166
|
+
self.fit_estimator()
|
|
167
|
+
|
|
168
|
+
def test_step(self, batch, *args: list[Any]) -> None:
|
|
169
|
+
"""Calculate accuracy on the test set for the given batch."""
|
|
170
|
+
acc = self.calculate_accuracy(batch)
|
|
171
|
+
self.log(name="test_acc", value=acc, on_step=False, on_epoch=True, prog_bar=True)
|
|
172
|
+
return acc
|