quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,1263 @@
1
+ from __future__ import annotations
2
+
3
+ import glob
4
+ import json
5
+ import os
6
+ import typing
7
+ from copy import deepcopy
8
+ from pathlib import Path
9
+ from typing import Any, Generic, cast
10
+
11
+ import cv2
12
+ import hydra
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ import pandas as pd
16
+ import timm
17
+ import torch
18
+ from joblib import dump, load
19
+ from omegaconf import DictConfig, ListConfig, OmegaConf
20
+ from pytorch_grad_cam import GradCAM
21
+ from scipy import ndimage
22
+ from sklearn.base import ClassifierMixin
23
+ from sklearn.metrics import ConfusionMatrixDisplay
24
+ from torch import nn
25
+ from torchinfo import summary
26
+ from tqdm import tqdm
27
+
28
+ from quadra.callbacks.mlflow import get_mlflow_logger
29
+ from quadra.callbacks.scheduler import WarmupInit
30
+ from quadra.datamodules import (
31
+ ClassificationDataModule,
32
+ MultilabelClassificationDataModule,
33
+ SklearnClassificationDataModule,
34
+ )
35
+ from quadra.datasets.classification import ImageClassificationListDataset
36
+ from quadra.models.base import ModelSignatureWrapper
37
+ from quadra.models.classification import BaseNetworkBuilder
38
+ from quadra.models.evaluation import BaseEvaluationModel, TorchEvaluationModel, TorchscriptEvaluationModel
39
+ from quadra.modules.classification import ClassificationModule
40
+ from quadra.tasks.base import Evaluation, LightningTask, Task
41
+ from quadra.trainers.classification import SklearnClassificationTrainer
42
+ from quadra.utils import utils
43
+ from quadra.utils.classification import (
44
+ get_results,
45
+ save_classification_result,
46
+ )
47
+ from quadra.utils.evaluation import automatic_datamodule_batch_size
48
+ from quadra.utils.export import export_model, import_deployment_model
49
+ from quadra.utils.models import get_feature, is_vision_transformer
50
+ from quadra.utils.vit_explainability import VitAttentionGradRollout
51
+
52
+ log = utils.get_logger(__name__)
53
+
54
+ SklearnClassificationDataModuleT = typing.TypeVar(
55
+ "SklearnClassificationDataModuleT", bound=SklearnClassificationDataModule
56
+ )
57
+ ClassificationDataModuleT = typing.TypeVar("ClassificationDataModuleT", bound=ClassificationDataModule)
58
+
59
+
60
+ # TODO: Maybe we should have a BaseClassificationTask that is extended by Classification and MultilabelClassification
61
+ # at the current time, multilabel experiments use this Classification task class and they can not generate report
62
+ # (it is written specifically for a vanilla classification). Moreover, this class takes generic
63
+ # ClassificationDataModuleT, but multilabel experim. uses MultilabelClassificationDataModule, which is not a child of
64
+ # ClassificationDataModule
65
+ class Classification(Generic[ClassificationDataModuleT], LightningTask[ClassificationDataModuleT]):
66
+ """Classification Task.
67
+
68
+ Args:
69
+ config: The experiment configuration
70
+ output: The otuput configuration.
71
+ gradcam: Whether to compute gradcams
72
+ checkpoint_path: The path to the checkpoint to load the model from. Defaults to None.
73
+ lr_multiplier: The multiplier for the backbone learning rate. Defaults to None.
74
+ output: The ouput configuration (under task config). It contains the bool "example" to generate
75
+ figs of discordant/concordant predictions.
76
+ report: Whether to generate a report containing the results after test phase
77
+ run_test: Whether to run the test phase.
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ config: DictConfig,
83
+ output: DictConfig,
84
+ checkpoint_path: str | None = None,
85
+ lr_multiplier: float | None = None,
86
+ gradcam: bool = False,
87
+ report: bool = False,
88
+ run_test: bool = False,
89
+ ):
90
+ super().__init__(
91
+ config=config,
92
+ checkpoint_path=checkpoint_path,
93
+ run_test=run_test,
94
+ report=report,
95
+ )
96
+ self.output = output
97
+ self.gradcam = gradcam
98
+ self._lr_multiplier = lr_multiplier
99
+ self._pre_classifier: nn.Module
100
+ self._classifier: nn.Module
101
+ self._model: nn.Module
102
+ self._optimizer: torch.optim.Optimizer
103
+ self._scheduler: torch.optim.lr_scheduler._LRScheduler
104
+ self.model_json: dict[str, Any] | None = None
105
+ self.export_folder: str = "deployment_model"
106
+ self.deploy_info_file: str = "model.json"
107
+ self.report_confmat: pd.DataFrame
108
+ self.best_model_path: str | None = None
109
+
110
+ @property
111
+ def optimizer(self) -> torch.optim.Optimizer:
112
+ """Get the optimizer."""
113
+ return self._optimizer
114
+
115
+ @optimizer.setter
116
+ def optimizer(self, optimizer_config: DictConfig) -> None:
117
+ """Set the optimizer."""
118
+ if (
119
+ isinstance(self.model.features_extractor, nn.Module)
120
+ and isinstance(self.model.pre_classifier, nn.Module)
121
+ and isinstance(self.model.classifier, nn.Module)
122
+ ):
123
+ log.info("Instantiating optimizer <%s>", self.config.optimizer["_target_"])
124
+ if self._lr_multiplier is not None and self._lr_multiplier > 0:
125
+ params = [
126
+ {
127
+ "params": self.model.features_extractor.parameters(),
128
+ "lr": optimizer_config.lr * self._lr_multiplier,
129
+ }
130
+ ]
131
+ else:
132
+ params = [{"params": self.model.features_extractor.parameters(), "lr": optimizer_config.lr}]
133
+ params.append({"params": self.model.pre_classifier.parameters(), "lr": optimizer_config.lr})
134
+ params.append({"params": self.model.classifier.parameters(), "lr": optimizer_config.lr})
135
+ self._optimizer = hydra.utils.instantiate(optimizer_config, params)
136
+
137
+ @property
138
+ def len_train_dataloader(self) -> int:
139
+ """Get the length of the train dataloader."""
140
+ len_train_dataloader = len(self.datamodule.train_dataloader())
141
+ if self.devices is not None:
142
+ num_gpus = len(self.devices) if isinstance(self.devices, list) else 1
143
+ len_train_dataloader = len_train_dataloader // num_gpus
144
+ if not self.datamodule.train_dataloader().drop_last:
145
+ len_train_dataloader += int(len(self.datamodule.train_dataloader()) % num_gpus != 0)
146
+ return len_train_dataloader
147
+
148
+ @property
149
+ def scheduler(self) -> torch.optim.lr_scheduler._LRScheduler:
150
+ """Get the scheduler."""
151
+ return self._scheduler
152
+
153
+ @scheduler.setter
154
+ def scheduler(self, scheduler_config: DictConfig) -> None:
155
+ log.info("Instantiating scheduler <%s>", scheduler_config["_target_"])
156
+ if "CosineAnnealingWithLinearWarmUp" in self.config.scheduler["_target_"]:
157
+ # This scheduler will be overwritten by the SSLCallback
158
+ self._scheduler = hydra.utils.instantiate(
159
+ scheduler_config,
160
+ optimizer=self.optimizer,
161
+ batch_size=1,
162
+ len_loader=1,
163
+ )
164
+ self.add_callback(WarmupInit(scheduler_config=scheduler_config))
165
+ else:
166
+ self._scheduler = hydra.utils.instantiate(scheduler_config, optimizer=self.optimizer)
167
+
168
+ @property
169
+ def module(self) -> ClassificationModule:
170
+ """Get the module of the model."""
171
+ return self._module
172
+
173
+ @LightningTask.module.setter
174
+ def module(self, module_config): # noqa: F811
175
+ """Set the module of the model."""
176
+ module = hydra.utils.instantiate(
177
+ module_config,
178
+ model=self.model,
179
+ optimizer=self.optimizer,
180
+ lr_scheduler=self.scheduler,
181
+ gradcam=self.gradcam,
182
+ )
183
+ if self.checkpoint_path is not None:
184
+ log.info("Loading model from lightning checkpoint: %s", self.checkpoint_path)
185
+ module = module.__class__.load_from_checkpoint(
186
+ self.checkpoint_path,
187
+ model=self.model,
188
+ optimizer=self.optimizer,
189
+ lr_scheduler=self.scheduler,
190
+ criterion=module.criterion,
191
+ gradcam=self.gradcam,
192
+ )
193
+ self._module = module
194
+
195
+ @property
196
+ def pre_classifier(self) -> nn.Module:
197
+ return self._pre_classifier
198
+
199
+ @pre_classifier.setter
200
+ def pre_classifier(self, model_config: DictConfig) -> None:
201
+ if "pre_classifier" in model_config and model_config.pre_classifier is not None:
202
+ log.info("Instantiating pre_classifier <%s>", model_config.pre_classifier["_target_"])
203
+ self._pre_classifier = hydra.utils.instantiate(model_config.pre_classifier, _convert_="partial")
204
+ else:
205
+ log.info("No pre-classifier found in config: instantiate a torch.nn.Identity instead")
206
+ self._pre_classifier = nn.Identity()
207
+
208
+ @property
209
+ def classifier(self) -> nn.Module:
210
+ return self._classifier
211
+
212
+ @classifier.setter
213
+ def classifier(self, model_config: DictConfig) -> None:
214
+ if "classifier" in model_config:
215
+ log.info("Instantiating classifier <%s>", model_config.classifier["_target_"])
216
+ if self.datamodule.num_classes is None or self.datamodule.num_classes < 2:
217
+ raise ValueError(f"Non compliant datamodule.num_classes : {self.datamodule.num_classes}")
218
+ self._classifier = hydra.utils.instantiate(
219
+ model_config.classifier, out_features=self.datamodule.num_classes, _convert_="partial"
220
+ )
221
+ else:
222
+ raise ValueError("A `classifier` definition must be specified in the config")
223
+
224
+ @property
225
+ def model(self) -> nn.Module:
226
+ return self._model
227
+
228
+ @model.setter
229
+ def model(self, model_config: DictConfig) -> None:
230
+ self.pre_classifier = model_config # type: ignore[assignment]
231
+ self.classifier = model_config # type: ignore[assignment]
232
+ log.info("Instantiating backbone <%s>", model_config.model["_target_"])
233
+ self._model = hydra.utils.instantiate(
234
+ model_config.model, classifier=self.classifier, pre_classifier=self.pre_classifier, _convert_="partial"
235
+ )
236
+ if getattr(self.config.backbone, "freeze_parameters_name", None) is not None:
237
+ self.freeze_layers_by_name(self.config.backbone.freeze_parameters_name)
238
+
239
+ if getattr(self.config.backbone, "freeze_parameters_index", None) is not None:
240
+ frozen_parameters_indices: list[int]
241
+ if isinstance(self.config.backbone.freeze_parameters_index, int):
242
+ # Freeze all layers up to the specified index
243
+ frozen_parameters_indices = list(range(self.config.backbone.freeze_parameters_index + 1))
244
+ elif isinstance(self.config.backbone.freeze_parameters_index, ListConfig):
245
+ frozen_parameters_indices = cast(
246
+ list[int], OmegaConf.to_container(self.config.backbone.freeze_parameters_index, resolve=True)
247
+ )
248
+ else:
249
+ raise ValueError("freeze_parameters_index must be an int or a list of int")
250
+
251
+ self.freeze_parameters_by_index(frozen_parameters_indices)
252
+
253
+ def prepare(self) -> None:
254
+ """Prepare the experiment."""
255
+ super().prepare()
256
+ self.model = self.config.model
257
+ self.optimizer = self.config.optimizer
258
+ self.scheduler = self.config.scheduler
259
+ self.module = self.config.model.module
260
+
261
+ def train(self):
262
+ """Train the model."""
263
+ super().train()
264
+ if (
265
+ self.trainer.checkpoint_callback is not None
266
+ and hasattr(self.trainer.checkpoint_callback, "best_model_path")
267
+ and self.trainer.checkpoint_callback.best_model_path is not None
268
+ and len(self.trainer.checkpoint_callback.best_model_path) > 0
269
+ ):
270
+ self.best_model_path = self.trainer.checkpoint_callback.best_model_path
271
+ log.info("Loading best epoch weights...")
272
+
273
+ def test(self) -> None:
274
+ """Test the model."""
275
+ if not self.config.trainer.get("fast_dev_run"):
276
+ log.info("Starting testing!")
277
+ self.trainer.test(datamodule=self.datamodule, model=self.module, ckpt_path=self.best_model_path)
278
+
279
+ def export(self) -> None:
280
+ """Generate deployment models for the task."""
281
+ if self.datamodule.class_to_idx is None:
282
+ log.warning(
283
+ "No `class_to_idx` found in the datamodule, class information will not be saved in the model.json"
284
+ )
285
+ idx_to_class = {}
286
+ else:
287
+ idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
288
+
289
+ # Get best model!
290
+ if self.best_model_path is not None:
291
+ log.info("Saving deployment model for %s checkpoint", self.best_model_path)
292
+
293
+ module = self.module.__class__.load_from_checkpoint(
294
+ self.best_model_path,
295
+ model=self.module.model,
296
+ optimizer=self.optimizer,
297
+ lr_scheduler=self.scheduler,
298
+ criterion=self.module.criterion,
299
+ gradcam=False,
300
+ )
301
+ else:
302
+ log.warning("No checkpoint callback found in the trainer, exporting the last model weights")
303
+ module = self.module
304
+
305
+ input_shapes = self.config.export.input_shapes
306
+
307
+ # TODO: What happens if we have 64 precision?
308
+ half_precision = "16" in self.trainer.precision
309
+
310
+ self.model_json, export_paths = export_model(
311
+ config=self.config,
312
+ model=module.model,
313
+ export_folder=self.export_folder,
314
+ half_precision=half_precision,
315
+ input_shapes=input_shapes,
316
+ idx_to_class=idx_to_class,
317
+ )
318
+
319
+ if len(export_paths) == 0:
320
+ return
321
+
322
+ with open(os.path.join(self.export_folder, self.deploy_info_file), "w") as f:
323
+ json.dump(self.model_json, f)
324
+
325
+ def generate_report(self) -> None:
326
+ """Generate a report for the task."""
327
+ if self.datamodule.class_to_idx is None:
328
+ log.warning("No `class_to_idx` found in the datamodule, report will not be generated")
329
+ return
330
+
331
+ if isinstance(self.datamodule, MultilabelClassificationDataModule):
332
+ log.warning("Report generation is not supported for multilabel classification tasks at the moment.")
333
+ return
334
+
335
+ log.info("Generating report!")
336
+ if not self.run_test or self.config.trainer.get("fast_dev_run"):
337
+ self.datamodule.setup(stage="test")
338
+
339
+ # Deepcopy to remove the inference mode from gradients causing issues when loading checkpoints
340
+ # TODO: Why deepcopy of module model removes ModelSignatureWrapper?
341
+ self.module.model.instance = deepcopy(self.module.model.instance)
342
+ if "16" in self.trainer.precision:
343
+ log.warning("Gradcam is currently not supported with half precision, it will be disabled")
344
+ self.module.gradcam = False
345
+ self.gradcam = False
346
+
347
+ predictions_outputs = self.trainer.predict(
348
+ model=self.module, datamodule=self.datamodule, ckpt_path=self.best_model_path
349
+ )
350
+ if not predictions_outputs:
351
+ log.warning("There is no prediction to generate the report. Skipping report generation.")
352
+ return
353
+ all_outputs = [x[0] for x in predictions_outputs]
354
+ all_probs = [x[2] for x in predictions_outputs]
355
+ if not all_outputs or not all_probs:
356
+ log.warning("There is no prediction to generate the report. Skipping report generation.")
357
+ return
358
+ all_outputs = [item for sublist in all_outputs for item in sublist]
359
+ all_probs = [item for sublist in all_probs for item in sublist]
360
+ all_targets = [target.tolist() for im, target in self.datamodule.test_dataloader()]
361
+ all_targets = [item for sublist in all_targets for item in sublist]
362
+
363
+ if self.module.gradcam:
364
+ grayscale_cams = [x[1] for x in predictions_outputs]
365
+ grayscale_cams = [item for sublist in grayscale_cams for item in sublist]
366
+ grayscale_cams = np.stack(grayscale_cams) # N x H x W
367
+ else:
368
+ grayscale_cams = None
369
+
370
+ # creating confusion matrix
371
+ idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
372
+ _, self.report_confmat, accuracy = get_results(
373
+ test_labels=all_targets,
374
+ pred_labels=all_outputs,
375
+ idx_to_labels=idx_to_class,
376
+ )
377
+ output_folder_test = "test"
378
+ test_dataloader = self.datamodule.test_dataloader()
379
+ test_dataset = cast(ImageClassificationListDataset, test_dataloader.dataset)
380
+ self.res = pd.DataFrame(
381
+ {
382
+ "sample": list(test_dataset.x),
383
+ "real_label": all_targets,
384
+ "pred_label": all_outputs,
385
+ "probability": all_probs,
386
+ }
387
+ )
388
+ os.makedirs(output_folder_test, exist_ok=True)
389
+ save_classification_result(
390
+ results=self.res,
391
+ output_folder=output_folder_test,
392
+ confmat=self.report_confmat,
393
+ accuracy=accuracy,
394
+ test_dataloader=self.datamodule.test_dataloader(),
395
+ config=self.config,
396
+ output=self.output,
397
+ grayscale_cams=grayscale_cams,
398
+ )
399
+
400
+ if len(self.logger) > 0:
401
+ mflow_logger = get_mlflow_logger(trainer=self.trainer)
402
+ tensorboard_logger = utils.get_tensorboard_logger(trainer=self.trainer)
403
+ artifacts = glob.glob(os.path.join(output_folder_test, "**/*"), recursive=True)
404
+ if self.config.core.get("upload_artifacts") and len(artifacts) > 0:
405
+ if mflow_logger is not None:
406
+ log.info("Uploading artifacts to MLFlow")
407
+ for a in artifacts:
408
+ if os.path.isdir(a):
409
+ continue
410
+
411
+ dirname = Path(a).parent.name
412
+ mflow_logger.experiment.log_artifact(
413
+ run_id=mflow_logger.run_id,
414
+ local_path=a,
415
+ artifact_path=os.path.join("classification_output", dirname),
416
+ )
417
+ if tensorboard_logger is not None:
418
+ log.info("Uploading artifacts to Tensorboard")
419
+ for a in artifacts:
420
+ if os.path.isdir(a):
421
+ continue
422
+
423
+ ext = os.path.splitext(a)[1].lower()
424
+
425
+ if ext in [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".gif"]:
426
+ try:
427
+ img = cv2.imread(a)
428
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
429
+ except cv2.error:
430
+ log.info("Could not upload artifact image %s", a)
431
+ continue
432
+ output_path = os.path.sep.join(a.split(os.path.sep)[-2:])
433
+ tensorboard_logger.experiment.add_image(output_path, img, 0, dataformats="HWC")
434
+ else:
435
+ utils.upload_file_tensorboard(a, tensorboard_logger)
436
+
437
+ def freeze_layers_by_name(self, freeze_parameters_name: list[str]):
438
+ """Freeze layers specified in freeze_parameters_name.
439
+
440
+ Args:
441
+ freeze_parameters_name: Layers that will be frozen during training.
442
+
443
+ """
444
+ count_frozen = 0
445
+ for name, param in self.model.named_parameters():
446
+ if any(x in name.split(".")[1] for x in freeze_parameters_name):
447
+ log.debug("Freezing layer %s", name)
448
+ param.requires_grad = False
449
+
450
+ if not param.requires_grad:
451
+ count_frozen += 1
452
+
453
+ log.info("Frozen %d parameters", count_frozen)
454
+
455
+ def freeze_parameters_by_index(self, freeze_parameters_index: list[int]):
456
+ """Freeze parameters specified in freeze_parameters_name.
457
+
458
+ Args:
459
+ freeze_parameters_index: Indices of parameters that will be frozen during training.
460
+
461
+ """
462
+ if getattr(self.config.backbone, "freeze_parameters_name", None) is not None:
463
+ log.warning(
464
+ "Please be aware that some of the model's parameters have already been frozen using \
465
+ the specified freeze_parameters_name. You are combining these two actions."
466
+ )
467
+ count_frozen = 0
468
+ for i, (name, param) in enumerate(self.model.named_parameters()):
469
+ if i in freeze_parameters_index:
470
+ log.debug("Freezing layer %s", name)
471
+ param.requires_grad = False
472
+
473
+ if not param.requires_grad:
474
+ count_frozen += 1
475
+
476
+ log.info("Frozen %d parameters", count_frozen)
477
+
478
+
479
+ class SklearnClassification(Generic[SklearnClassificationDataModuleT], Task[SklearnClassificationDataModuleT]):
480
+ """Sklearn classification task.
481
+
482
+ Args:
483
+ config: The experiment configuration
484
+ device: The device to use. Defaults to None.
485
+ output: Dictionary defining which kind of outputs to generate. Defaults to None.
486
+ automatic_batch_size: Whether to automatically find the largest batch size that fits in memory.
487
+ save_model_summary: Whether to save a model_summary.txt file containing the model summary.
488
+ half_precision: Whether to use half precision during training.
489
+ gradcam: Whether to compute gradcams for test results.
490
+ """
491
+
492
+ def __init__(
493
+ self,
494
+ config: DictConfig,
495
+ output: DictConfig,
496
+ device: str,
497
+ automatic_batch_size: DictConfig,
498
+ save_model_summary: bool = False,
499
+ half_precision: bool = False,
500
+ gradcam: bool = False,
501
+ ):
502
+ super().__init__(config=config)
503
+
504
+ self._device = device
505
+ self.output = output
506
+ self._backbone: ModelSignatureWrapper
507
+ self._trainer: SklearnClassificationTrainer
508
+ self._model: ClassifierMixin
509
+ self.metadata: dict[str, Any] = {
510
+ "test_confusion_matrix": [],
511
+ "test_accuracy": [],
512
+ "test_results": [],
513
+ "test_labels": [],
514
+ "cams": [],
515
+ }
516
+ self.export_folder = "deployment_model"
517
+ self.deploy_info_file = "model.json"
518
+ self.train_dataloader_list: list[torch.utils.data.DataLoader] = []
519
+ self.test_dataloader_list: list[torch.utils.data.DataLoader] = []
520
+ self.automatic_batch_size = automatic_batch_size
521
+ self.save_model_summary = save_model_summary
522
+ self.half_precision = half_precision
523
+ self.gradcam = gradcam
524
+
525
+ @property
526
+ def device(self) -> str:
527
+ return self._device
528
+
529
+ def prepare(self) -> None:
530
+ """Prepare the experiment."""
531
+ self.datamodule = self.config.datamodule
532
+
533
+ self.backbone = self.config.backbone
534
+
535
+ self.model = self.config.model
536
+
537
+ # prepare_data() must be explicitly called if the task does not include a lightining training
538
+ self.datamodule.prepare_data()
539
+ self.datamodule.setup(stage="fit")
540
+
541
+ self.trainer = self.config.trainer
542
+
543
+ @property
544
+ def model(self) -> ClassifierMixin:
545
+ """sklearn.base.ClassifierMixin: The model."""
546
+ return self._model
547
+
548
+ @model.setter
549
+ def model(self, model_config: DictConfig):
550
+ """sklearn.base.ClassifierMixin: The model."""
551
+ log.info("Instantiating model <%s>", model_config["_target_"])
552
+ self._model = hydra.utils.instantiate(model_config)
553
+
554
+ @property
555
+ def backbone(self) -> ModelSignatureWrapper:
556
+ """Backbone: The backbone."""
557
+ return self._backbone
558
+
559
+ @backbone.setter
560
+ def backbone(self, backbone_config):
561
+ """Load backbone."""
562
+ if backbone_config.metadata.get("checkpoint"):
563
+ log.info("Loading backbone from <%s>", backbone_config.metadata.checkpoint)
564
+ self._backbone = torch.load(backbone_config.metadata.checkpoint)
565
+ else:
566
+ log.info("Loading backbone from <%s>", backbone_config.model["_target_"])
567
+ self._backbone = hydra.utils.instantiate(backbone_config.model)
568
+
569
+ self._backbone = ModelSignatureWrapper(self._backbone)
570
+ self._backbone.eval()
571
+ if self.half_precision:
572
+ if self.device == "cpu":
573
+ raise ValueError("Half precision is not supported on CPU")
574
+ self._backbone.half()
575
+
576
+ if self.gradcam:
577
+ log.warning("Gradcam is currently not supported with half precision, it will be disabled")
578
+ self.gradcam = False
579
+ self._backbone.to(self.device)
580
+
581
+ @property
582
+ def trainer(self) -> SklearnClassificationTrainer:
583
+ """Trainer: The trainer."""
584
+ return self._trainer
585
+
586
+ @trainer.setter
587
+ def trainer(self, trainer_config: DictConfig) -> None:
588
+ """Trainer: The trainer."""
589
+ log.info("Instantiating trainer <%s>", trainer_config["_target_"])
590
+ trainer = hydra.utils.instantiate(trainer_config, backbone=self.backbone, classifier=self.model)
591
+ self._trainer = trainer
592
+
593
+ @typing.no_type_check
594
+ @automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
595
+ def train(self) -> None:
596
+ """Train the model."""
597
+ log.info("Starting training...!")
598
+ all_features = None
599
+ all_labels = None
600
+
601
+ class_to_keep = None
602
+
603
+ self.train_dataloader_list = list(self.datamodule.train_dataloader())
604
+ self.test_dataloader_list = list(self.datamodule.val_dataloader())
605
+
606
+ if hasattr(self.datamodule, "class_to_keep_training") and self.datamodule.class_to_keep_training is not None:
607
+ class_to_keep = self.datamodule.class_to_keep_training
608
+
609
+ if self.save_model_summary:
610
+ self.extract_model_summary(feature_extractor=self.backbone, dl=self.datamodule.full_dataloader())
611
+
612
+ if hasattr(self.datamodule, "cache") and self.datamodule.cache:
613
+ if self.config.trainer.iteration_over_training != 1:
614
+ raise AttributeError("Cache is only supported when iteration over training is set to 1")
615
+
616
+ full_dataloader = self.datamodule.full_dataloader()
617
+ all_features, all_labels, _ = get_feature(
618
+ feature_extractor=self.backbone, dl=full_dataloader, iteration_over_training=1
619
+ )
620
+
621
+ sorted_indices = np.argsort(full_dataloader.dataset.x)
622
+ all_features = all_features[sorted_indices]
623
+ all_labels = all_labels[sorted_indices]
624
+
625
+ # cycle over all train/test split
626
+ for train_dataloader, test_dataloader in zip(self.train_dataloader_list, self.test_dataloader_list):
627
+ # Reinit classifier
628
+ self.model = self.config.model
629
+ self.trainer.change_classifier(self.model)
630
+
631
+ # Train on current training set
632
+ if all_features is not None and all_labels is not None:
633
+ # Find which are the indices used to pass from the sorted list of string to the disordered one
634
+ sorted_indices = np.argsort(np.concatenate([train_dataloader.dataset.x, test_dataloader.dataset.x]))
635
+ revese_sorted_indices = np.argsort(sorted_indices)
636
+
637
+ # Use these indices to correctly match the extracted features with the new file order
638
+ all_features_sorted = all_features[revese_sorted_indices]
639
+ all_labels_sorted = all_labels[revese_sorted_indices]
640
+
641
+ train_len = len(train_dataloader.dataset.x)
642
+
643
+ self.trainer.fit(
644
+ train_features=all_features_sorted[0:train_len], train_labels=all_labels_sorted[0:train_len]
645
+ )
646
+
647
+ _, pd_cm, accuracy, res, cams = self.trainer.test(
648
+ test_dataloader=test_dataloader,
649
+ test_features=all_features_sorted[train_len:],
650
+ test_labels=all_labels_sorted[train_len:],
651
+ class_to_keep=class_to_keep,
652
+ idx_to_class=train_dataloader.dataset.idx_to_class,
653
+ predict_proba=True,
654
+ gradcam=self.gradcam,
655
+ )
656
+ else:
657
+ self.trainer.fit(train_dataloader=train_dataloader)
658
+ _, pd_cm, accuracy, res, cams = self.trainer.test(
659
+ test_dataloader=test_dataloader,
660
+ class_to_keep=class_to_keep,
661
+ idx_to_class=train_dataloader.dataset.idx_to_class,
662
+ predict_proba=True,
663
+ gradcam=self.gradcam,
664
+ )
665
+
666
+ # save results
667
+ self.metadata["test_confusion_matrix"].append(pd_cm)
668
+ self.metadata["test_accuracy"].append(accuracy)
669
+ self.metadata["test_results"].append(res)
670
+ self.metadata["test_labels"].append(
671
+ [
672
+ train_dataloader.dataset.idx_to_class[i] if i != -1 else "N/A"
673
+ for i in res["real_label"].unique().tolist()
674
+ ]
675
+ )
676
+ self.metadata["cams"].append(cams)
677
+
678
+ def extract_model_summary(
679
+ self, feature_extractor: torch.nn.Module | BaseEvaluationModel, dl: torch.utils.data.DataLoader
680
+ ) -> None:
681
+ """Given a dataloader and a PyTorch model, use torchinfo to extract a summary of the model and save it
682
+ to a file.
683
+
684
+ Args:
685
+ dl: PyTorch dataloader
686
+ feature_extractor: PyTorch backbone
687
+ """
688
+ if isinstance(feature_extractor, (TorchEvaluationModel, TorchscriptEvaluationModel)):
689
+ # TODO: I'm not sure torchinfo supports torchscript models
690
+ # If we are working with torch based evaluation models we need to extract the model
691
+ feature_extractor = feature_extractor.model
692
+
693
+ for b in tqdm(dl):
694
+ x1, _ = b
695
+
696
+ if hasattr(feature_extractor, "parameters"):
697
+ # Move input to the correct device
698
+ parameter = next(feature_extractor.parameters())
699
+ x1 = x1.to(parameter.device).to(parameter.dtype)
700
+ x1 = x1[0].unsqueeze(0) # Remove batch dimension
701
+
702
+ model_info = None
703
+
704
+ try:
705
+ try:
706
+ # TODO: Do we want to print the summary to the console as well?
707
+ model_info = summary(feature_extractor, input_data=(x1), verbose=0) # type: ignore[arg-type]
708
+ except Exception:
709
+ log.warning(
710
+ "Failed to retrieve model summary using input data information, retrieving only "
711
+ "parameters information"
712
+ )
713
+ model_info = summary(feature_extractor, verbose=0) # type: ignore[arg-type]
714
+ except Exception as e:
715
+ # If for some reason the summary fails we don't want to stop the training
716
+ log.warning("Failed to retrieve model summary: %s", e)
717
+
718
+ if model_info is not None:
719
+ with open("model_summary.txt", "w") as f:
720
+ f.write(str(model_info))
721
+ else:
722
+ log.warning("Failed to retrieve model summary, current model has no parameters")
723
+
724
+ break
725
+
726
+ @automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
727
+ def train_full_data(self):
728
+ """Train the model on train + validation."""
729
+ # Reinit classifier
730
+ self.model = self.config.model
731
+ self.trainer.change_classifier(self.model)
732
+
733
+ self.trainer.fit(train_dataloader=self.datamodule.full_dataloader())
734
+
735
+ def test(self) -> None:
736
+ """Skip test phase."""
737
+ # we don't need test phase since sklearn trainer is already running test inside
738
+ # train module to handle cross validation
739
+
740
+ @typing.no_type_check
741
+ @automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
742
+ def test_full_data(self) -> None:
743
+ """Test model trained on full dataset."""
744
+ self.config.datamodule.class_to_idx = self.datamodule.full_dataset.class_to_idx
745
+ self.config.datamodule.phase = "test"
746
+ idx_to_class = self.datamodule.full_dataset.idx_to_class
747
+ self.datamodule.setup("test")
748
+ test_dataloader = self.datamodule.test_dataloader()
749
+
750
+ if len(self.datamodule.data["samples"]) == 0:
751
+ log.info("No test data, skipping test")
752
+ return
753
+
754
+ # Put backbone on the correct device as it may be moved after export
755
+ self.backbone.to(self.device)
756
+ _, pd_cm, accuracy, res, cams = self.trainer.test(
757
+ test_dataloader=test_dataloader, idx_to_class=idx_to_class, predict_proba=True, gradcam=self.gradcam
758
+ )
759
+
760
+ output_folder_test = "test"
761
+
762
+ os.makedirs(output_folder_test, exist_ok=True)
763
+
764
+ save_classification_result(
765
+ results=res,
766
+ output_folder=output_folder_test,
767
+ confmat=pd_cm,
768
+ accuracy=accuracy,
769
+ test_dataloader=test_dataloader,
770
+ config=self.config,
771
+ output=self.output,
772
+ grayscale_cams=cams,
773
+ )
774
+
775
+ def export(self) -> None:
776
+ """Generate deployment model for the task."""
777
+ if self.config.export is None or len(self.config.export.types) == 0:
778
+ log.info("No export type specified skipping export")
779
+ return
780
+
781
+ input_shapes = self.config.export.input_shapes
782
+
783
+ idx_to_class = {v: k for k, v in self.datamodule.full_dataset.class_to_idx.items()}
784
+
785
+ model_json, export_paths = export_model(
786
+ config=self.config,
787
+ model=self.backbone,
788
+ export_folder=self.export_folder,
789
+ half_precision=self.half_precision,
790
+ input_shapes=input_shapes,
791
+ idx_to_class=idx_to_class,
792
+ pytorch_model_type="backbone",
793
+ )
794
+
795
+ dump(self.model, os.path.join(self.export_folder, "classifier.joblib"))
796
+
797
+ if len(export_paths) > 0:
798
+ with open(os.path.join(self.export_folder, self.deploy_info_file), "w") as f:
799
+ json.dump(model_json, f)
800
+
801
+ def generate_report(self) -> None:
802
+ """Generate report for the task."""
803
+ log.info("Generating report!")
804
+
805
+ cm_list = []
806
+
807
+ for count in range(len(self.metadata["test_accuracy"])):
808
+ current_output_folder = f"{self.output.folder}_{count}"
809
+ os.makedirs(current_output_folder, exist_ok=True)
810
+
811
+ c_matrix = self.metadata["test_confusion_matrix"][count]
812
+ cm_list.append(c_matrix)
813
+ save_classification_result(
814
+ results=self.metadata["test_results"][count],
815
+ output_folder=current_output_folder,
816
+ confmat=c_matrix,
817
+ accuracy=self.metadata["test_accuracy"][count],
818
+ test_dataloader=self.test_dataloader_list[count],
819
+ config=self.config,
820
+ output=self.output,
821
+ grayscale_cams=self.metadata["cams"][count],
822
+ )
823
+ final_confusion_matrix = sum(cm_list)
824
+
825
+ self.metadata["final_confusion_matrix"] = final_confusion_matrix
826
+ # Save final conf matrix
827
+ final_folder = f"{self.output.folder}"
828
+ os.makedirs(final_folder, exist_ok=True)
829
+ disp = ConfusionMatrixDisplay(
830
+ confusion_matrix=np.array(final_confusion_matrix),
831
+ display_labels=[x.replace("pred:", "") for x in final_confusion_matrix.columns.to_list()],
832
+ )
833
+ disp.plot(include_values=True, cmap=plt.cm.Greens, ax=None, colorbar=False, xticks_rotation=90)
834
+ plt.title(f"Confusion Matrix (Accuracy: {(self.metadata['test_accuracy'][count] * 100):.2f}%)")
835
+ plt.savefig(os.path.join(final_folder, "test_confusion_matrix.png"), bbox_inches="tight", pad_inches=0, dpi=300)
836
+ plt.close()
837
+
838
+ def execute(self) -> None:
839
+ """Execute the experiment and all the steps."""
840
+ self.prepare()
841
+ self.train()
842
+ if self.output.report:
843
+ self.generate_report()
844
+ self.train_full_data()
845
+ if self.config.export is not None and len(self.config.export.types) > 0:
846
+ self.export()
847
+ if self.output.test_full_data:
848
+ self.test_full_data()
849
+ self.finalize()
850
+
851
+
852
+ class SklearnTestClassification(Evaluation[SklearnClassificationDataModuleT]):
853
+ """Perform a test using an imported SklearnClassification pytorch model.
854
+
855
+ Args:
856
+ config: The experiment configuration
857
+ output: where to save results
858
+ model_path: path to trained model generated from SklearnClassification task.
859
+ device: the device where to run the model (cuda or cpu)
860
+ gradcam: Whether to compute gradcams
861
+ **kwargs: Additional arguments to pass to the task
862
+ """
863
+
864
+ def __init__(
865
+ self,
866
+ config: DictConfig,
867
+ output: DictConfig,
868
+ model_path: str,
869
+ device: str,
870
+ gradcam: bool = False,
871
+ **kwargs: Any,
872
+ ):
873
+ super().__init__(config=config, model_path=model_path, device=device, **kwargs)
874
+ self.gradcam = gradcam
875
+ self.output = output
876
+ self._backbone: BaseEvaluationModel
877
+ self._classifier: ClassifierMixin
878
+ self.class_to_idx: dict[str, int]
879
+ self.idx_to_class: dict[int, str]
880
+ self.test_dataloader: torch.utils.data.DataLoader
881
+ self.metadata: dict[str, Any] = {
882
+ "test_confusion_matrix": None,
883
+ "test_accuracy": None,
884
+ "test_results": None,
885
+ "test_labels": None,
886
+ "cams": None,
887
+ }
888
+
889
+ def prepare(self) -> None:
890
+ """Prepare the experiment."""
891
+ super().prepare()
892
+
893
+ idx_to_class = {}
894
+ class_to_idx = {}
895
+ for k, v in self.model_data["classes"].items():
896
+ idx_to_class[int(k)] = v
897
+ class_to_idx[v] = int(k)
898
+
899
+ self.idx_to_class = idx_to_class
900
+ self.class_to_idx = class_to_idx
901
+
902
+ self.config.datamodule.class_to_idx = class_to_idx
903
+
904
+ self.datamodule = self.config.datamodule
905
+ # prepare_data() must be explicitly called because there is no lightning training
906
+ self.datamodule.prepare_data()
907
+ self.datamodule.setup(stage="test")
908
+
909
+ # Configure trainer
910
+ self.trainer = self.config.trainer
911
+
912
+ @property
913
+ def deployment_model(self):
914
+ """Deployment model."""
915
+ return None
916
+
917
+ @deployment_model.setter
918
+ def deployment_model(self, model_path: str):
919
+ """Set backbone and classifier."""
920
+ self.backbone = model_path # type: ignore[assignment]
921
+ # Load classifier
922
+ self.classifier = os.path.join(Path(model_path).parent, "classifier.joblib")
923
+
924
+ @property
925
+ def classifier(self) -> ClassifierMixin:
926
+ """Classifier: The classifier."""
927
+ return self._classifier
928
+
929
+ @classifier.setter
930
+ def classifier(self, classifier_path: str) -> None:
931
+ """Load classifier."""
932
+ self._classifier = load(classifier_path)
933
+
934
+ @property
935
+ def backbone(self) -> BaseEvaluationModel:
936
+ """Backbone: The backbone."""
937
+ return self._backbone
938
+
939
+ @backbone.setter
940
+ def backbone(self, model_path: str) -> None:
941
+ """Load backbone."""
942
+ file_extension = os.path.splitext(model_path)[1]
943
+
944
+ model_architecture = None
945
+ if file_extension == ".pth":
946
+ backbone_config_path = os.path.join(Path(model_path).parent, "model_config.yaml")
947
+ log.info("Loading backbone from config")
948
+ backbone_config = OmegaConf.load(backbone_config_path)
949
+
950
+ if backbone_config.metadata.get("checkpoint"):
951
+ log.info("Loading backbone from <%s>", backbone_config.metadata.checkpoint)
952
+ model_architecture = torch.load(backbone_config.metadata.checkpoint)
953
+ else:
954
+ log.info("Loading backbone from <%s>", backbone_config.model["_target_"])
955
+ model_architecture = hydra.utils.instantiate(backbone_config.model)
956
+
957
+ self._backbone = import_deployment_model(
958
+ model_path=model_path,
959
+ device=self.device,
960
+ inference_config=self.config.inference,
961
+ model_architecture=model_architecture,
962
+ )
963
+
964
+ if self.gradcam and not isinstance(self._backbone, TorchEvaluationModel):
965
+ log.warning("Gradcam is supported only for pytorch models. Skipping gradcam")
966
+ self.gradcam = False
967
+
968
+ @property
969
+ def trainer(self) -> SklearnClassificationTrainer:
970
+ """Trainer: The trainer."""
971
+ return self._trainer
972
+
973
+ @trainer.setter
974
+ def trainer(self, trainer_config: DictConfig) -> None:
975
+ """Trainer: The trainer."""
976
+ log.info("Instantiating trainer <%s>", trainer_config["_target_"])
977
+
978
+ if self.backbone.training:
979
+ self.backbone.eval()
980
+
981
+ trainer = hydra.utils.instantiate(trainer_config, backbone=self.backbone, classifier=self.classifier)
982
+ self._trainer = trainer
983
+
984
+ @automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
985
+ def test(self) -> None:
986
+ """Run the test."""
987
+ self.test_dataloader = self.datamodule.test_dataloader()
988
+
989
+ _, pd_cm, accuracy, res, cams = self.trainer.test(
990
+ test_dataloader=self.test_dataloader,
991
+ idx_to_class=self.idx_to_class,
992
+ predict_proba=True,
993
+ gradcam=self.gradcam,
994
+ )
995
+
996
+ # save results
997
+ self.metadata["test_confusion_matrix"] = pd_cm
998
+ self.metadata["test_accuracy"] = accuracy
999
+ self.metadata["test_results"] = res
1000
+ self.metadata["test_labels"] = [
1001
+ self.idx_to_class[i] if i != -1 else "N/A" for i in res["real_label"].unique().tolist()
1002
+ ]
1003
+ self.metadata["cams"] = cams
1004
+
1005
+ def generate_report(self) -> None:
1006
+ """Generate a report for the task."""
1007
+ log.info("Generating report!")
1008
+ os.makedirs(self.output.folder, exist_ok=True)
1009
+ save_classification_result(
1010
+ results=self.metadata["test_results"],
1011
+ output_folder=self.output.folder,
1012
+ confmat=self.metadata["test_confusion_matrix"],
1013
+ accuracy=self.metadata["test_accuracy"],
1014
+ test_dataloader=self.test_dataloader,
1015
+ config=self.config,
1016
+ output=self.output,
1017
+ grayscale_cams=self.metadata["cams"],
1018
+ )
1019
+
1020
+ def execute(self) -> None:
1021
+ """Execute the experiment and all the steps."""
1022
+ self.prepare()
1023
+ self.test()
1024
+ if self.output.report:
1025
+ self.generate_report()
1026
+ self.finalize()
1027
+
1028
+
1029
+ class ClassificationEvaluation(Evaluation[ClassificationDataModuleT]):
1030
+ """Perform a test on an imported Classification pytorch model.
1031
+
1032
+ Args:
1033
+ config: Task configuration
1034
+ output: Configuration for the output
1035
+ model_path: Path to pytorch .pt model file
1036
+ report: Whether to generate the report of the predictions
1037
+ gradcam: Whether to compute gradcams
1038
+ device: Device to use for evaluation. If None, the device is automatically determined
1039
+
1040
+ """
1041
+
1042
+ def __init__(
1043
+ self,
1044
+ config: DictConfig,
1045
+ output: DictConfig,
1046
+ model_path: str,
1047
+ report: bool = True,
1048
+ gradcam: bool = False,
1049
+ device: str | None = None,
1050
+ ):
1051
+ super().__init__(config=config, model_path=model_path, device=device)
1052
+ self.report_path = "test_output"
1053
+ self.output = output
1054
+ self.report = report
1055
+ self.gradcam = gradcam
1056
+ self.cam: GradCAM
1057
+
1058
+ def get_torch_model(self, model_config: DictConfig) -> nn.Module:
1059
+ """Instantiate the torch model from the config."""
1060
+ pre_classifier = self.get_pre_classifier(model_config)
1061
+ classifier = self.get_classifier(model_config)
1062
+ log.info("Instantiating backbone <%s>", model_config.model["_target_"])
1063
+
1064
+ return hydra.utils.instantiate(
1065
+ model_config.model, classifier=classifier, pre_classifier=pre_classifier, _convert_="partial"
1066
+ )
1067
+
1068
+ def get_pre_classifier(self, model_config: DictConfig) -> nn.Module:
1069
+ """Instantiate the pre-classifier from the config."""
1070
+ if "pre_classifier" in model_config and model_config.pre_classifier is not None:
1071
+ log.info("Instantiating pre_classifier <%s>", model_config.pre_classifier["_target_"])
1072
+ pre_classifier = hydra.utils.instantiate(model_config.pre_classifier, _convert_="partial")
1073
+ else:
1074
+ log.info("No pre-classifier found in config: instantiate a torch.nn.Identity instead")
1075
+ pre_classifier = nn.Identity()
1076
+
1077
+ return pre_classifier
1078
+
1079
+ def get_classifier(self, model_config: DictConfig) -> nn.Module:
1080
+ """Instantiate the classifier from the config."""
1081
+ if "classifier" in model_config:
1082
+ log.info("Instantiating classifier <%s>", model_config.classifier["_target_"])
1083
+ return hydra.utils.instantiate(
1084
+ model_config.classifier, out_features=len(self.model_data["classes"]), _convert_="partial"
1085
+ )
1086
+
1087
+ raise ValueError("A `classifier` definition must be specified in the config")
1088
+
1089
+ @property
1090
+ def deployment_model(self) -> BaseEvaluationModel:
1091
+ """Deployment model."""
1092
+ return self._deployment_model
1093
+
1094
+ @deployment_model.setter
1095
+ def deployment_model(self, model_path: str):
1096
+ """Set the deployment model."""
1097
+ file_extension = os.path.splitext(model_path)[1]
1098
+ model_architecture = None
1099
+ if file_extension == ".pth":
1100
+ model_config = OmegaConf.load(os.path.join(Path(model_path).parent, "model_config.yaml"))
1101
+
1102
+ if not isinstance(model_config, DictConfig):
1103
+ raise ValueError(f"The model config must be a DictConfig, got {type(model_config)}")
1104
+
1105
+ model_architecture = self.get_torch_model(model_config)
1106
+
1107
+ self._deployment_model = import_deployment_model(
1108
+ model_path=model_path,
1109
+ device=self.device,
1110
+ inference_config=self.config.inference,
1111
+ model_architecture=model_architecture,
1112
+ )
1113
+
1114
+ if self.gradcam and not isinstance(self.deployment_model, TorchEvaluationModel):
1115
+ log.warning("To compute gradcams you need to provide the path to an exported .pth state_dict file")
1116
+ self.gradcam = False
1117
+
1118
+ def prepare(self) -> None:
1119
+ """Prepare the evaluation."""
1120
+ super().prepare()
1121
+ self.datamodule = self.config.datamodule
1122
+ self.datamodule.class_to_idx = {v: int(k) for k, v in self.model_data["classes"].items()}
1123
+ self.datamodule.num_classes = len(self.datamodule.class_to_idx)
1124
+
1125
+ # prepare_data() must be explicitly called because there is no training
1126
+ self.datamodule.prepare_data()
1127
+ self.datamodule.setup(stage="test")
1128
+
1129
+ def prepare_gradcam(self) -> None:
1130
+ """Initializing gradcam for the predictions."""
1131
+ if not hasattr(self.deployment_model.model, "features_extractor"):
1132
+ log.warning("Gradcam not implemented for this backbone, it will not be computed")
1133
+ self.gradcam = False
1134
+ return
1135
+
1136
+ if isinstance(self.deployment_model.model.features_extractor, timm.models.resnet.ResNet):
1137
+ target_layers = [cast(BaseNetworkBuilder, self.deployment_model.model).features_extractor.layer4[-1]]
1138
+ self.cam = GradCAM(
1139
+ model=self.deployment_model.model,
1140
+ target_layers=target_layers,
1141
+ )
1142
+ for p in self.deployment_model.model.features_extractor.layer4[-1].parameters():
1143
+ p.requires_grad = True
1144
+ elif is_vision_transformer(cast(BaseNetworkBuilder, self.deployment_model.model).features_extractor):
1145
+ self.grad_rollout = VitAttentionGradRollout(cast(nn.Module, self.deployment_model.model))
1146
+ else:
1147
+ log.warning("Gradcam not implemented for this backbone, it will not be computed")
1148
+ self.gradcam = False
1149
+
1150
+ @automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
1151
+ def test(self) -> None:
1152
+ """Perform test."""
1153
+ log.info("Running test")
1154
+ test_dataloader = self.datamodule.test_dataloader()
1155
+
1156
+ image_labels = []
1157
+ probabilities = []
1158
+ predicted_classes = []
1159
+ grayscale_cams_list = []
1160
+
1161
+ if self.gradcam:
1162
+ self.prepare_gradcam()
1163
+
1164
+ with torch.set_grad_enabled(self.gradcam):
1165
+ for batch_item in tqdm(test_dataloader):
1166
+ im, target = batch_item
1167
+ im = im.to(device=self.device, dtype=self.deployment_model.model_dtype).detach()
1168
+
1169
+ if self.gradcam:
1170
+ # When gradcam is used we need to remove gradients
1171
+ outputs = self.deployment_model(im).detach()
1172
+ else:
1173
+ outputs = self.deployment_model(im)
1174
+
1175
+ probs = torch.softmax(outputs, dim=1)
1176
+ preds = torch.max(probs, dim=1).indices
1177
+
1178
+ probabilities.append(probs.tolist())
1179
+ predicted_classes.append(preds.tolist())
1180
+ image_labels.extend(target.tolist())
1181
+ if self.gradcam and hasattr(self.deployment_model.model, "features_extractor"):
1182
+ with torch.inference_mode(False):
1183
+ im = im.clone()
1184
+ if isinstance(self.deployment_model.model.features_extractor, timm.models.resnet.ResNet):
1185
+ grayscale_cam = self.cam(input_tensor=im, targets=None)
1186
+ grayscale_cams_list.append(torch.from_numpy(grayscale_cam))
1187
+ elif is_vision_transformer(
1188
+ cast(BaseNetworkBuilder, self.deployment_model.model).features_extractor
1189
+ ):
1190
+ grayscale_cam_low_res = self.grad_rollout(input_tensor=im, targets_list=preds.tolist())
1191
+ orig_shape = grayscale_cam_low_res.shape
1192
+ new_shape = (orig_shape[0], im.shape[2], im.shape[3])
1193
+ zoom_factors = tuple(np.array(new_shape) / np.array(orig_shape))
1194
+ grayscale_cam = ndimage.zoom(grayscale_cam_low_res, zoom_factors, order=1)
1195
+ grayscale_cams_list.append(torch.from_numpy(grayscale_cam))
1196
+
1197
+ grayscale_cams: torch.Tensor | None = None
1198
+ if self.gradcam:
1199
+ grayscale_cams = torch.cat(grayscale_cams_list, dim=0)
1200
+
1201
+ predicted_classes = [item for sublist in predicted_classes for item in sublist]
1202
+ probabilities = [max(item) for sublist in probabilities for item in sublist]
1203
+ if self.datamodule.class_to_idx is not None:
1204
+ idx_to_class = {v: k for k, v in self.datamodule.class_to_idx.items()}
1205
+
1206
+ _, pd_cm, test_accuracy = get_results(
1207
+ test_labels=image_labels,
1208
+ pred_labels=predicted_classes,
1209
+ idx_to_labels=idx_to_class,
1210
+ )
1211
+
1212
+ res = pd.DataFrame(
1213
+ {
1214
+ "sample": list(test_dataloader.dataset.x), # type: ignore[attr-defined]
1215
+ "real_label": image_labels,
1216
+ "pred_label": predicted_classes,
1217
+ "probability": probabilities,
1218
+ }
1219
+ )
1220
+
1221
+ log.info("Avg classification accuracy: %s", test_accuracy)
1222
+
1223
+ self.res = pd.DataFrame(
1224
+ {
1225
+ "sample": list(test_dataloader.dataset.x), # type: ignore[attr-defined]
1226
+ "real_label": image_labels,
1227
+ "pred_label": predicted_classes,
1228
+ "probability": probabilities,
1229
+ }
1230
+ )
1231
+
1232
+ # save results
1233
+ self.metadata["test_confusion_matrix"] = pd_cm
1234
+ self.metadata["test_accuracy"] = test_accuracy
1235
+ self.metadata["predictions"] = predicted_classes
1236
+ self.metadata["test_results"] = res
1237
+ self.metadata["probabilities"] = probabilities
1238
+ self.metadata["test_labels"] = image_labels
1239
+ self.metadata["grayscale_cams"] = grayscale_cams
1240
+
1241
+ def generate_report(self) -> None:
1242
+ """Generate a report for the task."""
1243
+ log.info("Generating report!")
1244
+ os.makedirs(self.report_path, exist_ok=True)
1245
+
1246
+ save_classification_result(
1247
+ results=self.metadata["test_results"],
1248
+ output_folder=self.report_path,
1249
+ confmat=self.metadata["test_confusion_matrix"],
1250
+ accuracy=self.metadata["test_accuracy"],
1251
+ test_dataloader=self.datamodule.test_dataloader(),
1252
+ config=self.config,
1253
+ output=self.output,
1254
+ grayscale_cams=self.metadata["grayscale_cams"],
1255
+ )
1256
+
1257
+ def execute(self) -> None:
1258
+ """Execute the evaluation."""
1259
+ self.prepare()
1260
+ self.test()
1261
+ if self.report:
1262
+ self.generate_report()
1263
+ self.finalize()