rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
import networkx as nx
|
|
2
|
+
import numpy as np
|
|
3
|
+
from biotite.structure.info import residue
|
|
4
|
+
from scipy.spatial.distance import cdist
|
|
5
|
+
|
|
6
|
+
from foundry.metrics.metric import Metric
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def collapsing_virtual_atoms_batched(
|
|
10
|
+
atom_arrays, central_atom, threshold=0.5, return_virtual_index=False
|
|
11
|
+
):
|
|
12
|
+
"""
|
|
13
|
+
Apply collapsing_virtual_atoms to a batch of atom arrays.
|
|
14
|
+
|
|
15
|
+
Parameters:
|
|
16
|
+
atom_arrays (List[AtomArray]): Batch of atom arrays.
|
|
17
|
+
central_atom (str): Atom to compute distance from (e.g., "CA").
|
|
18
|
+
threshold (float): Distance threshold to identify virtual atoms.
|
|
19
|
+
return_virtual_index (bool): Whether to also return the virtual mask.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
List of filtered atom arrays or (atom_array, mask) tuples
|
|
23
|
+
"""
|
|
24
|
+
result = []
|
|
25
|
+
for atom_array in atom_arrays:
|
|
26
|
+
virtual_atom_mask = np.zeros(len(atom_array), dtype=bool)
|
|
27
|
+
|
|
28
|
+
# We need to select residues by the combination of chain_iid and res_id.
|
|
29
|
+
chain_iid_with_sep = np.char.add(atom_array.chain_iid, "|")
|
|
30
|
+
chain_iid_and_res_id = np.char.add(
|
|
31
|
+
chain_iid_with_sep, atom_array.res_id.astype(str)
|
|
32
|
+
)
|
|
33
|
+
atom_array.set_annotation("chain_iid_and_res_id", chain_iid_and_res_id)
|
|
34
|
+
unique_residue_identifiers = np.unique(chain_iid_and_res_id)
|
|
35
|
+
|
|
36
|
+
for res_identifier in unique_residue_identifiers:
|
|
37
|
+
# ... Pick the current residue
|
|
38
|
+
cur_mask = atom_array.chain_iid_and_res_id == res_identifier
|
|
39
|
+
cur_residue = atom_array[cur_mask]
|
|
40
|
+
cur_central_atom = central_atom
|
|
41
|
+
|
|
42
|
+
# For Glycine: it doesn't have CB, so set the virtual atom as CA.
|
|
43
|
+
# The current way to handle this is to check if predicted CA and CB are too close, because in the case of glycine and we pad virtual atoms based on CB, CB's coords are set as CA.
|
|
44
|
+
# There might be a better way to do this.
|
|
45
|
+
CA_coord = cur_residue.coord[cur_residue.atom_name == "CA"]
|
|
46
|
+
CB_coord = cur_residue.coord[cur_residue.atom_name == "CB"]
|
|
47
|
+
if np.linalg.norm(CA_coord - CB_coord) < threshold:
|
|
48
|
+
cur_central_atom = "CA"
|
|
49
|
+
|
|
50
|
+
central_mask = cur_residue.atom_name == cur_central_atom
|
|
51
|
+
|
|
52
|
+
if not np.any(central_mask):
|
|
53
|
+
continue
|
|
54
|
+
|
|
55
|
+
# ... Calculate the distance to the central atom
|
|
56
|
+
central_coord = cur_residue.coord[central_mask][
|
|
57
|
+
0
|
|
58
|
+
] # Should only have one central atom anyway
|
|
59
|
+
dists = np.linalg.norm(cur_residue.coord - central_coord, axis=-1)
|
|
60
|
+
|
|
61
|
+
# ... Select virtual atom by the distance. Shouldn't count the central atom itself. (F)
|
|
62
|
+
is_virtual = (dists < threshold) & ~central_mask
|
|
63
|
+
|
|
64
|
+
virtual_atom_mask[np.where(cur_mask)[0][is_virtual]] = True
|
|
65
|
+
|
|
66
|
+
filtered = atom_array[~virtual_atom_mask]
|
|
67
|
+
if return_virtual_index:
|
|
68
|
+
result.append((filtered, virtual_atom_mask))
|
|
69
|
+
else:
|
|
70
|
+
result.append(filtered)
|
|
71
|
+
|
|
72
|
+
return result
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def construct_graph(coords, cutoff_min, cutoff_max):
|
|
76
|
+
"""
|
|
77
|
+
Use coordinates to construct a NetworkX graph.
|
|
78
|
+
Nodes = atom indices.
|
|
79
|
+
Edges = distance-based inferred bonds.
|
|
80
|
+
|
|
81
|
+
Parameters:
|
|
82
|
+
coords: [n, 3]
|
|
83
|
+
cutoff_min: min distance to consider a bond (avoid self-loops)
|
|
84
|
+
cutoff_max: max distance to consider a bond (e.g., typical covalent bond)
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
G: A NetworkX graph
|
|
88
|
+
"""
|
|
89
|
+
dists = cdist(coords, coords) # [N, N]
|
|
90
|
+
|
|
91
|
+
G = nx.Graph()
|
|
92
|
+
n_atoms = coords.shape[0]
|
|
93
|
+
|
|
94
|
+
# ... Add nodes
|
|
95
|
+
for i in range(n_atoms):
|
|
96
|
+
G.add_node(i)
|
|
97
|
+
|
|
98
|
+
# ... Add edges based on distance
|
|
99
|
+
for i in range(n_atoms):
|
|
100
|
+
for j in range(i + 1, n_atoms):
|
|
101
|
+
if cutoff_min < dists[i, j] < cutoff_max:
|
|
102
|
+
G.add_edge(i, j)
|
|
103
|
+
|
|
104
|
+
return G
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def are_graphs_isomorphic(g1, g2):
|
|
108
|
+
"""
|
|
109
|
+
Check if two graphs are topologically isomorphic (ignoring atom/bond types).
|
|
110
|
+
"""
|
|
111
|
+
return nx.is_isomorphic(g1, g2)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def check_sidechain_quality(atom_array, dist_threshold_min=1, dist_threshold_max=2):
|
|
115
|
+
"""
|
|
116
|
+
Check sidechain quality. This is done by checking:
|
|
117
|
+
(1) if a sidechain can map to a standard amino acid based on the topology;
|
|
118
|
+
(2) if two sidechains has unexpected bond connection
|
|
119
|
+
(3) if a sidechain itself has collapse
|
|
120
|
+
A valid sidechain is defined by satisfying all the three rules.
|
|
121
|
+
|
|
122
|
+
Return:
|
|
123
|
+
- matched (dict): all possible standard amino acids that a sidechain can map to.
|
|
124
|
+
- valid_sidechain_percent (float): percentage of valid sidechains.
|
|
125
|
+
- unintended_bonds_percent (float): percentage of sidechains with unintended bonds with other sidechains.
|
|
126
|
+
- clash_percent (float): percentage of sidechains that has collapse in itself.
|
|
127
|
+
"""
|
|
128
|
+
# Step 1: Build standard amino acid graphs
|
|
129
|
+
standard_aa = [
|
|
130
|
+
"ALA",
|
|
131
|
+
"ARG",
|
|
132
|
+
"ASN",
|
|
133
|
+
"ASP",
|
|
134
|
+
"CYS",
|
|
135
|
+
"GLU",
|
|
136
|
+
"GLN",
|
|
137
|
+
"GLY",
|
|
138
|
+
"HIS",
|
|
139
|
+
"ILE",
|
|
140
|
+
"LEU",
|
|
141
|
+
"LYS",
|
|
142
|
+
"MET",
|
|
143
|
+
"PHE",
|
|
144
|
+
"PRO",
|
|
145
|
+
"SER",
|
|
146
|
+
"THR",
|
|
147
|
+
"TRP",
|
|
148
|
+
"TYR",
|
|
149
|
+
"VAL",
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
standard_aa_atom_array = [residue(aa) for aa in standard_aa]
|
|
153
|
+
|
|
154
|
+
# ... Remove OXT atoms and hydrogens
|
|
155
|
+
standard_aa_atom_array = [
|
|
156
|
+
aa[(~np.isin(aa.atom_name, np.array(["OXT"]))) & (aa.element != "H")]
|
|
157
|
+
for aa in standard_aa_atom_array
|
|
158
|
+
]
|
|
159
|
+
|
|
160
|
+
# ... Convert standard AA to topology graphs
|
|
161
|
+
standard_aa_graphs = []
|
|
162
|
+
for aa in standard_aa_atom_array:
|
|
163
|
+
try:
|
|
164
|
+
g = construct_graph(
|
|
165
|
+
aa.coord, cutoff_min=dist_threshold_min, cutoff_max=dist_threshold_max
|
|
166
|
+
)
|
|
167
|
+
standard_aa_graphs.append(g)
|
|
168
|
+
except Exception as e:
|
|
169
|
+
print(f"Failed to convert {aa} to graph: {e}")
|
|
170
|
+
standard_aa_graphs.append(None)
|
|
171
|
+
|
|
172
|
+
# We need to select residues by the combination of chain_iid and res_id.
|
|
173
|
+
chain_iid_with_sep = np.char.add(atom_array.chain_iid, "|")
|
|
174
|
+
chain_iid_and_res_id = np.char.add(
|
|
175
|
+
chain_iid_with_sep, atom_array.res_id.astype(str)
|
|
176
|
+
)
|
|
177
|
+
atom_array.set_annotation("chain_iid_and_res_id", chain_iid_and_res_id)
|
|
178
|
+
unique_residue_identifiers = np.unique(chain_iid_and_res_id)
|
|
179
|
+
matches = {}
|
|
180
|
+
|
|
181
|
+
# ... Map predicted sidechain to any standard amino acids
|
|
182
|
+
for res_identifier in unique_residue_identifiers:
|
|
183
|
+
matches[res_identifier] = []
|
|
184
|
+
cur_res_coords = atom_array.coord[
|
|
185
|
+
atom_array.chain_iid_and_res_id == res_identifier
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
try:
|
|
189
|
+
cur_graph = construct_graph(
|
|
190
|
+
cur_res_coords,
|
|
191
|
+
cutoff_min=dist_threshold_min,
|
|
192
|
+
cutoff_max=dist_threshold_max,
|
|
193
|
+
)
|
|
194
|
+
except Exception as e:
|
|
195
|
+
print(
|
|
196
|
+
f"[WARN] Could not build graph for chain_iid|res_id {res_identifier}: {e}"
|
|
197
|
+
)
|
|
198
|
+
continue
|
|
199
|
+
|
|
200
|
+
for aa_idx, aa_graph in enumerate(standard_aa_graphs):
|
|
201
|
+
if aa_graph is None:
|
|
202
|
+
continue
|
|
203
|
+
if are_graphs_isomorphic(cur_graph, aa_graph):
|
|
204
|
+
matches[res_identifier].append(standard_aa[aa_idx])
|
|
205
|
+
|
|
206
|
+
# Step 2: Check if the inter and intra-residue quality is good.
|
|
207
|
+
# (1) Check if there are potential bonds between sidechains from different residues.
|
|
208
|
+
# (2) Check if atoms are too close to collapse.
|
|
209
|
+
|
|
210
|
+
coords = atom_array.coord
|
|
211
|
+
residue_identifiers = atom_array.chain_iid_and_res_id
|
|
212
|
+
|
|
213
|
+
# ... Mask sidechain atoms. Now the sidechain is any atoms except four backbone atoms
|
|
214
|
+
is_sidechain = ~np.isin(atom_array.atom_name, np.array(["N", "CA", "C", "O"]))
|
|
215
|
+
|
|
216
|
+
coords_sc = coords[is_sidechain]
|
|
217
|
+
residue_identifiers_sc = residue_identifiers[is_sidechain]
|
|
218
|
+
|
|
219
|
+
# ... Calculate pairwise distances
|
|
220
|
+
dists = cdist(coords_sc, coords_sc)
|
|
221
|
+
|
|
222
|
+
# ... Check if there are potential bonds between sidechains from different residues.
|
|
223
|
+
unintended_bonds = {
|
|
224
|
+
res_identifier: False for res_identifier in unique_residue_identifiers
|
|
225
|
+
}
|
|
226
|
+
N = dists.shape[0]
|
|
227
|
+
|
|
228
|
+
# Only look at the upper triangle (exclude diagonal)
|
|
229
|
+
iu, ju = np.triu_indices(N, k=1)
|
|
230
|
+
|
|
231
|
+
# Apply distance threshold to identify any possible bonds
|
|
232
|
+
potential_bonds = (dists[iu, ju] > dist_threshold_min) & (
|
|
233
|
+
dists[iu, ju] < dist_threshold_max
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Check if atoms are from different residues
|
|
237
|
+
diff_res_mask = residue_identifiers_sc[iu] != residue_identifiers_sc[ju]
|
|
238
|
+
|
|
239
|
+
# Combine both masks
|
|
240
|
+
bonds_mask = potential_bonds & diff_res_mask
|
|
241
|
+
|
|
242
|
+
# ... Annotate residues with unintended bonds
|
|
243
|
+
for idx in range(len(bonds_mask)):
|
|
244
|
+
if bonds_mask[idx]:
|
|
245
|
+
unintended_bonds[residue_identifiers_sc[iu[idx]]] = True
|
|
246
|
+
unintended_bonds[residue_identifiers_sc[ju[idx]]] = True
|
|
247
|
+
|
|
248
|
+
# ... Check if atoms are too close to be real
|
|
249
|
+
clash_residues = {
|
|
250
|
+
res_identifier: False for res_identifier in unique_residue_identifiers
|
|
251
|
+
}
|
|
252
|
+
clash_mask = dists[iu, ju] < dist_threshold_min
|
|
253
|
+
for idx in range(len(clash_mask)):
|
|
254
|
+
if clash_mask[idx]:
|
|
255
|
+
clash_residues[residue_identifiers_sc[iu[idx]]] = True
|
|
256
|
+
clash_residues[residue_identifiers_sc[ju[idx]]] = True
|
|
257
|
+
|
|
258
|
+
# ... Output the final valid sidechains
|
|
259
|
+
if_valid_sidechains = [
|
|
260
|
+
(len(matches[res_identifier]) > 0)
|
|
261
|
+
& (~unintended_bonds[res_identifier])
|
|
262
|
+
& (~clash_residues[res_identifier])
|
|
263
|
+
for res_identifier in unique_residue_identifiers
|
|
264
|
+
]
|
|
265
|
+
if_unintended_bonds = [
|
|
266
|
+
unintended_bonds[res_identifier]
|
|
267
|
+
for res_identifier in unique_residue_identifiers
|
|
268
|
+
]
|
|
269
|
+
if_clash = [
|
|
270
|
+
clash_residues[res_identifier] for res_identifier in unique_residue_identifiers
|
|
271
|
+
]
|
|
272
|
+
|
|
273
|
+
valid_sidechain_percent = sum(if_valid_sidechains) / len(unique_residue_identifiers)
|
|
274
|
+
unintended_bonds_percent = sum(if_unintended_bonds) / len(
|
|
275
|
+
unique_residue_identifiers
|
|
276
|
+
)
|
|
277
|
+
clash_percent = sum(if_clash) / len(unique_residue_identifiers)
|
|
278
|
+
|
|
279
|
+
return matches, valid_sidechain_percent, unintended_bonds_percent, clash_percent
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def compute_batched_sidechain_quality(
|
|
283
|
+
predicted_atom_array_stack,
|
|
284
|
+
central_atom,
|
|
285
|
+
dist_threshold_min=1.0,
|
|
286
|
+
dist_threshold_max=2.0,
|
|
287
|
+
already_removed_virtual_atoms=False,
|
|
288
|
+
):
|
|
289
|
+
"""
|
|
290
|
+
Compute sidechain metrics for each structure in a batch.
|
|
291
|
+
"""
|
|
292
|
+
batch_metrics = []
|
|
293
|
+
|
|
294
|
+
for atom_array in predicted_atom_array_stack:
|
|
295
|
+
metrics = {}
|
|
296
|
+
matches, valid, unintended, clash = check_sidechain_quality(
|
|
297
|
+
atom_array, dist_threshold_min, dist_threshold_max
|
|
298
|
+
)
|
|
299
|
+
metrics["mapped_restype"] = matches
|
|
300
|
+
metrics["valid_sidechain_percent"] = valid
|
|
301
|
+
metrics["unintended_bonds_percent"] = unintended
|
|
302
|
+
metrics["clash_percent"] = clash
|
|
303
|
+
batch_metrics.append(metrics)
|
|
304
|
+
return batch_metrics
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
class SidechainMetrics(Metric):
|
|
308
|
+
def __init__(
|
|
309
|
+
self,
|
|
310
|
+
dist_threshold_min,
|
|
311
|
+
dist_threshold_max,
|
|
312
|
+
central_atom,
|
|
313
|
+
already_removed_virtual_atoms=False,
|
|
314
|
+
):
|
|
315
|
+
super().__init__()
|
|
316
|
+
self.dist_threshold_min = dist_threshold_min
|
|
317
|
+
self.dist_threshold_max = dist_threshold_max
|
|
318
|
+
self.central_atom = central_atom
|
|
319
|
+
self.already_removed_virtual_atoms = already_removed_virtual_atoms
|
|
320
|
+
|
|
321
|
+
@property
|
|
322
|
+
def kwargs_to_compute_args(self):
|
|
323
|
+
return {
|
|
324
|
+
"predicted_atom_array_stack": ("predicted_atom_array_stack",),
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
def compute(self, predicted_atom_array_stack):
|
|
328
|
+
batch_metrics = compute_batched_sidechain_quality(
|
|
329
|
+
predicted_atom_array_stack,
|
|
330
|
+
self.central_atom,
|
|
331
|
+
self.dist_threshold_min,
|
|
332
|
+
self.dist_threshold_max,
|
|
333
|
+
self.already_removed_virtual_atoms,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
# Aggregate output for batch-level metrics
|
|
337
|
+
o = {
|
|
338
|
+
"mean_valid_sidechain_percent": float(
|
|
339
|
+
np.mean([m["valid_sidechain_percent"] for m in batch_metrics])
|
|
340
|
+
),
|
|
341
|
+
"mean_unintended_bonds_percent": float(
|
|
342
|
+
np.mean([m["unintended_bonds_percent"] for m in batch_metrics])
|
|
343
|
+
),
|
|
344
|
+
"mean_clash_percent": float(
|
|
345
|
+
np.mean([m["clash_percent"] for m in batch_metrics])
|
|
346
|
+
),
|
|
347
|
+
# "mapped_restype": [m["mapped_restype"] for m in batch_metrics],
|
|
348
|
+
}
|
|
349
|
+
return o
|
rfd3/model/RFD3.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import hydra
|
|
4
|
+
import torch
|
|
5
|
+
from omegaconf import DictConfig
|
|
6
|
+
from rfd3.model.cfg_utils import (
|
|
7
|
+
strip_f,
|
|
8
|
+
)
|
|
9
|
+
from rfd3.model.inference_sampler import ConditionalDiffusionSampler
|
|
10
|
+
from rfd3.model.layers.encoders import TokenInitializer
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
from foundry.utils.ddp import RankedLogger
|
|
14
|
+
|
|
15
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RFD3(nn.Module):
|
|
19
|
+
"""
|
|
20
|
+
Simplified model for generation
|
|
21
|
+
This module level serves to wrap the diffusion module of AF3
|
|
22
|
+
to be roughly equivalent to the AF3 model w/o trunk processing.
|
|
23
|
+
|
|
24
|
+
Allows the same sampler to be used
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
*,
|
|
30
|
+
# Channel dimensions ('global' features)
|
|
31
|
+
c_s: int,
|
|
32
|
+
c_z: int,
|
|
33
|
+
c_atom: int,
|
|
34
|
+
c_atompair: int,
|
|
35
|
+
# Arguments for modules that will be instantiated
|
|
36
|
+
token_initializer: DictConfig | dict,
|
|
37
|
+
diffusion_module: DictConfig | dict,
|
|
38
|
+
inference_sampler: DictConfig | dict,
|
|
39
|
+
**_,
|
|
40
|
+
):
|
|
41
|
+
super().__init__()
|
|
42
|
+
# Check for chunked P_LL mode via environment variable
|
|
43
|
+
use_chunked_pll = os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"
|
|
44
|
+
ranked_logger.info(f"RFD3 initialized with chunked_pll={use_chunked_pll}")
|
|
45
|
+
|
|
46
|
+
# Simple constant-feature initializer
|
|
47
|
+
self.token_initializer = TokenInitializer(
|
|
48
|
+
c_s=c_s,
|
|
49
|
+
c_z=c_z,
|
|
50
|
+
c_atom=c_atom,
|
|
51
|
+
c_atompair=c_atompair,
|
|
52
|
+
use_chunked_pll=use_chunked_pll,
|
|
53
|
+
**token_initializer,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Diffusion module instantiated to allow for config scripting
|
|
57
|
+
self.diffusion_module = hydra.utils.instantiate(
|
|
58
|
+
diffusion_module, c_atom=c_atom, c_atompair=c_atompair, c_s=c_s, c_z=c_z
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self.use_classifier_free_guidance = (
|
|
62
|
+
inference_sampler["use_classifier_free_guidance"]
|
|
63
|
+
and inference_sampler["cfg_scale"] != 1.0
|
|
64
|
+
)
|
|
65
|
+
self.cfg_features = inference_sampler.pop("cfg_features", [])
|
|
66
|
+
|
|
67
|
+
# ... initialize the inference sampler, which performs a full diffusion rollout during inference
|
|
68
|
+
self.inference_sampler = ConditionalDiffusionSampler(**inference_sampler)
|
|
69
|
+
|
|
70
|
+
def forward(
|
|
71
|
+
self,
|
|
72
|
+
input: dict,
|
|
73
|
+
coord_atom_lvl_to_be_noised: torch.Tensor = None,
|
|
74
|
+
n_cycle=None,
|
|
75
|
+
**_,
|
|
76
|
+
) -> dict:
|
|
77
|
+
initializer_outputs = self.token_initializer(input["f"])
|
|
78
|
+
|
|
79
|
+
if self.training:
|
|
80
|
+
# Single denoising step
|
|
81
|
+
return self.diffusion_module(
|
|
82
|
+
X_noisy_L=input["X_noisy_L"],
|
|
83
|
+
t=input["t"],
|
|
84
|
+
f=input["f"],
|
|
85
|
+
n_recycle=n_cycle,
|
|
86
|
+
**initializer_outputs,
|
|
87
|
+
) # [D, L, 3]
|
|
88
|
+
else:
|
|
89
|
+
if self.use_classifier_free_guidance:
|
|
90
|
+
f_ref = strip_f(input["f"], self.cfg_features)
|
|
91
|
+
ref_initializer_outputs = self.token_initializer(f_ref)
|
|
92
|
+
else:
|
|
93
|
+
f_ref = None
|
|
94
|
+
ref_initializer_outputs = None
|
|
95
|
+
|
|
96
|
+
return self.inference_sampler.sample_diffusion_like_af3(
|
|
97
|
+
f=input["f"],
|
|
98
|
+
f_ref=f_ref, # for cfg
|
|
99
|
+
diffusion_module=self.diffusion_module,
|
|
100
|
+
diffusion_batch_size=coord_atom_lvl_to_be_noised.shape[0],
|
|
101
|
+
coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
|
|
102
|
+
# Forwarded as **kwargs:
|
|
103
|
+
initializer_outputs=initializer_outputs,
|
|
104
|
+
ref_initializer_outputs=ref_initializer_outputs, # for cfg
|
|
105
|
+
)
|