quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,322 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, cast
5
+
6
+ import numpy as np
7
+ import torch
8
+ from hydra.utils import instantiate
9
+ from omegaconf import DictConfig, OmegaConf
10
+ from torch import nn
11
+ from torch.jit import RecursiveScriptModule
12
+
13
+ from quadra.utils.logger import get_logger
14
+
15
+ try:
16
+ import onnxruntime as ort # noqa
17
+
18
+ ONNX_AVAILABLE = True
19
+ except ImportError:
20
+ ONNX_AVAILABLE = False
21
+
22
+
23
+ log = get_logger(__name__)
24
+
25
+
26
+ class BaseEvaluationModel(ABC):
27
+ """Base interface for all evaluation models."""
28
+
29
+ def __init__(self, config: DictConfig) -> None:
30
+ self.model: Any
31
+ self.model_path: str | None
32
+ self.device: str
33
+ self.config = config
34
+ self.is_loaded = False
35
+ self.model_dtype: np.dtype | torch.dtype
36
+
37
+ @abstractmethod
38
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
39
+ pass
40
+
41
+ @abstractmethod
42
+ def load_from_disk(self, model_path: str, device: str = "cpu"):
43
+ """Load model from disk."""
44
+
45
+ @abstractmethod
46
+ def to(self, device: str):
47
+ """Move model to device."""
48
+
49
+ @abstractmethod
50
+ def eval(self):
51
+ """Set model to evaluation mode."""
52
+
53
+ @abstractmethod
54
+ def half(self):
55
+ """Convert model to half precision."""
56
+
57
+ @abstractmethod
58
+ def cpu(self):
59
+ """Move model to cpu."""
60
+
61
+ @property
62
+ def training(self) -> bool:
63
+ """Return whether model is in training mode."""
64
+ return False
65
+
66
+ @property
67
+ def device(self) -> str:
68
+ """Return the device of the model."""
69
+ return self._device
70
+
71
+ @device.setter
72
+ def device(self, device: str):
73
+ """Set the device of the model."""
74
+ if device == "cuda" and ":" not in device:
75
+ device = f"{device}:0"
76
+
77
+ self._device = device
78
+
79
+
80
+ class TorchscriptEvaluationModel(BaseEvaluationModel):
81
+ """Wrapper for torchscript models."""
82
+
83
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
84
+ return self.model(*args, **kwargs)
85
+
86
+ def load_from_disk(self, model_path: str, device: str = "cpu"):
87
+ """Load model from disk."""
88
+ self.model_path = model_path
89
+ self.device = device
90
+
91
+ model = cast(RecursiveScriptModule, torch.jit.load(self.model_path))
92
+ model.eval()
93
+ model.to(self.device)
94
+
95
+ parameter_types = {param.dtype for param in model.parameters()}
96
+ if len(parameter_types) == 2:
97
+ # TODO: There could be models with mixed precision?
98
+ raise ValueError(f"Expected only one type of parameters, found {parameter_types}")
99
+
100
+ self.model_dtype = list(parameter_types)[0]
101
+ self.model = model
102
+ self.is_loaded = True
103
+
104
+ def to(self, device: str):
105
+ """Move model to device."""
106
+ self.model.to(device)
107
+ self.device = device
108
+
109
+ def eval(self):
110
+ """Set model to evaluation mode."""
111
+ self.model.eval()
112
+
113
+ @property
114
+ def training(self) -> bool:
115
+ """Return whether model is in training mode."""
116
+ return self.model.training
117
+
118
+ def half(self):
119
+ """Convert model to half precision."""
120
+ self.model.half()
121
+
122
+ def cpu(self):
123
+ """Move model to cpu."""
124
+ self.model.cpu()
125
+
126
+
127
+ class TorchEvaluationModel(TorchscriptEvaluationModel):
128
+ """Wrapper for torch models.
129
+
130
+ Args:
131
+ model_architecture: Optional torch model architecture
132
+ """
133
+
134
+ def __init__(self, config: DictConfig, model_architecture: nn.Module) -> None:
135
+ super().__init__(config=config)
136
+ self.model = model_architecture
137
+ self.model.eval()
138
+ device = next(self.model.parameters()).device
139
+ self.device = str(device)
140
+
141
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
142
+ return self.model(*args, **kwargs)
143
+
144
+ def load_from_disk(self, model_path: str, device: str = "cpu"):
145
+ """Load model from disk."""
146
+ self.model_path = model_path
147
+ self.device = device
148
+ self.model.load_state_dict(torch.load(self.model_path))
149
+ self.model.eval()
150
+ self.model.to(self.device)
151
+
152
+ parameter_types = {param.dtype for param in self.model.parameters()}
153
+ if len(parameter_types) == 2:
154
+ # TODO: There could be models with mixed precision?
155
+ raise ValueError(f"Expected only one type of parameters, found {parameter_types}")
156
+
157
+ self.model_dtype = list(parameter_types)[0]
158
+ self.is_loaded = True
159
+
160
+
161
+ onnx_to_torch_dtype_dict = {
162
+ "tensor(bool)": torch.bool,
163
+ "tensor(uint8)": torch.uint8,
164
+ "tensor(int8)": torch.int8,
165
+ "tensor(int16)": torch.int16,
166
+ "tensor(int32)": torch.int32,
167
+ "tensor(int64)": torch.int64,
168
+ "tensor(float16)": torch.float16,
169
+ "tensor(float32)": torch.float32,
170
+ "tensor(float)": torch.float32,
171
+ "tensor(float64)": torch.float64,
172
+ "tensor(complex64)": torch.complex64,
173
+ "tensor(complex128)": torch.complex128,
174
+ }
175
+
176
+
177
+ class ONNXEvaluationModel(BaseEvaluationModel):
178
+ """Wrapper for ONNX models. It's designed to provide a similar interface to standard torch models."""
179
+
180
+ def __init__(self, config: DictConfig) -> None:
181
+ if not ONNX_AVAILABLE:
182
+ raise ImportError(
183
+ "onnxruntime is not installed. Please install ONNX capabilities for quadra with: poetry install -E onnx"
184
+ )
185
+ super().__init__(config=config)
186
+ self.session_options = self.generate_session_options()
187
+
188
+ def generate_session_options(self) -> ort.SessionOptions:
189
+ """Generate session options from the current config."""
190
+ session_options = ort.SessionOptions()
191
+
192
+ if hasattr(self.config, "session_options") and self.config.session_options is not None:
193
+ session_options_dict = cast(
194
+ dict[str, Any], OmegaConf.to_container(self.config.session_options, resolve=True)
195
+ )
196
+ for key, value in session_options_dict.items():
197
+ final_value = value
198
+ if isinstance(value, dict) and "_target_" in value:
199
+ final_value = instantiate(final_value)
200
+
201
+ setattr(session_options, key, final_value)
202
+
203
+ return session_options
204
+
205
+ def __call__(self, *inputs: np.ndarray | torch.Tensor) -> Any:
206
+ """Run inference on the model and return the output as torch tensors."""
207
+ # TODO: Maybe we can support also kwargs
208
+ use_pytorch = False
209
+
210
+ onnx_inputs: dict[str, np.ndarray | torch.Tensor] = {}
211
+
212
+ for onnx_input, current_input in zip(self.model.get_inputs(), inputs):
213
+ if isinstance(current_input, torch.Tensor):
214
+ onnx_inputs[onnx_input.name] = current_input
215
+ use_pytorch = True
216
+ elif isinstance(current_input, np.ndarray):
217
+ onnx_inputs[onnx_input.name] = current_input
218
+ else:
219
+ raise ValueError(f"Invalid input type: {type(inputs)}")
220
+
221
+ if use_pytorch and isinstance(current_input, np.ndarray):
222
+ raise ValueError("Cannot mix torch and numpy inputs")
223
+
224
+ if use_pytorch:
225
+ onnx_output = self._forward_from_pytorch(cast(dict[str, torch.Tensor], onnx_inputs))
226
+ else:
227
+ onnx_output = self._forward_from_numpy(cast(dict[str, np.ndarray], onnx_inputs))
228
+
229
+ onnx_output = [torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x for x in onnx_output]
230
+
231
+ if len(onnx_output) == 1:
232
+ onnx_output = onnx_output[0]
233
+
234
+ return onnx_output
235
+
236
+ def _forward_from_pytorch(self, input_dict: dict[str, torch.Tensor]):
237
+ """Run inference on the model and return the output as torch tensors."""
238
+ io_binding = self.model.io_binding()
239
+ device_type = self.device.split(":")[0]
240
+
241
+ for k, v in input_dict.items():
242
+ if not v.is_contiguous():
243
+ # If not contiguous onnx give wrong results
244
+ v = v.contiguous() # noqa: PLW2901
245
+
246
+ io_binding.bind_input(
247
+ name=k,
248
+ device_type=device_type,
249
+ # Weirdly enough onnx wants 0 for cpu
250
+ device_id=0 if device_type == "cpu" else int(self.device.split(":")[1]),
251
+ element_type=np.float16 if v.dtype == torch.float16 else np.float32,
252
+ shape=tuple(v.shape),
253
+ buffer_ptr=v.data_ptr(),
254
+ )
255
+
256
+ for x in self.model.get_outputs():
257
+ # TODO: Is it possible to also bind the output? We require info about output dimensions
258
+ io_binding.bind_output(name=x.name)
259
+
260
+ self.model.run_with_iobinding(io_binding)
261
+
262
+ output = io_binding.copy_outputs_to_cpu()
263
+
264
+ return output
265
+
266
+ def _forward_from_numpy(self, input_dict: dict[str, np.ndarray]):
267
+ """Run inference on the model and return the output as numpy array."""
268
+ ort_outputs = [x.name for x in self.model.get_outputs()]
269
+
270
+ onnx_output = self.model.run(ort_outputs, input_dict)
271
+
272
+ return onnx_output
273
+
274
+ def load_from_disk(self, model_path: str, device: str = "cpu"):
275
+ """Load model from disk."""
276
+ self.model_path = model_path
277
+ self.device = device
278
+
279
+ ort_providers = self._get_providers(device)
280
+ self.model = ort.InferenceSession(self.model_path, providers=ort_providers, sess_options=self.session_options)
281
+ self.model_dtype = self.cast_onnx_dtype(self.model.get_inputs()[0].type)
282
+ self.is_loaded = True
283
+
284
+ def _get_providers(self, device: str) -> list[tuple[str, dict[str, Any]] | str]:
285
+ """Return the providers for the ONNX model based on the device."""
286
+ ort_providers: list[tuple[str, dict[str, Any]] | str]
287
+
288
+ if device == "cpu":
289
+ ort_providers = ["CPUExecutionProvider"]
290
+ else:
291
+ ort_providers = [
292
+ (
293
+ "CUDAExecutionProvider",
294
+ {
295
+ "device_id": int(device.split(":")[1]),
296
+ },
297
+ )
298
+ ]
299
+
300
+ return ort_providers
301
+
302
+ def to(self, device: str):
303
+ """Move model to device."""
304
+ self.device = device
305
+ ort_providers = self._get_providers(device)
306
+ self.model.set_providers(ort_providers)
307
+
308
+ def eval(self):
309
+ """Fake interface to match torch models."""
310
+ return self
311
+
312
+ def half(self):
313
+ """Convert model to half precision."""
314
+ raise NotImplementedError("At the moment ONNX models do not support half method.")
315
+
316
+ def cpu(self):
317
+ """Move model to cpu."""
318
+ self.to("cpu")
319
+
320
+ def cast_onnx_dtype(self, onnx_dtype: str) -> torch.dtype | np.dtype:
321
+ """Cast ONNX dtype to numpy or pytorch dtype."""
322
+ return onnx_to_torch_dtype_dict[onnx_dtype]
File without changes
@@ -0,0 +1,30 @@
1
+ from typing import Any
2
+
3
+ import segmentation_models_pytorch as smp
4
+
5
+
6
+ def create_smp_backbone(
7
+ arch: str,
8
+ encoder_name: str,
9
+ freeze_encoder: bool = False,
10
+ in_channels: int = 3,
11
+ num_classes: int = 0,
12
+ **kwargs: Any,
13
+ ):
14
+ """Create Segmentation.models.pytorch model backbone
15
+ Args:
16
+ arch: architecture name
17
+ encoder_name: architecture name
18
+ freeze_encoder: freeze encoder or not
19
+ in_channels: number of input channels
20
+ num_classes: number of classes
21
+ **kwargs: extra arguments for model (for example classification head).
22
+ """
23
+ model = smp.create_model(
24
+ arch=arch, encoder_name=encoder_name, in_channels=in_channels, classes=num_classes, **kwargs
25
+ )
26
+ if freeze_encoder:
27
+ for child in model.encoder.children():
28
+ for param in child.parameters():
29
+ param.requires_grad = False
30
+ return model
quadra/modules/base.py ADDED
@@ -0,0 +1,312 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from typing import Any
5
+
6
+ import pytorch_lightning as pl
7
+ import sklearn
8
+ import torch
9
+ import torchmetrics
10
+ from sklearn.linear_model import LogisticRegression
11
+ from torch import nn
12
+ from torch.optim import Optimizer
13
+
14
+ from quadra.models.base import ModelSignatureWrapper
15
+
16
+ __all__ = ["BaseLightningModule", "SSLModule"]
17
+
18
+
19
+ class BaseLightningModule(pl.LightningModule):
20
+ """Base lightning module.
21
+
22
+ Args:
23
+ model: Network Module used for extract features
24
+ optimizer: optimizer of the training. If None a default Adam is used.
25
+ lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ model: nn.Module,
31
+ optimizer: Optimizer | None = None,
32
+ lr_scheduler: object | None = None,
33
+ lr_scheduler_interval: str | None = "epoch",
34
+ ):
35
+ super().__init__()
36
+ self.model = ModelSignatureWrapper(model)
37
+ self.optimizer = optimizer
38
+ self.schedulers = lr_scheduler
39
+ self.lr_scheduler_interval = lr_scheduler_interval
40
+
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ """Forward method
43
+ Args:
44
+ x: input tensor.
45
+
46
+ Returns:
47
+ model inference
48
+ """
49
+ return self.model(x)
50
+
51
+ def configure_optimizers(self) -> tuple[list[Any], list[dict[str, Any]]]:
52
+ """Get default optimizer if not passed a value.
53
+
54
+ Returns:
55
+ optimizer and lr scheduler as Tuple containing a list of optimizers and a list of lr schedulers
56
+ """
57
+ # get default optimizer
58
+ if getattr(self, "optimizer", None) is None or not self.optimizer:
59
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=3e-4)
60
+
61
+ # get default scheduler
62
+ if getattr(self, "schedulers", None) is None or not self.schedulers:
63
+ self.schedulers = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5, patience=30)
64
+
65
+ lr_scheduler_conf = {
66
+ "scheduler": self.schedulers,
67
+ "interval": self.lr_scheduler_interval,
68
+ "monitor": "val_loss",
69
+ "strict": False,
70
+ }
71
+ return [self.optimizer], [lr_scheduler_conf]
72
+
73
+ # pylint: disable=unused-argument
74
+ def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx: int = 0):
75
+ """Redefine optimizer zero grad."""
76
+ optimizer.zero_grad(set_to_none=True)
77
+
78
+
79
+ class SSLModule(BaseLightningModule):
80
+ """Base module for self supervised learning.
81
+
82
+ Args:
83
+ model: Network Module used for extract features
84
+ criterion: SSL loss to be applied
85
+ classifier: Standard sklearn classifiers
86
+ optimizer: optimizer of the training. If None a default Adam is used.
87
+ lr_scheduler: lr scheduler. If None a default ReduceLROnPlateau is used.
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ model: nn.Module,
93
+ criterion: nn.Module,
94
+ classifier: sklearn.base.ClassifierMixin | None = None,
95
+ optimizer: Optimizer | None = None,
96
+ lr_scheduler: object | None = None,
97
+ lr_scheduler_interval: str | None = "epoch",
98
+ ):
99
+ super().__init__(model, optimizer, lr_scheduler, lr_scheduler_interval)
100
+ self.criterion = criterion
101
+ self.classifier_train_loader: torch.utils.data.DataLoader | None
102
+ if classifier is None:
103
+ self.classifier = LogisticRegression(max_iter=10000, n_jobs=8, random_state=42)
104
+ else:
105
+ self.classifier = classifier
106
+
107
+ self.val_acc = torchmetrics.Accuracy()
108
+
109
+ def fit_estimator(self):
110
+ """Fit a classifier on the embeddings extracted from the current trained model."""
111
+ targets = []
112
+ train_embeddings = []
113
+ self.model.eval()
114
+ with torch.no_grad():
115
+ for im, target in self.classifier_train_loader:
116
+ emb = self.model(im.to(self.device))
117
+ targets.append(target)
118
+ train_embeddings.append(emb)
119
+ targets = torch.cat(targets, dim=0).cpu().numpy()
120
+ train_embeddings = torch.cat(train_embeddings, dim=0).cpu().numpy()
121
+ self.classifier.fit(train_embeddings, targets)
122
+
123
+ def calculate_accuracy(self, batch):
124
+ """Calculate accuracy on a batch of data."""
125
+ images, labels = batch
126
+ with torch.no_grad():
127
+ embedding = self.model(images).cpu().numpy()
128
+
129
+ predictions = self.classifier.predict(embedding)
130
+ labels = labels.detach()
131
+ acc = self.val_acc(torch.tensor(predictions, device=self.device), labels)
132
+
133
+ return acc
134
+
135
+ # TODO: In multiprocessing mode, this function is called multiple times, how can we avoid this?
136
+ def on_validation_start(self) -> None:
137
+ if not hasattr(self, "classifier_train_loader") and hasattr(self.trainer, "datamodule"):
138
+ self.classifier_train_loader = self.trainer.datamodule.classifier_train_dataloader()
139
+
140
+ if self.classifier_train_loader is not None:
141
+ self.fit_estimator()
142
+
143
+ def validation_step(self, batch: tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor], batch_idx: int) -> None:
144
+ # pylint: disable=unused-argument
145
+ if self.classifier_train_loader is None:
146
+ # Compute loss
147
+ (im_x, im_y), _ = batch
148
+ z1 = self(im_x)
149
+ z2 = self(im_y)
150
+ loss = self.criterion(z1, z2)
151
+
152
+ self.log(
153
+ "val_loss",
154
+ loss,
155
+ on_epoch=True,
156
+ on_step=True,
157
+ logger=True,
158
+ prog_bar=True,
159
+ )
160
+ return loss
161
+
162
+ acc = self.calculate_accuracy(batch)
163
+ self.log("val_acc", acc, on_epoch=True, on_step=False, logger=True, prog_bar=True)
164
+ return None
165
+
166
+
167
+ class SegmentationModel(BaseLightningModule):
168
+ """Generic segmentation model.
169
+
170
+ Args:
171
+ model: segmentation model to be used.
172
+ loss_fun: loss function to be used.
173
+ optimizer: Optimizer to be used. Defaults to None.
174
+ lr_scheduler: lr scheduler to be used. Defaults to None.
175
+ """
176
+
177
+ def __init__(
178
+ self,
179
+ model: torch.nn.Module,
180
+ loss_fun: Callable,
181
+ optimizer: Optimizer | None = None,
182
+ lr_scheduler: object | None = None,
183
+ ):
184
+ super().__init__(model, optimizer, lr_scheduler)
185
+ self.loss_fun = loss_fun
186
+
187
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
188
+ """Forward method
189
+ Args:
190
+ x: input tensor.
191
+
192
+ Returns:
193
+ model inference
194
+ """
195
+ x = self.model(x)
196
+ return x
197
+
198
+ def step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
199
+ """Compute loss
200
+ Args:
201
+ batch: batch.
202
+
203
+ Returns:
204
+ Prediction and target masks
205
+ """
206
+ images, target_masks, _ = batch
207
+ pred_masks = self(images)
208
+ if len(pred_masks.shape) == 3:
209
+ pred_masks = pred_masks.unsqueeze(1)
210
+ if len(target_masks.shape) == 3:
211
+ target_masks = target_masks.unsqueeze(1)
212
+ assert pred_masks.shape == target_masks.shape
213
+
214
+ return pred_masks, target_masks
215
+
216
+ def compute_loss(self, pred_masks: torch.Tensor, target_masks: torch.Tensor) -> torch.Tensor:
217
+ """Compute loss
218
+ Args:
219
+ pred_masks: predicted masks
220
+ target_masks: target masks.
221
+
222
+ Returns:
223
+ The computed loss
224
+
225
+ """
226
+ loss = self.loss_fun(pred_masks, target_masks)
227
+ return loss
228
+
229
+ def training_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int):
230
+ """Training step."""
231
+ # pylint: disable=unused-argument
232
+ pred_masks, target_masks = self.step(batch)
233
+ loss = self.compute_loss(pred_masks, target_masks)
234
+ self.log_dict(
235
+ {"loss": loss},
236
+ on_step=True,
237
+ on_epoch=True,
238
+ prog_bar=True,
239
+ )
240
+ return loss
241
+
242
+ def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx):
243
+ """Validation step."""
244
+ # pylint: disable=unused-argument
245
+ pred_masks, target_masks = self.step(batch)
246
+ loss = self.compute_loss(pred_masks, target_masks)
247
+ self.log_dict(
248
+ {"val_loss": loss},
249
+ on_step=True,
250
+ on_epoch=True,
251
+ prog_bar=True,
252
+ )
253
+ return loss
254
+
255
+ def test_step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int):
256
+ """Test step."""
257
+ # pylint: disable=unused-argument
258
+ pred_masks, target_masks = self.step(batch)
259
+ loss = self.compute_loss(pred_masks, target_masks)
260
+ self.log_dict(
261
+ {"test_loss": loss},
262
+ on_step=True,
263
+ on_epoch=True,
264
+ prog_bar=True,
265
+ )
266
+ return loss
267
+
268
+ def predict_step(
269
+ self,
270
+ batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
271
+ batch_idx: int,
272
+ dataloader_idx: int | None = None,
273
+ ) -> Any:
274
+ """Predict step."""
275
+ # pylint: disable=unused-argument
276
+ images, masks, labels = batch
277
+ pred_masks = self(images)
278
+ return images.cpu(), masks.cpu(), pred_masks.cpu(), labels.cpu()
279
+
280
+
281
+ class SegmentationModelMulticlass(SegmentationModel):
282
+ """Generic multiclass segmentation model.
283
+
284
+ Args:
285
+ model: segmentation model to be used.
286
+ loss_fun: loss function to be used.
287
+ optimizer: Optimizer to be used. Defaults to None.
288
+ lr_scheduler: lr scheduler to be used. Defaults to None.
289
+ """
290
+
291
+ def __init__(
292
+ self,
293
+ model: torch.nn.Module,
294
+ loss_fun: Callable,
295
+ optimizer: Optimizer | None = None,
296
+ lr_scheduler: object | None = None,
297
+ ):
298
+ super().__init__(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, loss_fun=loss_fun)
299
+
300
+ def step(self, batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
301
+ """Compute step
302
+ Args:
303
+ batch: batch.
304
+
305
+ Returns:
306
+ prediction, target
307
+
308
+ """
309
+ images, target_masks, _ = batch
310
+ pred_masks = self(images)
311
+
312
+ return pred_masks, target_masks
@@ -0,0 +1,3 @@
1
+ from .base import ClassificationModule, MultilabelClassificationModule
2
+
3
+ __all__ = ["ClassificationModule", "MultilabelClassificationModule"]