rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Virtual-atom transforms for Atom14
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import biotite.structure as struc
|
|
6
|
+
import numpy as np
|
|
7
|
+
from atomworks.io.utils.atom_array_plus import insert_atoms
|
|
8
|
+
from atomworks.ml.transforms.base import (
|
|
9
|
+
Transform,
|
|
10
|
+
)
|
|
11
|
+
from atomworks.ml.utils.token import get_token_starts
|
|
12
|
+
from rfd3.constants import (
|
|
13
|
+
ATOM14_ATOM_NAME_TO_ELEMENT,
|
|
14
|
+
ATOM14_ATOM_NAMES,
|
|
15
|
+
VIRTUAL_ATOM_ELEMENT_NAME,
|
|
16
|
+
association_schemes,
|
|
17
|
+
association_schemes_stripped,
|
|
18
|
+
ccd_ordering_atomchar,
|
|
19
|
+
)
|
|
20
|
+
from rfd3.transforms.conditioning_base import (
|
|
21
|
+
UnindexFlaggedTokens,
|
|
22
|
+
)
|
|
23
|
+
from rfd3.transforms.util_transforms import (
|
|
24
|
+
assert_single_representative,
|
|
25
|
+
get_af3_token_representative_masks,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
from foundry.common import exists
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def map_to_association_scheme(atom_names: list | str, res_name: str, scheme="atom14"):
|
|
32
|
+
"""
|
|
33
|
+
Maps a list of names to the atom14 naming scheme for that particular name (within a specific residue)
|
|
34
|
+
NB this function is a bit more general since it is used to handle tipatoms too.
|
|
35
|
+
"""
|
|
36
|
+
if scheme not in association_schemes_stripped:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
f"Scheme {scheme} not found in association_schemes_stripped. Available schemes: {list(association_schemes_stripped.keys())}"
|
|
39
|
+
)
|
|
40
|
+
atom_names = (
|
|
41
|
+
[str(atom_names)] if isinstance(atom_names, (str, np.str_)) else atom_names
|
|
42
|
+
)
|
|
43
|
+
idxs = np.array(
|
|
44
|
+
[
|
|
45
|
+
association_schemes_stripped[scheme][res_name].index(name)
|
|
46
|
+
for name in atom_names
|
|
47
|
+
]
|
|
48
|
+
)
|
|
49
|
+
return ATOM14_ATOM_NAMES[idxs]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def map_names_to_elements(
|
|
53
|
+
atom_names: list | str, default=VIRTUAL_ATOM_ELEMENT_NAME
|
|
54
|
+
) -> np.ndarray:
|
|
55
|
+
"""
|
|
56
|
+
Maps ATOM14 atom names to their corresponding elements.
|
|
57
|
+
If a name is not in ATOM14_ATOM_NAMES (e.g. if atom name is VX - virtual atom),
|
|
58
|
+
then it returns the default value
|
|
59
|
+
"""
|
|
60
|
+
atom_names = [atom_names] if isinstance(atom_names, str) else atom_names
|
|
61
|
+
elements = [ATOM14_ATOM_NAME_TO_ELEMENT.get(name, default) for name in atom_names]
|
|
62
|
+
return np.array(elements)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def generate_atom_mappings_(scheme="atom14"):
|
|
66
|
+
scheme = association_schemes[scheme]
|
|
67
|
+
|
|
68
|
+
atom_mapping = {}
|
|
69
|
+
symmetry_mapping = {}
|
|
70
|
+
|
|
71
|
+
for aaa, atom14_names in ccd_ordering_atomchar.items():
|
|
72
|
+
mapping = list(range(14))
|
|
73
|
+
scheme_names = scheme[aaa]
|
|
74
|
+
|
|
75
|
+
for ccd_index in range(len(atom14_names)):
|
|
76
|
+
atom14_name = atom14_names[ccd_index]
|
|
77
|
+
if atom14_name is not None:
|
|
78
|
+
assert (
|
|
79
|
+
atom14_name in scheme_names
|
|
80
|
+
), f"{atom14_name} not in CCD ordering for {aaa}"
|
|
81
|
+
scheme_index = scheme_names.index(atom14_name)
|
|
82
|
+
scheme_index_in_cur_mapping = mapping.index(scheme_index)
|
|
83
|
+
mapping[ccd_index], mapping[scheme_index_in_cur_mapping] = (
|
|
84
|
+
mapping[scheme_index_in_cur_mapping],
|
|
85
|
+
mapping[ccd_index],
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
assert set(mapping) == set(range(len(scheme_names)))
|
|
89
|
+
|
|
90
|
+
# atom_mapping[aaa] = mapping
|
|
91
|
+
atom_mapping[aaa] = mapping
|
|
92
|
+
|
|
93
|
+
##################################################################
|
|
94
|
+
# Temporarily comment this out
|
|
95
|
+
# if aaa in symmetric_atomchar:
|
|
96
|
+
# symmetry_mapping[aaa] = []
|
|
97
|
+
# for group in symmetric_atomchar[aaa]:
|
|
98
|
+
# indices = [atom14_names.index(name) for name in group]
|
|
99
|
+
# symmetry_mapping[aaa].append(indices)
|
|
100
|
+
symmetry_mapping = {}
|
|
101
|
+
##################################################################
|
|
102
|
+
|
|
103
|
+
# Test that the mapping is valid
|
|
104
|
+
for aaa in atom_mapping.keys():
|
|
105
|
+
idxs = atom_mapping[aaa]
|
|
106
|
+
|
|
107
|
+
assert len(idxs) == len(set(idxs)), f"Duplicate indices in mapping for {aaa}"
|
|
108
|
+
|
|
109
|
+
atom_mapping_expected = np.array(scheme[aaa])[idxs]
|
|
110
|
+
atom_mapping_actual = np.array(ccd_ordering_atomchar[aaa])
|
|
111
|
+
|
|
112
|
+
assert np.array_equal(
|
|
113
|
+
atom_mapping_expected, atom_mapping_actual
|
|
114
|
+
), f"Mapping mismatch for {aaa}: {atom_mapping_expected} != {atom_mapping_actual}"
|
|
115
|
+
|
|
116
|
+
return atom_mapping, symmetry_mapping
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def permute_symmetric_atom_names_(
|
|
120
|
+
atom_names: list, res_name: str, association_map: dict, symmetry_map: dict
|
|
121
|
+
) -> list:
|
|
122
|
+
# NB: Can leak GT sequence if the model receives the canconical ordering of atoms as input
|
|
123
|
+
# With the structure-local atom attention it will not unless N_keys(n_attn_seq_neighbours) > n_atom_attn_queries.
|
|
124
|
+
if res_name in association_map:
|
|
125
|
+
idx_to_swap = association_map[res_name]
|
|
126
|
+
atom_names = atom_names[idx_to_swap]
|
|
127
|
+
if res_name in symmetry_map:
|
|
128
|
+
for group in symmetry_map[res_name]:
|
|
129
|
+
if np.random.rand() < 0.5: # random swap
|
|
130
|
+
atom_names[group] = atom_names[group[::-1]]
|
|
131
|
+
return atom_names
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
#####################################################################################################
|
|
135
|
+
# Virtual atom transforms
|
|
136
|
+
#####################################################################################################
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class PadTokensWithVirtualAtoms(Transform):
|
|
140
|
+
"""
|
|
141
|
+
Pads tokens with virtual atoms to ensure all residue tokens have a fixed number of atoms
|
|
142
|
+
|
|
143
|
+
Applies padding only to the tokens who do not have sequence
|
|
144
|
+
Applies association schema during training and to tokens with sequence.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
requires_previous_transforms = [UnindexFlaggedTokens]
|
|
148
|
+
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
n_atoms_per_token,
|
|
152
|
+
atom_to_pad_from,
|
|
153
|
+
association_scheme,
|
|
154
|
+
):
|
|
155
|
+
self.n_atoms_per_token = n_atoms_per_token
|
|
156
|
+
self.atom_to_pad_from = atom_to_pad_from
|
|
157
|
+
self.association_scheme = association_scheme
|
|
158
|
+
if exists(association_scheme):
|
|
159
|
+
self.association_map_, self.symmetry_map_ = generate_atom_mappings_(
|
|
160
|
+
association_scheme
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
def forward(self, data: dict) -> dict:
|
|
164
|
+
atom_array = data["atom_array"]
|
|
165
|
+
starts = get_token_starts(atom_array, add_exclusive_stop=True)
|
|
166
|
+
token_starts = starts[:-1]
|
|
167
|
+
token_level_array = atom_array[token_starts]
|
|
168
|
+
is_motif_atom_with_fixed_seq = token_level_array.is_motif_atom_with_fixed_seq
|
|
169
|
+
is_motif_token_unindexed = token_level_array.is_motif_atom_unindexed
|
|
170
|
+
|
|
171
|
+
token_ids = np.unique(atom_array.token_id)
|
|
172
|
+
assert len(token_ids) == len(
|
|
173
|
+
is_motif_atom_with_fixed_seq
|
|
174
|
+
), "Token ids and token level array have different lengths!"
|
|
175
|
+
|
|
176
|
+
# Unindexed tokens are never fully atomized, but may be assigned as atomized to have repr atoms:
|
|
177
|
+
is_residue = (
|
|
178
|
+
token_level_array.is_protein & ~token_level_array.atomize
|
|
179
|
+
) | is_motif_token_unindexed
|
|
180
|
+
|
|
181
|
+
# Unindexed tokens are never padded, and so are treated as residues with fixed sequence.
|
|
182
|
+
is_paddable = is_residue & ~(
|
|
183
|
+
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
|
|
184
|
+
)
|
|
185
|
+
is_non_paddable_residue = is_residue & (
|
|
186
|
+
is_motif_atom_with_fixed_seq | is_motif_token_unindexed
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Collect virtual atoms to insert (we will insert them all at once)
|
|
190
|
+
virtual_atoms_to_insert = []
|
|
191
|
+
insert_positions = []
|
|
192
|
+
|
|
193
|
+
# First pass: collect virtual atoms for insertion
|
|
194
|
+
for token_id, (start, end) in enumerate(zip(starts[:-1], starts[1:])):
|
|
195
|
+
if is_paddable[token_id]:
|
|
196
|
+
token = atom_array[start:end]
|
|
197
|
+
# First, pad with virtual atoms if needed
|
|
198
|
+
n_pad = self.n_atoms_per_token - len(token)
|
|
199
|
+
if n_pad > 0:
|
|
200
|
+
mask = get_af3_token_representative_masks(
|
|
201
|
+
token, central_atom=self.atom_to_pad_from
|
|
202
|
+
)
|
|
203
|
+
assert_single_representative(token)
|
|
204
|
+
|
|
205
|
+
# ... Create virtual atoms
|
|
206
|
+
pad_atoms = token[mask].copy()
|
|
207
|
+
pad_atoms = (
|
|
208
|
+
pad_atoms[0]
|
|
209
|
+
if isinstance(pad_atoms, struc.AtomArray)
|
|
210
|
+
else pad_atoms
|
|
211
|
+
)
|
|
212
|
+
pad_atoms.element = VIRTUAL_ATOM_ELEMENT_NAME
|
|
213
|
+
|
|
214
|
+
# ... Expand to desired number of atoms
|
|
215
|
+
pad_array = struc.array([pad_atoms] * n_pad)
|
|
216
|
+
|
|
217
|
+
# ... Change occupancy | if any atom in the token has occupancy, set to 1.0
|
|
218
|
+
occ = 1.0 if pad_atoms.occupancy.sum() > 0.0 else 0.0
|
|
219
|
+
pad_array.occupancy = np.full(n_pad, occ)
|
|
220
|
+
|
|
221
|
+
# ... Even if the input pad_atoms are all motif, we don't ever want padded atoms to be motif
|
|
222
|
+
pad_array.is_motif_atom = np.zeros(n_pad, dtype=bool)
|
|
223
|
+
|
|
224
|
+
# Handle multidimensional annotations
|
|
225
|
+
def _fix_multidimensional_annotations_in_pad_array(
|
|
226
|
+
atomarray, padarray
|
|
227
|
+
):
|
|
228
|
+
for annotation in atomarray.get_annotation_categories():
|
|
229
|
+
if len(atomarray.get_annotation(annotation).shape) > 1:
|
|
230
|
+
stacked = np.stack(
|
|
231
|
+
padarray.get_annotation(annotation)
|
|
232
|
+
).astype(float)
|
|
233
|
+
padarray.del_annotation(annotation)
|
|
234
|
+
padarray.set_annotation(annotation, stacked)
|
|
235
|
+
return padarray
|
|
236
|
+
|
|
237
|
+
pad_array = _fix_multidimensional_annotations_in_pad_array(
|
|
238
|
+
token, pad_array
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Collect virtual atoms for later insertion
|
|
242
|
+
virtual_atoms_to_insert.append(pad_array)
|
|
243
|
+
insert_positions.append(end)
|
|
244
|
+
|
|
245
|
+
# Insert all virtual atoms at once using insert_atoms
|
|
246
|
+
if virtual_atoms_to_insert:
|
|
247
|
+
atom_array_padded = insert_atoms(
|
|
248
|
+
atom_array, virtual_atoms_to_insert, insert_positions
|
|
249
|
+
)
|
|
250
|
+
else:
|
|
251
|
+
atom_array_padded = atom_array
|
|
252
|
+
|
|
253
|
+
# Initialize gt_atom_name annotation if it doesn't exist
|
|
254
|
+
if "gt_atom_name" not in atom_array_padded.get_annotation_categories():
|
|
255
|
+
atom_array_padded.set_annotation(
|
|
256
|
+
"gt_atom_name", np.empty(len(atom_array_padded), dtype="U4")
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Second pass: process tokens with proper atom name assignment after padding
|
|
260
|
+
# Get updated token starts after padding
|
|
261
|
+
starts_padded = get_token_starts(atom_array_padded, add_exclusive_stop=True)
|
|
262
|
+
|
|
263
|
+
for token_id, (start, end) in enumerate(
|
|
264
|
+
zip(starts_padded[:-1], starts_padded[1:])
|
|
265
|
+
):
|
|
266
|
+
if is_paddable[token_id]:
|
|
267
|
+
# ... Permutation of atom names during training
|
|
268
|
+
if not data["is_inference"] and exists(self.association_scheme):
|
|
269
|
+
atom_names = permute_symmetric_atom_names_(
|
|
270
|
+
ATOM14_ATOM_NAMES,
|
|
271
|
+
atom_array_padded.res_name[start],
|
|
272
|
+
association_map=self.association_map_,
|
|
273
|
+
symmetry_map=self.symmetry_map_,
|
|
274
|
+
)
|
|
275
|
+
else:
|
|
276
|
+
atom_names = ATOM14_ATOM_NAMES
|
|
277
|
+
atom_array_padded.atom_name[start:end] = atom_names
|
|
278
|
+
atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names
|
|
279
|
+
|
|
280
|
+
elif is_non_paddable_residue[token_id]:
|
|
281
|
+
# When sequence-constrained, we want to directly map the residue name based on the sequence
|
|
282
|
+
atom_names, res_name = (
|
|
283
|
+
atom_array_padded.atom_name[start:end],
|
|
284
|
+
atom_array_padded.res_name[start],
|
|
285
|
+
)
|
|
286
|
+
atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names
|
|
287
|
+
atom_names = map_to_association_scheme(
|
|
288
|
+
atom_names, res_name, scheme=self.association_scheme
|
|
289
|
+
)
|
|
290
|
+
atom_array_padded.atom_name[start:end] = atom_names
|
|
291
|
+
else:
|
|
292
|
+
# ... Add gt_atom_name annotation to other tokens
|
|
293
|
+
atom_names = atom_array_padded.atom_name[start:end]
|
|
294
|
+
atom_array_padded.get_annotation("gt_atom_name")[start:end] = atom_names
|
|
295
|
+
|
|
296
|
+
# ... Update atom array
|
|
297
|
+
assert {VIRTUAL_ATOM_ELEMENT_NAME} != set(
|
|
298
|
+
atom_array_padded.element[start:end].tolist()
|
|
299
|
+
), (
|
|
300
|
+
"Padded atoms should be virtual atoms, but found: "
|
|
301
|
+
f"{set(atom_array_padded.element[start:end].tolist())}"
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
data["atom_array"] = atom_array_padded
|
|
305
|
+
return data
|