quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (302) hide show
  1. hydra_plugins/quadra_searchpath_plugin.py +30 -0
  2. quadra/__init__.py +6 -0
  3. quadra/callbacks/__init__.py +0 -0
  4. quadra/callbacks/anomalib.py +289 -0
  5. quadra/callbacks/lightning.py +501 -0
  6. quadra/callbacks/mlflow.py +291 -0
  7. quadra/callbacks/scheduler.py +69 -0
  8. quadra/configs/__init__.py +0 -0
  9. quadra/configs/backbone/caformer_m36.yaml +8 -0
  10. quadra/configs/backbone/caformer_s36.yaml +8 -0
  11. quadra/configs/backbone/convnextv2_base.yaml +8 -0
  12. quadra/configs/backbone/convnextv2_femto.yaml +8 -0
  13. quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
  14. quadra/configs/backbone/dino_vitb8.yaml +12 -0
  15. quadra/configs/backbone/dino_vits8.yaml +12 -0
  16. quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
  17. quadra/configs/backbone/dinov2_vits14.yaml +12 -0
  18. quadra/configs/backbone/efficientnet_b0.yaml +8 -0
  19. quadra/configs/backbone/efficientnet_b1.yaml +8 -0
  20. quadra/configs/backbone/efficientnet_b2.yaml +8 -0
  21. quadra/configs/backbone/efficientnet_b3.yaml +8 -0
  22. quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
  23. quadra/configs/backbone/levit_128s.yaml +8 -0
  24. quadra/configs/backbone/mnasnet0_5.yaml +9 -0
  25. quadra/configs/backbone/resnet101.yaml +8 -0
  26. quadra/configs/backbone/resnet18.yaml +8 -0
  27. quadra/configs/backbone/resnet18_ssl.yaml +8 -0
  28. quadra/configs/backbone/resnet50.yaml +8 -0
  29. quadra/configs/backbone/smp.yaml +9 -0
  30. quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
  31. quadra/configs/backbone/unetr.yaml +15 -0
  32. quadra/configs/backbone/vit16_base.yaml +9 -0
  33. quadra/configs/backbone/vit16_small.yaml +9 -0
  34. quadra/configs/backbone/vit16_tiny.yaml +9 -0
  35. quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
  36. quadra/configs/callbacks/all.yaml +45 -0
  37. quadra/configs/callbacks/default.yaml +34 -0
  38. quadra/configs/callbacks/default_anomalib.yaml +64 -0
  39. quadra/configs/config.yaml +33 -0
  40. quadra/configs/core/default.yaml +11 -0
  41. quadra/configs/datamodule/base/anomaly.yaml +16 -0
  42. quadra/configs/datamodule/base/classification.yaml +21 -0
  43. quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
  44. quadra/configs/datamodule/base/segmentation.yaml +18 -0
  45. quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
  46. quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
  47. quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
  48. quadra/configs/datamodule/base/ssl.yaml +21 -0
  49. quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
  50. quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
  51. quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
  52. quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
  53. quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
  54. quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
  55. quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
  56. quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
  57. quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
  58. quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
  59. quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
  60. quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
  61. quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
  62. quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
  63. quadra/configs/experiment/base/classification/classification.yaml +73 -0
  64. quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
  65. quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
  66. quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
  67. quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
  68. quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
  69. quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
  70. quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
  71. quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
  72. quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
  73. quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
  74. quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
  75. quadra/configs/experiment/base/ssl/byol.yaml +43 -0
  76. quadra/configs/experiment/base/ssl/dino.yaml +46 -0
  77. quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
  78. quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
  79. quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
  80. quadra/configs/experiment/custom/cls.yaml +12 -0
  81. quadra/configs/experiment/default.yaml +15 -0
  82. quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
  83. quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
  84. quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
  85. quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
  86. quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
  87. quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
  88. quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
  89. quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
  90. quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
  91. quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
  92. quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
  93. quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
  94. quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
  95. quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
  96. quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
  97. quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
  98. quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
  99. quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
  100. quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
  101. quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
  102. quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
  103. quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
  104. quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
  105. quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
  106. quadra/configs/export/default.yaml +13 -0
  107. quadra/configs/hydra/anomaly_custom.yaml +15 -0
  108. quadra/configs/hydra/default.yaml +14 -0
  109. quadra/configs/inference/default.yaml +26 -0
  110. quadra/configs/logger/comet.yaml +10 -0
  111. quadra/configs/logger/csv.yaml +5 -0
  112. quadra/configs/logger/mlflow.yaml +12 -0
  113. quadra/configs/logger/tensorboard.yaml +8 -0
  114. quadra/configs/loss/asl.yaml +7 -0
  115. quadra/configs/loss/barlow.yaml +2 -0
  116. quadra/configs/loss/bce.yaml +1 -0
  117. quadra/configs/loss/byol.yaml +1 -0
  118. quadra/configs/loss/cross_entropy.yaml +1 -0
  119. quadra/configs/loss/dino.yaml +8 -0
  120. quadra/configs/loss/simclr.yaml +2 -0
  121. quadra/configs/loss/simsiam.yaml +1 -0
  122. quadra/configs/loss/smp_ce.yaml +3 -0
  123. quadra/configs/loss/smp_dice.yaml +2 -0
  124. quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
  125. quadra/configs/loss/smp_mcc.yaml +2 -0
  126. quadra/configs/loss/vicreg.yaml +5 -0
  127. quadra/configs/model/anomalib/cfa.yaml +35 -0
  128. quadra/configs/model/anomalib/cflow.yaml +30 -0
  129. quadra/configs/model/anomalib/csflow.yaml +34 -0
  130. quadra/configs/model/anomalib/dfm.yaml +19 -0
  131. quadra/configs/model/anomalib/draem.yaml +29 -0
  132. quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
  133. quadra/configs/model/anomalib/fastflow.yaml +32 -0
  134. quadra/configs/model/anomalib/padim.yaml +32 -0
  135. quadra/configs/model/anomalib/patchcore.yaml +36 -0
  136. quadra/configs/model/barlow.yaml +16 -0
  137. quadra/configs/model/byol.yaml +25 -0
  138. quadra/configs/model/classification.yaml +10 -0
  139. quadra/configs/model/dino.yaml +26 -0
  140. quadra/configs/model/logistic_regression.yaml +4 -0
  141. quadra/configs/model/multilabel_classification.yaml +9 -0
  142. quadra/configs/model/simclr.yaml +18 -0
  143. quadra/configs/model/simsiam.yaml +24 -0
  144. quadra/configs/model/smp.yaml +4 -0
  145. quadra/configs/model/smp_multiclass.yaml +4 -0
  146. quadra/configs/model/vicreg.yaml +16 -0
  147. quadra/configs/optimizer/adam.yaml +5 -0
  148. quadra/configs/optimizer/adamw.yaml +3 -0
  149. quadra/configs/optimizer/default.yaml +4 -0
  150. quadra/configs/optimizer/lars.yaml +8 -0
  151. quadra/configs/optimizer/sgd.yaml +4 -0
  152. quadra/configs/scheduler/default.yaml +5 -0
  153. quadra/configs/scheduler/rop.yaml +5 -0
  154. quadra/configs/scheduler/step.yaml +3 -0
  155. quadra/configs/scheduler/warmrestart.yaml +2 -0
  156. quadra/configs/scheduler/warmup.yaml +6 -0
  157. quadra/configs/task/anomalib/cfa.yaml +5 -0
  158. quadra/configs/task/anomalib/cflow.yaml +5 -0
  159. quadra/configs/task/anomalib/csflow.yaml +5 -0
  160. quadra/configs/task/anomalib/draem.yaml +5 -0
  161. quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
  162. quadra/configs/task/anomalib/fastflow.yaml +5 -0
  163. quadra/configs/task/anomalib/inference.yaml +3 -0
  164. quadra/configs/task/anomalib/padim.yaml +5 -0
  165. quadra/configs/task/anomalib/patchcore.yaml +5 -0
  166. quadra/configs/task/classification.yaml +6 -0
  167. quadra/configs/task/classification_evaluation.yaml +6 -0
  168. quadra/configs/task/default.yaml +1 -0
  169. quadra/configs/task/segmentation.yaml +9 -0
  170. quadra/configs/task/segmentation_evaluation.yaml +3 -0
  171. quadra/configs/task/sklearn_classification.yaml +13 -0
  172. quadra/configs/task/sklearn_classification_patch.yaml +11 -0
  173. quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
  174. quadra/configs/task/sklearn_classification_test.yaml +8 -0
  175. quadra/configs/task/ssl.yaml +2 -0
  176. quadra/configs/trainer/lightning_cpu.yaml +36 -0
  177. quadra/configs/trainer/lightning_gpu.yaml +35 -0
  178. quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
  179. quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
  180. quadra/configs/trainer/lightning_multigpu.yaml +37 -0
  181. quadra/configs/trainer/sklearn_classification.yaml +7 -0
  182. quadra/configs/transforms/byol.yaml +47 -0
  183. quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
  184. quadra/configs/transforms/default.yaml +37 -0
  185. quadra/configs/transforms/default_numpy.yaml +24 -0
  186. quadra/configs/transforms/default_resize.yaml +22 -0
  187. quadra/configs/transforms/dino.yaml +63 -0
  188. quadra/configs/transforms/linear_eval.yaml +18 -0
  189. quadra/datamodules/__init__.py +20 -0
  190. quadra/datamodules/anomaly.py +180 -0
  191. quadra/datamodules/base.py +375 -0
  192. quadra/datamodules/classification.py +1003 -0
  193. quadra/datamodules/generic/__init__.py +0 -0
  194. quadra/datamodules/generic/imagenette.py +144 -0
  195. quadra/datamodules/generic/mnist.py +81 -0
  196. quadra/datamodules/generic/mvtec.py +58 -0
  197. quadra/datamodules/generic/oxford_pet.py +163 -0
  198. quadra/datamodules/patch.py +190 -0
  199. quadra/datamodules/segmentation.py +742 -0
  200. quadra/datamodules/ssl.py +140 -0
  201. quadra/datasets/__init__.py +17 -0
  202. quadra/datasets/anomaly.py +287 -0
  203. quadra/datasets/classification.py +241 -0
  204. quadra/datasets/patch.py +138 -0
  205. quadra/datasets/segmentation.py +239 -0
  206. quadra/datasets/ssl.py +110 -0
  207. quadra/losses/__init__.py +0 -0
  208. quadra/losses/classification/__init__.py +6 -0
  209. quadra/losses/classification/asl.py +83 -0
  210. quadra/losses/classification/focal.py +320 -0
  211. quadra/losses/classification/prototypical.py +148 -0
  212. quadra/losses/ssl/__init__.py +17 -0
  213. quadra/losses/ssl/barlowtwins.py +47 -0
  214. quadra/losses/ssl/byol.py +37 -0
  215. quadra/losses/ssl/dino.py +129 -0
  216. quadra/losses/ssl/hyperspherical.py +45 -0
  217. quadra/losses/ssl/idmm.py +50 -0
  218. quadra/losses/ssl/simclr.py +67 -0
  219. quadra/losses/ssl/simsiam.py +30 -0
  220. quadra/losses/ssl/vicreg.py +76 -0
  221. quadra/main.py +49 -0
  222. quadra/metrics/__init__.py +3 -0
  223. quadra/metrics/segmentation.py +251 -0
  224. quadra/models/__init__.py +0 -0
  225. quadra/models/base.py +151 -0
  226. quadra/models/classification/__init__.py +8 -0
  227. quadra/models/classification/backbones.py +149 -0
  228. quadra/models/classification/base.py +92 -0
  229. quadra/models/evaluation.py +322 -0
  230. quadra/modules/__init__.py +0 -0
  231. quadra/modules/backbone.py +30 -0
  232. quadra/modules/base.py +312 -0
  233. quadra/modules/classification/__init__.py +3 -0
  234. quadra/modules/classification/base.py +327 -0
  235. quadra/modules/ssl/__init__.py +17 -0
  236. quadra/modules/ssl/barlowtwins.py +59 -0
  237. quadra/modules/ssl/byol.py +172 -0
  238. quadra/modules/ssl/common.py +285 -0
  239. quadra/modules/ssl/dino.py +186 -0
  240. quadra/modules/ssl/hyperspherical.py +206 -0
  241. quadra/modules/ssl/idmm.py +98 -0
  242. quadra/modules/ssl/simclr.py +73 -0
  243. quadra/modules/ssl/simsiam.py +68 -0
  244. quadra/modules/ssl/vicreg.py +67 -0
  245. quadra/optimizers/__init__.py +4 -0
  246. quadra/optimizers/lars.py +153 -0
  247. quadra/optimizers/sam.py +127 -0
  248. quadra/schedulers/__init__.py +3 -0
  249. quadra/schedulers/base.py +44 -0
  250. quadra/schedulers/warmup.py +127 -0
  251. quadra/tasks/__init__.py +24 -0
  252. quadra/tasks/anomaly.py +582 -0
  253. quadra/tasks/base.py +397 -0
  254. quadra/tasks/classification.py +1263 -0
  255. quadra/tasks/patch.py +492 -0
  256. quadra/tasks/segmentation.py +389 -0
  257. quadra/tasks/ssl.py +560 -0
  258. quadra/trainers/README.md +3 -0
  259. quadra/trainers/__init__.py +0 -0
  260. quadra/trainers/classification.py +179 -0
  261. quadra/utils/__init__.py +0 -0
  262. quadra/utils/anomaly.py +112 -0
  263. quadra/utils/classification.py +618 -0
  264. quadra/utils/deprecation.py +31 -0
  265. quadra/utils/evaluation.py +474 -0
  266. quadra/utils/export.py +585 -0
  267. quadra/utils/imaging.py +32 -0
  268. quadra/utils/logger.py +15 -0
  269. quadra/utils/mlflow.py +98 -0
  270. quadra/utils/model_manager.py +320 -0
  271. quadra/utils/models.py +523 -0
  272. quadra/utils/patch/__init__.py +15 -0
  273. quadra/utils/patch/dataset.py +1433 -0
  274. quadra/utils/patch/metrics.py +449 -0
  275. quadra/utils/patch/model.py +153 -0
  276. quadra/utils/patch/visualization.py +217 -0
  277. quadra/utils/resolver.py +42 -0
  278. quadra/utils/segmentation.py +31 -0
  279. quadra/utils/tests/__init__.py +0 -0
  280. quadra/utils/tests/fixtures/__init__.py +1 -0
  281. quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
  282. quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
  283. quadra/utils/tests/fixtures/dataset/classification.py +406 -0
  284. quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
  285. quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
  286. quadra/utils/tests/fixtures/models/__init__.py +3 -0
  287. quadra/utils/tests/fixtures/models/anomaly.py +89 -0
  288. quadra/utils/tests/fixtures/models/classification.py +45 -0
  289. quadra/utils/tests/fixtures/models/segmentation.py +33 -0
  290. quadra/utils/tests/helpers.py +70 -0
  291. quadra/utils/tests/models.py +27 -0
  292. quadra/utils/utils.py +525 -0
  293. quadra/utils/validator.py +115 -0
  294. quadra/utils/visualization.py +422 -0
  295. quadra/utils/vit_explainability.py +349 -0
  296. quadra-2.2.7.dist-info/LICENSE +201 -0
  297. quadra-2.2.7.dist-info/METADATA +381 -0
  298. quadra-2.2.7.dist-info/RECORD +300 -0
  299. {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
  300. quadra-2.2.7.dist-info/entry_points.txt +3 -0
  301. quadra-0.0.1.dist-info/METADATA +0 -14
  302. quadra-0.0.1.dist-info/RECORD +0 -4
@@ -0,0 +1,501 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import uuid
5
+ from copy import deepcopy
6
+ from typing import Any
7
+
8
+ import pytorch_lightning as pl
9
+ from pytorch_lightning.callbacks import Callback
10
+ from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder as LightningBatchSizeFinder
11
+ from pytorch_lightning.utilities import rank_zero_only
12
+ from pytorch_lightning.utilities.exceptions import _TunerExitException
13
+ from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
14
+ from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_setattr
15
+ from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
16
+ from torch import nn
17
+
18
+ from quadra.utils.utils import get_logger
19
+
20
+ log = get_logger(__name__)
21
+
22
+ # pylint: disable=protected-access
23
+
24
+
25
+ def _scale_batch_size(
26
+ trainer: pl.Trainer,
27
+ mode: str = "power",
28
+ steps_per_trial: int = 3,
29
+ init_val: int = 2,
30
+ max_trials: int = 25,
31
+ batch_arg_name: str = "batch_size",
32
+ ) -> int | None:
33
+ """Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
34
+ error.
35
+
36
+ Args:
37
+ trainer: A Trainer instance.
38
+ mode: Search strategy to update the batch size:
39
+
40
+ - ``'power'``: Keep multiplying the batch size by 2, until we get an OOM error.
41
+ - ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error
42
+ do a binary search between the last successful batch size and the batch size that failed.
43
+
44
+ steps_per_trial: number of steps to run with a given batch size.
45
+ Ideally 1 should be enough to test if an OOM error occurs,
46
+ however in practise a few are needed
47
+ init_val: initial batch size to start the search with
48
+ max_trials: max number of increases in batch size done before
49
+ algorithm is terminated
50
+ batch_arg_name: name of the attribute that stores the batch size.
51
+ It is expected that the user has provided a model or datamodule that has a hyperparameter
52
+ with that name. We will look for this attribute name in the following places
53
+
54
+ - ``model``
55
+ - ``model.hparams``
56
+ - ``trainer.datamodule`` (the datamodule passed to the tune method)
57
+
58
+ """
59
+ if trainer.fast_dev_run: # type: ignore[attr-defined]
60
+ rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
61
+ return None
62
+
63
+ # Save initial model, that is loaded after batch size is found
64
+ ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
65
+ trainer.save_checkpoint(ckpt_path)
66
+
67
+ # Arguments we adjust during the batch size finder, save for restoring
68
+ params = __scale_batch_dump_params(trainer)
69
+
70
+ # Set to values that are required by the algorithm
71
+ __scale_batch_reset_params(trainer, steps_per_trial)
72
+
73
+ if trainer.progress_bar_callback:
74
+ trainer.progress_bar_callback.disable()
75
+
76
+ lightning_setattr(trainer.lightning_module, batch_arg_name, init_val)
77
+
78
+ if mode == "power":
79
+ new_size = _run_power_scaling(trainer, init_val, batch_arg_name, max_trials, params)
80
+ elif mode == "binsearch":
81
+ new_size = _run_binary_scaling(trainer, init_val, batch_arg_name, max_trials, params)
82
+
83
+ garbage_collection_cuda()
84
+
85
+ log.info("Finished batch size finder, will continue with full run using batch size %d", new_size)
86
+
87
+ __scale_batch_restore_params(trainer, params)
88
+
89
+ if trainer.progress_bar_callback:
90
+ trainer.progress_bar_callback.enable()
91
+
92
+ trainer._checkpoint_connector.restore(ckpt_path)
93
+ trainer.strategy.remove_checkpoint(ckpt_path)
94
+
95
+ return new_size
96
+
97
+
98
+ def __scale_batch_dump_params(trainer: pl.Trainer) -> dict[str, Any]:
99
+ """Dump the parameters that need to be reset after the batch size finder.."""
100
+ dumped_params = {
101
+ "loggers": trainer.loggers,
102
+ "callbacks": trainer.callbacks, # type: ignore[attr-defined]
103
+ }
104
+ loop = trainer._active_loop
105
+ assert loop is not None
106
+ if isinstance(loop, pl.loops._FitLoop):
107
+ dumped_params["max_steps"] = trainer.max_steps
108
+ dumped_params["limit_train_batches"] = trainer.limit_train_batches
109
+ dumped_params["limit_val_batches"] = trainer.limit_val_batches
110
+ elif isinstance(loop, pl.loops._EvaluationLoop):
111
+ stage = trainer.state.stage
112
+ assert stage is not None
113
+ dumped_params["limit_eval_batches"] = getattr(trainer, f"limit_{stage.dataloader_prefix}_batches")
114
+ dumped_params["loop_verbose"] = loop.verbose
115
+
116
+ dumped_params["loop_state_dict"] = deepcopy(loop.state_dict())
117
+ return dumped_params
118
+
119
+
120
+ def __scale_batch_reset_params(trainer: pl.Trainer, steps_per_trial: int) -> None:
121
+ """Reset the parameters that need to be reset after the batch size finder."""
122
+ from pytorch_lightning.loggers.logger import DummyLogger # pylint: disable=import-outside-toplevel
123
+
124
+ trainer.logger = DummyLogger() if trainer.logger is not None else None
125
+ trainer.callbacks = [] # type: ignore[attr-defined]
126
+
127
+ loop = trainer._active_loop
128
+ assert loop is not None
129
+ if isinstance(loop, pl.loops._FitLoop):
130
+ trainer.limit_train_batches = 1.0
131
+ trainer.limit_val_batches = steps_per_trial
132
+ trainer.fit_loop.epoch_loop.max_steps = steps_per_trial
133
+ elif isinstance(loop, pl.loops._EvaluationLoop):
134
+ stage = trainer.state.stage
135
+ assert stage is not None
136
+ setattr(trainer, f"limit_{stage.dataloader_prefix}_batches", steps_per_trial)
137
+ loop.verbose = False
138
+
139
+
140
+ def __scale_batch_restore_params(trainer: pl.Trainer, params: dict[str, Any]) -> None:
141
+ """Restore the parameters that need to be reset after the batch size finder."""
142
+ # TODO: There are more states that needs to be reset (#4512 and #4870)
143
+ trainer.loggers = params["loggers"]
144
+ trainer.callbacks = params["callbacks"] # type: ignore[attr-defined]
145
+
146
+ loop = trainer._active_loop
147
+ assert loop is not None
148
+ if isinstance(loop, pl.loops._FitLoop):
149
+ loop.epoch_loop.max_steps = params["max_steps"]
150
+ trainer.limit_train_batches = params["limit_train_batches"]
151
+ trainer.limit_val_batches = params["limit_val_batches"]
152
+ elif isinstance(loop, pl.loops._EvaluationLoop):
153
+ stage = trainer.state.stage
154
+ assert stage is not None
155
+ setattr(trainer, f"limit_{stage.dataloader_prefix}_batches", params["limit_eval_batches"])
156
+
157
+ loop.load_state_dict(deepcopy(params["loop_state_dict"]))
158
+ loop.restarting = False
159
+ if isinstance(loop, pl.loops._EvaluationLoop) and "loop_verbose" in params:
160
+ loop.verbose = params["loop_verbose"]
161
+
162
+ # make sure the loop's state is reset
163
+ _reset_dataloaders(trainer)
164
+ loop.reset()
165
+
166
+
167
+ def _run_power_scaling(
168
+ trainer: pl.Trainer,
169
+ new_size: int,
170
+ batch_arg_name: str,
171
+ max_trials: int,
172
+ params: dict[str, Any],
173
+ ) -> int:
174
+ """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
175
+ # this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
176
+ # if it was we exit, else we continue downscaling in case we haven't encountered a single optimal batch size
177
+ any_success = False
178
+ # In the original
179
+ for i in range(max_trials):
180
+ garbage_collection_cuda()
181
+
182
+ # reset after each try
183
+ _reset_progress(trainer)
184
+
185
+ try:
186
+ if i == 0:
187
+ rank_zero_info(f"Starting batch size finder with batch size {new_size}")
188
+ new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=1.0, desc=None)
189
+ changed = True
190
+ else:
191
+ new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")
192
+ # Force the train dataloader to reset as the batch size has changed
193
+ _reset_dataloaders(trainer)
194
+ _try_loop_run(trainer, params)
195
+
196
+ any_success = True
197
+
198
+ # In the original lightning implementation this is done before _reset_dataloaders
199
+ # As such the batch size is not checked for the last iteration!!!
200
+ if not changed:
201
+ break
202
+ except RuntimeError as exception:
203
+ if is_oom_error(exception):
204
+ # If we fail in power mode, half the size and return
205
+ garbage_collection_cuda()
206
+ if any_success:
207
+ # In the original lightning code there's a line that doesn't halve the size properly if batch_size
208
+ # is bigger than the dataset length
209
+ rank_zero_info(f"Batch size {new_size} failed, using batch size {new_size // 2}")
210
+ new_size = new_size // 2
211
+ lightning_setattr(trainer.lightning_module, batch_arg_name, new_size)
212
+ else:
213
+ # In this case it means the first iteration will fail already, probably due to a way to big
214
+ # initial batch size, since the next iteration will start from (new_size // 2) * 2, which is the
215
+ # same divide by 4 instead and retry
216
+ rank_zero_info(f"Batch size {new_size} failed at first iteration, using batch size {new_size // 4}")
217
+ new_size = new_size // 4
218
+ lightning_setattr(trainer.lightning_module, batch_arg_name, new_size)
219
+
220
+ # Force the train dataloader to reset as the batch size has changed
221
+ _reset_dataloaders(trainer)
222
+ if any_success:
223
+ break
224
+ else:
225
+ raise # some other error not memory related
226
+
227
+ return new_size
228
+
229
+
230
+ def _run_binary_scaling(
231
+ trainer: pl.Trainer,
232
+ new_size: int,
233
+ batch_arg_name: str,
234
+ max_trials: int,
235
+ params: dict[str, Any],
236
+ ) -> int:
237
+ """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered.
238
+
239
+ Hereafter, the batch size is further refined using a binary search
240
+
241
+ """
242
+ low = 1
243
+ high = None
244
+ count = 0
245
+ while True:
246
+ garbage_collection_cuda()
247
+
248
+ # reset after each try
249
+ _reset_progress(trainer)
250
+
251
+ try:
252
+ # run loop
253
+ _try_loop_run(trainer, params)
254
+ count += 1
255
+ if count > max_trials:
256
+ break
257
+ # Double in size
258
+ low = new_size
259
+ if high:
260
+ if high - low <= 1:
261
+ break
262
+ midval = (high + low) // 2
263
+ new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="succeeded")
264
+ else:
265
+ new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")
266
+
267
+ if not changed:
268
+ break
269
+
270
+ # Force the train dataloader to reset as the batch size has changed
271
+ _reset_dataloaders(trainer)
272
+
273
+ except RuntimeError as exception:
274
+ # Only these errors should trigger an adjustment
275
+ if is_oom_error(exception):
276
+ # If we fail in power mode, half the size and return
277
+ garbage_collection_cuda()
278
+
279
+ high = new_size
280
+ midval = (high + low) // 2
281
+ new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="failed")
282
+
283
+ # Force the train dataloader to reset as the batch size has changed
284
+ _reset_dataloaders(trainer)
285
+
286
+ if high - low <= 1:
287
+ break
288
+ else:
289
+ raise # some other error not memory related
290
+
291
+ return new_size
292
+
293
+
294
+ def _adjust_batch_size(
295
+ trainer: pl.Trainer,
296
+ batch_arg_name: str = "batch_size",
297
+ factor: float = 1.0,
298
+ value: int | None = None,
299
+ desc: str | None = None,
300
+ ) -> tuple[int, bool]:
301
+ """Helper function for adjusting the batch size.
302
+
303
+ Args:
304
+ trainer: instance of pytorch_lightning.Trainer
305
+ batch_arg_name: name of the attribute that stores the batch size
306
+ factor: value which the old batch size is multiplied by to get the
307
+ new batch size
308
+ value: if a value is given, will override the batch size with this value.
309
+ Note that the value of `factor` will not have an effect in this case
310
+ desc: either ``"succeeded"`` or ``"failed"``. Used purely for logging
311
+
312
+ Returns:
313
+ The new batch size for the next trial and a bool that signals whether the
314
+ new value is different than the previous batch size.
315
+
316
+ """
317
+ model = trainer.lightning_module
318
+ batch_size = lightning_getattr(model, batch_arg_name)
319
+ assert batch_size is not None
320
+
321
+ loop = trainer._active_loop
322
+ assert loop is not None
323
+ loop.setup_data()
324
+ combined_loader = loop._combined_loader
325
+ assert combined_loader is not None
326
+ try:
327
+ combined_dataset_length = combined_loader._dataset_length()
328
+ if batch_size >= combined_dataset_length:
329
+ rank_zero_info(f"The batch size {batch_size} is greater or equal than the length of your dataset.")
330
+ return batch_size, False
331
+ except NotImplementedError:
332
+ # all datasets are iterable style
333
+ pass
334
+
335
+ new_size = value if value is not None else int(batch_size * factor)
336
+ if desc:
337
+ rank_zero_info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
338
+ changed = new_size != batch_size
339
+ lightning_setattr(model, batch_arg_name, new_size)
340
+
341
+ return new_size, changed
342
+
343
+
344
+ def _reset_dataloaders(trainer: pl.Trainer) -> None:
345
+ """Reset the dataloaders to force a reload."""
346
+ loop = trainer._active_loop
347
+ assert loop is not None
348
+ loop._combined_loader = None # force a reload
349
+ loop.setup_data()
350
+ if isinstance(loop, pl.loops._FitLoop):
351
+ loop.epoch_loop.val_loop._combined_loader = None
352
+ loop.epoch_loop.val_loop.setup_data()
353
+
354
+
355
+ def _try_loop_run(trainer: pl.Trainer, params: dict[str, Any]) -> None:
356
+ """Try to run the loop with the current batch size."""
357
+ loop = trainer._active_loop
358
+ assert loop is not None
359
+ loop.load_state_dict(deepcopy(params["loop_state_dict"]))
360
+ loop.restarting = False
361
+ loop.run()
362
+
363
+
364
+ def _reset_progress(trainer: pl.Trainer) -> None:
365
+ """Reset the progress of the trainer."""
366
+ if trainer.lightning_module.automatic_optimization:
367
+ trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.reset()
368
+ else:
369
+ trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.reset()
370
+
371
+ trainer.fit_loop.epoch_progress.reset()
372
+
373
+
374
+ # Most of the code above is copied from the original lightning implementation since almost everything is private
375
+
376
+
377
+ class LightningTrainerBaseSetup(Callback):
378
+ """Custom callback used to setup a lightning trainer with default options.
379
+
380
+ Args:
381
+ log_every_n_steps: Default value for trainer.log_every_n_steps if the dataloader is too small.
382
+ """
383
+
384
+ def __init__(self, log_every_n_steps: int = 1) -> None:
385
+ self.log_every_n_steps = log_every_n_steps
386
+
387
+ @rank_zero_only
388
+ def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
389
+ """Called on every stage."""
390
+ if not hasattr(trainer, "datamodule") or not hasattr(trainer, "log_every_n_steps"):
391
+ raise ValueError("Trainer must have a datamodule and log_every_n_steps attribute.")
392
+
393
+ len_train_dataloader = len(trainer.datamodule.train_dataloader())
394
+ if len_train_dataloader <= trainer.log_every_n_steps:
395
+ if len_train_dataloader > self.log_every_n_steps:
396
+ trainer.log_every_n_steps = self.log_every_n_steps
397
+ log.info("`trainer.log_every_n_steps` is too high, setting it to %d", self.log_every_n_steps)
398
+ else:
399
+ trainer.log_every_n_steps = 1
400
+ log.warning(
401
+ "The default log_every_n_steps %d is too high given the datamodule lenght %d, fallback to 1",
402
+ self.log_every_n_steps,
403
+ len_train_dataloader,
404
+ )
405
+
406
+
407
+ class BatchSizeFinder(LightningBatchSizeFinder):
408
+ """Batch size finder setting the proper model training status as the current one from lightning seems bugged.
409
+ It also allows to skip some batch size finding steps.
410
+
411
+ Args:
412
+ find_train_batch_size: Whether to find the training batch size.
413
+ find_validation_batch_size: Whether to find the validation batch size.
414
+ find_test_batch_size: Whether to find the test batch size.
415
+ find_predict_batch_size: Whether to find the predict batch size.
416
+ mode: The mode to use for batch size finding. See `pytorch_lightning.callbacks.BatchSizeFinder` for more
417
+ details.
418
+ steps_per_trial: The number of steps per trial. See `pytorch_lightning.callbacks.BatchSizeFinder` for more
419
+ details.
420
+ init_val: The initial value for batch size. See `pytorch_lightning.callbacks.BatchSizeFinder` for more details.
421
+ max_trials: The maximum number of trials. See `pytorch_lightning.callbacks.BatchSizeFinder` for more details.
422
+ batch_arg_name: The name of the batch size argument. See `pytorch_lightning.callbacks.BatchSizeFinder` for more
423
+ details.
424
+ """
425
+
426
+ def __init__(
427
+ self,
428
+ find_train_batch_size: bool = True,
429
+ find_validation_batch_size: bool = False,
430
+ find_test_batch_size: bool = False,
431
+ find_predict_batch_size: bool = False,
432
+ mode: str = "power",
433
+ steps_per_trial: int = 3,
434
+ init_val: int = 2,
435
+ max_trials: int = 25,
436
+ batch_arg_name: str = "batch_size",
437
+ ) -> None:
438
+ super().__init__(mode, steps_per_trial, init_val, max_trials, batch_arg_name)
439
+ self.find_train_batch_size = find_train_batch_size
440
+ self.find_validation_batch_size = find_validation_batch_size
441
+ self.find_test_batch_size = find_test_batch_size
442
+ self.find_predict_batch_size = find_predict_batch_size
443
+
444
+ def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
445
+ if not self.find_train_batch_size or trainer.state.stage is None:
446
+ # If called during validation skip it as it will be triggered during on_validation_start
447
+ return None
448
+
449
+ if trainer.state.stage.value != "train":
450
+ return None
451
+
452
+ if not isinstance(pl_module.model, nn.Module):
453
+ raise ValueError("The model must be a nn.Module")
454
+ pl_module.model.train()
455
+
456
+ return super().on_fit_start(trainer, pl_module)
457
+
458
+ def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
459
+ if not self.find_validation_batch_size:
460
+ return None
461
+
462
+ if not isinstance(pl_module.model, nn.Module):
463
+ raise ValueError("The model must be a nn.Module")
464
+ pl_module.model.eval()
465
+
466
+ return super().on_validation_start(trainer, pl_module)
467
+
468
+ def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
469
+ if not self.find_test_batch_size:
470
+ return None
471
+
472
+ if not isinstance(pl_module.model, nn.Module):
473
+ raise ValueError("The model must be a nn.Module")
474
+ pl_module.model.eval()
475
+
476
+ return super().on_test_start(trainer, pl_module)
477
+
478
+ def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
479
+ if not self.find_predict_batch_size:
480
+ return None
481
+
482
+ if not isinstance(pl_module.model, nn.Module):
483
+ raise ValueError("The model must be a nn.Module")
484
+ pl_module.model.eval()
485
+
486
+ return super().on_predict_start(trainer, pl_module)
487
+
488
+ def scale_batch_size(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
489
+ """Scale the batch size."""
490
+ new_size = _scale_batch_size(
491
+ trainer,
492
+ self._mode,
493
+ self._steps_per_trial,
494
+ self._init_val,
495
+ self._max_trials,
496
+ self._batch_arg_name,
497
+ )
498
+
499
+ self.optimal_batch_size = new_size
500
+ if self._early_exit:
501
+ raise _TunerExitException()