rc-foundry 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. foundry/__init__.py +57 -0
  2. foundry/callbacks/__init__.py +5 -0
  3. foundry/callbacks/callback.py +116 -0
  4. foundry/callbacks/health_logging.py +419 -0
  5. foundry/callbacks/metrics_logging.py +211 -0
  6. foundry/callbacks/timing_logging.py +67 -0
  7. foundry/callbacks/train_logging.py +278 -0
  8. foundry/common.py +108 -0
  9. foundry/constants.py +28 -0
  10. foundry/hydra/resolvers.py +77 -0
  11. foundry/inference_engines/base.py +235 -0
  12. foundry/inference_engines/checkpoint_registry.py +66 -0
  13. foundry/metrics/__init__.py +12 -0
  14. foundry/metrics/losses.py +30 -0
  15. foundry/metrics/metric.py +319 -0
  16. foundry/model/layers/blocks.py +47 -0
  17. foundry/testing/__init__.py +6 -0
  18. foundry/testing/fixtures.py +19 -0
  19. foundry/testing/pytest_hooks.py +15 -0
  20. foundry/trainers/fabric.py +923 -0
  21. foundry/training/EMA.py +67 -0
  22. foundry/training/checkpoint.py +61 -0
  23. foundry/training/schedulers.py +91 -0
  24. foundry/utils/alignment.py +86 -0
  25. foundry/utils/components.py +415 -0
  26. foundry/utils/datasets.py +405 -0
  27. foundry/utils/ddp.py +103 -0
  28. foundry/utils/instantiators.py +72 -0
  29. foundry/utils/logging.py +279 -0
  30. foundry/utils/rigid.py +1460 -0
  31. foundry/utils/rotation_augmentation.py +65 -0
  32. foundry/utils/squashfs.py +172 -0
  33. foundry/utils/torch.py +317 -0
  34. foundry/utils/weights.py +271 -0
  35. foundry/version.py +34 -0
  36. foundry_cli/__init__.py +3 -0
  37. foundry_cli/download_checkpoints.py +281 -0
  38. mpnn/__init__.py +1 -0
  39. mpnn/collate/feature_collator.py +265 -0
  40. mpnn/inference.py +53 -0
  41. mpnn/inference_engines/mpnn.py +549 -0
  42. mpnn/loss/nll_loss.py +122 -0
  43. mpnn/metrics/nll.py +369 -0
  44. mpnn/metrics/sequence_recovery.py +440 -0
  45. mpnn/model/layers/graph_embeddings.py +2372 -0
  46. mpnn/model/layers/message_passing.py +332 -0
  47. mpnn/model/layers/position_wise_feed_forward.py +44 -0
  48. mpnn/model/layers/positional_encoding.py +98 -0
  49. mpnn/model/mpnn.py +2632 -0
  50. mpnn/pipelines/mpnn.py +162 -0
  51. mpnn/samplers/samplers.py +167 -0
  52. mpnn/train.py +341 -0
  53. mpnn/trainers/mpnn.py +193 -0
  54. mpnn/transforms/feature_aggregation/mpnn.py +184 -0
  55. mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
  56. mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
  57. mpnn/transforms/feature_aggregation/user_settings.py +347 -0
  58. mpnn/transforms/polymer_ligand_interface.py +164 -0
  59. mpnn/utils/inference.py +2397 -0
  60. mpnn/utils/probability.py +37 -0
  61. mpnn/utils/weights.py +309 -0
  62. rc_foundry-0.1.1.dist-info/METADATA +239 -0
  63. rc_foundry-0.1.1.dist-info/RECORD +180 -0
  64. rc_foundry-0.1.1.dist-info/WHEEL +4 -0
  65. rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
  66. rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
  67. rf3/__init__.py +3 -0
  68. rf3/_version.py +33 -0
  69. rf3/alignment.py +79 -0
  70. rf3/callbacks/dump_validation_structures.py +101 -0
  71. rf3/callbacks/metrics_logging.py +324 -0
  72. rf3/chemical.py +1529 -0
  73. rf3/cli.py +77 -0
  74. rf3/data/cyclic_transform.py +78 -0
  75. rf3/data/extra_xforms.py +36 -0
  76. rf3/data/ground_truth_template.py +463 -0
  77. rf3/data/paired_msa.py +206 -0
  78. rf3/data/pipeline_utils.py +128 -0
  79. rf3/data/pipelines.py +558 -0
  80. rf3/diffusion_samplers/inference_sampler.py +222 -0
  81. rf3/inference.py +65 -0
  82. rf3/inference_engines/__init__.py +5 -0
  83. rf3/inference_engines/rf3.py +735 -0
  84. rf3/kinematics.py +354 -0
  85. rf3/loss/af3_confidence_loss.py +515 -0
  86. rf3/loss/af3_losses.py +655 -0
  87. rf3/loss/loss.py +179 -0
  88. rf3/metrics/chiral.py +179 -0
  89. rf3/metrics/clashing_chains.py +68 -0
  90. rf3/metrics/distogram.py +421 -0
  91. rf3/metrics/lddt.py +523 -0
  92. rf3/metrics/metadata.py +43 -0
  93. rf3/metrics/metric_utils.py +192 -0
  94. rf3/metrics/predicted_error.py +134 -0
  95. rf3/metrics/rasa.py +108 -0
  96. rf3/metrics/selected_distances.py +91 -0
  97. rf3/model/RF3.py +527 -0
  98. rf3/model/RF3_blocks.py +92 -0
  99. rf3/model/RF3_structure.py +303 -0
  100. rf3/model/layers/af3_auxiliary_heads.py +255 -0
  101. rf3/model/layers/af3_diffusion_transformer.py +544 -0
  102. rf3/model/layers/attention.py +313 -0
  103. rf3/model/layers/layer_utils.py +127 -0
  104. rf3/model/layers/mlff.py +118 -0
  105. rf3/model/layers/outer_product.py +59 -0
  106. rf3/model/layers/pairformer_layers.py +783 -0
  107. rf3/model/layers/structure_bias.py +56 -0
  108. rf3/scoring.py +1787 -0
  109. rf3/symmetry/resolve.py +284 -0
  110. rf3/train.py +194 -0
  111. rf3/trainers/rf3.py +570 -0
  112. rf3/util_module.py +47 -0
  113. rf3/utils/frames.py +109 -0
  114. rf3/utils/inference.py +665 -0
  115. rf3/utils/io.py +198 -0
  116. rf3/utils/loss.py +72 -0
  117. rf3/utils/predict_and_score.py +165 -0
  118. rf3/utils/predicted_error.py +673 -0
  119. rf3/utils/recycling.py +42 -0
  120. rf3/validate.py +140 -0
  121. rfd3/.gitignore +7 -0
  122. rfd3/Makefile +76 -0
  123. rfd3/__init__.py +12 -0
  124. rfd3/callbacks.py +66 -0
  125. rfd3/cli.py +41 -0
  126. rfd3/constants.py +212 -0
  127. rfd3/engine.py +543 -0
  128. rfd3/inference/datasets.py +193 -0
  129. rfd3/inference/input_parsing.py +1123 -0
  130. rfd3/inference/legacy_input_parsing.py +717 -0
  131. rfd3/inference/parsing.py +165 -0
  132. rfd3/inference/symmetry/atom_array.py +298 -0
  133. rfd3/inference/symmetry/checks.py +241 -0
  134. rfd3/inference/symmetry/contigs.py +63 -0
  135. rfd3/inference/symmetry/frames.py +355 -0
  136. rfd3/inference/symmetry/symmetry_utils.py +398 -0
  137. rfd3/metrics/design_metrics.py +465 -0
  138. rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
  139. rfd3/metrics/hbonds_metrics.py +389 -0
  140. rfd3/metrics/losses.py +325 -0
  141. rfd3/metrics/metrics_utils.py +118 -0
  142. rfd3/metrics/sidechain_metrics.py +349 -0
  143. rfd3/model/RFD3.py +105 -0
  144. rfd3/model/RFD3_diffusion_module.py +387 -0
  145. rfd3/model/cfg_utils.py +81 -0
  146. rfd3/model/inference_sampler.py +635 -0
  147. rfd3/model/layers/attention.py +577 -0
  148. rfd3/model/layers/block_utils.py +580 -0
  149. rfd3/model/layers/blocks.py +777 -0
  150. rfd3/model/layers/chunked_pairwise.py +377 -0
  151. rfd3/model/layers/encoders.py +417 -0
  152. rfd3/model/layers/layer_utils.py +197 -0
  153. rfd3/model/layers/pairformer_layers.py +128 -0
  154. rfd3/run_inference.py +45 -0
  155. rfd3/testing/debug.py +139 -0
  156. rfd3/testing/debug_utils.py +73 -0
  157. rfd3/testing/testing_utils.py +356 -0
  158. rfd3/train.py +194 -0
  159. rfd3/trainer/dump_validation_structures.py +154 -0
  160. rfd3/trainer/fabric_trainer.py +923 -0
  161. rfd3/trainer/recycling.py +42 -0
  162. rfd3/trainer/rfd3.py +485 -0
  163. rfd3/trainer/trainer_utils.py +502 -0
  164. rfd3/transforms/conditioning_base.py +508 -0
  165. rfd3/transforms/conditioning_utils.py +200 -0
  166. rfd3/transforms/design_transforms.py +807 -0
  167. rfd3/transforms/dna_crop.py +523 -0
  168. rfd3/transforms/hbonds.py +407 -0
  169. rfd3/transforms/hbonds_hbplus.py +246 -0
  170. rfd3/transforms/ncaa_transforms.py +153 -0
  171. rfd3/transforms/pipelines.py +632 -0
  172. rfd3/transforms/ppi_transforms.py +541 -0
  173. rfd3/transforms/rasa.py +116 -0
  174. rfd3/transforms/symmetry.py +76 -0
  175. rfd3/transforms/training_conditions.py +552 -0
  176. rfd3/transforms/util_transforms.py +498 -0
  177. rfd3/transforms/virtual_atoms.py +305 -0
  178. rfd3/utils/inference.py +648 -0
  179. rfd3/utils/io.py +245 -0
  180. rfd3/utils/vizualize.py +276 -0
