quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sklearn
|
|
4
|
+
import timm
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
|
|
8
|
+
from quadra.modules.base import SSLModule
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class IDMM(SSLModule):
|
|
12
|
+
"""IDMM model.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
model: backbone model
|
|
16
|
+
prediction_mlp: student prediction MLP
|
|
17
|
+
criterion: loss function
|
|
18
|
+
multiview_loss: whether to use the multiview loss as definied in https://arxiv.org/abs/2201.10728.
|
|
19
|
+
Defaults to True.
|
|
20
|
+
mixup_fn: the mixup/cutmix function to be applied to a batch of images.
|
|
21
|
+
Defaults to None.
|
|
22
|
+
classifier: Standard sklearn classifier
|
|
23
|
+
optimizer: optimizer of the training. If None a default Adam is used.
|
|
24
|
+
lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used.
|
|
25
|
+
lr_scheduler_interval: interval at which the lr scheduler is updated.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
model: torch.nn.Module,
|
|
31
|
+
prediction_mlp: torch.nn.Module,
|
|
32
|
+
criterion: torch.nn.Module,
|
|
33
|
+
multiview_loss: bool = True,
|
|
34
|
+
mixup_fn: timm.data.Mixup | None = None,
|
|
35
|
+
classifier: sklearn.base.ClassifierMixin | None = None,
|
|
36
|
+
optimizer: torch.optim.Optimizer | None = None,
|
|
37
|
+
lr_scheduler: object | None = None,
|
|
38
|
+
lr_scheduler_interval: str | None = "epoch",
|
|
39
|
+
):
|
|
40
|
+
super().__init__(
|
|
41
|
+
model,
|
|
42
|
+
criterion,
|
|
43
|
+
classifier,
|
|
44
|
+
optimizer,
|
|
45
|
+
lr_scheduler,
|
|
46
|
+
lr_scheduler_interval,
|
|
47
|
+
)
|
|
48
|
+
# self.save_hyperparameters()
|
|
49
|
+
self.prediction_mlp = prediction_mlp
|
|
50
|
+
self.mixup_fn = mixup_fn
|
|
51
|
+
self.multiview_loss = multiview_loss
|
|
52
|
+
|
|
53
|
+
def forward(self, x):
|
|
54
|
+
z = self.model(x)
|
|
55
|
+
p = self.prediction_mlp(z)
|
|
56
|
+
return z, p
|
|
57
|
+
|
|
58
|
+
def training_step(self, batch, batch_idx):
|
|
59
|
+
# pylint: disable=unused-argument
|
|
60
|
+
# Compute loss
|
|
61
|
+
if self.multiview_loss:
|
|
62
|
+
im_x, im_y, target = batch
|
|
63
|
+
|
|
64
|
+
# Contrastive loss
|
|
65
|
+
za, _ = self(im_x)
|
|
66
|
+
zb, _ = self(im_y)
|
|
67
|
+
za = F.normalize(za, dim=-1)
|
|
68
|
+
zb = F.normalize(zb, dim=-1)
|
|
69
|
+
s_aa = za.T @ za
|
|
70
|
+
s_ab = za.T @ zb
|
|
71
|
+
contrastive = (
|
|
72
|
+
torch.log(torch.exp(s_aa).sum(-1))
|
|
73
|
+
- torch.diagonal(s_aa)
|
|
74
|
+
+ torch.log(torch.exp(s_ab).sum(-1))
|
|
75
|
+
- torch.diagonal(s_ab)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Instance discrimination
|
|
79
|
+
if self.mixup_fn is not None:
|
|
80
|
+
im_x, target = self.mixup_fn(im_x, target)
|
|
81
|
+
_, pred = self(im_x)
|
|
82
|
+
loss = self.criterion(pred, target) + contrastive.mean()
|
|
83
|
+
else:
|
|
84
|
+
im_x, target = batch
|
|
85
|
+
if self.mixup_fn is not None:
|
|
86
|
+
im_x, target = self.mixup_fn(im_x, target)
|
|
87
|
+
pred = self(im_x)
|
|
88
|
+
loss = self.criterion(pred, target)
|
|
89
|
+
|
|
90
|
+
self.log(
|
|
91
|
+
"loss",
|
|
92
|
+
loss,
|
|
93
|
+
on_epoch=True,
|
|
94
|
+
on_step=True,
|
|
95
|
+
logger=True,
|
|
96
|
+
prog_bar=True,
|
|
97
|
+
)
|
|
98
|
+
return loss
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sklearn
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from quadra.modules.base import SSLModule
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SimCLR(SSLModule):
|
|
11
|
+
"""SIMCLR class.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
model: Feature extractor as pytorch `torch.nn.Module`
|
|
15
|
+
projection_mlp: projection head as
|
|
16
|
+
pytorch `torch.nn.Module`
|
|
17
|
+
criterion: SSL loss to be applied
|
|
18
|
+
classifier: Standard sklearn classifier. Defaults to None.
|
|
19
|
+
optimizer: optimizer of the training. If None a default Adam is used. Defaults to None.
|
|
20
|
+
lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.
|
|
21
|
+
lr_scheduler_interval: interval at which the lr scheduler is updated. Defaults to "epoch".
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
model: nn.Module,
|
|
27
|
+
projection_mlp: nn.Module,
|
|
28
|
+
criterion: torch.nn.Module,
|
|
29
|
+
classifier: sklearn.base.ClassifierMixin | None = None,
|
|
30
|
+
optimizer: torch.optim.Optimizer | None = None,
|
|
31
|
+
lr_scheduler: object | None = None,
|
|
32
|
+
lr_scheduler_interval: str | None = "epoch",
|
|
33
|
+
):
|
|
34
|
+
super().__init__(
|
|
35
|
+
model,
|
|
36
|
+
criterion,
|
|
37
|
+
classifier,
|
|
38
|
+
optimizer,
|
|
39
|
+
lr_scheduler,
|
|
40
|
+
lr_scheduler_interval,
|
|
41
|
+
)
|
|
42
|
+
self.projection_mlp = projection_mlp
|
|
43
|
+
|
|
44
|
+
def forward(self, x):
|
|
45
|
+
x = self.model(x)
|
|
46
|
+
x = self.projection_mlp(x)
|
|
47
|
+
return x
|
|
48
|
+
|
|
49
|
+
def training_step(
|
|
50
|
+
self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int
|
|
51
|
+
) -> torch.Tensor:
|
|
52
|
+
"""Args:
|
|
53
|
+
batch: The batch of data
|
|
54
|
+
batch_idx: The index of the batch.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
The computed loss
|
|
58
|
+
"""
|
|
59
|
+
# pylint: disable=unused-argument
|
|
60
|
+
(im_x, im_y), _ = batch
|
|
61
|
+
emb_x = self(im_x)
|
|
62
|
+
emb_y = self(im_y)
|
|
63
|
+
loss = self.criterion(emb_x, emb_y)
|
|
64
|
+
|
|
65
|
+
self.log(
|
|
66
|
+
"loss",
|
|
67
|
+
loss,
|
|
68
|
+
on_epoch=True,
|
|
69
|
+
on_step=True,
|
|
70
|
+
logger=True,
|
|
71
|
+
prog_bar=True,
|
|
72
|
+
)
|
|
73
|
+
return loss
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sklearn
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
from quadra.modules.base import SSLModule
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SimSIAM(SSLModule):
|
|
10
|
+
"""SimSIAM model.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
model: Feature extractor as pytorch `torch.nn.Module`
|
|
14
|
+
projection_mlp: optional projection head as pytorch `torch.nn.Module`
|
|
15
|
+
prediction_mlp: optional predicition head as pytorch `torch.nn.Module`
|
|
16
|
+
criterion: loss to be applied.
|
|
17
|
+
classifier: Standard sklearn classifier.
|
|
18
|
+
optimizer: optimizer of the training. If None a default Adam is used.
|
|
19
|
+
lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used.
|
|
20
|
+
lr_scheduler_interval: interval at which the lr scheduler is updated.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
model: torch.nn.Module,
|
|
26
|
+
projection_mlp: torch.nn.Module,
|
|
27
|
+
prediction_mlp: torch.nn.Module,
|
|
28
|
+
criterion: torch.nn.Module,
|
|
29
|
+
classifier: sklearn.base.ClassifierMixin | None = None,
|
|
30
|
+
optimizer: torch.optim.Optimizer | None = None,
|
|
31
|
+
lr_scheduler: object | None = None,
|
|
32
|
+
lr_scheduler_interval: str | None = "epoch",
|
|
33
|
+
):
|
|
34
|
+
super().__init__(
|
|
35
|
+
model,
|
|
36
|
+
criterion,
|
|
37
|
+
classifier,
|
|
38
|
+
optimizer,
|
|
39
|
+
lr_scheduler,
|
|
40
|
+
lr_scheduler_interval,
|
|
41
|
+
)
|
|
42
|
+
# self.save_hyperparameters()
|
|
43
|
+
self.projection_mlp = projection_mlp
|
|
44
|
+
self.prediction_mlp = prediction_mlp
|
|
45
|
+
|
|
46
|
+
def forward(self, x):
|
|
47
|
+
x = self.model(x)
|
|
48
|
+
z = self.projection_mlp(x)
|
|
49
|
+
p = self.prediction_mlp(z)
|
|
50
|
+
return p, z.detach()
|
|
51
|
+
|
|
52
|
+
def training_step(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int):
|
|
53
|
+
# pylint: disable=unused-argument
|
|
54
|
+
# Compute loss
|
|
55
|
+
(im_x, im_y), _ = batch
|
|
56
|
+
p1, z1 = self(im_x)
|
|
57
|
+
p2, z2 = self(im_y)
|
|
58
|
+
loss = self.criterion(p1, p2, z1, z2)
|
|
59
|
+
|
|
60
|
+
self.log(
|
|
61
|
+
"loss",
|
|
62
|
+
loss,
|
|
63
|
+
on_epoch=True,
|
|
64
|
+
on_step=True,
|
|
65
|
+
logger=True,
|
|
66
|
+
prog_bar=True,
|
|
67
|
+
)
|
|
68
|
+
return loss
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sklearn
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn, optim
|
|
6
|
+
|
|
7
|
+
from quadra.modules.base import SSLModule
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class VICReg(SSLModule):
|
|
11
|
+
"""VICReg model.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
model: Network Module used for extract features
|
|
15
|
+
projection_mlp: Module to project extracted features
|
|
16
|
+
criterion: SSL loss to be applied.
|
|
17
|
+
classifier: Standard sklearn classifier. Defaults to None.
|
|
18
|
+
optimizer: optimizer of the training. If None a default Adam is used. Defaults to None.
|
|
19
|
+
lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used. Defaults to None.
|
|
20
|
+
lr_scheduler_interval: interval at which the lr scheduler is updated. Defaults to "epoch".
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
model: nn.Module,
|
|
27
|
+
projection_mlp: nn.Module,
|
|
28
|
+
criterion: nn.Module,
|
|
29
|
+
classifier: sklearn.base.ClassifierMixin | None = None,
|
|
30
|
+
optimizer: optim.Optimizer | None = None,
|
|
31
|
+
lr_scheduler: object | None = None,
|
|
32
|
+
lr_scheduler_interval: str | None = "epoch",
|
|
33
|
+
):
|
|
34
|
+
super().__init__(
|
|
35
|
+
model,
|
|
36
|
+
criterion,
|
|
37
|
+
classifier,
|
|
38
|
+
optimizer,
|
|
39
|
+
lr_scheduler,
|
|
40
|
+
lr_scheduler_interval,
|
|
41
|
+
)
|
|
42
|
+
# self.save_hyperparameters()
|
|
43
|
+
self.projection_mlp = projection_mlp
|
|
44
|
+
self.criterion = criterion
|
|
45
|
+
|
|
46
|
+
def forward(self, x):
|
|
47
|
+
x = self.model(x)
|
|
48
|
+
z = self.projection_mlp(x)
|
|
49
|
+
return z
|
|
50
|
+
|
|
51
|
+
def training_step(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int):
|
|
52
|
+
# pylint: disable=unused-argument
|
|
53
|
+
# Compute loss
|
|
54
|
+
(im_x, im_y), _ = batch
|
|
55
|
+
z1 = self(im_x)
|
|
56
|
+
z2 = self(im_y)
|
|
57
|
+
loss = self.criterion(z1, z2)
|
|
58
|
+
|
|
59
|
+
self.log(
|
|
60
|
+
"loss",
|
|
61
|
+
loss,
|
|
62
|
+
on_epoch=True,
|
|
63
|
+
on_step=True,
|
|
64
|
+
logger=True,
|
|
65
|
+
prog_bar=True,
|
|
66
|
+
)
|
|
67
|
+
return loss
|
|
@@ -0,0 +1,153 @@
|
|
|
1
|
+
"""References:
|
|
2
|
+
- https://arxiv.org/pdf/1708.03888.pdf
|
|
3
|
+
- https://github.com/pytorch/pytorch/blob/1.6/torch/optim/sgd.py.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from collections.abc import Callable
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from torch.nn import Parameter
|
|
12
|
+
from torch.optim.optimizer import Optimizer, _RequiredParameter, required
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LARS(Optimizer):
|
|
16
|
+
r"""Extends SGD in PyTorch with LARS scaling from the paper
|
|
17
|
+
`Large batch training of Convolutional Networks <https://arxiv.org/pdf/1708.03888.pdf>`_.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
params: iterable of parameters to optimize or dicts defining
|
|
21
|
+
parameter groups
|
|
22
|
+
lr: learning rate
|
|
23
|
+
momentum: momentum factor (default: 0)
|
|
24
|
+
weight_decay: weight decay (L2 penalty) (default: 0)
|
|
25
|
+
dampening: dampening for momentum (default: 0)
|
|
26
|
+
nesterov: enables Nesterov momentum (default: False)
|
|
27
|
+
trust_coefficient: trust coefficient for computing LR (default: 0.001)
|
|
28
|
+
eps: eps for division denominator (default: 1e-8).
|
|
29
|
+
|
|
30
|
+
Example:
|
|
31
|
+
>>> model = torch.nn.Linear(10, 1)
|
|
32
|
+
>>> input = torch.Tensor(10)
|
|
33
|
+
>>> target = torch.Tensor([1.])
|
|
34
|
+
>>> loss_fn = lambda input, target: (input - target) ** 2
|
|
35
|
+
>>> #
|
|
36
|
+
>>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9)
|
|
37
|
+
>>> optimizer.zero_grad()
|
|
38
|
+
>>> loss_fn(model(input), target).backward()
|
|
39
|
+
>>> optimizer.step()
|
|
40
|
+
|
|
41
|
+
.. note::
|
|
42
|
+
The application of momentum in the SGD part is modified according to
|
|
43
|
+
the PyTorch standards. LARS scaling fits into the equation in the
|
|
44
|
+
following fashion.
|
|
45
|
+
|
|
46
|
+
.. math::
|
|
47
|
+
\begin{aligned}
|
|
48
|
+
g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\
|
|
49
|
+
v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\
|
|
50
|
+
p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
|
|
51
|
+
\\end{aligned}
|
|
52
|
+
|
|
53
|
+
where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the
|
|
54
|
+
parameters, gradient, velocity, momentum, and weight decay respectively.
|
|
55
|
+
The :math:`lars_lr` is defined by Eq. 6 in the paper.
|
|
56
|
+
The Nesterov version is analogously modified.
|
|
57
|
+
|
|
58
|
+
.. warning::
|
|
59
|
+
Parameters with weight decay set to 0 will automatically be excluded from
|
|
60
|
+
layer-wise LR scaling. This is to ensure consistency with papers like SimCLR
|
|
61
|
+
and BYOL.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
params: list[Parameter],
|
|
67
|
+
lr: _RequiredParameter = required,
|
|
68
|
+
momentum: float = 0,
|
|
69
|
+
dampening: float = 0,
|
|
70
|
+
weight_decay: float = 0,
|
|
71
|
+
nesterov: bool = False,
|
|
72
|
+
trust_coefficient: float = 0.001,
|
|
73
|
+
eps: float = 1e-8,
|
|
74
|
+
):
|
|
75
|
+
if lr is not required and lr < 0.0: # type: ignore[operator]
|
|
76
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
|
77
|
+
if momentum < 0.0:
|
|
78
|
+
raise ValueError(f"Invalid momentum value: {momentum}")
|
|
79
|
+
if weight_decay < 0.0:
|
|
80
|
+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
|
81
|
+
|
|
82
|
+
defaults = {
|
|
83
|
+
"lr": lr,
|
|
84
|
+
"momentum": momentum,
|
|
85
|
+
"dampening": dampening,
|
|
86
|
+
"weight_decay": weight_decay,
|
|
87
|
+
"nesterov": nesterov,
|
|
88
|
+
}
|
|
89
|
+
if nesterov and (momentum <= 0 or dampening != 0):
|
|
90
|
+
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
|
91
|
+
|
|
92
|
+
self.eps = eps
|
|
93
|
+
self.trust_coefficient = trust_coefficient
|
|
94
|
+
|
|
95
|
+
super().__init__(params, defaults)
|
|
96
|
+
|
|
97
|
+
def __setstate__(self, state):
|
|
98
|
+
super().__setstate__(state)
|
|
99
|
+
|
|
100
|
+
for group in self.param_groups:
|
|
101
|
+
group.setdefault("nesterov", False)
|
|
102
|
+
|
|
103
|
+
@torch.no_grad()
|
|
104
|
+
def step(self, closure: Callable | None = None):
|
|
105
|
+
"""Performs a single optimization step.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
closure: A closure that reevaluates the model and returns the loss. Defaults to None.
|
|
109
|
+
"""
|
|
110
|
+
loss = None
|
|
111
|
+
if closure is not None:
|
|
112
|
+
with torch.enable_grad():
|
|
113
|
+
loss = closure()
|
|
114
|
+
|
|
115
|
+
# exclude scaling for params with 0 weight decay
|
|
116
|
+
for group in self.param_groups:
|
|
117
|
+
weight_decay = group["weight_decay"]
|
|
118
|
+
momentum = group["momentum"]
|
|
119
|
+
dampening = group["dampening"]
|
|
120
|
+
nesterov = group["nesterov"]
|
|
121
|
+
|
|
122
|
+
for p in group["params"]:
|
|
123
|
+
if p.grad is None:
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
d_p = p.grad
|
|
127
|
+
p_norm = torch.norm(p.data)
|
|
128
|
+
g_norm = torch.norm(p.grad.data)
|
|
129
|
+
|
|
130
|
+
# lars scaling + weight decay part
|
|
131
|
+
if weight_decay != 0 and p_norm != 0 and g_norm != 0:
|
|
132
|
+
lars_lr = p_norm / (g_norm + p_norm * weight_decay + self.eps)
|
|
133
|
+
lars_lr *= self.trust_coefficient
|
|
134
|
+
|
|
135
|
+
d_p = d_p.add(p, alpha=weight_decay)
|
|
136
|
+
d_p *= lars_lr
|
|
137
|
+
|
|
138
|
+
# sgd part
|
|
139
|
+
if momentum != 0:
|
|
140
|
+
param_state = self.state[p]
|
|
141
|
+
if "momentum_buffer" not in param_state:
|
|
142
|
+
buf = param_state["momentum_buffer"] = torch.clone(d_p).detach()
|
|
143
|
+
else:
|
|
144
|
+
buf = param_state["momentum_buffer"]
|
|
145
|
+
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
|
|
146
|
+
if nesterov:
|
|
147
|
+
d_p = d_p.add(buf, alpha=momentum)
|
|
148
|
+
else:
|
|
149
|
+
d_p = buf
|
|
150
|
+
|
|
151
|
+
p.add_(d_p, alpha=-group["lr"])
|
|
152
|
+
|
|
153
|
+
return loss
|
quadra/optimizers/sam.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch.nn import Parameter
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SAM(torch.optim.Optimizer):
|
|
11
|
+
"""PyTorch implementation of Sharpness-Aware-Minization paper: https://arxiv.org/abs/2010.01412
|
|
12
|
+
and https://arxiv.org/abs/2102.11600.
|
|
13
|
+
Taken from: https://github.com/davda54/sam.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
params: model parameters.
|
|
17
|
+
base_optimizer: optimizer to use.
|
|
18
|
+
rho: Postive float value used to scale the gradients.
|
|
19
|
+
adaptive: Boolean flag indicating whether to use adaptive step update.
|
|
20
|
+
**kwargs: Additional parameters for the base optimizer.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
params: list[Parameter],
|
|
26
|
+
base_optimizer: torch.optim.Optimizer,
|
|
27
|
+
rho: float = 0.05,
|
|
28
|
+
adaptive: bool = True,
|
|
29
|
+
**kwargs: Any,
|
|
30
|
+
):
|
|
31
|
+
assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
|
|
32
|
+
|
|
33
|
+
defaults = {"rho": rho, "adaptive": adaptive, **kwargs}
|
|
34
|
+
super().__init__(params, defaults)
|
|
35
|
+
|
|
36
|
+
if callable(base_optimizer):
|
|
37
|
+
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
|
|
38
|
+
else:
|
|
39
|
+
self.base_optimizer = base_optimizer
|
|
40
|
+
self.rho = rho
|
|
41
|
+
self.adaptive = adaptive
|
|
42
|
+
self.param_groups = self.base_optimizer.param_groups
|
|
43
|
+
|
|
44
|
+
@torch.no_grad()
|
|
45
|
+
def first_step(self, zero_grad: bool = False) -> None:
|
|
46
|
+
"""First step for SAM optimizer.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
zero_grad: Boolean flag indicating whether to zero the gradients.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
None
|
|
53
|
+
"""
|
|
54
|
+
grad_norm = self._grad_norm()
|
|
55
|
+
for group in self.param_groups:
|
|
56
|
+
scale = self.rho / (grad_norm + 1e-12)
|
|
57
|
+
|
|
58
|
+
for p in group["params"]:
|
|
59
|
+
if p.grad is None:
|
|
60
|
+
continue
|
|
61
|
+
e_w = (torch.pow(p, 2) if self.adaptive else 1.0) * p.grad * scale.to(p)
|
|
62
|
+
p.add_(e_w) # climb to the local maximum "w + e(w)"
|
|
63
|
+
self.state[p]["e_w"] = e_w
|
|
64
|
+
|
|
65
|
+
if zero_grad:
|
|
66
|
+
self.zero_grad()
|
|
67
|
+
|
|
68
|
+
@torch.no_grad()
|
|
69
|
+
def second_step(self, zero_grad: bool = False) -> None:
|
|
70
|
+
"""Second step for SAM optimizer.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
zero_grad: Boolean flag indicating whether to zero the gradients.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
None
|
|
77
|
+
|
|
78
|
+
"""
|
|
79
|
+
for group in self.param_groups:
|
|
80
|
+
for p in group["params"]:
|
|
81
|
+
if p.grad is None:
|
|
82
|
+
continue
|
|
83
|
+
p.sub_(self.state[p]["e_w"]) # get back to "w" from "w + e(w)"
|
|
84
|
+
|
|
85
|
+
self.base_optimizer.step() # do the actual "sharpness-aware" update
|
|
86
|
+
|
|
87
|
+
if zero_grad:
|
|
88
|
+
self.zero_grad()
|
|
89
|
+
|
|
90
|
+
@torch.no_grad()
|
|
91
|
+
def step(self, closure: Callable | None = None) -> None: # type: ignore[override]
|
|
92
|
+
"""Step for SAM optimizer.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
closure: The Optional closure for enable grad.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
None
|
|
99
|
+
|
|
100
|
+
"""
|
|
101
|
+
if closure is not None:
|
|
102
|
+
closure = torch.enable_grad()(closure)
|
|
103
|
+
|
|
104
|
+
self.first_step(zero_grad=True)
|
|
105
|
+
if closure is not None:
|
|
106
|
+
closure()
|
|
107
|
+
self.second_step(zero_grad=False)
|
|
108
|
+
|
|
109
|
+
def _grad_norm(self) -> torch.Tensor:
|
|
110
|
+
"""Put everything on the same device, in case of model parallelism
|
|
111
|
+
Returns:
|
|
112
|
+
Grad norm.
|
|
113
|
+
"""
|
|
114
|
+
# put everything on the same device, in case of model parallelism
|
|
115
|
+
shared_device = self.param_groups[0]["params"][0].device
|
|
116
|
+
norm = torch.norm(
|
|
117
|
+
torch.stack(
|
|
118
|
+
[
|
|
119
|
+
((torch.abs(p) if self.adaptive else 1.0) * p.grad).norm(p=2).to(shared_device)
|
|
120
|
+
for group in self.param_groups
|
|
121
|
+
for p in group["params"]
|
|
122
|
+
if p.grad is not None
|
|
123
|
+
]
|
|
124
|
+
),
|
|
125
|
+
p=2,
|
|
126
|
+
)
|
|
127
|
+
return norm
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from torch.optim import Optimizer
|
|
4
|
+
from torch.optim.lr_scheduler import _LRScheduler
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class LearningRateScheduler(_LRScheduler):
|
|
8
|
+
"""Provides inteface of learning rate scheduler.
|
|
9
|
+
|
|
10
|
+
Note:
|
|
11
|
+
Do not use this class directly, use one of the sub classes.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, optimizer: Optimizer, init_lr: tuple[float, ...]):
|
|
15
|
+
# pylint: disable=super-init-not-called
|
|
16
|
+
self.optimizer = optimizer
|
|
17
|
+
self.init_lr = init_lr
|
|
18
|
+
|
|
19
|
+
def step(self, *args, **kwargs):
|
|
20
|
+
"""Base method, must be implemented by the sub classes."""
|
|
21
|
+
raise NotImplementedError
|
|
22
|
+
|
|
23
|
+
def set_lr(self, lr: tuple[float, ...]):
|
|
24
|
+
"""Set the learning rate for the optimizer."""
|
|
25
|
+
if self.optimizer is not None:
|
|
26
|
+
for i, g in enumerate(self.optimizer.param_groups):
|
|
27
|
+
if "fix_lr" in g and g["fix_lr"]:
|
|
28
|
+
if len(lr) == 1:
|
|
29
|
+
lr_to_set = self.init_lr[0]
|
|
30
|
+
else:
|
|
31
|
+
lr_to_set = self.init_lr[i]
|
|
32
|
+
elif len(lr) == 1:
|
|
33
|
+
lr_to_set = lr[0]
|
|
34
|
+
else:
|
|
35
|
+
lr_to_set = lr[i]
|
|
36
|
+
g["lr"] = lr_to_set
|
|
37
|
+
|
|
38
|
+
def get_lr(self):
|
|
39
|
+
"""Get the current learning rate if the optimizer is available."""
|
|
40
|
+
if self.optimizer is not None:
|
|
41
|
+
for g in self.optimizer.param_groups:
|
|
42
|
+
return g["lr"]
|
|
43
|
+
|
|
44
|
+
return None
|