rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,923 @@
1
+ """Generic training harness built atop PyTorch Lightning Fabric.
2
+
3
+ In addition to standard harness features (gradient accumulation, mixed precision, etc.), includes native support for EMA.
4
+
5
+ References:
6
+ - Pytorch Lightning Trainer Example (https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/fabric/build_your_own_trainer/trainer.py)
7
+ - Lightning Hydra Template (https://github.com/ashleve/lightning-hydra-template)
8
+ """
9
+
10
+ import math
11
+ from abc import ABC, abstractmethod
12
+ from datetime import timedelta
13
+ from pathlib import Path
14
+ from typing import cast
15
+
16
+ import hydra
17
+ import lightning as L
18
+ import torch
19
+ from beartype.typing import Any, Literal, Mapping
20
+ from lightning.fabric.accelerators import Accelerator
21
+ from lightning.fabric.loggers import Logger
22
+ from lightning.fabric.strategies import DDPStrategy, Strategy
23
+ from lightning.fabric.wrappers import (
24
+ _FabricDataLoader,
25
+ _FabricModule,
26
+ _FabricOptimizer,
27
+ )
28
+
29
+ from foundry.callbacks.callback import BaseCallback
30
+ from foundry.training.EMA import EMA
31
+ from foundry.training.schedulers import SchedulerConfig
32
+ from foundry.utils.ddp import RankedLogger
33
+ from foundry.utils.weights import (
34
+ CheckpointConfig,
35
+ WeightLoadingConfig,
36
+ freeze_parameters_with_config,
37
+ load_weights_with_policies,
38
+ )
39
+
40
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
41
+
42
+
43
+ def is_interactive_environment() -> bool:
44
+ try:
45
+ from IPython import get_ipython
46
+
47
+ return get_ipython() is not None
48
+ except ImportError:
49
+ return False
50
+
51
+
52
+ class FabricTrainer(ABC):
53
+ def __init__(
54
+ self,
55
+ *,
56
+ accelerator: str | Accelerator = "auto",
57
+ strategy: str | Strategy = "ddp",
58
+ devices_per_node: list[int] | int | str = "auto",
59
+ num_nodes: int = 1,
60
+ precision: str | int = "bf16-mixed",
61
+ callbacks: BaseCallback | list[BaseCallback] | None = None,
62
+ loggers: Logger | list[Logger] | None = None,
63
+ max_epochs: int = 1000,
64
+ grad_accum_steps: int = 1,
65
+ validate_every_n_epochs: int = 1,
66
+ n_examples_per_epoch: int = 24_000,
67
+ output_dir: Path | str | None = None,
68
+ checkpoint_every_n_epochs: int = 1,
69
+ checkpoint_every_n_steps: int | None = None,
70
+ clip_grad_max_norm: float | None = None,
71
+ skip_nan_grad: bool = False,
72
+ error_if_grad_nonfinite: bool = False,
73
+ limit_train_batches: int | float = float("inf"),
74
+ limit_val_batches: int | float = float("inf"),
75
+ prevalidate: bool = False,
76
+ nccl_timeout: int = 3_200,
77
+ find_unused_parameters: bool = False,
78
+ skip_optimizer_loading: bool = False,
79
+ ) -> None:
80
+ """Base Trainer class built around Lightning Fabric.
81
+
82
+ Args:
83
+ accelerator: The hardware to run on. See (1) for details. Possible choices are:
84
+ ``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
85
+ strategy: Strategy for how to run across multiple devices. See (1) for details. Possible choices are:
86
+ ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"fsdp"``.
87
+ devices_per_node: Number of devices to train on per machine (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
88
+ See (1) for details.
89
+ EXAMPLE: If you run on 2 nodes with 8 GPUs each, you would set ``devices_per_node=8``, not ``16``.
90
+ num_nodes: Number of machines (nodes) for distributed training (default: 1). See (1) for details.
91
+ precision: Double precision (``"64"``), full precision (``"32"``), half precision AMP (``"16-mixed"``),
92
+ or bfloat16 precision AMP (``"bf16-mixed"``). See (2) for details.
93
+ callbacks: A single callback or a list of callbacks, each inheriting the BaseCallback Abstract Base Class.
94
+ loggers: A single logger or a list of loggers. See (3) for details.
95
+ max_epochs: Maximum number of epochs to train for (default: 1000).
96
+ grad_accum_steps: Number of batches to process before calling optimizer.step() (default: 1). See (4) for details on gradient accumulation in Fabric.
97
+ validate_every_n_epochs: Number of epochs between validation runs (default: 1).
98
+ n_examples_per_epoch: Number of examples to sample per epoch, across all GPUs. E.g., number of distinct examples that will
99
+ be "seen" by the model in a single epoch. If smaller than the the number implied by the dataloader, we will
100
+ alert a warning and use the smaller number.
101
+ output_dir: Directory to save checkpoints, metrics, intermediate validation strructures, etc. (default: None).
102
+ checkpoint_every_n_epochs: Number of epochs between saving checkpoints (default: 1).
103
+ checkpoint_every_n_steps: Number of optimizer steps between saving checkpoints (default: None).
104
+ If set, checkpoints will be saved every N optimizer steps. If None, only epoch-based checkpointing is used.
105
+ clip_grad_max_norm: Maximum gradient norm to clip to (default: None). If None, no gradient clipping is performed.
106
+ skip_nan_grad: Whether to skip optimizer updates when gradients contain NaN or Inf values (default: False).
107
+ Useful for training stability, especially with mixed precision or challenging datasets.
108
+ error_if_grad_nonfinite: Whether to raise when gradient clipping detects NaN or Inf gradients (default: False).
109
+ limit_train_batches: Limit on the number of training batches per epoch (default: float("inf")).
110
+ Helpful for debugging; should NOT be used when training production models.
111
+ limit_val_batches: Limit on the number of validation batches per epoch (default: float("inf")).
112
+ Helpful for debugging; should NOT be used when training production models.
113
+ prevalidate: Whether to run validation before training starts (default: False).
114
+ nccl_timeout: Timeout for NCCL operations (default: 3200). Only used with DDP strategy.
115
+ find_unused_parameters: Whether to let DDP find and skip gradient synchronization for unused parameters in the model (default: False). NOTE: Setting to True will incur a performance penalty,
116
+ but allow for training for bespoke use cases where parts of the model are frozen.
117
+ skip_optimizer_loading: Whether to skip loading the optimizer/scheduler state when restoring from checkpoints (default: False).
118
+
119
+ References:
120
+ (1) Fabric Arguments (https://lightning.ai/docs/fabric/stable/api/fabric_args.html)
121
+ (2) Fabric Precision Documentation (https://lightning.ai/docs/fabric/stable/fundamentals/precision.html)
122
+ (3) Fabric Loggers (https://lightning.ai/docs/fabric/2.4.0/api/loggers.html)
123
+ (4) Efficient Gradient Accumulation (https://lightning.ai/docs/fabric/2.4.0/advanced/gradient_accumulation.html)
124
+ """
125
+ # Use custom DDP strategy only for multi-device, non-interactive environments
126
+ if (
127
+ strategy == "ddp"
128
+ and not is_interactive_environment()
129
+ and not (num_nodes == 1 and devices_per_node == 1)
130
+ ):
131
+ strategy = DDPStrategy(
132
+ timeout=timedelta(seconds=nccl_timeout),
133
+ find_unused_parameters=find_unused_parameters,
134
+ )
135
+ else:
136
+ strategy = "auto" # type: ignore
137
+
138
+ # See (1) for initialization arguments for Fabric()
139
+ self.fabric = L.Fabric(
140
+ accelerator=accelerator,
141
+ strategy=strategy,
142
+ devices=devices_per_node,
143
+ num_nodes=num_nodes,
144
+ precision=precision,
145
+ callbacks=callbacks,
146
+ loggers=loggers,
147
+ )
148
+
149
+ # Training
150
+ self.clip_grad_max_norm = clip_grad_max_norm
151
+ self.skip_nan_grad = skip_nan_grad
152
+ self.error_if_grad_nonfinite = error_if_grad_nonfinite
153
+ self.grad_accum_steps = grad_accum_steps
154
+
155
+ # Stopping
156
+ self.max_epochs = max_epochs
157
+ self.should_stop = False
158
+ self.n_examples_per_epoch = n_examples_per_epoch
159
+ self.limit_train_batches = limit_train_batches
160
+ self.limit_val_batches = limit_val_batches
161
+
162
+ # Validation
163
+ self.validate_every_n_epochs = validate_every_n_epochs
164
+ self.prevalidate = prevalidate
165
+
166
+ # Checkpoints
167
+ self.output_dir = Path(output_dir) if output_dir else None
168
+ self.checkpoint_every_n_epochs = checkpoint_every_n_epochs
169
+ self.checkpoint_every_n_steps = checkpoint_every_n_steps
170
+ self.skip_optimizer_loading = skip_optimizer_loading
171
+
172
+ def initialize_or_update_trainer_state(
173
+ self,
174
+ updates: dict,
175
+ ):
176
+ """Initialize or update the state dictionary for the trainer.
177
+
178
+ State keys:
179
+ model: The model to train.
180
+ optimizer: The optimizer to use with the model. May be None for validation/inference.
181
+ scheduler_cfg: Learning rate SchedulerConfig (e.g., a LRScheduler with intervals/frequency). May be None for validation/inference or if no scheduler is used.
182
+ global_step: Global optimizer step; used by W&B logger, learning rate schedulers, etc. Default is 0.
183
+ current_epoch: Global epoch counter; used for validation, learning rate schedulers, checkpointing, etc. Default is 0.
184
+ train_cfg: The training configuration dictionary. Used for reinitializing the trainer with the same configuration
185
+ (for training or for inference). Default is an empty dictionary.
186
+ """
187
+ # Default values for the state
188
+ default_state = {
189
+ "model": None,
190
+ "optimizer": None,
191
+ "scheduler_cfg": None,
192
+ "global_step": 0,
193
+ "current_epoch": 0,
194
+ "train_cfg": {},
195
+ }
196
+
197
+ # Initialize self.state with default values if it doesn't exist
198
+ if not hasattr(self, "state"):
199
+ self.state = default_state.copy()
200
+ else:
201
+ # Ensure existing state has all default keys
202
+ for key, value in default_state.items():
203
+ self.state.setdefault(key, value)
204
+
205
+ # Merge the updates into the existing state
206
+ self.state.update(updates)
207
+
208
+ def construct_optimizer(self) -> None:
209
+ """Instantiate the optimizer(s)
210
+
211
+ We provide a default implementation that instantiates the optimizer(s) from the Hydra configuration.
212
+ More complex models (e.g., GANs) may require custom implementations.
213
+ """
214
+ assert (
215
+ "model" in self.state and hasattr(self.state["model"], "parameters")
216
+ ), "Model not found in state dictionary! You must call `construct_model()` before constructing the optimizer."
217
+
218
+ if self.state["train_cfg"].model.optimizer:
219
+ # ... instantiate the optimizer
220
+ optimizer = hydra.utils.instantiate(
221
+ self.state["train_cfg"].model.optimizer,
222
+ params=self.state["model"].parameters(),
223
+ )
224
+ self.initialize_or_update_trainer_state({"optimizer": optimizer})
225
+
226
+ def construct_scheduler(self) -> None:
227
+ """Instantiate the learning rate scheduler(s)
228
+
229
+ Like optimizers, we provided a default implementation that instantiates the scheduler(s) from the Hydra configuration.
230
+ More complex models (e.g., GANs) may require custom implementations.
231
+ """
232
+ assert (
233
+ "optimizer" in self.state and self.state["optimizer"]
234
+ ), "Optimizer not found in state dictionary! You must call `construct_optimizer()` before constructing the scheduler."
235
+
236
+ # ... instantiate the LR scheduler(s)
237
+ lr_scheduler = (
238
+ hydra.utils.instantiate(
239
+ self.state["train_cfg"].model.lr_scheduler,
240
+ optimizer=self.state["optimizer"],
241
+ )
242
+ if self.state["train_cfg"].model.lr_scheduler
243
+ else None
244
+ )
245
+
246
+ if lr_scheduler:
247
+ # We assume "interval = step" and "frequency = 1" for the default scheduler; custom implementations may override this method
248
+ scheduler_cfg = SchedulerConfig(
249
+ scheduler=lr_scheduler,
250
+ interval="step",
251
+ frequency=1,
252
+ )
253
+ self.initialize_or_update_trainer_state({"scheduler_cfg": scheduler_cfg})
254
+
255
+ def construct_model(self):
256
+ """Instantiate the model, updating the trainer state in-place.
257
+
258
+ This method must set the "model" key in the state dictionary using `self.initialize_or_update_trainer_state()`.
259
+ For an example, see the `construct_model` method in the `AF3Trainer`
260
+ Construct the model and optionally wrap with EMA.
261
+ """
262
+ # ... instantiate model with Hydra and Fabric
263
+ with self.fabric.init_module():
264
+ ranked_logger.info("Instantiating model...")
265
+
266
+ model = hydra.utils.instantiate(
267
+ self.state["train_cfg"].model.net,
268
+ _recursive_=False,
269
+ )
270
+
271
+ # Optionally, wrap the model with EMA
272
+ if self.state["train_cfg"].model.ema is not None:
273
+ ranked_logger.info("Wrapping model with EMA...")
274
+ model = EMA(model, **self.state["train_cfg"].model.ema)
275
+
276
+ self.initialize_or_update_trainer_state({"model": model})
277
+
278
+ def setup_model_optimizers_and_schedulers(self) -> None:
279
+ """Setup the model, optimizer(s), and scheduler(s) with Fabric.
280
+
281
+ Note that we must call this method after constructing (instantiating) the model, optimizer(s), and scheduler(s).
282
+ For details on multi-model and multi-optimizer setups, see: https://lightning.ai/docs/fabric/2.2.3/advanced/multiple_setup.html
283
+ """
284
+ assert self.state[
285
+ "model"
286
+ ], "You must construct the model before setting up the model, optimizer, and scheduler."
287
+ model = self.state["model"]
288
+ optimizer = self.state["optimizer"]
289
+
290
+ # ... setup the model and optimizer
291
+ if optimizer:
292
+ model, optimizer = self.fabric.setup(model, optimizer)
293
+ else:
294
+ model = self.fabric.setup(model)
295
+
296
+ # ... update the state dictionary (we avoid updating the state dictionary in-place, which is an anti-pattern)
297
+ self.initialize_or_update_trainer_state(
298
+ {
299
+ "model": model,
300
+ "optimizer": optimizer,
301
+ }
302
+ )
303
+
304
+ def fit(
305
+ self,
306
+ train_loader: torch.utils.data.DataLoader,
307
+ val_loaders: dict[str, torch.utils.data.DataLoader] | None = None,
308
+ ckpt_config: CheckpointConfig | None = None,
309
+ ) -> None:
310
+ """Main entry point for training a model.
311
+
312
+ Args:
313
+ train_loader: Dataloader for training. Must have an iterable returning batches.
314
+ val_loaders: Dictionary of dataloaders for validation. The keys are the names of the loaders, and the values are the loaders themselves.
315
+ ckpt_config: Configuration for loading a checkpoint. May contain:
316
+ - ckpt_path: Path to either:
317
+ (a) A previous checkpoint directory from which to resume training from. In this case, we will automatically load
318
+ the latest checkpoint using `self.get_latest_checkpoint()`.
319
+ (b) A specific checkpoint file to load. In this case, we will load the checkpoint from the specified file.
320
+ If None, no checkpoint is loaded, and the model will be trained from scratch.
321
+ - weight_loading_config: Weight loading policies to apply to the checkpoint weights. If None, default policies are used (copy weights with re-initialization as a fallback
322
+ if shapes do not match)
323
+ - reset_optimizer: Whether to reset the optimizer state when loading a checkpoint. If True, the optimizer will not be loaded from the checkpoint.
324
+ """
325
+ assert (
326
+ hasattr(self, "state") and "model" in self.state
327
+ ), "Model not found in state dictionary! You must call `instantiate_model()` before running fit()."
328
+
329
+ # (If we don't have enough examples to sample, we will log a warning and use the smaller number)
330
+ if len(train_loader) * self.fabric.world_size < self.n_examples_per_epoch:
331
+ ranked_logger.warning(
332
+ f"Number of examples per epoch ({self.n_examples_per_epoch}) exceeds the number of examples in the loader: "
333
+ f"({len(train_loader) * self.fabric.world_size}). Using the latter."
334
+ )
335
+ self.n_examples_per_epoch = len(train_loader) * self.fabric.world_size
336
+ self.n_batches_per_epoch = math.ceil(
337
+ self.n_examples_per_epoch / self.fabric.world_size
338
+ )
339
+
340
+ # ... setup training and validation dataloaders with Fabric
341
+ train_loader = self.fabric.setup_dataloaders(
342
+ # Our sampler is already distributed, so we don't need to wrap with a DistributedSampler
343
+ train_loader,
344
+ use_distributed_sampler=False,
345
+ )
346
+
347
+ if val_loaders is not None:
348
+ for key, loader in val_loaders.items():
349
+ val_loaders[key] = self.fabric.setup_dataloaders(
350
+ loader, use_distributed_sampler=False
351
+ )
352
+
353
+ self.setup_model_optimizers_and_schedulers()
354
+
355
+ if ckpt_config is not None:
356
+ assert hasattr(
357
+ ckpt_config, "path"
358
+ ), "Checkpoint path not found in checkpoint configuration!"
359
+ ckpt_path = Path(ckpt_config.path)
360
+
361
+ reset_optimizer = bool(
362
+ getattr(ckpt_config, "reset_optimizer", False)
363
+ or self.skip_optimizer_loading
364
+ )
365
+
366
+ if ckpt_path.is_dir():
367
+ # If given a directory, load the latest checkpoint from the directory
368
+ ranked_logger.info(
369
+ f"Loading latest checkpoint within the directory {ckpt_path}..."
370
+ )
371
+ self.load_checkpoint(
372
+ self.get_latest_checkpoint(ckpt_path),
373
+ weight_loading_config=ckpt_config.weight_loading_config,
374
+ reset_optimizer=reset_optimizer,
375
+ )
376
+ else:
377
+ # If given a specific checkpoint file, load that checkpoint
378
+ self.load_checkpoint(
379
+ ckpt_path,
380
+ weight_loading_config=ckpt_config.weight_loading_config,
381
+ reset_optimizer=reset_optimizer,
382
+ )
383
+
384
+ # Apply parameter freezing if a freezing config is provided
385
+ if getattr(ckpt_config, "parameter_freezing_config", None) is not None:
386
+ ranked_logger.info(
387
+ "Applying parameter freezing according to CheckpointConfig..."
388
+ )
389
+ freeze_parameters_with_config(
390
+ # We must access the model through "module", since the model may be wrapped
391
+ self.state["model"].module,
392
+ ckpt_config.parameter_freezing_config,
393
+ )
394
+
395
+ # Increment the global epoch (e.g., if we loaded a checkpoint from [the end of] epoch 5, we should start training at epoch 6)
396
+ self.state["current_epoch"] += 1
397
+ # Stopping conditions
398
+ if (
399
+ self.max_epochs is not None
400
+ and self.state["current_epoch"] >= self.max_epochs
401
+ ):
402
+ self.should_stop = True
403
+ else:
404
+ ranked_logger.info("No checkpoint provided; training from scratch.")
405
+
406
+ # Set the _num_iter_calls internal attribute of the wrapped loader to the current epoch
407
+ # (NOTE: This addresses a bug in Lightning Fabric, where there the iter() method calls the `_set_sampler_epoch()` method,
408
+ # relying on the _num_iter_calls attribute to determine the current epoch)
409
+ train_loader._num_iter_calls = self.state["current_epoch"]
410
+
411
+ self.fabric.call("on_fit_start", trainer=self)
412
+
413
+ # Prevalidate
414
+ if self.prevalidate and val_loaders:
415
+ # Temporarily decrement the current epoch, since we haven't done any training this epoch
416
+ self.state["current_epoch"] -= 1 # (Will be -1 if training from scratch)
417
+ ranked_logger.info(
418
+ f"Prevalidating with epoch {self.state['current_epoch']} before training; to avoid this behavior, set `prevalidate=False` in the Trainer config."
419
+ )
420
+ self.validation_loop(
421
+ val_loaders=val_loaders,
422
+ limit_batches=self.limit_val_batches,
423
+ )
424
+ self.state["current_epoch"] += 1 # (Restore the current epoch)
425
+
426
+ while not self.should_stop:
427
+ # ... train for one epoch
428
+ ranked_logger.info(
429
+ f"\n+ Starting epoch {self.state['current_epoch']}/{self.max_epochs - 1}\n"
430
+ f"+ Total examples per epoch (across all GPU): {self.n_examples_per_epoch}\n"
431
+ f"+ Examples per GPU (batches per epoch): {self.n_batches_per_epoch}\n"
432
+ f"+ Gradient accumulation steps: {self.grad_accum_steps}\n"
433
+ f"+ Expected optimizer steps per epoch: {self.n_batches_per_epoch // self.grad_accum_steps}\n"
434
+ )
435
+
436
+ self.train_loop(
437
+ train_loader=train_loader,
438
+ limit_batches=self.limit_train_batches,
439
+ )
440
+
441
+ ranked_logger.info(f"Finished epoch {self.state['current_epoch']}!")
442
+
443
+ # ... validate, if we're at the validation interval
444
+ if self.should_validate and val_loaders:
445
+ ranked_logger.info(
446
+ f"Starting validation for epoch {self.state['current_epoch']}!"
447
+ )
448
+ self.validation_loop(
449
+ val_loaders=val_loaders,
450
+ limit_batches=self.limit_val_batches,
451
+ )
452
+
453
+ # ... step the scheduler, if we're adjusting the learning rate at the epoch-level
454
+ self.step_scheduler(
455
+ level="epoch", current_value=self.state["current_epoch"]
456
+ )
457
+
458
+ # ... save checkpoint, if we've reached the checkpoint interval
459
+ if self.state["current_epoch"] % self.checkpoint_every_n_epochs == 0:
460
+ self.save_checkpoint()
461
+
462
+ # ... increment the epoch
463
+ self.state["current_epoch"] += 1
464
+
465
+ # Stopping conditions
466
+ if (
467
+ self.max_epochs is not None
468
+ and self.state["current_epoch"] >= self.max_epochs
469
+ ):
470
+ self.should_stop = True
471
+
472
+ # Reset for next `fit()` call
473
+ self.should_stop = False
474
+
475
+ self.fabric.call("on_fit_end", trainer=self)
476
+
477
+ def train_loop(
478
+ self,
479
+ *,
480
+ train_loader: _FabricDataLoader,
481
+ limit_batches: int | float = float("inf"),
482
+ ):
483
+ """Train model for a single epoch.
484
+
485
+ Args:
486
+ train_loader: Dataloader for training.
487
+ limit_batches: Limit on the batches during this training epoch. If greater than the number of batches in the
488
+ `train_loader`, this argument has no effect. Helpful for debugging; should NOT be used when training production models.
489
+ """
490
+ self.fabric.call("on_train_epoch_start", trainer=self)
491
+
492
+ assert self.state["model"].training
493
+
494
+ # NOTE: When we call iter(), Fabric calls the `set_sampler_epoch()` method on the sampler behind the scenes, so we don't need to call it explicitly
495
+ train_iter = iter(train_loader)
496
+ self.fabric.call("on_after_train_loader_iter", trainer=self)
497
+
498
+ for batch_idx in range(len(train_loader)):
499
+ # (End epoch if stopping training completely or maximum desired batches for this epoch reached)
500
+ if self.should_stop or batch_idx >= limit_batches:
501
+ break
502
+
503
+ self.fabric.call("on_before_train_loader_next", trainer=self)
504
+ batch = next(train_iter)
505
+
506
+ self.fabric.call(
507
+ "on_train_batch_start", trainer=self, batch=batch, batch_idx=batch_idx
508
+ )
509
+
510
+ # Optimizer should step if we've accumulated the desired number of gradients
511
+ should_optimizer_step = (batch_idx + 1) % self.grad_accum_steps == 0
512
+
513
+ self.training_step(
514
+ batch=batch,
515
+ batch_idx=batch_idx,
516
+ is_accumulating=not should_optimizer_step, # triggers gradient syncing
517
+ )
518
+
519
+ self.fabric.call(
520
+ "on_train_batch_end",
521
+ trainer=self,
522
+ outputs=self._current_train_return,
523
+ batch=batch,
524
+ batch_idx=batch_idx,
525
+ )
526
+
527
+ if should_optimizer_step:
528
+ self.fabric.call(
529
+ "on_before_optimizer_step",
530
+ trainer=self,
531
+ optimizer=self.state["optimizer"],
532
+ )
533
+
534
+ # ... step the optimizer, clipping gradients and updating EMA parameters if applicable
535
+ # Note: step_optimizer() calls optimizer.step(), which internally triggers
536
+ # on_after_optimizer_step callbacks via _FabricOptimizer
537
+ self.step_optimizer()
538
+
539
+ # ... call optimizer_step hook (distinct from on_after_optimizer_step which is called by Fabric)
540
+ self.fabric.call(
541
+ "optimizer_step",
542
+ trainer=self,
543
+ optimizer=self.state["optimizer"],
544
+ )
545
+
546
+ # ... step the scheduler, if we're adjusting the learning rate at the optimizer step-level
547
+ self.step_scheduler(
548
+ level="step", current_value=self.state["global_step"]
549
+ )
550
+
551
+ # ... increment the global step, if optimizer stepped
552
+ # NOTE: Each node maintains its own global step
553
+ self.state["global_step"] += int(should_optimizer_step)
554
+
555
+ # ... save checkpoint if we've reached the step-based checkpoint interval
556
+ if (
557
+ should_optimizer_step
558
+ and self.checkpoint_every_n_steps is not None
559
+ and self.state["global_step"] % self.checkpoint_every_n_steps == 0
560
+ ):
561
+ self.save_checkpoint()
562
+
563
+ self.fabric.call("on_train_epoch_end", trainer=self)
564
+
565
+ def validation_loop(
566
+ self,
567
+ *,
568
+ val_loaders: dict[str, _FabricDataLoader],
569
+ limit_batches: int | float = float("inf"),
570
+ ):
571
+ """Run validation loop for a single validation epoch.
572
+
573
+ Args:
574
+ val_loader: Dictionary of Dataloaders (more precisely, _FabricDataLoader) for validation.
575
+ limit_batches: Limit on the batches during this validation epoch. If greater than the number of batches in the
576
+ `val_loader`, this argument has no effect. Helpful for debugging; should NOT be used for production.
577
+ """
578
+ # ... set model to evaluation mode
579
+ self.state["model"].eval()
580
+
581
+ with torch.no_grad():
582
+ # ... assert we're in evaluation mode
583
+ assert not self.state["model"].training
584
+
585
+ self.fabric.call("on_validation_epoch_start", trainer=self)
586
+
587
+ # ... iterate over all validation loaders
588
+ for val_loader_name, val_loader in val_loaders.items():
589
+ ranked_logger.info(
590
+ f"Running validation on dataset: {val_loader_name}, with {len(val_loader)} batches, with world_size={self.fabric.world_size}."
591
+ )
592
+
593
+ for batch_idx, batch in enumerate(val_loader):
594
+ # ... end validation epoch if stopping training completely or maximum desired batches for this epoch reached
595
+ if self.should_stop or batch_idx >= limit_batches:
596
+ break
597
+
598
+ self.fabric.call(
599
+ "on_validation_batch_start",
600
+ trainer=self,
601
+ batch=batch,
602
+ batch_idx=batch_idx,
603
+ num_batches=len(val_loader),
604
+ dataset_name=val_loader_name,
605
+ )
606
+
607
+ validation_result = self.validation_step(
608
+ batch=batch,
609
+ batch_idx=batch_idx,
610
+ )
611
+
612
+ self.fabric.call(
613
+ "on_validation_batch_end",
614
+ trainer=self,
615
+ outputs=validation_result,
616
+ batch=batch,
617
+ batch_idx=batch_idx,
618
+ num_batches=len(val_loader),
619
+ dataset_name=val_loader_name,
620
+ )
621
+
622
+ self.fabric.call("on_validation_epoch_end", trainer=self)
623
+
624
+ # ... reset the model to training mode
625
+ self.state["model"].train()
626
+
627
+ @abstractmethod
628
+ def training_step(
629
+ self,
630
+ batch: Any,
631
+ batch_idx: int,
632
+ is_accumulating: bool,
633
+ ) -> None:
634
+ """Training step, running forward and backward passes.
635
+
636
+ Args:
637
+ batch: The current batch; can be of any form.
638
+ batch_idx: The index of the current batch.
639
+ is_accumulating: Whether we are accumulating gradients (i.e., not yet calling optimizer.step()).
640
+ If this is the case, we should skip the synchronization during the backward pass.
641
+
642
+ Returns:
643
+ torch.Tensor | Mapping[str, Any]: The loss tensor or a dictionary containing the loss tensor.
644
+ """
645
+ pass
646
+
647
+ @abstractmethod
648
+ def validation_step(
649
+ self,
650
+ batch: Any,
651
+ batch_idx: int,
652
+ val_loader_name: str | None = None,
653
+ ) -> dict:
654
+ """Validation step, running forward pass.
655
+
656
+ Args:
657
+ batch: The current batch; can be of any form.
658
+ batch_idx: The index of the current batch (within that validation loader).
659
+ val_loader_name: The name of the validation loader, if applicable.
660
+
661
+ Returns:
662
+ dict: A dictionary containing the output of the designated validation metrics.
663
+ """
664
+ pass
665
+
666
+ def validate(
667
+ self,
668
+ val_loaders: dict,
669
+ ckpt_path: Path | str,
670
+ ) -> None:
671
+ """Validate a model using the given dataloaders and checkpoint.
672
+
673
+ Args:
674
+ model: The PyTorch model to validate.
675
+ val_loaders: A dictionary of dataloaders for validation, where keys are names and values are dataloaders.
676
+ ckpt_path: Path to a specific checkpoint file to load. If None, the model will be validated as is.
677
+ """
678
+ assert (
679
+ hasattr(self, "state") and "model" in self.state
680
+ ), "Model not found in state dictionary! You must call `instantiate_model()` before running validate()."
681
+
682
+ self.setup_model_optimizers_and_schedulers()
683
+
684
+ self.load_checkpoint(ckpt_path)
685
+
686
+ # Setup validation dataloaders with Fabric
687
+ for key, loader in val_loaders.items():
688
+ val_loaders[key] = self.fabric.setup_dataloaders(
689
+ loader, use_distributed_sampler=False
690
+ )
691
+
692
+ # Run the validation loop
693
+ self.validation_loop(
694
+ val_loaders=val_loaders, limit_batches=self.limit_val_batches
695
+ )
696
+
697
+ def step_optimizer(self):
698
+ """Step the optimizer.
699
+
700
+ This method must be called only when the optimizer is stepped (i.e., after accumulating the desired number of gradients).
701
+
702
+ We then perform following steps:
703
+ 1. Check for NaN/Inf gradients (skip update if skip_nan_grad=True and found).
704
+ 2. Clip gradients, if applicable.
705
+ 3. Step the optimizer.
706
+ 4. Zero the gradients.
707
+ 5. Update the EMA parameters, if applicable.
708
+ """
709
+ assert "optimizer" in self.state and isinstance(
710
+ self.state["optimizer"], _FabricOptimizer
711
+ )
712
+ assert "model" in self.state and isinstance(
713
+ self.state["model"], _FabricModule | EMA
714
+ )
715
+
716
+ optimizer = self.state["optimizer"]
717
+ model = self.state["model"]
718
+
719
+ # ... check for NaN/Inf gradients, if applicable
720
+ if self.skip_nan_grad:
721
+ has_nan_or_inf = False
722
+ for param in model.parameters():
723
+ if param.grad is not None:
724
+ if not torch.isfinite(param.grad).all():
725
+ has_nan_or_inf = True
726
+ break
727
+
728
+ if has_nan_or_inf:
729
+ ranked_logger.warning(
730
+ f"Skipping optimizer step at global_step={self.state['global_step']} due to NaN/Inf gradients"
731
+ )
732
+ optimizer.zero_grad()
733
+ return # Skip this update
734
+
735
+ # ... clip gradients, if applicable
736
+ if self.clip_grad_max_norm is not None:
737
+ self.fabric.clip_gradients(
738
+ module=model,
739
+ optimizer=optimizer,
740
+ max_norm=self.clip_grad_max_norm,
741
+ error_if_nonfinite=self.error_if_grad_nonfinite,
742
+ )
743
+
744
+ # ... step the optimizer
745
+ optimizer.step()
746
+
747
+ # ... zero gradients
748
+ optimizer.zero_grad()
749
+
750
+ # ... update EMA parameters, if applicable
751
+ if hasattr(model, "update"):
752
+ model.update()
753
+
754
+ def step_scheduler(
755
+ self,
756
+ level: Literal["epoch", "step"],
757
+ current_value: int,
758
+ ):
759
+ """Step the learning rate scheduler.
760
+
761
+ Args:
762
+ level: The level at which to step the scheduler. Either "epoch" or "step".
763
+ current_value: The current epoch or step value.
764
+ """
765
+ # (No scheduler)
766
+ if "scheduler_cfg" not in self.state or self.state["scheduler_cfg"] is None:
767
+ return
768
+ else:
769
+ scheduler_cfg = self.state["scheduler_cfg"]
770
+
771
+ # (Wrong interval; e.g., we adjust learning rate every epoch, but we are stepping at the step level)
772
+ if scheduler_cfg.interval != level:
773
+ return
774
+
775
+ # (Right interval, but wrong frequency)
776
+ if current_value % cast(int, scheduler_cfg.frequency) != 0:
777
+ return
778
+
779
+ # ... step the scheduler
780
+ scheduler_cfg.scheduler.step()
781
+
782
+ def save_checkpoint(self) -> None:
783
+ """Saves a checkpoint with current state to `self.output_dir/ckpt`.
784
+
785
+ If no output directory is specified, then no checkpoint is saved.
786
+ """
787
+ # No checkpoint directory; skip saving
788
+ if not self.output_dir:
789
+ ranked_logger.warning(
790
+ "No output directory specified; skipping model checkpointing of state dictionary."
791
+ )
792
+ return
793
+
794
+ # (Provide a hook to modify the state before saving)
795
+ self.fabric.call("on_save_checkpoint", trainer=self, state=self.state)
796
+
797
+ # ... construct the checkpoint file path using Path
798
+ checkpoint_file = (
799
+ self.output_dir / "ckpt" / f"epoch-{self.state['current_epoch']:04d}.ckpt"
800
+ )
801
+
802
+ # NOTE: Fabric's `save()` will call the `state_dict()` method on the model, optimizer, and scheduler_cfg
803
+ self.fabric.save(checkpoint_file, self.state)
804
+ ranked_logger.info(f"Saved checkpoint to: {checkpoint_file}")
805
+
806
+ def _load_optimizer(self, ckpt: Mapping) -> None:
807
+ """Loads the optimizer state from the checkpoint."""
808
+ if "optimizer" in ckpt and self.state["optimizer"]:
809
+ self.state["optimizer"].load_state_dict(ckpt["optimizer"])
810
+ else:
811
+ ranked_logger.warning("Skipping optimizer loading...")
812
+
813
+ def _load_scheduler(self, ckpt: Mapping) -> None:
814
+ """Loads the learning rate scheduler state from the checkpoint."""
815
+ if "scheduler_cfg" in ckpt and self.state["scheduler_cfg"]:
816
+ self.state["scheduler_cfg"].load_state_dict(ckpt["scheduler_cfg"])
817
+ else:
818
+ ranked_logger.warning("Skipping scheduler loading...")
819
+
820
+ def _load_model(
821
+ self, ckpt: Mapping, weight_loading_config: WeightLoadingConfig | None = None
822
+ ) -> None:
823
+ """Loads the model state from the checkpoint, handling EMA and size mismatches."""
824
+ # ... load pre-trained weights from the CHECKPOINT into the MODEL (that at this point has random weights)
825
+ model = self.state["model"]
826
+ model.load_state_dict(
827
+ load_weights_with_policies(
828
+ model=self.state["model"],
829
+ ckpt=ckpt["model"],
830
+ config=weight_loading_config,
831
+ ),
832
+ strict=True,
833
+ )
834
+
835
+ def load_checkpoint(
836
+ self,
837
+ checkpoint: Path | str | dict,
838
+ weight_loading_config: WeightLoadingConfig | None = None,
839
+ reset_optimizer: bool = False,
840
+ ) -> None:
841
+ """Loads a checkpoint from the specified path or uses a pre-loaded checkpoint dict.
842
+
843
+ Args:
844
+ checkpoint: Either a path to a checkpoint file or a pre-loaded checkpoint dict.
845
+ weight_loading_config: Weight loading policies to apply. Defaults to ``None``.
846
+ reset_optimizer: Whether to reset the optimizer state. Defaults to ``False``.
847
+ """
848
+ # ... load the checkpoint or use the provided dict
849
+ if isinstance(checkpoint, dict):
850
+ ranked_logger.info("Using pre-loaded checkpoint...")
851
+ ckpt = checkpoint
852
+ else:
853
+ ranked_logger.info(f"Loading checkpoint from: {checkpoint}...")
854
+ ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False)
855
+
856
+ try:
857
+ # ... optimize, scheduler
858
+ if not reset_optimizer:
859
+ self._load_optimizer(ckpt)
860
+ self._load_scheduler(ckpt)
861
+ # ... model
862
+ self._load_model(ckpt, weight_loading_config)
863
+
864
+ # ... stateless keys
865
+ # (We do not want to load the `train_cfg` in this instance, as it may contain different configurations)
866
+ keys_to_ignore = {"model", "optimizer", "scheduler_cfg", "train_cfg"}
867
+ self.state.update(
868
+ {
869
+ key: value
870
+ for key, value in ckpt.items()
871
+ if key not in keys_to_ignore and key in self.state
872
+ }
873
+ )
874
+
875
+ # Log warnings for missing and extra keys
876
+ state_keys = set(self.state) - keys_to_ignore
877
+ ckpt_keys = set(ckpt) - keys_to_ignore
878
+
879
+ if missing := state_keys - ckpt_keys:
880
+ ranked_logger.warning(
881
+ f"Keys found in STATE but not CKPT: {sorted(missing)}"
882
+ )
883
+ if extra := ckpt_keys - state_keys:
884
+ ranked_logger.warning(
885
+ f"Keys found in CKPT but not STATE: {sorted(extra)}"
886
+ )
887
+
888
+ ranked_logger.info(
889
+ f"Loaded checkpoint. Current epoch: {self.state['current_epoch']}, global step: {self.state['global_step']}"
890
+ )
891
+ except Exception as e:
892
+ ranked_logger.exception(
893
+ f"Error loading checkpoint: {e}. Please ensure that the model architecture matches the checkpoint."
894
+ )
895
+ raise
896
+
897
+ @staticmethod
898
+ def get_latest_checkpoint(ckpt_load_dir: Path) -> Path:
899
+ """Returns the latest checkpoint file from the given directory.
900
+
901
+ Assumes that checkpoints are stored with filenames such that a standard string-based
902
+ sort will correctly order them by creation time (e.g., with epoch numbers, or timestamps).
903
+
904
+ Args:
905
+ ckpt_load_dir (Path): The directory to search for checkpoint files.
906
+
907
+ Returns:
908
+ Path: The path to the latest checkpoint file, or None if no checkpoints are found
909
+ or if the directory does not exist.
910
+ """
911
+ if not ckpt_load_dir.is_dir():
912
+ return None
913
+
914
+ # List all files in the directory and sort them
915
+ items = sorted(ckpt_load_dir.iterdir())
916
+
917
+ # Return the last item in the sorted list, if any
918
+ return items[-1] if items else None
919
+
920
+ @property
921
+ def should_validate(self) -> bool:
922
+ """Whether to currently run validation."""
923
+ return self.state["current_epoch"] % self.validate_every_n_epochs == 0