quadra 0.0.1__py3-none-any.whl → 2.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +32 -0
  37. quadra/configs/callbacks/default.yaml +37 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +67 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +46 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +331 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1264 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +579 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +524 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.1.13.dist-info/LICENSE +201 -0
  297. quadra-2.1.13.dist-info/METADATA +386 -0
  298. quadra-2.1.13.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.1.13.dist-info}/WHEEL +1 -1
  300. quadra-2.1.13.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
quadra/utils/export.py ADDED
@@ -0,0 +1,579 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from collections.abc import Sequence
5
+ from typing import Any, Literal, TypeVar, cast
6
+
7
+ import torch
8
+ from anomalib.models.cflow import CflowLightning
9
+ from omegaconf import DictConfig, ListConfig, OmegaConf
10
+ from onnxconverter_common import auto_convert_mixed_precision
11
+ from torch import nn
12
+
13
+ from quadra.models.base import ModelSignatureWrapper
14
+ from quadra.models.evaluation import (
15
+ BaseEvaluationModel,
16
+ ONNXEvaluationModel,
17
+ TorchEvaluationModel,
18
+ TorchscriptEvaluationModel,
19
+ )
20
+ from quadra.utils.logger import get_logger
21
+
22
+ try:
23
+ import onnx # noqa
24
+ from onnxsim import simplify as onnx_simplify # noqa
25
+
26
+ ONNX_AVAILABLE = True
27
+ except ImportError:
28
+ ONNX_AVAILABLE = False
29
+
30
+ log = get_logger(__name__)
31
+
32
+ BaseDeploymentModelT = TypeVar("BaseDeploymentModelT", bound=BaseEvaluationModel)
33
+
34
+
35
+ def generate_torch_inputs(
36
+ input_shapes: list[Any],
37
+ device: str | torch.device,
38
+ half_precision: bool = False,
39
+ dtype: torch.dtype = torch.float32,
40
+ batch_size: int = 1,
41
+ ) -> list[Any] | tuple[Any, ...] | torch.Tensor:
42
+ """Given a list of input shapes that can contain either lists, tuples or dicts, with tuples being the input shapes
43
+ of the model, generate a list of torch tensors with the given device and dtype.
44
+ """
45
+ inp = None
46
+
47
+ if isinstance(input_shapes, (ListConfig, DictConfig)):
48
+ input_shapes = OmegaConf.to_container(input_shapes)
49
+
50
+ if isinstance(input_shapes, list):
51
+ if any(isinstance(inp, (Sequence, dict)) for inp in input_shapes):
52
+ return [generate_torch_inputs(inp, device, half_precision, dtype) for inp in input_shapes]
53
+
54
+ # Base case
55
+ inp = torch.randn((batch_size, *input_shapes), dtype=dtype, device=device)
56
+
57
+ if isinstance(input_shapes, dict):
58
+ return {k: generate_torch_inputs(v, device, half_precision, dtype) for k, v in input_shapes.items()}
59
+
60
+ if isinstance(input_shapes, tuple):
61
+ if any(isinstance(inp, (Sequence, dict)) for inp in input_shapes):
62
+ # The tuple contains a list, tuple or dict
63
+ return tuple(generate_torch_inputs(inp, device, half_precision, dtype) for inp in input_shapes)
64
+
65
+ # Base case
66
+ inp = torch.randn((batch_size, *input_shapes), dtype=dtype, device=device)
67
+
68
+ if inp is None:
69
+ raise RuntimeError("Something went wrong during model export, unable to parse input shapes")
70
+
71
+ if half_precision:
72
+ inp = inp.half()
73
+
74
+ return inp
75
+
76
+
77
+ def extract_torch_model_inputs(
78
+ model: nn.Module | ModelSignatureWrapper,
79
+ input_shapes: list[Any] | None = None,
80
+ half_precision: bool = False,
81
+ batch_size: int = 1,
82
+ ) -> tuple[list[Any] | tuple[Any, ...] | torch.Tensor, list[Any]] | None:
83
+ """Extract the input shapes for the given model and generate a list of torch tensors with the
84
+ given device and dtype.
85
+
86
+ Args:
87
+ model: Module or ModelSignatureWrapper
88
+ input_shapes: Inputs shapes
89
+ half_precision: If True, the model will be exported with half precision
90
+ batch_size: Batch size for the input shapes
91
+ """
92
+ if isinstance(model, ModelSignatureWrapper) and input_shapes is None:
93
+ input_shapes = model.input_shapes
94
+
95
+ if input_shapes is None:
96
+ log.warning(
97
+ "Input shape is None, can not trace model! Please provide input_shapes in the task export configuration."
98
+ )
99
+ return None
100
+
101
+ if half_precision:
102
+ # TODO: This doesn't support bfloat16!!
103
+ inp = generate_torch_inputs(
104
+ input_shapes=input_shapes, device="cuda:0", half_precision=True, dtype=torch.float16, batch_size=batch_size
105
+ )
106
+ else:
107
+ inp = generate_torch_inputs(
108
+ input_shapes=input_shapes, device="cpu", half_precision=False, dtype=torch.float32, batch_size=batch_size
109
+ )
110
+
111
+ return inp, input_shapes
112
+
113
+
114
+ @torch.inference_mode()
115
+ def export_torchscript_model(
116
+ model: nn.Module,
117
+ output_path: str,
118
+ input_shapes: list[Any] | None = None,
119
+ half_precision: bool = False,
120
+ model_name: str = "model.pt",
121
+ ) -> tuple[str, Any] | None:
122
+ """Export a PyTorch model with TorchScript.
123
+
124
+ Args:
125
+ model: PyTorch model to be exported
126
+ input_shapes: Inputs shape for tracing
127
+ output_path: Path to save the model
128
+ half_precision: If True, the model will be exported with half precision
129
+ model_name: Name of the exported model
130
+
131
+ Returns:
132
+ If the model is exported successfully, the path to the model and the input shape are returned.
133
+
134
+ """
135
+ if isinstance(model, CflowLightning):
136
+ log.warning("Exporting cflow model with torchscript is not supported yet.")
137
+ return None
138
+
139
+ model.eval()
140
+ if half_precision:
141
+ model.to("cuda:0")
142
+ model = model.half()
143
+ else:
144
+ model.cpu()
145
+
146
+ model_inputs = extract_torch_model_inputs(model, input_shapes, half_precision)
147
+
148
+ if model_inputs is None:
149
+ return None
150
+
151
+ if isinstance(model, ModelSignatureWrapper):
152
+ model = model.instance
153
+
154
+ inp, input_shapes = model_inputs
155
+
156
+ try:
157
+ try:
158
+ model_jit = torch.jit.trace(model, inp)
159
+ except RuntimeError as e:
160
+ log.warning("Standard tracing failed with exception %s, attempting tracing with strict=False", e)
161
+ model_jit = torch.jit.trace(model, inp, strict=False)
162
+
163
+ os.makedirs(output_path, exist_ok=True)
164
+
165
+ model_path = os.path.join(output_path, model_name)
166
+ model_jit.save(model_path)
167
+
168
+ log.info("Torchscript model saved to %s", os.path.join(os.getcwd(), model_path))
169
+
170
+ return os.path.join(os.getcwd(), model_path), input_shapes
171
+ except Exception as e:
172
+ log.debug("Failed to export torchscript model with exception: %s", e)
173
+ return None
174
+
175
+
176
+ @torch.inference_mode()
177
+ def export_onnx_model(
178
+ model: nn.Module,
179
+ output_path: str,
180
+ onnx_config: DictConfig,
181
+ input_shapes: list[Any] | None = None,
182
+ half_precision: bool = False,
183
+ model_name: str = "model.onnx",
184
+ ) -> tuple[str, Any] | None:
185
+ """Export a PyTorch model with ONNX.
186
+
187
+ Args:
188
+ model: PyTorch model to be exported
189
+ output_path: Path to save the model
190
+ input_shapes: Input shapes for tracing
191
+ onnx_config: ONNX export configuration
192
+ half_precision: If True, the model will be exported with half precision
193
+ model_name: Name of the exported model
194
+ """
195
+ if not ONNX_AVAILABLE:
196
+ log.warning("ONNX is not installed, can not export model in this format.")
197
+ log.warning("Please install ONNX capabilities for quadra with: poetry install -E onnx")
198
+ return None
199
+
200
+ model.eval()
201
+ if half_precision:
202
+ model.to("cuda:0")
203
+ model = model.half()
204
+ else:
205
+ model.cpu()
206
+
207
+ if hasattr(onnx_config, "fixed_batch_size") and onnx_config.fixed_batch_size is not None:
208
+ batch_size = onnx_config.fixed_batch_size
209
+ else:
210
+ batch_size = 1
211
+
212
+ model_inputs = extract_torch_model_inputs(
213
+ model=model, input_shapes=input_shapes, half_precision=half_precision, batch_size=batch_size
214
+ )
215
+ if model_inputs is None:
216
+ return None
217
+
218
+ if isinstance(model, ModelSignatureWrapper):
219
+ model = model.instance
220
+
221
+ inp, input_shapes = model_inputs
222
+
223
+ os.makedirs(output_path, exist_ok=True)
224
+
225
+ model_path = os.path.join(output_path, model_name)
226
+
227
+ input_names = onnx_config.input_names if hasattr(onnx_config, "input_names") else None
228
+
229
+ if input_names is None:
230
+ input_names = []
231
+ for i, _ in enumerate(inp):
232
+ input_names.append(f"input_{i}")
233
+
234
+ output = [model(*inp)]
235
+ output_names = onnx_config.output_names if hasattr(onnx_config, "output_names") else None
236
+
237
+ if output_names is None:
238
+ output_names = []
239
+ for i, _ in enumerate(output):
240
+ output_names.append(f"output_{i}")
241
+
242
+ dynamic_axes = onnx_config.dynamic_axes if hasattr(onnx_config, "dynamic_axes") else None
243
+
244
+ if hasattr(onnx_config, "fixed_batch_size") and onnx_config.fixed_batch_size is not None:
245
+ dynamic_axes = None
246
+ elif dynamic_axes is None:
247
+ dynamic_axes = {}
248
+ for i, _ in enumerate(input_names):
249
+ dynamic_axes[input_names[i]] = {0: "batch_size"}
250
+
251
+ for i, _ in enumerate(output_names):
252
+ dynamic_axes[output_names[i]] = {0: "batch_size"}
253
+
254
+ modified_onnx_config = cast(dict[str, Any], OmegaConf.to_container(onnx_config, resolve=True))
255
+
256
+ modified_onnx_config["input_names"] = input_names
257
+ modified_onnx_config["output_names"] = output_names
258
+ modified_onnx_config["dynamic_axes"] = dynamic_axes
259
+
260
+ simplify = modified_onnx_config.pop("simplify", False)
261
+ _ = modified_onnx_config.pop("fixed_batch_size", None)
262
+
263
+ if len(inp) == 1:
264
+ inp = inp[0]
265
+
266
+ if isinstance(inp, list):
267
+ inp = tuple(inp) # onnx doesn't like lists representing tuples of inputs
268
+
269
+ if isinstance(inp, dict):
270
+ raise ValueError("ONNX export does not support model with dict inputs")
271
+
272
+ try:
273
+ torch.onnx.export(model=model, args=inp, f=model_path, **modified_onnx_config)
274
+
275
+ onnx_model = onnx.load(model_path)
276
+ # Check if ONNX model is valid
277
+ onnx.checker.check_model(onnx_model)
278
+ except Exception as e:
279
+ log.debug("ONNX export failed with error: %s", e)
280
+ return None
281
+
282
+ log.info("ONNX model saved to %s", os.path.join(os.getcwd(), model_path))
283
+
284
+ if half_precision:
285
+ is_export_ok = _safe_export_half_precision_onnx(
286
+ model=model,
287
+ export_model_path=model_path,
288
+ inp=inp,
289
+ onnx_config=onnx_config,
290
+ input_shapes=input_shapes,
291
+ input_names=input_names,
292
+ )
293
+
294
+ if not is_export_ok:
295
+ return None
296
+
297
+ if simplify:
298
+ log.info("Attempting to simplify ONNX model")
299
+ onnx_model = onnx.load(model_path)
300
+
301
+ try:
302
+ simplified_model, check = onnx_simplify(onnx_model)
303
+ except Exception as e:
304
+ log.debug("ONNX simplification failed with error: %s", e)
305
+ check = False
306
+
307
+ if not check:
308
+ log.warning("Something failed during model simplification, only original ONNX model will be exported")
309
+ else:
310
+ model_filename, model_extension = os.path.splitext(model_name)
311
+ model_name = f"{model_filename}_simplified{model_extension}"
312
+ model_path = os.path.join(output_path, model_name)
313
+ onnx.save(simplified_model, model_path)
314
+ log.info("Simplified ONNX model saved to %s", os.path.join(os.getcwd(), model_path))
315
+
316
+ return os.path.join(os.getcwd(), model_path), input_shapes
317
+
318
+
319
+ def _safe_export_half_precision_onnx(
320
+ model: nn.Module,
321
+ export_model_path: str,
322
+ inp: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...],
323
+ onnx_config: DictConfig,
324
+ input_shapes: list[Any],
325
+ input_names: list[str],
326
+ ):
327
+ """Check that the exported half precision ONNX model does not contain NaN values. If it does, attempt to export
328
+ the model with a more stable export and overwrite the original model.
329
+
330
+ Args:
331
+ model: PyTorch model to be exported
332
+ export_model_path: Path to save the model
333
+ inp: Input tensors for the model
334
+ onnx_config: ONNX export configuration
335
+ input_shapes: Input shapes for the model
336
+ input_names: Input names for the model
337
+
338
+ Returns:
339
+ True if the model is stable or it was possible to export a more stable model, False otherwise.
340
+ """
341
+ test_fp_16_model: BaseEvaluationModel = import_deployment_model(
342
+ export_model_path, OmegaConf.create({"onnx": {}}), "cuda:0"
343
+ )
344
+ if not isinstance(inp, Sequence):
345
+ inp = [inp]
346
+
347
+ test_output = test_fp_16_model(*inp)
348
+
349
+ if not isinstance(test_output, Sequence):
350
+ test_output = [test_output]
351
+
352
+ # Check if there are nan values in any of the outputs
353
+ is_broken_model = any(torch.isnan(out).any() for out in test_output)
354
+
355
+ if is_broken_model:
356
+ try:
357
+ log.warning(
358
+ "The exported half precision ONNX model contains NaN values, attempting with a more stable export..."
359
+ )
360
+ # Cast back the fp16 model to fp32 to simulate the export with fp32
361
+ model = model.float()
362
+ log.info("Starting to export model in full precision")
363
+ export_output = export_onnx_model(
364
+ model=model,
365
+ output_path=os.path.dirname(export_model_path),
366
+ onnx_config=onnx_config,
367
+ input_shapes=input_shapes,
368
+ half_precision=False,
369
+ model_name=os.path.basename(export_model_path),
370
+ )
371
+ if export_output is not None:
372
+ export_model_path, _ = export_output
373
+ else:
374
+ log.warning("Failed to export model")
375
+ return False
376
+
377
+ model_fp32 = onnx.load(export_model_path)
378
+ test_data = {input_names[i]: inp[i].float().cpu().numpy() for i in range(len(inp))}
379
+ log.warning("Attempting to convert model in mixed precision, this may take a while...")
380
+ model_fp16 = auto_convert_mixed_precision(model_fp32, test_data, rtol=0.01, atol=0.001, keep_io_types=False)
381
+ onnx.save(model_fp16, export_model_path)
382
+
383
+ onnx_model = onnx.load(export_model_path)
384
+ # Check if ONNX model is valid
385
+ onnx.checker.check_model(onnx_model)
386
+ return True
387
+ except Exception as e:
388
+ log.debug("Failed to export model with mixed precision with error: %s", e)
389
+ return False
390
+ else:
391
+ log.info("Exported half precision ONNX model does not contain NaN values, model is stable")
392
+ return True
393
+
394
+
395
+ def export_pytorch_model(model: nn.Module, output_path: str, model_name: str = "model.pth") -> str:
396
+ """Export pytorch model's parameter dictionary using a deserialized state_dict.
397
+
398
+ Args:
399
+ model: PyTorch model to be exported
400
+ output_path: Path to save the model
401
+ model_name: Name of the exported model
402
+
403
+ Returns:
404
+ If the model is exported successfully, the path to the model is returned.
405
+
406
+ """
407
+ if isinstance(model, ModelSignatureWrapper):
408
+ model = model.instance
409
+
410
+ os.makedirs(output_path, exist_ok=True)
411
+ model.eval()
412
+ model.cpu()
413
+ model_path = os.path.join(output_path, model_name)
414
+ torch.save(model.state_dict(), model_path)
415
+ log.info("Pytorch model saved to %s", os.path.join(output_path, model_name))
416
+
417
+ return os.path.join(os.getcwd(), model_path)
418
+
419
+
420
+ def export_model(
421
+ config: DictConfig,
422
+ model: Any,
423
+ export_folder: str,
424
+ half_precision: bool,
425
+ input_shapes: list[Any] | None = None,
426
+ idx_to_class: dict[int, str] | None = None,
427
+ pytorch_model_type: Literal["backbone", "model"] = "model",
428
+ ) -> tuple[dict[str, Any], dict[str, str]]:
429
+ """Generate deployment models for the task.
430
+
431
+ Args:
432
+ config: Experiment config
433
+ model: Model to be exported
434
+ export_folder: Path to save the exported model
435
+ half_precision: Whether to use half precision for the exported model
436
+ input_shapes: Input shapes for the exported model
437
+ idx_to_class: Mapping from class index to class name
438
+ pytorch_model_type: Type of the pytorch model config to be exported, if it's backbone on disk we will save the
439
+ config.backbone config, otherwise we will save the config.model
440
+
441
+ Returns:
442
+ If the model is exported successfully, return a dictionary containing information about the exported model and
443
+ a second dictionary containing the paths to the exported models. Otherwise, return two empty dictionaries.
444
+ """
445
+ if config.export is None or len(config.export.types) == 0:
446
+ log.info("No export type specified skipping export")
447
+ return {}, {}
448
+
449
+ os.makedirs(export_folder, exist_ok=True)
450
+
451
+ if input_shapes is None:
452
+ # Try to get input shapes from config
453
+ # If this is also None we will try to retrieve it from the ModelSignatureWrapper, if it fails we can't export
454
+ input_shapes = config.export.input_shapes
455
+
456
+ export_paths = {}
457
+
458
+ for export_type in config.export.types:
459
+ if export_type == "torchscript":
460
+ out = export_torchscript_model(
461
+ model=model,
462
+ input_shapes=input_shapes,
463
+ output_path=export_folder,
464
+ half_precision=half_precision,
465
+ )
466
+
467
+ if out is None:
468
+ log.warning("Torchscript export failed, enable debug logging for more details")
469
+ continue
470
+
471
+ export_path, input_shapes = out
472
+ export_paths[export_type] = export_path
473
+ elif export_type == "pytorch":
474
+ export_path = export_pytorch_model(
475
+ model=model,
476
+ output_path=export_folder,
477
+ )
478
+ export_paths[export_type] = export_path
479
+ with open(os.path.join(export_folder, "model_config.yaml"), "w") as f:
480
+ OmegaConf.save(getattr(config, pytorch_model_type), f, resolve=True)
481
+ elif export_type == "onnx":
482
+ if not hasattr(config.export, "onnx"):
483
+ log.warning("No onnx configuration found, skipping onnx export")
484
+ continue
485
+
486
+ out = export_onnx_model(
487
+ model=model,
488
+ output_path=export_folder,
489
+ onnx_config=config.export.onnx,
490
+ input_shapes=input_shapes,
491
+ half_precision=half_precision,
492
+ )
493
+
494
+ if out is None:
495
+ log.warning("ONNX export failed, enable debug logging for more details")
496
+ continue
497
+
498
+ export_path, input_shapes = out
499
+ export_paths[export_type] = export_path
500
+ else:
501
+ log.warning("Export type: %s not implemented", export_type)
502
+
503
+ if len(export_paths) == 0:
504
+ log.warning("No export type was successful, no model will be available for deployment")
505
+ return {}, export_paths
506
+
507
+ model_json = {
508
+ "input_size": input_shapes,
509
+ "classes": idx_to_class,
510
+ "mean": list(config.transforms.mean),
511
+ "std": list(config.transforms.std),
512
+ }
513
+
514
+ return model_json, export_paths
515
+
516
+
517
+ def import_deployment_model(
518
+ model_path: str,
519
+ inference_config: DictConfig,
520
+ device: str,
521
+ model_architecture: nn.Module | None = None,
522
+ ) -> BaseEvaluationModel:
523
+ """Try to import a model for deployment, currently only supports torchscript .pt files and
524
+ state dictionaries .pth files.
525
+
526
+ Args:
527
+ model_path: Path to the model
528
+ inference_config: Inference configuration, should contain keys for the different deployment models
529
+ device: Device to load the model on
530
+ model_architecture: Optional model architecture to use for loading a plain pytorch model
531
+
532
+ Returns:
533
+ A tuple containing the model and the model type
534
+ """
535
+ log.info("Importing trained model")
536
+
537
+ file_extension = os.path.splitext(os.path.basename(model_path))[1]
538
+ deployment_model: BaseEvaluationModel | None = None
539
+
540
+ if file_extension == ".pt":
541
+ deployment_model = TorchscriptEvaluationModel(config=inference_config.torchscript)
542
+ elif file_extension == ".pth":
543
+ if model_architecture is None:
544
+ raise ValueError("model_architecture must be specified when loading a .pth file")
545
+
546
+ deployment_model = TorchEvaluationModel(config=inference_config.pytorch, model_architecture=model_architecture)
547
+ elif file_extension == ".onnx":
548
+ deployment_model = ONNXEvaluationModel(config=inference_config.onnx)
549
+
550
+ if deployment_model is not None:
551
+ deployment_model.load_from_disk(model_path=model_path, device=device)
552
+
553
+ log.info("Imported %s model", deployment_model.__class__.__name__)
554
+
555
+ return deployment_model
556
+
557
+ raise ValueError(f"Unable to load model with extension {file_extension}, valid extensions are: ['.pt', 'pth']")
558
+
559
+
560
+ # This may be better as a dict?
561
+ def get_export_extension(export_type: str) -> str:
562
+ """Get the extension of the exported model.
563
+
564
+ Args:
565
+ export_type: The type of the exported model.
566
+
567
+ Returns:
568
+ The extension of the exported model.
569
+ """
570
+ if export_type == "onnx":
571
+ extension = "onnx"
572
+ elif export_type == "torchscript":
573
+ extension = "pt"
574
+ elif export_type == "pytorch":
575
+ extension = "pth"
576
+ else:
577
+ raise ValueError(f"Unsupported export type {export_type}")
578
+
579
+ return extension
@@ -0,0 +1,32 @@
1
+ from __future__ import annotations
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+
7
+ def crop_image(image: np.ndarray, roi: tuple[int, int, int, int]) -> np.ndarray:
8
+ """Crop an image given a roi in proper format.
9
+
10
+ Args:
11
+ image: array of size HxW or HxWxC
12
+ roi: (w_upper_left, h_upper_left, w_bottom_right, h_bottom_right)
13
+
14
+ Returns:
15
+ Cropped image based on roi
16
+ """
17
+ return image[roi[1] : roi[3], roi[0] : roi[2]]
18
+
19
+
20
+ def keep_aspect_ratio_resize(image: np.ndarray, size: int = 224, interpolation: int = 1) -> np.ndarray:
21
+ """Resize input image while keeping its aspect ratio."""
22
+ (h, w) = image.shape[:2]
23
+
24
+ if h < w:
25
+ height = size
26
+ width = int(w * size / h)
27
+ else:
28
+ width = size
29
+ height = int(h * size / w)
30
+
31
+ resized = cv2.resize(image, (width, height), interpolation=interpolation)
32
+ return resized
quadra/utils/logger.py ADDED
@@ -0,0 +1,15 @@
1
+ import logging
2
+
3
+ from pytorch_lightning.utilities import rank_zero_only
4
+
5
+
6
+ def get_logger(name=__name__) -> logging.Logger:
7
+ """Initializes multi-GPU-friendly python logger."""
8
+ logger = logging.getLogger(name)
9
+
10
+ # this ensures all logging levels get marked with the rank zero decorator
11
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
12
+ for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"):
13
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
14
+
15
+ return logger