rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,305 @@
1
+ """
2
+ Virtual-atom transforms for Atom14
3
+ """
4
+
5
+ import biotite.structure as struc
6
+ import numpy as np
7
+ from atomworks.io.utils.atom_array_plus import insert_atoms
8
+ from atomworks.ml.transforms.base import (
9
+ Transform,
10
+ )
11
+ from atomworks.ml.utils.token import get_token_starts
12
+ from rfd3.constants import (
13
+ ATOM14_ATOM_NAME_TO_ELEMENT,
14
+ ATOM14_ATOM_NAMES,
15
+ VIRTUAL_ATOM_ELEMENT_NAME,
16
+ association_schemes,
17
+ association_schemes_stripped,
18
+ ccd_ordering_atomchar,
19
+ )
20
+ from rfd3.transforms.conditioning_base import (
21
+ UnindexFlaggedTokens,
22
+ )
23
+ from rfd3.transforms.util_transforms import (
24
+ assert_single_representative,
25
+ get_af3_token_representative_masks,
26
+ )
27
+
28
+ from foundry.common import exists
29
+
30
+
31
+ def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="atom14"):
32
+ """
33
+ Maps a list of names to the atom14 naming scheme for that particular name (within a specific residue)
34
+ NB this function is a bit more general since it is used to handle tipatoms too.
35
+ """
36
+ if scheme not in association_schemes_stripped:
37
+ raise ValueError(
38
+ f"Scheme {scheme} not found in association_schemes_stripped. Available schemes: {list(association_schemes_stripped.keys())}"
39
+ )
40
+ atom_names = (
41
+ [str(atom_names)] if isinstance(atom_names, (str, np.str_)) else atom_names
42
+ )
43
+ idxs = np.array(
44
+ [
45
+ association_schemes_stripped[scheme][res_name].index(name)
46
+ for name in atom_names
47
+ ]
48
+ )
49
+ return ATOM14_ATOM_NAMES[idxs]
50
+
51
+
52
+ def map_names_to_elements(
53
+ atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME
54
+ ) -> np.ndarray:
55
+ """
56
+ Maps ATOM14 atom names to their corresponding elements.
57
+ If a name is not in ATOM14_ATOM_NAMES (e.g. if atom name is VX - virtual atom),
58
+ then it returns the default value
59
+ """
60
+ atom_names = [atom_names] if isinstance(atom_names, str) else atom_names
61
+ elements = [ATOM14_ATOM_NAME_TO_ELEMENT.get(name, default) for name in atom_names]
62
+ return np.array(elements)
63
+
64
+
65
+ def generate_atom_mappings_(scheme="atom14"):
66
+ scheme = association_schemes[scheme]
67
+
68
+ atom_mapping = {}
69
+ symmetry_mapping = {}
70
+
71
+ for aaa, atom14_names in ccd_ordering_atomchar.items():
72
+ mapping = list(range(14))
73
+ scheme_names = scheme[aaa]
74
+
75
+ for ccd_index in range(len(atom14_names)):
76
+ atom14_name = atom14_names[ccd_index]
77
+ if atom14_name is not None:
78
+ assert (
79
+ atom14_name in scheme_names
80
+ ), f"{atom14_name} not in CCD ordering for {aaa}"
81
+ scheme_index = scheme_names.index(atom14_name)
82
+ scheme_index_in_cur_mapping = mapping.index(scheme_index)
83
+ mapping[ccd_index], mapping[scheme_index_in_cur_mapping] = (
84
+ mapping[scheme_index_in_cur_mapping],
85
+ mapping[ccd_index],
86
+ )
87
+
88
+ assert set(mapping) == set(range(len(scheme_names)))
89
+
90
+ # atom_mapping[aaa] = mapping
91
+ atom_mapping[aaa] = mapping
92
+
93
+ ##################################################################
94
+ # Temporarily comment this out
95
+ # if aaa in symmetric_atomchar:
96
+ # symmetry_mapping[aaa] = []
97
+ # for group in symmetric_atomchar[aaa]:
98
+ # indices = [atom14_names.index(name) for name in group]
99
+ # symmetry_mapping[aaa].append(indices)
100
+ symmetry_mapping = {}
101
+ ##################################################################
102
+
103
+ # Test that the mapping is valid
104
+ for aaa in atom_mapping.keys():
105
+ idxs = atom_mapping[aaa]
106
+
107
+ assert len(idxs) == len(set(idxs)), f"Duplicate indices in mapping for {aaa}"
108
+
109
+ atom_mapping_expected = np.array(scheme[aaa])[idxs]
110
+ atom_mapping_actual = np.array(ccd_ordering_atomchar[aaa])
111
+
112
+ assert np.array_equal(
113
+ atom_mapping_expected, atom_mapping_actual
114
+ ), f"Mapping mismatch for {aaa}: {atom_mapping_expected} != {atom_mapping_actual}"
115
+
116
+ return atom_mapping, symmetry_mapping
117
+
118
+
119
+ def permute_symmetric_atom_names_(
120
+ atom_names: list, res_name: str, association_map: dict, symmetry_map: dict
121
+ ) -> list:
122
+ # NB: Can leak GT sequence if the model receives the canconical ordering of atoms as input
123
+ # With the structure-local atom attention it will not unless N_keys(n_attn_seq_neighbours) > n_atom_attn_queries.
124
+ if res_name in association_map:
125
+ idx_to_swap = association_map[res_name]
126
+ atom_names = atom_names[idx_to_swap]
127
+ if res_name in symmetry_map:
128
+ for group in symmetry_map[res_name]:
129
+ if np.random.rand() < 0.5: # random swap
130
+ atom_names[group] = atom_names[group[::-1]]
131
+ return atom_names
132
+
133
+
134
+ #####################################################################################################
135
+ # Virtual atom transforms
136
+ #####################################################################################################
137
+
138
+
139
+ class PadTokensWithVirtualAtoms(Transform):
140
+ """
141
+ Pads tokens with virtual atoms to ensure all residue tokens have a fixed number of atoms
142
+
143
+ Applies padding only to the tokens who do not have sequence
144
+ Applies association schema during training and to tokens with sequence.
145
+ """
146
+
147
+ requires_previous_transforms = [UnindexFlaggedTokens]
148
+
149
+ def __init__(
150
+ self,
151
+ n_atoms_per_token,
152
+ atom_to_pad_from,
153
+ association_scheme,
154
+ ):
155
+ self.n_atoms_per_token = n_atoms_per_token
156
+ self.atom_to_pad_from = atom_to_pad_from
157
+ self.association_scheme = association_scheme
158
+ if exists(association_scheme):
159
+ self.association_map_, self.symmetry_map_ = generate_atom_mappings_(
160
+ association_scheme
161
+ )
162
+
163
+ def forward(self, data: dict) -> dict:
164
+ atom_array = data["atom_array"]
165
+ starts = get_token_starts(atom_array, add_exclusive_stop=True)
166
+ token_starts = starts[:-1]
167
+ token_level_array = atom_array[token_starts]
168
+ is_motif_atom_with_fixed_seq = token_level_array.is_motif_atom_with_fixed_seq
169
+ is_motif_token_unindexed = token_level_array.is_motif_atom_unindexed
170
+
171
+ token_ids = np.unique(atom_array.token_id)
172
+ assert len(token_ids) == len(
173
+ is_motif_atom_with_fixed_seq
174
+ ), "Token ids and token level array have different lengths!"
175
+
176
+ # Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
177
+ is_residue = (
178
+ token_level_array.is_protein & ~token_level_array.atomize
179
+ ) | is_motif_token_unindexed
180
+
181
+ # Unindexed tokens are never padded, and so are treated as residues with fixed sequence.
182
+ is_paddable = is_residue & ~(
183
+ is_motif_atom_with_fixed_seq | is_motif_token_unindexed
184
+ )
185
+ is_non_paddable_residue = is_residue & (
186
+ is_motif_atom_with_fixed_seq | is_motif_token_unindexed
187
+ )
188
+
189
+ # Collect virtual atoms to insert (we will insert them all at once)
190
+ virtual_atoms_to_insert = []
191
+ insert_positions = []
192
+
193
+ # First pass: collect virtual atoms for insertion
194
+ for token_id, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
195
+ if is_paddable[token_id]:
196
+ token = atom_array[start:end]
197
+ # First, pad with virtual atoms if needed
198
+ n_pad = self.n_atoms_per_token - len(token)
199
+ if n_pad > 0:
200
+ mask = get_af3_token_representative_masks(
201
+ token, central_atom=self.atom_to_pad_from
202
+ )
203
+ assert_single_representative(token)
204
+
205
+ # ... Create virtual atoms
206
+ pad_atoms = token[mask].copy()
207
+ pad_atoms = (
208
+ pad_atoms[0]
209
+ if isinstance(pad_atoms, struc.AtomArray)
210
+ else pad_atoms
211
+ )
212
+ pad_atoms.element = VIRTUAL_ATOM_ELEMENT_NAME
213
+
214
+ # ... Expand to desired number of atoms
215
+ pad_array = struc.array([pad_atoms] * n_pad)
216
+
217
+ # ... Change occupancy | if any atom in the token has occupancy, set to 1.0
218
+ occ = 1.0 if pad_atoms.occupancy.sum() > 0.0 else 0.0
219
+ pad_array.occupancy = np.full(n_pad, occ)
220
+
221
+ # ... Even if the input pad_atoms are all motif, we don't ever want padded atoms to be motif
222
+ pad_array.is_motif_atom = np.zeros(n_pad, dtype=bool)
223
+
224
+ # Handle multidimensional annotations
225
+ def _fix_multidimensional_annotations_in_pad_array(
226
+ atomarray, padarray
227
+ ):
228
+ for annotation in atomarray.get_annotation_categories():
229
+ if len(atomarray.get_annotation(annotation).shape) > 1:
230
+ stacked = np.stack(
231
+ padarray.get_annotation(annotation)
232
+ ).astype(float)
233
+ padarray.del_annotation(annotation)
234
+ padarray.set_annotation(annotation, stacked)
235
+ return padarray
236
+
237
+ pad_array = _fix_multidimensional_annotations_in_pad_array(
238
+ token, pad_array
239
+ )
240
+
241
+ # Collect virtual atoms for later insertion
242
+ virtual_atoms_to_insert.append(pad_array)
243
+ insert_positions.append(end)
244
+
245
+ # Insert all virtual atoms at once using insert_atoms
246
+ if virtual_atoms_to_insert:
247
+ atom_array_padded = insert_atoms(
248
+ atom_array, virtual_atoms_to_insert, insert_positions
249
+ )
250
+ else:
251
+ atom_array_padded = atom_array
252
+
253
+ # Initialize gt_atom_name annotation if it doesn't exist
254
+ if "gt_atom_name" not in atom_array_padded.get_annotation_categories():
255
+ atom_array_padded.set_annotation(
256
+ "gt_atom_name", np.empty(len(atom_array_padded), dtype="U4")
257
+ )
258
+
259
+ # Second pass: process tokens with proper atom name assignment after padding
260
+ # Get updated token starts after padding
261
+ starts_padded = get_token_starts(atom_array_padded, add_exclusive_stop=True)
262
+
263
+ for token_id, (start, end) in enumerate(
264
+ zip(starts_padded[:-1], starts_padded[1:])
265
+ ):
266
+ if is_paddable[token_id]:
267
+ # ... Permutation of atom names during training
268
+ if not data["is_inference"] and exists(self.association_scheme):
269
+ atom_names = permute_symmetric_atom_names_(
270
+ ATOM14_ATOM_NAMES,
271
+ atom_array_padded.res_name[start],
272
+ association_map=self.association_map_,
273
+ symmetry_map=self.symmetry_map_,
274
+ )
275
+ else:
276
+ atom_names = ATOM14_ATOM_NAMES
277
+ atom_array_padded.atom_name[start:end] = atom_names
278
+ atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names
279
+
280
+ elif is_non_paddable_residue[token_id]:
281
+ # When sequence-constrained, we want to directly map the residue name based on the sequence
282
+ atom_names, res_name = (
283
+ atom_array_padded.atom_name[start:end],
284
+ atom_array_padded.res_name[start],
285
+ )
286
+ atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names
287
+ atom_names = map_to_association_scheme(
288
+ atom_names, res_name, scheme=self.association_scheme
289
+ )
290
+ atom_array_padded.atom_name[start:end] = atom_names
291
+ else:
292
+ # ... Add gt_atom_name annotation to other tokens
293
+ atom_names = atom_array_padded.atom_name[start:end]
294
+ atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names
295
+
296
+ # ... Update atom array
297
+ assert {VIRTUAL_ATOM_ELEMENT_NAME} != set(
298
+ atom_array_padded.element[start:end].tolist()
299
+ ), (
300
+ "Padded atoms should be virtual atoms, but found: "
301
+ f"{set(atom_array_padded.element[start:end].tolist())}"
302
+ )
303
+
304
+ data["atom_array"] = atom_array_padded
305
+ return data