quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
quadra/utils/utils.py ADDED
@@ -0,0 +1,525 @@
1
+ """Common utility functions.
2
+ Some of them are mostly based on https://github.com/ashleve/lightning-hydra-template.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import glob
8
+ import json
9
+ import logging
10
+ import os
11
+ import subprocess
12
+ import sys
13
+ import warnings
14
+ from collections.abc import Iterable, Iterator, Sequence
15
+ from typing import Any, cast
16
+
17
+ import cv2
18
+ import dotenv
19
+ import mlflow
20
+ import numpy as np
21
+ import pytorch_lightning as pl
22
+ import rich.syntax
23
+ import rich.tree
24
+ import torch
25
+ from hydra.core.hydra_config import HydraConfig
26
+ from hydra.utils import get_original_cwd
27
+ from lightning_fabric.utilities.device_parser import _parse_gpu_ids
28
+ from omegaconf import DictConfig, OmegaConf
29
+ from pytorch_lightning.loggers import TensorBoardLogger
30
+ from pytorch_lightning.utilities import rank_zero_only
31
+
32
+ import quadra
33
+ import quadra.utils.export as quadra_export
34
+ from quadra.callbacks.mlflow import get_mlflow_logger
35
+ from quadra.utils.mlflow import infer_signature_model
36
+
37
+ try:
38
+ import onnx # noqa
39
+
40
+ ONNX_AVAILABLE = True
41
+ except ImportError:
42
+ ONNX_AVAILABLE = False
43
+
44
+
45
+ IMAGE_EXTENSIONS: list[str] = [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".pbm", ".pgm", ".ppm", ".pxm", ".pnm"]
46
+
47
+
48
+ def get_logger(name=__name__) -> logging.Logger:
49
+ """Initializes multi-GPU-friendly python logger."""
50
+ logger = logging.getLogger(name)
51
+
52
+ # this ensures all logging levels get marked with the rank zero decorator
53
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
54
+ for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"):
55
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
56
+
57
+ return logger
58
+
59
+
60
+ def extras(config: DictConfig) -> None:
61
+ """A couple of optional utilities, controlled by main config file:
62
+ - disabling warnings
63
+ - forcing debug friendly configuration
64
+ - verifying experiment name is set when running in experiment mode.
65
+ Modifies DictConfig in place.
66
+
67
+ Args:
68
+ config: Configuration composed by Hydra.
69
+ """
70
+ logging.basicConfig()
71
+ logging.getLogger().setLevel(config.core.log_level.upper())
72
+
73
+ log = get_logger(__name__)
74
+ config.core.command += " ".join(sys.argv)
75
+ config.core.experiment_path = os.getcwd()
76
+
77
+ # disable python warnings if <config.ignore_warnings=True>
78
+ if config.get("ignore_warnings"):
79
+ log.info("Disabling python warnings! <config.ignore_warnings=True>")
80
+ warnings.filterwarnings("ignore")
81
+
82
+ # force debugger friendly configuration if <config.trainer.fast_dev_run=True>
83
+ # debuggers don't like GPUs and multiprocessing
84
+ if config.get("trainer") and config.trainer.get("fast_dev_run"):
85
+ log.info("Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>")
86
+ if config.trainer.get("gpus"):
87
+ config.trainer.devices = 1
88
+ config.trainer.accelerator = "cpu"
89
+ config.trainer.gpus = None
90
+ if config.datamodule.get("pin_memory"):
91
+ config.datamodule.pin_memory = False
92
+ if config.datamodule.get("num_workers"):
93
+ config.datamodule.num_workers = 0
94
+
95
+
96
+ @rank_zero_only
97
+ def print_config(
98
+ config: DictConfig,
99
+ fields: Sequence[str] = (
100
+ "task",
101
+ "trainer",
102
+ "model",
103
+ "datamodule",
104
+ "callbacks",
105
+ "logger",
106
+ "core",
107
+ "backbone",
108
+ "transforms",
109
+ "optimizer",
110
+ "scheduler",
111
+ ),
112
+ resolve: bool = True,
113
+ ) -> None:
114
+ """Prints content of DictConfig using Rich library and its tree structure.
115
+
116
+ Args:
117
+ config: Configuration composed by Hydra.
118
+ fields: Determines which main fields from config will
119
+ be printed and in what order.
120
+ resolve: Whether to resolve reference fields of DictConfig.
121
+ """
122
+ style = "dim"
123
+ tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
124
+
125
+ for field in fields:
126
+ branch = tree.add(field, style=style, guide_style=style)
127
+
128
+ config_section = config.get(field)
129
+ branch_content = str(config_section)
130
+ if isinstance(config_section, DictConfig):
131
+ branch_content = OmegaConf.to_yaml(config_section, resolve=resolve)
132
+
133
+ branch.add(rich.syntax.Syntax(branch_content, "yaml"))
134
+
135
+ rich.print(tree)
136
+
137
+ with open("config_tree.txt", "w") as fp:
138
+ rich.print(tree, file=fp)
139
+
140
+
141
+ @rank_zero_only
142
+ def log_hyperparameters(
143
+ config: DictConfig,
144
+ model: pl.LightningModule,
145
+ trainer: pl.Trainer,
146
+ ) -> None:
147
+ """This method controls which parameters from Hydra config are saved by Lightning loggers.
148
+
149
+ Additionaly saves:
150
+ - number of trainable model parameters
151
+ """
152
+ log = get_logger(__name__)
153
+
154
+ if not HydraConfig.initialized() or trainer.logger is None:
155
+ return
156
+
157
+ log.info("Logging hyperparameters!")
158
+ hydra_cfg = HydraConfig.get()
159
+ hydra_choices = OmegaConf.to_container(hydra_cfg.runtime.choices)
160
+ if isinstance(hydra_choices, dict):
161
+ # For multirun override the choices that are not automatically updated
162
+ for item in hydra_cfg.overrides.task:
163
+ if "." in item:
164
+ continue
165
+
166
+ override, value = item.split("=")
167
+ hydra_choices[override] = value
168
+
169
+ hparams = {}
170
+ hydra_choices_final = {}
171
+ for k, v in hydra_choices.items():
172
+ if isinstance(k, str):
173
+ k_replaced = k.replace("@", "_at_")
174
+ hydra_choices_final[k_replaced] = v
175
+ if v is not None and isinstance(v, str) and "@" in v:
176
+ hydra_choices_final[k_replaced] = v.replace("@", "_at_")
177
+
178
+ hparams.update(hydra_choices_final)
179
+ else:
180
+ logging.warning("Hydra choices is not a dictionary, skip adding them to the logger")
181
+ # save number of model parameters
182
+ hparams["model/params_total"] = sum(p.numel() for p in model.parameters())
183
+ hparams["model/params_trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
184
+ hparams["model/params_not_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
185
+ hparams["experiment_path"] = config.core.experiment_path
186
+ hparams["command"] = config.core.command
187
+ hparams["library/version"] = str(quadra.__version__)
188
+
189
+ with open(os.devnull, "w") as fnull:
190
+ if subprocess.call(["git", "-C", get_original_cwd(), "status"], stderr=subprocess.STDOUT, stdout=fnull) == 0:
191
+ try:
192
+ hparams["git/commit"] = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip()
193
+ hparams["git/branch"] = (
194
+ subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("ascii").strip()
195
+ )
196
+ hparams["git/remote"] = (
197
+ subprocess.check_output(["git", "remote", "get-url", "origin"]).decode("ascii").strip()
198
+ )
199
+ except subprocess.CalledProcessError:
200
+ log.warning(
201
+ "Could not get git commit, branch or remote information, the repository might not have any commits "
202
+ " yet or it might have been initialized wrongly."
203
+ )
204
+ else:
205
+ log.warning("Could not find git repository, skipping git commit and branch info")
206
+
207
+ # send hparams to all loggers
208
+ trainer.logger.log_hyperparams(hparams)
209
+
210
+
211
+ def upload_file_tensorboard(file_path: str, tensorboard_logger: TensorBoardLogger) -> None:
212
+ """Upload a file to tensorboard handling different extensions.
213
+
214
+ Args:
215
+ file_path: Path to the file to upload.
216
+ tensorboard_logger: Tensorboard logger instance.
217
+ """
218
+ tag = os.path.basename(file_path)
219
+ ext = os.path.splitext(file_path)[1].lower()
220
+
221
+ if ext == ".json":
222
+ with open(file_path) as f:
223
+ json_content = json.load(f)
224
+
225
+ json_content = f"```json\n{json.dumps(json_content, indent=4)}\n```"
226
+ tensorboard_logger.experiment.add_text(tag=tag, text_string=json_content, global_step=0)
227
+ elif ext in [".yaml", ".yml"]:
228
+ with open(file_path) as f:
229
+ yaml_content = f.read()
230
+ yaml_content = f"```yaml\n{yaml_content}\n```"
231
+ tensorboard_logger.experiment.add_text(tag=tag, text_string=yaml_content, global_step=0)
232
+ else:
233
+ with open(file_path, encoding="utf-8") as f:
234
+ tensorboard_logger.experiment.add_text(tag=tag, text_string=f.read().replace("\n", " \n"), global_step=0)
235
+
236
+ tensorboard_logger.experiment.flush()
237
+
238
+
239
+ def finish(
240
+ config: DictConfig,
241
+ module: pl.LightningModule,
242
+ datamodule: pl.LightningDataModule,
243
+ trainer: pl.Trainer,
244
+ callbacks: list[pl.Callback],
245
+ logger: list[pl.loggers.Logger],
246
+ export_folder: str,
247
+ ) -> None:
248
+ """Upload config files to MLFlow server.
249
+
250
+ Args:
251
+ config: Configuration composed by Hydra.
252
+ module: LightningModule.
253
+ datamodule: LightningDataModule.
254
+ trainer: LightningTrainer.
255
+ callbacks: List of LightningCallbacks.
256
+ logger: List of LightningLoggers.
257
+ export_folder: Folder where the deployment models are exported.
258
+ """
259
+ # pylint: disable=unused-argument
260
+ if len(logger) > 0 and config.core.get("upload_artifacts"):
261
+ mlflow_logger = get_mlflow_logger(trainer=trainer)
262
+ tensorboard_logger = get_tensorboard_logger(trainer=trainer)
263
+ file_names = ["config.yaml", "config_resolved.yaml", "config_tree.txt", "data/dataset.csv"]
264
+ if "16" in str(trainer.precision):
265
+ index = _parse_gpu_ids(config.trainer.devices, include_cuda=True)[0]
266
+ device = "cuda:" + str(index)
267
+ half_precision = True
268
+ else:
269
+ device = "cpu"
270
+ half_precision = False
271
+
272
+ if mlflow_logger is not None:
273
+ config_paths = []
274
+
275
+ for f in file_names:
276
+ if os.path.isfile(os.path.join(os.getcwd(), f)):
277
+ config_paths.append(os.path.join(os.getcwd(), f))
278
+
279
+ for path in config_paths:
280
+ mlflow_logger.experiment.log_artifact(
281
+ run_id=mlflow_logger.run_id, local_path=path, artifact_path="metadata"
282
+ )
283
+
284
+ deployed_models = glob.glob(os.path.join(export_folder, "*"))
285
+ model_json: dict[str, Any] | None = None
286
+
287
+ if os.path.exists(os.path.join(export_folder, "model.json")):
288
+ with open(os.path.join(export_folder, "model.json")) as json_file:
289
+ model_json = json.load(json_file)
290
+
291
+ if model_json is not None:
292
+ input_size = model_json["input_size"]
293
+ # Not a huge fan of this check
294
+ if not isinstance(input_size[0], list):
295
+ # Input size is not a list of lists
296
+ input_size = [input_size]
297
+ inputs = cast(
298
+ list[Any],
299
+ quadra_export.generate_torch_inputs(input_size, device=device, half_precision=half_precision),
300
+ )
301
+ types_to_upload = config.core.get("upload_models")
302
+ for model_path in deployed_models:
303
+ model_type = model_type_from_path(model_path)
304
+ if model_type is None:
305
+ logging.warning("%s model type not supported", model_path)
306
+ continue
307
+ if model_type is not None and model_type in types_to_upload:
308
+ if model_type == "pytorch":
309
+ logging.warning("Pytorch format still not supported for mlflow upload")
310
+ continue
311
+
312
+ model = quadra_export.import_deployment_model(
313
+ model_path,
314
+ device=device,
315
+ inference_config=config.inference,
316
+ )
317
+
318
+ if model_type in ["torchscript", "pytorch"]:
319
+ signature = infer_signature_model(model.model, inputs)
320
+ with mlflow.start_run(run_id=mlflow_logger.run_id) as _:
321
+ mlflow.pytorch.log_model(
322
+ model.model,
323
+ artifact_path=model_path,
324
+ signature=signature,
325
+ )
326
+ elif model_type in ["onnx", "simplified_onnx"] and ONNX_AVAILABLE:
327
+ signature = infer_signature_model(model, inputs)
328
+ with mlflow.start_run(run_id=mlflow_logger.run_id) as _:
329
+ if model.model_path is None:
330
+ logging.warning(
331
+ "Cannot log onnx model on mlflow, \
332
+ BaseEvaluationModel 'model_path' attribute is None"
333
+ )
334
+ else:
335
+ model_proto = onnx.load(model.model_path)
336
+ mlflow.onnx.log_model(
337
+ model_proto,
338
+ artifact_path=model_path,
339
+ signature=signature,
340
+ )
341
+
342
+ if tensorboard_logger is not None:
343
+ config_paths = []
344
+ for f in file_names:
345
+ if os.path.isfile(os.path.join(os.getcwd(), f)):
346
+ config_paths.append(os.path.join(os.getcwd(), f))
347
+
348
+ for path in config_paths:
349
+ upload_file_tensorboard(file_path=path, tensorboard_logger=tensorboard_logger)
350
+
351
+ tensorboard_logger.experiment.flush()
352
+
353
+
354
+ def load_envs(env_file: str | None = None) -> None:
355
+ """Load all the environment variables defined in the `env_file`.
356
+ This is equivalent to `. env_file` in bash.
357
+
358
+ It is possible to define all the system specific variables in the `env_file`.
359
+
360
+ Args:
361
+ env_file: the file that defines the environment variables to use. If None
362
+ it searches for a `.env` file in the project.
363
+ """
364
+ dotenv.load_dotenv(dotenv_path=env_file, override=True)
365
+
366
+
367
+ def model_type_from_path(model_path: str) -> str | None:
368
+ """Determine the type of the machine learning model based on its file extension.
369
+
370
+ Parameters:
371
+ - model_path (str): The file path of the machine learning model.
372
+
373
+ Returns:
374
+ - str: The type of the model, which can be one of the following:
375
+ - "torchscript" if the model has a '.pt' extension (TorchScript).
376
+ - "pytorch" if the model has a '.pth' extension (PyTorch).
377
+ - "simplified_onnx" if the model file ends with 'simplified.onnx' (Simplified ONNX).
378
+ - "onnx" if the model has a '.onnx' extension (ONNX).
379
+ - "json" id the model has a '.json' extension (JSON).
380
+ - None if model extension is not supported.
381
+
382
+ Example:
383
+ ```python
384
+ model_path = "path/to/your/model.onnx"
385
+ model_type = model_type_from_path(model_path)
386
+ print(f"The model type is: {model_type}")
387
+ ```
388
+ """
389
+ if model_path.endswith(".pt"):
390
+ return "torchscript"
391
+ if model_path.endswith(".pth"):
392
+ return "pytorch"
393
+ if model_path.endswith("simplified.onnx"):
394
+ return "simplified_onnx"
395
+ if model_path.endswith(".onnx"):
396
+ return "onnx"
397
+ if model_path.endswith(".json"):
398
+ return "json"
399
+ return None
400
+
401
+
402
+ def setup_opencv() -> None:
403
+ """Setup OpenCV to use only one thread and not use OpenCL."""
404
+ cv2.setNumThreads(1)
405
+ cv2.ocl.setUseOpenCL(False)
406
+
407
+
408
+ def get_device(cuda: bool = True) -> str:
409
+ """Returns the device to use for training.
410
+
411
+ Args:
412
+ cuda: whether to use cuda or not
413
+
414
+ Returns:
415
+ The device to use
416
+ """
417
+ if torch.cuda.is_available() and cuda:
418
+ return "cuda:0"
419
+
420
+ return "cpu"
421
+
422
+
423
+ def nested_set(dic: dict, keys: list[str], value: str) -> None:
424
+ """Assign the value of a dictionary using nested keys."""
425
+ for key in keys[:-1]:
426
+ dic = dic.setdefault(key, {})
427
+
428
+ dic[keys[-1]] = value
429
+
430
+
431
+ def flatten_list(input_list: Iterable[Any]) -> Iterator[Any]:
432
+ """Return an iterator over the flattened list.
433
+
434
+ Args:
435
+ input_list: the list to be flattened
436
+
437
+ Yields:
438
+ The iterator over the flattend list
439
+ """
440
+ for v in input_list:
441
+ if isinstance(v, Iterable) and not isinstance(v, (str, bytes)):
442
+ yield from flatten_list(v)
443
+ else:
444
+ yield v
445
+
446
+
447
+ class HydraEncoder(json.JSONEncoder):
448
+ """Custom JSON encoder to handle OmegaConf objects."""
449
+
450
+ def default(self, o):
451
+ """Convert OmegaConf objects to base python objects."""
452
+ if o is not None and OmegaConf.is_config(o):
453
+ return OmegaConf.to_container(o)
454
+ return json.JSONEncoder.default(self, o)
455
+
456
+
457
+ class NumpyEncoder(json.JSONEncoder):
458
+ """Custom JSON encoder to handle numpy objects."""
459
+
460
+ def default(self, o):
461
+ """Custom JSON encoder to handle numpy objects."""
462
+ if o is not None:
463
+ if isinstance(o, np.ndarray):
464
+ if o.size == 1:
465
+ return o.item()
466
+ return o.tolist()
467
+ if isinstance(o, np.number):
468
+ return o.item()
469
+ return json.JSONEncoder.default(self, o)
470
+
471
+
472
+ class AllGatherSyncFunction(torch.autograd.Function):
473
+ """Function to gather gradients from multiple GPUs."""
474
+
475
+ @staticmethod
476
+ def forward(ctx, tensor):
477
+ ctx.batch_size = tensor.shape[0]
478
+
479
+ gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
480
+
481
+ torch.distributed.all_gather(gathered_tensor, tensor)
482
+ gathered_tensor = torch.cat(gathered_tensor, 0)
483
+
484
+ return gathered_tensor
485
+
486
+ @staticmethod
487
+ def backward(ctx, grad_output):
488
+ grad_input = grad_output.clone()
489
+ torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
490
+
491
+ idx_from = torch.distributed.get_rank() * ctx.batch_size
492
+ idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
493
+ return grad_input[idx_from:idx_to]
494
+
495
+
496
+ @torch.no_grad()
497
+ def concat_all_gather(tensor):
498
+ """Performs all_gather operation on the provided tensors.
499
+ *** Warning ***: torch.distributed.all_gather has no gradient.
500
+ """
501
+ tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
502
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
503
+
504
+ output = torch.cat(tensors_gather, dim=0)
505
+ return output
506
+
507
+
508
+ def get_tensorboard_logger(trainer: pl.Trainer) -> TensorBoardLogger | None:
509
+ """Safely get tensorboard logger from Lightning Trainer loggers.
510
+
511
+ Args:
512
+ trainer: Pytorch Lightning Trainer.
513
+
514
+ Returns:
515
+ An mlflow logger if available, else None.
516
+ """
517
+ if isinstance(trainer.logger, TensorBoardLogger):
518
+ return trainer.logger
519
+
520
+ if isinstance(trainer.logger, list):
521
+ for logger in trainer.logger:
522
+ if isinstance(logger, TensorBoardLogger):
523
+ return logger
524
+
525
+ return None
@@ -0,0 +1,115 @@
1
+ from __future__ import annotations
2
+
3
+ import difflib
4
+ import importlib
5
+ import inspect
6
+ from collections.abc import Iterable
7
+ from typing import Any
8
+
9
+ from omegaconf import DictConfig, ListConfig, OmegaConf
10
+
11
+ from quadra.utils.utils import get_logger
12
+
13
+ OMEGACONF_FIELDS: tuple[str, ...] = ("_target_", "_convert_", "_recursive_", "_args_")
14
+ EXCLUDE_KEYS: tuple[str, ...] = ("hydra",)
15
+
16
+ logger = get_logger(__name__)
17
+
18
+
19
+ def get_callable_arguments(full_module_path: str) -> tuple[list[str], bool]:
20
+ """Gets all arguments from module path.
21
+
22
+ Args:
23
+ full_module_path: Full module path to the target class or function.
24
+
25
+ Raises:
26
+ ValueError: If the target is not a class or a function.
27
+
28
+ Returns:
29
+ All arguments from the target class or function.
30
+ """
31
+ module_path, callable_name = full_module_path.rsplit(".", 1)
32
+ module = importlib.import_module(module_path)
33
+ callable_ = getattr(module, callable_name)
34
+ # check if it is a class
35
+ accepts_kwargs = False
36
+ if inspect.isclass(callable_):
37
+ arg_names = []
38
+ for cls in callable_.__mro__:
39
+ if cls is object:
40
+ break
41
+ # We don' access the instance but mypy complains
42
+ init_argspec = inspect.getfullargspec(cls.__init__) # type: ignore
43
+ cls_arg_names = init_argspec.args[1:]
44
+ cls_kwonlyargs = init_argspec.kwonlyargs
45
+ arg_names.extend(cls_arg_names)
46
+ arg_names.extend(cls_kwonlyargs)
47
+ # if the target class or function accepts kwargs, we cannot check arguments
48
+ accepts_kwargs = init_argspec.varkw is not None or accepts_kwargs
49
+ arg_names = list(set(arg_names))
50
+ elif inspect.isfunction(callable_):
51
+ init_argspec = inspect.getfullargspec(callable_)
52
+ arg_names = []
53
+ arg_names.extend(init_argspec.args)
54
+ arg_names.extend(init_argspec.kwonlyargs)
55
+ accepts_kwargs = init_argspec.varkw is not None or accepts_kwargs
56
+ else:
57
+ raise ValueError("The target must be a class or a function.")
58
+
59
+ return arg_names, accepts_kwargs
60
+
61
+
62
+ def check_all_arguments(callable_variable: str, configuration_arguments: list[str], argument_names: list[str]) -> None:
63
+ """Checks if all arguments passed from configuration are valid for the target class or function.
64
+
65
+ Args:
66
+ callable_variable : Full module path to the target class or function.
67
+ configuration_arguments : All arguments passed from configuration.
68
+ argument_names: All arguments from the target class or function.
69
+
70
+ Raises:
71
+ ValueError: If the argument is not valid for the target class or function.
72
+ """
73
+ for argument in configuration_arguments:
74
+ if argument not in argument_names:
75
+ error_string = (
76
+ f"`{argument}` is not a valid argument passed " f"from configuration to `{callable_variable}`."
77
+ )
78
+ closest_match = difflib.get_close_matches(argument, argument_names, n=1, cutoff=0.5)
79
+ if len(closest_match) > 0:
80
+ error_string += f" Did you mean `{closest_match[0]}`?"
81
+ raise ValueError(error_string)
82
+
83
+
84
+ def validate_config(_cfg: DictConfig | ListConfig, package_name: str = "quadra") -> None:
85
+ """Recursively traverse OmegaConf object and check if arguments are valid for the target class or function.
86
+ If not, raise a ValueError with a suggestion for the closest match of the argument name.
87
+
88
+ Args:
89
+ _cfg: OmegaConf object
90
+ package_name: package name to check for instantiation.
91
+ """
92
+ # The below lines of code for looping over a DictConfig/ListConfig are
93
+ # borrowed from OmegaConf PR #719.
94
+ itr: Iterable[Any]
95
+ if isinstance(_cfg, ListConfig):
96
+ itr = range(len(_cfg))
97
+ else:
98
+ itr = _cfg
99
+ for key in itr:
100
+ if OmegaConf.is_missing(_cfg, key):
101
+ continue
102
+ if isinstance(key, str) and any(x in key for x in EXCLUDE_KEYS):
103
+ continue
104
+ if OmegaConf.is_config(_cfg[key]):
105
+ validate_config(_cfg[key])
106
+ elif isinstance(_cfg[key], str):
107
+ if key == "_target_":
108
+ callable_variable = str(_cfg[key])
109
+ if callable_variable.startswith(package_name):
110
+ configuration_arguments = [str(x) for x in _cfg if x not in OMEGACONF_FIELDS]
111
+ argument_names, accepts_kwargs = get_callable_arguments(callable_variable)
112
+ if not accepts_kwargs:
113
+ check_all_arguments(callable_variable, configuration_arguments, argument_names)
114
+ else:
115
+ logger.debug("Skipping %s from config. It is not supported.", key)