quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +32 -0
- quadra/configs/callbacks/default.yaml +37 -0
- quadra/configs/callbacks/default_anomalib.yaml +67 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +46 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +331 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1264 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +579 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +524 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.1.13.dist-info/LICENSE +201 -0
- quadra-2.1.13.dist-info/METADATA +386 -0
- quadra-2.1.13.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
- quadra-2.1.13.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,1264 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import typing
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Generic, cast
|
|
10
|
+
|
|
11
|
+
import cv2
|
|
12
|
+
import hydra
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
import numpy as np
|
|
15
|
+
import pandas as pd
|
|
16
|
+
import timm
|
|
17
|
+
import torch
|
|
18
|
+
from joblib import dump, load
|
|
19
|
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
20
|
+
from pytorch_grad_cam import GradCAM
|
|
21
|
+
from scipy import ndimage
|
|
22
|
+
from sklearn.base import ClassifierMixin
|
|
23
|
+
from sklearn.metrics import ConfusionMatrixDisplay
|
|
24
|
+
from torch import nn
|
|
25
|
+
from torchinfo import summary
|
|
26
|
+
from tqdm import tqdm
|
|
27
|
+
|
|
28
|
+
from quadra.callbacks.mlflow import get_mlflow_logger
|
|
29
|
+
from quadra.callbacks.scheduler import WarmupInit
|
|
30
|
+
from quadra.datamodules import (
|
|
31
|
+
ClassificationDataModule,
|
|
32
|
+
MultilabelClassificationDataModule,
|
|
33
|
+
SklearnClassificationDataModule,
|
|
34
|
+
)
|
|
35
|
+
from quadra.datasets.classification import ImageClassificationListDataset
|
|
36
|
+
from quadra.models.base import ModelSignatureWrapper
|
|
37
|
+
from quadra.models.classification import BaseNetworkBuilder
|
|
38
|
+
from quadra.models.evaluation import BaseEvaluationModel, TorchEvaluationModel, TorchscriptEvaluationModel
|
|
39
|
+
from quadra.modules.classification import ClassificationModule
|
|
40
|
+
from quadra.tasks.base import Evaluation, LightningTask, Task
|
|
41
|
+
from quadra.trainers.classification import SklearnClassificationTrainer
|
|
42
|
+
from quadra.utils import utils
|
|
43
|
+
from quadra.utils.classification import (
|
|
44
|
+
get_results,
|
|
45
|
+
save_classification_result,
|
|
46
|
+
)
|
|
47
|
+
from quadra.utils.evaluation import automatic_datamodule_batch_size
|
|
48
|
+
from quadra.utils.export import export_model, import_deployment_model
|
|
49
|
+
from quadra.utils.models import get_feature, is_vision_transformer
|
|
50
|
+
from quadra.utils.vit_explainability import VitAttentionGradRollout
|
|
51
|
+
|
|
52
|
+
log = utils.get_logger(__name__)
|
|
53
|
+
|
|
54
|
+
SklearnClassificationDataModuleT = typing.TypeVar(
|
|
55
|
+
"SklearnClassificationDataModuleT", bound=SklearnClassificationDataModule
|
|
56
|
+
)
|
|
57
|
+
ClassificationDataModuleT = typing.TypeVar("ClassificationDataModuleT", bound=ClassificationDataModule)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
# TODO: Maybe we should have a BaseClassificationTask that is extended by Classification and MultilabelClassification
|
|
61
|
+
# at the current time, multilabel experiments use this Classification task class and they can not generate report
|
|
62
|
+
# (it is written specifically for a vanilla classification). Moreover, this class takes generic
|
|
63
|
+
# ClassificationDataModuleT, but multilabel experim. uses MultilabelClassificationDataModule, which is not a child of
|
|
64
|
+
# ClassificationDataModule
|
|
65
|
+
class Classification(Generic[ClassificationDataModuleT], LightningTask[ClassificationDataModuleT]):
|
|
66
|
+
"""Classification Task.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
config: The experiment configuration
|
|
70
|
+
output: The otuput configuration.
|
|
71
|
+
gradcam: Whether to compute gradcams
|
|
72
|
+
checkpoint_path: The path to the checkpoint to load the model from. Defaults to None.
|
|
73
|
+
lr_multiplier: The multiplier for the backbone learning rate. Defaults to None.
|
|
74
|
+
output: The ouput configuration (under task config). It contains the bool "example" to generate
|
|
75
|
+
figs of discordant/concordant predictions.
|
|
76
|
+
report: Whether to generate a report containing the results after test phase
|
|
77
|
+
run_test: Whether to run the test phase.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
config: DictConfig,
|
|
83
|
+
output: DictConfig,
|
|
84
|
+
checkpoint_path: str | None = None,
|
|
85
|
+
lr_multiplier: float | None = None,
|
|
86
|
+
gradcam: bool = False,
|
|
87
|
+
report: bool = False,
|
|
88
|
+
run_test: bool = False,
|
|
89
|
+
):
|
|
90
|
+
super().__init__(
|
|
91
|
+
config=config,
|
|
92
|
+
checkpoint_path=checkpoint_path,
|
|
93
|
+
run_test=run_test,
|
|
94
|
+
report=report,
|
|
95
|
+
)
|
|
96
|
+
self.output = output
|
|
97
|
+
self.gradcam = gradcam
|
|
98
|
+
self._lr_multiplier = lr_multiplier
|
|
99
|
+
self._pre_classifier: nn.Module
|
|
100
|
+
self._classifier: nn.Module
|
|
101
|
+
self._model: nn.Module
|
|
102
|
+
self._optimizer: torch.optim.Optimizer
|
|
103
|
+
self._scheduler: torch.optim.lr_scheduler._LRScheduler
|
|
104
|
+
self.model_json: dict[str, Any] | None = None
|
|
105
|
+
self.export_folder: str = "deployment_model"
|
|
106
|
+
self.deploy_info_file: str = "model.json"
|
|
107
|
+
self.report_confmat: pd.DataFrame
|
|
108
|
+
self.best_model_path: str | None = None
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def optimizer(self) -> torch.optim.Optimizer:
|
|
112
|
+
"""Get the optimizer."""
|
|
113
|
+
return self._optimizer
|
|
114
|
+
|
|
115
|
+
@optimizer.setter
|
|
116
|
+
def optimizer(self, optimizer_config: DictConfig) -> None:
|
|
117
|
+
"""Set the optimizer."""
|
|
118
|
+
if (
|
|
119
|
+
isinstance(self.model.features_extractor, nn.Module)
|
|
120
|
+
and isinstance(self.model.pre_classifier, nn.Module)
|
|
121
|
+
and isinstance(self.model.classifier, nn.Module)
|
|
122
|
+
):
|
|
123
|
+
log.info("Instantiating optimizer <%s>", self.config.optimizer["_target_"])
|
|
124
|
+
if self._lr_multiplier is not None and self._lr_multiplier > 0:
|
|
125
|
+
params = [
|
|
126
|
+
{
|
|
127
|
+
"params": self.model.features_extractor.parameters(),
|
|
128
|
+
"lr": optimizer_config.lr * self._lr_multiplier,
|
|
129
|
+
}
|
|
130
|
+
]
|
|
131
|
+
else:
|
|
132
|
+
params = [{"params": self.model.features_extractor.parameters(), "lr": optimizer_config.lr}]
|
|
133
|
+
params.append({"params": self.model.pre_classifier.parameters(), "lr": optimizer_config.lr})
|
|
134
|
+
params.append({"params": self.model.classifier.parameters(), "lr": optimizer_config.lr})
|
|
135
|
+
self._optimizer = hydra.utils.instantiate(optimizer_config, params)
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def len_train_dataloader(self) -> int:
|
|
139
|
+
"""Get the length of the train dataloader."""
|
|
140
|
+
len_train_dataloader = len(self.datamodule.train_dataloader())
|
|
141
|
+
if self.devices is not None:
|
|
142
|
+
num_gpus = len(self.devices) if isinstance(self.devices, list) else 1
|
|
143
|
+
len_train_dataloader = len_train_dataloader // num_gpus
|
|
144
|
+
if not self.datamodule.train_dataloader().drop_last:
|
|
145
|
+
len_train_dataloader += int(len(self.datamodule.train_dataloader()) % num_gpus != 0)
|
|
146
|
+
return len_train_dataloader
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def scheduler(self) -> torch.optim.lr_scheduler._LRScheduler:
|
|
150
|
+
"""Get the scheduler."""
|
|
151
|
+
return self._scheduler
|
|
152
|
+
|
|
153
|
+
@scheduler.setter
|
|
154
|
+
def scheduler(self, scheduler_config: DictConfig) -> None:
|
|
155
|
+
log.info("Instantiating scheduler <%s>", scheduler_config["_target_"])
|
|
156
|
+
if "CosineAnnealingWithLinearWarmUp" in self.config.scheduler["_target_"]:
|
|
157
|
+
# This scheduler will be overwritten by the SSLCallback
|
|
158
|
+
self._scheduler = hydra.utils.instantiate(
|
|
159
|
+
scheduler_config,
|
|
160
|
+
optimizer=self.optimizer,
|
|
161
|
+
batch_size=1,
|
|
162
|
+
len_loader=1,
|
|
163
|
+
)
|
|
164
|
+
self.add_callback(WarmupInit(scheduler_config=scheduler_config))
|
|
165
|
+
else:
|
|
166
|
+
self._scheduler = hydra.utils.instantiate(scheduler_config, optimizer=self.optimizer)
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def module(self) -> ClassificationModule:
|
|
170
|
+
"""Get the module of the model."""
|
|
171
|
+
return self._module
|
|
172
|
+
|
|
173
|
+
@LightningTask.module.setter
|
|
174
|
+
def module(self, module_config): # noqa: F811
|
|
175
|
+
"""Set the module of the model."""
|
|
176
|
+
module = hydra.utils.instantiate(
|
|
177
|
+
module_config,
|
|
178
|
+
model=self.model,
|
|
179
|
+
optimizer=self.optimizer,
|
|
180
|
+
lr_scheduler=self.scheduler,
|
|
181
|
+
gradcam=self.gradcam,
|
|
182
|
+
)
|
|
183
|
+
if self.checkpoint_path is not None:
|
|
184
|
+
log.info("Loading model from lightning checkpoint: %s", self.checkpoint_path)
|
|
185
|
+
module = module.__class__.load_from_checkpoint(
|
|
186
|
+
self.checkpoint_path,
|
|
187
|
+
model=self.model,
|
|
188
|
+
optimizer=self.optimizer,
|
|
189
|
+
lr_scheduler=self.scheduler,
|
|
190
|
+
criterion=module.criterion,
|
|
191
|
+
gradcam=self.gradcam,
|
|
192
|
+
)
|
|
193
|
+
self._module = module
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def pre_classifier(self) -> nn.Module:
|
|
197
|
+
return self._pre_classifier
|
|
198
|
+
|
|
199
|
+
@pre_classifier.setter
|
|
200
|
+
def pre_classifier(self, model_config: DictConfig) -> None:
|
|
201
|
+
if "pre_classifier" in model_config and model_config.pre_classifier is not None:
|
|
202
|
+
log.info("Instantiating pre_classifier <%s>", model_config.pre_classifier["_target_"])
|
|
203
|
+
self._pre_classifier = hydra.utils.instantiate(model_config.pre_classifier, _convert_="partial")
|
|
204
|
+
else:
|
|
205
|
+
log.info("No pre-classifier found in config: instantiate a torch.nn.Identity instead")
|
|
206
|
+
self._pre_classifier = nn.Identity()
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def classifier(self) -> nn.Module:
|
|
210
|
+
return self._classifier
|
|
211
|
+
|
|
212
|
+
@classifier.setter
|
|
213
|
+
def classifier(self, model_config: DictConfig) -> None:
|
|
214
|
+
if "classifier" in model_config:
|
|
215
|
+
log.info("Instantiating classifier <%s>", model_config.classifier["_target_"])
|
|
216
|
+
if self.datamodule.num_classes is None or self.datamodule.num_classes < 2:
|
|
217
|
+
raise ValueError(f"Non compliant datamodule.num_classes : {self.datamodule.num_classes}")
|
|
218
|
+
self._classifier = hydra.utils.instantiate(
|
|
219
|
+
model_config.classifier, out_features=self.datamodule.num_classes, _convert_="partial"
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
raise ValueError("A `classifier` definition must be specified in the config")
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def model(self) -> nn.Module:
|
|
226
|
+
return self._model
|
|
227
|
+
|
|
228
|
+
@model.setter
|
|
229
|
+
def model(self, model_config: DictConfig) -> None:
|
|
230
|
+
self.pre_classifier = model_config # type: ignore[assignment]
|
|
231
|
+
self.classifier = model_config # type: ignore[assignment]
|
|
232
|
+
log.info("Instantiating backbone <%s>", model_config.model["_target_"])
|
|
233
|
+
self._model = hydra.utils.instantiate(
|
|
234
|
+
model_config.model, classifier=self.classifier, pre_classifier=self.pre_classifier, _convert_="partial"
|
|
235
|
+
)
|
|
236
|
+
if getattr(self.config.backbone, "freeze_parameters_name", None) is not None:
|
|
237
|
+
self.freeze_layers_by_name(self.config.backbone.freeze_parameters_name)
|
|
238
|
+
|
|
239
|
+
if getattr(self.config.backbone, "freeze_parameters_index", None) is not None:
|
|
240
|
+
frozen_parameters_indices: list[int]
|
|
241
|
+
if isinstance(self.config.backbone.freeze_parameters_index, int):
|
|
242
|
+
# Freeze all layers up to the specified index
|
|
243
|
+
frozen_parameters_indices = list(range(self.config.backbone.freeze_parameters_index + 1))
|
|
244
|
+
elif isinstance(self.config.backbone.freeze_parameters_index, ListConfig):
|
|
245
|
+
frozen_parameters_indices = cast(
|
|
246
|
+
list[int], OmegaConf.to_container(self.config.backbone.freeze_parameters_index, resolve=True)
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
raise ValueError("freeze_parameters_index must be an int or a list of int")
|
|
250
|
+
|
|
251
|
+
self.freeze_parameters_by_index(frozen_parameters_indices)
|
|
252
|
+
|
|
253
|
+
def prepare(self) -> None:
|
|
254
|
+
"""Prepare the experiment."""
|
|
255
|
+
super().prepare()
|
|
256
|
+
self.model = self.config.model
|
|
257
|
+
self.optimizer = self.config.optimizer
|
|
258
|
+
self.scheduler = self.config.scheduler
|
|
259
|
+
self.module = self.config.model.module
|
|
260
|
+
|
|
261
|
+
def train(self):
|
|
262
|
+
"""Train the model."""
|
|
263
|
+
super().train()
|
|
264
|
+
if (
|
|
265
|
+
self.trainer.checkpoint_callback is not None
|
|
266
|
+
and hasattr(self.trainer.checkpoint_callback, "best_model_path")
|
|
267
|
+
and self.trainer.checkpoint_callback.best_model_path is not None
|
|
268
|
+
and len(self.trainer.checkpoint_callback.best_model_path) > 0
|
|
269
|
+
):
|
|
270
|
+
self.best_model_path = self.trainer.checkpoint_callback.best_model_path
|
|
271
|
+
log.info("Loading best epoch weights...")
|
|
272
|
+
|
|
273
|
+
def test(self) -> None:
|
|
274
|
+
"""Test the model."""
|
|
275
|
+
if not self.config.trainer.get("fast_dev_run"):
|
|
276
|
+
log.info("Starting testing!")
|
|
277
|
+
self.trainer.test(datamodule=self.datamodule, model=self.module, ckpt_path=self.best_model_path)
|
|
278
|
+
|
|
279
|
+
def export(self) -> None:
|
|
280
|
+
"""Generate deployment models for the task."""
|
|
281
|
+
if self.datamodule.class_to_idx is None:
|
|
282
|
+
log.warning(
|
|
283
|
+
"No `class_to_idx` found in the datamodule, class information will not be saved in the model.json"
|
|
284
|
+
)
|
|
285
|
+
idx_to_class = {}
|
|
286
|
+
else:
|
|
287
|
+
idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
|
|
288
|
+
|
|
289
|
+
# Get best model!
|
|
290
|
+
if self.best_model_path is not None:
|
|
291
|
+
log.info("Saving deployment model for %s checkpoint", self.best_model_path)
|
|
292
|
+
|
|
293
|
+
module = self.module.__class__.load_from_checkpoint(
|
|
294
|
+
self.best_model_path,
|
|
295
|
+
model=self.module.model,
|
|
296
|
+
optimizer=self.optimizer,
|
|
297
|
+
lr_scheduler=self.scheduler,
|
|
298
|
+
criterion=self.module.criterion,
|
|
299
|
+
gradcam=False,
|
|
300
|
+
)
|
|
301
|
+
else:
|
|
302
|
+
log.warning("No checkpoint callback found in the trainer, exporting the last model weights")
|
|
303
|
+
module = self.module
|
|
304
|
+
|
|
305
|
+
input_shapes = self.config.export.input_shapes
|
|
306
|
+
|
|
307
|
+
# TODO: What happens if we have 64 precision?
|
|
308
|
+
half_precision = "16" in self.trainer.precision
|
|
309
|
+
|
|
310
|
+
self.model_json, export_paths = export_model(
|
|
311
|
+
config=self.config,
|
|
312
|
+
model=module.model,
|
|
313
|
+
export_folder=self.export_folder,
|
|
314
|
+
half_precision=half_precision,
|
|
315
|
+
input_shapes=input_shapes,
|
|
316
|
+
idx_to_class=idx_to_class,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
if len(export_paths) == 0:
|
|
320
|
+
return
|
|
321
|
+
|
|
322
|
+
with open(os.path.join(self.export_folder, self.deploy_info_file), "w") as f:
|
|
323
|
+
json.dump(self.model_json, f)
|
|
324
|
+
|
|
325
|
+
def generate_report(self) -> None:
|
|
326
|
+
"""Generate a report for the task."""
|
|
327
|
+
if self.datamodule.class_to_idx is None:
|
|
328
|
+
log.warning("No `class_to_idx` found in the datamodule, report will not be generated")
|
|
329
|
+
return
|
|
330
|
+
|
|
331
|
+
if isinstance(self.datamodule, MultilabelClassificationDataModule):
|
|
332
|
+
log.warning("Report generation is not supported for multilabel classification tasks at the moment.")
|
|
333
|
+
return
|
|
334
|
+
|
|
335
|
+
log.info("Generating report!")
|
|
336
|
+
if not self.run_test or self.config.trainer.get("fast_dev_run"):
|
|
337
|
+
self.datamodule.setup(stage="test")
|
|
338
|
+
|
|
339
|
+
# Deepcopy to remove the inference mode from gradients causing issues when loading checkpoints
|
|
340
|
+
# TODO: Why deepcopy of module model removes ModelSignatureWrapper?
|
|
341
|
+
self.module.model.instance = deepcopy(self.module.model.instance)
|
|
342
|
+
if "16" in self.trainer.precision:
|
|
343
|
+
log.warning("Gradcam is currently not supported with half precision, it will be disabled")
|
|
344
|
+
self.module.gradcam = False
|
|
345
|
+
self.gradcam = False
|
|
346
|
+
|
|
347
|
+
predictions_outputs = self.trainer.predict(
|
|
348
|
+
model=self.module, datamodule=self.datamodule, ckpt_path=self.best_model_path
|
|
349
|
+
)
|
|
350
|
+
if not predictions_outputs:
|
|
351
|
+
log.warning("There is no prediction to generate the report. Skipping report generation.")
|
|
352
|
+
return
|
|
353
|
+
all_outputs = [x[0] for x in predictions_outputs]
|
|
354
|
+
all_probs = [x[2] for x in predictions_outputs]
|
|
355
|
+
if not all_outputs or not all_probs:
|
|
356
|
+
log.warning("There is no prediction to generate the report. Skipping report generation.")
|
|
357
|
+
return
|
|
358
|
+
all_outputs = [item for sublist in all_outputs for item in sublist]
|
|
359
|
+
all_probs = [item for sublist in all_probs for item in sublist]
|
|
360
|
+
all_targets = [target.tolist() for im, target in self.datamodule.test_dataloader()]
|
|
361
|
+
all_targets = [item for sublist in all_targets for item in sublist]
|
|
362
|
+
|
|
363
|
+
if self.module.gradcam:
|
|
364
|
+
grayscale_cams = [x[1] for x in predictions_outputs]
|
|
365
|
+
grayscale_cams = [item for sublist in grayscale_cams for item in sublist]
|
|
366
|
+
grayscale_cams = np.stack(grayscale_cams) # N x H x W
|
|
367
|
+
else:
|
|
368
|
+
grayscale_cams = None
|
|
369
|
+
|
|
370
|
+
# creating confusion matrix
|
|
371
|
+
idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
|
|
372
|
+
_, self.report_confmat, accuracy = get_results(
|
|
373
|
+
test_labels=all_targets,
|
|
374
|
+
pred_labels=all_outputs,
|
|
375
|
+
idx_to_labels=idx_to_class,
|
|
376
|
+
)
|
|
377
|
+
output_folder_test = "test"
|
|
378
|
+
test_dataloader = self.datamodule.test_dataloader()
|
|
379
|
+
test_dataset = cast(ImageClassificationListDataset, test_dataloader.dataset)
|
|
380
|
+
self.res = pd.DataFrame(
|
|
381
|
+
{
|
|
382
|
+
"sample": list(test_dataset.x),
|
|
383
|
+
"real_label": all_targets,
|
|
384
|
+
"pred_label": all_outputs,
|
|
385
|
+
"probability": all_probs,
|
|
386
|
+
}
|
|
387
|
+
)
|
|
388
|
+
os.makedirs(output_folder_test, exist_ok=True)
|
|
389
|
+
save_classification_result(
|
|
390
|
+
results=self.res,
|
|
391
|
+
output_folder=output_folder_test,
|
|
392
|
+
confmat=self.report_confmat,
|
|
393
|
+
accuracy=accuracy,
|
|
394
|
+
test_dataloader=self.datamodule.test_dataloader(),
|
|
395
|
+
config=self.config,
|
|
396
|
+
output=self.output,
|
|
397
|
+
grayscale_cams=grayscale_cams,
|
|
398
|
+
)
|
|
399
|
+
|
|
400
|
+
if len(self.logger) > 0:
|
|
401
|
+
mflow_logger = get_mlflow_logger(trainer=self.trainer)
|
|
402
|
+
tensorboard_logger = utils.get_tensorboard_logger(trainer=self.trainer)
|
|
403
|
+
artifacts = glob.glob(os.path.join(output_folder_test, "**/*"), recursive=True)
|
|
404
|
+
if self.config.core.get("upload_artifacts") and len(artifacts) > 0:
|
|
405
|
+
if mflow_logger is not None:
|
|
406
|
+
log.info("Uploading artifacts to MLFlow")
|
|
407
|
+
for a in artifacts:
|
|
408
|
+
if os.path.isdir(a):
|
|
409
|
+
continue
|
|
410
|
+
|
|
411
|
+
dirname = Path(a).parent.name
|
|
412
|
+
mflow_logger.experiment.log_artifact(
|
|
413
|
+
run_id=mflow_logger.run_id,
|
|
414
|
+
local_path=a,
|
|
415
|
+
artifact_path=os.path.join("classification_output", dirname),
|
|
416
|
+
)
|
|
417
|
+
if tensorboard_logger is not None:
|
|
418
|
+
log.info("Uploading artifacts to Tensorboard")
|
|
419
|
+
for a in artifacts:
|
|
420
|
+
if os.path.isdir(a):
|
|
421
|
+
continue
|
|
422
|
+
|
|
423
|
+
ext = os.path.splitext(a)[1].lower()
|
|
424
|
+
|
|
425
|
+
if ext in [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".gif"]:
|
|
426
|
+
try:
|
|
427
|
+
img = cv2.imread(a)
|
|
428
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
429
|
+
except cv2.error:
|
|
430
|
+
log.info("Could not upload artifact image %s", a)
|
|
431
|
+
continue
|
|
432
|
+
output_path = os.path.sep.join(a.split(os.path.sep)[-2:])
|
|
433
|
+
tensorboard_logger.experiment.add_image(output_path, img, 0, dataformats="HWC")
|
|
434
|
+
else:
|
|
435
|
+
utils.upload_file_tensorboard(a, tensorboard_logger)
|
|
436
|
+
|
|
437
|
+
def freeze_layers_by_name(self, freeze_parameters_name: list[str]):
|
|
438
|
+
"""Freeze layers specified in freeze_parameters_name.
|
|
439
|
+
|
|
440
|
+
Args:
|
|
441
|
+
freeze_parameters_name: Layers that will be frozen during training.
|
|
442
|
+
|
|
443
|
+
"""
|
|
444
|
+
count_frozen = 0
|
|
445
|
+
for name, param in self.model.named_parameters():
|
|
446
|
+
if any(x in name.split(".")[1] for x in freeze_parameters_name):
|
|
447
|
+
log.debug("Freezing layer %s", name)
|
|
448
|
+
param.requires_grad = False
|
|
449
|
+
|
|
450
|
+
if not param.requires_grad:
|
|
451
|
+
count_frozen += 1
|
|
452
|
+
|
|
453
|
+
log.info("Frozen %d parameters", count_frozen)
|
|
454
|
+
|
|
455
|
+
def freeze_parameters_by_index(self, freeze_parameters_index: list[int]):
|
|
456
|
+
"""Freeze parameters specified in freeze_parameters_name.
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
freeze_parameters_index: Indices of parameters that will be frozen during training.
|
|
460
|
+
|
|
461
|
+
"""
|
|
462
|
+
if getattr(self.config.backbone, "freeze_parameters_name", None) is not None:
|
|
463
|
+
log.warning(
|
|
464
|
+
"Please be aware that some of the model's parameters have already been frozen using \
|
|
465
|
+
the specified freeze_parameters_name. You are combining these two actions."
|
|
466
|
+
)
|
|
467
|
+
count_frozen = 0
|
|
468
|
+
for i, (name, param) in enumerate(self.model.named_parameters()):
|
|
469
|
+
if i in freeze_parameters_index:
|
|
470
|
+
log.debug("Freezing layer %s", name)
|
|
471
|
+
param.requires_grad = False
|
|
472
|
+
|
|
473
|
+
if not param.requires_grad:
|
|
474
|
+
count_frozen += 1
|
|
475
|
+
|
|
476
|
+
log.info("Frozen %d parameters", count_frozen)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
class SklearnClassification(Generic[SklearnClassificationDataModuleT], Task[SklearnClassificationDataModuleT]):
|
|
480
|
+
"""Sklearn classification task.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
config: The experiment configuration
|
|
484
|
+
device: The device to use. Defaults to None.
|
|
485
|
+
output: Dictionary defining which kind of outputs to generate. Defaults to None.
|
|
486
|
+
automatic_batch_size: Whether to automatically find the largest batch size that fits in memory.
|
|
487
|
+
save_model_summary: Whether to save a model_summary.txt file containing the model summary.
|
|
488
|
+
half_precision: Whether to use half precision during training.
|
|
489
|
+
gradcam: Whether to compute gradcams for test results.
|
|
490
|
+
"""
|
|
491
|
+
|
|
492
|
+
def __init__(
|
|
493
|
+
self,
|
|
494
|
+
config: DictConfig,
|
|
495
|
+
output: DictConfig,
|
|
496
|
+
device: str,
|
|
497
|
+
automatic_batch_size: DictConfig,
|
|
498
|
+
save_model_summary: bool = False,
|
|
499
|
+
half_precision: bool = False,
|
|
500
|
+
gradcam: bool = False,
|
|
501
|
+
):
|
|
502
|
+
super().__init__(config=config)
|
|
503
|
+
|
|
504
|
+
self._device = device
|
|
505
|
+
self.output = output
|
|
506
|
+
self._backbone: ModelSignatureWrapper
|
|
507
|
+
self._trainer: SklearnClassificationTrainer
|
|
508
|
+
self._model: ClassifierMixin
|
|
509
|
+
self.metadata: dict[str, Any] = {
|
|
510
|
+
"test_confusion_matrix": [],
|
|
511
|
+
"test_accuracy": [],
|
|
512
|
+
"test_results": [],
|
|
513
|
+
"test_labels": [],
|
|
514
|
+
"cams": [],
|
|
515
|
+
}
|
|
516
|
+
self.export_folder = "deployment_model"
|
|
517
|
+
self.deploy_info_file = "model.json"
|
|
518
|
+
self.train_dataloader_list: list[torch.utils.data.DataLoader] = []
|
|
519
|
+
self.test_dataloader_list: list[torch.utils.data.DataLoader] = []
|
|
520
|
+
self.automatic_batch_size = automatic_batch_size
|
|
521
|
+
self.save_model_summary = save_model_summary
|
|
522
|
+
self.half_precision = half_precision
|
|
523
|
+
self.gradcam = gradcam
|
|
524
|
+
|
|
525
|
+
@property
|
|
526
|
+
def device(self) -> str:
|
|
527
|
+
return self._device
|
|
528
|
+
|
|
529
|
+
def prepare(self) -> None:
|
|
530
|
+
"""Prepare the experiment."""
|
|
531
|
+
self.datamodule = self.config.datamodule
|
|
532
|
+
|
|
533
|
+
self.backbone = self.config.backbone
|
|
534
|
+
|
|
535
|
+
self.model = self.config.model
|
|
536
|
+
|
|
537
|
+
# prepare_data() must be explicitly called if the task does not include a lightining training
|
|
538
|
+
self.datamodule.prepare_data()
|
|
539
|
+
self.datamodule.setup(stage="fit")
|
|
540
|
+
|
|
541
|
+
self.trainer = self.config.trainer
|
|
542
|
+
|
|
543
|
+
@property
|
|
544
|
+
def model(self) -> ClassifierMixin:
|
|
545
|
+
"""sklearn.base.ClassifierMixin: The model."""
|
|
546
|
+
return self._model
|
|
547
|
+
|
|
548
|
+
@model.setter
|
|
549
|
+
def model(self, model_config: DictConfig):
|
|
550
|
+
"""sklearn.base.ClassifierMixin: The model."""
|
|
551
|
+
log.info("Instantiating model <%s>", model_config["_target_"])
|
|
552
|
+
self._model = hydra.utils.instantiate(model_config)
|
|
553
|
+
|
|
554
|
+
@property
|
|
555
|
+
def backbone(self) -> ModelSignatureWrapper:
|
|
556
|
+
"""Backbone: The backbone."""
|
|
557
|
+
return self._backbone
|
|
558
|
+
|
|
559
|
+
@backbone.setter
|
|
560
|
+
def backbone(self, backbone_config):
|
|
561
|
+
"""Load backbone."""
|
|
562
|
+
if backbone_config.metadata.get("checkpoint"):
|
|
563
|
+
log.info("Loading backbone from <%s>", backbone_config.metadata.checkpoint)
|
|
564
|
+
self._backbone = torch.load(backbone_config.metadata.checkpoint)
|
|
565
|
+
else:
|
|
566
|
+
log.info("Loading backbone from <%s>", backbone_config.model["_target_"])
|
|
567
|
+
self._backbone = hydra.utils.instantiate(backbone_config.model)
|
|
568
|
+
|
|
569
|
+
self._backbone = ModelSignatureWrapper(self._backbone)
|
|
570
|
+
self._backbone.eval()
|
|
571
|
+
if self.half_precision:
|
|
572
|
+
if self.device == "cpu":
|
|
573
|
+
raise ValueError("Half precision is not supported on CPU")
|
|
574
|
+
self._backbone.half()
|
|
575
|
+
|
|
576
|
+
if self.gradcam:
|
|
577
|
+
log.warning("Gradcam is currently not supported with half precision, it will be disabled")
|
|
578
|
+
self.gradcam = False
|
|
579
|
+
self._backbone.to(self.device)
|
|
580
|
+
|
|
581
|
+
@property
|
|
582
|
+
def trainer(self) -> SklearnClassificationTrainer:
|
|
583
|
+
"""Trainer: The trainer."""
|
|
584
|
+
return self._trainer
|
|
585
|
+
|
|
586
|
+
@trainer.setter
|
|
587
|
+
def trainer(self, trainer_config: DictConfig) -> None:
|
|
588
|
+
"""Trainer: The trainer."""
|
|
589
|
+
log.info("Instantiating trainer <%s>", trainer_config["_target_"])
|
|
590
|
+
trainer = hydra.utils.instantiate(trainer_config, backbone=self.backbone, classifier=self.model)
|
|
591
|
+
self._trainer = trainer
|
|
592
|
+
|
|
593
|
+
@typing.no_type_check
|
|
594
|
+
@automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
|
|
595
|
+
def train(self) -> None:
|
|
596
|
+
"""Train the model."""
|
|
597
|
+
log.info("Starting training...!")
|
|
598
|
+
all_features = None
|
|
599
|
+
all_labels = None
|
|
600
|
+
|
|
601
|
+
class_to_keep = None
|
|
602
|
+
|
|
603
|
+
self.train_dataloader_list = list(self.datamodule.train_dataloader())
|
|
604
|
+
self.test_dataloader_list = list(self.datamodule.val_dataloader())
|
|
605
|
+
|
|
606
|
+
if hasattr(self.datamodule, "class_to_keep_training") and self.datamodule.class_to_keep_training is not None:
|
|
607
|
+
class_to_keep = self.datamodule.class_to_keep_training
|
|
608
|
+
|
|
609
|
+
if self.save_model_summary:
|
|
610
|
+
self.extract_model_summary(feature_extractor=self.backbone, dl=self.datamodule.full_dataloader())
|
|
611
|
+
|
|
612
|
+
if hasattr(self.datamodule, "cache") and self.datamodule.cache:
|
|
613
|
+
if self.config.trainer.iteration_over_training != 1:
|
|
614
|
+
raise AttributeError("Cache is only supported when iteration over training is set to 1")
|
|
615
|
+
|
|
616
|
+
full_dataloader = self.datamodule.full_dataloader()
|
|
617
|
+
all_features, all_labels, _ = get_feature(
|
|
618
|
+
feature_extractor=self.backbone, dl=full_dataloader, iteration_over_training=1
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
sorted_indices = np.argsort(full_dataloader.dataset.x)
|
|
622
|
+
all_features = all_features[sorted_indices]
|
|
623
|
+
all_labels = all_labels[sorted_indices]
|
|
624
|
+
|
|
625
|
+
# cycle over all train/test split
|
|
626
|
+
for train_dataloader, test_dataloader in zip(self.train_dataloader_list, self.test_dataloader_list):
|
|
627
|
+
# Reinit classifier
|
|
628
|
+
self.model = self.config.model
|
|
629
|
+
self.trainer.change_classifier(self.model)
|
|
630
|
+
|
|
631
|
+
# Train on current training set
|
|
632
|
+
if all_features is not None and all_labels is not None:
|
|
633
|
+
# Find which are the indices used to pass from the sorted list of string to the disordered one
|
|
634
|
+
sorted_indices = np.argsort(np.concatenate([train_dataloader.dataset.x, test_dataloader.dataset.x]))
|
|
635
|
+
revese_sorted_indices = np.argsort(sorted_indices)
|
|
636
|
+
|
|
637
|
+
# Use these indices to correctly match the extracted features with the new file order
|
|
638
|
+
all_features_sorted = all_features[revese_sorted_indices]
|
|
639
|
+
all_labels_sorted = all_labels[revese_sorted_indices]
|
|
640
|
+
|
|
641
|
+
train_len = len(train_dataloader.dataset.x)
|
|
642
|
+
|
|
643
|
+
self.trainer.fit(
|
|
644
|
+
train_features=all_features_sorted[0:train_len], train_labels=all_labels_sorted[0:train_len]
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
_, pd_cm, accuracy, res, cams = self.trainer.test(
|
|
648
|
+
test_dataloader=test_dataloader,
|
|
649
|
+
test_features=all_features_sorted[train_len:],
|
|
650
|
+
test_labels=all_labels_sorted[train_len:],
|
|
651
|
+
class_to_keep=class_to_keep,
|
|
652
|
+
idx_to_class=train_dataloader.dataset.idx_to_class,
|
|
653
|
+
predict_proba=True,
|
|
654
|
+
gradcam=self.gradcam,
|
|
655
|
+
)
|
|
656
|
+
else:
|
|
657
|
+
self.trainer.fit(train_dataloader=train_dataloader)
|
|
658
|
+
_, pd_cm, accuracy, res, cams = self.trainer.test(
|
|
659
|
+
test_dataloader=test_dataloader,
|
|
660
|
+
class_to_keep=class_to_keep,
|
|
661
|
+
idx_to_class=train_dataloader.dataset.idx_to_class,
|
|
662
|
+
predict_proba=True,
|
|
663
|
+
gradcam=self.gradcam,
|
|
664
|
+
)
|
|
665
|
+
|
|
666
|
+
# save results
|
|
667
|
+
self.metadata["test_confusion_matrix"].append(pd_cm)
|
|
668
|
+
self.metadata["test_accuracy"].append(accuracy)
|
|
669
|
+
self.metadata["test_results"].append(res)
|
|
670
|
+
self.metadata["test_labels"].append(
|
|
671
|
+
[
|
|
672
|
+
train_dataloader.dataset.idx_to_class[i] if i != -1 else "N/A"
|
|
673
|
+
for i in res["real_label"].unique().tolist()
|
|
674
|
+
]
|
|
675
|
+
)
|
|
676
|
+
self.metadata["cams"].append(cams)
|
|
677
|
+
|
|
678
|
+
def extract_model_summary(
|
|
679
|
+
self, feature_extractor: torch.nn.Module | BaseEvaluationModel, dl: torch.utils.data.DataLoader
|
|
680
|
+
) -> None:
|
|
681
|
+
"""Given a dataloader and a PyTorch model, use torchinfo to extract a summary of the model and save it
|
|
682
|
+
to a file.
|
|
683
|
+
|
|
684
|
+
Args:
|
|
685
|
+
dl: PyTorch dataloader
|
|
686
|
+
feature_extractor: PyTorch backbone
|
|
687
|
+
"""
|
|
688
|
+
if isinstance(feature_extractor, (TorchEvaluationModel, TorchscriptEvaluationModel)):
|
|
689
|
+
# TODO: I'm not sure torchinfo supports torchscript models
|
|
690
|
+
# If we are working with torch based evaluation models we need to extract the model
|
|
691
|
+
feature_extractor = feature_extractor.model
|
|
692
|
+
|
|
693
|
+
for b in tqdm(dl):
|
|
694
|
+
x1, _ = b
|
|
695
|
+
|
|
696
|
+
if hasattr(feature_extractor, "parameters"):
|
|
697
|
+
# Move input to the correct device
|
|
698
|
+
parameter = next(feature_extractor.parameters())
|
|
699
|
+
x1 = x1.to(parameter.device).to(parameter.dtype)
|
|
700
|
+
x1 = x1[0].unsqueeze(0) # Remove batch dimension
|
|
701
|
+
|
|
702
|
+
model_info = None
|
|
703
|
+
|
|
704
|
+
try:
|
|
705
|
+
try:
|
|
706
|
+
# TODO: Do we want to print the summary to the console as well?
|
|
707
|
+
model_info = summary(feature_extractor, input_data=(x1), verbose=0) # type: ignore[arg-type]
|
|
708
|
+
except Exception:
|
|
709
|
+
log.warning(
|
|
710
|
+
"Failed to retrieve model summary using input data information, retrieving only "
|
|
711
|
+
"parameters information"
|
|
712
|
+
)
|
|
713
|
+
model_info = summary(feature_extractor, verbose=0) # type: ignore[arg-type]
|
|
714
|
+
except Exception as e:
|
|
715
|
+
# If for some reason the summary fails we don't want to stop the training
|
|
716
|
+
log.warning("Failed to retrieve model summary: %s", e)
|
|
717
|
+
|
|
718
|
+
if model_info is not None:
|
|
719
|
+
with open("model_summary.txt", "w") as f:
|
|
720
|
+
f.write(str(model_info))
|
|
721
|
+
else:
|
|
722
|
+
log.warning("Failed to retrieve model summary, current model has no parameters")
|
|
723
|
+
|
|
724
|
+
break
|
|
725
|
+
|
|
726
|
+
@automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
|
|
727
|
+
def train_full_data(self):
|
|
728
|
+
"""Train the model on train + validation."""
|
|
729
|
+
# Reinit classifier
|
|
730
|
+
self.model = self.config.model
|
|
731
|
+
self.trainer.change_classifier(self.model)
|
|
732
|
+
|
|
733
|
+
self.trainer.fit(train_dataloader=self.datamodule.full_dataloader())
|
|
734
|
+
|
|
735
|
+
def test(self) -> None:
|
|
736
|
+
"""Skip test phase."""
|
|
737
|
+
# we don't need test phase since sklearn trainer is already running test inside
|
|
738
|
+
# train module to handle cross validation
|
|
739
|
+
|
|
740
|
+
@typing.no_type_check
|
|
741
|
+
@automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
|
|
742
|
+
def test_full_data(self) -> None:
|
|
743
|
+
"""Test model trained on full dataset."""
|
|
744
|
+
self.config.datamodule.class_to_idx = self.datamodule.full_dataset.class_to_idx
|
|
745
|
+
self.config.datamodule.phase = "test"
|
|
746
|
+
idx_to_class = self.datamodule.full_dataset.idx_to_class
|
|
747
|
+
self.datamodule.setup("test")
|
|
748
|
+
test_dataloader = self.datamodule.test_dataloader()
|
|
749
|
+
|
|
750
|
+
if len(self.datamodule.data["samples"]) == 0:
|
|
751
|
+
log.info("No test data, skipping test")
|
|
752
|
+
return
|
|
753
|
+
|
|
754
|
+
# Put backbone on the correct device as it may be moved after export
|
|
755
|
+
self.backbone.to(self.device)
|
|
756
|
+
_, pd_cm, accuracy, res, cams = self.trainer.test(
|
|
757
|
+
test_dataloader=test_dataloader, idx_to_class=idx_to_class, predict_proba=True, gradcam=self.gradcam
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
output_folder_test = "test"
|
|
761
|
+
|
|
762
|
+
os.makedirs(output_folder_test, exist_ok=True)
|
|
763
|
+
|
|
764
|
+
save_classification_result(
|
|
765
|
+
results=res,
|
|
766
|
+
output_folder=output_folder_test,
|
|
767
|
+
confmat=pd_cm,
|
|
768
|
+
accuracy=accuracy,
|
|
769
|
+
test_dataloader=test_dataloader,
|
|
770
|
+
config=self.config,
|
|
771
|
+
output=self.output,
|
|
772
|
+
grayscale_cams=cams,
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
def export(self) -> None:
|
|
776
|
+
"""Generate deployment model for the task."""
|
|
777
|
+
if self.config.export is None or len(self.config.export.types) == 0:
|
|
778
|
+
log.info("No export type specified skipping export")
|
|
779
|
+
return
|
|
780
|
+
|
|
781
|
+
input_shapes = self.config.export.input_shapes
|
|
782
|
+
|
|
783
|
+
idx_to_class = {v: k for k, v in self.datamodule.full_dataset.class_to_idx.items()}
|
|
784
|
+
|
|
785
|
+
model_json, export_paths = export_model(
|
|
786
|
+
config=self.config,
|
|
787
|
+
model=self.backbone,
|
|
788
|
+
export_folder=self.export_folder,
|
|
789
|
+
half_precision=self.half_precision,
|
|
790
|
+
input_shapes=input_shapes,
|
|
791
|
+
idx_to_class=idx_to_class,
|
|
792
|
+
pytorch_model_type="backbone",
|
|
793
|
+
)
|
|
794
|
+
|
|
795
|
+
dump(self.model, os.path.join(self.export_folder, "classifier.joblib"))
|
|
796
|
+
|
|
797
|
+
if len(export_paths) > 0:
|
|
798
|
+
with open(os.path.join(self.export_folder, self.deploy_info_file), "w") as f:
|
|
799
|
+
json.dump(model_json, f)
|
|
800
|
+
|
|
801
|
+
def generate_report(self) -> None:
|
|
802
|
+
"""Generate report for the task."""
|
|
803
|
+
log.info("Generating report!")
|
|
804
|
+
|
|
805
|
+
cm_list = []
|
|
806
|
+
|
|
807
|
+
for count in range(len(self.metadata["test_accuracy"])):
|
|
808
|
+
current_output_folder = f"{self.output.folder}_{count}"
|
|
809
|
+
os.makedirs(current_output_folder, exist_ok=True)
|
|
810
|
+
|
|
811
|
+
c_matrix = self.metadata["test_confusion_matrix"][count]
|
|
812
|
+
cm_list.append(c_matrix)
|
|
813
|
+
save_classification_result(
|
|
814
|
+
results=self.metadata["test_results"][count],
|
|
815
|
+
output_folder=current_output_folder,
|
|
816
|
+
confmat=c_matrix,
|
|
817
|
+
accuracy=self.metadata["test_accuracy"][count],
|
|
818
|
+
test_dataloader=self.test_dataloader_list[count],
|
|
819
|
+
config=self.config,
|
|
820
|
+
output=self.output,
|
|
821
|
+
grayscale_cams=self.metadata["cams"][count],
|
|
822
|
+
)
|
|
823
|
+
final_confusion_matrix = sum(cm_list)
|
|
824
|
+
|
|
825
|
+
self.metadata["final_confusion_matrix"] = final_confusion_matrix
|
|
826
|
+
# Save final conf matrix
|
|
827
|
+
final_folder = f"{self.output.folder}"
|
|
828
|
+
os.makedirs(final_folder, exist_ok=True)
|
|
829
|
+
disp = ConfusionMatrixDisplay(
|
|
830
|
+
confusion_matrix=np.array(final_confusion_matrix),
|
|
831
|
+
display_labels=[x.replace("pred:", "") for x in final_confusion_matrix.columns.to_list()],
|
|
832
|
+
)
|
|
833
|
+
disp.plot(include_values=True, cmap=plt.cm.Greens, ax=None, colorbar=False, xticks_rotation=90)
|
|
834
|
+
plt.title(f"Confusion Matrix (Accuracy: {(self.metadata['test_accuracy'][count] * 100):.2f}%)")
|
|
835
|
+
plt.savefig(os.path.join(final_folder, "test_confusion_matrix.png"), bbox_inches="tight", pad_inches=0, dpi=300)
|
|
836
|
+
plt.close()
|
|
837
|
+
|
|
838
|
+
def execute(self) -> None:
|
|
839
|
+
"""Execute the experiment and all the steps."""
|
|
840
|
+
self.prepare()
|
|
841
|
+
self.train()
|
|
842
|
+
if self.output.report:
|
|
843
|
+
self.generate_report()
|
|
844
|
+
self.train_full_data()
|
|
845
|
+
if self.config.export is not None and len(self.config.export.types) > 0:
|
|
846
|
+
self.export()
|
|
847
|
+
if self.output.test_full_data:
|
|
848
|
+
self.test_full_data()
|
|
849
|
+
self.finalize()
|
|
850
|
+
|
|
851
|
+
|
|
852
|
+
class SklearnTestClassification(Evaluation[SklearnClassificationDataModuleT]):
|
|
853
|
+
"""Perform a test using an imported SklearnClassification pytorch model.
|
|
854
|
+
|
|
855
|
+
Args:
|
|
856
|
+
config: The experiment configuration
|
|
857
|
+
output: where to save results
|
|
858
|
+
model_path: path to trained model generated from SklearnClassification task.
|
|
859
|
+
device: the device where to run the model (cuda or cpu)
|
|
860
|
+
gradcam: Whether to compute gradcams
|
|
861
|
+
**kwargs: Additional arguments to pass to the task
|
|
862
|
+
"""
|
|
863
|
+
|
|
864
|
+
def __init__(
|
|
865
|
+
self,
|
|
866
|
+
config: DictConfig,
|
|
867
|
+
output: DictConfig,
|
|
868
|
+
model_path: str,
|
|
869
|
+
device: str,
|
|
870
|
+
gradcam: bool = False,
|
|
871
|
+
**kwargs: Any,
|
|
872
|
+
):
|
|
873
|
+
super().__init__(config=config, model_path=model_path, device=device, **kwargs)
|
|
874
|
+
self.gradcam = gradcam
|
|
875
|
+
self.output = output
|
|
876
|
+
self._backbone: BaseEvaluationModel
|
|
877
|
+
self._classifier: ClassifierMixin
|
|
878
|
+
self.class_to_idx: dict[str, int]
|
|
879
|
+
self.idx_to_class: dict[int, str]
|
|
880
|
+
self.test_dataloader: torch.utils.data.DataLoader
|
|
881
|
+
self.metadata: dict[str, Any] = {
|
|
882
|
+
"test_confusion_matrix": None,
|
|
883
|
+
"test_accuracy": None,
|
|
884
|
+
"test_results": None,
|
|
885
|
+
"test_labels": None,
|
|
886
|
+
"cams": None,
|
|
887
|
+
}
|
|
888
|
+
|
|
889
|
+
def prepare(self) -> None:
|
|
890
|
+
"""Prepare the experiment."""
|
|
891
|
+
super().prepare()
|
|
892
|
+
|
|
893
|
+
idx_to_class = {}
|
|
894
|
+
class_to_idx = {}
|
|
895
|
+
for k, v in self.model_data["classes"].items():
|
|
896
|
+
idx_to_class[int(k)] = v
|
|
897
|
+
class_to_idx[v] = int(k)
|
|
898
|
+
|
|
899
|
+
self.idx_to_class = idx_to_class
|
|
900
|
+
self.class_to_idx = class_to_idx
|
|
901
|
+
|
|
902
|
+
self.config.datamodule.class_to_idx = class_to_idx
|
|
903
|
+
|
|
904
|
+
self.datamodule = self.config.datamodule
|
|
905
|
+
# prepare_data() must be explicitly called because there is no lightning training
|
|
906
|
+
self.datamodule.prepare_data()
|
|
907
|
+
self.datamodule.setup(stage="test")
|
|
908
|
+
|
|
909
|
+
# Configure trainer
|
|
910
|
+
self.trainer = self.config.trainer
|
|
911
|
+
|
|
912
|
+
@property
|
|
913
|
+
def deployment_model(self):
|
|
914
|
+
"""Deployment model."""
|
|
915
|
+
return None
|
|
916
|
+
|
|
917
|
+
@deployment_model.setter
|
|
918
|
+
def deployment_model(self, model_path: str):
|
|
919
|
+
"""Set backbone and classifier."""
|
|
920
|
+
self.backbone = model_path # type: ignore[assignment]
|
|
921
|
+
# Load classifier
|
|
922
|
+
self.classifier = os.path.join(Path(model_path).parent, "classifier.joblib")
|
|
923
|
+
|
|
924
|
+
@property
|
|
925
|
+
def classifier(self) -> ClassifierMixin:
|
|
926
|
+
"""Classifier: The classifier."""
|
|
927
|
+
return self._classifier
|
|
928
|
+
|
|
929
|
+
@classifier.setter
|
|
930
|
+
def classifier(self, classifier_path: str) -> None:
|
|
931
|
+
"""Load classifier."""
|
|
932
|
+
self._classifier = load(classifier_path)
|
|
933
|
+
|
|
934
|
+
@property
|
|
935
|
+
def backbone(self) -> BaseEvaluationModel:
|
|
936
|
+
"""Backbone: The backbone."""
|
|
937
|
+
return self._backbone
|
|
938
|
+
|
|
939
|
+
@backbone.setter
|
|
940
|
+
def backbone(self, model_path: str) -> None:
|
|
941
|
+
"""Load backbone."""
|
|
942
|
+
file_extension = os.path.splitext(model_path)[1]
|
|
943
|
+
|
|
944
|
+
model_architecture = None
|
|
945
|
+
if file_extension == ".pth":
|
|
946
|
+
backbone_config_path = os.path.join(Path(model_path).parent, "model_config.yaml")
|
|
947
|
+
log.info("Loading backbone from config")
|
|
948
|
+
backbone_config = OmegaConf.load(backbone_config_path)
|
|
949
|
+
|
|
950
|
+
if backbone_config.metadata.get("checkpoint"):
|
|
951
|
+
log.info("Loading backbone from <%s>", backbone_config.metadata.checkpoint)
|
|
952
|
+
model_architecture = torch.load(backbone_config.metadata.checkpoint)
|
|
953
|
+
else:
|
|
954
|
+
log.info("Loading backbone from <%s>", backbone_config.model["_target_"])
|
|
955
|
+
model_architecture = hydra.utils.instantiate(backbone_config.model)
|
|
956
|
+
|
|
957
|
+
self._backbone = import_deployment_model(
|
|
958
|
+
model_path=model_path,
|
|
959
|
+
device=self.device,
|
|
960
|
+
inference_config=self.config.inference,
|
|
961
|
+
model_architecture=model_architecture,
|
|
962
|
+
)
|
|
963
|
+
|
|
964
|
+
if self.gradcam and not isinstance(self._backbone, TorchEvaluationModel):
|
|
965
|
+
log.warning("Gradcam is supported only for pytorch models. Skipping gradcam")
|
|
966
|
+
self.gradcam = False
|
|
967
|
+
|
|
968
|
+
@property
|
|
969
|
+
def trainer(self) -> SklearnClassificationTrainer:
|
|
970
|
+
"""Trainer: The trainer."""
|
|
971
|
+
return self._trainer
|
|
972
|
+
|
|
973
|
+
@trainer.setter
|
|
974
|
+
def trainer(self, trainer_config: DictConfig) -> None:
|
|
975
|
+
"""Trainer: The trainer."""
|
|
976
|
+
log.info("Instantiating trainer <%s>", trainer_config["_target_"])
|
|
977
|
+
|
|
978
|
+
if self.backbone.training:
|
|
979
|
+
self.backbone.eval()
|
|
980
|
+
|
|
981
|
+
trainer = hydra.utils.instantiate(trainer_config, backbone=self.backbone, classifier=self.classifier)
|
|
982
|
+
self._trainer = trainer
|
|
983
|
+
|
|
984
|
+
@automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
|
|
985
|
+
def test(self) -> None:
|
|
986
|
+
"""Run the test."""
|
|
987
|
+
self.test_dataloader = self.datamodule.test_dataloader()
|
|
988
|
+
|
|
989
|
+
_, pd_cm, accuracy, res, cams = self.trainer.test(
|
|
990
|
+
test_dataloader=self.test_dataloader,
|
|
991
|
+
idx_to_class=self.idx_to_class,
|
|
992
|
+
predict_proba=True,
|
|
993
|
+
gradcam=self.gradcam,
|
|
994
|
+
)
|
|
995
|
+
|
|
996
|
+
# save results
|
|
997
|
+
self.metadata["test_confusion_matrix"] = pd_cm
|
|
998
|
+
self.metadata["test_accuracy"] = accuracy
|
|
999
|
+
self.metadata["test_results"] = res
|
|
1000
|
+
self.metadata["test_labels"] = [
|
|
1001
|
+
self.idx_to_class[i] if i != -1 else "N/A" for i in res["real_label"].unique().tolist()
|
|
1002
|
+
]
|
|
1003
|
+
self.metadata["cams"] = cams
|
|
1004
|
+
|
|
1005
|
+
def generate_report(self) -> None:
|
|
1006
|
+
"""Generate a report for the task."""
|
|
1007
|
+
log.info("Generating report!")
|
|
1008
|
+
os.makedirs(self.output.folder, exist_ok=True)
|
|
1009
|
+
save_classification_result(
|
|
1010
|
+
results=self.metadata["test_results"],
|
|
1011
|
+
output_folder=self.output.folder,
|
|
1012
|
+
confmat=self.metadata["test_confusion_matrix"],
|
|
1013
|
+
accuracy=self.metadata["test_accuracy"],
|
|
1014
|
+
test_dataloader=self.test_dataloader,
|
|
1015
|
+
config=self.config,
|
|
1016
|
+
output=self.output,
|
|
1017
|
+
grayscale_cams=self.metadata["cams"],
|
|
1018
|
+
)
|
|
1019
|
+
|
|
1020
|
+
def execute(self) -> None:
|
|
1021
|
+
"""Execute the experiment and all the steps."""
|
|
1022
|
+
self.prepare()
|
|
1023
|
+
self.test()
|
|
1024
|
+
if self.output.report:
|
|
1025
|
+
self.generate_report()
|
|
1026
|
+
self.finalize()
|
|
1027
|
+
|
|
1028
|
+
|
|
1029
|
+
class ClassificationEvaluation(Evaluation[ClassificationDataModuleT]):
|
|
1030
|
+
"""Perform a test on an imported Classification pytorch model.
|
|
1031
|
+
|
|
1032
|
+
Args:
|
|
1033
|
+
config: Task configuration
|
|
1034
|
+
output: Configuration for the output
|
|
1035
|
+
model_path: Path to pytorch .pt model file
|
|
1036
|
+
report: Whether to generate the report of the predictions
|
|
1037
|
+
gradcam: Whether to compute gradcams
|
|
1038
|
+
device: Device to use for evaluation. If None, the device is automatically determined
|
|
1039
|
+
|
|
1040
|
+
"""
|
|
1041
|
+
|
|
1042
|
+
def __init__(
|
|
1043
|
+
self,
|
|
1044
|
+
config: DictConfig,
|
|
1045
|
+
output: DictConfig,
|
|
1046
|
+
model_path: str,
|
|
1047
|
+
report: bool = True,
|
|
1048
|
+
gradcam: bool = False,
|
|
1049
|
+
device: str | None = None,
|
|
1050
|
+
):
|
|
1051
|
+
super().__init__(config=config, model_path=model_path, device=device)
|
|
1052
|
+
self.report_path = "test_output"
|
|
1053
|
+
self.output = output
|
|
1054
|
+
self.report = report
|
|
1055
|
+
self.gradcam = gradcam
|
|
1056
|
+
self.cam: GradCAM
|
|
1057
|
+
|
|
1058
|
+
def get_torch_model(self, model_config: DictConfig) -> nn.Module:
|
|
1059
|
+
"""Instantiate the torch model from the config."""
|
|
1060
|
+
pre_classifier = self.get_pre_classifier(model_config)
|
|
1061
|
+
classifier = self.get_classifier(model_config)
|
|
1062
|
+
log.info("Instantiating backbone <%s>", model_config.model["_target_"])
|
|
1063
|
+
|
|
1064
|
+
return hydra.utils.instantiate(
|
|
1065
|
+
model_config.model, classifier=classifier, pre_classifier=pre_classifier, _convert_="partial"
|
|
1066
|
+
)
|
|
1067
|
+
|
|
1068
|
+
def get_pre_classifier(self, model_config: DictConfig) -> nn.Module:
|
|
1069
|
+
"""Instantiate the pre-classifier from the config."""
|
|
1070
|
+
if "pre_classifier" in model_config and model_config.pre_classifier is not None:
|
|
1071
|
+
log.info("Instantiating pre_classifier <%s>", model_config.pre_classifier["_target_"])
|
|
1072
|
+
pre_classifier = hydra.utils.instantiate(model_config.pre_classifier, _convert_="partial")
|
|
1073
|
+
else:
|
|
1074
|
+
log.info("No pre-classifier found in config: instantiate a torch.nn.Identity instead")
|
|
1075
|
+
pre_classifier = nn.Identity()
|
|
1076
|
+
|
|
1077
|
+
return pre_classifier
|
|
1078
|
+
|
|
1079
|
+
def get_classifier(self, model_config: DictConfig) -> nn.Module:
|
|
1080
|
+
"""Instantiate the classifier from the config."""
|
|
1081
|
+
if "classifier" in model_config:
|
|
1082
|
+
log.info("Instantiating classifier <%s>", model_config.classifier["_target_"])
|
|
1083
|
+
return hydra.utils.instantiate(
|
|
1084
|
+
model_config.classifier, out_features=len(self.model_data["classes"]), _convert_="partial"
|
|
1085
|
+
)
|
|
1086
|
+
|
|
1087
|
+
raise ValueError("A `classifier` definition must be specified in the config")
|
|
1088
|
+
|
|
1089
|
+
@property
|
|
1090
|
+
def deployment_model(self) -> BaseEvaluationModel:
|
|
1091
|
+
"""Deployment model."""
|
|
1092
|
+
return self._deployment_model
|
|
1093
|
+
|
|
1094
|
+
@deployment_model.setter
|
|
1095
|
+
def deployment_model(self, model_path: str):
|
|
1096
|
+
"""Set the deployment model."""
|
|
1097
|
+
file_extension = os.path.splitext(model_path)[1]
|
|
1098
|
+
model_architecture = None
|
|
1099
|
+
if file_extension == ".pth":
|
|
1100
|
+
model_config = OmegaConf.load(os.path.join(Path(model_path).parent, "model_config.yaml"))
|
|
1101
|
+
|
|
1102
|
+
if not isinstance(model_config, DictConfig):
|
|
1103
|
+
raise ValueError(f"The model config must be a DictConfig, got {type(model_config)}")
|
|
1104
|
+
|
|
1105
|
+
model_architecture = self.get_torch_model(model_config)
|
|
1106
|
+
|
|
1107
|
+
self._deployment_model = import_deployment_model(
|
|
1108
|
+
model_path=model_path,
|
|
1109
|
+
device=self.device,
|
|
1110
|
+
inference_config=self.config.inference,
|
|
1111
|
+
model_architecture=model_architecture,
|
|
1112
|
+
)
|
|
1113
|
+
|
|
1114
|
+
if self.gradcam and not isinstance(self.deployment_model, TorchEvaluationModel):
|
|
1115
|
+
log.warning("To compute gradcams you need to provide the path to an exported .pth state_dict file")
|
|
1116
|
+
self.gradcam = False
|
|
1117
|
+
|
|
1118
|
+
def prepare(self) -> None:
|
|
1119
|
+
"""Prepare the evaluation."""
|
|
1120
|
+
super().prepare()
|
|
1121
|
+
self.datamodule = self.config.datamodule
|
|
1122
|
+
self.datamodule.class_to_idx = {v: int(k) for k, v in self.model_data["classes"].items()}
|
|
1123
|
+
self.datamodule.num_classes = len(self.datamodule.class_to_idx)
|
|
1124
|
+
|
|
1125
|
+
# prepare_data() must be explicitly called because there is no training
|
|
1126
|
+
self.datamodule.prepare_data()
|
|
1127
|
+
self.datamodule.setup(stage="test")
|
|
1128
|
+
|
|
1129
|
+
def prepare_gradcam(self) -> None:
|
|
1130
|
+
"""Initializing gradcam for the predictions."""
|
|
1131
|
+
if not hasattr(self.deployment_model.model, "features_extractor"):
|
|
1132
|
+
log.warning("Gradcam not implemented for this backbone, it will not be computed")
|
|
1133
|
+
self.gradcam = False
|
|
1134
|
+
return
|
|
1135
|
+
|
|
1136
|
+
if isinstance(self.deployment_model.model.features_extractor, timm.models.resnet.ResNet):
|
|
1137
|
+
target_layers = [cast(BaseNetworkBuilder, self.deployment_model.model).features_extractor.layer4[-1]]
|
|
1138
|
+
self.cam = GradCAM(
|
|
1139
|
+
model=self.deployment_model.model,
|
|
1140
|
+
target_layers=target_layers,
|
|
1141
|
+
use_cuda=(self.device != "cpu"),
|
|
1142
|
+
)
|
|
1143
|
+
for p in self.deployment_model.model.features_extractor.layer4[-1].parameters():
|
|
1144
|
+
p.requires_grad = True
|
|
1145
|
+
elif is_vision_transformer(cast(BaseNetworkBuilder, self.deployment_model.model).features_extractor):
|
|
1146
|
+
self.grad_rollout = VitAttentionGradRollout(cast(nn.Module, self.deployment_model.model))
|
|
1147
|
+
else:
|
|
1148
|
+
log.warning("Gradcam not implemented for this backbone, it will not be computed")
|
|
1149
|
+
self.gradcam = False
|
|
1150
|
+
|
|
1151
|
+
@automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
|
|
1152
|
+
def test(self) -> None:
|
|
1153
|
+
"""Perform test."""
|
|
1154
|
+
log.info("Running test")
|
|
1155
|
+
test_dataloader = self.datamodule.test_dataloader()
|
|
1156
|
+
|
|
1157
|
+
image_labels = []
|
|
1158
|
+
probabilities = []
|
|
1159
|
+
predicted_classes = []
|
|
1160
|
+
grayscale_cams_list = []
|
|
1161
|
+
|
|
1162
|
+
if self.gradcam:
|
|
1163
|
+
self.prepare_gradcam()
|
|
1164
|
+
|
|
1165
|
+
with torch.set_grad_enabled(self.gradcam):
|
|
1166
|
+
for batch_item in tqdm(test_dataloader):
|
|
1167
|
+
im, target = batch_item
|
|
1168
|
+
im = im.to(device=self.device, dtype=self.deployment_model.model_dtype).detach()
|
|
1169
|
+
|
|
1170
|
+
if self.gradcam:
|
|
1171
|
+
# When gradcam is used we need to remove gradients
|
|
1172
|
+
outputs = self.deployment_model(im).detach()
|
|
1173
|
+
else:
|
|
1174
|
+
outputs = self.deployment_model(im)
|
|
1175
|
+
|
|
1176
|
+
probs = torch.softmax(outputs, dim=1)
|
|
1177
|
+
preds = torch.max(probs, dim=1).indices
|
|
1178
|
+
|
|
1179
|
+
probabilities.append(probs.tolist())
|
|
1180
|
+
predicted_classes.append(preds.tolist())
|
|
1181
|
+
image_labels.extend(target.tolist())
|
|
1182
|
+
if self.gradcam and hasattr(self.deployment_model.model, "features_extractor"):
|
|
1183
|
+
with torch.inference_mode(False):
|
|
1184
|
+
im = im.clone()
|
|
1185
|
+
if isinstance(self.deployment_model.model.features_extractor, timm.models.resnet.ResNet):
|
|
1186
|
+
grayscale_cam = self.cam(input_tensor=im, targets=None)
|
|
1187
|
+
grayscale_cams_list.append(torch.from_numpy(grayscale_cam))
|
|
1188
|
+
elif is_vision_transformer(
|
|
1189
|
+
cast(BaseNetworkBuilder, self.deployment_model.model).features_extractor
|
|
1190
|
+
):
|
|
1191
|
+
grayscale_cam_low_res = self.grad_rollout(input_tensor=im, targets_list=preds.tolist())
|
|
1192
|
+
orig_shape = grayscale_cam_low_res.shape
|
|
1193
|
+
new_shape = (orig_shape[0], im.shape[2], im.shape[3])
|
|
1194
|
+
zoom_factors = tuple(np.array(new_shape) / np.array(orig_shape))
|
|
1195
|
+
grayscale_cam = ndimage.zoom(grayscale_cam_low_res, zoom_factors, order=1)
|
|
1196
|
+
grayscale_cams_list.append(torch.from_numpy(grayscale_cam))
|
|
1197
|
+
|
|
1198
|
+
grayscale_cams: torch.Tensor | None = None
|
|
1199
|
+
if self.gradcam:
|
|
1200
|
+
grayscale_cams = torch.cat(grayscale_cams_list, dim=0)
|
|
1201
|
+
|
|
1202
|
+
predicted_classes = [item for sublist in predicted_classes for item in sublist]
|
|
1203
|
+
probabilities = [max(item) for sublist in probabilities for item in sublist]
|
|
1204
|
+
if self.datamodule.class_to_idx is not None:
|
|
1205
|
+
idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
|
|
1206
|
+
|
|
1207
|
+
_, pd_cm, test_accuracy = get_results(
|
|
1208
|
+
test_labels=image_labels,
|
|
1209
|
+
pred_labels=predicted_classes,
|
|
1210
|
+
idx_to_labels=idx_to_class,
|
|
1211
|
+
)
|
|
1212
|
+
|
|
1213
|
+
res = pd.DataFrame(
|
|
1214
|
+
{
|
|
1215
|
+
"sample": list(test_dataloader.dataset.x), # type: ignore[attr-defined]
|
|
1216
|
+
"real_label": image_labels,
|
|
1217
|
+
"pred_label": predicted_classes,
|
|
1218
|
+
"probability": probabilities,
|
|
1219
|
+
}
|
|
1220
|
+
)
|
|
1221
|
+
|
|
1222
|
+
log.info("Avg classification accuracy: %s", test_accuracy)
|
|
1223
|
+
|
|
1224
|
+
self.res = pd.DataFrame(
|
|
1225
|
+
{
|
|
1226
|
+
"sample": list(test_dataloader.dataset.x), # type: ignore[attr-defined]
|
|
1227
|
+
"real_label": image_labels,
|
|
1228
|
+
"pred_label": predicted_classes,
|
|
1229
|
+
"probability": probabilities,
|
|
1230
|
+
}
|
|
1231
|
+
)
|
|
1232
|
+
|
|
1233
|
+
# save results
|
|
1234
|
+
self.metadata["test_confusion_matrix"] = pd_cm
|
|
1235
|
+
self.metadata["test_accuracy"] = test_accuracy
|
|
1236
|
+
self.metadata["predictions"] = predicted_classes
|
|
1237
|
+
self.metadata["test_results"] = res
|
|
1238
|
+
self.metadata["probabilities"] = probabilities
|
|
1239
|
+
self.metadata["test_labels"] = image_labels
|
|
1240
|
+
self.metadata["grayscale_cams"] = grayscale_cams
|
|
1241
|
+
|
|
1242
|
+
def generate_report(self) -> None:
|
|
1243
|
+
"""Generate a report for the task."""
|
|
1244
|
+
log.info("Generating report!")
|
|
1245
|
+
os.makedirs(self.report_path, exist_ok=True)
|
|
1246
|
+
|
|
1247
|
+
save_classification_result(
|
|
1248
|
+
results=self.metadata["test_results"],
|
|
1249
|
+
output_folder=self.report_path,
|
|
1250
|
+
confmat=self.metadata["test_confusion_matrix"],
|
|
1251
|
+
accuracy=self.metadata["test_accuracy"],
|
|
1252
|
+
test_dataloader=self.datamodule.test_dataloader(),
|
|
1253
|
+
config=self.config,
|
|
1254
|
+
output=self.output,
|
|
1255
|
+
grayscale_cams=self.metadata["grayscale_cams"],
|
|
1256
|
+
)
|
|
1257
|
+
|
|
1258
|
+
def execute(self) -> None:
|
|
1259
|
+
"""Execute the evaluation."""
|
|
1260
|
+
self.prepare()
|
|
1261
|
+
self.test()
|
|
1262
|
+
if self.report:
|
|
1263
|
+
self.generate_report()
|
|
1264
|
+
self.finalize()
|