quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import cast
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import torch
|
|
8
|
+
from numpy import ndarray
|
|
9
|
+
from pandas import DataFrame
|
|
10
|
+
from sklearn.linear_model import LogisticRegression
|
|
11
|
+
from sklearn.linear_model._base import ClassifierMixin
|
|
12
|
+
from torch.utils.data import DataLoader
|
|
13
|
+
|
|
14
|
+
from quadra.utils import utils
|
|
15
|
+
from quadra.utils.classification import get_results
|
|
16
|
+
from quadra.utils.models import get_feature
|
|
17
|
+
|
|
18
|
+
log = utils.get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class SklearnClassificationTrainer:
|
|
22
|
+
"""Class to configure and run a classification using torch for feature extraction and sklearn to fit a classifier.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
input_shape: [H, W, C]
|
|
26
|
+
random_state: seed to fix randomness
|
|
27
|
+
classifier: classification model
|
|
28
|
+
iteration_over_training: the number of iteration over training during feature extraction
|
|
29
|
+
backbone: the feature extractor
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
input_shape: list,
|
|
35
|
+
backbone: torch.nn.Module,
|
|
36
|
+
random_state: int = 42,
|
|
37
|
+
classifier: ClassifierMixin = LogisticRegression,
|
|
38
|
+
iteration_over_training: int = 1,
|
|
39
|
+
) -> None:
|
|
40
|
+
super().__init__()
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
self.classifier = classifier(max_iter=1e4, random_state=random_state)
|
|
44
|
+
except Exception:
|
|
45
|
+
self.classifier = classifier
|
|
46
|
+
|
|
47
|
+
self.input_shape = input_shape
|
|
48
|
+
self.random_state = random_state
|
|
49
|
+
self.iteration_over_training = iteration_over_training
|
|
50
|
+
self.backbone = backbone
|
|
51
|
+
|
|
52
|
+
def change_backbone(self, backbone: torch.nn.Module):
|
|
53
|
+
"""Update feature extractor."""
|
|
54
|
+
self.backbone = backbone
|
|
55
|
+
self.backbone.eval()
|
|
56
|
+
|
|
57
|
+
def change_classifier(self, classifier: ClassifierMixin):
|
|
58
|
+
"""Update classifier."""
|
|
59
|
+
self.classifier = classifier
|
|
60
|
+
|
|
61
|
+
def fit(
|
|
62
|
+
self,
|
|
63
|
+
train_dataloader: DataLoader | None = None,
|
|
64
|
+
train_features: ndarray | None = None,
|
|
65
|
+
train_labels: ndarray | None = None,
|
|
66
|
+
):
|
|
67
|
+
"""Fit classifier on training set."""
|
|
68
|
+
# Extract feature
|
|
69
|
+
if self.backbone is None:
|
|
70
|
+
raise AssertionError("You must set a model before running execution")
|
|
71
|
+
|
|
72
|
+
if train_dataloader is not None: # train_features is None or train_labels is None:
|
|
73
|
+
log.info("Extracting features from training set")
|
|
74
|
+
train_features, train_labels, _ = get_feature(
|
|
75
|
+
feature_extractor=self.backbone,
|
|
76
|
+
dl=train_dataloader,
|
|
77
|
+
iteration_over_training=self.iteration_over_training,
|
|
78
|
+
gradcam=False,
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
log.info("Using cached features for training set")
|
|
82
|
+
# With the current implementation cached features are not sorted
|
|
83
|
+
# Even though it doesn't seem to change anything
|
|
84
|
+
if train_features is None or train_labels is None:
|
|
85
|
+
raise AssertionError("Train features and labels must be provided when using cached data")
|
|
86
|
+
permuted_indices = np.random.RandomState(seed=self.random_state).permutation(train_features.shape[0])
|
|
87
|
+
train_features = train_features[permuted_indices]
|
|
88
|
+
train_labels = train_labels[permuted_indices]
|
|
89
|
+
|
|
90
|
+
log.info("Fitting classifier on %d features", len(train_features)) # type: ignore[arg-type]
|
|
91
|
+
self.classifier.fit(train_features, train_labels)
|
|
92
|
+
|
|
93
|
+
def test(
|
|
94
|
+
self,
|
|
95
|
+
test_dataloader: DataLoader,
|
|
96
|
+
test_labels: ndarray | None = None,
|
|
97
|
+
test_features: ndarray | None = None,
|
|
98
|
+
class_to_keep: list[int] | None = None,
|
|
99
|
+
idx_to_class: dict[int, str] | None = None,
|
|
100
|
+
predict_proba: bool = True,
|
|
101
|
+
gradcam: bool = False,
|
|
102
|
+
) -> (
|
|
103
|
+
tuple[str | dict, DataFrame, float, DataFrame, np.ndarray | None]
|
|
104
|
+
| tuple[None, None, None, DataFrame, np.ndarray | None]
|
|
105
|
+
):
|
|
106
|
+
"""Test classifier on test set.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
test_dataloader: Test dataloader
|
|
110
|
+
test_labels: test labels
|
|
111
|
+
test_features: Optional test features used when cache data is available
|
|
112
|
+
class_to_keep: list of class to keep
|
|
113
|
+
idx_to_class: dictionary mapping class index to class name
|
|
114
|
+
predict_proba: if True, predict also probability for each test image
|
|
115
|
+
gradcam: Whether to compute gradcam
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
cl_rep: Classification report
|
|
119
|
+
pd_cm: Confusion matrix dataframe
|
|
120
|
+
accuracy: Test accuracy
|
|
121
|
+
res: Test results
|
|
122
|
+
cams: Gradcams
|
|
123
|
+
"""
|
|
124
|
+
cams = None
|
|
125
|
+
# Extract feature
|
|
126
|
+
if test_features is None:
|
|
127
|
+
log.info("Extracting features from test set")
|
|
128
|
+
test_features, final_test_labels, cams = get_feature(
|
|
129
|
+
feature_extractor=self.backbone,
|
|
130
|
+
dl=test_dataloader,
|
|
131
|
+
gradcam=gradcam,
|
|
132
|
+
classifier=self.classifier,
|
|
133
|
+
input_shape=(self.input_shape[2], self.input_shape[0], self.input_shape[1]),
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
if test_labels is None:
|
|
137
|
+
raise ValueError("Test labels must be provided when using cached data")
|
|
138
|
+
log.info("Using cached features for test set")
|
|
139
|
+
final_test_labels = test_labels
|
|
140
|
+
|
|
141
|
+
# Run classifier
|
|
142
|
+
log.info("Predict classifier on test set")
|
|
143
|
+
test_prediction_label = self.classifier.predict(test_features)
|
|
144
|
+
if predict_proba:
|
|
145
|
+
test_probability = self.classifier.predict_proba(test_features)
|
|
146
|
+
test_probability = test_probability.max(axis=1)
|
|
147
|
+
|
|
148
|
+
if class_to_keep is not None:
|
|
149
|
+
if idx_to_class is None:
|
|
150
|
+
raise ValueError("You must provide `idx_to_class` and `test_labels` when using `class_to_keep`")
|
|
151
|
+
filtered_test_labels = [int(x) if idx_to_class[x] in class_to_keep else -1 for x in final_test_labels]
|
|
152
|
+
else:
|
|
153
|
+
filtered_test_labels = cast(list[int], final_test_labels.tolist())
|
|
154
|
+
|
|
155
|
+
if not hasattr(test_dataloader.dataset, "x"):
|
|
156
|
+
raise ValueError("Current dataset doesn't provide an `x` attribute")
|
|
157
|
+
|
|
158
|
+
res = pd.DataFrame(
|
|
159
|
+
{
|
|
160
|
+
"sample": list(test_dataloader.dataset.x),
|
|
161
|
+
"real_label": final_test_labels,
|
|
162
|
+
"pred_label": test_prediction_label,
|
|
163
|
+
}
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
if not all(t == -1 for t in filtered_test_labels):
|
|
167
|
+
test_real_label_cm = np.array(filtered_test_labels)
|
|
168
|
+
if cams is not None:
|
|
169
|
+
cams = cams[test_real_label_cm != -1] # TODO: Is class_to_keep still used?
|
|
170
|
+
pred_labels_cm = np.array(test_prediction_label)[test_real_label_cm != -1]
|
|
171
|
+
test_real_label_cm = test_real_label_cm[test_real_label_cm != -1].astype(pred_labels_cm.dtype)
|
|
172
|
+
cl_rep, pd_cm, accuracy = get_results(test_real_label_cm, pred_labels_cm, idx_to_class)
|
|
173
|
+
|
|
174
|
+
if predict_proba:
|
|
175
|
+
res["probability"] = test_probability
|
|
176
|
+
|
|
177
|
+
return cl_rep, pd_cm, accuracy, res, cams
|
|
178
|
+
|
|
179
|
+
return None, None, None, res, cams
|
quadra/utils/__init__.py
ADDED
|
File without changes
|
quadra/utils/anomaly.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""Anomaly Score Normalization Callback that uses min-max normalization."""
|
|
2
|
+
|
|
3
|
+
# Copyright (C) 2022 Intel Corporation
|
|
4
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from typing import Any, TypeAlias
|
|
10
|
+
except ImportError:
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from typing_extensions import TypeAlias # noqa
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# MyPy wants TypeAlias, but pylint has problems dealing with it
|
|
17
|
+
import numpy as np # pylint: disable=unused-import
|
|
18
|
+
import pytorch_lightning as pl
|
|
19
|
+
import torch # pylint: disable=unused-import
|
|
20
|
+
from anomalib.models.components import AnomalyModule
|
|
21
|
+
from pytorch_lightning import Callback
|
|
22
|
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
|
23
|
+
|
|
24
|
+
# https://github.com/python/cpython/issues/90015#issuecomment-1172996118
|
|
25
|
+
MapOrValue: TypeAlias = "float | torch.Tensor | np.ndarray"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def normalize_anomaly_score(raw_score: MapOrValue, threshold: float) -> MapOrValue:
|
|
29
|
+
"""Normalize anomaly score value or map based on threshold.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
raw_score: Raw anomaly score valure or map
|
|
33
|
+
threshold: Threshold for anomaly detection
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Normalized anomaly score value or map clipped between 0 and 1000
|
|
37
|
+
"""
|
|
38
|
+
if threshold > 0:
|
|
39
|
+
normalized_score = (raw_score / threshold) * 100.0
|
|
40
|
+
elif threshold == 0:
|
|
41
|
+
# TODO: Is this the best way to handle this case?
|
|
42
|
+
normalized_score = (raw_score + 1) * 100.0
|
|
43
|
+
else:
|
|
44
|
+
normalized_score = 200.0 - ((raw_score / threshold) * 100.0)
|
|
45
|
+
|
|
46
|
+
if isinstance(normalized_score, torch.Tensor):
|
|
47
|
+
return torch.clamp(normalized_score, 0.0, 1000.0)
|
|
48
|
+
|
|
49
|
+
return np.clip(normalized_score, 0.0, 1000.0)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class ThresholdNormalizationCallback(Callback):
|
|
53
|
+
"""Callback that normalizes the image-level and pixel-level anomaly scores dividing by the threshold value.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
threshold_type: Threshold used to normalize pixel level anomaly scores, either image or pixel (default)
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, threshold_type: str = "pixel"):
|
|
60
|
+
super().__init__()
|
|
61
|
+
self.threshold_type = threshold_type
|
|
62
|
+
|
|
63
|
+
def on_test_start(self, trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
|
|
64
|
+
"""Called when the test begins."""
|
|
65
|
+
del trainer # `trainer` variable is not used.
|
|
66
|
+
|
|
67
|
+
for metric in (pl_module.image_metrics, pl_module.pixel_metrics):
|
|
68
|
+
if metric is not None:
|
|
69
|
+
metric.set_threshold(100.0)
|
|
70
|
+
|
|
71
|
+
def on_test_batch_end(
|
|
72
|
+
self,
|
|
73
|
+
trainer: pl.Trainer,
|
|
74
|
+
pl_module: AnomalyModule,
|
|
75
|
+
outputs: STEP_OUTPUT | None,
|
|
76
|
+
batch: Any,
|
|
77
|
+
batch_idx: int,
|
|
78
|
+
dataloader_idx: int = 0,
|
|
79
|
+
) -> None:
|
|
80
|
+
"""Called when the test batch ends, normalizes the predicted scores and anomaly maps."""
|
|
81
|
+
del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
|
|
82
|
+
|
|
83
|
+
self._normalize_batch(outputs, pl_module)
|
|
84
|
+
|
|
85
|
+
def on_predict_batch_end(
|
|
86
|
+
self,
|
|
87
|
+
trainer: pl.Trainer,
|
|
88
|
+
pl_module: AnomalyModule,
|
|
89
|
+
outputs: Any,
|
|
90
|
+
batch: Any,
|
|
91
|
+
batch_idx: int,
|
|
92
|
+
dataloader_idx: int = 0,
|
|
93
|
+
) -> None:
|
|
94
|
+
"""Called when the predict batch ends, normalizes the predicted scores and anomaly maps."""
|
|
95
|
+
del trainer, batch, batch_idx, dataloader_idx # These variables are not used.
|
|
96
|
+
|
|
97
|
+
self._normalize_batch(outputs, pl_module)
|
|
98
|
+
|
|
99
|
+
def _normalize_batch(self, outputs, pl_module):
|
|
100
|
+
"""Normalize a batch of predictions."""
|
|
101
|
+
image_threshold = pl_module.image_threshold.value.cpu()
|
|
102
|
+
pixel_threshold = pl_module.pixel_threshold.value.cpu()
|
|
103
|
+
outputs["pred_scores"] = normalize_anomaly_score(outputs["pred_scores"], image_threshold.item())
|
|
104
|
+
|
|
105
|
+
threshold = pixel_threshold if self.threshold_type == "pixel" else image_threshold
|
|
106
|
+
threshold = threshold.item()
|
|
107
|
+
|
|
108
|
+
if "anomaly_maps" in outputs:
|
|
109
|
+
outputs["anomaly_maps"] = normalize_anomaly_score(outputs["anomaly_maps"], threshold)
|
|
110
|
+
|
|
111
|
+
if "box_scores" in outputs:
|
|
112
|
+
outputs["box_scores"] = [normalize_anomaly_score(scores, threshold) for scores in outputs["box_scores"]]
|