quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
quadra/tasks/ssl.py ADDED
@@ -0,0 +1,560 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ from typing import Any, cast
6
+
7
+ import hydra
8
+ import torch
9
+ from omegaconf import DictConfig, open_dict
10
+ from pytorch_lightning import LightningModule
11
+ from torch import nn
12
+ from torch.nn.functional import interpolate
13
+ from torch.utils.tensorboard import SummaryWriter
14
+ from tqdm import tqdm
15
+
16
+ from quadra.callbacks.scheduler import WarmupInit
17
+ from quadra.models.base import ModelSignatureWrapper
18
+ from quadra.models.evaluation import BaseEvaluationModel
19
+ from quadra.tasks.base import LightningTask, Task
20
+ from quadra.utils import utils
21
+ from quadra.utils.export import export_model, import_deployment_model
22
+
23
+ log = utils.get_logger(__name__)
24
+
25
+
26
+ class SSL(LightningTask):
27
+ """SSL Task.
28
+
29
+ Args:
30
+ config: The experiment configuration
31
+ checkpoint_path: The path to the checkpoint to load the model from Defaults to None
32
+ report: Whether to create the report
33
+ run_test: Whether to run final test
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ config: DictConfig,
39
+ run_test: bool = False,
40
+ report: bool = False,
41
+ checkpoint_path: str | None = None,
42
+ ):
43
+ super().__init__(
44
+ config=config,
45
+ checkpoint_path=checkpoint_path,
46
+ run_test=run_test,
47
+ report=report,
48
+ )
49
+ self._backbone: nn.Module
50
+ self._optimizer: torch.optim.Optimizer
51
+ self._lr_scheduler: torch.optim.lr_scheduler._LRScheduler
52
+ self.export_folder = "deployment_model"
53
+
54
+ def learnable_parameters(self) -> list[nn.Parameter]:
55
+ """Get the learnable parameters."""
56
+ raise NotImplementedError("This method must be implemented by the subclass")
57
+
58
+ @property
59
+ def optimizer(self) -> torch.optim.Optimizer:
60
+ """Get the optimizer."""
61
+ return self._optimizer
62
+
63
+ @optimizer.setter
64
+ def optimizer(self, optimizer_config: DictConfig) -> None:
65
+ """Set the optimizer."""
66
+ log.info("Instantiating optimizer <%s>", self.config.optimizer["_target_"])
67
+ self._optimizer = hydra.utils.instantiate(optimizer_config, self.learnable_parameters())
68
+
69
+ @property
70
+ def scheduler(self) -> torch.optim.lr_scheduler._LRScheduler:
71
+ """Get the scheduler."""
72
+ return self._scheduler
73
+
74
+ @scheduler.setter
75
+ def scheduler(self, scheduler_config: DictConfig) -> None:
76
+ log.info("Instantiating scheduler <%s>", scheduler_config["_target_"])
77
+ if "CosineAnnealingWithLinearWarmUp" in self.config.scheduler["_target_"]:
78
+ # This scheduler will be overwritten by the SSLCallback
79
+ self._scheduler = hydra.utils.instantiate(
80
+ scheduler_config,
81
+ optimizer=self.optimizer,
82
+ batch_size=1,
83
+ len_loader=1,
84
+ )
85
+ self.add_callback(WarmupInit(scheduler_config=scheduler_config))
86
+ else:
87
+ self._scheduler = hydra.utils.instantiate(scheduler_config, optimizer=self.optimizer)
88
+
89
+ def test(self) -> None:
90
+ """Test the model."""
91
+ if self.run_test and not self.config.trainer.get("fast_dev_run"):
92
+ log.info("Starting testing!")
93
+ log.info("Using last epoch's weights for testing.")
94
+ self.trainer.test(datamodule=self.datamodule, model=self.module, ckpt_path=None)
95
+
96
+ def export(self) -> None:
97
+ """Deploy a model ready for production."""
98
+ half_precision = "16" in self.trainer.precision
99
+
100
+ input_shapes = self.config.export.input_shapes
101
+
102
+ model_json, export_paths = export_model(
103
+ config=self.config,
104
+ model=self.module.model,
105
+ export_folder=self.export_folder,
106
+ half_precision=half_precision,
107
+ input_shapes=input_shapes,
108
+ idx_to_class=None,
109
+ )
110
+
111
+ if len(export_paths) == 0:
112
+ return
113
+
114
+ with open(os.path.join(self.export_folder, "model.json"), "w") as f:
115
+ json.dump(model_json, f, cls=utils.HydraEncoder)
116
+
117
+
118
+ class Simsiam(SSL):
119
+ """Simsiam model as a pytorch_lightning.LightningModule.
120
+
121
+ Args:
122
+ config: the main config
123
+ checkpoint_path: if a checkpoint is specified, then it will return a trained model,
124
+ with weights loaded from the checkpoint path specified.
125
+ Defaults to None.
126
+ run_test: Whether to run final test
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ config: DictConfig,
132
+ checkpoint_path: str | None = None,
133
+ run_test: bool = False,
134
+ ):
135
+ super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)
136
+ self.backbone: nn.Module
137
+ self.projection_mlp: nn.Module
138
+ self.prediction_mlp: nn.Module
139
+
140
+ def learnable_parameters(self) -> list[nn.Parameter]:
141
+ """Get the learnable parameters."""
142
+ return list(
143
+ list(self.backbone.parameters())
144
+ + list(self.projection_mlp.parameters())
145
+ + list(self.prediction_mlp.parameters()),
146
+ )
147
+
148
+ def prepare(self) -> None:
149
+ """Prepare the experiment."""
150
+ super().prepare()
151
+ self.backbone = hydra.utils.instantiate(self.config.model.model)
152
+ self.projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
153
+ self.prediction_mlp = hydra.utils.instantiate(self.config.model.prediction_mlp)
154
+ self.optimizer = self.config.optimizer
155
+ self.scheduler = self.config.scheduler
156
+ self.module = self.config.model.module
157
+
158
+ @property
159
+ def module(self) -> LightningModule:
160
+ """Get the module of the model."""
161
+ return super().module
162
+
163
+ @module.setter
164
+ def module(self, module_config):
165
+ """Set the module of the model."""
166
+ module = hydra.utils.instantiate(
167
+ module_config,
168
+ model=self.backbone,
169
+ projection_mlp=self.projection_mlp,
170
+ prediction_mlp=self.prediction_mlp,
171
+ optimizer=self.optimizer,
172
+ lr_scheduler=self.scheduler,
173
+ )
174
+ if self.checkpoint_path is not None:
175
+ module = module.__class__.load_from_checkpoint(
176
+ self.checkpoint_path,
177
+ model=self.backbone,
178
+ projection_mlp=self.projection_mlp,
179
+ prediction_mlp=self.prediction_mlp,
180
+ criterion=module.criterion,
181
+ optimizer=self.optimizer,
182
+ lr_scheduler=self.scheduler,
183
+ )
184
+ self._module = module
185
+
186
+
187
+ class SimCLR(SSL):
188
+ """SimCLR model as a pytorch_lightning.LightningModule.
189
+
190
+ Args:
191
+ config: the main config
192
+ checkpoint_path: if a checkpoint is specified, then it will return a trained model,
193
+ with weights loaded from the checkpoint path specified.
194
+ Defaults to None.
195
+ run_test: Whether to run final test
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ config: DictConfig,
201
+ checkpoint_path: str | None = None,
202
+ run_test: bool = False,
203
+ ):
204
+ super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)
205
+ self.backbone: nn.Module
206
+ self.projection_mlp: nn.Module
207
+
208
+ def learnable_parameters(self) -> list[nn.Parameter]:
209
+ """Get the learnable parameters."""
210
+ return list(self.backbone.parameters()) + list(self.projection_mlp.parameters())
211
+
212
+ def prepare(self) -> None:
213
+ """Prepare the experiment."""
214
+ super().prepare()
215
+ self.backbone = hydra.utils.instantiate(self.config.model.model)
216
+ self.projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
217
+ self.optimizer = self.config.optimizer
218
+ self.scheduler = self.config.scheduler
219
+ self.module = self.config.model.module
220
+
221
+ @property
222
+ def module(self) -> LightningModule:
223
+ return super().module
224
+
225
+ @module.setter
226
+ def module(self, module_config):
227
+ """Set the module of the model."""
228
+ module = hydra.utils.instantiate(
229
+ module_config,
230
+ model=self.backbone,
231
+ projection_mlp=self.projection_mlp,
232
+ optimizer=self.optimizer,
233
+ lr_scheduler=self.scheduler,
234
+ )
235
+ if self.checkpoint_path is not None:
236
+ module = module.__class__.load_from_checkpoint(
237
+ self.checkpoint_path,
238
+ model=self.backbone,
239
+ projection_mlp=self.projection_mlp,
240
+ criterion=module.criterion,
241
+ optimizer=self.optimizer,
242
+ lr_scheduler=self.scheduler,
243
+ )
244
+ self._module = module
245
+ self._module.model = ModelSignatureWrapper(self._module.model)
246
+
247
+
248
+ class Barlow(SimCLR):
249
+ """Barlow model as a pytorch_lightning.LightningModule.
250
+
251
+ Args:
252
+ config: the main config
253
+ checkpoint_path: if a checkpoint is specified, then it will return a trained model,
254
+ with weights loaded from the checkpoint path specified.
255
+ Defaults to None.
256
+ run_test: Whether to run final test
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ config: DictConfig,
262
+ checkpoint_path: str | None = None,
263
+ run_test: bool = False,
264
+ ):
265
+ super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)
266
+
267
+ def prepare(self) -> None:
268
+ """Prepare the experiment."""
269
+ super(SimCLR, self).prepare()
270
+ self.backbone = hydra.utils.instantiate(self.config.model.model)
271
+
272
+ with open_dict(self.config):
273
+ self.config.model.projection_mlp.hidden_dim = (
274
+ self.config.model.projection_mlp.hidden_dim * self.config.model.projection_mlp_mult
275
+ )
276
+ self.config.model.projection_mlp.output_dim = (
277
+ self.config.model.projection_mlp.output_dim * self.config.model.projection_mlp_mult
278
+ )
279
+ self.projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
280
+ self.optimizer = self.config.optimizer
281
+ self.scheduler = self.config.scheduler
282
+ self.module = self.config.model.module
283
+
284
+
285
+ class BYOL(SSL):
286
+ """BYOL model as a pytorch_lightning.LightningModule.
287
+
288
+ Args:
289
+ config: the main config
290
+ checkpoint_path: if a checkpoint is specified, then it will return a trained model,
291
+ with weights loaded from the checkpoint path specified.
292
+ Defaults to None.
293
+ run_test: Whether to run final test
294
+ **kwargs: Keyword arguments
295
+ """
296
+
297
+ def __init__(
298
+ self,
299
+ config: DictConfig,
300
+ checkpoint_path: str | None = None,
301
+ run_test: bool = False,
302
+ **kwargs: Any,
303
+ ):
304
+ super().__init__(
305
+ config=config,
306
+ checkpoint_path=checkpoint_path,
307
+ run_test=run_test,
308
+ **kwargs,
309
+ )
310
+ self.student_model: nn.Module
311
+ self.teacher_model: nn.Module
312
+ self.student_projection_mlp: nn.Module
313
+ self.student_prediction_mlp: nn.Module
314
+ self.teacher_projection_mlp: nn.Module
315
+
316
+ def learnable_parameters(self) -> list[nn.Parameter]:
317
+ """Get the learnable parameters."""
318
+ return list(
319
+ list(self.student_model.parameters())
320
+ + list(self.student_projection_mlp.parameters())
321
+ + list(self.student_prediction_mlp.parameters()),
322
+ )
323
+
324
+ def prepare(self) -> None:
325
+ """Prepare the experiment."""
326
+ super().prepare()
327
+ self.student_model = hydra.utils.instantiate(self.config.model.student)
328
+ self.teacher_model = hydra.utils.instantiate(self.config.model.student)
329
+ self.student_projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
330
+ self.student_prediction_mlp = hydra.utils.instantiate(self.config.model.prediction_mlp)
331
+ self.teacher_projection_mlp = hydra.utils.instantiate(self.config.model.projection_mlp)
332
+ self.optimizer = self.config.optimizer
333
+ self.scheduler = self.config.scheduler
334
+ self.module = self.config.model.module
335
+
336
+ @property
337
+ def module(self) -> LightningModule:
338
+ return super().module
339
+
340
+ @module.setter
341
+ def module(self, module_config):
342
+ """Set the module of the model."""
343
+ module = hydra.utils.instantiate(
344
+ module_config,
345
+ student=self.student_model,
346
+ teacher=self.teacher_model,
347
+ student_projection_mlp=self.student_projection_mlp,
348
+ student_prediction_mlp=self.student_prediction_mlp,
349
+ teacher_projection_mlp=self.teacher_projection_mlp,
350
+ optimizer=self.optimizer,
351
+ lr_scheduler=self.scheduler,
352
+ )
353
+ if self.checkpoint_path is not None:
354
+ module = module.__class__.load_from_checkpoint(
355
+ self.checkpoint_path,
356
+ student=self.student_model,
357
+ teacher=self.teacher_model,
358
+ student_projection_mlp=self.student_projection_mlp,
359
+ student_prediction_mlp=self.student_prediction_mlp,
360
+ teacher_projection_mlp=self.teacher_projection_mlp,
361
+ criterion=module.criterion,
362
+ optimizer=self.optimizer,
363
+ lr_scheduler=self.scheduler,
364
+ )
365
+ self._module = module
366
+
367
+
368
+ class DINO(SSL):
369
+ """DINO model as a pytorch_lightning.LightningModule.
370
+
371
+ Args:
372
+ config: the main config
373
+ checkpoint_path: if a checkpoint is specified, then it will return a trained model,
374
+ with weights loaded from the checkpoint path specified.
375
+ Defaults to None.
376
+ run_test: Whether to run final test
377
+ """
378
+
379
+ def __init__(
380
+ self,
381
+ config: DictConfig,
382
+ checkpoint_path: str | None = None,
383
+ run_test: bool = False,
384
+ ):
385
+ super().__init__(config=config, checkpoint_path=checkpoint_path, run_test=run_test)
386
+ self.student_model: nn.Module
387
+ self.teacher_model: nn.Module
388
+ self.student_projection_mlp: nn.Module
389
+ self.teacher_projection_mlp: nn.Module
390
+
391
+ def learnable_parameters(self) -> list[nn.Parameter]:
392
+ """Get the learnable parameters."""
393
+ return list(
394
+ list(self.student_model.parameters()) + list(self.student_projection_mlp.parameters()),
395
+ )
396
+
397
+ def prepare(self) -> None:
398
+ """Prepare the experiment."""
399
+ super().prepare()
400
+ self.student_model = cast(nn.Module, hydra.utils.instantiate(self.config.model.student))
401
+ self.teacher_model = cast(nn.Module, hydra.utils.instantiate(self.config.model.student))
402
+ self.student_projection_mlp = cast(nn.Module, hydra.utils.instantiate(self.config.model.student_projection_mlp))
403
+ self.teacher_projection_mlp = cast(nn.Module, hydra.utils.instantiate(self.config.model.teacher_projection_mlp))
404
+ self.optimizer = self.config.optimizer
405
+ self.scheduler = self.config.scheduler
406
+ self.module = self.config.model.module
407
+
408
+ @property
409
+ def module(self) -> LightningModule:
410
+ return super().module
411
+
412
+ @module.setter
413
+ def module(self, module_config):
414
+ """Set the module of the model."""
415
+ module = hydra.utils.instantiate(
416
+ module_config,
417
+ student=self.student_model,
418
+ teacher=self.teacher_model,
419
+ student_projection_mlp=self.student_projection_mlp,
420
+ teacher_projection_mlp=self.teacher_projection_mlp,
421
+ optimizer=self.optimizer,
422
+ lr_scheduler=self.scheduler,
423
+ )
424
+ if self.checkpoint_path is not None:
425
+ module = module.__class__.load_from_checkpoint(
426
+ self.checkpoint_path,
427
+ student=self.student_model,
428
+ teacher=self.teacher_model,
429
+ student_projection_mlp=self.student_projection_mlp,
430
+ teacher_projection_mlp=self.teacher_projection_mlp,
431
+ criterion=module.criterion,
432
+ optimizer=self.optimizer,
433
+ lr_scheduler=self.scheduler,
434
+ )
435
+ self._module = module
436
+
437
+
438
+ class EmbeddingVisualization(Task):
439
+ """Visualization task for learned embeddings.
440
+
441
+ Args:
442
+ config: The loaded experiment config
443
+ model_path: The path to a deployment model
444
+ report_folder: Where to save the embeddings
445
+ embedding_image_size: If not None rescale the images associated with the embeddings, tensorboard will save
446
+ on disk a large sprite containing all the images in a matrix fashion, if the dimension of this sprite is too
447
+ big it's not possible to load it in the browser. Rescaling the output image from the model input size to
448
+ something smaller can solve this issue. The field is an int to always rescale to a squared image.
449
+ """
450
+
451
+ def __init__(
452
+ self,
453
+ config: DictConfig,
454
+ model_path: str,
455
+ report_folder: str = "embeddings",
456
+ embedding_image_size: int | None = None,
457
+ ):
458
+ super().__init__(config=config)
459
+
460
+ self.config = config
461
+ self.metadata = {
462
+ "report_files": [],
463
+ }
464
+ self.model_path = model_path
465
+ self._device = utils.get_device()
466
+ log.info("Current device: %s", self._device)
467
+
468
+ self.report_folder = report_folder
469
+ if self.model_path is None:
470
+ raise ValueError(
471
+ "Model path cannot be found!, please specify it in the config or pass it as an argument for"
472
+ " evaluation"
473
+ )
474
+ self.embeddings_path = os.path.join(self.model_path, self.report_folder)
475
+ if not os.path.exists(self.embeddings_path):
476
+ os.makedirs(self.embeddings_path)
477
+ self.embedding_writer = SummaryWriter(self.embeddings_path)
478
+ self.writer_step = 0 # for tensorboard
479
+ self.embedding_image_size = embedding_image_size
480
+ self._deployment_model: BaseEvaluationModel
481
+ self.deployment_model_type: str
482
+
483
+ @property
484
+ def device(self):
485
+ return self._device
486
+
487
+ @device.setter
488
+ def device(self, device: str):
489
+ self._device = device
490
+
491
+ if self.deployment_model is not None:
492
+ # After prepare
493
+ self.deployment_model = self.deployment_model.to(self._device)
494
+
495
+ @property
496
+ def deployment_model(self):
497
+ """Get the deployment model."""
498
+ return self._deployment_model
499
+
500
+ @deployment_model.setter
501
+ def deployment_model(self, model_path: str):
502
+ """Set the deployment model."""
503
+ self._deployment_model = import_deployment_model(
504
+ model_path=model_path, device=self.device, inference_config=self.config.inference
505
+ )
506
+
507
+ def prepare(self) -> None:
508
+ """Prepare the evaluation."""
509
+ super().prepare()
510
+ self.deployment_model = self.model_path
511
+
512
+ @torch.no_grad()
513
+ def test(self) -> None:
514
+ """Run embeddings extraction."""
515
+ self.datamodule.setup("fit")
516
+ idx_to_class = self.datamodule.val_dataset.idx_to_class
517
+ self.datamodule.setup("test")
518
+ dataloader = self.datamodule.test_dataloader()
519
+ images = []
520
+ metadata: list[tuple[int, str, str]] = []
521
+ embeddings = []
522
+ std = torch.tensor(self.config.transforms.std).view(1, -1, 1, 1)
523
+ mean = torch.tensor(self.config.transforms.mean).view(1, -1, 1, 1)
524
+ dl = self.datamodule.test_dataloader()
525
+ counter = 0
526
+
527
+ is_half_precision = False
528
+ for param in self.deployment_model.parameters():
529
+ if param.dtype == torch.half:
530
+ is_half_precision = True
531
+ break
532
+
533
+ for batch in tqdm(dataloader):
534
+ im, target = batch
535
+ if is_half_precision:
536
+ im = im.half()
537
+
538
+ x = self.deployment_model(im.to(self.device))
539
+ targets = [int(t.item()) for t in target]
540
+ class_names = [idx_to_class[t.item()] for t in target]
541
+ file_paths = [s[0] for s in dl.dataset.samples[counter : counter + len(im)]]
542
+ embeddings.append(x.cpu())
543
+ im = im * std
544
+ im += mean
545
+
546
+ if self.embedding_image_size is not None:
547
+ im = interpolate(im, self.embedding_image_size)
548
+
549
+ images.append(im.cpu())
550
+ metadata.extend(zip(targets, class_names, file_paths))
551
+ counter += len(im)
552
+ images = torch.cat(images, dim=0)
553
+ embeddings = torch.cat(embeddings, dim=0)
554
+ self.embedding_writer.add_embedding(
555
+ embeddings,
556
+ metadata=metadata,
557
+ label_img=images,
558
+ global_step=self.writer_step,
559
+ metadata_header=["class", "class_name", "path"],
560
+ )
@@ -0,0 +1,3 @@
1
+ # Trainers
2
+
3
+ Here are defined custom trainers that can be used to replace Pytorch Lightning. For example for classification we have implemented a trainer that uses the `scikit-learn` library to train a classifier on top of a torch feature extractor.
File without changes