rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,498 @@
|
|
|
1
|
+
# see atomworks.ml.ransforms.feature_aggregation
|
|
2
|
+
import time
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from atomworks.constants import STANDARD_AA
|
|
9
|
+
from atomworks.enums import ChainTypeInfo
|
|
10
|
+
from atomworks.io.utils.sequence import (
|
|
11
|
+
is_purine,
|
|
12
|
+
is_pyrimidine,
|
|
13
|
+
)
|
|
14
|
+
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
|
15
|
+
from atomworks.ml.transforms._checks import (
|
|
16
|
+
check_atom_array_annotation,
|
|
17
|
+
check_contains_keys,
|
|
18
|
+
check_is_instance,
|
|
19
|
+
)
|
|
20
|
+
from atomworks.ml.transforms.atom_array import get_within_entity_idx
|
|
21
|
+
from atomworks.ml.transforms.base import Transform
|
|
22
|
+
from atomworks.ml.utils.token import (
|
|
23
|
+
get_token_count,
|
|
24
|
+
get_token_starts,
|
|
25
|
+
is_glycine,
|
|
26
|
+
is_protein_unknown,
|
|
27
|
+
is_standard_aa_not_glycine,
|
|
28
|
+
is_unknown_nucleotide,
|
|
29
|
+
spread_token_wise,
|
|
30
|
+
)
|
|
31
|
+
from biotite.structure import AtomArray
|
|
32
|
+
|
|
33
|
+
af3_sequence_encoding = AF3SequenceEncoding()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def assert_single_representative(token, central_atom="CB"):
|
|
37
|
+
mask = get_af3_token_representative_masks(token, central_atom=central_atom)
|
|
38
|
+
assert (
|
|
39
|
+
np.sum(mask) == 1
|
|
40
|
+
), f"No representative atom (CB) found. mask: {mask}\nToken: {token}"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def assert_single_token(token):
|
|
44
|
+
assert get_token_count(token) == 1, f"Token is not a single token: {token}"
|
|
45
|
+
assert_single_representative(token)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def add_representative_atom(token, central_atom="CB"):
|
|
49
|
+
if get_af3_token_representative_masks(token, central_atom=central_atom).sum() == 1:
|
|
50
|
+
return token
|
|
51
|
+
length = token.array_length()
|
|
52
|
+
token.atomize = np.array([True] + [False] * (length - 1), dtype=bool)
|
|
53
|
+
assert_single_representative(token)
|
|
54
|
+
return token
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class TimerWrapper(Transform):
|
|
58
|
+
def check_input(self, *args, **kwargs):
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
def __init__(self, transform):
|
|
62
|
+
self.transform = transform
|
|
63
|
+
|
|
64
|
+
def forward(self, data):
|
|
65
|
+
start = time.time()
|
|
66
|
+
data = self.transform.forward(data)
|
|
67
|
+
print(f"Time taken: {time.time() - start} s || Transform: {self.transform}")
|
|
68
|
+
return data
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class IPDB(Transform):
|
|
72
|
+
def forward(self, data):
|
|
73
|
+
aa = data["atom_array"] # noqa
|
|
74
|
+
import ipdb
|
|
75
|
+
|
|
76
|
+
ipdb.set_trace()
|
|
77
|
+
return data
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
sequence_encoding = AF3SequenceEncoding()
|
|
81
|
+
|
|
82
|
+
_aa_like_res_names = sequence_encoding.all_res_names[sequence_encoding.is_aa_like]
|
|
83
|
+
_rna_like_res_names = sequence_encoding.all_res_names[sequence_encoding.is_rna_like]
|
|
84
|
+
_dna_like_res_names = sequence_encoding.all_res_names[sequence_encoding.is_dna_like]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class AssignTypes(Transform):
|
|
88
|
+
"""
|
|
89
|
+
Assigns types to the atoms in the atom array using af3 sequence encoding scheme.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def check_input(self, data):
|
|
93
|
+
assert "atom_array" in data, "Input data must contain 'atom_array'."
|
|
94
|
+
|
|
95
|
+
def forward(self, data):
|
|
96
|
+
data["atom_array"] = assign_types_(data["atom_array"])
|
|
97
|
+
return data
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def assign_types_(atom_array):
|
|
101
|
+
token_starts = get_token_starts(atom_array)
|
|
102
|
+
res_names = atom_array[token_starts].res_name
|
|
103
|
+
token_id = np.arange(get_token_count(atom_array), dtype=np.uint32) # [n_tokens]
|
|
104
|
+
atom_to_token_map = spread_token_wise(atom_array, token_id)
|
|
105
|
+
|
|
106
|
+
is_protein = np.isin(res_names, _aa_like_res_names).astype(bool)
|
|
107
|
+
is_residue = np.isin(res_names, STANDARD_AA).astype(bool)
|
|
108
|
+
is_rna = np.isin(res_names, _rna_like_res_names).astype(bool)
|
|
109
|
+
is_dna = np.isin(res_names, _dna_like_res_names).astype(bool)
|
|
110
|
+
is_ligand = ~(is_protein | is_rna | is_dna).astype(bool)
|
|
111
|
+
|
|
112
|
+
# Set annotations
|
|
113
|
+
atom_array.set_annotation("is_protein", is_protein[atom_to_token_map])
|
|
114
|
+
atom_array.set_annotation("is_rna", is_rna[atom_to_token_map])
|
|
115
|
+
atom_array.set_annotation("is_dna", is_dna[atom_to_token_map])
|
|
116
|
+
atom_array.set_annotation("is_ligand", is_ligand[atom_to_token_map])
|
|
117
|
+
atom_array.set_annotation("is_residue", is_residue[atom_to_token_map])
|
|
118
|
+
|
|
119
|
+
return atom_array
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class AggregateFeaturesLikeAF3WithoutMSA(Transform):
|
|
123
|
+
"""
|
|
124
|
+
Exactly like AggregateFeaturesLikeAF3 but without MSAs
|
|
125
|
+
|
|
126
|
+
Removed comments for readability, no additional code is in this function, just removed msa parts
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
requires_previous_transforms = [
|
|
130
|
+
"AtomizeByCCDName",
|
|
131
|
+
"EncodeAF3TokenLevelFeatures",
|
|
132
|
+
"AddAF3TokenBondFeatures",
|
|
133
|
+
"UnindexFlaggedTokens",
|
|
134
|
+
]
|
|
135
|
+
incompatible_previous_transforms = [
|
|
136
|
+
"AggregateFeaturesLikeAF3",
|
|
137
|
+
"AggregateFeaturesLikeAF3WithoutMSA",
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
def check_input(self, data) -> None:
|
|
141
|
+
check_contains_keys(data, ["atom_array"])
|
|
142
|
+
check_is_instance(data, "atom_array", AtomArray)
|
|
143
|
+
check_atom_array_annotation(
|
|
144
|
+
data, ["coord_to_be_noised", "chain_iid", "occupancy"]
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def forward(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
148
|
+
"""
|
|
149
|
+
Aggregates features into the format expected by AlphaFold 3.
|
|
150
|
+
|
|
151
|
+
This method processes the input data, combining MSA features, ground truth
|
|
152
|
+
structures, and other relevant information into a standardized format.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
data (Dict[str, Any]): The input data dictionary containing MSA features,
|
|
156
|
+
atom array, and other relevant information.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
Dict[str, Any]: The processed data dictionary with aggregated features.
|
|
160
|
+
"""
|
|
161
|
+
# Initialize feats dictionary if not present
|
|
162
|
+
if "feats" not in data:
|
|
163
|
+
data["feats"] = {}
|
|
164
|
+
|
|
165
|
+
data["feats"]["ref_atom_name_chars"] = F.one_hot(
|
|
166
|
+
data["feats"]["ref_atom_name_chars"].long(), num_classes=64
|
|
167
|
+
).float()
|
|
168
|
+
data["feats"]["ref_element"] = F.one_hot(
|
|
169
|
+
data["feats"]["ref_element"].long(), num_classes=128
|
|
170
|
+
).float()
|
|
171
|
+
data["feats"]["ref_pos"] = torch.nan_to_num(data["feats"]["ref_pos"], nan=0.0)
|
|
172
|
+
|
|
173
|
+
# Process ground truth structure
|
|
174
|
+
atom_array = data["atom_array"]
|
|
175
|
+
|
|
176
|
+
coord_atom_lvl = atom_array.coord
|
|
177
|
+
mask_atom_lvl = atom_array.occupancy > 0.0
|
|
178
|
+
token_starts = get_token_starts(atom_array)
|
|
179
|
+
token_level_array = atom_array[token_starts]
|
|
180
|
+
chain_iid_token_lvl = token_level_array.chain_iid
|
|
181
|
+
if "ground_truth" not in data:
|
|
182
|
+
data["ground_truth"] = {}
|
|
183
|
+
|
|
184
|
+
data["ground_truth"].update(
|
|
185
|
+
{
|
|
186
|
+
"coord_atom_lvl": torch.tensor(coord_atom_lvl), # [n_atoms, 3]
|
|
187
|
+
"mask_atom_lvl": torch.tensor(mask_atom_lvl), # [n_atoms]
|
|
188
|
+
"chain_iid_token_lvl": chain_iid_token_lvl, # numpy.ndarray of strings with shape (n_tokens,)
|
|
189
|
+
"is_original_unindexed_token": torch.from_numpy(
|
|
190
|
+
data["ground_truth"].get(
|
|
191
|
+
"is_original_unindexed_token",
|
|
192
|
+
np.zeros(len(token_starts), dtype=bool),
|
|
193
|
+
)
|
|
194
|
+
).bool(), # [n_tokens]
|
|
195
|
+
}
|
|
196
|
+
)
|
|
197
|
+
data["coord_atom_lvl_to_be_noised"] = torch.tensor(
|
|
198
|
+
atom_array.coord_to_be_noised
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Remove any token bond features relating to unindexed tokens
|
|
202
|
+
if "token_bonds" in data["feats"]:
|
|
203
|
+
token_bonds = data["feats"]["token_bonds"]
|
|
204
|
+
mask = data["feats"]["is_motif_token_unindexed"]
|
|
205
|
+
|
|
206
|
+
# tokens bonded to unindexed & unindexed bonded to tokens
|
|
207
|
+
token_bonds[mask, :] = False
|
|
208
|
+
token_bonds[:, mask] = False
|
|
209
|
+
|
|
210
|
+
# Add partial t during inference
|
|
211
|
+
if "partial_t" in atom_array.get_annotation_categories():
|
|
212
|
+
assert data["is_inference"], "Partial diffusion only inference!"
|
|
213
|
+
data["feats"]["partial_t"] = torch.from_numpy(
|
|
214
|
+
atom_array.get_annotation("partial_t")
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
return data
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def add_backbone_and_sidechain_annotations(atom_array: AtomArray) -> AtomArray:
|
|
221
|
+
"""
|
|
222
|
+
Adds the backbone and sidechain annotations to the AtomArray.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
atom_array (AtomArray): The AtomArray to which the annotations will be added.
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
AtomArray: The AtomArray with the added annotations.
|
|
229
|
+
"""
|
|
230
|
+
# Get the backbone atoms
|
|
231
|
+
atomized = atom_array.atomize
|
|
232
|
+
is_protein = np.isin(atom_array.chain_type, ChainTypeInfo.PROTEINS)
|
|
233
|
+
backbone_atoms = ["N", "CA", "C", "O"]
|
|
234
|
+
backbone_mask = np.isin(atom_array.atom_name, backbone_atoms) & is_protein
|
|
235
|
+
backbone_mask = backbone_mask | atomized
|
|
236
|
+
sidechain_mask = ~backbone_mask & ~atomized & is_protein
|
|
237
|
+
|
|
238
|
+
# Add the annotations
|
|
239
|
+
atom_array.set_annotation("is_backbone", backbone_mask)
|
|
240
|
+
atom_array.set_annotation("is_sidechain", sidechain_mask)
|
|
241
|
+
|
|
242
|
+
return atom_array
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
####################################################################################################
|
|
246
|
+
# Changes to datahub base transforms (instead of creating new branches)
|
|
247
|
+
####################################################################################################
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
# from atomworks.ml.utils.token import get_af3_token_representative_masks
|
|
251
|
+
def get_af3_token_representative_masks(
|
|
252
|
+
atom_array: AtomArray, central_atom: str = "CA"
|
|
253
|
+
) -> np.ndarray:
|
|
254
|
+
pyrimidine_representative_atom = is_pyrimidine(atom_array.res_name) & (
|
|
255
|
+
atom_array.atom_name == "C2"
|
|
256
|
+
)
|
|
257
|
+
purine_representative_atom = is_purine(atom_array.res_name) & (
|
|
258
|
+
atom_array.atom_name == "C4"
|
|
259
|
+
)
|
|
260
|
+
unknown_na_representative_atom = is_unknown_nucleotide(atom_array.res_name) & (
|
|
261
|
+
atom_array.atom_name == "C4"
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
glycine_representative_atom = is_glycine(atom_array.res_name) & (
|
|
265
|
+
atom_array.atom_name == "CA"
|
|
266
|
+
)
|
|
267
|
+
protein_residue_not_glycine_representative_atom = is_standard_aa_not_glycine(
|
|
268
|
+
atom_array.res_name
|
|
269
|
+
) & (
|
|
270
|
+
atom_array.atom_name == central_atom # only change
|
|
271
|
+
)
|
|
272
|
+
unknown_protein_residue_representative_atom = (
|
|
273
|
+
is_protein_unknown(atom_array.res_name)
|
|
274
|
+
) & (atom_array.atom_name == "CA")
|
|
275
|
+
atoms = atom_array.atomize
|
|
276
|
+
|
|
277
|
+
_token_rep_mask = (
|
|
278
|
+
pyrimidine_representative_atom
|
|
279
|
+
| purine_representative_atom
|
|
280
|
+
| unknown_na_representative_atom
|
|
281
|
+
| glycine_representative_atom
|
|
282
|
+
| protein_residue_not_glycine_representative_atom
|
|
283
|
+
| unknown_protein_residue_representative_atom
|
|
284
|
+
| atoms
|
|
285
|
+
)
|
|
286
|
+
return _token_rep_mask
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class RemoveTokensWithoutCorrespondingCentralAtom(Transform):
|
|
290
|
+
"""
|
|
291
|
+
Remove tokens with missing central atoms.
|
|
292
|
+
"""
|
|
293
|
+
|
|
294
|
+
def __init__(self, central_atom: str = "CA"):
|
|
295
|
+
self.central_atom = central_atom
|
|
296
|
+
|
|
297
|
+
def check_input(self, data):
|
|
298
|
+
check_contains_keys(data, ["atom_array"])
|
|
299
|
+
check_is_instance(data, "atom_array", AtomArray)
|
|
300
|
+
check_atom_array_annotation(data, ["atom_name", "res_name"])
|
|
301
|
+
|
|
302
|
+
def forward(self, data):
|
|
303
|
+
central_atom = self.central_atom
|
|
304
|
+
atom_array = data["atom_array"]
|
|
305
|
+
pyrimidine_mask = is_pyrimidine(atom_array.res_name)
|
|
306
|
+
purine_mask = is_purine(atom_array.res_name)
|
|
307
|
+
unknown_na_mask = is_unknown_nucleotide(atom_array.res_name)
|
|
308
|
+
glycine_mask = is_glycine(atom_array.res_name)
|
|
309
|
+
aa_not_glycine_mask = is_standard_aa_not_glycine(atom_array.res_name)
|
|
310
|
+
unknown_aa_mask = is_protein_unknown(atom_array.res_name)
|
|
311
|
+
|
|
312
|
+
anything_else_mask = ~(
|
|
313
|
+
pyrimidine_mask
|
|
314
|
+
| purine_mask
|
|
315
|
+
| unknown_na_mask
|
|
316
|
+
| glycine_mask
|
|
317
|
+
| aa_not_glycine_mask
|
|
318
|
+
| unknown_aa_mask
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
def _get_if_central_atom_present_mask(atom_array, case_mask, central_atom):
|
|
322
|
+
token_starts = get_token_starts(atom_array[case_mask])
|
|
323
|
+
central_atom_mask = atom_array[case_mask].atom_name == central_atom
|
|
324
|
+
if len(token_starts) == central_atom_mask.sum():
|
|
325
|
+
## all tokens have central atom, *vast majority*
|
|
326
|
+
return case_mask
|
|
327
|
+
else:
|
|
328
|
+
## find the missing ones, *very rare*
|
|
329
|
+
out_mask = case_mask
|
|
330
|
+
all_token_starts = get_token_starts(atom_array)
|
|
331
|
+
token_start_mask = case_mask[all_token_starts]
|
|
332
|
+
case_token_starts = all_token_starts[token_start_mask]
|
|
333
|
+
|
|
334
|
+
for item in case_token_starts:
|
|
335
|
+
res_start = item
|
|
336
|
+
idx = all_token_starts.tolist().index(res_start)
|
|
337
|
+
res_mask = np.bool_(np.zeros(len(atom_array)))
|
|
338
|
+
if idx == len(all_token_starts) - 1:
|
|
339
|
+
res_mask[res_start:] = True
|
|
340
|
+
else:
|
|
341
|
+
res_end = all_token_starts[idx + 1]
|
|
342
|
+
res_mask[res_start:res_end] = True
|
|
343
|
+
res_array = atom_array[res_mask]
|
|
344
|
+
|
|
345
|
+
# remove if central atom not present
|
|
346
|
+
if (res_array.atom_name == central_atom).sum() == 0:
|
|
347
|
+
out_mask = out_mask & ~res_mask
|
|
348
|
+
return out_mask
|
|
349
|
+
|
|
350
|
+
keep_mask = (
|
|
351
|
+
_get_if_central_atom_present_mask(atom_array, pyrimidine_mask, "C2")
|
|
352
|
+
| _get_if_central_atom_present_mask(atom_array, purine_mask, "C4")
|
|
353
|
+
| _get_if_central_atom_present_mask(atom_array, unknown_na_mask, "C4")
|
|
354
|
+
| _get_if_central_atom_present_mask(atom_array, glycine_mask, "CA")
|
|
355
|
+
| _get_if_central_atom_present_mask(
|
|
356
|
+
atom_array, aa_not_glycine_mask, central_atom
|
|
357
|
+
)
|
|
358
|
+
| _get_if_central_atom_present_mask(atom_array, unknown_aa_mask, "CA")
|
|
359
|
+
| anything_else_mask
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
data["atom_array"] = atom_array[keep_mask]
|
|
363
|
+
return data
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
class EncodeAF3TokenLevelFeatures(Transform):
|
|
367
|
+
def __init__(
|
|
368
|
+
self, sequence_encoding: AF3SequenceEncoding, encode_residues_to: int = None
|
|
369
|
+
):
|
|
370
|
+
self.sequence_encoding = sequence_encoding
|
|
371
|
+
self.encode_residues_to = encode_residues_to # for spoofing the restype
|
|
372
|
+
|
|
373
|
+
def check_input(self, data: dict[str, Any]) -> None:
|
|
374
|
+
check_contains_keys(data, ["atom_array"])
|
|
375
|
+
check_is_instance(data, "atom_array", AtomArray)
|
|
376
|
+
check_atom_array_annotation(
|
|
377
|
+
data,
|
|
378
|
+
[
|
|
379
|
+
"atomize",
|
|
380
|
+
"pn_unit_iid",
|
|
381
|
+
"chain_entity",
|
|
382
|
+
"res_name",
|
|
383
|
+
"within_chain_res_idx",
|
|
384
|
+
],
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
388
|
+
atom_array = data["atom_array"]
|
|
389
|
+
|
|
390
|
+
# ... get token-level array
|
|
391
|
+
token_starts = get_token_starts(atom_array)
|
|
392
|
+
token_level_array = atom_array[token_starts]
|
|
393
|
+
|
|
394
|
+
# ... identifier tokens
|
|
395
|
+
# ... (residue)
|
|
396
|
+
residue_index = token_level_array.within_chain_res_idx
|
|
397
|
+
# ... (token)
|
|
398
|
+
token_index = np.arange(len(token_starts))
|
|
399
|
+
# ... (chain instance)
|
|
400
|
+
asym_name, asym_id = np.unique(
|
|
401
|
+
token_level_array.pn_unit_iid, return_inverse=True
|
|
402
|
+
)
|
|
403
|
+
# ... (chain entity)
|
|
404
|
+
entity_name, entity_id = np.unique(
|
|
405
|
+
token_level_array.pn_unit_entity, return_inverse=True
|
|
406
|
+
)
|
|
407
|
+
# ... (within chain entity)
|
|
408
|
+
sym_name, sym_id = get_within_entity_idx(token_level_array, level="pn_unit")
|
|
409
|
+
|
|
410
|
+
# ... molecule type
|
|
411
|
+
_aa_like_res_names = self.sequence_encoding.all_res_names[
|
|
412
|
+
self.sequence_encoding.is_aa_like
|
|
413
|
+
]
|
|
414
|
+
is_protein = np.isin(token_level_array.res_name, _aa_like_res_names)
|
|
415
|
+
|
|
416
|
+
_rna_like_res_names = self.sequence_encoding.all_res_names[
|
|
417
|
+
self.sequence_encoding.is_rna_like
|
|
418
|
+
]
|
|
419
|
+
is_rna = np.isin(token_level_array.res_name, _rna_like_res_names)
|
|
420
|
+
|
|
421
|
+
_dna_like_res_names = self.sequence_encoding.all_res_names[
|
|
422
|
+
self.sequence_encoding.is_dna_like
|
|
423
|
+
]
|
|
424
|
+
is_dna = np.isin(token_level_array.res_name, _dna_like_res_names)
|
|
425
|
+
|
|
426
|
+
is_ligand = ~(is_protein | is_rna | is_dna)
|
|
427
|
+
|
|
428
|
+
# Get is_polar features
|
|
429
|
+
polar_restypes = np.array(
|
|
430
|
+
[
|
|
431
|
+
"SER",
|
|
432
|
+
"THR",
|
|
433
|
+
"ASN",
|
|
434
|
+
"GLN",
|
|
435
|
+
"TYR",
|
|
436
|
+
"CYS",
|
|
437
|
+
"HIS",
|
|
438
|
+
"LYS",
|
|
439
|
+
"ARG",
|
|
440
|
+
"ASP",
|
|
441
|
+
"GLU",
|
|
442
|
+
]
|
|
443
|
+
)
|
|
444
|
+
is_polar = is_protein & np.isin(token_level_array.res_name, polar_restypes)
|
|
445
|
+
|
|
446
|
+
# ... sequence tokens
|
|
447
|
+
res_names = token_level_array.res_name
|
|
448
|
+
if self.encode_residues_to is not None:
|
|
449
|
+
is_masked = ~token_level_array.is_motif_atom_with_fixed_seq
|
|
450
|
+
res_names[is_masked] = np.full(
|
|
451
|
+
np.sum(is_masked), self.encode_residues_to, dtype=res_names.dtype
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
restype = self.sequence_encoding.encode(res_names)
|
|
455
|
+
data["encoded"] = {"seq": restype} # For msa's
|
|
456
|
+
restype = F.one_hot(
|
|
457
|
+
torch.tensor(restype), num_classes=self.sequence_encoding.n_tokens
|
|
458
|
+
).numpy()
|
|
459
|
+
|
|
460
|
+
# ... Add termini annotations (n_tok, 2)
|
|
461
|
+
terminus_type = np.zeros(
|
|
462
|
+
(
|
|
463
|
+
len(token_level_array),
|
|
464
|
+
2,
|
|
465
|
+
),
|
|
466
|
+
dtype=restype.dtype,
|
|
467
|
+
)
|
|
468
|
+
terminus_type[token_level_array.is_C_terminus, 0] = 1
|
|
469
|
+
terminus_type[token_level_array.is_N_terminus, 1] = 1
|
|
470
|
+
|
|
471
|
+
# ... add to data dict
|
|
472
|
+
if "feats" not in data:
|
|
473
|
+
data["feats"] = {}
|
|
474
|
+
if "feat_metadata" not in data:
|
|
475
|
+
data["feat_metadata"] = {}
|
|
476
|
+
|
|
477
|
+
# ... add to data dict
|
|
478
|
+
data["feats"] |= {
|
|
479
|
+
"residue_index": residue_index, # (N_tokens) (int)
|
|
480
|
+
"token_index": token_index, # (N_tokens) (int)
|
|
481
|
+
"asym_id": asym_id, # (N_tokens) (int)
|
|
482
|
+
"entity_id": entity_id, # (N_tokens) (int)
|
|
483
|
+
"sym_id": sym_id, # (N_tokens) (int)
|
|
484
|
+
"restype": restype, # (N_tokens, 32) (float, one-hot)
|
|
485
|
+
"is_protein": is_protein, # (N_tokens) (bool)
|
|
486
|
+
"is_rna": is_rna, # (N_tokens) (bool)
|
|
487
|
+
"is_dna": is_dna, # (N_tokens) (bool)
|
|
488
|
+
"is_ligand": is_ligand, # (N_tokens) (bool)
|
|
489
|
+
"terminus_type": terminus_type, # (N_tokens, 2) (int)
|
|
490
|
+
"is_polar": is_polar, # (N_tokens) (bool)
|
|
491
|
+
}
|
|
492
|
+
data["feat_metadata"] |= {
|
|
493
|
+
"asym_name": asym_name, # (N_asyms)
|
|
494
|
+
"entity_name": entity_name, # (N_entities)
|
|
495
|
+
"sym_name": sym_name, # (N_entities)
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
return data
|