rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,284 @@
1
+ """Generalized symmetry resolution implementation, operating on the outputs of AtomWorks.io `parse` function."""
2
+
3
+ import logging
4
+ from typing import Any, Dict
5
+
6
+ import numpy as np
7
+ import torch
8
+ from atomworks.io.transforms.atom_array import ensure_atom_array_stack
9
+ from atomworks.ml.transforms.atom_array import AddGlobalTokenIdAnnotation
10
+ from atomworks.ml.transforms.atomize import AtomizeByCCDName
11
+ from atomworks.ml.transforms.base import Compose, convert_to_torch
12
+ from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX
13
+ from biotite.structure import AtomArray, AtomArrayStack
14
+ from jaxtyping import Bool, Float, Int
15
+ from rf3.loss.af3_losses import (
16
+ ResidueSymmetryResolution,
17
+ SubunitSymmetryResolution,
18
+ )
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def resolve_symmetries(
24
+ predicted_atom_array: AtomArray | AtomArrayStack,
25
+ ground_truth_atom_array: AtomArray | AtomArrayStack,
26
+ resolve_residue_symmetries: bool = True,
27
+ resolve_subunit_symmetries: bool = True,
28
+ ) -> AtomArrayStack:
29
+ """
30
+ Generalized symmetry resolution for both residue- and subunit-level symmetries.
31
+
32
+ Returns updated ground truth AtomArray with coordinates that minimize RMSD with the predicted structure.
33
+
34
+ Args:
35
+ predicted_atom_array: Predicted structure as AtomArray or AtomArrayStack
36
+ ground_truth_atom_array: Ground truth structure as AtomArray or AtomArrayStack
37
+ resolve_residue_symmetries: Whether to resolve residue-level symmetries
38
+ resolve_subunit_symmetries: Whether to resolve subunit-level symmetries
39
+
40
+ Returns:
41
+ Updated ground truth AtomArray or AtomArrayStack with resolved coordinates
42
+ """
43
+ predicted_stack = ensure_atom_array_stack(predicted_atom_array)
44
+ ground_truth_stack = ensure_atom_array_stack(ground_truth_atom_array)
45
+
46
+ # Set ground truth coordinates to nan if they are nan in the predicted coordinates...
47
+ ground_truth_stack.coord[np.isnan(predicted_stack.coord)] = np.nan
48
+
49
+ # ... then nan-to-num the pred_aa coordinates (otherwise, the symmetry resolution may fail)
50
+ # TODO: Update the symmetry resolution to handle NaNs in the predicted coordinates
51
+ predicted_stack.coord = np.nan_to_num(predicted_stack.coord)
52
+
53
+ # Extract predicted and ground truth coordinates
54
+ X_pred: Float[torch.Tensor, "D L 3"] = torch.tensor(
55
+ predicted_stack.coord, dtype=torch.float32
56
+ )
57
+ X_gt: Float[torch.Tensor, "D L 3"] = torch.tensor(
58
+ ground_truth_stack.coord, dtype=torch.float32
59
+ )
60
+
61
+ # (Match dimensions)
62
+ D_pred, L_pred = X_pred.shape[:2]
63
+ D_gt, L_gt = X_gt.shape[:2]
64
+
65
+ if D_pred != D_gt:
66
+ if D_gt == 1:
67
+ X_gt = X_gt.expand(D_pred, -1, -1)
68
+ else:
69
+ raise ValueError(
70
+ f"Cannot broadcast ground truth of shape ({D_gt}) to prediction of shape ({D_pred})"
71
+ )
72
+ assert L_pred == L_gt, "Length mismatch: predicted {L_pred}, ground truth {L_gt}"
73
+
74
+ # Generate symmetric features (e.g., automorphisms, entity information, etc.) inputs from ground truth
75
+ symmetry_data = generate_symmetry_resolution_inputs_from_atom_array(
76
+ ground_truth_stack
77
+ )
78
+
79
+ # Extract coordinate mask from ground truth stack
80
+ crd_mask: Bool[torch.Tensor, "D L"]
81
+ if "occupancy" in ground_truth_stack.get_annotation_categories():
82
+ crd_mask = torch.tensor(ground_truth_stack.occupancy > 0.0, dtype=torch.bool)
83
+ else:
84
+ logger.warning(
85
+ "No occupancy annotation found in ground truth, using coordinate validity mask (not NaN)"
86
+ )
87
+ crd_mask = ~torch.isnan(torch.tensor(ground_truth_stack.coord)).any(dim=-1)
88
+
89
+ assert not torch.isnan(
90
+ X_pred
91
+ ).any(), "NaN coordinates found in predicted structure!"
92
+
93
+ # Apply symmetry resolution (returns updated ground truth coordinates)
94
+ X_gt_resolved: Float[torch.Tensor, "D L 3"] = apply_symmetry_resolution(
95
+ X_pred=X_pred,
96
+ X_gt=X_gt,
97
+ crd_mask=crd_mask,
98
+ automorphisms=symmetry_data["automorphisms"],
99
+ molecule_entity=symmetry_data["molecule_entity"],
100
+ molecule_iid=symmetry_data["molecule_iid"],
101
+ crop_mask=symmetry_data["crop_mask"],
102
+ coord_atom_lvl=symmetry_data["coord_atom_lvl"],
103
+ mask_atom_lvl=symmetry_data["mask_atom_lvl"],
104
+ resolve_residue=resolve_residue_symmetries,
105
+ resolve_subunit=resolve_subunit_symmetries,
106
+ )
107
+
108
+ # Update the ground truth AtomArray with resolved coordinates
109
+ result_stack = ground_truth_stack.copy()
110
+ result_stack.coord = X_gt_resolved.cpu().numpy()
111
+
112
+ return result_stack
113
+
114
+
115
+ def generate_symmetry_resolution_inputs_from_atom_array(
116
+ atom_array: AtomArray | AtomArrayStack,
117
+ ) -> Dict[str, Any]:
118
+ """
119
+ Generate all inputs needed for symmetry resolution from an AtomArray.
120
+
121
+ Args:
122
+ atom_array: Input AtomArray or AtomArrayStack
123
+
124
+ Returns:
125
+ Dictionary containing:
126
+ - automorphisms: List[np.ndarray]
127
+ - molecule_entity: torch.Tensor [N_atoms]
128
+ - molecule_iid: torch.Tensor [N_atoms]
129
+ - coord_atom_lvl: torch.Tensor [N_atoms, 3]
130
+ - mask_atom_lvl: torch.Tensor [N_atoms]
131
+ - atom_to_token_map: torch.Tensor [N_atoms]
132
+ - crop_mask: torch.Tensor [N_atoms]
133
+ """
134
+ # (Take first model)
135
+ atom_array_stack = ensure_atom_array_stack(atom_array)
136
+ atom_array = atom_array_stack[0]
137
+
138
+ # (Avoid modifying the original)
139
+ atom_array = atom_array.copy()
140
+
141
+ # Prepare transform pipeline to generate features
142
+ transforms = [AtomizeByCCDName(atomize_by_default=True)]
143
+
144
+ if "token_id" not in atom_array.get_annotation_categories():
145
+ transforms.append(AddGlobalTokenIdAnnotation())
146
+
147
+ transforms.append(FindAutomorphismsWithNetworkX())
148
+
149
+ pipeline = Compose(transforms)
150
+ data = pipeline({"atom_array": atom_array})
151
+ atom_array = data["atom_array"]
152
+
153
+ result: Dict[str, Any] = {}
154
+ # Extract automorphisms
155
+ result["automorphisms"] = data.get("automorphisms", [])
156
+ # Extract molecule annotations (assert they exist)
157
+ assert (
158
+ "molecule_entity" in atom_array.get_annotation_categories()
159
+ ), "molecule_entity annotation required"
160
+ assert (
161
+ "molecule_iid" in atom_array.get_annotation_categories()
162
+ ), "molecule_iid annotation required"
163
+
164
+ result["molecule_entity"] = atom_array.molecule_entity
165
+ result["molecule_iid"] = atom_array.molecule_iid
166
+
167
+ # Extract coordinates
168
+ coords: np.ndarray = atom_array.coord
169
+ result["coord_atom_lvl"] = coords
170
+
171
+ # Extract mask from occupancy (like in lddt.py) - no batch dimension for SubunitSymmetryResolution
172
+ mask: np.ndarray
173
+ if "occupancy" in atom_array.get_annotation_categories():
174
+ mask = atom_array.occupancy > 0.0
175
+ else:
176
+ # Fallback to coordinate validity
177
+ mask = ~np.isnan(atom_array.coord).any(axis=-1)
178
+
179
+ # Keep mask as [N_atoms] for SubunitSymmetryResolution compatibility
180
+ result["mask_atom_lvl"] = mask
181
+
182
+ # Extract atom to token map (like in lddt.py)
183
+ if "token_id" in atom_array.get_annotation_categories():
184
+ result["atom_to_token_map"] = atom_array.token_id.astype(np.int32)
185
+ else:
186
+ # This should not happen since AddGlobalTokenIdAnnotation was applied
187
+ raise ValueError(
188
+ "token_id annotation not found after AddGlobalTokenIdAnnotation"
189
+ )
190
+
191
+ # Create crop_mask (full range)
192
+ result["crop_mask"] = np.arange(len(atom_array), dtype=np.int32)
193
+
194
+ # Step 3: Convert all numpy arrays to torch tensors using convert_to_torch
195
+ # First, create a temporary dict with the keys we want to convert
196
+ torch_data = {
197
+ "molecule_entity": result["molecule_entity"],
198
+ "molecule_iid": result["molecule_iid"],
199
+ "coord_atom_lvl": result["coord_atom_lvl"],
200
+ "mask_atom_lvl": result["mask_atom_lvl"],
201
+ "atom_to_token_map": result["atom_to_token_map"],
202
+ "crop_mask": result["crop_mask"],
203
+ }
204
+
205
+ # Convert to torch tensors
206
+ torch_data = convert_to_torch(torch_data, list(torch_data.keys()))
207
+
208
+ # Update result with torch tensors
209
+ result.update(torch_data)
210
+
211
+ return result
212
+
213
+
214
+ def apply_symmetry_resolution(
215
+ X_pred: Float[torch.Tensor, "D L 3"],
216
+ X_gt: Float[torch.Tensor, "D L 3"],
217
+ crd_mask: Bool[torch.Tensor, "D L"],
218
+ automorphisms: list,
219
+ molecule_entity: Int[torch.Tensor, "N_atoms"],
220
+ molecule_iid: Int[torch.Tensor, "N_atoms"],
221
+ crop_mask: Int[torch.Tensor, "N_atoms"],
222
+ coord_atom_lvl: Float[torch.Tensor, "N_atoms 3"],
223
+ mask_atom_lvl: Bool[torch.Tensor, "N_atoms"],
224
+ resolve_residue: bool = True,
225
+ resolve_subunit: bool = True,
226
+ ) -> Float[torch.Tensor, "D L 3"]:
227
+ """
228
+ Apply the actual symmetry resolution using the existing classes and return updated coordinates.
229
+
230
+ Args:
231
+ X_pred: Predicted coordinates [D, L, 3]
232
+ X_gt: Ground truth coordinates [D, L, 3]
233
+ crd_mask: Coordinate mask [D, L]
234
+ automorphisms: List of automorphism groups
235
+ molecule_entity: Molecule entity IDs [N_atoms]
236
+ molecule_iid: Molecule instance IDs [N_atoms]
237
+ crop_mask: Crop mask indices [N_atoms]
238
+ coord_atom_lvl: Atom-level coordinates [N_atoms, 3]
239
+ mask_atom_lvl: Atom-level mask [N_atoms]
240
+ resolve_residue: Whether to resolve residue symmetries
241
+ resolve_subunit: Whether to resolve subunit symmetries
242
+
243
+ Returns:
244
+ Updated ground truth coordinates [D, L, 3]
245
+ """
246
+ # Prepare loss_input dictionary for existing classes
247
+ loss_input: Dict[str, torch.Tensor] = {
248
+ "X_gt_L": X_gt.clone(),
249
+ "crd_mask_L": crd_mask.clone(),
250
+ }
251
+
252
+ # Apply subunit symmetry resolution
253
+ if resolve_subunit:
254
+ subunit_resolver = SubunitSymmetryResolution()
255
+
256
+ # Create symmetry resolution input
257
+ symmetry_resolution: Dict[str, torch.Tensor] = {
258
+ "molecule_entity": molecule_entity,
259
+ "molecule_iid": molecule_iid,
260
+ "crop_mask": crop_mask,
261
+ "coord_atom_lvl": coord_atom_lvl,
262
+ "mask_atom_lvl": mask_atom_lvl,
263
+ }
264
+
265
+ # Create network output dict
266
+ network_output: Dict[str, torch.Tensor] = {"X_L": X_pred}
267
+
268
+ # Apply subunit resolution
269
+ logger.info("Applying subunit symmetry resolution")
270
+ loss_input = subunit_resolver(network_output, loss_input, symmetry_resolution)
271
+
272
+ # Apply residue symmetry resolution
273
+ if resolve_residue and automorphisms:
274
+ logger.info("Applying residue symmetry resolution")
275
+ residue_resolver = ResidueSymmetryResolution()
276
+
277
+ # Create network output dict
278
+ network_output: Dict[str, torch.Tensor] = {"X_L": X_pred}
279
+
280
+ # Apply residue resolution
281
+ loss_input = residue_resolver(network_output, loss_input, automorphisms)
282
+
283
+ # Return the updated ground truth coordinates
284
+ return loss_input["X_gt_L"]
rf3/train.py ADDED
@@ -0,0 +1,194 @@
1
+ #!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"'
2
+
3
+ import logging
4
+ import os
5
+
6
+ import hydra
7
+ import rootutils
8
+ from dotenv import load_dotenv
9
+ from omegaconf import DictConfig
10
+
11
+ from foundry.utils.logging import suppress_warnings
12
+ from foundry.utils.weights import CheckpointConfig
13
+
14
+ # Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
15
+ # NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
16
+ rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
17
+
18
+ load_dotenv(override=True)
19
+
20
+ _config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs")
21
+
22
+ _spawning_process_logger = logging.getLogger(__name__)
23
+
24
+
25
+ @hydra.main(config_path=_config_path, config_name="train", version_base="1.3")
26
+ def train(cfg: DictConfig) -> None:
27
+ # ==============================================================================
28
+ # Import dependencies and resolve Hydra configuration
29
+ # ==============================================================================
30
+
31
+ _spawning_process_logger.info("Importing dependencies...")
32
+
33
+ # Lazy imports to make config generation fast
34
+ import torch
35
+ from lightning.fabric import seed_everything
36
+ from lightning.fabric.loggers import Logger
37
+
38
+ # If training on DIGS L40, set precision of matrix multiplication to balance speed and accuracy
39
+ # Reference: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
40
+ torch.set_float32_matmul_precision("medium")
41
+
42
+ from foundry.callbacks.callback import BaseCallback # noqa
43
+ from foundry.utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa
44
+ from foundry.utils.logging import (
45
+ print_config_tree,
46
+ log_hyperparameters_with_all_loggers,
47
+ ) # noqa
48
+ from foundry.utils.ddp import RankedLogger # noqa
49
+ from foundry.utils.ddp import is_rank_zero, set_accelerator_based_on_availability # noqa
50
+ from foundry.utils.datasets import (
51
+ recursively_instantiate_datasets_and_samplers,
52
+ assemble_distributed_loader,
53
+ subset_dataset_to_example_ids,
54
+ assemble_val_loader_dict,
55
+ ) # noqa
56
+
57
+ set_accelerator_based_on_availability(cfg)
58
+
59
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
60
+ _spawning_process_logger.info("Completed dependency imports ...")
61
+
62
+ # ... print the configuration tree (NOTE: Only prints for rank 0)
63
+ print_config_tree(cfg, resolve=True)
64
+
65
+ # ==============================================================================
66
+ # Logging and Callback instantiation
67
+ # ==============================================================================
68
+
69
+ # Reduce the logging level for all dataset and sampler loggers (unless rank 0)
70
+ # We will still see messages from Rank 0; they are identical, since all ranks load and sample from the same datasets
71
+ if not is_rank_zero():
72
+ dataset_logger = logging.getLogger("datasets")
73
+ sampler_logger = logging.getLogger("atomworks.ml.samplers")
74
+ dataset_logger.setLevel(logging.WARNING)
75
+ sampler_logger.setLevel(logging.ERROR)
76
+
77
+ # ... seed everything (NOTE: By setting `workers=True`, we ensure that the dataloaders are seeded as well)
78
+ # (`PL_GLOBAL_SEED` environment varaible will be passed to the spawned subprocessed; e.g., through `ddp_spawn` backend)
79
+ if cfg.get("seed"):
80
+ ranked_logger.info(f"Seeding everything with seed={cfg.seed}...")
81
+ seed_everything(cfg.seed, workers=True, verbose=True)
82
+ else:
83
+ ranked_logger.warning("No seed provided - Not seeding anything!")
84
+
85
+ ranked_logger.info("Instantiating loggers...")
86
+ loggers: list[Logger] = instantiate_loggers(cfg.get("logger"))
87
+
88
+ ranked_logger.info("Instantiating callbacks...")
89
+ callbacks: list[BaseCallback] = instantiate_callbacks(cfg.get("callbacks"))
90
+
91
+ # ==============================================================================
92
+ # Trainer and model instantiation
93
+ # ==============================================================================
94
+
95
+ # ... instantiate the trainer
96
+ ranked_logger.info("Instantiating trainer...")
97
+ trainer = hydra.utils.instantiate(
98
+ cfg.trainer,
99
+ loggers=loggers or None,
100
+ callbacks=callbacks or None,
101
+ _convert_="partial",
102
+ _recursive_=False,
103
+ )
104
+ # (Store the Hydra configuration in the trainer state)
105
+ trainer.initialize_or_update_trainer_state({"train_cfg": cfg})
106
+
107
+ # ... spawn processes for distributed training
108
+ # (We spawn here, rather than within `fit`, so we can use Fabric's `init_module` to efficiently initialize the model on the appropriate device)
109
+ ranked_logger.info(
110
+ f"Spawning {trainer.fabric.world_size} processes from {trainer.fabric.global_rank}..."
111
+ )
112
+ trainer.fabric.launch()
113
+
114
+ # ... construct the model
115
+ trainer.construct_model()
116
+
117
+ # ... construct the optimizer and schedule (which requires the model to be constructed)
118
+ trainer.construct_optimizer()
119
+ trainer.construct_scheduler()
120
+
121
+ # ==============================================================================
122
+ # Dataset instantiation
123
+ # ==============================================================================
124
+
125
+ # Number of examples per epoch (accross all GPUs)
126
+ # (We must sample this many indices from our sampler)
127
+ n_examples_per_epoch = cfg.trainer.n_examples_per_epoch
128
+
129
+ # ... build the train dataset
130
+ assert (
131
+ "train" in cfg.datasets and cfg.datasets.train
132
+ ), "No 'train' dataloader configuration provided! If only performing validation, use `validate.py` instead."
133
+ dataset_and_sampler = recursively_instantiate_datasets_and_samplers(
134
+ cfg.datasets.train
135
+ )
136
+ train_dataset, train_sampler = (
137
+ dataset_and_sampler["dataset"],
138
+ dataset_and_sampler["sampler"],
139
+ )
140
+
141
+ # ... compose the train loader
142
+ if "subset_to_example_ids" in cfg.datasets:
143
+ # Backdoor for debugging and overfitting: subset the dataset to a specific set of example IDs
144
+ train_dataset = subset_dataset_to_example_ids(
145
+ train_dataset, cfg.datasets.subset_to_example_ids
146
+ )
147
+ train_sampler = None # Sampler is no longer valid, since we are using a subset of the dataset
148
+
149
+ train_loader = assemble_distributed_loader(
150
+ dataset=train_dataset,
151
+ sampler=train_sampler,
152
+ rank=trainer.fabric.global_rank,
153
+ world_size=trainer.fabric.world_size,
154
+ n_examples_per_epoch=n_examples_per_epoch,
155
+ loader_cfg=cfg.dataloader["train"],
156
+ )
157
+
158
+ # ... compose the validation loader(s)
159
+ if "val" in cfg.datasets and cfg.datasets.val:
160
+ val_loaders = assemble_val_loader_dict(
161
+ cfg=cfg.datasets.val,
162
+ rank=trainer.fabric.global_rank,
163
+ world_size=trainer.fabric.world_size,
164
+ loader_cfg=cfg.dataloader["val"],
165
+ )
166
+ else:
167
+ ranked_logger.warning("No validation datasets provided! Skipping validation...")
168
+ val_loaders = None
169
+
170
+ ranked_logger.info("Logging hyperparameters...")
171
+ log_hyperparameters_with_all_loggers(
172
+ trainer=trainer, cfg=cfg, model=trainer.state["model"]
173
+ )
174
+
175
+ # ... load the checkpoint configuration
176
+ ckpt_config = None
177
+ if "ckpt_config" in cfg and cfg.ckpt_config:
178
+ ckpt_config = hydra.utils.instantiate(cfg.ckpt_config)
179
+ elif "ckpt_path" in cfg and cfg.ckpt_path:
180
+ # Just a checkpoint path
181
+ if cfg.ckpt_path is not None:
182
+ ckpt_config = CheckpointConfig(path=cfg.ckpt_path)
183
+
184
+ # ... train the model
185
+ ranked_logger.info("Training model...")
186
+
187
+ with suppress_warnings():
188
+ trainer.fit(
189
+ train_loader=train_loader, val_loaders=val_loaders, ckpt_config=ckpt_config
190
+ )
191
+
192
+
193
+ if __name__ == "__main__":
194
+ train()