quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
quadra/models/base.py ADDED
@@ -0,0 +1,151 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from collections.abc import Sequence
5
+ from typing import Any
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from quadra.utils.logger import get_logger
11
+
12
+ log = get_logger(__name__)
13
+
14
+
15
+ class ModelSignatureWrapper(nn.Module):
16
+ """Model wrapper used to retrieve input shape. It can be used as a decorator of nn.Module, the first call to the
17
+ forward method will retrieve the input shape and store it in the input_shapes attribute.
18
+ It will also save the model summary in a file called model_summary.txt in the current working directory.
19
+ """
20
+
21
+ def __init__(self, model: nn.Module):
22
+ super().__init__()
23
+ self.instance = model
24
+ self.input_shapes: Any = None
25
+ self.disable = False
26
+
27
+ if isinstance(self.instance, ModelSignatureWrapper):
28
+ # Handle nested ModelSignatureWrapper
29
+ self.input_shapes = self.instance.input_shapes
30
+ self.instance = self.instance.instance
31
+
32
+ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
33
+ """Retrieve the input shape and forward the model, if the input shape is already retrieved it will just forward
34
+ the model.
35
+ """
36
+ if self.input_shapes is None and not self.disable:
37
+ try:
38
+ self.input_shapes = self._get_input_shapes(*args, **kwargs)
39
+ except Exception:
40
+ log.warning(
41
+ "Failed to retrieve input shapes after forward! To export the model you'll need to "
42
+ "provide the input shapes manually setting the config.export.input_shapes parameter! "
43
+ "Alternatively you could try to use a forward with supported input types (and their compositions) "
44
+ "(list, tuple, dict, tensors)."
45
+ )
46
+ self.disable = True
47
+
48
+ return self.instance.forward(*args, **kwargs)
49
+
50
+ def to(self, *args, **kwargs):
51
+ """Handle calls to to method returning the underlying model."""
52
+ self.instance = self.instance.to(*args, **kwargs)
53
+
54
+ return self
55
+
56
+ def half(self, *args, **kwargs):
57
+ """Handle calls to to method returning the underlying model."""
58
+ self.instance = self.instance.half(*args, **kwargs)
59
+
60
+ return self
61
+
62
+ def cpu(self, *args, **kwargs):
63
+ """Handle calls to to method returning the underlying model."""
64
+ self.instance = self.instance.cpu(*args, **kwargs)
65
+
66
+ return self
67
+
68
+ def _get_input_shapes(self, *args: Any, **kwargs: Any) -> list[Any]:
69
+ """Retrieve the input shapes from the input. Inputs will be in the same order as the forward method
70
+ signature.
71
+ """
72
+ input_shapes = []
73
+
74
+ for arg in args:
75
+ input_shapes.append(self._get_input_shape(arg))
76
+
77
+ if isinstance(self.instance.forward, torch.ScriptMethod):
78
+ # Handle torchscript backbones
79
+ for i, argument in enumerate(self.instance.forward.schema.arguments):
80
+ if i < (len(args) + 1): # +1 for self
81
+ continue
82
+
83
+ if argument.name == "self":
84
+ continue
85
+
86
+ if argument.name in kwargs:
87
+ input_shapes.append(self._get_input_shape(kwargs[argument.name]))
88
+ else:
89
+ # Retrieve the default value
90
+ input_shapes.append(self._get_input_shape(argument.default_value))
91
+ else:
92
+ signature = inspect.signature(self.instance.forward)
93
+
94
+ for i, key in enumerate(signature.parameters.keys()):
95
+ if i < len(args):
96
+ continue
97
+
98
+ if key in kwargs:
99
+ input_shapes.append(self._get_input_shape(kwargs[key]))
100
+ else:
101
+ # Retrieve the default value
102
+ input_shapes.append(self._get_input_shape(signature.parameters[key].default))
103
+
104
+ return input_shapes
105
+
106
+ def _get_input_shape(self, inp: Sequence | torch.Tensor) -> list[Any] | tuple[Any, ...] | dict[str, Any]:
107
+ """Recursive function to retrieve the input shapes."""
108
+ if isinstance(inp, list):
109
+ return [self._get_input_shape(i) for i in inp]
110
+
111
+ if isinstance(inp, tuple):
112
+ return tuple(self._get_input_shape(i) for i in inp)
113
+
114
+ if isinstance(inp, torch.Tensor):
115
+ return tuple(inp.shape[1:])
116
+
117
+ if isinstance(inp, dict):
118
+ return {k: self._get_input_shape(v) for k, v in inp.items()}
119
+
120
+ raise ValueError(f"Input type {type(inp)} not supported")
121
+
122
+ def __getattr__(self, name: str) -> torch.Tensor | nn.Module:
123
+ if name in ["instance", "input_shapes"]:
124
+ return self.__dict__[name]
125
+
126
+ return getattr(self.__dict__["instance"], name)
127
+
128
+ def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None:
129
+ if name in ["instance", "input_shapes"]:
130
+ self.__dict__[name] = value
131
+ else:
132
+ setattr(self.instance, name, value)
133
+
134
+ def __getattribute__(self, __name: str) -> Any:
135
+ if __name in [
136
+ "instance",
137
+ "input_shapes",
138
+ "__dict__",
139
+ "forward",
140
+ "_get_input_shapes",
141
+ "_get_input_shape",
142
+ "to",
143
+ "half",
144
+ "cpu",
145
+ "call_super_init",
146
+ "_call_impl",
147
+ "_compiled_call_impl",
148
+ ]:
149
+ return super().__getattribute__(__name)
150
+
151
+ return getattr(self.instance, __name)
@@ -0,0 +1,8 @@
1
+ from .backbones import BaseNetworkBuilder, TimmNetworkBuilder, TorchHubNetworkBuilder, TorchVisionNetworkBuilder
2
+
3
+ __all__ = [
4
+ "BaseNetworkBuilder",
5
+ "TorchVisionNetworkBuilder",
6
+ "TorchHubNetworkBuilder",
7
+ "TimmNetworkBuilder",
8
+ ]
@@ -0,0 +1,149 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import timm
6
+ import torch
7
+ from timm.models.helpers import load_checkpoint
8
+ from torch import nn
9
+ from torchvision import models
10
+
11
+ from quadra.models.classification.base import BaseNetworkBuilder
12
+ from quadra.utils.logger import get_logger
13
+
14
+ log = get_logger(__name__)
15
+
16
+
17
+ class TorchHubNetworkBuilder(BaseNetworkBuilder):
18
+ """TorchHub feature extractor, with the possibility to map features to an hypersphere.
19
+
20
+ Args:
21
+ repo_or_dir: The name of the repository or the path to the directory containing the model.
22
+ model_name: The name of the model within the repository.
23
+ pretrained: Whether to load the pretrained weights for the model.
24
+ pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
25
+ classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
26
+ freeze: Whether to freeze the feature extractor. Defaults to True.
27
+ hyperspherical: Whether to map features to an hypersphere. Defaults to False.
28
+ flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True.
29
+ checkpoint_path: Path to a checkpoint to load after the model is initialized. Defaults to None.
30
+ **torch_hub_kwargs: Additional arguments to pass to torch.hub.load
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ repo_or_dir: str,
36
+ model_name: str,
37
+ pretrained: bool = True,
38
+ pre_classifier: nn.Module | None = None,
39
+ classifier: nn.Module | None = None,
40
+ freeze: bool = True,
41
+ hyperspherical: bool = False,
42
+ flatten_features: bool = True,
43
+ checkpoint_path: str | None = None,
44
+ **torch_hub_kwargs: Any,
45
+ ):
46
+ self.pretrained = pretrained
47
+ features_extractor = torch.hub.load(
48
+ repo_or_dir=repo_or_dir, model=model_name, pretrained=self.pretrained, **torch_hub_kwargs
49
+ )
50
+ if checkpoint_path:
51
+ log.info("Loading checkpoint from %s", checkpoint_path)
52
+ load_checkpoint(model=features_extractor, checkpoint_path=checkpoint_path)
53
+
54
+ super().__init__(
55
+ features_extractor=features_extractor,
56
+ pre_classifier=pre_classifier,
57
+ classifier=classifier,
58
+ freeze=freeze,
59
+ hyperspherical=hyperspherical,
60
+ flatten_features=flatten_features,
61
+ )
62
+
63
+
64
+ class TorchVisionNetworkBuilder(BaseNetworkBuilder):
65
+ """Torchvision feature extractor, with the possibility to map features to an hypersphere.
66
+
67
+ Args:
68
+ model_name: Torchvision model function that will be evaluated, for example: torchvision.models.resnet18.
69
+ pretrained: Whether to load the pretrained weights for the model.
70
+ pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
71
+ classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
72
+ freeze: Whether to freeze the feature extractor. Defaults to True.
73
+ hyperspherical: Whether to map features to an hypersphere. Defaults to False.
74
+ flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True.
75
+ checkpoint_path: Path to a checkpoint to load after the model is initialized. Defaults to None.
76
+ **torchvision_kwargs: Additional arguments to pass to the model function.
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ model_name: str,
82
+ pretrained: bool = True,
83
+ pre_classifier: nn.Module | None = None,
84
+ classifier: nn.Module | None = None,
85
+ freeze: bool = True,
86
+ hyperspherical: bool = False,
87
+ flatten_features: bool = True,
88
+ checkpoint_path: str | None = None,
89
+ **torchvision_kwargs: Any,
90
+ ):
91
+ self.pretrained = pretrained
92
+ model_function = models.__dict__[model_name]
93
+ features_extractor = model_function(pretrained=self.pretrained, progress=True, **torchvision_kwargs)
94
+ if checkpoint_path:
95
+ log.info("Loading checkpoint from %s", checkpoint_path)
96
+ load_checkpoint(model=features_extractor, checkpoint_path=checkpoint_path)
97
+
98
+ # Remove classifier
99
+ features_extractor.classifier = nn.Identity()
100
+ super().__init__(
101
+ features_extractor=features_extractor,
102
+ pre_classifier=pre_classifier,
103
+ classifier=classifier,
104
+ freeze=freeze,
105
+ hyperspherical=hyperspherical,
106
+ flatten_features=flatten_features,
107
+ )
108
+
109
+
110
+ class TimmNetworkBuilder(BaseNetworkBuilder):
111
+ """Torchvision feature extractor, with the possibility to map features to an hypersphere.
112
+
113
+ Args:
114
+ model_name: Timm model name
115
+ pretrained: Whether to load the pretrained weights for the model.
116
+ pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
117
+ classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
118
+ freeze: Whether to freeze the feature extractor. Defaults to True.
119
+ hyperspherical: Whether to map features to an hypersphere. Defaults to False.
120
+ flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True.
121
+ checkpoint_path: Path to a checkpoint to load after the model is initialized. Defaults to None.
122
+ **timm_kwargs: Additional arguments to pass to timm.create_model
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ model_name: str,
128
+ pretrained: bool = True,
129
+ pre_classifier: nn.Module | None = None,
130
+ classifier: nn.Module | None = None,
131
+ freeze: bool = True,
132
+ hyperspherical: bool = False,
133
+ flatten_features: bool = True,
134
+ checkpoint_path: str | None = None,
135
+ **timm_kwargs: Any,
136
+ ):
137
+ self.pretrained = pretrained
138
+ features_extractor = timm.create_model(
139
+ model_name, pretrained=self.pretrained, num_classes=0, checkpoint_path=checkpoint_path, **timm_kwargs
140
+ )
141
+
142
+ super().__init__(
143
+ features_extractor=features_extractor,
144
+ pre_classifier=pre_classifier,
145
+ classifier=classifier,
146
+ freeze=freeze,
147
+ hyperspherical=hyperspherical,
148
+ flatten_features=flatten_features,
149
+ )
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ from torch import nn
4
+
5
+ from quadra.utils.models import L2Norm
6
+
7
+
8
+ class BaseNetworkBuilder(nn.Module):
9
+ """Baseline Feature Extractor, with the possibility to map features to an hypersphere.
10
+ If hypershperical is True the classifier is ignored.
11
+
12
+ Args:
13
+ features_extractor: Feature extractor as a toch.nn.Module.
14
+ pre_classifier: Pre classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
15
+ classifier: Classifier as a torch.nn.Module. Defaults to nn.Identity() if None.
16
+ freeze: Whether to freeze the feature extractor. Defaults to True.
17
+ hyperspherical: Whether to map features to an hypersphere. Defaults to False.
18
+ flatten_features: Whether to flatten the features before the pre_classifier. May be required if your model
19
+ is outputting a feature map rather than a vector. Defaults to True.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ features_extractor: nn.Module,
25
+ pre_classifier: nn.Module | None = None,
26
+ classifier: nn.Module | None = None,
27
+ freeze: bool = True,
28
+ hyperspherical: bool = False,
29
+ flatten_features: bool = True,
30
+ ):
31
+ super().__init__()
32
+ if pre_classifier is None:
33
+ pre_classifier = nn.Identity()
34
+
35
+ if classifier is None:
36
+ classifier = nn.Identity()
37
+
38
+ self.features_extractor = features_extractor
39
+ self.freeze = freeze
40
+ self.hyperspherical = hyperspherical
41
+ self.pre_classifier = pre_classifier
42
+ self.classifier = classifier
43
+ self.flatten: bool = False
44
+ self._hyperspherical: bool = False
45
+ self.l2: L2Norm | None = None
46
+ self.flatten_features = flatten_features
47
+
48
+ self.freeze = freeze
49
+ self.hyperspherical = hyperspherical
50
+
51
+ if self.freeze:
52
+ for p in self.features_extractor.parameters():
53
+ p.requires_grad = False
54
+
55
+ @property
56
+ def freeze(self) -> bool:
57
+ """Whether to freeze the feature extractor."""
58
+ return self._freeze
59
+
60
+ @freeze.setter
61
+ def freeze(self, value: bool) -> None:
62
+ """Whether to freeze the feature extractor."""
63
+ for p in self.features_extractor.parameters():
64
+ p.requires_grad = not value
65
+
66
+ self._freeze = value
67
+
68
+ @property
69
+ def hyperspherical(self) -> bool:
70
+ """Whether to map the extracted features into an hypersphere."""
71
+ return self._hyperspherical
72
+
73
+ @hyperspherical.setter
74
+ def hyperspherical(self, value: bool) -> None:
75
+ """Whether to map the extracted features into an hypersphere."""
76
+ self._hyperspherical = value
77
+ self.l2 = L2Norm() if value else None
78
+
79
+ def forward(self, x):
80
+ x = self.features_extractor(x)
81
+
82
+ if self.flatten_features:
83
+ x = x.view(x.size(0), -1)
84
+
85
+ x = self.pre_classifier(x)
86
+
87
+ if self.hyperspherical:
88
+ x = self.l2(x)
89
+
90
+ x = self.classifier(x)
91
+
92
+ return x