rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
rf3/symmetry/resolve.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
"""Generalized symmetry resolution implementation, operating on the outputs of AtomWorks.io `parse` function."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Dict
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from atomworks.io.transforms.atom_array import ensure_atom_array_stack
|
|
9
|
+
from atomworks.ml.transforms.atom_array import AddGlobalTokenIdAnnotation
|
|
10
|
+
from atomworks.ml.transforms.atomize import AtomizeByCCDName
|
|
11
|
+
from atomworks.ml.transforms.base import Compose, convert_to_torch
|
|
12
|
+
from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX
|
|
13
|
+
from biotite.structure import AtomArray, AtomArrayStack
|
|
14
|
+
from jaxtyping import Bool, Float, Int
|
|
15
|
+
from rf3.loss.af3_losses import (
|
|
16
|
+
ResidueSymmetryResolution,
|
|
17
|
+
SubunitSymmetryResolution,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def resolve_symmetries(
|
|
24
|
+
predicted_atom_array: AtomArray | AtomArrayStack,
|
|
25
|
+
ground_truth_atom_array: AtomArray | AtomArrayStack,
|
|
26
|
+
resolve_residue_symmetries: bool = True,
|
|
27
|
+
resolve_subunit_symmetries: bool = True,
|
|
28
|
+
) -> AtomArrayStack:
|
|
29
|
+
"""
|
|
30
|
+
Generalized symmetry resolution for both residue- and subunit-level symmetries.
|
|
31
|
+
|
|
32
|
+
Returns updated ground truth AtomArray with coordinates that minimize RMSD with the predicted structure.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
predicted_atom_array: Predicted structure as AtomArray or AtomArrayStack
|
|
36
|
+
ground_truth_atom_array: Ground truth structure as AtomArray or AtomArrayStack
|
|
37
|
+
resolve_residue_symmetries: Whether to resolve residue-level symmetries
|
|
38
|
+
resolve_subunit_symmetries: Whether to resolve subunit-level symmetries
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Updated ground truth AtomArray or AtomArrayStack with resolved coordinates
|
|
42
|
+
"""
|
|
43
|
+
predicted_stack = ensure_atom_array_stack(predicted_atom_array)
|
|
44
|
+
ground_truth_stack = ensure_atom_array_stack(ground_truth_atom_array)
|
|
45
|
+
|
|
46
|
+
# Set ground truth coordinates to nan if they are nan in the predicted coordinates...
|
|
47
|
+
ground_truth_stack.coord[np.isnan(predicted_stack.coord)] = np.nan
|
|
48
|
+
|
|
49
|
+
# ... then nan-to-num the pred_aa coordinates (otherwise, the symmetry resolution may fail)
|
|
50
|
+
# TODO: Update the symmetry resolution to handle NaNs in the predicted coordinates
|
|
51
|
+
predicted_stack.coord = np.nan_to_num(predicted_stack.coord)
|
|
52
|
+
|
|
53
|
+
# Extract predicted and ground truth coordinates
|
|
54
|
+
X_pred: Float[torch.Tensor, "D L 3"] = torch.tensor(
|
|
55
|
+
predicted_stack.coord, dtype=torch.float32
|
|
56
|
+
)
|
|
57
|
+
X_gt: Float[torch.Tensor, "D L 3"] = torch.tensor(
|
|
58
|
+
ground_truth_stack.coord, dtype=torch.float32
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# (Match dimensions)
|
|
62
|
+
D_pred, L_pred = X_pred.shape[:2]
|
|
63
|
+
D_gt, L_gt = X_gt.shape[:2]
|
|
64
|
+
|
|
65
|
+
if D_pred != D_gt:
|
|
66
|
+
if D_gt == 1:
|
|
67
|
+
X_gt = X_gt.expand(D_pred, -1, -1)
|
|
68
|
+
else:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"Cannot broadcast ground truth of shape ({D_gt}) to prediction of shape ({D_pred})"
|
|
71
|
+
)
|
|
72
|
+
assert L_pred == L_gt, "Length mismatch: predicted {L_pred}, ground truth {L_gt}"
|
|
73
|
+
|
|
74
|
+
# Generate symmetric features (e.g., automorphisms, entity information, etc.) inputs from ground truth
|
|
75
|
+
symmetry_data = generate_symmetry_resolution_inputs_from_atom_array(
|
|
76
|
+
ground_truth_stack
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Extract coordinate mask from ground truth stack
|
|
80
|
+
crd_mask: Bool[torch.Tensor, "D L"]
|
|
81
|
+
if "occupancy" in ground_truth_stack.get_annotation_categories():
|
|
82
|
+
crd_mask = torch.tensor(ground_truth_stack.occupancy > 0.0, dtype=torch.bool)
|
|
83
|
+
else:
|
|
84
|
+
logger.warning(
|
|
85
|
+
"No occupancy annotation found in ground truth, using coordinate validity mask (not NaN)"
|
|
86
|
+
)
|
|
87
|
+
crd_mask = ~torch.isnan(torch.tensor(ground_truth_stack.coord)).any(dim=-1)
|
|
88
|
+
|
|
89
|
+
assert not torch.isnan(
|
|
90
|
+
X_pred
|
|
91
|
+
).any(), "NaN coordinates found in predicted structure!"
|
|
92
|
+
|
|
93
|
+
# Apply symmetry resolution (returns updated ground truth coordinates)
|
|
94
|
+
X_gt_resolved: Float[torch.Tensor, "D L 3"] = apply_symmetry_resolution(
|
|
95
|
+
X_pred=X_pred,
|
|
96
|
+
X_gt=X_gt,
|
|
97
|
+
crd_mask=crd_mask,
|
|
98
|
+
automorphisms=symmetry_data["automorphisms"],
|
|
99
|
+
molecule_entity=symmetry_data["molecule_entity"],
|
|
100
|
+
molecule_iid=symmetry_data["molecule_iid"],
|
|
101
|
+
crop_mask=symmetry_data["crop_mask"],
|
|
102
|
+
coord_atom_lvl=symmetry_data["coord_atom_lvl"],
|
|
103
|
+
mask_atom_lvl=symmetry_data["mask_atom_lvl"],
|
|
104
|
+
resolve_residue=resolve_residue_symmetries,
|
|
105
|
+
resolve_subunit=resolve_subunit_symmetries,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Update the ground truth AtomArray with resolved coordinates
|
|
109
|
+
result_stack = ground_truth_stack.copy()
|
|
110
|
+
result_stack.coord = X_gt_resolved.cpu().numpy()
|
|
111
|
+
|
|
112
|
+
return result_stack
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def generate_symmetry_resolution_inputs_from_atom_array(
|
|
116
|
+
atom_array: AtomArray | AtomArrayStack,
|
|
117
|
+
) -> Dict[str, Any]:
|
|
118
|
+
"""
|
|
119
|
+
Generate all inputs needed for symmetry resolution from an AtomArray.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
atom_array: Input AtomArray or AtomArrayStack
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
Dictionary containing:
|
|
126
|
+
- automorphisms: List[np.ndarray]
|
|
127
|
+
- molecule_entity: torch.Tensor [N_atoms]
|
|
128
|
+
- molecule_iid: torch.Tensor [N_atoms]
|
|
129
|
+
- coord_atom_lvl: torch.Tensor [N_atoms, 3]
|
|
130
|
+
- mask_atom_lvl: torch.Tensor [N_atoms]
|
|
131
|
+
- atom_to_token_map: torch.Tensor [N_atoms]
|
|
132
|
+
- crop_mask: torch.Tensor [N_atoms]
|
|
133
|
+
"""
|
|
134
|
+
# (Take first model)
|
|
135
|
+
atom_array_stack = ensure_atom_array_stack(atom_array)
|
|
136
|
+
atom_array = atom_array_stack[0]
|
|
137
|
+
|
|
138
|
+
# (Avoid modifying the original)
|
|
139
|
+
atom_array = atom_array.copy()
|
|
140
|
+
|
|
141
|
+
# Prepare transform pipeline to generate features
|
|
142
|
+
transforms = [AtomizeByCCDName(atomize_by_default=True)]
|
|
143
|
+
|
|
144
|
+
if "token_id" not in atom_array.get_annotation_categories():
|
|
145
|
+
transforms.append(AddGlobalTokenIdAnnotation())
|
|
146
|
+
|
|
147
|
+
transforms.append(FindAutomorphismsWithNetworkX())
|
|
148
|
+
|
|
149
|
+
pipeline = Compose(transforms)
|
|
150
|
+
data = pipeline({"atom_array": atom_array})
|
|
151
|
+
atom_array = data["atom_array"]
|
|
152
|
+
|
|
153
|
+
result: Dict[str, Any] = {}
|
|
154
|
+
# Extract automorphisms
|
|
155
|
+
result["automorphisms"] = data.get("automorphisms", [])
|
|
156
|
+
# Extract molecule annotations (assert they exist)
|
|
157
|
+
assert (
|
|
158
|
+
"molecule_entity" in atom_array.get_annotation_categories()
|
|
159
|
+
), "molecule_entity annotation required"
|
|
160
|
+
assert (
|
|
161
|
+
"molecule_iid" in atom_array.get_annotation_categories()
|
|
162
|
+
), "molecule_iid annotation required"
|
|
163
|
+
|
|
164
|
+
result["molecule_entity"] = atom_array.molecule_entity
|
|
165
|
+
result["molecule_iid"] = atom_array.molecule_iid
|
|
166
|
+
|
|
167
|
+
# Extract coordinates
|
|
168
|
+
coords: np.ndarray = atom_array.coord
|
|
169
|
+
result["coord_atom_lvl"] = coords
|
|
170
|
+
|
|
171
|
+
# Extract mask from occupancy (like in lddt.py) - no batch dimension for SubunitSymmetryResolution
|
|
172
|
+
mask: np.ndarray
|
|
173
|
+
if "occupancy" in atom_array.get_annotation_categories():
|
|
174
|
+
mask = atom_array.occupancy > 0.0
|
|
175
|
+
else:
|
|
176
|
+
# Fallback to coordinate validity
|
|
177
|
+
mask = ~np.isnan(atom_array.coord).any(axis=-1)
|
|
178
|
+
|
|
179
|
+
# Keep mask as [N_atoms] for SubunitSymmetryResolution compatibility
|
|
180
|
+
result["mask_atom_lvl"] = mask
|
|
181
|
+
|
|
182
|
+
# Extract atom to token map (like in lddt.py)
|
|
183
|
+
if "token_id" in atom_array.get_annotation_categories():
|
|
184
|
+
result["atom_to_token_map"] = atom_array.token_id.astype(np.int32)
|
|
185
|
+
else:
|
|
186
|
+
# This should not happen since AddGlobalTokenIdAnnotation was applied
|
|
187
|
+
raise ValueError(
|
|
188
|
+
"token_id annotation not found after AddGlobalTokenIdAnnotation"
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Create crop_mask (full range)
|
|
192
|
+
result["crop_mask"] = np.arange(len(atom_array), dtype=np.int32)
|
|
193
|
+
|
|
194
|
+
# Step 3: Convert all numpy arrays to torch tensors using convert_to_torch
|
|
195
|
+
# First, create a temporary dict with the keys we want to convert
|
|
196
|
+
torch_data = {
|
|
197
|
+
"molecule_entity": result["molecule_entity"],
|
|
198
|
+
"molecule_iid": result["molecule_iid"],
|
|
199
|
+
"coord_atom_lvl": result["coord_atom_lvl"],
|
|
200
|
+
"mask_atom_lvl": result["mask_atom_lvl"],
|
|
201
|
+
"atom_to_token_map": result["atom_to_token_map"],
|
|
202
|
+
"crop_mask": result["crop_mask"],
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
# Convert to torch tensors
|
|
206
|
+
torch_data = convert_to_torch(torch_data, list(torch_data.keys()))
|
|
207
|
+
|
|
208
|
+
# Update result with torch tensors
|
|
209
|
+
result.update(torch_data)
|
|
210
|
+
|
|
211
|
+
return result
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def apply_symmetry_resolution(
|
|
215
|
+
X_pred: Float[torch.Tensor, "D L 3"],
|
|
216
|
+
X_gt: Float[torch.Tensor, "D L 3"],
|
|
217
|
+
crd_mask: Bool[torch.Tensor, "D L"],
|
|
218
|
+
automorphisms: list,
|
|
219
|
+
molecule_entity: Int[torch.Tensor, "N_atoms"],
|
|
220
|
+
molecule_iid: Int[torch.Tensor, "N_atoms"],
|
|
221
|
+
crop_mask: Int[torch.Tensor, "N_atoms"],
|
|
222
|
+
coord_atom_lvl: Float[torch.Tensor, "N_atoms 3"],
|
|
223
|
+
mask_atom_lvl: Bool[torch.Tensor, "N_atoms"],
|
|
224
|
+
resolve_residue: bool = True,
|
|
225
|
+
resolve_subunit: bool = True,
|
|
226
|
+
) -> Float[torch.Tensor, "D L 3"]:
|
|
227
|
+
"""
|
|
228
|
+
Apply the actual symmetry resolution using the existing classes and return updated coordinates.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
X_pred: Predicted coordinates [D, L, 3]
|
|
232
|
+
X_gt: Ground truth coordinates [D, L, 3]
|
|
233
|
+
crd_mask: Coordinate mask [D, L]
|
|
234
|
+
automorphisms: List of automorphism groups
|
|
235
|
+
molecule_entity: Molecule entity IDs [N_atoms]
|
|
236
|
+
molecule_iid: Molecule instance IDs [N_atoms]
|
|
237
|
+
crop_mask: Crop mask indices [N_atoms]
|
|
238
|
+
coord_atom_lvl: Atom-level coordinates [N_atoms, 3]
|
|
239
|
+
mask_atom_lvl: Atom-level mask [N_atoms]
|
|
240
|
+
resolve_residue: Whether to resolve residue symmetries
|
|
241
|
+
resolve_subunit: Whether to resolve subunit symmetries
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Updated ground truth coordinates [D, L, 3]
|
|
245
|
+
"""
|
|
246
|
+
# Prepare loss_input dictionary for existing classes
|
|
247
|
+
loss_input: Dict[str, torch.Tensor] = {
|
|
248
|
+
"X_gt_L": X_gt.clone(),
|
|
249
|
+
"crd_mask_L": crd_mask.clone(),
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
# Apply subunit symmetry resolution
|
|
253
|
+
if resolve_subunit:
|
|
254
|
+
subunit_resolver = SubunitSymmetryResolution()
|
|
255
|
+
|
|
256
|
+
# Create symmetry resolution input
|
|
257
|
+
symmetry_resolution: Dict[str, torch.Tensor] = {
|
|
258
|
+
"molecule_entity": molecule_entity,
|
|
259
|
+
"molecule_iid": molecule_iid,
|
|
260
|
+
"crop_mask": crop_mask,
|
|
261
|
+
"coord_atom_lvl": coord_atom_lvl,
|
|
262
|
+
"mask_atom_lvl": mask_atom_lvl,
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
# Create network output dict
|
|
266
|
+
network_output: Dict[str, torch.Tensor] = {"X_L": X_pred}
|
|
267
|
+
|
|
268
|
+
# Apply subunit resolution
|
|
269
|
+
logger.info("Applying subunit symmetry resolution")
|
|
270
|
+
loss_input = subunit_resolver(network_output, loss_input, symmetry_resolution)
|
|
271
|
+
|
|
272
|
+
# Apply residue symmetry resolution
|
|
273
|
+
if resolve_residue and automorphisms:
|
|
274
|
+
logger.info("Applying residue symmetry resolution")
|
|
275
|
+
residue_resolver = ResidueSymmetryResolution()
|
|
276
|
+
|
|
277
|
+
# Create network output dict
|
|
278
|
+
network_output: Dict[str, torch.Tensor] = {"X_L": X_pred}
|
|
279
|
+
|
|
280
|
+
# Apply residue resolution
|
|
281
|
+
loss_input = residue_resolver(network_output, loss_input, automorphisms)
|
|
282
|
+
|
|
283
|
+
# Return the updated ground truth coordinates
|
|
284
|
+
return loss_input["X_gt_L"]
|
rf3/train.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
#!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/../../../../.ipd/shebang/rf3_exec.sh" "$0" "$@"'
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
import hydra
|
|
7
|
+
import rootutils
|
|
8
|
+
from dotenv import load_dotenv
|
|
9
|
+
from omegaconf import DictConfig
|
|
10
|
+
|
|
11
|
+
from foundry.utils.logging import suppress_warnings
|
|
12
|
+
from foundry.utils.weights import CheckpointConfig
|
|
13
|
+
|
|
14
|
+
# Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
|
|
15
|
+
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
|
|
16
|
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
|
17
|
+
|
|
18
|
+
load_dotenv(override=True)
|
|
19
|
+
|
|
20
|
+
_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs")
|
|
21
|
+
|
|
22
|
+
_spawning_process_logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@hydra.main(config_path=_config_path, config_name="train", version_base="1.3")
|
|
26
|
+
def train(cfg: DictConfig) -> None:
|
|
27
|
+
# ==============================================================================
|
|
28
|
+
# Import dependencies and resolve Hydra configuration
|
|
29
|
+
# ==============================================================================
|
|
30
|
+
|
|
31
|
+
_spawning_process_logger.info("Importing dependencies...")
|
|
32
|
+
|
|
33
|
+
# Lazy imports to make config generation fast
|
|
34
|
+
import torch
|
|
35
|
+
from lightning.fabric import seed_everything
|
|
36
|
+
from lightning.fabric.loggers import Logger
|
|
37
|
+
|
|
38
|
+
# If training on DIGS L40, set precision of matrix multiplication to balance speed and accuracy
|
|
39
|
+
# Reference: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
|
|
40
|
+
torch.set_float32_matmul_precision("medium")
|
|
41
|
+
|
|
42
|
+
from foundry.callbacks.callback import BaseCallback # noqa
|
|
43
|
+
from foundry.utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa
|
|
44
|
+
from foundry.utils.logging import (
|
|
45
|
+
print_config_tree,
|
|
46
|
+
log_hyperparameters_with_all_loggers,
|
|
47
|
+
) # noqa
|
|
48
|
+
from foundry.utils.ddp import RankedLogger # noqa
|
|
49
|
+
from foundry.utils.ddp import is_rank_zero, set_accelerator_based_on_availability # noqa
|
|
50
|
+
from foundry.utils.datasets import (
|
|
51
|
+
recursively_instantiate_datasets_and_samplers,
|
|
52
|
+
assemble_distributed_loader,
|
|
53
|
+
subset_dataset_to_example_ids,
|
|
54
|
+
assemble_val_loader_dict,
|
|
55
|
+
) # noqa
|
|
56
|
+
|
|
57
|
+
set_accelerator_based_on_availability(cfg)
|
|
58
|
+
|
|
59
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
60
|
+
_spawning_process_logger.info("Completed dependency imports ...")
|
|
61
|
+
|
|
62
|
+
# ... print the configuration tree (NOTE: Only prints for rank 0)
|
|
63
|
+
print_config_tree(cfg, resolve=True)
|
|
64
|
+
|
|
65
|
+
# ==============================================================================
|
|
66
|
+
# Logging and Callback instantiation
|
|
67
|
+
# ==============================================================================
|
|
68
|
+
|
|
69
|
+
# Reduce the logging level for all dataset and sampler loggers (unless rank 0)
|
|
70
|
+
# We will still see messages from Rank 0; they are identical, since all ranks load and sample from the same datasets
|
|
71
|
+
if not is_rank_zero():
|
|
72
|
+
dataset_logger = logging.getLogger("datasets")
|
|
73
|
+
sampler_logger = logging.getLogger("atomworks.ml.samplers")
|
|
74
|
+
dataset_logger.setLevel(logging.WARNING)
|
|
75
|
+
sampler_logger.setLevel(logging.ERROR)
|
|
76
|
+
|
|
77
|
+
# ... seed everything (NOTE: By setting `workers=True`, we ensure that the dataloaders are seeded as well)
|
|
78
|
+
# (`PL_GLOBAL_SEED` environment varaible will be passed to the spawned subprocessed; e.g., through `ddp_spawn` backend)
|
|
79
|
+
if cfg.get("seed"):
|
|
80
|
+
ranked_logger.info(f"Seeding everything with seed={cfg.seed}...")
|
|
81
|
+
seed_everything(cfg.seed, workers=True, verbose=True)
|
|
82
|
+
else:
|
|
83
|
+
ranked_logger.warning("No seed provided - Not seeding anything!")
|
|
84
|
+
|
|
85
|
+
ranked_logger.info("Instantiating loggers...")
|
|
86
|
+
loggers: list[Logger] = instantiate_loggers(cfg.get("logger"))
|
|
87
|
+
|
|
88
|
+
ranked_logger.info("Instantiating callbacks...")
|
|
89
|
+
callbacks: list[BaseCallback] = instantiate_callbacks(cfg.get("callbacks"))
|
|
90
|
+
|
|
91
|
+
# ==============================================================================
|
|
92
|
+
# Trainer and model instantiation
|
|
93
|
+
# ==============================================================================
|
|
94
|
+
|
|
95
|
+
# ... instantiate the trainer
|
|
96
|
+
ranked_logger.info("Instantiating trainer...")
|
|
97
|
+
trainer = hydra.utils.instantiate(
|
|
98
|
+
cfg.trainer,
|
|
99
|
+
loggers=loggers or None,
|
|
100
|
+
callbacks=callbacks or None,
|
|
101
|
+
_convert_="partial",
|
|
102
|
+
_recursive_=False,
|
|
103
|
+
)
|
|
104
|
+
# (Store the Hydra configuration in the trainer state)
|
|
105
|
+
trainer.initialize_or_update_trainer_state({"train_cfg": cfg})
|
|
106
|
+
|
|
107
|
+
# ... spawn processes for distributed training
|
|
108
|
+
# (We spawn here, rather than within `fit`, so we can use Fabric's `init_module` to efficiently initialize the model on the appropriate device)
|
|
109
|
+
ranked_logger.info(
|
|
110
|
+
f"Spawning {trainer.fabric.world_size} processes from {trainer.fabric.global_rank}..."
|
|
111
|
+
)
|
|
112
|
+
trainer.fabric.launch()
|
|
113
|
+
|
|
114
|
+
# ... construct the model
|
|
115
|
+
trainer.construct_model()
|
|
116
|
+
|
|
117
|
+
# ... construct the optimizer and schedule (which requires the model to be constructed)
|
|
118
|
+
trainer.construct_optimizer()
|
|
119
|
+
trainer.construct_scheduler()
|
|
120
|
+
|
|
121
|
+
# ==============================================================================
|
|
122
|
+
# Dataset instantiation
|
|
123
|
+
# ==============================================================================
|
|
124
|
+
|
|
125
|
+
# Number of examples per epoch (accross all GPUs)
|
|
126
|
+
# (We must sample this many indices from our sampler)
|
|
127
|
+
n_examples_per_epoch = cfg.trainer.n_examples_per_epoch
|
|
128
|
+
|
|
129
|
+
# ... build the train dataset
|
|
130
|
+
assert (
|
|
131
|
+
"train" in cfg.datasets and cfg.datasets.train
|
|
132
|
+
), "No 'train' dataloader configuration provided! If only performing validation, use `validate.py` instead."
|
|
133
|
+
dataset_and_sampler = recursively_instantiate_datasets_and_samplers(
|
|
134
|
+
cfg.datasets.train
|
|
135
|
+
)
|
|
136
|
+
train_dataset, train_sampler = (
|
|
137
|
+
dataset_and_sampler["dataset"],
|
|
138
|
+
dataset_and_sampler["sampler"],
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# ... compose the train loader
|
|
142
|
+
if "subset_to_example_ids" in cfg.datasets:
|
|
143
|
+
# Backdoor for debugging and overfitting: subset the dataset to a specific set of example IDs
|
|
144
|
+
train_dataset = subset_dataset_to_example_ids(
|
|
145
|
+
train_dataset, cfg.datasets.subset_to_example_ids
|
|
146
|
+
)
|
|
147
|
+
train_sampler = None # Sampler is no longer valid, since we are using a subset of the dataset
|
|
148
|
+
|
|
149
|
+
train_loader = assemble_distributed_loader(
|
|
150
|
+
dataset=train_dataset,
|
|
151
|
+
sampler=train_sampler,
|
|
152
|
+
rank=trainer.fabric.global_rank,
|
|
153
|
+
world_size=trainer.fabric.world_size,
|
|
154
|
+
n_examples_per_epoch=n_examples_per_epoch,
|
|
155
|
+
loader_cfg=cfg.dataloader["train"],
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# ... compose the validation loader(s)
|
|
159
|
+
if "val" in cfg.datasets and cfg.datasets.val:
|
|
160
|
+
val_loaders = assemble_val_loader_dict(
|
|
161
|
+
cfg=cfg.datasets.val,
|
|
162
|
+
rank=trainer.fabric.global_rank,
|
|
163
|
+
world_size=trainer.fabric.world_size,
|
|
164
|
+
loader_cfg=cfg.dataloader["val"],
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
ranked_logger.warning("No validation datasets provided! Skipping validation...")
|
|
168
|
+
val_loaders = None
|
|
169
|
+
|
|
170
|
+
ranked_logger.info("Logging hyperparameters...")
|
|
171
|
+
log_hyperparameters_with_all_loggers(
|
|
172
|
+
trainer=trainer, cfg=cfg, model=trainer.state["model"]
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# ... load the checkpoint configuration
|
|
176
|
+
ckpt_config = None
|
|
177
|
+
if "ckpt_config" in cfg and cfg.ckpt_config:
|
|
178
|
+
ckpt_config = hydra.utils.instantiate(cfg.ckpt_config)
|
|
179
|
+
elif "ckpt_path" in cfg and cfg.ckpt_path:
|
|
180
|
+
# Just a checkpoint path
|
|
181
|
+
if cfg.ckpt_path is not None:
|
|
182
|
+
ckpt_config = CheckpointConfig(path=cfg.ckpt_path)
|
|
183
|
+
|
|
184
|
+
# ... train the model
|
|
185
|
+
ranked_logger.info("Training model...")
|
|
186
|
+
|
|
187
|
+
with suppress_warnings():
|
|
188
|
+
trainer.fit(
|
|
189
|
+
train_loader=train_loader, val_loaders=val_loaders, ckpt_config=ckpt_config
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
if __name__ == "__main__":
|
|
194
|
+
train()
|