quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
quadra/utils/models.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import warnings
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Union, cast
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import timm
|
|
10
|
+
import torch
|
|
11
|
+
import torch.nn.functional as F
|
|
12
|
+
import tqdm
|
|
13
|
+
from pytorch_grad_cam import GradCAM
|
|
14
|
+
from scipy import ndimage
|
|
15
|
+
from sklearn.linear_model._base import ClassifierMixin
|
|
16
|
+
from timm.models.layers import DropPath
|
|
17
|
+
from timm.models.vision_transformer import Mlp
|
|
18
|
+
from torch import nn
|
|
19
|
+
|
|
20
|
+
from quadra.models.evaluation import (
|
|
21
|
+
BaseEvaluationModel,
|
|
22
|
+
ONNXEvaluationModel,
|
|
23
|
+
TorchEvaluationModel,
|
|
24
|
+
TorchscriptEvaluationModel,
|
|
25
|
+
)
|
|
26
|
+
from quadra.utils import utils
|
|
27
|
+
from quadra.utils.vit_explainability import VitAttentionGradRollout
|
|
28
|
+
|
|
29
|
+
log = utils.get_logger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def net_hat(input_size: int, output_size: int) -> torch.nn.Sequential:
|
|
33
|
+
"""Create a linear layer with input and output neurons.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
input_size: Number of input neurons
|
|
37
|
+
output_size: Number of output neurons.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A sequential containing a single Linear layer taking input neurons and producing output neurons
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
return torch.nn.Sequential(torch.nn.Linear(input_size, output_size))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def create_net_hat(dims: list[int], act_fun: Callable = torch.nn.ReLU, dropout_p: float = 0) -> torch.nn.Sequential:
|
|
47
|
+
"""Create a sequence of linear layers with activation functions and dropout.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
dims: Dimension of hidden layers and output
|
|
51
|
+
act_fun: activation function to use between layers, default ReLU
|
|
52
|
+
dropout_p: Dropout probability. Defaults to 0.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Sequence of linear layers of dimension specified by the input, each linear layer is followed
|
|
56
|
+
by an activation function and optionally a dropout layer with the input probability
|
|
57
|
+
"""
|
|
58
|
+
components: list[nn.Module] = []
|
|
59
|
+
for i, _ in enumerate(dims[:-2]):
|
|
60
|
+
if dropout_p > 0:
|
|
61
|
+
components.append(torch.nn.Dropout(dropout_p))
|
|
62
|
+
components.append(net_hat(dims[i], dims[i + 1]))
|
|
63
|
+
components.append(act_fun())
|
|
64
|
+
components.append(net_hat(dims[-2], dims[-1]))
|
|
65
|
+
components.append(L2Norm())
|
|
66
|
+
return torch.nn.Sequential(*components)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class L2Norm(torch.nn.Module):
|
|
70
|
+
"""Compute L2 Norm."""
|
|
71
|
+
|
|
72
|
+
def forward(self, x: torch.Tensor):
|
|
73
|
+
return x / torch.norm(x, p=2, dim=1, keepdim=True)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def init_weights(m):
|
|
77
|
+
"""Basic weight initialization."""
|
|
78
|
+
classname = m.__class__.__name__
|
|
79
|
+
if classname.find("Conv2d") != -1 or classname.find("ConvTranspose2d") != -1:
|
|
80
|
+
nn.init.kaiming_uniform_(m.weight)
|
|
81
|
+
nn.init.zeros_(m.bias)
|
|
82
|
+
elif classname.find("BatchNorm") != -1:
|
|
83
|
+
nn.init.normal_(m.weight, 1.0, 0.02)
|
|
84
|
+
nn.init.zeros_(m.bias)
|
|
85
|
+
elif classname.find("Linear") != -1:
|
|
86
|
+
nn.init.xavier_normal_(m.weight)
|
|
87
|
+
m.bias.data.fill_(0)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def get_feature(
|
|
91
|
+
feature_extractor: torch.nn.Module | BaseEvaluationModel,
|
|
92
|
+
dl: torch.utils.data.DataLoader,
|
|
93
|
+
iteration_over_training: int = 1,
|
|
94
|
+
gradcam: bool = False,
|
|
95
|
+
classifier: ClassifierMixin | None = None,
|
|
96
|
+
input_shape: tuple[int, int, int] | None = None,
|
|
97
|
+
limit_batches: int | None = None,
|
|
98
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
|
|
99
|
+
"""Given a dataloader and a PyTorch model, extract features with the model and return features and labels.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
dl: PyTorch dataloader
|
|
103
|
+
feature_extractor: Pretrained PyTorch backbone
|
|
104
|
+
iteration_over_training: Extract feature iteration_over_training times for each image
|
|
105
|
+
(best if used with augmentation)
|
|
106
|
+
gradcam: Whether to compute gradcams. Notice that it will slow the function
|
|
107
|
+
classifier: Scikit-learn classifier
|
|
108
|
+
input_shape: [H,W,C], backbone input shape, needed by classifier's pytorch wrapper
|
|
109
|
+
limit_batches: Limit the number of batches to be processed
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Tuple containing:
|
|
113
|
+
features: Model features
|
|
114
|
+
labels: input_labels
|
|
115
|
+
grayscale_cams: Gradcam output maps, None if gradcam arg is False
|
|
116
|
+
"""
|
|
117
|
+
if isinstance(feature_extractor, (TorchEvaluationModel, TorchscriptEvaluationModel)):
|
|
118
|
+
# If we are working with torch based evaluation models we need to extract the model
|
|
119
|
+
feature_extractor = feature_extractor.model
|
|
120
|
+
elif isinstance(feature_extractor, ONNXEvaluationModel):
|
|
121
|
+
gradcam = False
|
|
122
|
+
|
|
123
|
+
feature_extractor.eval()
|
|
124
|
+
|
|
125
|
+
# Setup gradcam
|
|
126
|
+
if gradcam:
|
|
127
|
+
if not hasattr(feature_extractor, "features_extractor"):
|
|
128
|
+
gradcam = False
|
|
129
|
+
elif isinstance(feature_extractor.features_extractor, timm.models.resnet.ResNet):
|
|
130
|
+
target_layers = [feature_extractor.features_extractor.layer4[-1]]
|
|
131
|
+
cam = GradCAM(
|
|
132
|
+
model=feature_extractor,
|
|
133
|
+
target_layers=target_layers,
|
|
134
|
+
)
|
|
135
|
+
for p in feature_extractor.features_extractor.layer4[-1].parameters():
|
|
136
|
+
p.requires_grad = True
|
|
137
|
+
elif is_vision_transformer(feature_extractor.features_extractor):
|
|
138
|
+
grad_rollout = VitAttentionGradRollout(
|
|
139
|
+
feature_extractor.features_extractor,
|
|
140
|
+
classifier=classifier,
|
|
141
|
+
example_input=None if input_shape is None else torch.randn(1, *input_shape),
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
gradcam = False
|
|
145
|
+
|
|
146
|
+
if not gradcam:
|
|
147
|
+
log.warning("Gradcam not implemented for this backbone, it will not be computed")
|
|
148
|
+
|
|
149
|
+
# Extract features from data
|
|
150
|
+
|
|
151
|
+
for iteration in range(iteration_over_training):
|
|
152
|
+
for i, b in enumerate(tqdm.tqdm(dl)):
|
|
153
|
+
x1, y1 = b
|
|
154
|
+
|
|
155
|
+
if hasattr(feature_extractor, "parameters"):
|
|
156
|
+
# Move input to the correct device and dtype
|
|
157
|
+
parameter = next(feature_extractor.parameters())
|
|
158
|
+
x1 = x1.to(parameter.device).to(parameter.dtype)
|
|
159
|
+
elif isinstance(feature_extractor, BaseEvaluationModel):
|
|
160
|
+
x1 = x1.to(feature_extractor.device).to(feature_extractor.model_dtype)
|
|
161
|
+
|
|
162
|
+
if gradcam:
|
|
163
|
+
y_hat = cast(
|
|
164
|
+
Union[list[torch.Tensor], tuple[torch.Tensor], torch.Tensor], feature_extractor(x1).detach()
|
|
165
|
+
)
|
|
166
|
+
# mypy can't detect that gradcam is true only if we have a features_extractor
|
|
167
|
+
if is_vision_transformer(feature_extractor.features_extractor): # type: ignore[union-attr]
|
|
168
|
+
grayscale_cam_low_res = grad_rollout(
|
|
169
|
+
input_tensor=x1, targets_list=y1
|
|
170
|
+
) # TODO: We are using labels (y1) but it would be better to use preds
|
|
171
|
+
orig_shape = grayscale_cam_low_res.shape
|
|
172
|
+
new_shape = (orig_shape[0], x1.shape[2], x1.shape[3])
|
|
173
|
+
zoom_factors = tuple(np.array(new_shape) / np.array(orig_shape))
|
|
174
|
+
grayscale_cam = ndimage.zoom(grayscale_cam_low_res, zoom_factors, order=1)
|
|
175
|
+
else:
|
|
176
|
+
grayscale_cam = cam(input_tensor=x1, targets=None)
|
|
177
|
+
feature_extractor.zero_grad(set_to_none=True) # type: ignore[union-attr]
|
|
178
|
+
else:
|
|
179
|
+
with torch.no_grad():
|
|
180
|
+
y_hat = cast(Union[list[torch.Tensor], tuple[torch.Tensor], torch.Tensor], feature_extractor(x1))
|
|
181
|
+
grayscale_cams = None
|
|
182
|
+
|
|
183
|
+
if isinstance(y_hat, (list, tuple)):
|
|
184
|
+
y_hat = y_hat[0].cpu()
|
|
185
|
+
else:
|
|
186
|
+
y_hat = y_hat.cpu()
|
|
187
|
+
|
|
188
|
+
if torch.cuda.is_available():
|
|
189
|
+
torch.cuda.empty_cache()
|
|
190
|
+
|
|
191
|
+
if i == 0 and iteration == 0:
|
|
192
|
+
features = torch.cat([y_hat], dim=0)
|
|
193
|
+
labels = np.concatenate([y1])
|
|
194
|
+
if gradcam:
|
|
195
|
+
grayscale_cams = grayscale_cam
|
|
196
|
+
else:
|
|
197
|
+
features = torch.cat([features, y_hat], dim=0)
|
|
198
|
+
labels = np.concatenate([labels, y1], axis=0)
|
|
199
|
+
if gradcam:
|
|
200
|
+
grayscale_cams = np.concatenate([grayscale_cams, grayscale_cam], axis=0)
|
|
201
|
+
|
|
202
|
+
if limit_batches is not None and (i + 1) >= limit_batches:
|
|
203
|
+
break
|
|
204
|
+
|
|
205
|
+
return features.detach().numpy(), labels, grayscale_cams
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def is_vision_transformer(model: torch.nn.Module) -> bool:
|
|
209
|
+
"""Verify if pytorch module is a Vision Transformer.
|
|
210
|
+
This check is primarily needed for gradcam computation in classification tasks.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
model: Model
|
|
214
|
+
"""
|
|
215
|
+
return type(model).__name__ == "VisionTransformer"
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _no_grad_trunc_normal_(tensor: torch.Tensor, mean: float, std: float, a: float, b: float):
|
|
219
|
+
"""Cut & paste from PyTorch official master until it's in a few official releases - RW
|
|
220
|
+
Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
tensor: an n-dimensional `torch.Tensor`
|
|
224
|
+
mean: the mean of the normal distribution
|
|
225
|
+
std: the standard deviation of the normal distribution
|
|
226
|
+
a: the minimum cutoff
|
|
227
|
+
b: the maximum cutoff
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
def norm_cdf(x: float):
|
|
231
|
+
"""Computes standard normal cumulative distribution function."""
|
|
232
|
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
|
233
|
+
|
|
234
|
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
|
235
|
+
warnings.warn(
|
|
236
|
+
(
|
|
237
|
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
|
238
|
+
"The distribution of values may be incorrect."
|
|
239
|
+
),
|
|
240
|
+
stacklevel=2,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
with torch.no_grad():
|
|
244
|
+
# Values are generated by using a truncated uniform distribution and
|
|
245
|
+
# then using the inverse CDF for the normal distribution.
|
|
246
|
+
# Get upper and lower cdf values
|
|
247
|
+
l = norm_cdf((a - mean) / std)
|
|
248
|
+
u = norm_cdf((b - mean) / std)
|
|
249
|
+
|
|
250
|
+
# Uniformly fill tensor with values from [l, u], then translate to
|
|
251
|
+
# [2l-1, 2u-1].
|
|
252
|
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
|
253
|
+
|
|
254
|
+
# Use inverse cdf transform for normal distribution to get truncated
|
|
255
|
+
# standard normal
|
|
256
|
+
tensor.erfinv_()
|
|
257
|
+
|
|
258
|
+
# Transform to proper mean, std
|
|
259
|
+
tensor.mul_(std * math.sqrt(2.0))
|
|
260
|
+
tensor.add_(mean)
|
|
261
|
+
|
|
262
|
+
# Clamp to ensure it's in the proper range
|
|
263
|
+
tensor.clamp_(min=a, max=b)
|
|
264
|
+
return tensor
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def trunc_normal_(tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0):
|
|
268
|
+
"""Call `_no_grad_trunc_normal_` with `torch.no_grad()`.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
tensor: an n-dimensional `torch.Tensor`
|
|
272
|
+
mean: the mean of the normal distribution
|
|
273
|
+
std: the standard deviation of the normal distribution
|
|
274
|
+
a: the minimum cutoff
|
|
275
|
+
b: the maximum cutoff
|
|
276
|
+
"""
|
|
277
|
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def clip_gradients(model: nn.Module, clip: float) -> list[float]:
|
|
281
|
+
"""Args:
|
|
282
|
+
model: The model
|
|
283
|
+
clip: The clip value.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
The norms of the gradients
|
|
287
|
+
"""
|
|
288
|
+
norms = []
|
|
289
|
+
for _, p in model.named_parameters():
|
|
290
|
+
if p.grad is not None:
|
|
291
|
+
param_norm = p.grad.data.norm(2)
|
|
292
|
+
norms.append(param_norm.item())
|
|
293
|
+
clip_coef = clip / (param_norm + 1e-6)
|
|
294
|
+
if clip_coef < 1:
|
|
295
|
+
p.grad.data.mul_(clip_coef)
|
|
296
|
+
return norms
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
# TODO: do not use this implementation for new models
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
class AttentionExtractor(torch.nn.Module):
|
|
303
|
+
"""General attention extractor.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
model: Backbone model which contains the attention layer.
|
|
307
|
+
attention_layer_name: Attention layer for extracting attention maps.
|
|
308
|
+
Defaults to "attn_drop".
|
|
309
|
+
attention_layer_name: Attention layer for extracting attention maps.
|
|
310
|
+
"""
|
|
311
|
+
|
|
312
|
+
def __init__(self, model: torch.nn.Module, attention_layer_name: str = "attn_drop"):
|
|
313
|
+
super().__init__()
|
|
314
|
+
self.model = model
|
|
315
|
+
modules = [module for module_name, module in self.model.named_modules() if attention_layer_name in module_name]
|
|
316
|
+
if modules:
|
|
317
|
+
modules[-1].register_forward_hook(self.get_attention)
|
|
318
|
+
self.attentions = torch.zeros((1, 0))
|
|
319
|
+
|
|
320
|
+
def clear(self):
|
|
321
|
+
"""Clear the grabbed attentions."""
|
|
322
|
+
self.attentions = torch.zeros((1, 0))
|
|
323
|
+
|
|
324
|
+
def get_attention(self, module: nn.Module, input_tensor: torch.Tensor, output: torch.Tensor): # pylint: disable=unused-argument
|
|
325
|
+
"""Method to be registered to grab attentions."""
|
|
326
|
+
self.attentions = output.detach().clone().cpu()
|
|
327
|
+
|
|
328
|
+
@staticmethod
|
|
329
|
+
def process_attention_maps(attentions: torch.Tensor, img_width: int, img_height: int) -> torch.Tensor:
|
|
330
|
+
"""Preprocess attentions maps to be visualized.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
attentions: grabbed attentions
|
|
334
|
+
img_width: image width
|
|
335
|
+
img_height: image height
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
torch.Tensor: preprocessed attentions, with the shape equal to the one of the image from
|
|
339
|
+
which attentions has been computed
|
|
340
|
+
"""
|
|
341
|
+
if len(attentions.shape) == 4:
|
|
342
|
+
# vit
|
|
343
|
+
# batch, heads, N, N (class atention layer)
|
|
344
|
+
attentions = attentions[:, :, 0, 1:] # batch, heads, height-1
|
|
345
|
+
|
|
346
|
+
else:
|
|
347
|
+
# xcit
|
|
348
|
+
# batch, heads, N
|
|
349
|
+
attentions = attentions[:, :, 1:] # batch, heads, dim-1
|
|
350
|
+
nh = attentions.shape[1]
|
|
351
|
+
patch_size = int(math.sqrt(img_width * img_height / attentions.shape[-1]))
|
|
352
|
+
w_featmap = img_width // patch_size
|
|
353
|
+
h_featmap = img_height // patch_size
|
|
354
|
+
|
|
355
|
+
# we keep only the output patch attention we dont want cls
|
|
356
|
+
attentions = attentions.reshape(attentions.shape[0], nh, w_featmap, h_featmap)
|
|
357
|
+
attentions = F.interpolate(attentions, scale_factor=patch_size, mode="nearest")
|
|
358
|
+
return attentions
|
|
359
|
+
|
|
360
|
+
def forward(self, t: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
361
|
+
self.clear()
|
|
362
|
+
out = self.model(t)
|
|
363
|
+
return (out, self.attentions) # torch.jit.trace does not complain
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
# TODO: do not use this implementation for new models
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
class PositionalEncoding1D(torch.nn.Module):
|
|
370
|
+
"""Standard sine-cosine positional encoding from https://arxiv.org/abs/2010.11929.
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
d_model: Embedding dimension
|
|
374
|
+
temperature: Temperature for the positional encoding. Defaults to 10000.0.
|
|
375
|
+
dropout: Dropout rate. Defaults to 0.0.
|
|
376
|
+
max_len: Maximum length of the sequence. Defaults to 5000.
|
|
377
|
+
"""
|
|
378
|
+
|
|
379
|
+
def __init__(self, d_model: int, temperature: float = 10000.0, dropout: float = 0.0, max_len: int = 5000):
|
|
380
|
+
super().__init__()
|
|
381
|
+
self.dropout: torch.nn.Dropout | torch.nn.Identity
|
|
382
|
+
if dropout > 0:
|
|
383
|
+
self.dropout = torch.nn.Dropout(p=dropout)
|
|
384
|
+
else:
|
|
385
|
+
self.dropout = torch.nn.Identity()
|
|
386
|
+
|
|
387
|
+
position = torch.arange(max_len).unsqueeze(1)
|
|
388
|
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(temperature) / d_model))
|
|
389
|
+
self.pe = torch.zeros(max_len, 1, d_model)
|
|
390
|
+
self.pe[:, 0, 0::2] = torch.sin(position * div_term)
|
|
391
|
+
self.pe[:, 0, 1::2] = torch.cos(position * div_term)
|
|
392
|
+
self.pe = self.pe.permute(1, 0, 2)
|
|
393
|
+
self.pe = torch.nn.Parameter(self.pe)
|
|
394
|
+
self.pe.requires_grad = False
|
|
395
|
+
|
|
396
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
397
|
+
"""Forward pass of the positional encoding.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
x: torch tensor [batch_size, seq_len, embedding_dim].
|
|
401
|
+
"""
|
|
402
|
+
x = x + self.pe[:, : x.size(1), :]
|
|
403
|
+
return self.dropout(x)
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
class LSABlock(torch.nn.Module):
|
|
407
|
+
"""Local Self Attention Block from https://arxiv.org/abs/2112.13492.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
dim: embedding dimension
|
|
411
|
+
num_heads: number of attention heads
|
|
412
|
+
mlp_ratio: ratio of mlp hidden dim to embedding dim
|
|
413
|
+
qkv_bias: enable bias for qkv if True
|
|
414
|
+
drop: dropout rate
|
|
415
|
+
attn_drop: attention dropout rate
|
|
416
|
+
drop_path: stochastic depth rate
|
|
417
|
+
act_layer: activation layer
|
|
418
|
+
norm_layer:: normalization layer
|
|
419
|
+
mask_diagonal: whether to mask Q^T x K diagonal with -infinity so not to
|
|
420
|
+
count self relationship between tokens. Defaults to True
|
|
421
|
+
learnable_temperature: whether to use a learnable temperature as specified in
|
|
422
|
+
https://arxiv.org/abs/2112.13492. Defaults to True.
|
|
423
|
+
"""
|
|
424
|
+
|
|
425
|
+
def __init__(
|
|
426
|
+
self,
|
|
427
|
+
dim: int,
|
|
428
|
+
num_heads: int,
|
|
429
|
+
mlp_ratio: float = 4.0,
|
|
430
|
+
qkv_bias: bool = False,
|
|
431
|
+
drop: float = 0.0,
|
|
432
|
+
attn_drop: float = 0.0,
|
|
433
|
+
drop_path: float = 0.0,
|
|
434
|
+
act_layer: type[nn.Module] = torch.nn.GELU,
|
|
435
|
+
norm_layer: type[torch.nn.LayerNorm] = torch.nn.LayerNorm,
|
|
436
|
+
mask_diagonal: bool = True,
|
|
437
|
+
learnable_temperature: bool = True,
|
|
438
|
+
):
|
|
439
|
+
super().__init__()
|
|
440
|
+
self.norm1 = norm_layer(dim)
|
|
441
|
+
self.attn = LocalSelfAttention(
|
|
442
|
+
dim,
|
|
443
|
+
num_heads=num_heads,
|
|
444
|
+
qkv_bias=qkv_bias,
|
|
445
|
+
attn_drop=attn_drop,
|
|
446
|
+
proj_drop=drop,
|
|
447
|
+
mask_diagonal=mask_diagonal,
|
|
448
|
+
learnable_temperature=learnable_temperature,
|
|
449
|
+
)
|
|
450
|
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
|
451
|
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else torch.nn.Identity()
|
|
452
|
+
self.norm2 = norm_layer(dim)
|
|
453
|
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
454
|
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
455
|
+
|
|
456
|
+
def forward(self, x):
|
|
457
|
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
|
458
|
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
|
459
|
+
return x
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
class LocalSelfAttention(torch.nn.Module):
|
|
463
|
+
"""Local Self Attention from https://arxiv.org/abs/2112.13492.
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
dim: embedding dimension.
|
|
467
|
+
num_heads: number of attention heads.
|
|
468
|
+
qkv_bias: enable bias for qkv if True.
|
|
469
|
+
attn_drop: attention dropout rate.
|
|
470
|
+
proj_drop: projection dropout rate.
|
|
471
|
+
mask_diagonal: whether to mask Q^T x K diagonal with -infinity
|
|
472
|
+
so not to count self relationship between tokens. Defaults to True.
|
|
473
|
+
learnable_temperature: whether to use a learnable temperature as specified in
|
|
474
|
+
https://arxiv.org/abs/2112.13492. Defaults to True.
|
|
475
|
+
"""
|
|
476
|
+
|
|
477
|
+
def __init__(
|
|
478
|
+
self,
|
|
479
|
+
dim: int,
|
|
480
|
+
num_heads: int = 8,
|
|
481
|
+
qkv_bias: bool = False,
|
|
482
|
+
attn_drop: float = 0.0,
|
|
483
|
+
proj_drop: float = 0.0,
|
|
484
|
+
mask_diagonal: bool = True,
|
|
485
|
+
learnable_temperature: bool = True,
|
|
486
|
+
):
|
|
487
|
+
super().__init__()
|
|
488
|
+
self.num_heads = num_heads
|
|
489
|
+
head_dim = dim // num_heads
|
|
490
|
+
self.mask_diagonal = mask_diagonal
|
|
491
|
+
if learnable_temperature:
|
|
492
|
+
self.register_parameter("scale", torch.nn.Parameter(torch.tensor(head_dim**-0.5, requires_grad=True)))
|
|
493
|
+
else:
|
|
494
|
+
self.scale = head_dim**-0.5
|
|
495
|
+
|
|
496
|
+
self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
497
|
+
self.attn_drop = torch.nn.Dropout(attn_drop)
|
|
498
|
+
self.proj = torch.nn.Linear(dim, dim)
|
|
499
|
+
self.proj_drop = torch.nn.Dropout(proj_drop)
|
|
500
|
+
|
|
501
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
502
|
+
"""Computes the local self attention.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
x: input tensor
|
|
506
|
+
|
|
507
|
+
Returns:
|
|
508
|
+
Output of the local self attention.
|
|
509
|
+
"""
|
|
510
|
+
B, N, C = x.shape
|
|
511
|
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
512
|
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
|
513
|
+
|
|
514
|
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
515
|
+
if self.mask_diagonal:
|
|
516
|
+
attn[torch.eye(N, device=attn.device, dtype=torch.bool).repeat(B, self.num_heads, 1, 1)] = -float("inf")
|
|
517
|
+
attn = attn.softmax(dim=-1)
|
|
518
|
+
attn = self.attn_drop(attn)
|
|
519
|
+
|
|
520
|
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
521
|
+
x = self.proj(x)
|
|
522
|
+
x = self.proj_drop(x)
|
|
523
|
+
return x
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .dataset import generate_patch_dataset, get_image_mask_association
|
|
2
|
+
from .metrics import compute_patch_metrics, reconstruct_patch
|
|
3
|
+
from .model import RleEncoder, save_classification_result
|
|
4
|
+
from .visualization import plot_patch_reconstruction, plot_patch_results
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"generate_patch_dataset",
|
|
8
|
+
"reconstruct_patch",
|
|
9
|
+
"save_classification_result",
|
|
10
|
+
"plot_patch_reconstruction",
|
|
11
|
+
"plot_patch_results",
|
|
12
|
+
"get_image_mask_association",
|
|
13
|
+
"compute_patch_metrics",
|
|
14
|
+
"RleEncoder",
|
|
15
|
+
]
|