rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,349 @@
1
+ import networkx as nx
2
+ import numpy as np
3
+ from biotite.structure.info import residue
4
+ from scipy.spatial.distance import cdist
5
+
6
+ from foundry.metrics.metric import Metric
7
+
8
+
9
+ def collapsing_virtual_atoms_batched(
10
+ atom_arrays, central_atom, threshold=0.5, return_virtual_index=False
11
+ ):
12
+ """
13
+ Apply collapsing_virtual_atoms to a batch of atom arrays.
14
+
15
+ Parameters:
16
+ atom_arrays (List[AtomArray]): Batch of atom arrays.
17
+ central_atom (str): Atom to compute distance from (e.g., "CA").
18
+ threshold (float): Distance threshold to identify virtual atoms.
19
+ return_virtual_index (bool): Whether to also return the virtual mask.
20
+
21
+ Returns:
22
+ List of filtered atom arrays or (atom_array, mask) tuples
23
+ """
24
+ result = []
25
+ for atom_array in atom_arrays:
26
+ virtual_atom_mask = np.zeros(len(atom_array), dtype=bool)
27
+
28
+ # We need to select residues by the combination of chain_iid and res_id.
29
+ chain_iid_with_sep = np.char.add(atom_array.chain_iid, "|")
30
+ chain_iid_and_res_id = np.char.add(
31
+ chain_iid_with_sep, atom_array.res_id.astype(str)
32
+ )
33
+ atom_array.set_annotation("chain_iid_and_res_id", chain_iid_and_res_id)
34
+ unique_residue_identifiers = np.unique(chain_iid_and_res_id)
35
+
36
+ for res_identifier in unique_residue_identifiers:
37
+ # ... Pick the current residue
38
+ cur_mask = atom_array.chain_iid_and_res_id == res_identifier
39
+ cur_residue = atom_array[cur_mask]
40
+ cur_central_atom = central_atom
41
+
42
+ # For Glycine: it doesn't have CB, so set the virtual atom as CA.
43
+ # The current way to handle this is to check if predicted CA and CB are too close, because in the case of glycine and we pad virtual atoms based on CB, CB's coords are set as CA.
44
+ # There might be a better way to do this.
45
+ CA_coord = cur_residue.coord[cur_residue.atom_name == "CA"]
46
+ CB_coord = cur_residue.coord[cur_residue.atom_name == "CB"]
47
+ if np.linalg.norm(CA_coord - CB_coord) < threshold:
48
+ cur_central_atom = "CA"
49
+
50
+ central_mask = cur_residue.atom_name == cur_central_atom
51
+
52
+ if not np.any(central_mask):
53
+ continue
54
+
55
+ # ... Calculate the distance to the central atom
56
+ central_coord = cur_residue.coord[central_mask][
57
+ 0
58
+ ] # Should only have one central atom anyway
59
+ dists = np.linalg.norm(cur_residue.coord - central_coord, axis=-1)
60
+
61
+ # ... Select virtual atom by the distance. Shouldn't count the central atom itself. (F)
62
+ is_virtual = (dists < threshold) & ~central_mask
63
+
64
+ virtual_atom_mask[np.where(cur_mask)[0][is_virtual]] = True
65
+
66
+ filtered = atom_array[~virtual_atom_mask]
67
+ if return_virtual_index:
68
+ result.append((filtered, virtual_atom_mask))
69
+ else:
70
+ result.append(filtered)
71
+
72
+ return result
73
+
74
+
75
+ def construct_graph(coords, cutoff_min, cutoff_max):
76
+ """
77
+ Use coordinates to construct a NetworkX graph.
78
+ Nodes = atom indices.
79
+ Edges = distance-based inferred bonds.
80
+
81
+ Parameters:
82
+ coords: [n, 3]
83
+ cutoff_min: min distance to consider a bond (avoid self-loops)
84
+ cutoff_max: max distance to consider a bond (e.g., typical covalent bond)
85
+
86
+ Returns:
87
+ G: A NetworkX graph
88
+ """
89
+ dists = cdist(coords, coords) # [N, N]
90
+
91
+ G = nx.Graph()
92
+ n_atoms = coords.shape[0]
93
+
94
+ # ... Add nodes
95
+ for i in range(n_atoms):
96
+ G.add_node(i)
97
+
98
+ # ... Add edges based on distance
99
+ for i in range(n_atoms):
100
+ for j in range(i + 1, n_atoms):
101
+ if cutoff_min < dists[i, j] < cutoff_max:
102
+ G.add_edge(i, j)
103
+
104
+ return G
105
+
106
+
107
+ def are_graphs_isomorphic(g1, g2):
108
+ """
109
+ Check if two graphs are topologically isomorphic (ignoring atom/bond types).
110
+ """
111
+ return nx.is_isomorphic(g1, g2)
112
+
113
+
114
+ def check_sidechain_quality(atom_array, dist_threshold_min=1, dist_threshold_max=2):
115
+ """
116
+ Check sidechain quality. This is done by checking:
117
+ (1) if a sidechain can map to a standard amino acid based on the topology;
118
+ (2) if two sidechains has unexpected bond connection
119
+ (3) if a sidechain itself has collapse
120
+ A valid sidechain is defined by satisfying all the three rules.
121
+
122
+ Return:
123
+ - matched (dict): all possible standard amino acids that a sidechain can map to.
124
+ - valid_sidechain_percent (float): percentage of valid sidechains.
125
+ - unintended_bonds_percent (float): percentage of sidechains with unintended bonds with other sidechains.
126
+ - clash_percent (float): percentage of sidechains that has collapse in itself.
127
+ """
128
+ # Step 1: Build standard amino acid graphs
129
+ standard_aa = [
130
+ "ALA",
131
+ "ARG",
132
+ "ASN",
133
+ "ASP",
134
+ "CYS",
135
+ "GLU",
136
+ "GLN",
137
+ "GLY",
138
+ "HIS",
139
+ "ILE",
140
+ "LEU",
141
+ "LYS",
142
+ "MET",
143
+ "PHE",
144
+ "PRO",
145
+ "SER",
146
+ "THR",
147
+ "TRP",
148
+ "TYR",
149
+ "VAL",
150
+ ]
151
+
152
+ standard_aa_atom_array = [residue(aa) for aa in standard_aa]
153
+
154
+ # ... Remove OXT atoms and hydrogens
155
+ standard_aa_atom_array = [
156
+ aa[(~np.isin(aa.atom_name, np.array(["OXT"]))) & (aa.element != "H")]
157
+ for aa in standard_aa_atom_array
158
+ ]
159
+
160
+ # ... Convert standard AA to topology graphs
161
+ standard_aa_graphs = []
162
+ for aa in standard_aa_atom_array:
163
+ try:
164
+ g = construct_graph(
165
+ aa.coord, cutoff_min=dist_threshold_min, cutoff_max=dist_threshold_max
166
+ )
167
+ standard_aa_graphs.append(g)
168
+ except Exception as e:
169
+ print(f"Failed to convert {aa} to graph: {e}")
170
+ standard_aa_graphs.append(None)
171
+
172
+ # We need to select residues by the combination of chain_iid and res_id.
173
+ chain_iid_with_sep = np.char.add(atom_array.chain_iid, "|")
174
+ chain_iid_and_res_id = np.char.add(
175
+ chain_iid_with_sep, atom_array.res_id.astype(str)
176
+ )
177
+ atom_array.set_annotation("chain_iid_and_res_id", chain_iid_and_res_id)
178
+ unique_residue_identifiers = np.unique(chain_iid_and_res_id)
179
+ matches = {}
180
+
181
+ # ... Map predicted sidechain to any standard amino acids
182
+ for res_identifier in unique_residue_identifiers:
183
+ matches[res_identifier] = []
184
+ cur_res_coords = atom_array.coord[
185
+ atom_array.chain_iid_and_res_id == res_identifier
186
+ ]
187
+
188
+ try:
189
+ cur_graph = construct_graph(
190
+ cur_res_coords,
191
+ cutoff_min=dist_threshold_min,
192
+ cutoff_max=dist_threshold_max,
193
+ )
194
+ except Exception as e:
195
+ print(
196
+ f"[WARN] Could not build graph for chain_iid|res_id {res_identifier}: {e}"
197
+ )
198
+ continue
199
+
200
+ for aa_idx, aa_graph in enumerate(standard_aa_graphs):
201
+ if aa_graph is None:
202
+ continue
203
+ if are_graphs_isomorphic(cur_graph, aa_graph):
204
+ matches[res_identifier].append(standard_aa[aa_idx])
205
+
206
+ # Step 2: Check if the inter and intra-residue quality is good.
207
+ # (1) Check if there are potential bonds between sidechains from different residues.
208
+ # (2) Check if atoms are too close to collapse.
209
+
210
+ coords = atom_array.coord
211
+ residue_identifiers = atom_array.chain_iid_and_res_id
212
+
213
+ # ... Mask sidechain atoms. Now the sidechain is any atoms except four backbone atoms
214
+ is_sidechain = ~np.isin(atom_array.atom_name, np.array(["N", "CA", "C", "O"]))
215
+
216
+ coords_sc = coords[is_sidechain]
217
+ residue_identifiers_sc = residue_identifiers[is_sidechain]
218
+
219
+ # ... Calculate pairwise distances
220
+ dists = cdist(coords_sc, coords_sc)
221
+
222
+ # ... Check if there are potential bonds between sidechains from different residues.
223
+ unintended_bonds = {
224
+ res_identifier: False for res_identifier in unique_residue_identifiers
225
+ }
226
+ N = dists.shape[0]
227
+
228
+ # Only look at the upper triangle (exclude diagonal)
229
+ iu, ju = np.triu_indices(N, k=1)
230
+
231
+ # Apply distance threshold to identify any possible bonds
232
+ potential_bonds = (dists[iu, ju] > dist_threshold_min) & (
233
+ dists[iu, ju] < dist_threshold_max
234
+ )
235
+
236
+ # Check if atoms are from different residues
237
+ diff_res_mask = residue_identifiers_sc[iu] != residue_identifiers_sc[ju]
238
+
239
+ # Combine both masks
240
+ bonds_mask = potential_bonds & diff_res_mask
241
+
242
+ # ... Annotate residues with unintended bonds
243
+ for idx in range(len(bonds_mask)):
244
+ if bonds_mask[idx]:
245
+ unintended_bonds[residue_identifiers_sc[iu[idx]]] = True
246
+ unintended_bonds[residue_identifiers_sc[ju[idx]]] = True
247
+
248
+ # ... Check if atoms are too close to be real
249
+ clash_residues = {
250
+ res_identifier: False for res_identifier in unique_residue_identifiers
251
+ }
252
+ clash_mask = dists[iu, ju] < dist_threshold_min
253
+ for idx in range(len(clash_mask)):
254
+ if clash_mask[idx]:
255
+ clash_residues[residue_identifiers_sc[iu[idx]]] = True
256
+ clash_residues[residue_identifiers_sc[ju[idx]]] = True
257
+
258
+ # ... Output the final valid sidechains
259
+ if_valid_sidechains = [
260
+ (len(matches[res_identifier]) > 0)
261
+ & (~unintended_bonds[res_identifier])
262
+ & (~clash_residues[res_identifier])
263
+ for res_identifier in unique_residue_identifiers
264
+ ]
265
+ if_unintended_bonds = [
266
+ unintended_bonds[res_identifier]
267
+ for res_identifier in unique_residue_identifiers
268
+ ]
269
+ if_clash = [
270
+ clash_residues[res_identifier] for res_identifier in unique_residue_identifiers
271
+ ]
272
+
273
+ valid_sidechain_percent = sum(if_valid_sidechains) / len(unique_residue_identifiers)
274
+ unintended_bonds_percent = sum(if_unintended_bonds) / len(
275
+ unique_residue_identifiers
276
+ )
277
+ clash_percent = sum(if_clash) / len(unique_residue_identifiers)
278
+
279
+ return matches, valid_sidechain_percent, unintended_bonds_percent, clash_percent
280
+
281
+
282
+ def compute_batched_sidechain_quality(
283
+ predicted_atom_array_stack,
284
+ central_atom,
285
+ dist_threshold_min=1.0,
286
+ dist_threshold_max=2.0,
287
+ already_removed_virtual_atoms=False,
288
+ ):
289
+ """
290
+ Compute sidechain metrics for each structure in a batch.
291
+ """
292
+ batch_metrics = []
293
+
294
+ for atom_array in predicted_atom_array_stack:
295
+ metrics = {}
296
+ matches, valid, unintended, clash = check_sidechain_quality(
297
+ atom_array, dist_threshold_min, dist_threshold_max
298
+ )
299
+ metrics["mapped_restype"] = matches
300
+ metrics["valid_sidechain_percent"] = valid
301
+ metrics["unintended_bonds_percent"] = unintended
302
+ metrics["clash_percent"] = clash
303
+ batch_metrics.append(metrics)
304
+ return batch_metrics
305
+
306
+
307
+ class SidechainMetrics(Metric):
308
+ def __init__(
309
+ self,
310
+ dist_threshold_min,
311
+ dist_threshold_max,
312
+ central_atom,
313
+ already_removed_virtual_atoms=False,
314
+ ):
315
+ super().__init__()
316
+ self.dist_threshold_min = dist_threshold_min
317
+ self.dist_threshold_max = dist_threshold_max
318
+ self.central_atom = central_atom
319
+ self.already_removed_virtual_atoms = already_removed_virtual_atoms
320
+
321
+ @property
322
+ def kwargs_to_compute_args(self):
323
+ return {
324
+ "predicted_atom_array_stack": ("predicted_atom_array_stack",),
325
+ }
326
+
327
+ def compute(self, predicted_atom_array_stack):
328
+ batch_metrics = compute_batched_sidechain_quality(
329
+ predicted_atom_array_stack,
330
+ self.central_atom,
331
+ self.dist_threshold_min,
332
+ self.dist_threshold_max,
333
+ self.already_removed_virtual_atoms,
334
+ )
335
+
336
+ # Aggregate output for batch-level metrics
337
+ o = {
338
+ "mean_valid_sidechain_percent": float(
339
+ np.mean([m["valid_sidechain_percent"] for m in batch_metrics])
340
+ ),
341
+ "mean_unintended_bonds_percent": float(
342
+ np.mean([m["unintended_bonds_percent"] for m in batch_metrics])
343
+ ),
344
+ "mean_clash_percent": float(
345
+ np.mean([m["clash_percent"] for m in batch_metrics])
346
+ ),
347
+ # "mapped_restype": [m["mapped_restype"] for m in batch_metrics],
348
+ }
349
+ return o
rfd3/model/RFD3.py ADDED
@@ -0,0 +1,105 @@
1
+ import os
2
+
3
+ import hydra
4
+ import torch
5
+ from omegaconf import DictConfig
6
+ from rfd3.model.cfg_utils import (
7
+ strip_f,
8
+ )
9
+ from rfd3.model.inference_sampler import ConditionalDiffusionSampler
10
+ from rfd3.model.layers.encoders import TokenInitializer
11
+ from torch import nn
12
+
13
+ from foundry.utils.ddp import RankedLogger
14
+
15
+ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
16
+
17
+
18
+ class RFD3(nn.Module):
19
+ """
20
+ Simplified model for generation
21
+ This module level serves to wrap the diffusion module of AF3
22
+ to be roughly equivalent to the AF3 model w/o trunk processing.
23
+
24
+ Allows the same sampler to be used
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ *,
30
+ # Channel dimensions ('global' features)
31
+ c_s: int,
32
+ c_z: int,
33
+ c_atom: int,
34
+ c_atompair: int,
35
+ # Arguments for modules that will be instantiated
36
+ token_initializer: DictConfig | dict,
37
+ diffusion_module: DictConfig | dict,
38
+ inference_sampler: DictConfig | dict,
39
+ **_,
40
+ ):
41
+ super().__init__()
42
+ # Check for chunked P_LL mode via environment variable
43
+ use_chunked_pll = os.environ.get("RFD3_LOW_MEMORY_MODE", None) == "1"
44
+ ranked_logger.info(f"RFD3 initialized with chunked_pll={use_chunked_pll}")
45
+
46
+ # Simple constant-feature initializer
47
+ self.token_initializer = TokenInitializer(
48
+ c_s=c_s,
49
+ c_z=c_z,
50
+ c_atom=c_atom,
51
+ c_atompair=c_atompair,
52
+ use_chunked_pll=use_chunked_pll,
53
+ **token_initializer,
54
+ )
55
+
56
+ # Diffusion module instantiated to allow for config scripting
57
+ self.diffusion_module = hydra.utils.instantiate(
58
+ diffusion_module, c_atom=c_atom, c_atompair=c_atompair, c_s=c_s, c_z=c_z
59
+ )
60
+
61
+ self.use_classifier_free_guidance = (
62
+ inference_sampler["use_classifier_free_guidance"]
63
+ and inference_sampler["cfg_scale"] != 1.0
64
+ )
65
+ self.cfg_features = inference_sampler.pop("cfg_features", [])
66
+
67
+ # ... initialize the inference sampler, which performs a full diffusion rollout during inference
68
+ self.inference_sampler = ConditionalDiffusionSampler(**inference_sampler)
69
+
70
+ def forward(
71
+ self,
72
+ input: dict,
73
+ coord_atom_lvl_to_be_noised: torch.Tensor = None,
74
+ n_cycle=None,
75
+ **_,
76
+ ) -> dict:
77
+ initializer_outputs = self.token_initializer(input["f"])
78
+
79
+ if self.training:
80
+ # Single denoising step
81
+ return self.diffusion_module(
82
+ X_noisy_L=input["X_noisy_L"],
83
+ t=input["t"],
84
+ f=input["f"],
85
+ n_recycle=n_cycle,
86
+ **initializer_outputs,
87
+ ) # [D, L, 3]
88
+ else:
89
+ if self.use_classifier_free_guidance:
90
+ f_ref = strip_f(input["f"], self.cfg_features)
91
+ ref_initializer_outputs = self.token_initializer(f_ref)
92
+ else:
93
+ f_ref = None
94
+ ref_initializer_outputs = None
95
+
96
+ return self.inference_sampler.sample_diffusion_like_af3(
97
+ f=input["f"],
98
+ f_ref=f_ref, # for cfg
99
+ diffusion_module=self.diffusion_module,
100
+ diffusion_batch_size=coord_atom_lvl_to_be_noised.shape[0],
101
+ coord_atom_lvl_to_be_noised=coord_atom_lvl_to_be_noised,
102
+ # Forwarded as **kwargs:
103
+ initializer_outputs=initializer_outputs,
104
+ ref_initializer_outputs=ref_initializer_outputs, # for cfg
105
+ )