@@ -0,0 +1,498 @@
1
+ # see atomworks.ml.ransforms.feature_aggregation
2
+ import time
3
+ from typing import Any, Dict
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from atomworks.constants import STANDARD_AA
9
+ from atomworks.enums import ChainTypeInfo
10
+ from atomworks.io.utils.sequence import (
11
+ is_purine,
12
+ is_pyrimidine,
13
+ )
14
+ from atomworks.ml.encoding_definitions import AF3SequenceEncoding
15
+ from atomworks.ml.transforms._checks import (
16
+ check_atom_array_annotation,
17
+ check_contains_keys,
18
+ check_is_instance,
19
+ )
20
+ from atomworks.ml.transforms.atom_array import get_within_entity_idx
21
+ from atomworks.ml.transforms.base import Transform
22
+ from atomworks.ml.utils.token import (
23
+ get_token_count,
24
+ get_token_starts,
25
+ is_glycine,
26
+ is_protein_unknown,
27
+ is_standard_aa_not_glycine,
28
+ is_unknown_nucleotide,
29
+ spread_token_wise,
30
+ )
31
+ from biotite.structure import AtomArray
32
+
33
+ af3_sequence_encoding = AF3SequenceEncoding()
34
+
35
+
36
+ def assert_single_representative(token, central_atom="CB"):
37
+ mask = get_af3_token_representative_masks(token, central_atom=central_atom)
38
+ assert (
39
+ np.sum(mask) == 1
40
+ ), f"No representative atom (CB) found. mask: {mask}\nToken: {token}"
41
+
42
+
43
+ def assert_single_token(token):
44
+ assert get_token_count(token) == 1, f"Token is not a single token: {token}"
45
+ assert_single_representative(token)
46
+
47
+
48
+ def add_representative_atom(token, central_atom="CB"):
49
+ if get_af3_token_representative_masks(token, central_atom=central_atom).sum() == 1:
50
+ return token
51
+ length = token.array_length()
52
+ token.atomize = np.array([True] + [False] * (length - 1), dtype=bool)
53
+ assert_single_representative(token)
54
+ return token
55
+
56
+
57
+ class TimerWrapper(Transform):
58
+ def check_input(self, *args, **kwargs):
59
+ pass
60
+
61
+ def __init__(self, transform):
62
+ self.transform = transform
63
+
64
+ def forward(self, data):
65
+ start = time.time()
66
+ data = self.transform.forward(data)
67
+ print(f"Time taken: {time.time() - start} s || Transform: {self.transform}")
68
+ return data
69
+
70
+
71
+ class IPDB(Transform):
72
+ def forward(self, data):
73
+ aa = data["atom_array"] # noqa
74
+ import ipdb
75
+
76
+ ipdb.set_trace()
77
+ return data
78
+
79
+
80
+ sequence_encoding = AF3SequenceEncoding()
81
+
82
+ _aa_like_res_names = sequence_encoding.all_res_names[sequence_encoding.is_aa_like]
83
+ _rna_like_res_names = sequence_encoding.all_res_names[sequence_encoding.is_rna_like]
84
+ _dna_like_res_names = sequence_encoding.all_res_names[sequence_encoding.is_dna_like]
85
+
86
+
87
+ class AssignTypes(Transform):
88
+ """
89
+ Assigns types to the atoms in the atom array using af3 sequence encoding scheme.
90
+ """
91
+
92
+ def check_input(self, data):
93
+ assert "atom_array" in data, "Input data must contain 'atom_array'."
94
+
95
+ def forward(self, data):
96
+ data["atom_array"] = assign_types_(data["atom_array"])
97
+ return data
98
+
99
+
100
+ def assign_types_(atom_array):
101
+ token_starts = get_token_starts(atom_array)
102
+ res_names = atom_array[token_starts].res_name
103
+ token_id = np.arange(get_token_count(atom_array), dtype=np.uint32) # [n_tokens]
104
+ atom_to_token_map = spread_token_wise(atom_array, token_id)
105
+
106
+ is_protein = np.isin(res_names, _aa_like_res_names).astype(bool)
107
+ is_residue = np.isin(res_names, STANDARD_AA).astype(bool)
108
+ is_rna = np.isin(res_names, _rna_like_res_names).astype(bool)
109
+ is_dna = np.isin(res_names, _dna_like_res_names).astype(bool)
110
+ is_ligand = ~(is_protein | is_rna | is_dna).astype(bool)
111
+
112
+ # Set annotations
113
+ atom_array.set_annotation("is_protein", is_protein[atom_to_token_map])
114
+ atom_array.set_annotation("is_rna", is_rna[atom_to_token_map])
115
+ atom_array.set_annotation("is_dna", is_dna[atom_to_token_map])
116
+ atom_array.set_annotation("is_ligand", is_ligand[atom_to_token_map])
117
+ atom_array.set_annotation("is_residue", is_residue[atom_to_token_map])
118
+
119
+ return atom_array
120
+
121
+
122
+ class AggregateFeaturesLikeAF3WithoutMSA(Transform):
123
+ """
124
+ Exactly like AggregateFeaturesLikeAF3 but without MSAs
125
+
126
+ Removed comments for readability, no additional code is in this function, just removed msa parts
127
+ """
128
+
129
+ requires_previous_transforms = [
130
+ "AtomizeByCCDName",
131
+ "EncodeAF3TokenLevelFeatures",
132
+ "AddAF3TokenBondFeatures",
133
+ "UnindexFlaggedTokens",
134
+ ]
135
+ incompatible_previous_transforms = [
136
+ "AggregateFeaturesLikeAF3",
137
+ "AggregateFeaturesLikeAF3WithoutMSA",
138
+ ]
139
+
140
+ def check_input(self, data) -> None:
141
+ check_contains_keys(data, ["atom_array"])
142
+ check_is_instance(data, "atom_array", AtomArray)
143
+ check_atom_array_annotation(
144
+ data, ["coord_to_be_noised", "chain_iid", "occupancy"]
145
+ )
146
+
147
+ def forward(self, data: Dict[str, Any]) -> Dict[str, Any]:
148
+ """
149
+ Aggregates features into the format expected by AlphaFold 3.
150
+
151
+ This method processes the input data, combining MSA features, ground truth
152
+ structures, and other relevant information into a standardized format.
153
+
154
+ Args:
155
+ data (Dict[str, Any]): The input data dictionary containing MSA features,
156
+ atom array, and other relevant information.
157
+
158
+ Returns:
159
+ Dict[str, Any]: The processed data dictionary with aggregated features.
160
+ """
161
+ # Initialize feats dictionary if not present
162
+ if "feats" not in data:
163
+ data["feats"] = {}
164
+
165
+ data["feats"]["ref_atom_name_chars"] = F.one_hot(
166
+ data["feats"]["ref_atom_name_chars"].long(), num_classes=64
167
+ ).float()
168
+ data["feats"]["ref_element"] = F.one_hot(
169
+ data["feats"]["ref_element"].long(), num_classes=128
170
+ ).float()
171
+ data["feats"]["ref_pos"] = torch.nan_to_num(data["feats"]["ref_pos"], nan=0.0)
172
+
173
+ # Process ground truth structure
174
+ atom_array = data["atom_array"]
175
+
176
+ coord_atom_lvl = atom_array.coord
177
+ mask_atom_lvl = atom_array.occupancy > 0.0
178
+ token_starts = get_token_starts(atom_array)
179
+ token_level_array = atom_array[token_starts]
180
+ chain_iid_token_lvl = token_level_array.chain_iid
181
+ if "ground_truth" not in data:
182
+ data["ground_truth"] = {}
183
+
184
+ data["ground_truth"].update(
185
+ {
186
+ "coord_atom_lvl": torch.tensor(coord_atom_lvl), # [n_atoms, 3]
187
+ "mask_atom_lvl": torch.tensor(mask_atom_lvl), # [n_atoms]
188
+ "chain_iid_token_lvl": chain_iid_token_lvl, # numpy.ndarray of strings with shape (n_tokens,)
189
+ "is_original_unindexed_token": torch.from_numpy(
190
+ data["ground_truth"].get(
191
+ "is_original_unindexed_token",
192
+ np.zeros(len(token_starts), dtype=bool),
193
+ )
194
+ ).bool(), # [n_tokens]
195
+ }
196
+ )
197
+ data["coord_atom_lvl_to_be_noised"] = torch.tensor(
198
+ atom_array.coord_to_be_noised
199
+ )
200
+
201
+ # Remove any token bond features relating to unindexed tokens
202
+ if "token_bonds" in data["feats"]:
203
+ token_bonds = data["feats"]["token_bonds"]
204
+ mask = data["feats"]["is_motif_token_unindexed"]
205
+
206
+ # tokens bonded to unindexed & unindexed bonded to tokens
207
+ token_bonds[mask, :] = False
208
+ token_bonds[:, mask] = False
209
+
210
+ # Add partial t during inference
211
+ if "partial_t" in atom_array.get_annotation_categories():
212
+ assert data["is_inference"], "Partial diffusion only inference!"
213
+ data["feats"]["partial_t"] = torch.from_numpy(
214
+ atom_array.get_annotation("partial_t")
215
+ )
216
+
217
+ return data
218
+
219
+
220
+ def add_backbone_and_sidechain_annotations(atom_array: AtomArray) -> AtomArray:
221
+ """
222
+ Adds the backbone and sidechain annotations to the AtomArray.
223
+
224
+ Args:
225
+ atom_array (AtomArray): The AtomArray to which the annotations will be added.
226
+
227
+ Returns:
228
+ AtomArray: The AtomArray with the added annotations.
229
+ """
230
+ # Get the backbone atoms
231
+ atomized = atom_array.atomize
232
+ is_protein = np.isin(atom_array.chain_type, ChainTypeInfo.PROTEINS)
233
+ backbone_atoms = ["N", "CA", "C", "O"]
234
+ backbone_mask = np.isin(atom_array.atom_name, backbone_atoms) & is_protein
235
+ backbone_mask = backbone_mask | atomized
236
+ sidechain_mask = ~backbone_mask & ~atomized & is_protein
237
+
238
+ # Add the annotations
239
+ atom_array.set_annotation("is_backbone", backbone_mask)
240
+ atom_array.set_annotation("is_sidechain", sidechain_mask)
241
+
242
+ return atom_array
243
+
244
+
245
+ ####################################################################################################
246
+ # Changes to datahub base transforms (instead of creating new branches)
247
+ ####################################################################################################
248
+
249
+
250
+ # from atomworks.ml.utils.token import get_af3_token_representative_masks
251
+ def get_af3_token_representative_masks(
252
+ atom_array: AtomArray, central_atom: str = "CA"
253
+ ) -> np.ndarray:
254
+ pyrimidine_representative_atom = is_pyrimidine(atom_array.res_name) & (
255
+ atom_array.atom_name == "C2"
256
+ )
257
+ purine_representative_atom = is_purine(atom_array.res_name) & (
258
+ atom_array.atom_name == "C4"
259
+ )
260
+ unknown_na_representative_atom = is_unknown_nucleotide(atom_array.res_name) & (
261
+ atom_array.atom_name == "C4"
262
+ )
263
+
264
+ glycine_representative_atom = is_glycine(atom_array.res_name) & (
265
+ atom_array.atom_name == "CA"
266
+ )
267
+ protein_residue_not_glycine_representative_atom = is_standard_aa_not_glycine(
268
+ atom_array.res_name
269
+ ) & (
270
+ atom_array.atom_name == central_atom # only change
271
+ )
272
+ unknown_protein_residue_representative_atom = (
273
+ is_protein_unknown(atom_array.res_name)
274
+ ) & (atom_array.atom_name == "CA")
275
+ atoms = atom_array.atomize
276
+
277
+ _token_rep_mask = (
278
+ pyrimidine_representative_atom
279
+ | purine_representative_atom
280
+ | unknown_na_representative_atom
281
+ | glycine_representative_atom
282
+ | protein_residue_not_glycine_representative_atom
283
+ | unknown_protein_residue_representative_atom
284
+ | atoms
285
+ )
286
+ return _token_rep_mask
287
+
288
+
289
+ class RemoveTokensWithoutCorrespondingCentralAtom(Transform):
290
+ """
291
+ Remove tokens with missing central atoms.
292
+ """
293
+
294
+ def __init__(self, central_atom: str = "CA"):
295
+ self.central_atom = central_atom
296
+
297
+ def check_input(self, data):
298
+ check_contains_keys(data, ["atom_array"])
299
+ check_is_instance(data, "atom_array", AtomArray)
300
+ check_atom_array_annotation(data, ["atom_name", "res_name"])
301
+
302
+ def forward(self, data):
303
+ central_atom = self.central_atom
304
+ atom_array = data["atom_array"]
305
+ pyrimidine_mask = is_pyrimidine(atom_array.res_name)
306
+ purine_mask = is_purine(atom_array.res_name)
307
+ unknown_na_mask = is_unknown_nucleotide(atom_array.res_name)
308
+ glycine_mask = is_glycine(atom_array.res_name)
309
+ aa_not_glycine_mask = is_standard_aa_not_glycine(atom_array.res_name)
310
+ unknown_aa_mask = is_protein_unknown(atom_array.res_name)
311
+
312
+ anything_else_mask = ~(
313
+ pyrimidine_mask
314
+ | purine_mask
315
+ | unknown_na_mask
316
+ | glycine_mask
317
+ | aa_not_glycine_mask
318
+ | unknown_aa_mask
319
+ )
320
+
321
+ def _get_if_central_atom_present_mask(atom_array, case_mask, central_atom):
322
+ token_starts = get_token_starts(atom_array[case_mask])
323
+ central_atom_mask = atom_array[case_mask].atom_name == central_atom
324
+ if len(token_starts) == central_atom_mask.sum():
325
+ ## all tokens have central atom, *vast majority*
326
+ return case_mask
327
+ else:
328
+ ## find the missing ones, *very rare*
329
+ out_mask = case_mask
330
+ all_token_starts = get_token_starts(atom_array)
331
+ token_start_mask = case_mask[all_token_starts]
332
+ case_token_starts = all_token_starts[token_start_mask]
333
+
334
+ for item in case_token_starts:
335
+ res_start = item
336
+ idx = all_token_starts.tolist().index(res_start)
337
+ res_mask = np.bool_(np.zeros(len(atom_array)))
338
+ if idx == len(all_token_starts) - 1:
339
+ res_mask[res_start:] = True
340
+ else:
341
+ res_end = all_token_starts[idx + 1]
342
+ res_mask[res_start:res_end] = True
343
+ res_array = atom_array[res_mask]
344
+
345
+ # remove if central atom not present
346
+ if (res_array.atom_name == central_atom).sum() == 0:
347
+ out_mask = out_mask & ~res_mask
348
+ return out_mask
349
+
350
+ keep_mask = (
351
+ _get_if_central_atom_present_mask(atom_array, pyrimidine_mask, "C2")
352
+ | _get_if_central_atom_present_mask(atom_array, purine_mask, "C4")
353
+ | _get_if_central_atom_present_mask(atom_array, unknown_na_mask, "C4")
354
+ | _get_if_central_atom_present_mask(atom_array, glycine_mask, "CA")
355
+ | _get_if_central_atom_present_mask(
356
+ atom_array, aa_not_glycine_mask, central_atom
357
+ )
358
+ | _get_if_central_atom_present_mask(atom_array, unknown_aa_mask, "CA")
359
+ | anything_else_mask
360
+ )
361
+
362
+ data["atom_array"] = atom_array[keep_mask]
363
+ return data
364
+
365
+
366
+ class EncodeAF3TokenLevelFeatures(Transform):
367
+ def __init__(
368
+ self, sequence_encoding: AF3SequenceEncoding, encode_residues_to: int = None
369
+ ):
370
+ self.sequence_encoding = sequence_encoding
371
+ self.encode_residues_to = encode_residues_to # for spoofing the restype
372
+
373
+ def check_input(self, data: dict[str, Any]) -> None:
374
+ check_contains_keys(data, ["atom_array"])
375
+ check_is_instance(data, "atom_array", AtomArray)
376
+ check_atom_array_annotation(
377
+ data,
378
+ [
379
+ "atomize",
380
+ "pn_unit_iid",
381
+ "chain_entity",
382
+ "res_name",
383
+ "within_chain_res_idx",
384
+ ],
385
+ )
386
+
387
+ def forward(self, data: dict[str, Any]) -> dict[str, Any]:
388
+ atom_array = data["atom_array"]
389
+
390
+ # ... get token-level array
391
+ token_starts = get_token_starts(atom_array)
392
+ token_level_array = atom_array[token_starts]
393
+
394
+ # ... identifier tokens
395
+ # ... (residue)
396
+ residue_index = token_level_array.within_chain_res_idx
397
+ # ... (token)
398
+ token_index = np.arange(len(token_starts))
399
+ # ... (chain instance)
400
+ asym_name, asym_id = np.unique(
401
+ token_level_array.pn_unit_iid, return_inverse=True
402
+ )
403
+ # ... (chain entity)
404
+ entity_name, entity_id = np.unique(
405
+ token_level_array.pn_unit_entity, return_inverse=True
406
+ )
407
+ # ... (within chain entity)
408
+ sym_name, sym_id = get_within_entity_idx(token_level_array, level="pn_unit")
409
+
410
+ # ... molecule type
411
+ _aa_like_res_names = self.sequence_encoding.all_res_names[
412
+ self.sequence_encoding.is_aa_like
413
+ ]
414
+ is_protein = np.isin(token_level_array.res_name, _aa_like_res_names)
415
+
416
+ _rna_like_res_names = self.sequence_encoding.all_res_names[
417
+ self.sequence_encoding.is_rna_like
418
+ ]
419
+ is_rna = np.isin(token_level_array.res_name, _rna_like_res_names)
420
+
421
+ _dna_like_res_names = self.sequence_encoding.all_res_names[
422
+ self.sequence_encoding.is_dna_like
423
+ ]
424
+ is_dna = np.isin(token_level_array.res_name, _dna_like_res_names)
425
+
426
+ is_ligand = ~(is_protein | is_rna | is_dna)
427
+
428
+ # Get is_polar features
429
+ polar_restypes = np.array(
430
+ [
431
+ "SER",
432
+ "THR",
433
+ "ASN",
434
+ "GLN",
435
+ "TYR",
436
+ "CYS",
437
+ "HIS",
438
+ "LYS",
439
+ "ARG",
440
+ "ASP",
441
+ "GLU",
442
+ ]
443
+ )
444
+ is_polar = is_protein & np.isin(token_level_array.res_name, polar_restypes)
445
+
446
+ # ... sequence tokens
447
+ res_names = token_level_array.res_name
448
+ if self.encode_residues_to is not None:
449
+ is_masked = ~token_level_array.is_motif_atom_with_fixed_seq
450
+ res_names[is_masked] = np.full(
451
+ np.sum(is_masked), self.encode_residues_to, dtype=res_names.dtype
452
+ )
453
+
454
+ restype = self.sequence_encoding.encode(res_names)
455
+ data["encoded"] = {"seq": restype} # For msa's
456
+ restype = F.one_hot(
457
+ torch.tensor(restype), num_classes=self.sequence_encoding.n_tokens
458
+ ).numpy()
459
+
460
+ # ... Add termini annotations (n_tok, 2)
461
+ terminus_type = np.zeros(
462
+ (
463
+ len(token_level_array),
464
+ 2,
465
+ ),
466
+ dtype=restype.dtype,
467
+ )
468
+ terminus_type[token_level_array.is_C_terminus, 0] = 1
469
+ terminus_type[token_level_array.is_N_terminus, 1] = 1
470
+
471
+ # ... add to data dict
472
+ if "feats" not in data:
473
+ data["feats"] = {}
474
+ if "feat_metadata" not in data:
475
+ data["feat_metadata"] = {}
476
+
477
+ # ... add to data dict
478
+ data["feats"] |= {
479
+ "residue_index": residue_index, # (N_tokens) (int)
480
+ "token_index": token_index, # (N_tokens) (int)
481
+ "asym_id": asym_id, # (N_tokens) (int)
482
+ "entity_id": entity_id, # (N_tokens) (int)
483
+ "sym_id": sym_id, # (N_tokens) (int)
484
+ "restype": restype, # (N_tokens, 32) (float, one-hot)
485
+ "is_protein": is_protein, # (N_tokens) (bool)
486
+ "is_rna": is_rna, # (N_tokens) (bool)
487
+ "is_dna": is_dna, # (N_tokens) (bool)
488
+ "is_ligand": is_ligand, # (N_tokens) (bool)
489
+ "terminus_type": terminus_type, # (N_tokens, 2) (int)
490
+ "is_polar": is_polar, # (N_tokens) (bool)
491
+ }
492
+ data["feat_metadata"] |= {
493
+ "asym_name": asym_name, # (N_asyms)
494
+ "entity_name": entity_name, # (N_entities)
495
+ "sym_name": sym_name, # (N_entities)
496
+ }
497
+
498
+ return data