rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,523 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from atomworks.enums import ChainType
|
|
3
|
+
from atomworks.ml.transforms._checks import (
|
|
4
|
+
check_atom_array_annotation,
|
|
5
|
+
check_contains_keys,
|
|
6
|
+
)
|
|
7
|
+
from atomworks.ml.transforms.base import Transform
|
|
8
|
+
from atomworks.ml.transforms.crop import resize_crop_info_if_too_many_atoms
|
|
9
|
+
from atomworks.ml.utils.token import (
|
|
10
|
+
get_token_count,
|
|
11
|
+
spread_token_wise,
|
|
12
|
+
)
|
|
13
|
+
from biotite.structure.basepairs import (
|
|
14
|
+
_check_dssr_criteria,
|
|
15
|
+
_get_proximate_residues,
|
|
16
|
+
get_residue_masks,
|
|
17
|
+
get_residue_starts_for,
|
|
18
|
+
)
|
|
19
|
+
from scipy.spatial import distance_matrix
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def protein_dna_contact_contiguous_crop_mask(
|
|
23
|
+
atom_array,
|
|
24
|
+
protein_contact_atoms,
|
|
25
|
+
dna_contact_atoms,
|
|
26
|
+
contact_dist_cutoff,
|
|
27
|
+
protein_expand_min,
|
|
28
|
+
protein_expand_max,
|
|
29
|
+
dna_expand_min,
|
|
30
|
+
dna_expand_max,
|
|
31
|
+
):
|
|
32
|
+
dna_contact, prot_contact = identify_and_sample_protein_dna_contact(
|
|
33
|
+
atom_array, protein_contact_atoms, dna_contact_atoms, contact_dist_cutoff
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# total_protein_expand = np.random.randint(protein_expand_min, protein_expand_max)
|
|
37
|
+
left = np.random.randint(protein_expand_min, protein_expand_max)
|
|
38
|
+
right = np.random.randint(protein_expand_min, protein_expand_max)
|
|
39
|
+
protein_keep_mask = expand_connected_component_mask(
|
|
40
|
+
atom_array, prot_contact, left, right
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# total_dna_expand = np.random.randint(dna_expand_min, dna_expand_max)
|
|
44
|
+
left = np.random.randint(dna_expand_min, dna_expand_max)
|
|
45
|
+
right = np.random.randint(dna_expand_min, dna_expand_max)
|
|
46
|
+
dna_keep_mask = get_dna_mask(atom_array, dna_contact, left, right)
|
|
47
|
+
# count keep protein token num and dna token num
|
|
48
|
+
|
|
49
|
+
mask = np.logical_or(protein_keep_mask, dna_keep_mask)
|
|
50
|
+
|
|
51
|
+
requires_crop = np.any(mask)
|
|
52
|
+
crop_atom_idxs = np.where(mask)[0]
|
|
53
|
+
|
|
54
|
+
token_id = np.arange(get_token_count(atom_array), dtype=np.uint32)
|
|
55
|
+
crop_token_idxs = spread_token_wise(atom_array, token_id)[mask]
|
|
56
|
+
|
|
57
|
+
if get_token_count(atom_array[mask]) > 300:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
"Noncanonical DNAs are causing token count explosion, skipping..."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
return {
|
|
63
|
+
"type": "ProteinDNAContactContiguousCrop",
|
|
64
|
+
"requires_crop": requires_crop,
|
|
65
|
+
"crop_atom_idxs": crop_atom_idxs,
|
|
66
|
+
"crop_token_idxs": crop_token_idxs,
|
|
67
|
+
"atom_array": atom_array,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def atom_array_from_contact_dict(atom_array, contact_atoms):
|
|
72
|
+
mask = []
|
|
73
|
+
for row in atom_array:
|
|
74
|
+
if (
|
|
75
|
+
row.res_name in contact_atoms.keys()
|
|
76
|
+
and row.atom_name in contact_atoms[row.res_name]
|
|
77
|
+
):
|
|
78
|
+
mask.append(True)
|
|
79
|
+
else:
|
|
80
|
+
mask.append(False)
|
|
81
|
+
|
|
82
|
+
return atom_array[mask]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def identify_and_sample_protein_dna_contact(
|
|
86
|
+
atom_array, protein_contact_atoms, dna_contact_atoms, contact_dist=4
|
|
87
|
+
):
|
|
88
|
+
if isinstance(protein_contact_atoms, dict):
|
|
89
|
+
protein = atom_array_from_contact_dict(atom_array, protein_contact_atoms)
|
|
90
|
+
elif isinstance(protein_contact_atoms, list):
|
|
91
|
+
protein = atom_array[
|
|
92
|
+
(atom_array.chain_type == ChainType.POLYPEPTIDE_L)
|
|
93
|
+
& np.isin(atom_array.atom_name, protein_contact_atoms)
|
|
94
|
+
]
|
|
95
|
+
elif isinstance(protein_contact_atoms, str):
|
|
96
|
+
if protein_contact_atoms == "all":
|
|
97
|
+
protein = atom_array[(atom_array.chain_type == ChainType.POLYPEPTIDE_L)]
|
|
98
|
+
else:
|
|
99
|
+
raise ValueError
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError
|
|
102
|
+
|
|
103
|
+
if isinstance(dna_contact_atoms, dict):
|
|
104
|
+
atom_array = atom_array[atom_array.chain_type == ChainType.DNA]
|
|
105
|
+
dna = atom_array_from_contact_dict(atom_array, dna_contact_atoms)
|
|
106
|
+
elif isinstance(dna_contact_atoms, list):
|
|
107
|
+
dna = atom_array[
|
|
108
|
+
(atom_array.chain_type == ChainType.DNA)
|
|
109
|
+
& (np.isin(atom_array.atom_name, dna_contact_atoms))
|
|
110
|
+
]
|
|
111
|
+
elif isinstance(dna_contact_atoms, str):
|
|
112
|
+
if dna_contact_atoms == "all":
|
|
113
|
+
dna = atom_array[(atom_array.chain_type == ChainType.DNA)]
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError
|
|
116
|
+
else:
|
|
117
|
+
raise ValueError
|
|
118
|
+
pdist = distance_matrix(dna.coord, protein.coord)
|
|
119
|
+
|
|
120
|
+
contacts = np.stack(np.where(pdist < contact_dist), axis=1)
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
sample = contacts[np.random.choice(range(len(contacts)))]
|
|
124
|
+
except Exception:
|
|
125
|
+
raise ValueError("No protein-DNA contacts found")
|
|
126
|
+
|
|
127
|
+
dna_contact = dna[sample[0]]
|
|
128
|
+
prot_contact = protein[sample[1]]
|
|
129
|
+
|
|
130
|
+
return dna_contact, prot_contact
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def create_residue_mask(atom_array, first_atom_indices):
|
|
134
|
+
"""
|
|
135
|
+
Creates a boolean mask for entire residues based on indices of their first atoms.
|
|
136
|
+
Uses efficient broadcasting for better performance.
|
|
137
|
+
|
|
138
|
+
Parameters
|
|
139
|
+
----------
|
|
140
|
+
atom_array : biotite.structure.atom_array
|
|
141
|
+
The atom array to create the mask for
|
|
142
|
+
first_atom_indices : array-like
|
|
143
|
+
Indices of the first atoms of the residues to select
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
numpy.ndarray
|
|
148
|
+
Boolean mask that can be used to select all atoms of the specified residues
|
|
149
|
+
"""
|
|
150
|
+
# Get target residue IDs and chain IDs as 2D arrays
|
|
151
|
+
target_res_ids = atom_array.res_id[first_atom_indices][:, np.newaxis]
|
|
152
|
+
target_chain_ids = atom_array.chain_id[first_atom_indices][:, np.newaxis]
|
|
153
|
+
|
|
154
|
+
# Use broadcasting to create masks for all residues at once
|
|
155
|
+
res_match = atom_array.res_id == target_res_ids
|
|
156
|
+
chain_match = atom_array.chain_id == target_chain_ids
|
|
157
|
+
|
|
158
|
+
# Combine the matches
|
|
159
|
+
mask = (res_match & chain_match).any(axis=0)
|
|
160
|
+
|
|
161
|
+
return mask
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def expand_connected_component_mask(atom_array, origin, left_expand, right_expand):
|
|
165
|
+
center = origin.within_poly_res_idx
|
|
166
|
+
left = center - left_expand
|
|
167
|
+
right = center + right_expand
|
|
168
|
+
candidates = list(range(left, right))
|
|
169
|
+
keep_mask = (atom_array.chain_id == origin.chain_id) & np.isin(
|
|
170
|
+
atom_array.within_poly_res_idx, candidates
|
|
171
|
+
)
|
|
172
|
+
return keep_mask
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def get_dna_mask(atom_array, origin, left_expand, right_expand):
|
|
176
|
+
one_chain_mask = expand_connected_component_mask(
|
|
177
|
+
atom_array, origin, left_expand, right_expand
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
pairs = base_pairs(atom_array)
|
|
181
|
+
|
|
182
|
+
other_chain_first_atom_indices = []
|
|
183
|
+
one_chain_first_atom_tags = np.zeros(len(atom_array), dtype=bool)
|
|
184
|
+
for pair in pairs:
|
|
185
|
+
if one_chain_mask[pair[0]]:
|
|
186
|
+
other_chain_first_atom_indices.append(pair[1])
|
|
187
|
+
one_chain_first_atom_tags[pair[0]] = True
|
|
188
|
+
|
|
189
|
+
elif one_chain_mask[pair[1]]:
|
|
190
|
+
other_chain_first_atom_indices.append(pair[0])
|
|
191
|
+
one_chain_first_atom_tags[pair[1]] = True
|
|
192
|
+
|
|
193
|
+
other_chain_mask = create_residue_mask(atom_array, other_chain_first_atom_indices)
|
|
194
|
+
|
|
195
|
+
return np.logical_or(one_chain_mask, other_chain_mask)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class ProteinDNAContactContiguousCrop(Transform):
|
|
199
|
+
"""
|
|
200
|
+
A transform the crops the DNA-protein contact region according to the continous region of contact.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
protein_contact_type (str): The type of protein contact atoms to consider. Can be 'backbone', 'sidechain', 'all', or 'from_dict'
|
|
204
|
+
dna_contact_type (str): The type of DNA contact atoms to consider. Can be 'backbone', 'base', 'all', or 'from_dict'
|
|
205
|
+
contact_distance_cutoff (float): The distance cutoff for considering two atoms to be in contact
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
def __init__(
|
|
209
|
+
self,
|
|
210
|
+
protein_contact_type,
|
|
211
|
+
dna_contact_type,
|
|
212
|
+
contact_distance_cutoff=10.0,
|
|
213
|
+
protein_expand_min=15,
|
|
214
|
+
protein_expand_max=40,
|
|
215
|
+
dna_expand_min=3,
|
|
216
|
+
dna_expand_max=10,
|
|
217
|
+
keep_uncropped_atom_array: bool = False,
|
|
218
|
+
max_atoms_in_crop=None,
|
|
219
|
+
protein_contact_atom_dict=None,
|
|
220
|
+
dna_contact_atom_dict=None,
|
|
221
|
+
):
|
|
222
|
+
if protein_contact_type == "backbone":
|
|
223
|
+
self.protein_contact_atoms = ["N", "CA", "C"]
|
|
224
|
+
elif protein_contact_type == "all":
|
|
225
|
+
self.protein_contact_atoms = "all"
|
|
226
|
+
|
|
227
|
+
if dna_contact_type == "backbone":
|
|
228
|
+
self.dna_contact_atoms = ["P", "OP1", "OP2"]
|
|
229
|
+
elif dna_contact_type == "base":
|
|
230
|
+
self.dna_contact_atoms = {
|
|
231
|
+
"DA": ["N7", "N6"],
|
|
232
|
+
"DC": ["N4"],
|
|
233
|
+
"DG": ["N7", "O6"],
|
|
234
|
+
"DT": ["O4"],
|
|
235
|
+
}
|
|
236
|
+
else:
|
|
237
|
+
self.dna_contact_atoms = "all"
|
|
238
|
+
|
|
239
|
+
self.protein_contact_type = protein_contact_type
|
|
240
|
+
self.dna_contact_type = dna_contact_type
|
|
241
|
+
|
|
242
|
+
self.protein_expand_min = protein_expand_min
|
|
243
|
+
self.protein_expand_max = protein_expand_max
|
|
244
|
+
self.dna_expand_min = dna_expand_min
|
|
245
|
+
self.dna_expand_max = dna_expand_max
|
|
246
|
+
self.contact_distance_cutoff = contact_distance_cutoff
|
|
247
|
+
|
|
248
|
+
self.keep_uncropped_atom_array = keep_uncropped_atom_array
|
|
249
|
+
self.max_atoms_in_crop = max_atoms_in_crop
|
|
250
|
+
|
|
251
|
+
def check_input(self, data: dict):
|
|
252
|
+
check_contains_keys(data, ["atom_array"])
|
|
253
|
+
check_atom_array_annotation(data, ["res_name"])
|
|
254
|
+
|
|
255
|
+
def forward(self, data: dict) -> dict:
|
|
256
|
+
atom_array = data["atom_array"]
|
|
257
|
+
|
|
258
|
+
crop_info = protein_dna_contact_contiguous_crop_mask(
|
|
259
|
+
atom_array,
|
|
260
|
+
self.protein_contact_atoms,
|
|
261
|
+
self.dna_contact_atoms,
|
|
262
|
+
self.contact_distance_cutoff,
|
|
263
|
+
self.protein_expand_min,
|
|
264
|
+
self.protein_expand_max,
|
|
265
|
+
self.dna_expand_min,
|
|
266
|
+
self.dna_expand_max,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
crop_info = resize_crop_info_if_too_many_atoms(
|
|
270
|
+
crop_info=crop_info,
|
|
271
|
+
atom_array=atom_array,
|
|
272
|
+
max_atoms=self.max_atoms_in_crop,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
if self.keep_uncropped_atom_array:
|
|
276
|
+
data["uncropped_atom_array"] = atom_array
|
|
277
|
+
|
|
278
|
+
if crop_info["requires_crop"]:
|
|
279
|
+
data["atom_array"] = atom_array[crop_info["crop_atom_idxs"]]
|
|
280
|
+
data["crop_info"] = crop_info
|
|
281
|
+
else:
|
|
282
|
+
data["atom_array"] = atom_array
|
|
283
|
+
return data
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def fill_nan_coords_with_random(atoms, min_val=-50, max_val=50, seed=None):
|
|
287
|
+
"""
|
|
288
|
+
Fill NaN coordinates in a biotite AtomArray with random values.
|
|
289
|
+
|
|
290
|
+
Parameters
|
|
291
|
+
----------
|
|
292
|
+
atoms : biotite.structure.AtomArray
|
|
293
|
+
The atom array containing coordinates to be filled
|
|
294
|
+
min_val : float, optional
|
|
295
|
+
Minimum value for random coordinates (default: -50)
|
|
296
|
+
max_val : float, optional
|
|
297
|
+
Maximum value for random coordinates (default: 50)
|
|
298
|
+
seed : int, optional
|
|
299
|
+
Random seed for reproducibility
|
|
300
|
+
|
|
301
|
+
Returns
|
|
302
|
+
-------
|
|
303
|
+
biotite.structure.AtomArray
|
|
304
|
+
A new AtomArray with NaN coordinates filled
|
|
305
|
+
"""
|
|
306
|
+
# Create a copy to avoid modifying the original
|
|
307
|
+
filled_atoms = atoms.copy()
|
|
308
|
+
|
|
309
|
+
# Set random seed if provided
|
|
310
|
+
if seed is not None:
|
|
311
|
+
np.random.seed(seed)
|
|
312
|
+
|
|
313
|
+
# Get the coordinate array
|
|
314
|
+
coords = filled_atoms.coord
|
|
315
|
+
|
|
316
|
+
# Find indices of NaN values
|
|
317
|
+
nan_mask = np.isnan(coords)
|
|
318
|
+
|
|
319
|
+
# Generate random values for NaN positions
|
|
320
|
+
random_coords = np.random.uniform(
|
|
321
|
+
low=min_val, high=max_val, size=coords[nan_mask].shape
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
# Fill NaN values with random coordinates
|
|
325
|
+
coords[nan_mask] = random_coords
|
|
326
|
+
|
|
327
|
+
return filled_atoms
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def base_pairs(atom_array, min_atoms_per_base=3, unique=True):
|
|
331
|
+
"""
|
|
332
|
+
Use DSSR criteria to find the base pairs in an :class:`atom_array`.
|
|
333
|
+
|
|
334
|
+
The algorithm is able to identify canonical and non-canonical
|
|
335
|
+
base pairs. between the 5 common bases Adenine, Guanine, Thymine,
|
|
336
|
+
Cytosine, and Uracil bound to Deoxyribose and Ribose.
|
|
337
|
+
Each Base is mapped to the 5 common bases Adenine, Guanine, Thymine,
|
|
338
|
+
Cytosine, and Uracil in a standard reference frame described in
|
|
339
|
+
:footcite:`Olson2001` using :func:`map_nucleotide()`.
|
|
340
|
+
|
|
341
|
+
The DSSR Criteria are as follows :footcite:`Lu2015`:
|
|
342
|
+
|
|
343
|
+
(i) Distance between base origins <=15 Å
|
|
344
|
+
|
|
345
|
+
(ii) Vertical separation between the base planes <=2.5 Å
|
|
346
|
+
|
|
347
|
+
(iii) Angle between the base normal vectors <=65°
|
|
348
|
+
|
|
349
|
+
(iv) Absence of stacking between the two bases
|
|
350
|
+
|
|
351
|
+
(v) Presence of at least one hydrogen bond involving a base atom
|
|
352
|
+
|
|
353
|
+
Parameters
|
|
354
|
+
----------
|
|
355
|
+
atom_array : atom_array
|
|
356
|
+
The :class:`atom_array` to find base pairs in.
|
|
357
|
+
min_atoms_per_base : integer, optional (default: 3)
|
|
358
|
+
The number of atoms a nucleotides' base must have to be
|
|
359
|
+
considered a candidate for a base pair.
|
|
360
|
+
unique : bool, optional (default: True)
|
|
361
|
+
If ``True``, each base is assumed to be only paired with one
|
|
362
|
+
other base. If multiple pairings are plausible, the pairing with
|
|
363
|
+
the most hydrogen bonds is selected.
|
|
364
|
+
|
|
365
|
+
Returns
|
|
366
|
+
-------
|
|
367
|
+
basepairs : ndarray, dtype=int, shape=(n,2)
|
|
368
|
+
Each row is equivalent to one base pair and contains the first
|
|
369
|
+
indices of the residues corresponding to each base.
|
|
370
|
+
|
|
371
|
+
Notes
|
|
372
|
+
-----
|
|
373
|
+
The bases from the standard reference frame described in
|
|
374
|
+
:footcite:`Olson2001` were modified such that only the base atoms
|
|
375
|
+
are implemented.
|
|
376
|
+
Sugar atoms (specifically C1') were disregarded, as nucleosides such
|
|
377
|
+
as PSU do not posess the usual N-glycosidic linkage, thus leading to
|
|
378
|
+
inaccurate results.
|
|
379
|
+
|
|
380
|
+
The vertical separation is implemented as the scalar
|
|
381
|
+
projection of the distance vectors between the base origins
|
|
382
|
+
according to :footcite:`Lu1997` onto the averaged base normal
|
|
383
|
+
vectors.
|
|
384
|
+
|
|
385
|
+
The presence of base stacking is assumed if the following criteria
|
|
386
|
+
are met :footcite:`Gabb1996`:
|
|
387
|
+
|
|
388
|
+
(i) Distance between aromatic ring centers <=4.5 Å
|
|
389
|
+
|
|
390
|
+
(ii) Angle between the ring normal vectors <=23°
|
|
391
|
+
|
|
392
|
+
(iii) Angle between normalized distance vector between two ring
|
|
393
|
+
centers and both bases' normal vectors <=40°
|
|
394
|
+
|
|
395
|
+
Please note that ring normal vectors are assumed to be equal to the
|
|
396
|
+
base normal vectors.
|
|
397
|
+
|
|
398
|
+
For structures without hydrogens the accuracy of the algorithm is
|
|
399
|
+
limited as the hydrogen bonds can be only checked be checked for
|
|
400
|
+
plausibility.
|
|
401
|
+
A hydrogen bond is considered as plausible if a cutoff of 3.6 Å
|
|
402
|
+
between N/O atom pairs is met. 3.6Å was chosen as hydrogen bonds are
|
|
403
|
+
typically 1.5-2.5Å in length. N-H and O-H bonds have a length of
|
|
404
|
+
1.00Å and 0.96Å respectively. Thus, including some buffer, a 3.6Å
|
|
405
|
+
cutoff should cover all hydrogen bonds.
|
|
406
|
+
|
|
407
|
+
Examples
|
|
408
|
+
--------
|
|
409
|
+
Compute the base pairs for the structure with the PDB ID 1QXB:
|
|
410
|
+
|
|
411
|
+
>>> from os.path import join
|
|
412
|
+
>>> dna_helix = load_structure(join(path_to_structures, "base_pairs", "1qxb.cif"))
|
|
413
|
+
>>> basepairs = base_pairs(dna_helix)
|
|
414
|
+
>>> print(dna_helix[basepairs].res_name)
|
|
415
|
+
[['DC' 'DG']
|
|
416
|
+
['DG' 'DC']
|
|
417
|
+
['DC' 'DG']
|
|
418
|
+
['DG' 'DC']
|
|
419
|
+
['DA' 'DT']
|
|
420
|
+
['DA' 'DT']
|
|
421
|
+
['DT' 'DA']
|
|
422
|
+
['DT' 'DA']
|
|
423
|
+
['DC' 'DG']
|
|
424
|
+
['DG' 'DC']
|
|
425
|
+
['DC' 'DG']
|
|
426
|
+
['DG' 'DC']]
|
|
427
|
+
|
|
428
|
+
References
|
|
429
|
+
----------
|
|
430
|
+
|
|
431
|
+
.. footbibliography::
|
|
432
|
+
"""
|
|
433
|
+
dna_boolean = np.logical_and(
|
|
434
|
+
atom_array.chain_type == ChainType.DNA,
|
|
435
|
+
np.isin(atom_array.res_name, ["DA", "DG", "DT", "DC"]),
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
# Get the nucleotides for the given atom_array
|
|
439
|
+
# Disregard the phosphate-backbone
|
|
440
|
+
non_phosphate_boolean = ~np.isin(
|
|
441
|
+
atom_array.atom_name, ["O5'", "P", "OP1", "OP2", "OP3", "HOP2", "HOP3"]
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
# Combine the two boolean masks
|
|
445
|
+
boolean_mask = np.logical_and(non_phosphate_boolean, dna_boolean)
|
|
446
|
+
|
|
447
|
+
# Get only nucleosides
|
|
448
|
+
nucleosides = atom_array[boolean_mask]
|
|
449
|
+
|
|
450
|
+
# Get the base pair candidates according to a N/O cutoff distance,
|
|
451
|
+
# where each base is identified as the first index of its respective
|
|
452
|
+
# residue
|
|
453
|
+
n_o_mask = np.isin(nucleosides.element, ["N", "O"])
|
|
454
|
+
|
|
455
|
+
nucleosides = fill_nan_coords_with_random(nucleosides)
|
|
456
|
+
basepair_candidates, n_o_matches = _get_proximate_residues(
|
|
457
|
+
nucleosides, n_o_mask, 3.6
|
|
458
|
+
)
|
|
459
|
+
|
|
460
|
+
# Contains the plausible base pairs
|
|
461
|
+
basepairs = []
|
|
462
|
+
# Contains the number of hydrogens for each plausible base pair
|
|
463
|
+
basepairs_hbonds = []
|
|
464
|
+
|
|
465
|
+
# Get the residue masks for each residue
|
|
466
|
+
base_masks = get_residue_masks(nucleosides, basepair_candidates.flatten())
|
|
467
|
+
|
|
468
|
+
# Group every two masks together for easy iteration (each 'row' is
|
|
469
|
+
# respective to a row in ``basepair_candidates``)
|
|
470
|
+
base_masks = base_masks.reshape(
|
|
471
|
+
(basepair_candidates.shape[0], 2, nucleosides.shape[0])
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
for (base1_index, base2_index), (base1_mask, base2_mask), n_o_pairs in zip(
|
|
475
|
+
basepair_candidates, base_masks, n_o_matches
|
|
476
|
+
):
|
|
477
|
+
base1 = nucleosides[base1_mask]
|
|
478
|
+
base2 = nucleosides[base2_mask]
|
|
479
|
+
|
|
480
|
+
hbonds = _check_dssr_criteria((base1, base2), min_atoms_per_base, unique)
|
|
481
|
+
|
|
482
|
+
# If no hydrogens are present use the number N/O pairs to
|
|
483
|
+
# decide between multiple pairing possibilities.
|
|
484
|
+
|
|
485
|
+
if hbonds is None:
|
|
486
|
+
# Each N/O-pair is detected twice. Thus, the number of
|
|
487
|
+
# matches must be divided by two.
|
|
488
|
+
hbonds = n_o_pairs / 2
|
|
489
|
+
if hbonds != -1:
|
|
490
|
+
basepairs.append((base1_index, base2_index))
|
|
491
|
+
if unique:
|
|
492
|
+
basepairs_hbonds.append(hbonds)
|
|
493
|
+
|
|
494
|
+
basepair_array = np.array(basepairs)
|
|
495
|
+
|
|
496
|
+
if unique:
|
|
497
|
+
# Contains all non-unique base pairs that are flagged to be
|
|
498
|
+
# removed
|
|
499
|
+
to_remove = []
|
|
500
|
+
|
|
501
|
+
# Get all bases that have non-unique pairing interactions
|
|
502
|
+
base_indices, occurrences = np.unique(basepairs, return_counts=True)
|
|
503
|
+
for base_index, occurrence in zip(base_indices, occurrences):
|
|
504
|
+
if occurrence > 1:
|
|
505
|
+
# Write the non-unique base pairs to a dictionary as
|
|
506
|
+
# 'index: number of hydrogen bonds'
|
|
507
|
+
remove_candidates = {}
|
|
508
|
+
for i, row in enumerate(np.asarray(basepair_array == base_index)):
|
|
509
|
+
if np.any(row):
|
|
510
|
+
remove_candidates[i] = basepairs_hbonds[i]
|
|
511
|
+
# Flag all non-unique base pairs for removal except the
|
|
512
|
+
# one that has the most hydrogen bonds
|
|
513
|
+
del remove_candidates[max(remove_candidates, key=remove_candidates.get)]
|
|
514
|
+
to_remove += list(remove_candidates.keys())
|
|
515
|
+
# Remove all flagged base pairs from the output `ndarray`
|
|
516
|
+
basepair_array = np.delete(basepair_array, to_remove, axis=0)
|
|
517
|
+
|
|
518
|
+
# Remap values to original atom array
|
|
519
|
+
if len(basepair_array) > 0:
|
|
520
|
+
basepair_array = np.where(boolean_mask)[0][basepair_array]
|
|
521
|
+
for i, row in enumerate(basepair_array):
|
|
522
|
+
basepair_array[i] = get_residue_starts_for(atom_array, row)
|
|
523
|
+
return basepair_array
|