quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
quadra/utils/export.py ADDED
@@ -0,0 +1,585 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import os
5
+ from collections.abc import Sequence
6
+ from typing import Any, Literal, TypeVar, cast
7
+
8
+ import torch
9
+ from anomalib.models.cflow import CflowLightning
10
+ from omegaconf import DictConfig, ListConfig, OmegaConf
11
+ from torch import nn
12
+
13
+ from quadra.models.base import ModelSignatureWrapper
14
+ from quadra.models.evaluation import (
15
+ BaseEvaluationModel,
16
+ ONNXEvaluationModel,
17
+ TorchEvaluationModel,
18
+ TorchscriptEvaluationModel,
19
+ )
20
+ from quadra.utils.logger import get_logger
21
+
22
+ try:
23
+ import onnx # noqa
24
+ from onnxsim import simplify as onnx_simplify # noqa
25
+ from onnxconverter_common import auto_convert_mixed_precision # noqa
26
+
27
+ ONNX_AVAILABLE = True
28
+ except ImportError:
29
+ ONNX_AVAILABLE = False
30
+
31
+ log = get_logger(__name__)
32
+
33
+ BaseDeploymentModelT = TypeVar("BaseDeploymentModelT", bound=BaseEvaluationModel)
34
+
35
+
36
+ def generate_torch_inputs(
37
+ input_shapes: list[Any],
38
+ device: str | torch.device,
39
+ half_precision: bool = False,
40
+ dtype: torch.dtype = torch.float32,
41
+ batch_size: int = 1,
42
+ ) -> list[Any] | tuple[Any, ...] | torch.Tensor:
43
+ """Given a list of input shapes that can contain either lists, tuples or dicts, with tuples being the input shapes
44
+ of the model, generate a list of torch tensors with the given device and dtype.
45
+ """
46
+ inp = None
47
+
48
+ if isinstance(input_shapes, (ListConfig, DictConfig)):
49
+ input_shapes = OmegaConf.to_container(input_shapes)
50
+
51
+ if isinstance(input_shapes, list):
52
+ if any(isinstance(inp, (Sequence, dict)) for inp in input_shapes):
53
+ return [generate_torch_inputs(inp, device, half_precision, dtype) for inp in input_shapes]
54
+
55
+ # Base case
56
+ inp = torch.randn((batch_size, *input_shapes), dtype=dtype, device=device)
57
+
58
+ if isinstance(input_shapes, dict):
59
+ return {k: generate_torch_inputs(v, device, half_precision, dtype) for k, v in input_shapes.items()}
60
+
61
+ if isinstance(input_shapes, tuple):
62
+ if any(isinstance(inp, (Sequence, dict)) for inp in input_shapes):
63
+ # The tuple contains a list, tuple or dict
64
+ return tuple(generate_torch_inputs(inp, device, half_precision, dtype) for inp in input_shapes)
65
+
66
+ # Base case
67
+ inp = torch.randn((batch_size, *input_shapes), dtype=dtype, device=device)
68
+
69
+ if inp is None:
70
+ raise RuntimeError("Something went wrong during model export, unable to parse input shapes")
71
+
72
+ if half_precision:
73
+ inp = inp.half()
74
+
75
+ return inp
76
+
77
+
78
+ def extract_torch_model_inputs(
79
+ model: nn.Module | ModelSignatureWrapper,
80
+ input_shapes: list[Any] | None = None,
81
+ half_precision: bool = False,
82
+ batch_size: int = 1,
83
+ ) -> tuple[list[Any] | tuple[Any, ...] | torch.Tensor, list[Any]] | None:
84
+ """Extract the input shapes for the given model and generate a list of torch tensors with the
85
+ given device and dtype.
86
+
87
+ Args:
88
+ model: Module or ModelSignatureWrapper
89
+ input_shapes: Inputs shapes
90
+ half_precision: If True, the model will be exported with half precision
91
+ batch_size: Batch size for the input shapes
92
+ """
93
+ if isinstance(model, ModelSignatureWrapper) and input_shapes is None:
94
+ input_shapes = model.input_shapes
95
+
96
+ if input_shapes is None:
97
+ log.warning(
98
+ "Input shape is None, can not trace model! Please provide input_shapes in the task export configuration."
99
+ )
100
+ return None
101
+
102
+ if half_precision:
103
+ # TODO: This doesn't support bfloat16!!
104
+ inp = generate_torch_inputs(
105
+ input_shapes=input_shapes, device="cuda:0", half_precision=True, dtype=torch.float16, batch_size=batch_size
106
+ )
107
+ else:
108
+ inp = generate_torch_inputs(
109
+ input_shapes=input_shapes, device="cpu", half_precision=False, dtype=torch.float32, batch_size=batch_size
110
+ )
111
+
112
+ return inp, input_shapes
113
+
114
+
115
+ @torch.inference_mode()
116
+ def export_torchscript_model(
117
+ model: nn.Module,
118
+ output_path: str,
119
+ input_shapes: list[Any] | None = None,
120
+ half_precision: bool = False,
121
+ model_name: str = "model.pt",
122
+ ) -> tuple[str, Any] | None:
123
+ """Export a PyTorch model with TorchScript.
124
+
125
+ Args:
126
+ model: PyTorch model to be exported
127
+ input_shapes: Inputs shape for tracing
128
+ output_path: Path to save the model
129
+ half_precision: If True, the model will be exported with half precision
130
+ model_name: Name of the exported model
131
+
132
+ Returns:
133
+ If the model is exported successfully, the path to the model and the input shape are returned.
134
+
135
+ """
136
+ if isinstance(model, CflowLightning):
137
+ log.warning("Exporting cflow model with torchscript is not supported yet.")
138
+ return None
139
+
140
+ model.eval()
141
+ if half_precision:
142
+ model.to("cuda:0")
143
+ model = model.half()
144
+ else:
145
+ model.cpu()
146
+
147
+ model_inputs = extract_torch_model_inputs(model, input_shapes, half_precision)
148
+
149
+ if model_inputs is None:
150
+ return None
151
+
152
+ if isinstance(model, ModelSignatureWrapper):
153
+ model = model.instance
154
+
155
+ inp, input_shapes = model_inputs
156
+
157
+ try:
158
+ try:
159
+ model_jit = torch.jit.trace(model, inp)
160
+ except RuntimeError as e:
161
+ log.warning("Standard tracing failed with exception %s, attempting tracing with strict=False", e)
162
+ model_jit = torch.jit.trace(model, inp, strict=False)
163
+
164
+ os.makedirs(output_path, exist_ok=True)
165
+
166
+ model_path = os.path.join(output_path, model_name)
167
+ model_jit.save(model_path)
168
+
169
+ log.info("Torchscript model saved to %s", os.path.join(os.getcwd(), model_path))
170
+
171
+ return os.path.join(os.getcwd(), model_path), input_shapes
172
+ except Exception as e:
173
+ log.debug("Failed to export torchscript model with exception: %s", e)
174
+ return None
175
+
176
+
177
+ @torch.inference_mode()
178
+ def export_onnx_model(
179
+ model: nn.Module,
180
+ output_path: str,
181
+ onnx_config: DictConfig,
182
+ input_shapes: list[Any] | None = None,
183
+ half_precision: bool = False,
184
+ model_name: str = "model.onnx",
185
+ ) -> tuple[str, Any] | None:
186
+ """Export a PyTorch model with ONNX.
187
+
188
+ Args:
189
+ model: PyTorch model to be exported
190
+ output_path: Path to save the model
191
+ input_shapes: Input shapes for tracing
192
+ onnx_config: ONNX export configuration
193
+ half_precision: If True, the model will be exported with half precision
194
+ model_name: Name of the exported model
195
+ """
196
+ if not ONNX_AVAILABLE:
197
+ log.warning("ONNX is not installed, can not export model in this format.")
198
+ log.warning("Please install ONNX capabilities for quadra with: poetry install -E onnx")
199
+ return None
200
+
201
+ model.eval()
202
+ if half_precision:
203
+ model.to("cuda:0")
204
+ model = model.half()
205
+ else:
206
+ model.cpu()
207
+
208
+ if hasattr(onnx_config, "fixed_batch_size") and onnx_config.fixed_batch_size is not None:
209
+ batch_size = onnx_config.fixed_batch_size
210
+ else:
211
+ batch_size = 1
212
+
213
+ model_inputs = extract_torch_model_inputs(
214
+ model=model, input_shapes=input_shapes, half_precision=half_precision, batch_size=batch_size
215
+ )
216
+ if model_inputs is None:
217
+ return None
218
+
219
+ if isinstance(model, ModelSignatureWrapper):
220
+ model = model.instance
221
+
222
+ inp, input_shapes = model_inputs
223
+
224
+ os.makedirs(output_path, exist_ok=True)
225
+
226
+ model_path = os.path.join(output_path, model_name)
227
+
228
+ input_names = onnx_config.input_names if hasattr(onnx_config, "input_names") else None
229
+
230
+ if input_names is None:
231
+ input_names = []
232
+ for i, _ in enumerate(inp):
233
+ input_names.append(f"input_{i}")
234
+
235
+ output = [model(*inp)]
236
+ output_names = onnx_config.output_names if hasattr(onnx_config, "output_names") else None
237
+
238
+ if output_names is None:
239
+ output_names = []
240
+ for i, _ in enumerate(output):
241
+ output_names.append(f"output_{i}")
242
+
243
+ dynamic_axes = onnx_config.dynamic_axes if hasattr(onnx_config, "dynamic_axes") else None
244
+
245
+ if hasattr(onnx_config, "fixed_batch_size") and onnx_config.fixed_batch_size is not None:
246
+ dynamic_axes = None
247
+ elif dynamic_axes is None:
248
+ dynamic_axes = {}
249
+ for i, _ in enumerate(input_names):
250
+ dynamic_axes[input_names[i]] = {0: "batch_size"}
251
+
252
+ for i, _ in enumerate(output_names):
253
+ dynamic_axes[output_names[i]] = {0: "batch_size"}
254
+
255
+ modified_onnx_config = cast(dict[str, Any], OmegaConf.to_container(onnx_config, resolve=True))
256
+
257
+ modified_onnx_config["input_names"] = input_names
258
+ modified_onnx_config["output_names"] = output_names
259
+ modified_onnx_config["dynamic_axes"] = dynamic_axes
260
+
261
+ simplify = modified_onnx_config.pop("simplify", False)
262
+ _ = modified_onnx_config.pop("fixed_batch_size", None)
263
+
264
+ if len(inp) == 1:
265
+ inp = inp[0]
266
+
267
+ if isinstance(inp, list):
268
+ inp = tuple(inp) # onnx doesn't like lists representing tuples of inputs
269
+
270
+ if isinstance(inp, dict):
271
+ raise ValueError("ONNX export does not support model with dict inputs")
272
+
273
+ try:
274
+ torch.onnx.export(model=model, args=inp, f=model_path, **modified_onnx_config)
275
+
276
+ onnx_model = onnx.load(model_path)
277
+ # Check if ONNX model is valid
278
+ onnx.checker.check_model(onnx_model)
279
+ except Exception as e:
280
+ log.debug("ONNX export failed with error: %s", e)
281
+ return None
282
+
283
+ log.info("ONNX model saved to %s", os.path.join(os.getcwd(), model_path))
284
+
285
+ if half_precision:
286
+ is_export_ok = _safe_export_half_precision_onnx(
287
+ model=model,
288
+ export_model_path=model_path,
289
+ inp=inp,
290
+ onnx_config=onnx_config,
291
+ input_shapes=input_shapes,
292
+ input_names=input_names,
293
+ )
294
+
295
+ if not is_export_ok:
296
+ return None
297
+
298
+ if simplify:
299
+ log.info("Attempting to simplify ONNX model")
300
+ onnx_model = onnx.load(model_path)
301
+
302
+ try:
303
+ simplified_model, check = onnx_simplify(onnx_model)
304
+ except Exception as e:
305
+ log.debug("ONNX simplification failed with error: %s", e)
306
+ check = False
307
+
308
+ if not check:
309
+ log.warning("Something failed during model simplification, only original ONNX model will be exported")
310
+ else:
311
+ model_filename, model_extension = os.path.splitext(model_name)
312
+ model_name = f"{model_filename}_simplified{model_extension}"
313
+ model_path = os.path.join(output_path, model_name)
314
+ onnx.save(simplified_model, model_path)
315
+ log.info("Simplified ONNX model saved to %s", os.path.join(os.getcwd(), model_path))
316
+
317
+ return os.path.join(os.getcwd(), model_path), input_shapes
318
+
319
+
320
+ def _safe_export_half_precision_onnx(
321
+ model: nn.Module,
322
+ export_model_path: str,
323
+ inp: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...],
324
+ onnx_config: DictConfig,
325
+ input_shapes: list[Any],
326
+ input_names: list[str],
327
+ ):
328
+ """Check that the exported half precision ONNX model does not contain NaN values. If it does, attempt to export
329
+ the model with a more stable export and overwrite the original model.
330
+
331
+ Args:
332
+ model: PyTorch model to be exported
333
+ export_model_path: Path to save the model
334
+ inp: Input tensors for the model
335
+ onnx_config: ONNX export configuration
336
+ input_shapes: Input shapes for the model
337
+ input_names: Input names for the model
338
+
339
+ Returns:
340
+ True if the model is stable or it was possible to export a more stable model, False otherwise.
341
+ """
342
+ test_fp_16_model: BaseEvaluationModel = import_deployment_model(
343
+ export_model_path, OmegaConf.create({"onnx": {}}), "cuda:0"
344
+ )
345
+ if not isinstance(inp, Sequence):
346
+ inp = [inp]
347
+
348
+ test_output = test_fp_16_model(*inp)
349
+
350
+ if not isinstance(test_output, Sequence):
351
+ test_output = [test_output]
352
+
353
+ # Check if there are nan values in any of the outputs
354
+ is_broken_model = any(torch.isnan(out).any() for out in test_output)
355
+
356
+ if is_broken_model:
357
+ try:
358
+ log.warning(
359
+ "The exported half precision ONNX model contains NaN values, attempting with a more stable export..."
360
+ )
361
+ # Cast back the fp16 model to fp32 to simulate the export with fp32
362
+ model = model.float()
363
+ log.info("Starting to export model in full precision")
364
+ export_output = export_onnx_model(
365
+ model=model,
366
+ output_path=os.path.dirname(export_model_path),
367
+ onnx_config=onnx_config,
368
+ input_shapes=input_shapes,
369
+ half_precision=False,
370
+ model_name=os.path.basename(export_model_path),
371
+ )
372
+ if export_output is not None:
373
+ export_model_path, _ = export_output
374
+ else:
375
+ log.warning("Failed to export model")
376
+ return False
377
+
378
+ model_fp32 = onnx.load(export_model_path)
379
+ test_data = {input_names[i]: inp[i].float().cpu().numpy() for i in range(len(inp))}
380
+ log.warning("Attempting to convert model in mixed precision, this may take a while...")
381
+ with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
382
+ # This function prints a lot of information that is not useful for the user
383
+ model_fp16 = auto_convert_mixed_precision(
384
+ model_fp32, test_data, rtol=0.01, atol=0.001, keep_io_types=False
385
+ )
386
+ onnx.save(model_fp16, export_model_path)
387
+
388
+ onnx_model = onnx.load(export_model_path)
389
+ # Check if ONNX model is valid
390
+ onnx.checker.check_model(onnx_model)
391
+ return True
392
+ except Exception as e:
393
+ raise RuntimeError(
394
+ "Failed to export model with automatic mixed precision, check your model or disable ONNX export"
395
+ ) from e
396
+ else:
397
+ log.info("Exported half precision ONNX model does not contain NaN values, model is stable")
398
+ return True
399
+
400
+
401
+ def export_pytorch_model(model: nn.Module, output_path: str, model_name: str = "model.pth") -> str:
402
+ """Export pytorch model's parameter dictionary using a deserialized state_dict.
403
+
404
+ Args:
405
+ model: PyTorch model to be exported
406
+ output_path: Path to save the model
407
+ model_name: Name of the exported model
408
+
409
+ Returns:
410
+ If the model is exported successfully, the path to the model is returned.
411
+
412
+ """
413
+ if isinstance(model, ModelSignatureWrapper):
414
+ model = model.instance
415
+
416
+ os.makedirs(output_path, exist_ok=True)
417
+ model.eval()
418
+ model.cpu()
419
+ model_path = os.path.join(output_path, model_name)
420
+ torch.save(model.state_dict(), model_path)
421
+ log.info("Pytorch model saved to %s", os.path.join(output_path, model_name))
422
+
423
+ return os.path.join(os.getcwd(), model_path)
424
+
425
+
426
+ def export_model(
427
+ config: DictConfig,
428
+ model: Any,
429
+ export_folder: str,
430
+ half_precision: bool,
431
+ input_shapes: list[Any] | None = None,
432
+ idx_to_class: dict[int, str] | None = None,
433
+ pytorch_model_type: Literal["backbone", "model"] = "model",
434
+ ) -> tuple[dict[str, Any], dict[str, str]]:
435
+ """Generate deployment models for the task.
436
+
437
+ Args:
438
+ config: Experiment config
439
+ model: Model to be exported
440
+ export_folder: Path to save the exported model
441
+ half_precision: Whether to use half precision for the exported model
442
+ input_shapes: Input shapes for the exported model
443
+ idx_to_class: Mapping from class index to class name
444
+ pytorch_model_type: Type of the pytorch model config to be exported, if it's backbone on disk we will save the
445
+ config.backbone config, otherwise we will save the config.model
446
+
447
+ Returns:
448
+ If the model is exported successfully, return a dictionary containing information about the exported model and
449
+ a second dictionary containing the paths to the exported models. Otherwise, return two empty dictionaries.
450
+ """
451
+ if config.export is None or len(config.export.types) == 0:
452
+ log.info("No export type specified skipping export")
453
+ return {}, {}
454
+
455
+ os.makedirs(export_folder, exist_ok=True)
456
+
457
+ if input_shapes is None:
458
+ # Try to get input shapes from config
459
+ # If this is also None we will try to retrieve it from the ModelSignatureWrapper, if it fails we can't export
460
+ input_shapes = config.export.input_shapes
461
+
462
+ export_paths = {}
463
+
464
+ for export_type in config.export.types:
465
+ if export_type == "torchscript":
466
+ out = export_torchscript_model(
467
+ model=model,
468
+ input_shapes=input_shapes,
469
+ output_path=export_folder,
470
+ half_precision=half_precision,
471
+ )
472
+
473
+ if out is None:
474
+ log.warning("Torchscript export failed, enable debug logging for more details")
475
+ continue
476
+
477
+ export_path, input_shapes = out
478
+ export_paths[export_type] = export_path
479
+ elif export_type == "pytorch":
480
+ export_path = export_pytorch_model(
481
+ model=model,
482
+ output_path=export_folder,
483
+ )
484
+ export_paths[export_type] = export_path
485
+ with open(os.path.join(export_folder, "model_config.yaml"), "w") as f:
486
+ OmegaConf.save(getattr(config, pytorch_model_type), f, resolve=True)
487
+ elif export_type == "onnx":
488
+ if not hasattr(config.export, "onnx"):
489
+ log.warning("No onnx configuration found, skipping onnx export")
490
+ continue
491
+
492
+ out = export_onnx_model(
493
+ model=model,
494
+ output_path=export_folder,
495
+ onnx_config=config.export.onnx,
496
+ input_shapes=input_shapes,
497
+ half_precision=half_precision,
498
+ )
499
+
500
+ if out is None:
501
+ log.warning("ONNX export failed, enable debug logging for more details")
502
+ continue
503
+
504
+ export_path, input_shapes = out
505
+ export_paths[export_type] = export_path
506
+ else:
507
+ log.warning("Export type: %s not implemented", export_type)
508
+
509
+ if len(export_paths) == 0:
510
+ log.warning("No export type was successful, no model will be available for deployment")
511
+ return {}, export_paths
512
+
513
+ model_json = {
514
+ "input_size": input_shapes,
515
+ "classes": idx_to_class,
516
+ "mean": list(config.transforms.mean),
517
+ "std": list(config.transforms.std),
518
+ }
519
+
520
+ return model_json, export_paths
521
+
522
+
523
+ def import_deployment_model(
524
+ model_path: str,
525
+ inference_config: DictConfig,
526
+ device: str,
527
+ model_architecture: nn.Module | None = None,
528
+ ) -> BaseEvaluationModel:
529
+ """Try to import a model for deployment, currently only supports torchscript .pt files and
530
+ state dictionaries .pth files.
531
+
532
+ Args:
533
+ model_path: Path to the model
534
+ inference_config: Inference configuration, should contain keys for the different deployment models
535
+ device: Device to load the model on
536
+ model_architecture: Optional model architecture to use for loading a plain pytorch model
537
+
538
+ Returns:
539
+ A tuple containing the model and the model type
540
+ """
541
+ log.info("Importing trained model")
542
+
543
+ file_extension = os.path.splitext(os.path.basename(model_path))[1]
544
+ deployment_model: BaseEvaluationModel | None = None
545
+
546
+ if file_extension == ".pt":
547
+ deployment_model = TorchscriptEvaluationModel(config=inference_config.torchscript)
548
+ elif file_extension == ".pth":
549
+ if model_architecture is None:
550
+ raise ValueError("model_architecture must be specified when loading a .pth file")
551
+
552
+ deployment_model = TorchEvaluationModel(config=inference_config.pytorch, model_architecture=model_architecture)
553
+ elif file_extension == ".onnx":
554
+ deployment_model = ONNXEvaluationModel(config=inference_config.onnx)
555
+
556
+ if deployment_model is not None:
557
+ deployment_model.load_from_disk(model_path=model_path, device=device)
558
+
559
+ log.info("Imported %s model", deployment_model.__class__.__name__)
560
+
561
+ return deployment_model
562
+
563
+ raise ValueError(f"Unable to load model with extension {file_extension}, valid extensions are: ['.pt', 'pth']")
564
+
565
+
566
+ # This may be better as a dict?
567
+ def get_export_extension(export_type: str) -> str:
568
+ """Get the extension of the exported model.
569
+
570
+ Args:
571
+ export_type: The type of the exported model.
572
+
573
+ Returns:
574
+ The extension of the exported model.
575
+ """
576
+ if export_type == "onnx":
577
+ extension = "onnx"
578
+ elif export_type == "torchscript":
579
+ extension = "pt"
580
+ elif export_type == "pytorch":
581
+ extension = "pth"
582
+ else:
583
+ raise ValueError(f"Unsupported export type {export_type}")
584
+
585
+ return extension
@@ -0,0 +1,32 @@
1
+ from __future__ import annotations
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+
7
+ def crop_image(image: np.ndarray, roi: tuple[int, int, int, int]) -> np.ndarray:
8
+ """Crop an image given a roi in proper format.
9
+
10
+ Args:
11
+ image: array of size HxW or HxWxC
12
+ roi: (w_upper_left, h_upper_left, w_bottom_right, h_bottom_right)
13
+
14
+ Returns:
15
+ Cropped image based on roi
16
+ """
17
+ return image[roi[1] : roi[3], roi[0] : roi[2]]
18
+
19
+
20
+ def keep_aspect_ratio_resize(image: np.ndarray, size: int = 224, interpolation: int = 1) -> np.ndarray:
21
+ """Resize input image while keeping its aspect ratio."""
22
+ (h, w) = image.shape[:2]
23
+
24
+ if h < w:
25
+ height = size
26
+ width = int(w * size / h)
27
+ else:
28
+ width = size
29
+ height = int(h * size / w)
30
+
31
+ resized = cv2.resize(image, (width, height), interpolation=interpolation)
32
+ return resized
quadra/utils/logger.py ADDED
@@ -0,0 +1,15 @@
1
+ import logging
2
+
3
+ from pytorch_lightning.utilities import rank_zero_only
4
+
5
+
6
+ def get_logger(name=__name__) -> logging.Logger:
7
+ """Initializes multi-GPU-friendly python logger."""
8
+ logger = logging.getLogger(name)
9
+
10
+ # this ensures all logging levels get marked with the rank zero decorator
11
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
12
+ for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"):
13
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
14
+
15
+ return logger