quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,291 @@
1
+ from __future__ import annotations
2
+
3
+ import glob
4
+ import os
5
+ from typing import Any, Literal
6
+
7
+ import torch
8
+ from pytorch_lightning import Callback, LightningModule, Trainer
9
+ from pytorch_lightning.callbacks import LearningRateMonitor
10
+ from pytorch_lightning.loggers import MLFlowLogger
11
+ from pytorch_lightning.utilities import rank_zero_only
12
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
13
+
14
+ from quadra.utils.mlflow import get_mlflow_logger
15
+
16
+
17
+ def check_minio_credentials() -> None:
18
+ """Check minio credentials for aws based storage such as minio.
19
+
20
+ Returns:
21
+ None
22
+ """
23
+ check = os.environ.get("AWS_ACCESS_KEY_ID") is not None and os.environ.get("AWS_SECRET_ACCESS_KEY") is not None
24
+ if not check:
25
+ raise ValueError(
26
+ "You are trying to upload mlflow artifacts, but minio credentials are not set. Please set them in your"
27
+ " environment variables."
28
+ )
29
+
30
+
31
+ def check_file_server_dependencies() -> None:
32
+ """Check file dependencies as boto3.
33
+
34
+ Returns:
35
+ None
36
+ """
37
+ try:
38
+ # pylint: disable=unused-import,import-outside-toplevel
39
+ import boto3 # noqa
40
+ import minio # noqa
41
+ except ImportError as e:
42
+ raise ImportError(
43
+ "You are trying to upload mlflow artifacts, but boto3 and minio are not installed. Please install them by"
44
+ " calling pip install minio boto3."
45
+ ) from e
46
+
47
+
48
+ def validate_artifact_storage(logger: MLFlowLogger):
49
+ """Validate artifact storage.
50
+
51
+ Args:
52
+ logger: Mlflow logger from pytorch lightning.
53
+
54
+ """
55
+ from quadra.utils.utils import get_logger # pylint: disable=[import-outside-toplevel]
56
+
57
+ log = get_logger(__name__)
58
+
59
+ client = logger.experiment
60
+ # TODO: we have to access the internal api to get the artifact uri, however there could be a better way
61
+ artifact_uri = client._tracking_client._get_artifact_repo( # pylint: disable=protected-access
62
+ logger.run_id
63
+ ).artifact_uri
64
+ if artifact_uri.startswith("s3://"):
65
+ check_minio_credentials()
66
+ check_file_server_dependencies()
67
+ log.info("Mlflow artifact storage is AWS/S3 basedand credentials and dependencies are satisfied.")
68
+ else:
69
+ log.info("Mlflow artifact storage uri is %s. Validation checks are not implemented.", artifact_uri)
70
+
71
+
72
+ class UploadCodeAsArtifact(Callback):
73
+ """Callback used to upload Code as artifact.
74
+
75
+ Uploads all *.py files to mlflow as an artifact, at the beginning of the run but
76
+ after initializing the trainer. It creates project-source folder under mlflow
77
+ artifacts and other necessary subfolders.
78
+
79
+ Args:
80
+ source_dir: Folder where all the source files are stored.
81
+ """
82
+
83
+ def __init__(self, source_dir: str):
84
+ self.source_dir = source_dir
85
+
86
+ @rank_zero_only
87
+ def on_test_end(self, trainer: Trainer, pl_module: LightningModule):
88
+ """Triggered at the end of test. Uploads all *.py files to mlflow as an artifact.
89
+
90
+ Args:
91
+ trainer: Pytorch Lightning trainer.
92
+ pl_module: Pytorch Lightning module.
93
+ """
94
+ logger = get_mlflow_logger(trainer=trainer)
95
+
96
+ if logger is None:
97
+ return
98
+
99
+ experiment = logger.experiment
100
+
101
+ for path in glob.glob(os.path.join(self.source_dir, "**/*.py"), recursive=True):
102
+ stripped_path = path.replace(self.source_dir, "")
103
+ if len(stripped_path.split("/")) > 1:
104
+ file_path_tree = "/" + "/".join(stripped_path.split("/")[:-1])
105
+ else:
106
+ file_path_tree = ""
107
+ experiment.log_artifact(
108
+ run_id=logger.run_id,
109
+ local_path=path,
110
+ artifact_path=f"project-source{file_path_tree}",
111
+ )
112
+
113
+
114
+ class LogGradients(Callback):
115
+ """Callback used to logs of the model at the end of the of each training step.
116
+
117
+ Args:
118
+ norm: Norm to use for the gradient. Default is L2 norm.
119
+ tag: Tag to add to the gradients. If None, no tag will be added.
120
+ sep: Separator to use in the log.
121
+ round_to: Number of decimals to round the gradients to.
122
+ log_all_grads: If True, log all gradients, not just the total norm.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ norm: int = 2,
128
+ tag: str | None = None,
129
+ sep: str = "/",
130
+ round_to: int = 3,
131
+ log_all_grads: bool = False,
132
+ ):
133
+ self.norm = norm
134
+ self.tag = tag
135
+ self.sep = sep
136
+ self.round_to = round_to
137
+ self.log_all_grads = log_all_grads
138
+
139
+ def _grad_norm(self, named_params) -> dict:
140
+ """Compute the gradient norm and return it in a dictionary."""
141
+ grad_tag = "" if self.tag is None else "_" + self.tag
142
+ results = {}
143
+ for name, p in named_params:
144
+ if p.requires_grad and p.grad is not None:
145
+ norm = float(p.grad.data.norm(self.norm))
146
+ key = f"grad_norm_{self.norm}{grad_tag}{self.sep}{name}"
147
+ results[key] = round(norm, 3)
148
+ total_norm = float(torch.tensor(list(results.values())).norm(self.norm))
149
+ if not self.log_all_grads:
150
+ # clear dictionary
151
+ results = {}
152
+ results[f"grad_norm_{self.norm}_total{grad_tag}"] = round(total_norm, self.round_to)
153
+ return results
154
+
155
+ @rank_zero_only
156
+ def on_train_batch_end(
157
+ self,
158
+ trainer: Trainer,
159
+ pl_module: LightningModule,
160
+ outputs: STEP_OUTPUT,
161
+ batch: Any,
162
+ batch_idx: int,
163
+ unused: int | None = 0,
164
+ ) -> None:
165
+ """Method called at the end of the train batch
166
+ Args:
167
+ trainer: pl.trainer
168
+ pl_module: lightning module
169
+ outputs: outputs
170
+ batch: batch
171
+ batch_idx: index
172
+ unused: dl index.
173
+
174
+
175
+ Returns:
176
+ None
177
+ """
178
+ # pylint: disable=unused-argument
179
+ logger = get_mlflow_logger(trainer=trainer)
180
+
181
+ if logger is None:
182
+ return
183
+
184
+ named_params = pl_module.named_parameters()
185
+ grads = self._grad_norm(named_params)
186
+ logger.log_metrics(grads)
187
+
188
+
189
+ class UploadCheckpointsAsArtifact(Callback):
190
+ """Callback used to upload checkpoints as artifacts.
191
+
192
+ Args:
193
+ ckpt_dir: Folder where all the checkpoints are stored in artifact folder.
194
+ ckpt_ext: Extension of checkpoint files (default: ckpt).
195
+ upload_best_only: Only upload best checkpoint (default: False)
196
+ delete_after_upload: Delete the checkpoint from local storage after uploading (default: True)
197
+ upload: If True, upload the checkpoints. If False, only save them on local machine.
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ ckpt_dir: str = "checkpoints/",
203
+ ckpt_ext: str = "ckpt",
204
+ upload_best_only: bool = False,
205
+ delete_after_upload: bool = True,
206
+ upload: bool = True,
207
+ ):
208
+ self.ckpt_dir = ckpt_dir
209
+ self.upload_best_only = upload_best_only
210
+ self.ckpt_ext = ckpt_ext
211
+ self.delete_after_upload = delete_after_upload
212
+ self.upload = upload
213
+
214
+ @rank_zero_only
215
+ def on_test_end(self, trainer: Trainer, pl_module: LightningModule):
216
+ """Triggered at the end of test. Uploads all model checkpoints to mlflow as an artifact.
217
+
218
+ Args:
219
+ trainer: Pytorch Lightning trainer.
220
+ pl_module: Pytorch Lightning module.
221
+ """
222
+ logger = get_mlflow_logger(trainer=trainer)
223
+
224
+ if logger is None:
225
+ return
226
+
227
+ experiment = logger.experiment
228
+
229
+ if (
230
+ trainer.checkpoint_callback
231
+ and self.upload_best_only
232
+ and hasattr(trainer.checkpoint_callback, "best_model_path")
233
+ ):
234
+ if self.upload:
235
+ experiment.log_artifact(
236
+ run_id=logger.run_id,
237
+ local_path=trainer.checkpoint_callback.best_model_path,
238
+ artifact_path="checkpoints",
239
+ )
240
+ else:
241
+ for path in glob.glob(os.path.join(self.ckpt_dir, f"**/*.{self.ckpt_ext}"), recursive=True):
242
+ if self.upload:
243
+ experiment.log_artifact(
244
+ run_id=logger.run_id,
245
+ local_path=path,
246
+ artifact_path="checkpoints",
247
+ )
248
+ if self.delete_after_upload:
249
+ for path in glob.glob(os.path.join(self.ckpt_dir, f"**/*.{self.ckpt_ext}"), recursive=True):
250
+ os.remove(path)
251
+
252
+
253
+ class LogLearningRate(LearningRateMonitor):
254
+ """Learning rate logger at the end of the training step/epoch.
255
+
256
+ Args:
257
+ logging_interval: Logging interval.
258
+ log_momentum: If True, log momentum as well.
259
+ """
260
+
261
+ def __init__(self, logging_interval: Literal["step", "epoch"] | None = None, log_momentum: bool = False):
262
+ super().__init__(logging_interval=logging_interval, log_momentum=log_momentum)
263
+
264
+ def on_train_batch_start(self, trainer, *args, **kwargs):
265
+ """Log learning rate at the beginning of the training step if logging interval is set to step."""
266
+ if not trainer.logger_connector.should_update_logs:
267
+ return
268
+ if self.logging_interval != "epoch":
269
+ logger = get_mlflow_logger(trainer=trainer)
270
+
271
+ if logger is None:
272
+ return
273
+
274
+ interval = "step" if self.logging_interval is None else "any"
275
+ latest_stat = self._extract_stats(trainer, interval)
276
+
277
+ if latest_stat:
278
+ logger.log_metrics(latest_stat, step=trainer.global_step)
279
+
280
+ def on_train_epoch_start(self, trainer, *args, **kwargs):
281
+ """Log learning rate at the beginning of the epoch if logging interval is set to epoch."""
282
+ if self.logging_interval != "step":
283
+ interval = "epoch" if self.logging_interval is None else "any"
284
+ latest_stat = self._extract_stats(trainer, interval)
285
+ logger = get_mlflow_logger(trainer=trainer)
286
+
287
+ if logger is None:
288
+ return
289
+
290
+ if latest_stat:
291
+ logger.log_metrics(latest_stat, step=trainer.global_step)
@@ -0,0 +1,69 @@
1
+ import hydra
2
+ import pytorch_lightning as pl
3
+ from omegaconf import DictConfig
4
+ from pytorch_lightning import Callback
5
+ from pytorch_lightning.utilities import rank_zero_only
6
+
7
+ from quadra.schedulers.warmup import CosineAnnealingWithLinearWarmUp
8
+ from quadra.utils.utils import get_logger
9
+
10
+ log = get_logger(__name__)
11
+
12
+
13
+ class WarmupInit(Callback):
14
+ """Custom callback used to setup a warmup scheduler.
15
+
16
+ Args:
17
+ scheduler_config: scheduler configuration.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ scheduler_config: DictConfig,
23
+ ) -> None:
24
+ self.scheduler_config = scheduler_config
25
+
26
+ @rank_zero_only
27
+ def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
28
+ """Called when fit begins."""
29
+ if not hasattr(trainer, "datamodule"):
30
+ raise ValueError("Trainer must have a datamodule attribute.")
31
+
32
+ if not any(isinstance(s.scheduler, CosineAnnealingWithLinearWarmUp) for s in trainer.lr_scheduler_configs):
33
+ return
34
+
35
+ log.info("Using warmup scheduler, forcing optimizer learning rate to zero.")
36
+ for i, _ in enumerate(trainer.optimizers):
37
+ for param_group in trainer.optimizers[i].param_groups:
38
+ param_group["lr"] = 0.0
39
+ trainer.optimizers[i].defaults["lr"] = 0.0
40
+
41
+ batch_size = trainer.datamodule.batch_size
42
+ train_dataloader = trainer.datamodule.train_dataloader()
43
+ len_train_dataloader = len(train_dataloader)
44
+ if isinstance(trainer.device_ids, list) and pl_module.device.type == "cuda":
45
+ num_gpus = len(trainer.device_ids)
46
+ len_train_dataloader = len_train_dataloader // num_gpus
47
+ if not train_dataloader.drop_last:
48
+ len_train_dataloader += int((len_train_dataloader % num_gpus) != 0)
49
+
50
+ if len_train_dataloader == 1:
51
+ log.warning(
52
+ "From this dataset size, we can only generate single batch. The batch size will be set as lenght of"
53
+ " the dataset "
54
+ )
55
+ batch_size = len(train_dataloader.dataset)
56
+
57
+ if isinstance(trainer.device_ids, list) and pl_module.device.type == "cuda":
58
+ batch_size = batch_size * len(trainer.device_ids)
59
+
60
+ scheduler = hydra.utils.instantiate(
61
+ self.scheduler_config,
62
+ optimizer=pl_module.optimizer,
63
+ batch_size=batch_size,
64
+ len_loader=len_train_dataloader,
65
+ )
66
+
67
+ for i, s in enumerate(trainer.lr_scheduler_configs):
68
+ if isinstance(s.scheduler, CosineAnnealingWithLinearWarmUp):
69
+ trainer.lr_scheduler_configs[i].scheduler = scheduler
File without changes
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: caformer_m36.sail_in22k_ft_in1k
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 576
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: caformer_s36.sail_in22k_ft_in1k
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 512
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: convnextv2_base.fcmae
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 1024
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: convnextv2_femto.fcmae
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 384
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: convnextv2_tiny.fcmae
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 768
@@ -0,0 +1,12 @@
1
+ model:
2
+ _target_: quadra.models.classification.TorchHubNetworkBuilder
3
+ repo_or_dir: facebookresearch/dino:main
4
+ model_name: dino_vitb8
5
+ pretrained: true
6
+ freeze: false
7
+ hyperspherical: false
8
+ metadata:
9
+ input_size: 224
10
+ output_dim: 768
11
+ patch_size: 8
12
+ nb_heads: 12
@@ -0,0 +1,12 @@
1
+ model:
2
+ _target_: quadra.models.classification.TorchHubNetworkBuilder
3
+ repo_or_dir: facebookresearch/dino:main
4
+ model_name: dino_vits8
5
+ pretrained: true
6
+ freeze: false
7
+ hyperspherical: false
8
+ metadata:
9
+ input_size: 224
10
+ output_dim: 384
11
+ patch_size: 8
12
+ nb_heads: 6
@@ -0,0 +1,12 @@
1
+ model:
2
+ _target_: quadra.models.classification.TorchHubNetworkBuilder
3
+ repo_or_dir: facebookresearch/dinov2
4
+ model_name: dinov2_vitb14
5
+ pretrained: true
6
+ freeze: false
7
+ hyperspherical: false
8
+ metadata:
9
+ input_size: 224
10
+ output_dim: 768
11
+ patch_size: 14
12
+ nb_heads: 12
@@ -0,0 +1,12 @@
1
+ model:
2
+ _target_: quadra.models.classification.TorchHubNetworkBuilder
3
+ repo_or_dir: facebookresearch/dinov2
4
+ model_name: dinov2_vits14
5
+ pretrained: true
6
+ freeze: false
7
+ hyperspherical: false
8
+ metadata:
9
+ input_size: 224
10
+ output_dim: 384
11
+ patch_size: 14
12
+ nb_heads: 6
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: tf_efficientnet_b0.ns_jft_in1k
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 1280
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: tf_efficientnet_b1.ns_jft_in1k
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 1280
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: tf_efficientnet_b2.ns_jft_in1k
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 260
8
+ output_dim: 1408
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: tf_efficientnet_b3.ns_jft_in1k
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 300
8
+ output_dim: 1536
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: tf_efficientnetv2_s_in21ft1k
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 384
8
+ output_dim: 1280
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: levit_128s
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 384
@@ -0,0 +1,9 @@
1
+ model:
2
+ _target_: quadra.models.classification.TorchVisionNetworkBuilder
3
+ model_name: mnasnet0_5
4
+ pretrained: true
5
+ freeze: false
6
+ hyperspherical: false
7
+ metadata:
8
+ input_size: 224
9
+ output_dim: 1280
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: resnet101
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 2048
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: resnet18.tv_in1k # Use torchvision weights
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 512
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: ssl_resnet18
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 512
@@ -0,0 +1,8 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: resnet50
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 2048
@@ -0,0 +1,9 @@
1
+ model:
2
+ _target_: quadra.modules.backbone.create_smp_backbone
3
+ arch: unet
4
+ encoder_name: resnet18
5
+ encoder_weights: imagenet
6
+ freeze_encoder: True
7
+ in_channels: 3
8
+ num_classes: 1
9
+ activation: null
@@ -0,0 +1,9 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: tiny_vit_21m_224.dist_in22k_ft_in1k
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 576
9
+ num_heads: 18 # Is it correct?
@@ -0,0 +1,15 @@
1
+ model:
2
+ _target_: monai.networks.nets.unetr.UNETR
3
+ in_channels: 3
4
+ out_channels: 1
5
+ img_size: [448, 448]
6
+ feature_size: 16
7
+ hidden_size: 384 # 192
8
+ mlp_dim: 1536 # 768
9
+ num_heads: 8 # 3
10
+ pos_embed: conv
11
+ norm_name: instance
12
+ conv_block: true
13
+ res_block: true
14
+ dropout_rate: 0
15
+ spatial_dims: 2
@@ -0,0 +1,9 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: vit_base_patch16_224
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 768
9
+ nb_heads: 6
@@ -0,0 +1,9 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: vit_small_patch16_224
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 384
9
+ nb_heads: 6
@@ -0,0 +1,9 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: vit_tiny_patch16_224
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 192
9
+ nb_heads: 3
@@ -0,0 +1,9 @@
1
+ model:
2
+ _target_: quadra.models.classification.TimmNetworkBuilder
3
+ model_name: xcit_tiny_24_p8_224
4
+ pretrained: true
5
+ freeze: false
6
+ metadata:
7
+ input_size: 224
8
+ output_dim: 192
9
+ nb_heads: 4