rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,923 @@
|
|
|
1
|
+
"""Generic training harness built atop PyTorch Lightning Fabric.
|
|
2
|
+
|
|
3
|
+
In addition to standard harness features (gradient accumulation, mixed precision, etc.), includes native support for EMA.
|
|
4
|
+
|
|
5
|
+
References:
|
|
6
|
+
- Pytorch Lightning Trainer Example (https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/fabric/build_your_own_trainer/trainer.py)
|
|
7
|
+
- Lightning Hydra Template (https://github.com/ashleve/lightning-hydra-template)
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import math
|
|
11
|
+
import time
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
from datetime import timedelta
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import cast
|
|
16
|
+
|
|
17
|
+
import hydra
|
|
18
|
+
import lightning as L
|
|
19
|
+
import numpy as np
|
|
20
|
+
import torch
|
|
21
|
+
from beartype.typing import Any, Mapping
|
|
22
|
+
from lightning.fabric.accelerators import Accelerator
|
|
23
|
+
from lightning.fabric.loggers import Logger
|
|
24
|
+
from lightning.fabric.strategies import DDPStrategy, Strategy
|
|
25
|
+
from lightning.fabric.wrappers import (
|
|
26
|
+
_FabricDataLoader,
|
|
27
|
+
_FabricModule,
|
|
28
|
+
_FabricOptimizer,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
from foundry.callbacks.callback import BaseCallback
|
|
32
|
+
from foundry.training.EMA import EMA
|
|
33
|
+
from foundry.training.schedulers import SchedulerConfig
|
|
34
|
+
from foundry.utils.ddp import RankedLogger
|
|
35
|
+
|
|
36
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
37
|
+
logger = RankedLogger(__name__, rank_zero_only=False)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class FabricTrainer(ABC):
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
*,
|
|
44
|
+
accelerator: str | Accelerator = "auto",
|
|
45
|
+
strategy: str | Strategy = "ddp",
|
|
46
|
+
devices_per_node: list[int] | int | str = "auto",
|
|
47
|
+
num_nodes: int = 1,
|
|
48
|
+
precision: str | int = "32-true",
|
|
49
|
+
callbacks: BaseCallback | list[BaseCallback] | None = None,
|
|
50
|
+
loggers: Logger | list[Logger] | None = None,
|
|
51
|
+
max_epochs: int = 1000,
|
|
52
|
+
grad_accum_steps: int = 1,
|
|
53
|
+
validate_every_n_epochs: int = 1,
|
|
54
|
+
n_examples_per_epoch: int = 24_000,
|
|
55
|
+
output_dir: Path | str | None = None,
|
|
56
|
+
checkpoint_every_n_epochs: int = 1,
|
|
57
|
+
clip_grad_max_norm: float | None = None,
|
|
58
|
+
error_if_grad_nonfinite: bool = True,
|
|
59
|
+
limit_train_batches: int | float = float("inf"),
|
|
60
|
+
limit_val_batches: int | float = float("inf"),
|
|
61
|
+
prevalidate: bool = False,
|
|
62
|
+
nccl_timeout: int = 3200,
|
|
63
|
+
skip_optimizer_loading: bool = False,
|
|
64
|
+
) -> None:
|
|
65
|
+
"""Base Trainer class built around Lightning Fabric.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
accelerator: The hardware to run on. See (1) for details. Possible choices are:
|
|
69
|
+
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
|
|
70
|
+
strategy: Strategy for how to run across multiple devices. See (1) for details. Possible choices are:
|
|
71
|
+
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.
|
|
72
|
+
devices_per_node: Number of devices to train on per machine (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
|
|
73
|
+
See (1) for details.
|
|
74
|
+
EXAMPLE: If you run on 2 nodes with 8 GPUs each, you would set ``devices_per_node=8``, not ``16``.
|
|
75
|
+
num_nodes: Number of machines (nodes) for distributed training (default: 1). See (1) for details.
|
|
76
|
+
precision: Double precision (``"64"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),
|
|
77
|
+
or bfloat16 precision AMP (``"bf16-mixed"``). See (2) for details.
|
|
78
|
+
callbacks: A single callback or a list of callbacks, each inheriting the BaseCallback Abstract Base Class.
|
|
79
|
+
loggers: A single logger or a list of loggers. See (3) for details.
|
|
80
|
+
max_epochs: Maximum number of epochs to train for (default: 1000).
|
|
81
|
+
grad_accum_steps: Number of batches to process before calling optimizer.step() (default: 1). See (4) for details on gradient accumulation in Fabric.
|
|
82
|
+
validate_every_n_epochs: Number of epochs between validation runs (default: 1).
|
|
83
|
+
n_examples_per_epoch: Number of examples to sample per epoch, across all GPUs. E.g., number of distinct examples that will
|
|
84
|
+
be "seen" by the model in a single epoch. If smaller than the the number implied by the dataloader, we will
|
|
85
|
+
alert a warning and use the smaller number.
|
|
86
|
+
output_dir: Directory to save checkpoints, metrics, intermediate validation strructures, etc. (default: None).
|
|
87
|
+
checkpoint_every_n_epochs: Number of epochs between saving checkpoints (default: 1).
|
|
88
|
+
clip_grad_max_norm: Maximum gradient norm to clip to (default: None). If None, no gradient clipping is performed.
|
|
89
|
+
error_if_grad_nonfinite: Whether to raise an error in gradient clipping if gradients are non-finite (default: True).
|
|
90
|
+
limit_train_batches: Limit on the number of training batches per epoch (default: float("inf")).
|
|
91
|
+
Helpful for debugging; should NOT be used when training production models.
|
|
92
|
+
limit_val_batches: Limit on the number of validation batches per epoch (default: float("inf")).
|
|
93
|
+
Helpful for debugging; should NOT be used when training production models.
|
|
94
|
+
prevalidate: Whether to run validation before training starts (default: False).
|
|
95
|
+
nccl_timeout: Timeout for NCCL operations (default: 3200). Only used with DDP strategy.
|
|
96
|
+
|
|
97
|
+
References:
|
|
98
|
+
(1) Fabric Arguments (https://lightning.ai/docs/fabric/stable/api/fabric_args.html)
|
|
99
|
+
(2) Fabric Precision Documentation (https://lightning.ai/docs/fabric/stable/fundamentals/precision.html)
|
|
100
|
+
(3) Fabric Loggers (https://lightning.ai/docs/fabric/2.4.0/api/loggers.html)
|
|
101
|
+
(4) Efficient Gradient Accumulation (https://lightning.ai/docs/fabric/2.4.0/advanced/gradient_accumulation.html)
|
|
102
|
+
"""
|
|
103
|
+
# DDP strategy requires a manual timeout higher than the default
|
|
104
|
+
if strategy == "ddp":
|
|
105
|
+
strategy = DDPStrategy(timeout=timedelta(seconds=nccl_timeout))
|
|
106
|
+
|
|
107
|
+
# See (1) for initialization arguments for Fabric()
|
|
108
|
+
self.fabric = L.Fabric(
|
|
109
|
+
accelerator=accelerator,
|
|
110
|
+
strategy=strategy,
|
|
111
|
+
devices=devices_per_node,
|
|
112
|
+
num_nodes=num_nodes,
|
|
113
|
+
precision=precision,
|
|
114
|
+
callbacks=callbacks,
|
|
115
|
+
loggers=loggers,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Training
|
|
119
|
+
self.clip_grad_max_norm = clip_grad_max_norm
|
|
120
|
+
self.grad_accum_steps = grad_accum_steps
|
|
121
|
+
self.error_if_grad_nonfinite = error_if_grad_nonfinite
|
|
122
|
+
|
|
123
|
+
# Stopping
|
|
124
|
+
self.max_epochs = max_epochs
|
|
125
|
+
self.should_stop = False
|
|
126
|
+
self.n_examples_per_epoch = n_examples_per_epoch
|
|
127
|
+
self.limit_train_batches = limit_train_batches
|
|
128
|
+
self.limit_val_batches = limit_val_batches
|
|
129
|
+
|
|
130
|
+
# Validation
|
|
131
|
+
self.validate_every_n_epochs = validate_every_n_epochs
|
|
132
|
+
self.prevalidate = prevalidate
|
|
133
|
+
|
|
134
|
+
# Checkpoints
|
|
135
|
+
self.output_dir = Path(output_dir) if output_dir else None
|
|
136
|
+
self.checkpoint_every_n_epochs = checkpoint_every_n_epochs
|
|
137
|
+
self.skip_optimizer_loading = skip_optimizer_loading
|
|
138
|
+
|
|
139
|
+
def initialize_or_update_trainer_state(
|
|
140
|
+
self,
|
|
141
|
+
updates: dict,
|
|
142
|
+
):
|
|
143
|
+
"""Initialize or update the state dictionary for the trainer.
|
|
144
|
+
|
|
145
|
+
State keys:
|
|
146
|
+
model: The model to train.
|
|
147
|
+
optimizer: The optimizer to use with the model. May be None for validation/inference.
|
|
148
|
+
scheduler_cfg: Learning rate SchedulerConfig (e.g., a LRScheduler with intervals/frequency). May be None for validation/inference or if no scheduler is used.
|
|
149
|
+
global_step: Global optimizer step; used by W&B logger, learning rate schedulers, etc. Default is 0.
|
|
150
|
+
current_epoch: Global epoch counter; used for validation, learning rate schedulers, checkpointing, etc. Default is 0.
|
|
151
|
+
train_cfg: The training configuration dictionary. Used for reinitializing the trainer with the same configuration
|
|
152
|
+
(for training or for inference). Default is an empty dictionary.
|
|
153
|
+
"""
|
|
154
|
+
# Default values for the state
|
|
155
|
+
default_state = {
|
|
156
|
+
"model": None,
|
|
157
|
+
"optimizer": None,
|
|
158
|
+
"scheduler_cfg": None,
|
|
159
|
+
"global_step": 0,
|
|
160
|
+
"current_epoch": 0,
|
|
161
|
+
"train_cfg": {},
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
# Initialize self.state with default values if it doesn't exist
|
|
165
|
+
if not hasattr(self, "state"):
|
|
166
|
+
self.state = default_state.copy()
|
|
167
|
+
else:
|
|
168
|
+
# Ensure existing state has all default keys
|
|
169
|
+
for key, value in default_state.items():
|
|
170
|
+
self.state.setdefault(key, value)
|
|
171
|
+
|
|
172
|
+
# Merge the updates into the existing state
|
|
173
|
+
self.state.update(updates)
|
|
174
|
+
|
|
175
|
+
def construct_optimizer(self) -> None:
|
|
176
|
+
"""Instantiate the optimizer(s)
|
|
177
|
+
|
|
178
|
+
We provide a default implementation that instantiates the optimizer(s) from the Hydra configuration.
|
|
179
|
+
More complex models (e.g., GANs) may require custom implementations.
|
|
180
|
+
"""
|
|
181
|
+
assert (
|
|
182
|
+
"model" in self.state and hasattr(self.state["model"], "parameters")
|
|
183
|
+
), "Model not found in state dictionary! You must call `construct_model()` before constructing the optimizer."
|
|
184
|
+
|
|
185
|
+
if self.state["train_cfg"].model.optimizer:
|
|
186
|
+
# ... instantiate the optimizer
|
|
187
|
+
optimizer = hydra.utils.instantiate(
|
|
188
|
+
self.state["train_cfg"].model.optimizer,
|
|
189
|
+
params=self.state["model"].parameters(),
|
|
190
|
+
)
|
|
191
|
+
self.initialize_or_update_trainer_state({"optimizer": optimizer})
|
|
192
|
+
|
|
193
|
+
def construct_scheduler(self) -> None:
|
|
194
|
+
"""Instantiate the learning rate scheduler(s)
|
|
195
|
+
|
|
196
|
+
Like optimizers, we provided a default implementation that instantiates the scheduler(s) from the Hydra configuration.
|
|
197
|
+
More complex models (e.g., GANs) may require custom implementations.
|
|
198
|
+
"""
|
|
199
|
+
assert (
|
|
200
|
+
"optimizer" in self.state and self.state["optimizer"]
|
|
201
|
+
), "Optimizer not found in state dictionary! You must call `construct_optimizer()` before constructing the scheduler."
|
|
202
|
+
|
|
203
|
+
# ... instantiate the LR scheduler(s)
|
|
204
|
+
lr_scheduler = (
|
|
205
|
+
hydra.utils.instantiate(
|
|
206
|
+
self.state["train_cfg"].model.lr_scheduler,
|
|
207
|
+
optimizer=self.state["optimizer"],
|
|
208
|
+
)
|
|
209
|
+
if self.state["train_cfg"].model.lr_scheduler
|
|
210
|
+
else None
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
if lr_scheduler:
|
|
214
|
+
# We assume "interval = step" and "frequency = 1" for the default scheduler; custom implementations may override this method
|
|
215
|
+
scheduler_cfg = SchedulerConfig(
|
|
216
|
+
scheduler=lr_scheduler,
|
|
217
|
+
interval="step",
|
|
218
|
+
frequency=1,
|
|
219
|
+
)
|
|
220
|
+
self.initialize_or_update_trainer_state({"scheduler_cfg": scheduler_cfg})
|
|
221
|
+
|
|
222
|
+
def construct_model(self):
|
|
223
|
+
"""Instantiate the model, updating the trainer state in-place.
|
|
224
|
+
|
|
225
|
+
This method must set the "model" key in the state dictionary using `self.initialize_or_update_trainer_state()`.
|
|
226
|
+
For an example, see the `construct_model` method in the `AF3Trainer`
|
|
227
|
+
"""
|
|
228
|
+
raise NotImplementedError
|
|
229
|
+
|
|
230
|
+
def setup_model_optimizers_and_schedulers(self) -> None:
|
|
231
|
+
"""Setup the model, optimizer(s), and scheduler(s) with Fabric.
|
|
232
|
+
|
|
233
|
+
Note that we must call this method after constructing (instantiating) the model, optimizer(s), and scheduler(s).
|
|
234
|
+
For details on multi-model and multi-optimizer setups, see: https://lightning.ai/docs/fabric/2.2.3/advanced/multiple_setup.html
|
|
235
|
+
"""
|
|
236
|
+
assert self.state[
|
|
237
|
+
"model"
|
|
238
|
+
], "You must construct the model before setting up the model, optimizer, and scheduler."
|
|
239
|
+
model = self.state["model"]
|
|
240
|
+
optimizer = self.state["optimizer"]
|
|
241
|
+
|
|
242
|
+
# ... setup the model and optimizer
|
|
243
|
+
if optimizer:
|
|
244
|
+
model, optimizer = self.fabric.setup(model, optimizer)
|
|
245
|
+
else:
|
|
246
|
+
model = self.fabric.setup(model)
|
|
247
|
+
|
|
248
|
+
# ... update the state dictionary (we avoid updating the state dictionary in-place, which is an anti-pattern)
|
|
249
|
+
self.initialize_or_update_trainer_state(
|
|
250
|
+
{
|
|
251
|
+
"model": model,
|
|
252
|
+
"optimizer": optimizer,
|
|
253
|
+
}
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
def fit(
|
|
257
|
+
self,
|
|
258
|
+
train_loader: torch.utils.data.DataLoader,
|
|
259
|
+
val_loaders: dict[str, torch.utils.data.DataLoader] | None = None,
|
|
260
|
+
ckpt_path: Path | str | None = None,
|
|
261
|
+
) -> None:
|
|
262
|
+
"""Main entry point for training a model.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
train_loader: Dataloader for training. Must have an iterable returning batches.
|
|
266
|
+
val_loaders: Dictionary of dataloaders for validation. The keys are the names of the loaders, and the values are the loaders themselves.
|
|
267
|
+
ckpt_path: Path to either:
|
|
268
|
+
(a) A previous checkpoint directory from which to resume training from. In this case, we will automatically load
|
|
269
|
+
the latest checkpoint using `self.get_latest_checkpoint()`.
|
|
270
|
+
(b) A specific checkpoint file to load. In this case, we will load the checkpoint from the specified file.
|
|
271
|
+
If None, no checkpoint is loaded, and the model will be trained from scratch.
|
|
272
|
+
"""
|
|
273
|
+
assert (
|
|
274
|
+
hasattr(self, "state") and "model" in self.state
|
|
275
|
+
), "Model not found in state dictionary! You must call `instantiate_model()` before running fit()."
|
|
276
|
+
|
|
277
|
+
# (If we don't have enough examples to sample, we will log a warning and use the smaller number)
|
|
278
|
+
if len(train_loader) * self.fabric.world_size < self.n_examples_per_epoch:
|
|
279
|
+
ranked_logger.warning(
|
|
280
|
+
f"Number of examples per epoch ({self.n_examples_per_epoch}) exceeds the number of examples in the loader: "
|
|
281
|
+
f"({len(train_loader) * self.fabric.world_size}). Using the latter."
|
|
282
|
+
)
|
|
283
|
+
self.n_examples_per_epoch = len(train_loader) * self.fabric.world_size
|
|
284
|
+
self.n_batches_per_epoch = math.ceil(
|
|
285
|
+
self.n_examples_per_epoch / self.fabric.world_size
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# ... setup training and validation dataloaders with Fabric
|
|
289
|
+
train_loader = self.fabric.setup_dataloaders(
|
|
290
|
+
# Our sampler is already distributed, so we don't need to wrap with a DistributedSampler
|
|
291
|
+
train_loader,
|
|
292
|
+
use_distributed_sampler=False,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
if val_loaders is not None:
|
|
296
|
+
for key, loader in val_loaders.items():
|
|
297
|
+
val_loaders[key] = self.fabric.setup_dataloaders(
|
|
298
|
+
loader, use_distributed_sampler=False
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
self.setup_model_optimizers_and_schedulers()
|
|
302
|
+
|
|
303
|
+
if ckpt_path is not None:
|
|
304
|
+
ckpt_path = Path(ckpt_path)
|
|
305
|
+
if ckpt_path.is_dir():
|
|
306
|
+
# If given a directory, load the latest checkpoint from the directory
|
|
307
|
+
ranked_logger.info(
|
|
308
|
+
f"Loading latest checkpoint within the directory {ckpt_path}..."
|
|
309
|
+
)
|
|
310
|
+
self.load_checkpoint(self.get_latest_checkpoint(ckpt_path))
|
|
311
|
+
else:
|
|
312
|
+
# If given a specific checkpoint file, load that checkpoint
|
|
313
|
+
self.load_checkpoint(ckpt_path)
|
|
314
|
+
|
|
315
|
+
# Increment the global epoch (e.g., if we loaded a checkpoint from [the end of] epoch 5, we should start training at epoch 6)
|
|
316
|
+
self.state["current_epoch"] += 1
|
|
317
|
+
# Stopping conditions
|
|
318
|
+
if (
|
|
319
|
+
self.max_epochs is not None
|
|
320
|
+
and self.state["current_epoch"] >= self.max_epochs
|
|
321
|
+
):
|
|
322
|
+
self.should_stop = True
|
|
323
|
+
else:
|
|
324
|
+
ranked_logger.info("No checkpoint provided; training from scratch.")
|
|
325
|
+
|
|
326
|
+
# Set the _num_iter_calls internal attribute of the wrapped loader to the current epoch
|
|
327
|
+
# (NOTE: This addresses a bug in Lightning Fabric, where there the iter() method calls the `_set_sampler_epoch()` method,
|
|
328
|
+
# relying on the _num_iter_calls attribute to determine the current epoch)
|
|
329
|
+
train_loader._num_iter_calls = self.state["current_epoch"]
|
|
330
|
+
|
|
331
|
+
self.fabric.call("on_fit_start", trainer=self, model=self.state["model"])
|
|
332
|
+
|
|
333
|
+
# Prevalidate
|
|
334
|
+
if self.prevalidate and val_loaders:
|
|
335
|
+
# Temporarily decrement the current epoch, since we haven't done any training this epoch
|
|
336
|
+
self.state["current_epoch"] -= 1 # (Will be -1 if training from scratch)
|
|
337
|
+
ranked_logger.info(
|
|
338
|
+
f"Prevalidating with epoch {self.state['current_epoch']} before training; to avoid this behavior, set `prevalidate=False` in the Trainer config."
|
|
339
|
+
)
|
|
340
|
+
self.validation_loop(
|
|
341
|
+
val_loaders=val_loaders,
|
|
342
|
+
limit_batches=self.limit_val_batches,
|
|
343
|
+
)
|
|
344
|
+
self.state["current_epoch"] += 1 # (Restore the current epoch)
|
|
345
|
+
|
|
346
|
+
while not self.should_stop:
|
|
347
|
+
# ... train for one epoch
|
|
348
|
+
ranked_logger.info(
|
|
349
|
+
f"\n+ Starting epoch {self.state['current_epoch']}/{self.max_epochs - 1}\n"
|
|
350
|
+
f"+ Total examples per epoch (across all GPU): {self.n_examples_per_epoch}\n"
|
|
351
|
+
f"+ Examples per GPU (batches per epoch): {self.n_batches_per_epoch}\n"
|
|
352
|
+
f"+ Gradient accumulation steps: {self.grad_accum_steps}\n"
|
|
353
|
+
f"+ Expected optimizer steps per epoch: {(self.n_batches_per_epoch // self.grad_accum_steps if self.grad_accum_steps > 0 else 0)}\n"
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
self.train_loop(
|
|
357
|
+
train_loader=train_loader,
|
|
358
|
+
limit_batches=self.limit_train_batches,
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
ranked_logger.info(f"Finished epoch {self.state['current_epoch']}!")
|
|
362
|
+
|
|
363
|
+
# ... validate, if we're at the validation interval
|
|
364
|
+
if self.should_validate and val_loaders:
|
|
365
|
+
ranked_logger.info(
|
|
366
|
+
f"Starting validation for epoch {self.state['current_epoch']}!"
|
|
367
|
+
)
|
|
368
|
+
self.validation_loop(
|
|
369
|
+
val_loaders=val_loaders,
|
|
370
|
+
limit_batches=self.limit_val_batches,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# ... step the scheduler, if we're adjusting the learning rate at the epoch-level
|
|
374
|
+
self.step_scheduler(
|
|
375
|
+
level="epoch", current_value=self.state["current_epoch"]
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# ... save checkpoint, if we've reached the checkpoint interval
|
|
379
|
+
if self.state["current_epoch"] % self.checkpoint_every_n_epochs == 0:
|
|
380
|
+
self.save_checkpoint()
|
|
381
|
+
|
|
382
|
+
# ... increment the epoch
|
|
383
|
+
self.state["current_epoch"] += 1
|
|
384
|
+
|
|
385
|
+
# Stopping conditions
|
|
386
|
+
if (
|
|
387
|
+
self.max_epochs is not None
|
|
388
|
+
and self.state["current_epoch"] >= self.max_epochs
|
|
389
|
+
):
|
|
390
|
+
self.should_stop = True
|
|
391
|
+
|
|
392
|
+
# Reset for next `fit()` call
|
|
393
|
+
self.should_stop = False
|
|
394
|
+
|
|
395
|
+
self.fabric.call("on_fit_end", trainer=self)
|
|
396
|
+
|
|
397
|
+
def train_loop(
|
|
398
|
+
self,
|
|
399
|
+
*,
|
|
400
|
+
train_loader: _FabricDataLoader,
|
|
401
|
+
limit_batches: int | float = float("inf"),
|
|
402
|
+
):
|
|
403
|
+
"""Train model for a single epoch.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
train_loader: Dataloader for training.
|
|
407
|
+
limit_batches: Limit on the batches during this training epoch. If greater than the number of batches in the
|
|
408
|
+
`train_loader`, this argument has no effect. Helpful for debugging; should NOT be used when training production models.
|
|
409
|
+
"""
|
|
410
|
+
self.fabric.call("on_train_epoch_start", trainer=self)
|
|
411
|
+
|
|
412
|
+
assert self.state["model"].training
|
|
413
|
+
|
|
414
|
+
# NOTE: When we call iter(), Fabric calls the `set_sampler_epoch()` method on the sampler behind the scenes, so we don't need to call it explicitly
|
|
415
|
+
train_iter = iter(train_loader)
|
|
416
|
+
self.fabric.call("on_after_train_loader_iter", trainer=self)
|
|
417
|
+
|
|
418
|
+
t_dataloaders = []
|
|
419
|
+
t_trains = []
|
|
420
|
+
t_step_end = time.time()
|
|
421
|
+
for batch_idx in range(len(train_loader)):
|
|
422
|
+
# (End epoch if stopping training completely or maximum desired batches for this epoch reached)
|
|
423
|
+
if self.should_stop or batch_idx >= limit_batches:
|
|
424
|
+
break
|
|
425
|
+
|
|
426
|
+
batch = next(train_iter)
|
|
427
|
+
|
|
428
|
+
self.fabric.call(
|
|
429
|
+
"on_train_batch_start", batch=batch, batch_idx=batch_idx, trainer=self
|
|
430
|
+
)
|
|
431
|
+
# Optimizer should step if we've accumulated the desired number of gradients
|
|
432
|
+
should_optimizer_step = (batch_idx + 1) % self.grad_accum_steps == 0
|
|
433
|
+
|
|
434
|
+
t_step_start = time.time()
|
|
435
|
+
t_dataloader = t_step_start - t_step_end
|
|
436
|
+
t_dataloaders.append(t_dataloader)
|
|
437
|
+
|
|
438
|
+
self.training_step(
|
|
439
|
+
batch=batch,
|
|
440
|
+
batch_idx=batch_idx,
|
|
441
|
+
is_accumulating=not should_optimizer_step,
|
|
442
|
+
)
|
|
443
|
+
t_step_end = time.time()
|
|
444
|
+
t_train = t_step_end - t_step_start
|
|
445
|
+
t_trains.append(t_train)
|
|
446
|
+
|
|
447
|
+
if t_train > 100 or t_dataloader > 100:
|
|
448
|
+
ranked_logger.warning(
|
|
449
|
+
f"Training step took {t_train:.3f} seconds, dataloader took {t_dataloader:.3f} seconds. "
|
|
450
|
+
"This may indicate a performance issue with the data loading or training step."
|
|
451
|
+
)
|
|
452
|
+
# Uncomment these lines to dump the slow example to disk for debugging
|
|
453
|
+
# fout = f"local/debug/batch_{batch_idx}_slow_example.pt"
|
|
454
|
+
# ranked_logger.info("Dumping batch example to disk: {}".format(fout))
|
|
455
|
+
# torch.save(batch, fout)
|
|
456
|
+
|
|
457
|
+
if batch_idx % 100 == 0:
|
|
458
|
+
logger.info(
|
|
459
|
+
f"Dataloading times for batch {batch_idx} - Avg time: {np.mean(t_dataloaders):.3f} (s, Wall), Max: {np.max(t_dataloaders):.3f} (s, Wall)",
|
|
460
|
+
rank=0,
|
|
461
|
+
)
|
|
462
|
+
logger.info(
|
|
463
|
+
f"Finished training step in {t_train} (s, Wall), Avg time: {np.mean(t_trains):.3f} (s, Wall), Max: {np.max(t_trains):.3f} (s, Wall)",
|
|
464
|
+
rank=0,
|
|
465
|
+
)
|
|
466
|
+
t_trains = []
|
|
467
|
+
t_dataloaders = []
|
|
468
|
+
|
|
469
|
+
if should_optimizer_step:
|
|
470
|
+
self.fabric.call(
|
|
471
|
+
"on_before_optimizer_step",
|
|
472
|
+
optimizer=self.state["optimizer"],
|
|
473
|
+
trainer=self,
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
# ... step the optimizer, clipping gradients and updating EMA parameters if applicable
|
|
477
|
+
self.step_optimizer()
|
|
478
|
+
|
|
479
|
+
self.fabric.call(
|
|
480
|
+
"optimizer_step", optimizer=self.state["optimizer"], trainer=self
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
self.fabric.call(
|
|
484
|
+
"on_train_batch_end",
|
|
485
|
+
outputs=self._current_train_return,
|
|
486
|
+
batch=batch,
|
|
487
|
+
batch_idx=batch_idx,
|
|
488
|
+
trainer=self,
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
if should_optimizer_step:
|
|
492
|
+
# ... step the scheduler, if we're adjusting the learning rate at the optimizer step-level
|
|
493
|
+
self.step_scheduler(
|
|
494
|
+
level="step", current_value=self.state["global_step"]
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# ... increment the global step, if optimizer stepped
|
|
498
|
+
# NOTE: Each node maintains its own global step
|
|
499
|
+
self.state["global_step"] += int(should_optimizer_step)
|
|
500
|
+
|
|
501
|
+
self.fabric.call("on_train_epoch_end", trainer=self)
|
|
502
|
+
|
|
503
|
+
def validation_loop(
|
|
504
|
+
self,
|
|
505
|
+
*,
|
|
506
|
+
val_loaders: dict[str, _FabricDataLoader],
|
|
507
|
+
limit_batches: int | float = float("inf"),
|
|
508
|
+
):
|
|
509
|
+
"""Run validation loop for a single validation epoch.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
val_loader: Dictionary of Dataloaders (more precisely, _FabricDataLoader) for validation.
|
|
513
|
+
limit_batches: Limit on the batches during this validation epoch. If greater than the number of batches in the
|
|
514
|
+
`val_loader`, this argument has no effect. Helpful for debugging; should NOT be used for production.
|
|
515
|
+
"""
|
|
516
|
+
# ... set model to evaluation mode
|
|
517
|
+
self.state["model"].eval()
|
|
518
|
+
|
|
519
|
+
with torch.no_grad():
|
|
520
|
+
# ... assert we're in evaluation mode
|
|
521
|
+
assert not self.state["model"].training
|
|
522
|
+
|
|
523
|
+
self.fabric.call("on_validation_epoch_start", trainer=self)
|
|
524
|
+
|
|
525
|
+
# ... iterate over all validation loaders
|
|
526
|
+
for val_loader_name, val_loader in val_loaders.items():
|
|
527
|
+
if (
|
|
528
|
+
hasattr(val_loader, "eval_every_n")
|
|
529
|
+
and self.state["global_step"] > 0
|
|
530
|
+
and self.state["global_step"] % val_loader.eval_every_n != 0
|
|
531
|
+
):
|
|
532
|
+
ranked_logger.info(
|
|
533
|
+
f"Skipping validation on dataset: {val_loader_name}, with {len(val_loader)} batches, with world_size={self.fabric.world_size}."
|
|
534
|
+
)
|
|
535
|
+
continue
|
|
536
|
+
else:
|
|
537
|
+
ranked_logger.info(
|
|
538
|
+
f"Running validation on dataset: {val_loader_name}, with {len(val_loader)} batches, with world_size={self.fabric.world_size}."
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
for batch_idx, batch in enumerate(val_loader):
|
|
542
|
+
# ... end validation epoch if stopping training completely or maximum desired batches for this epoch reached
|
|
543
|
+
if self.should_stop or batch_idx >= limit_batches:
|
|
544
|
+
break
|
|
545
|
+
|
|
546
|
+
self.fabric.call(
|
|
547
|
+
"on_validation_batch_start",
|
|
548
|
+
batch=batch,
|
|
549
|
+
batch_idx=batch_idx,
|
|
550
|
+
num_batches=len(val_loader),
|
|
551
|
+
trainer=self,
|
|
552
|
+
dataset_name=val_loader_name,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
validation_result = self.validation_step(
|
|
556
|
+
batch=batch,
|
|
557
|
+
batch_idx=batch_idx,
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
self.fabric.call(
|
|
561
|
+
"on_validation_batch_end",
|
|
562
|
+
outputs=validation_result,
|
|
563
|
+
batch=batch,
|
|
564
|
+
batch_idx=batch_idx,
|
|
565
|
+
num_batches=len(val_loader),
|
|
566
|
+
dataset_name=val_loader_name,
|
|
567
|
+
current_epoch=self.state["current_epoch"],
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
self.fabric.call("on_validation_epoch_end", trainer=self)
|
|
571
|
+
|
|
572
|
+
# ... reset the model to training mode
|
|
573
|
+
self.state["model"].train()
|
|
574
|
+
|
|
575
|
+
@abstractmethod
|
|
576
|
+
def training_step(
|
|
577
|
+
self,
|
|
578
|
+
batch: Any,
|
|
579
|
+
batch_idx: int,
|
|
580
|
+
is_accumulating: bool,
|
|
581
|
+
) -> None:
|
|
582
|
+
"""Training step, running forward and backward passes.
|
|
583
|
+
|
|
584
|
+
Args:
|
|
585
|
+
batch: The current batch; can be of any form.
|
|
586
|
+
batch_idx: The index of the current batch.
|
|
587
|
+
is_accumulating: Whether we are accumulating gradients (i.e., not yet calling optimizer.step()).
|
|
588
|
+
If this is the case, we should skip the synchronization during the backward pass.
|
|
589
|
+
|
|
590
|
+
Returns:
|
|
591
|
+
torch.Tensor | Mapping[str, Any]: The loss tensor or a dictionary containing the loss tensor.
|
|
592
|
+
"""
|
|
593
|
+
pass
|
|
594
|
+
|
|
595
|
+
@abstractmethod
|
|
596
|
+
def validation_step(
|
|
597
|
+
self,
|
|
598
|
+
batch: Any,
|
|
599
|
+
batch_idx: int,
|
|
600
|
+
val_loader_name: str | None = None,
|
|
601
|
+
) -> dict:
|
|
602
|
+
"""Validation step, running forward pass.
|
|
603
|
+
|
|
604
|
+
Args:
|
|
605
|
+
batch: The current batch; can be of any form.
|
|
606
|
+
batch_idx: The index of the current batch (within that validation loader).
|
|
607
|
+
val_loader_name: The name of the validation loader, if applicable.
|
|
608
|
+
|
|
609
|
+
Returns:
|
|
610
|
+
dict: A dictionary containing the output of the designated validation metrics.
|
|
611
|
+
"""
|
|
612
|
+
pass
|
|
613
|
+
|
|
614
|
+
def validate(
|
|
615
|
+
self,
|
|
616
|
+
val_loaders: dict,
|
|
617
|
+
ckpt_path: Path | str,
|
|
618
|
+
) -> None:
|
|
619
|
+
"""Validate a model using the given dataloaders and checkpoint.
|
|
620
|
+
|
|
621
|
+
Args:
|
|
622
|
+
model: The PyTorch model to validate.
|
|
623
|
+
val_loaders: A dictionary of dataloaders for validation, where keys are names and values are dataloaders.
|
|
624
|
+
ckpt_path: Path to a specific checkpoint file to load. If None, the model will be validated as is.
|
|
625
|
+
"""
|
|
626
|
+
assert (
|
|
627
|
+
hasattr(self, "state") and "model" in self.state
|
|
628
|
+
), "Model not found in state dictionary! You must call `instantiate_model()` before running validate()."
|
|
629
|
+
|
|
630
|
+
self.setup_model_optimizers_and_schedulers()
|
|
631
|
+
|
|
632
|
+
self.load_checkpoint(ckpt_path)
|
|
633
|
+
|
|
634
|
+
# Setup validation dataloaders with Fabric
|
|
635
|
+
for key, loader in val_loaders.items():
|
|
636
|
+
val_loaders[key] = self.fabric.setup_dataloaders(
|
|
637
|
+
loader, use_distributed_sampler=False
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
# Run the validation loop
|
|
641
|
+
self.state["model"].eval()
|
|
642
|
+
self.validation_loop(
|
|
643
|
+
val_loaders=val_loaders, limit_batches=self.limit_val_batches
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
def step_optimizer(self):
|
|
647
|
+
"""Step the optimizer.
|
|
648
|
+
|
|
649
|
+
This method must be called only when the optimizer is stepped (i.e., after accumulating the desired number of gradients).
|
|
650
|
+
|
|
651
|
+
We then perform following steps:
|
|
652
|
+
1. Clip gradients, if applicable.
|
|
653
|
+
2. Step the optimizer.
|
|
654
|
+
3. Zero the gradients.
|
|
655
|
+
4. Update the EMA parameters, if applicable.
|
|
656
|
+
"""
|
|
657
|
+
assert "optimizer" in self.state and isinstance(
|
|
658
|
+
self.state["optimizer"], _FabricOptimizer
|
|
659
|
+
)
|
|
660
|
+
assert "model" in self.state and isinstance(
|
|
661
|
+
self.state["model"], _FabricModule | EMA
|
|
662
|
+
)
|
|
663
|
+
|
|
664
|
+
optimizer = self.state["optimizer"]
|
|
665
|
+
model = self.state["model"]
|
|
666
|
+
|
|
667
|
+
# ... clip gradients, if applicable
|
|
668
|
+
if self.clip_grad_max_norm is not None:
|
|
669
|
+
self.fabric.clip_gradients(
|
|
670
|
+
module=model,
|
|
671
|
+
optimizer=optimizer,
|
|
672
|
+
max_norm=self.clip_grad_max_norm,
|
|
673
|
+
error_if_nonfinite=self.error_if_grad_nonfinite, ## Ultimately, should find the root cause of nonfinite gradients!
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
# ... step the optimizer
|
|
677
|
+
optimizer.step()
|
|
678
|
+
|
|
679
|
+
# ... zero gradients
|
|
680
|
+
optimizer.zero_grad()
|
|
681
|
+
|
|
682
|
+
# ... update EMA parameters, if applicable
|
|
683
|
+
if hasattr(model, "update"):
|
|
684
|
+
model.update()
|
|
685
|
+
|
|
686
|
+
def step_scheduler(
|
|
687
|
+
self,
|
|
688
|
+
level, #: Literal["epoch", "step"],
|
|
689
|
+
current_value: int,
|
|
690
|
+
):
|
|
691
|
+
"""Step the learning rate scheduler.
|
|
692
|
+
|
|
693
|
+
Args:
|
|
694
|
+
level: The level at which to step the scheduler. Either "epoch" or "step".
|
|
695
|
+
current_value: The current epoch or step value.
|
|
696
|
+
"""
|
|
697
|
+
# (No scheduler)
|
|
698
|
+
if "scheduler_cfg" not in self.state or self.state["scheduler_cfg"] is None:
|
|
699
|
+
return
|
|
700
|
+
else:
|
|
701
|
+
scheduler_cfg = self.state["scheduler_cfg"]
|
|
702
|
+
|
|
703
|
+
# (Wrong interval; e.g., we adjust learning rate every epoch, but we are stepping at the step level)
|
|
704
|
+
if scheduler_cfg.interval != level:
|
|
705
|
+
return
|
|
706
|
+
|
|
707
|
+
# (Right interval, but wrong frequency)
|
|
708
|
+
if current_value % cast(int, scheduler_cfg.frequency) != 0:
|
|
709
|
+
return
|
|
710
|
+
|
|
711
|
+
# ... step the scheduler
|
|
712
|
+
scheduler_cfg.scheduler.step()
|
|
713
|
+
|
|
714
|
+
def save_checkpoint(self) -> None:
|
|
715
|
+
"""Saves a checkpoint with current state to `self.output_dir/ckpt`.
|
|
716
|
+
|
|
717
|
+
If no output directory is specified, then no checkpoint is saved.
|
|
718
|
+
"""
|
|
719
|
+
# No checkpoint directory; skip saving
|
|
720
|
+
if not self.output_dir:
|
|
721
|
+
ranked_logger.warning(
|
|
722
|
+
"No output directory specified; skipping model checkpointing of state dictionary."
|
|
723
|
+
)
|
|
724
|
+
return
|
|
725
|
+
|
|
726
|
+
# (Provide a hook to modify the state before saving)
|
|
727
|
+
self.fabric.call("on_save_checkpoint", state=self.state, trainer=self)
|
|
728
|
+
|
|
729
|
+
# ... construct the checkpoint file path using Path
|
|
730
|
+
checkpoint_file = (
|
|
731
|
+
self.output_dir / "ckpt" / f"epoch-{self.state['current_epoch']:04d}.ckpt"
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
# NOTE: Fabric's `save()` will call the `state_dict()` method on the model, optimizer, and scheduler_cfg
|
|
735
|
+
self.fabric.save(checkpoint_file, self.state)
|
|
736
|
+
ranked_logger.info(f"Saved checkpoint to: {checkpoint_file}")
|
|
737
|
+
|
|
738
|
+
def _load_optimizer(self, ckpt: Mapping) -> None:
|
|
739
|
+
"""Loads the optimizer state from the checkpoint."""
|
|
740
|
+
if (
|
|
741
|
+
"optimizer" in ckpt
|
|
742
|
+
and self.state["optimizer"]
|
|
743
|
+
and not self.skip_optimizer_loading
|
|
744
|
+
):
|
|
745
|
+
self.state["optimizer"].load_state_dict(ckpt["optimizer"])
|
|
746
|
+
else:
|
|
747
|
+
ranked_logger.warning("Skipping optimizer loading...")
|
|
748
|
+
|
|
749
|
+
def _load_scheduler(self, ckpt: Mapping) -> None:
|
|
750
|
+
"""Loads the learning rate scheduler state from the checkpoint."""
|
|
751
|
+
if "scheduler_cfg" in ckpt and self.state["scheduler_cfg"]:
|
|
752
|
+
self.state["scheduler_cfg"].load_state_dict(ckpt["scheduler_cfg"])
|
|
753
|
+
else:
|
|
754
|
+
ranked_logger.warning("Skipping scheduler loading...")
|
|
755
|
+
|
|
756
|
+
def _load_model(self, ckpt: Mapping) -> None:
|
|
757
|
+
"""Loads the model state from the checkpoint, handling EMA and size mismatches."""
|
|
758
|
+
|
|
759
|
+
def _subset_state_dict_to_valid_params(
|
|
760
|
+
current_dict: Mapping, ckpt_dict: Mapping, log_prefix: str = ""
|
|
761
|
+
) -> dict:
|
|
762
|
+
"""Subset checkpoint to parameters with matching sizes, warn on mismatches."""
|
|
763
|
+
valid_state_dict = {}
|
|
764
|
+
for key, ckpt_tensor in ckpt_dict.items():
|
|
765
|
+
if key not in current_dict:
|
|
766
|
+
continue # Let strict=False handle missing keys
|
|
767
|
+
|
|
768
|
+
if ckpt_tensor.size() != current_dict[key].size():
|
|
769
|
+
ranked_logger.warning(
|
|
770
|
+
f"{log_prefix}Size mismatch for '{key}': "
|
|
771
|
+
f"model size {tuple(current_dict[key].size())} vs "
|
|
772
|
+
f"checkpoint size {tuple(ckpt_tensor.size())}. "
|
|
773
|
+
"Skipping this parameter."
|
|
774
|
+
)
|
|
775
|
+
else:
|
|
776
|
+
valid_state_dict[key] = ckpt_tensor
|
|
777
|
+
|
|
778
|
+
return valid_state_dict
|
|
779
|
+
|
|
780
|
+
# ... load the model, subsetting to parameters with matching sizes
|
|
781
|
+
model = self.state["model"]
|
|
782
|
+
model.load_state_dict(
|
|
783
|
+
_subset_state_dict_to_valid_params(model.state_dict(), ckpt["model"]),
|
|
784
|
+
strict=False,
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
def load_checkpoint(
|
|
788
|
+
self, ckpt_path: Path | str, is_inference: bool = False
|
|
789
|
+
) -> None:
|
|
790
|
+
"""Loads a checkpoint from the specified path."""
|
|
791
|
+
# ... load the checkpoint (replaces the state dictionary in-place)
|
|
792
|
+
ranked_logger.info(f"Loading checkpoint from: {ckpt_path}...")
|
|
793
|
+
ckpt = self.fabric.load(ckpt_path)
|
|
794
|
+
|
|
795
|
+
try:
|
|
796
|
+
# ... optimize, scheduler, model
|
|
797
|
+
if not is_inference:
|
|
798
|
+
self._load_optimizer(ckpt)
|
|
799
|
+
self._load_scheduler(ckpt)
|
|
800
|
+
self._load_model(ckpt)
|
|
801
|
+
|
|
802
|
+
# ... stateless keys
|
|
803
|
+
# (We do not want to load the `train_cfg` in this instance, as it may contain different configurations)
|
|
804
|
+
keys_to_ignore = {"model", "optimizer", "scheduler_cfg", "train_cfg"}
|
|
805
|
+
self.state.update(
|
|
806
|
+
{
|
|
807
|
+
key: value
|
|
808
|
+
for key, value in ckpt.items()
|
|
809
|
+
if key not in keys_to_ignore and key in self.state
|
|
810
|
+
}
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
# Log warnings for missing and extra keys
|
|
814
|
+
state_keys = set(self.state) - keys_to_ignore
|
|
815
|
+
ckpt_keys = set(ckpt) - keys_to_ignore
|
|
816
|
+
|
|
817
|
+
if missing := state_keys - ckpt_keys:
|
|
818
|
+
ranked_logger.warning(
|
|
819
|
+
f"Keys found in STATE but not CKPT: {sorted(missing)}"
|
|
820
|
+
)
|
|
821
|
+
if extra := ckpt_keys - state_keys:
|
|
822
|
+
ranked_logger.warning(
|
|
823
|
+
f"Keys found in CKPT but not STATE: {sorted(extra)}"
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
ranked_logger.info(
|
|
827
|
+
f"Loaded checkpoint. Current epoch: {self.state['current_epoch']}, global step: {self.state['global_step']}"
|
|
828
|
+
)
|
|
829
|
+
except Exception as e:
|
|
830
|
+
ranked_logger.error(
|
|
831
|
+
f"Error loading checkpoint: {e}. Trying to load with legacy settings..."
|
|
832
|
+
)
|
|
833
|
+
self.load_legacy_checkpoint(ckpt)
|
|
834
|
+
|
|
835
|
+
def load_legacy_checkpoint(self, ckpt: dict) -> dict:
|
|
836
|
+
# TODO: Remove when no longer needed
|
|
837
|
+
"""Backwards-compatibility function to checkpoints with legacy state formats"""
|
|
838
|
+
new_model_state = {}
|
|
839
|
+
prefixes = {key.split(".")[0] for key in ckpt["final_state_dict"].keys()}
|
|
840
|
+
|
|
841
|
+
if "model" not in prefixes:
|
|
842
|
+
# (Model-only checkpoints from training, without confidence head)
|
|
843
|
+
model_state_dict = {
|
|
844
|
+
f"model.{k}": v for k, v in ckpt["final_state_dict"].items()
|
|
845
|
+
}
|
|
846
|
+
shadow_state_dict = {
|
|
847
|
+
f"shadow.{k}": v for k, v in ckpt["model_state_dict"].items()
|
|
848
|
+
}
|
|
849
|
+
full_state_dict = {**model_state_dict, **shadow_state_dict}
|
|
850
|
+
|
|
851
|
+
elif "confidence" in prefixes:
|
|
852
|
+
# (Checkpoints with confidence head)
|
|
853
|
+
ranked_logger.info("Detected confidence module in checkpoint...")
|
|
854
|
+
|
|
855
|
+
# ... replace confidence head keys with model and shadow prefixes
|
|
856
|
+
model_state_dict = {
|
|
857
|
+
f"model.confidence_head{key[len('confidence') :]}"
|
|
858
|
+
if key.startswith("confidence")
|
|
859
|
+
else key: value
|
|
860
|
+
for key, value in ckpt["final_state_dict"].items()
|
|
861
|
+
}
|
|
862
|
+
|
|
863
|
+
shadow_state_dict = {
|
|
864
|
+
(
|
|
865
|
+
f"shadow.confidence_head{key[len('confidence') :]}"
|
|
866
|
+
if key.startswith("confidence")
|
|
867
|
+
else f"shadow{key[len('model') :]}"
|
|
868
|
+
if key.startswith("model")
|
|
869
|
+
else key
|
|
870
|
+
): value
|
|
871
|
+
for key, value in ckpt["model_state_dict"].items()
|
|
872
|
+
}
|
|
873
|
+
full_state_dict = {**model_state_dict, **shadow_state_dict}
|
|
874
|
+
else:
|
|
875
|
+
raise ValueError("Unknown checkpoint format")
|
|
876
|
+
|
|
877
|
+
# ... check shapes (we only load matching shapes to support fine-tuning or adding channels)
|
|
878
|
+
state_dict = self.state["model"].state_dict()
|
|
879
|
+
for param in state_dict:
|
|
880
|
+
if param not in full_state_dict:
|
|
881
|
+
ranked_logger.error(f"missing: {param}")
|
|
882
|
+
elif full_state_dict[param].shape == state_dict[param].shape:
|
|
883
|
+
new_model_state[param] = full_state_dict[param]
|
|
884
|
+
else:
|
|
885
|
+
ranked_logger.error(
|
|
886
|
+
f"wrong size: {param} {full_state_dict[param].shape} {state_dict[param].shape}"
|
|
887
|
+
)
|
|
888
|
+
|
|
889
|
+
# ... update the state
|
|
890
|
+
self.state["model"].load_state_dict(new_model_state, strict=False)
|
|
891
|
+
self.state["current_epoch"] = ckpt["epoch"]
|
|
892
|
+
|
|
893
|
+
ranked_logger.info(
|
|
894
|
+
f"Loaded internal AF3 clone checkpoint into model. Current epoch: {self.state['current_epoch']}, global step: {self.state['global_step']}"
|
|
895
|
+
)
|
|
896
|
+
|
|
897
|
+
@staticmethod
|
|
898
|
+
def get_latest_checkpoint(ckpt_load_dir: Path) -> Path:
|
|
899
|
+
"""Returns the latest checkpoint file from the given directory.
|
|
900
|
+
|
|
901
|
+
Assumes that checkpoints are stored with filenames such that a standard string-based
|
|
902
|
+
sort will correctly order them by creation time (e.g., with epoch numbers, or timestamps).
|
|
903
|
+
|
|
904
|
+
Args:
|
|
905
|
+
ckpt_load_dir (Path): The directory to search for checkpoint files.
|
|
906
|
+
|
|
907
|
+
Returns:
|
|
908
|
+
Path: The path to the latest checkpoint file, or None if no checkpoints are found
|
|
909
|
+
or if the directory does not exist.
|
|
910
|
+
"""
|
|
911
|
+
if not ckpt_load_dir.is_dir():
|
|
912
|
+
return None
|
|
913
|
+
|
|
914
|
+
# List all files in the directory and sort them
|
|
915
|
+
items = sorted(ckpt_load_dir.iterdir())
|
|
916
|
+
|
|
917
|
+
# Return the last item in the sorted list, if any
|
|
918
|
+
return items[-1] if items else None
|
|
919
|
+
|
|
920
|
+
@property
|
|
921
|
+
def should_validate(self) -> bool:
|
|
922
|
+
"""Whether to currently run validation."""
|
|
923
|
+
return self.state["current_epoch"] % self.validate_every_n_epochs == 0
|