quadra 0.0.1__py3-none-any.whl → 2.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydra_plugins/quadra_searchpath_plugin.py +30 -0
- quadra/__init__.py +6 -0
- quadra/callbacks/__init__.py +0 -0
- quadra/callbacks/anomalib.py +289 -0
- quadra/callbacks/lightning.py +501 -0
- quadra/callbacks/mlflow.py +291 -0
- quadra/callbacks/scheduler.py +69 -0
- quadra/configs/__init__.py +0 -0
- quadra/configs/backbone/caformer_m36.yaml +8 -0
- quadra/configs/backbone/caformer_s36.yaml +8 -0
- quadra/configs/backbone/convnextv2_base.yaml +8 -0
- quadra/configs/backbone/convnextv2_femto.yaml +8 -0
- quadra/configs/backbone/convnextv2_tiny.yaml +8 -0
- quadra/configs/backbone/dino_vitb8.yaml +12 -0
- quadra/configs/backbone/dino_vits8.yaml +12 -0
- quadra/configs/backbone/dinov2_vitb14.yaml +12 -0
- quadra/configs/backbone/dinov2_vits14.yaml +12 -0
- quadra/configs/backbone/efficientnet_b0.yaml +8 -0
- quadra/configs/backbone/efficientnet_b1.yaml +8 -0
- quadra/configs/backbone/efficientnet_b2.yaml +8 -0
- quadra/configs/backbone/efficientnet_b3.yaml +8 -0
- quadra/configs/backbone/efficientnetv2_s.yaml +8 -0
- quadra/configs/backbone/levit_128s.yaml +8 -0
- quadra/configs/backbone/mnasnet0_5.yaml +9 -0
- quadra/configs/backbone/resnet101.yaml +8 -0
- quadra/configs/backbone/resnet18.yaml +8 -0
- quadra/configs/backbone/resnet18_ssl.yaml +8 -0
- quadra/configs/backbone/resnet50.yaml +8 -0
- quadra/configs/backbone/smp.yaml +9 -0
- quadra/configs/backbone/tiny_vit_21m_224.yaml +9 -0
- quadra/configs/backbone/unetr.yaml +15 -0
- quadra/configs/backbone/vit16_base.yaml +9 -0
- quadra/configs/backbone/vit16_small.yaml +9 -0
- quadra/configs/backbone/vit16_tiny.yaml +9 -0
- quadra/configs/backbone/xcit_tiny_24_p8_224.yaml +9 -0
- quadra/configs/callbacks/all.yaml +45 -0
- quadra/configs/callbacks/default.yaml +34 -0
- quadra/configs/callbacks/default_anomalib.yaml +64 -0
- quadra/configs/config.yaml +33 -0
- quadra/configs/core/default.yaml +11 -0
- quadra/configs/datamodule/base/anomaly.yaml +16 -0
- quadra/configs/datamodule/base/classification.yaml +21 -0
- quadra/configs/datamodule/base/multilabel_classification.yaml +23 -0
- quadra/configs/datamodule/base/segmentation.yaml +18 -0
- quadra/configs/datamodule/base/segmentation_multiclass.yaml +20 -0
- quadra/configs/datamodule/base/sklearn_classification.yaml +23 -0
- quadra/configs/datamodule/base/sklearn_classification_patch.yaml +17 -0
- quadra/configs/datamodule/base/ssl.yaml +21 -0
- quadra/configs/datamodule/generic/imagenette/classification/base.yaml +9 -0
- quadra/configs/datamodule/generic/imagenette/ssl/base.yaml +10 -0
- quadra/configs/datamodule/generic/mnist/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/mvtec/anomaly/base.yaml +14 -0
- quadra/configs/datamodule/generic/oxford_pet/segmentation/base.yaml +9 -0
- quadra/configs/experiment/base/anomaly/cfa.yaml +47 -0
- quadra/configs/experiment/base/anomaly/cflow.yaml +47 -0
- quadra/configs/experiment/base/anomaly/csflow.yaml +48 -0
- quadra/configs/experiment/base/anomaly/draem.yaml +51 -0
- quadra/configs/experiment/base/anomaly/efficient_ad.yaml +43 -0
- quadra/configs/experiment/base/anomaly/fastflow.yaml +46 -0
- quadra/configs/experiment/base/anomaly/inference.yaml +21 -0
- quadra/configs/experiment/base/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/base/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/base/classification/classification.yaml +73 -0
- quadra/configs/experiment/base/classification/classification_evaluation.yaml +25 -0
- quadra/configs/experiment/base/classification/multilabel_classification.yaml +41 -0
- quadra/configs/experiment/base/classification/sklearn_classification.yaml +27 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch.yaml +25 -0
- quadra/configs/experiment/base/classification/sklearn_classification_patch_test.yaml +18 -0
- quadra/configs/experiment/base/classification/sklearn_classification_test.yaml +25 -0
- quadra/configs/experiment/base/segmentation/smp.yaml +30 -0
- quadra/configs/experiment/base/segmentation/smp_evaluation.yaml +17 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass.yaml +26 -0
- quadra/configs/experiment/base/segmentation/smp_multiclass_evaluation.yaml +18 -0
- quadra/configs/experiment/base/ssl/barlow.yaml +48 -0
- quadra/configs/experiment/base/ssl/byol.yaml +43 -0
- quadra/configs/experiment/base/ssl/dino.yaml +46 -0
- quadra/configs/experiment/base/ssl/linear_eval.yaml +52 -0
- quadra/configs/experiment/base/ssl/simclr.yaml +48 -0
- quadra/configs/experiment/base/ssl/simsiam.yaml +53 -0
- quadra/configs/experiment/custom/cls.yaml +12 -0
- quadra/configs/experiment/default.yaml +15 -0
- quadra/configs/experiment/generic/imagenette/classification/default.yaml +20 -0
- quadra/configs/experiment/generic/imagenette/ssl/barlow.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/byol.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/dino.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simclr.yaml +14 -0
- quadra/configs/experiment/generic/imagenette/ssl/simsiam.yaml +14 -0
- quadra/configs/experiment/generic/mnist/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mnist/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mnist/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mnist/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mnist/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mnist/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cfa.yaml +34 -0
- quadra/configs/experiment/generic/mvtec/anomaly/cflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/csflow.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/draem.yaml +33 -0
- quadra/configs/experiment/generic/mvtec/anomaly/efficient_ad.yaml +38 -0
- quadra/configs/experiment/generic/mvtec/anomaly/fastflow.yaml +29 -0
- quadra/configs/experiment/generic/mvtec/anomaly/inference.yaml +27 -0
- quadra/configs/experiment/generic/mvtec/anomaly/padim.yaml +37 -0
- quadra/configs/experiment/generic/mvtec/anomaly/patchcore.yaml +37 -0
- quadra/configs/experiment/generic/oxford_pet/segmentation/smp.yaml +27 -0
- quadra/configs/export/default.yaml +13 -0
- quadra/configs/hydra/anomaly_custom.yaml +15 -0
- quadra/configs/hydra/default.yaml +14 -0
- quadra/configs/inference/default.yaml +26 -0
- quadra/configs/logger/comet.yaml +10 -0
- quadra/configs/logger/csv.yaml +5 -0
- quadra/configs/logger/mlflow.yaml +12 -0
- quadra/configs/logger/tensorboard.yaml +8 -0
- quadra/configs/loss/asl.yaml +7 -0
- quadra/configs/loss/barlow.yaml +2 -0
- quadra/configs/loss/bce.yaml +1 -0
- quadra/configs/loss/byol.yaml +1 -0
- quadra/configs/loss/cross_entropy.yaml +1 -0
- quadra/configs/loss/dino.yaml +8 -0
- quadra/configs/loss/simclr.yaml +2 -0
- quadra/configs/loss/simsiam.yaml +1 -0
- quadra/configs/loss/smp_ce.yaml +3 -0
- quadra/configs/loss/smp_dice.yaml +2 -0
- quadra/configs/loss/smp_dice_multiclass.yaml +2 -0
- quadra/configs/loss/smp_mcc.yaml +2 -0
- quadra/configs/loss/vicreg.yaml +5 -0
- quadra/configs/model/anomalib/cfa.yaml +35 -0
- quadra/configs/model/anomalib/cflow.yaml +30 -0
- quadra/configs/model/anomalib/csflow.yaml +34 -0
- quadra/configs/model/anomalib/dfm.yaml +19 -0
- quadra/configs/model/anomalib/draem.yaml +29 -0
- quadra/configs/model/anomalib/efficient_ad.yaml +31 -0
- quadra/configs/model/anomalib/fastflow.yaml +32 -0
- quadra/configs/model/anomalib/padim.yaml +32 -0
- quadra/configs/model/anomalib/patchcore.yaml +36 -0
- quadra/configs/model/barlow.yaml +16 -0
- quadra/configs/model/byol.yaml +25 -0
- quadra/configs/model/classification.yaml +10 -0
- quadra/configs/model/dino.yaml +26 -0
- quadra/configs/model/logistic_regression.yaml +4 -0
- quadra/configs/model/multilabel_classification.yaml +9 -0
- quadra/configs/model/simclr.yaml +18 -0
- quadra/configs/model/simsiam.yaml +24 -0
- quadra/configs/model/smp.yaml +4 -0
- quadra/configs/model/smp_multiclass.yaml +4 -0
- quadra/configs/model/vicreg.yaml +16 -0
- quadra/configs/optimizer/adam.yaml +5 -0
- quadra/configs/optimizer/adamw.yaml +3 -0
- quadra/configs/optimizer/default.yaml +4 -0
- quadra/configs/optimizer/lars.yaml +8 -0
- quadra/configs/optimizer/sgd.yaml +4 -0
- quadra/configs/scheduler/default.yaml +5 -0
- quadra/configs/scheduler/rop.yaml +5 -0
- quadra/configs/scheduler/step.yaml +3 -0
- quadra/configs/scheduler/warmrestart.yaml +2 -0
- quadra/configs/scheduler/warmup.yaml +6 -0
- quadra/configs/task/anomalib/cfa.yaml +5 -0
- quadra/configs/task/anomalib/cflow.yaml +5 -0
- quadra/configs/task/anomalib/csflow.yaml +5 -0
- quadra/configs/task/anomalib/draem.yaml +5 -0
- quadra/configs/task/anomalib/efficient_ad.yaml +5 -0
- quadra/configs/task/anomalib/fastflow.yaml +5 -0
- quadra/configs/task/anomalib/inference.yaml +3 -0
- quadra/configs/task/anomalib/padim.yaml +5 -0
- quadra/configs/task/anomalib/patchcore.yaml +5 -0
- quadra/configs/task/classification.yaml +6 -0
- quadra/configs/task/classification_evaluation.yaml +6 -0
- quadra/configs/task/default.yaml +1 -0
- quadra/configs/task/segmentation.yaml +9 -0
- quadra/configs/task/segmentation_evaluation.yaml +3 -0
- quadra/configs/task/sklearn_classification.yaml +13 -0
- quadra/configs/task/sklearn_classification_patch.yaml +11 -0
- quadra/configs/task/sklearn_classification_patch_test.yaml +8 -0
- quadra/configs/task/sklearn_classification_test.yaml +8 -0
- quadra/configs/task/ssl.yaml +2 -0
- quadra/configs/trainer/lightning_cpu.yaml +36 -0
- quadra/configs/trainer/lightning_gpu.yaml +35 -0
- quadra/configs/trainer/lightning_gpu_bf16.yaml +36 -0
- quadra/configs/trainer/lightning_gpu_fp16.yaml +36 -0
- quadra/configs/trainer/lightning_multigpu.yaml +37 -0
- quadra/configs/trainer/sklearn_classification.yaml +7 -0
- quadra/configs/transforms/byol.yaml +47 -0
- quadra/configs/transforms/byol_no_random_resize.yaml +61 -0
- quadra/configs/transforms/default.yaml +37 -0
- quadra/configs/transforms/default_numpy.yaml +24 -0
- quadra/configs/transforms/default_resize.yaml +22 -0
- quadra/configs/transforms/dino.yaml +63 -0
- quadra/configs/transforms/linear_eval.yaml +18 -0
- quadra/datamodules/__init__.py +20 -0
- quadra/datamodules/anomaly.py +180 -0
- quadra/datamodules/base.py +375 -0
- quadra/datamodules/classification.py +1003 -0
- quadra/datamodules/generic/__init__.py +0 -0
- quadra/datamodules/generic/imagenette.py +144 -0
- quadra/datamodules/generic/mnist.py +81 -0
- quadra/datamodules/generic/mvtec.py +58 -0
- quadra/datamodules/generic/oxford_pet.py +163 -0
- quadra/datamodules/patch.py +190 -0
- quadra/datamodules/segmentation.py +742 -0
- quadra/datamodules/ssl.py +140 -0
- quadra/datasets/__init__.py +17 -0
- quadra/datasets/anomaly.py +287 -0
- quadra/datasets/classification.py +241 -0
- quadra/datasets/patch.py +138 -0
- quadra/datasets/segmentation.py +239 -0
- quadra/datasets/ssl.py +110 -0
- quadra/losses/__init__.py +0 -0
- quadra/losses/classification/__init__.py +6 -0
- quadra/losses/classification/asl.py +83 -0
- quadra/losses/classification/focal.py +320 -0
- quadra/losses/classification/prototypical.py +148 -0
- quadra/losses/ssl/__init__.py +17 -0
- quadra/losses/ssl/barlowtwins.py +47 -0
- quadra/losses/ssl/byol.py +37 -0
- quadra/losses/ssl/dino.py +129 -0
- quadra/losses/ssl/hyperspherical.py +45 -0
- quadra/losses/ssl/idmm.py +50 -0
- quadra/losses/ssl/simclr.py +67 -0
- quadra/losses/ssl/simsiam.py +30 -0
- quadra/losses/ssl/vicreg.py +76 -0
- quadra/main.py +49 -0
- quadra/metrics/__init__.py +3 -0
- quadra/metrics/segmentation.py +251 -0
- quadra/models/__init__.py +0 -0
- quadra/models/base.py +151 -0
- quadra/models/classification/__init__.py +8 -0
- quadra/models/classification/backbones.py +149 -0
- quadra/models/classification/base.py +92 -0
- quadra/models/evaluation.py +322 -0
- quadra/modules/__init__.py +0 -0
- quadra/modules/backbone.py +30 -0
- quadra/modules/base.py +312 -0
- quadra/modules/classification/__init__.py +3 -0
- quadra/modules/classification/base.py +327 -0
- quadra/modules/ssl/__init__.py +17 -0
- quadra/modules/ssl/barlowtwins.py +59 -0
- quadra/modules/ssl/byol.py +172 -0
- quadra/modules/ssl/common.py +285 -0
- quadra/modules/ssl/dino.py +186 -0
- quadra/modules/ssl/hyperspherical.py +206 -0
- quadra/modules/ssl/idmm.py +98 -0
- quadra/modules/ssl/simclr.py +73 -0
- quadra/modules/ssl/simsiam.py +68 -0
- quadra/modules/ssl/vicreg.py +67 -0
- quadra/optimizers/__init__.py +4 -0
- quadra/optimizers/lars.py +153 -0
- quadra/optimizers/sam.py +127 -0
- quadra/schedulers/__init__.py +3 -0
- quadra/schedulers/base.py +44 -0
- quadra/schedulers/warmup.py +127 -0
- quadra/tasks/__init__.py +24 -0
- quadra/tasks/anomaly.py +582 -0
- quadra/tasks/base.py +397 -0
- quadra/tasks/classification.py +1263 -0
- quadra/tasks/patch.py +492 -0
- quadra/tasks/segmentation.py +389 -0
- quadra/tasks/ssl.py +560 -0
- quadra/trainers/README.md +3 -0
- quadra/trainers/__init__.py +0 -0
- quadra/trainers/classification.py +179 -0
- quadra/utils/__init__.py +0 -0
- quadra/utils/anomaly.py +112 -0
- quadra/utils/classification.py +618 -0
- quadra/utils/deprecation.py +31 -0
- quadra/utils/evaluation.py +474 -0
- quadra/utils/export.py +585 -0
- quadra/utils/imaging.py +32 -0
- quadra/utils/logger.py +15 -0
- quadra/utils/mlflow.py +98 -0
- quadra/utils/model_manager.py +320 -0
- quadra/utils/models.py +523 -0
- quadra/utils/patch/__init__.py +15 -0
- quadra/utils/patch/dataset.py +1433 -0
- quadra/utils/patch/metrics.py +449 -0
- quadra/utils/patch/model.py +153 -0
- quadra/utils/patch/visualization.py +217 -0
- quadra/utils/resolver.py +42 -0
- quadra/utils/segmentation.py +31 -0
- quadra/utils/tests/__init__.py +0 -0
- quadra/utils/tests/fixtures/__init__.py +1 -0
- quadra/utils/tests/fixtures/dataset/__init__.py +39 -0
- quadra/utils/tests/fixtures/dataset/anomaly.py +124 -0
- quadra/utils/tests/fixtures/dataset/classification.py +406 -0
- quadra/utils/tests/fixtures/dataset/imagenette.py +53 -0
- quadra/utils/tests/fixtures/dataset/segmentation.py +161 -0
- quadra/utils/tests/fixtures/models/__init__.py +3 -0
- quadra/utils/tests/fixtures/models/anomaly.py +89 -0
- quadra/utils/tests/fixtures/models/classification.py +45 -0
- quadra/utils/tests/fixtures/models/segmentation.py +33 -0
- quadra/utils/tests/helpers.py +70 -0
- quadra/utils/tests/models.py +27 -0
- quadra/utils/utils.py +525 -0
- quadra/utils/validator.py +115 -0
- quadra/utils/visualization.py +422 -0
- quadra/utils/vit_explainability.py +349 -0
- quadra-2.2.7.dist-info/LICENSE +201 -0
- quadra-2.2.7.dist-info/METADATA +381 -0
- quadra-2.2.7.dist-info/RECORD +300 -0
- {quadra-0.0.1.dist-info → quadra-2.2.7.dist-info}/WHEEL +1 -1
- quadra-2.2.7.dist-info/entry_points.txt +3 -0
- quadra-0.0.1.dist-info/METADATA +0 -14
- quadra-0.0.1.dist-info/RECORD +0 -4
|
@@ -0,0 +1,501 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import uuid
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import pytorch_lightning as pl
|
|
9
|
+
from pytorch_lightning.callbacks import Callback
|
|
10
|
+
from pytorch_lightning.callbacks.batch_size_finder import BatchSizeFinder as LightningBatchSizeFinder
|
|
11
|
+
from pytorch_lightning.utilities import rank_zero_only
|
|
12
|
+
from pytorch_lightning.utilities.exceptions import _TunerExitException
|
|
13
|
+
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
|
|
14
|
+
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_setattr
|
|
15
|
+
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
|
|
16
|
+
from torch import nn
|
|
17
|
+
|
|
18
|
+
from quadra.utils.utils import get_logger
|
|
19
|
+
|
|
20
|
+
log = get_logger(__name__)
|
|
21
|
+
|
|
22
|
+
# pylint: disable=protected-access
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _scale_batch_size(
|
|
26
|
+
trainer: pl.Trainer,
|
|
27
|
+
mode: str = "power",
|
|
28
|
+
steps_per_trial: int = 3,
|
|
29
|
+
init_val: int = 2,
|
|
30
|
+
max_trials: int = 25,
|
|
31
|
+
batch_arg_name: str = "batch_size",
|
|
32
|
+
) -> int | None:
|
|
33
|
+
"""Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM)
|
|
34
|
+
error.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
trainer: A Trainer instance.
|
|
38
|
+
mode: Search strategy to update the batch size:
|
|
39
|
+
|
|
40
|
+
- ``'power'``: Keep multiplying the batch size by 2, until we get an OOM error.
|
|
41
|
+
- ``'binsearch'``: Initially keep multiplying by 2 and after encountering an OOM error
|
|
42
|
+
do a binary search between the last successful batch size and the batch size that failed.
|
|
43
|
+
|
|
44
|
+
steps_per_trial: number of steps to run with a given batch size.
|
|
45
|
+
Ideally 1 should be enough to test if an OOM error occurs,
|
|
46
|
+
however in practise a few are needed
|
|
47
|
+
init_val: initial batch size to start the search with
|
|
48
|
+
max_trials: max number of increases in batch size done before
|
|
49
|
+
algorithm is terminated
|
|
50
|
+
batch_arg_name: name of the attribute that stores the batch size.
|
|
51
|
+
It is expected that the user has provided a model or datamodule that has a hyperparameter
|
|
52
|
+
with that name. We will look for this attribute name in the following places
|
|
53
|
+
|
|
54
|
+
- ``model``
|
|
55
|
+
- ``model.hparams``
|
|
56
|
+
- ``trainer.datamodule`` (the datamodule passed to the tune method)
|
|
57
|
+
|
|
58
|
+
"""
|
|
59
|
+
if trainer.fast_dev_run: # type: ignore[attr-defined]
|
|
60
|
+
rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
# Save initial model, that is loaded after batch size is found
|
|
64
|
+
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
|
|
65
|
+
trainer.save_checkpoint(ckpt_path)
|
|
66
|
+
|
|
67
|
+
# Arguments we adjust during the batch size finder, save for restoring
|
|
68
|
+
params = __scale_batch_dump_params(trainer)
|
|
69
|
+
|
|
70
|
+
# Set to values that are required by the algorithm
|
|
71
|
+
__scale_batch_reset_params(trainer, steps_per_trial)
|
|
72
|
+
|
|
73
|
+
if trainer.progress_bar_callback:
|
|
74
|
+
trainer.progress_bar_callback.disable()
|
|
75
|
+
|
|
76
|
+
lightning_setattr(trainer.lightning_module, batch_arg_name, init_val)
|
|
77
|
+
|
|
78
|
+
if mode == "power":
|
|
79
|
+
new_size = _run_power_scaling(trainer, init_val, batch_arg_name, max_trials, params)
|
|
80
|
+
elif mode == "binsearch":
|
|
81
|
+
new_size = _run_binary_scaling(trainer, init_val, batch_arg_name, max_trials, params)
|
|
82
|
+
|
|
83
|
+
garbage_collection_cuda()
|
|
84
|
+
|
|
85
|
+
log.info("Finished batch size finder, will continue with full run using batch size %d", new_size)
|
|
86
|
+
|
|
87
|
+
__scale_batch_restore_params(trainer, params)
|
|
88
|
+
|
|
89
|
+
if trainer.progress_bar_callback:
|
|
90
|
+
trainer.progress_bar_callback.enable()
|
|
91
|
+
|
|
92
|
+
trainer._checkpoint_connector.restore(ckpt_path)
|
|
93
|
+
trainer.strategy.remove_checkpoint(ckpt_path)
|
|
94
|
+
|
|
95
|
+
return new_size
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def __scale_batch_dump_params(trainer: pl.Trainer) -> dict[str, Any]:
|
|
99
|
+
"""Dump the parameters that need to be reset after the batch size finder.."""
|
|
100
|
+
dumped_params = {
|
|
101
|
+
"loggers": trainer.loggers,
|
|
102
|
+
"callbacks": trainer.callbacks, # type: ignore[attr-defined]
|
|
103
|
+
}
|
|
104
|
+
loop = trainer._active_loop
|
|
105
|
+
assert loop is not None
|
|
106
|
+
if isinstance(loop, pl.loops._FitLoop):
|
|
107
|
+
dumped_params["max_steps"] = trainer.max_steps
|
|
108
|
+
dumped_params["limit_train_batches"] = trainer.limit_train_batches
|
|
109
|
+
dumped_params["limit_val_batches"] = trainer.limit_val_batches
|
|
110
|
+
elif isinstance(loop, pl.loops._EvaluationLoop):
|
|
111
|
+
stage = trainer.state.stage
|
|
112
|
+
assert stage is not None
|
|
113
|
+
dumped_params["limit_eval_batches"] = getattr(trainer, f"limit_{stage.dataloader_prefix}_batches")
|
|
114
|
+
dumped_params["loop_verbose"] = loop.verbose
|
|
115
|
+
|
|
116
|
+
dumped_params["loop_state_dict"] = deepcopy(loop.state_dict())
|
|
117
|
+
return dumped_params
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def __scale_batch_reset_params(trainer: pl.Trainer, steps_per_trial: int) -> None:
|
|
121
|
+
"""Reset the parameters that need to be reset after the batch size finder."""
|
|
122
|
+
from pytorch_lightning.loggers.logger import DummyLogger # pylint: disable=import-outside-toplevel
|
|
123
|
+
|
|
124
|
+
trainer.logger = DummyLogger() if trainer.logger is not None else None
|
|
125
|
+
trainer.callbacks = [] # type: ignore[attr-defined]
|
|
126
|
+
|
|
127
|
+
loop = trainer._active_loop
|
|
128
|
+
assert loop is not None
|
|
129
|
+
if isinstance(loop, pl.loops._FitLoop):
|
|
130
|
+
trainer.limit_train_batches = 1.0
|
|
131
|
+
trainer.limit_val_batches = steps_per_trial
|
|
132
|
+
trainer.fit_loop.epoch_loop.max_steps = steps_per_trial
|
|
133
|
+
elif isinstance(loop, pl.loops._EvaluationLoop):
|
|
134
|
+
stage = trainer.state.stage
|
|
135
|
+
assert stage is not None
|
|
136
|
+
setattr(trainer, f"limit_{stage.dataloader_prefix}_batches", steps_per_trial)
|
|
137
|
+
loop.verbose = False
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def __scale_batch_restore_params(trainer: pl.Trainer, params: dict[str, Any]) -> None:
|
|
141
|
+
"""Restore the parameters that need to be reset after the batch size finder."""
|
|
142
|
+
# TODO: There are more states that needs to be reset (#4512 and #4870)
|
|
143
|
+
trainer.loggers = params["loggers"]
|
|
144
|
+
trainer.callbacks = params["callbacks"] # type: ignore[attr-defined]
|
|
145
|
+
|
|
146
|
+
loop = trainer._active_loop
|
|
147
|
+
assert loop is not None
|
|
148
|
+
if isinstance(loop, pl.loops._FitLoop):
|
|
149
|
+
loop.epoch_loop.max_steps = params["max_steps"]
|
|
150
|
+
trainer.limit_train_batches = params["limit_train_batches"]
|
|
151
|
+
trainer.limit_val_batches = params["limit_val_batches"]
|
|
152
|
+
elif isinstance(loop, pl.loops._EvaluationLoop):
|
|
153
|
+
stage = trainer.state.stage
|
|
154
|
+
assert stage is not None
|
|
155
|
+
setattr(trainer, f"limit_{stage.dataloader_prefix}_batches", params["limit_eval_batches"])
|
|
156
|
+
|
|
157
|
+
loop.load_state_dict(deepcopy(params["loop_state_dict"]))
|
|
158
|
+
loop.restarting = False
|
|
159
|
+
if isinstance(loop, pl.loops._EvaluationLoop) and "loop_verbose" in params:
|
|
160
|
+
loop.verbose = params["loop_verbose"]
|
|
161
|
+
|
|
162
|
+
# make sure the loop's state is reset
|
|
163
|
+
_reset_dataloaders(trainer)
|
|
164
|
+
loop.reset()
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _run_power_scaling(
|
|
168
|
+
trainer: pl.Trainer,
|
|
169
|
+
new_size: int,
|
|
170
|
+
batch_arg_name: str,
|
|
171
|
+
max_trials: int,
|
|
172
|
+
params: dict[str, Any],
|
|
173
|
+
) -> int:
|
|
174
|
+
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
|
|
175
|
+
# this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not
|
|
176
|
+
# if it was we exit, else we continue downscaling in case we haven't encountered a single optimal batch size
|
|
177
|
+
any_success = False
|
|
178
|
+
# In the original
|
|
179
|
+
for i in range(max_trials):
|
|
180
|
+
garbage_collection_cuda()
|
|
181
|
+
|
|
182
|
+
# reset after each try
|
|
183
|
+
_reset_progress(trainer)
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
if i == 0:
|
|
187
|
+
rank_zero_info(f"Starting batch size finder with batch size {new_size}")
|
|
188
|
+
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=1.0, desc=None)
|
|
189
|
+
changed = True
|
|
190
|
+
else:
|
|
191
|
+
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")
|
|
192
|
+
# Force the train dataloader to reset as the batch size has changed
|
|
193
|
+
_reset_dataloaders(trainer)
|
|
194
|
+
_try_loop_run(trainer, params)
|
|
195
|
+
|
|
196
|
+
any_success = True
|
|
197
|
+
|
|
198
|
+
# In the original lightning implementation this is done before _reset_dataloaders
|
|
199
|
+
# As such the batch size is not checked for the last iteration!!!
|
|
200
|
+
if not changed:
|
|
201
|
+
break
|
|
202
|
+
except RuntimeError as exception:
|
|
203
|
+
if is_oom_error(exception):
|
|
204
|
+
# If we fail in power mode, half the size and return
|
|
205
|
+
garbage_collection_cuda()
|
|
206
|
+
if any_success:
|
|
207
|
+
# In the original lightning code there's a line that doesn't halve the size properly if batch_size
|
|
208
|
+
# is bigger than the dataset length
|
|
209
|
+
rank_zero_info(f"Batch size {new_size} failed, using batch size {new_size // 2}")
|
|
210
|
+
new_size = new_size // 2
|
|
211
|
+
lightning_setattr(trainer.lightning_module, batch_arg_name, new_size)
|
|
212
|
+
else:
|
|
213
|
+
# In this case it means the first iteration will fail already, probably due to a way to big
|
|
214
|
+
# initial batch size, since the next iteration will start from (new_size // 2) * 2, which is the
|
|
215
|
+
# same divide by 4 instead and retry
|
|
216
|
+
rank_zero_info(f"Batch size {new_size} failed at first iteration, using batch size {new_size // 4}")
|
|
217
|
+
new_size = new_size // 4
|
|
218
|
+
lightning_setattr(trainer.lightning_module, batch_arg_name, new_size)
|
|
219
|
+
|
|
220
|
+
# Force the train dataloader to reset as the batch size has changed
|
|
221
|
+
_reset_dataloaders(trainer)
|
|
222
|
+
if any_success:
|
|
223
|
+
break
|
|
224
|
+
else:
|
|
225
|
+
raise # some other error not memory related
|
|
226
|
+
|
|
227
|
+
return new_size
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def _run_binary_scaling(
|
|
231
|
+
trainer: pl.Trainer,
|
|
232
|
+
new_size: int,
|
|
233
|
+
batch_arg_name: str,
|
|
234
|
+
max_trials: int,
|
|
235
|
+
params: dict[str, Any],
|
|
236
|
+
) -> int:
|
|
237
|
+
"""Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered.
|
|
238
|
+
|
|
239
|
+
Hereafter, the batch size is further refined using a binary search
|
|
240
|
+
|
|
241
|
+
"""
|
|
242
|
+
low = 1
|
|
243
|
+
high = None
|
|
244
|
+
count = 0
|
|
245
|
+
while True:
|
|
246
|
+
garbage_collection_cuda()
|
|
247
|
+
|
|
248
|
+
# reset after each try
|
|
249
|
+
_reset_progress(trainer)
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
# run loop
|
|
253
|
+
_try_loop_run(trainer, params)
|
|
254
|
+
count += 1
|
|
255
|
+
if count > max_trials:
|
|
256
|
+
break
|
|
257
|
+
# Double in size
|
|
258
|
+
low = new_size
|
|
259
|
+
if high:
|
|
260
|
+
if high - low <= 1:
|
|
261
|
+
break
|
|
262
|
+
midval = (high + low) // 2
|
|
263
|
+
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="succeeded")
|
|
264
|
+
else:
|
|
265
|
+
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc="succeeded")
|
|
266
|
+
|
|
267
|
+
if not changed:
|
|
268
|
+
break
|
|
269
|
+
|
|
270
|
+
# Force the train dataloader to reset as the batch size has changed
|
|
271
|
+
_reset_dataloaders(trainer)
|
|
272
|
+
|
|
273
|
+
except RuntimeError as exception:
|
|
274
|
+
# Only these errors should trigger an adjustment
|
|
275
|
+
if is_oom_error(exception):
|
|
276
|
+
# If we fail in power mode, half the size and return
|
|
277
|
+
garbage_collection_cuda()
|
|
278
|
+
|
|
279
|
+
high = new_size
|
|
280
|
+
midval = (high + low) // 2
|
|
281
|
+
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc="failed")
|
|
282
|
+
|
|
283
|
+
# Force the train dataloader to reset as the batch size has changed
|
|
284
|
+
_reset_dataloaders(trainer)
|
|
285
|
+
|
|
286
|
+
if high - low <= 1:
|
|
287
|
+
break
|
|
288
|
+
else:
|
|
289
|
+
raise # some other error not memory related
|
|
290
|
+
|
|
291
|
+
return new_size
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _adjust_batch_size(
|
|
295
|
+
trainer: pl.Trainer,
|
|
296
|
+
batch_arg_name: str = "batch_size",
|
|
297
|
+
factor: float = 1.0,
|
|
298
|
+
value: int | None = None,
|
|
299
|
+
desc: str | None = None,
|
|
300
|
+
) -> tuple[int, bool]:
|
|
301
|
+
"""Helper function for adjusting the batch size.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
trainer: instance of pytorch_lightning.Trainer
|
|
305
|
+
batch_arg_name: name of the attribute that stores the batch size
|
|
306
|
+
factor: value which the old batch size is multiplied by to get the
|
|
307
|
+
new batch size
|
|
308
|
+
value: if a value is given, will override the batch size with this value.
|
|
309
|
+
Note that the value of `factor` will not have an effect in this case
|
|
310
|
+
desc: either ``"succeeded"`` or ``"failed"``. Used purely for logging
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
The new batch size for the next trial and a bool that signals whether the
|
|
314
|
+
new value is different than the previous batch size.
|
|
315
|
+
|
|
316
|
+
"""
|
|
317
|
+
model = trainer.lightning_module
|
|
318
|
+
batch_size = lightning_getattr(model, batch_arg_name)
|
|
319
|
+
assert batch_size is not None
|
|
320
|
+
|
|
321
|
+
loop = trainer._active_loop
|
|
322
|
+
assert loop is not None
|
|
323
|
+
loop.setup_data()
|
|
324
|
+
combined_loader = loop._combined_loader
|
|
325
|
+
assert combined_loader is not None
|
|
326
|
+
try:
|
|
327
|
+
combined_dataset_length = combined_loader._dataset_length()
|
|
328
|
+
if batch_size >= combined_dataset_length:
|
|
329
|
+
rank_zero_info(f"The batch size {batch_size} is greater or equal than the length of your dataset.")
|
|
330
|
+
return batch_size, False
|
|
331
|
+
except NotImplementedError:
|
|
332
|
+
# all datasets are iterable style
|
|
333
|
+
pass
|
|
334
|
+
|
|
335
|
+
new_size = value if value is not None else int(batch_size * factor)
|
|
336
|
+
if desc:
|
|
337
|
+
rank_zero_info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
|
|
338
|
+
changed = new_size != batch_size
|
|
339
|
+
lightning_setattr(model, batch_arg_name, new_size)
|
|
340
|
+
|
|
341
|
+
return new_size, changed
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def _reset_dataloaders(trainer: pl.Trainer) -> None:
|
|
345
|
+
"""Reset the dataloaders to force a reload."""
|
|
346
|
+
loop = trainer._active_loop
|
|
347
|
+
assert loop is not None
|
|
348
|
+
loop._combined_loader = None # force a reload
|
|
349
|
+
loop.setup_data()
|
|
350
|
+
if isinstance(loop, pl.loops._FitLoop):
|
|
351
|
+
loop.epoch_loop.val_loop._combined_loader = None
|
|
352
|
+
loop.epoch_loop.val_loop.setup_data()
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def _try_loop_run(trainer: pl.Trainer, params: dict[str, Any]) -> None:
|
|
356
|
+
"""Try to run the loop with the current batch size."""
|
|
357
|
+
loop = trainer._active_loop
|
|
358
|
+
assert loop is not None
|
|
359
|
+
loop.load_state_dict(deepcopy(params["loop_state_dict"]))
|
|
360
|
+
loop.restarting = False
|
|
361
|
+
loop.run()
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def _reset_progress(trainer: pl.Trainer) -> None:
|
|
365
|
+
"""Reset the progress of the trainer."""
|
|
366
|
+
if trainer.lightning_module.automatic_optimization:
|
|
367
|
+
trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.reset()
|
|
368
|
+
else:
|
|
369
|
+
trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.reset()
|
|
370
|
+
|
|
371
|
+
trainer.fit_loop.epoch_progress.reset()
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
# Most of the code above is copied from the original lightning implementation since almost everything is private
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
class LightningTrainerBaseSetup(Callback):
|
|
378
|
+
"""Custom callback used to setup a lightning trainer with default options.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
log_every_n_steps: Default value for trainer.log_every_n_steps if the dataloader is too small.
|
|
382
|
+
"""
|
|
383
|
+
|
|
384
|
+
def __init__(self, log_every_n_steps: int = 1) -> None:
|
|
385
|
+
self.log_every_n_steps = log_every_n_steps
|
|
386
|
+
|
|
387
|
+
@rank_zero_only
|
|
388
|
+
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
389
|
+
"""Called on every stage."""
|
|
390
|
+
if not hasattr(trainer, "datamodule") or not hasattr(trainer, "log_every_n_steps"):
|
|
391
|
+
raise ValueError("Trainer must have a datamodule and log_every_n_steps attribute.")
|
|
392
|
+
|
|
393
|
+
len_train_dataloader = len(trainer.datamodule.train_dataloader())
|
|
394
|
+
if len_train_dataloader <= trainer.log_every_n_steps:
|
|
395
|
+
if len_train_dataloader > self.log_every_n_steps:
|
|
396
|
+
trainer.log_every_n_steps = self.log_every_n_steps
|
|
397
|
+
log.info("`trainer.log_every_n_steps` is too high, setting it to %d", self.log_every_n_steps)
|
|
398
|
+
else:
|
|
399
|
+
trainer.log_every_n_steps = 1
|
|
400
|
+
log.warning(
|
|
401
|
+
"The default log_every_n_steps %d is too high given the datamodule lenght %d, fallback to 1",
|
|
402
|
+
self.log_every_n_steps,
|
|
403
|
+
len_train_dataloader,
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
class BatchSizeFinder(LightningBatchSizeFinder):
|
|
408
|
+
"""Batch size finder setting the proper model training status as the current one from lightning seems bugged.
|
|
409
|
+
It also allows to skip some batch size finding steps.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
find_train_batch_size: Whether to find the training batch size.
|
|
413
|
+
find_validation_batch_size: Whether to find the validation batch size.
|
|
414
|
+
find_test_batch_size: Whether to find the test batch size.
|
|
415
|
+
find_predict_batch_size: Whether to find the predict batch size.
|
|
416
|
+
mode: The mode to use for batch size finding. See `pytorch_lightning.callbacks.BatchSizeFinder` for more
|
|
417
|
+
details.
|
|
418
|
+
steps_per_trial: The number of steps per trial. See `pytorch_lightning.callbacks.BatchSizeFinder` for more
|
|
419
|
+
details.
|
|
420
|
+
init_val: The initial value for batch size. See `pytorch_lightning.callbacks.BatchSizeFinder` for more details.
|
|
421
|
+
max_trials: The maximum number of trials. See `pytorch_lightning.callbacks.BatchSizeFinder` for more details.
|
|
422
|
+
batch_arg_name: The name of the batch size argument. See `pytorch_lightning.callbacks.BatchSizeFinder` for more
|
|
423
|
+
details.
|
|
424
|
+
"""
|
|
425
|
+
|
|
426
|
+
def __init__(
|
|
427
|
+
self,
|
|
428
|
+
find_train_batch_size: bool = True,
|
|
429
|
+
find_validation_batch_size: bool = False,
|
|
430
|
+
find_test_batch_size: bool = False,
|
|
431
|
+
find_predict_batch_size: bool = False,
|
|
432
|
+
mode: str = "power",
|
|
433
|
+
steps_per_trial: int = 3,
|
|
434
|
+
init_val: int = 2,
|
|
435
|
+
max_trials: int = 25,
|
|
436
|
+
batch_arg_name: str = "batch_size",
|
|
437
|
+
) -> None:
|
|
438
|
+
super().__init__(mode, steps_per_trial, init_val, max_trials, batch_arg_name)
|
|
439
|
+
self.find_train_batch_size = find_train_batch_size
|
|
440
|
+
self.find_validation_batch_size = find_validation_batch_size
|
|
441
|
+
self.find_test_batch_size = find_test_batch_size
|
|
442
|
+
self.find_predict_batch_size = find_predict_batch_size
|
|
443
|
+
|
|
444
|
+
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
445
|
+
if not self.find_train_batch_size or trainer.state.stage is None:
|
|
446
|
+
# If called during validation skip it as it will be triggered during on_validation_start
|
|
447
|
+
return None
|
|
448
|
+
|
|
449
|
+
if trainer.state.stage.value != "train":
|
|
450
|
+
return None
|
|
451
|
+
|
|
452
|
+
if not isinstance(pl_module.model, nn.Module):
|
|
453
|
+
raise ValueError("The model must be a nn.Module")
|
|
454
|
+
pl_module.model.train()
|
|
455
|
+
|
|
456
|
+
return super().on_fit_start(trainer, pl_module)
|
|
457
|
+
|
|
458
|
+
def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
459
|
+
if not self.find_validation_batch_size:
|
|
460
|
+
return None
|
|
461
|
+
|
|
462
|
+
if not isinstance(pl_module.model, nn.Module):
|
|
463
|
+
raise ValueError("The model must be a nn.Module")
|
|
464
|
+
pl_module.model.eval()
|
|
465
|
+
|
|
466
|
+
return super().on_validation_start(trainer, pl_module)
|
|
467
|
+
|
|
468
|
+
def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
469
|
+
if not self.find_test_batch_size:
|
|
470
|
+
return None
|
|
471
|
+
|
|
472
|
+
if not isinstance(pl_module.model, nn.Module):
|
|
473
|
+
raise ValueError("The model must be a nn.Module")
|
|
474
|
+
pl_module.model.eval()
|
|
475
|
+
|
|
476
|
+
return super().on_test_start(trainer, pl_module)
|
|
477
|
+
|
|
478
|
+
def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
479
|
+
if not self.find_predict_batch_size:
|
|
480
|
+
return None
|
|
481
|
+
|
|
482
|
+
if not isinstance(pl_module.model, nn.Module):
|
|
483
|
+
raise ValueError("The model must be a nn.Module")
|
|
484
|
+
pl_module.model.eval()
|
|
485
|
+
|
|
486
|
+
return super().on_predict_start(trainer, pl_module)
|
|
487
|
+
|
|
488
|
+
def scale_batch_size(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
|
489
|
+
"""Scale the batch size."""
|
|
490
|
+
new_size = _scale_batch_size(
|
|
491
|
+
trainer,
|
|
492
|
+
self._mode,
|
|
493
|
+
self._steps_per_trial,
|
|
494
|
+
self._init_val,
|
|
495
|
+
self._max_trials,
|
|
496
|
+
self._batch_arg_name,
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
self.optimal_batch_size = new_size
|
|
500
|
+
if self._early_exit:
|
|
501
|
+
raise _TunerExitException()
|