rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
from typing import Any, Literal, Tuple
|
|
2
|
+
|
|
3
|
+
import biotite.structure as struc
|
|
4
|
+
import hydride
|
|
5
|
+
import numpy as np
|
|
6
|
+
from atomworks.io.transforms.atom_array import remove_hydrogens
|
|
7
|
+
from atomworks.io.utils.ccd import atom_array_from_ccd_code
|
|
8
|
+
from atomworks.ml.transforms._checks import (
|
|
9
|
+
check_atom_array_annotation,
|
|
10
|
+
check_contains_keys,
|
|
11
|
+
check_is_instance,
|
|
12
|
+
)
|
|
13
|
+
from atomworks.ml.transforms.base import Transform
|
|
14
|
+
from biotite.structure import AtomArray, AtomArrayStack
|
|
15
|
+
from rfd3.constants import SELECTION_NONPROTEIN, SELECTION_PROTEIN
|
|
16
|
+
|
|
17
|
+
from foundry.utils.ddp import RankedLogger
|
|
18
|
+
|
|
19
|
+
ranked_logger = RankedLogger()
|
|
20
|
+
|
|
21
|
+
HYDROGEN_LIKE_SYMBOLS = ("H", "H2", "D", "T")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# TODO: Once the cifutils submodule is bumped, we can use the built-in add_hydrogen_atom_positions function
|
|
25
|
+
def add_hydrogen_atom_positions(
|
|
26
|
+
atom_array: AtomArray | AtomArrayStack,
|
|
27
|
+
) -> AtomArray | AtomArrayStack:
|
|
28
|
+
"""Add hydrogens using biotite supported hydride library
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
atom_array (AtomArray | AtomArrayStack): The atom array containing the chain information.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
AtomArray: The updated atom array with hydrogens added.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def _get_charge_from_ccd_code(atom):
|
|
38
|
+
try:
|
|
39
|
+
ccd_array = atom_array_from_ccd_code(atom.res_name)
|
|
40
|
+
charge = ccd_array[
|
|
41
|
+
ccd_array.atom_name.tolist().index(atom.atom_name)
|
|
42
|
+
].charge
|
|
43
|
+
except Exception:
|
|
44
|
+
## res_name not found in ccd or atom_name not found in ccd_array
|
|
45
|
+
charge = 0
|
|
46
|
+
return charge
|
|
47
|
+
|
|
48
|
+
if "charge" not in atom_array.get_annotation_categories():
|
|
49
|
+
charges = np.vectorize(_get_charge_from_ccd_code)(atom_array)
|
|
50
|
+
atom_array.set_annotation("charge", charges)
|
|
51
|
+
|
|
52
|
+
# Add as a custom annotation
|
|
53
|
+
|
|
54
|
+
array = remove_hydrogens(atom_array)
|
|
55
|
+
|
|
56
|
+
fields_to_copy_from_residue_if_present = [
|
|
57
|
+
"auth_seq_id",
|
|
58
|
+
"label_entity_id",
|
|
59
|
+
"is_can_prot",
|
|
60
|
+
"is_can_nucl",
|
|
61
|
+
"is_sm",
|
|
62
|
+
"chain_type",
|
|
63
|
+
]
|
|
64
|
+
fields_to_copy_from_residue_if_present = list(
|
|
65
|
+
set(fields_to_copy_from_residue_if_present).intersection(
|
|
66
|
+
set(atom_array.get_annotation_categories())
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def _copy_missing_annotations_residue_wise(
|
|
71
|
+
arr_to_copy_from: AtomArray,
|
|
72
|
+
arr_to_update: AtomArray,
|
|
73
|
+
fields_to_copy_from_residue_if_present: list[str],
|
|
74
|
+
) -> AtomArray:
|
|
75
|
+
"""Copy specified annotations residue-wise from one AtomArray to another. Updates annotations in-place."""
|
|
76
|
+
residue_starts = struc.get_residue_starts(arr_to_copy_from)
|
|
77
|
+
residue_starts_atom_array = arr_to_copy_from[residue_starts]
|
|
78
|
+
annot = {
|
|
79
|
+
item: getattr(residue_starts_atom_array, item)
|
|
80
|
+
for item in fields_to_copy_from_residue_if_present
|
|
81
|
+
}
|
|
82
|
+
for field in fields_to_copy_from_residue_if_present:
|
|
83
|
+
updated_field = struc.spread_residue_wise(arr_to_update, annot[field])
|
|
84
|
+
arr_to_update.set_annotation(field, updated_field)
|
|
85
|
+
return arr_to_update
|
|
86
|
+
|
|
87
|
+
def _handle_nan_coords(atom_array, noise_level=1e-3):
|
|
88
|
+
coords = atom_array.coord
|
|
89
|
+
|
|
90
|
+
# Find NaNs
|
|
91
|
+
nan_mask = np.isnan(coords)
|
|
92
|
+
|
|
93
|
+
# Replace NaNs with 0 + small random offset
|
|
94
|
+
coords[nan_mask] = np.random.uniform(
|
|
95
|
+
-noise_level, noise_level, size=nan_mask.sum()
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Update atom_array in-place
|
|
99
|
+
atom_array.coord = coords
|
|
100
|
+
return atom_array, nan_mask
|
|
101
|
+
|
|
102
|
+
if isinstance(array, AtomArrayStack):
|
|
103
|
+
updated_arrays = []
|
|
104
|
+
for old_arr in array:
|
|
105
|
+
if old_arr.bonds is None:
|
|
106
|
+
old_arr.bonds = struc.connect_via_distances(old_arr)
|
|
107
|
+
|
|
108
|
+
## give some values to nan
|
|
109
|
+
old_arr, nan_mask = _handle_nan_coords(old_arr)
|
|
110
|
+
arr, mask = hydride.add_hydrogen(old_arr)
|
|
111
|
+
## put back nans
|
|
112
|
+
arr.coord[mask, :][nan_mask] = np.nan
|
|
113
|
+
arr = _copy_missing_annotations_residue_wise(
|
|
114
|
+
old_arr, arr, fields_to_copy_from_residue_if_present
|
|
115
|
+
)
|
|
116
|
+
updated_arrays.append(arr)
|
|
117
|
+
|
|
118
|
+
ret_array = struc.stack(updated_arrays)
|
|
119
|
+
|
|
120
|
+
elif isinstance(array, AtomArray):
|
|
121
|
+
if array.bonds is None:
|
|
122
|
+
array.bonds = struc.connect_via_distances(array)
|
|
123
|
+
## give some values to nan
|
|
124
|
+
array, nan_mask = _handle_nan_coords(array)
|
|
125
|
+
arr, mask = hydride.add_hydrogen(array)
|
|
126
|
+
## put back nans
|
|
127
|
+
arr.coord[mask, :][nan_mask] = np.nan
|
|
128
|
+
ret_array = _copy_missing_annotations_residue_wise(
|
|
129
|
+
array, arr, fields_to_copy_from_residue_if_present
|
|
130
|
+
)
|
|
131
|
+
return ret_array
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def check_atom_array_has_hydrogen(data: dict[str, Any]):
|
|
135
|
+
"""Check if `atom_array` key has bonds."""
|
|
136
|
+
import numpy as np
|
|
137
|
+
|
|
138
|
+
if not np.any(data["atom_array"].element == "H"):
|
|
139
|
+
raise ValueError("Key `atom_array` in data has no hydrogens.")
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def calculate_hbonds(
|
|
143
|
+
atom_array: AtomArray,
|
|
144
|
+
selection1: np.ndarray = None,
|
|
145
|
+
selection2: np.ndarray = None,
|
|
146
|
+
selection1_type: Literal["acceptor", "donor", "both"] = "both",
|
|
147
|
+
cutoff_dist: float = 3,
|
|
148
|
+
cutoff_angle: float = 120,
|
|
149
|
+
donor_elements: Tuple[str] = ("O", "N", "S", "F"),
|
|
150
|
+
acceptor_elements: Tuple[str] = ("O", "N", "S", "F"),
|
|
151
|
+
periodic: bool = False,
|
|
152
|
+
) -> Tuple[np.ndarray, np.ndarray, AtomArray]:
|
|
153
|
+
"""
|
|
154
|
+
Calculates Hbonds with biotite.struc.Hbond.
|
|
155
|
+
Assigns donor, acceptor annotation for each heavy atom involved.
|
|
156
|
+
Args:
|
|
157
|
+
atom_array (AtomArray):Expects the atom_array that contains hydrogens.
|
|
158
|
+
|
|
159
|
+
selection1 and selection2 (np.ndarray, optional): (Boolean mask for atoms to limit the hydrogen bond search to specific sections of the model.
|
|
160
|
+
The shape must match the shape of the atoms argument. If None is given, the whole atoms stack is used instead. (Default: None))
|
|
161
|
+
|
|
162
|
+
selection1_type (Literal, optional): Determines the type of selection1. The type of selection2 is chosen accordingly (‘both’ or the opposite).
|
|
163
|
+
(Default: 'both')
|
|
164
|
+
cutoff_dist (float, optional): The maximal distance between the hydrogen and acceptor to be considered a hydrogen bond. (Default: 2.5)
|
|
165
|
+
cutoff_angle (float, optional): The angle cutoff in degree between Donor-H..Acceptor to be considered a hydrogen bond. (Default: 120)
|
|
166
|
+
donor_elements, acceptor_elements (tuple of str): Elements to be considered as possible donors or acceptors. (Default: O, N, S)
|
|
167
|
+
periodic (bool, optional): If true, hydrogen bonds can also be detected in periodic boundary conditions. The box attribute of atoms is required in this case. (Default: False)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
"""
|
|
171
|
+
# Remove NaN coordinates
|
|
172
|
+
has_resolved_coordinates = ~np.isnan(atom_array.coord).any(axis=-1)
|
|
173
|
+
nonNaN_array = atom_array[has_resolved_coordinates]
|
|
174
|
+
|
|
175
|
+
# update selections if any
|
|
176
|
+
if selection1 is not None:
|
|
177
|
+
selection1 = selection1[has_resolved_coordinates]
|
|
178
|
+
if selection2 is not None:
|
|
179
|
+
selection2 = selection2[has_resolved_coordinates]
|
|
180
|
+
|
|
181
|
+
## index map from nonNaN_array to original
|
|
182
|
+
index_map = {
|
|
183
|
+
counter: i for counter, i in enumerate(has_resolved_coordinates.nonzero()[0])
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
if selection1.sum() == 0 or selection2.sum() == 0:
|
|
187
|
+
# no ligand, or ligand is of same type as selection1 (e.g. 6) (peptide)
|
|
188
|
+
triplets = np.array([])
|
|
189
|
+
else:
|
|
190
|
+
# Compute H bonds
|
|
191
|
+
triplets = struc.hbond( ## assuming AtomArray, not AtomArrayStack (returns an extra masks in that case)
|
|
192
|
+
nonNaN_array,
|
|
193
|
+
selection1=selection1,
|
|
194
|
+
selection2=selection2,
|
|
195
|
+
selection1_type=selection1_type,
|
|
196
|
+
cutoff_dist=cutoff_dist,
|
|
197
|
+
cutoff_angle=cutoff_angle,
|
|
198
|
+
donor_elements=donor_elements,
|
|
199
|
+
acceptor_elements=acceptor_elements,
|
|
200
|
+
periodic=periodic,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
## map back triplet indices, nonNaN indices to original indices
|
|
204
|
+
flattened = triplets.flatten()
|
|
205
|
+
triplets = np.array([index_map[i] for i in flattened]).reshape(-1, 3)
|
|
206
|
+
|
|
207
|
+
## add back NaNs
|
|
208
|
+
|
|
209
|
+
donor_array = np.array([[0.0] * len(atom_array)])
|
|
210
|
+
acceptor_array = np.array([[0.0] * len(atom_array)])
|
|
211
|
+
|
|
212
|
+
if len(triplets) > 0:
|
|
213
|
+
donor_array[:, triplets[:, 0]] = 1.0
|
|
214
|
+
acceptor_array[:, triplets[:, 2]] = 1.0
|
|
215
|
+
|
|
216
|
+
## [is_active_donor, is_active_acceptor] per atom
|
|
217
|
+
types = np.vstack((donor_array, acceptor_array)).T
|
|
218
|
+
|
|
219
|
+
return triplets, types, atom_array
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class CalculateHbonds(Transform):
|
|
223
|
+
"""Transform for calculating Hbonds, expects an AtomArray containing hydrogens."""
|
|
224
|
+
|
|
225
|
+
def __init__(
|
|
226
|
+
self,
|
|
227
|
+
selection1_type: Literal["acceptor", "donor", "both"] = "both",
|
|
228
|
+
cutoff_dist: float = 3,
|
|
229
|
+
cutoff_angle: float = 120,
|
|
230
|
+
donor_elements: Tuple[str] = ("O", "N", "S", "F"),
|
|
231
|
+
acceptor_elements: Tuple[str] = ("O", "N", "S", "F"),
|
|
232
|
+
periodic: bool = False,
|
|
233
|
+
make2d: bool = False,
|
|
234
|
+
):
|
|
235
|
+
"""
|
|
236
|
+
Initialize the Hbonds transform.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
|
|
240
|
+
selection1 and selection2 (list[str], optional): Specify a list of ChainTypes as in atomworks.enums. e.g. selectoin1 = ['POLYPEPTIDE(L)'], selection2 = ['NON-POLYMER', 'POLYRIBONUCLEOTIDE']
|
|
241
|
+
Allowed values: {'PEPTIDE NUCLEIC ACID', 'BRANCHED', 'POLYDEOXYRIBONUCLEOTIDE', 'POLYRIBONUCLEOTIDE', 'CYCLIC-PSEUDO-PEPTIDE', 'MACROLIDE', 'POLYDEOXYRIBONUCLEOTIDE/POLYRIBONUCLEOTIDE HYBRID', 'OTHER', 'POLYPEPTIDE(L)', 'NON-POLYMER', 'POLYPEPTIDE(D)', 'WATER'}
|
|
242
|
+
|
|
243
|
+
selection1_type (Literal, optional): Determines the type of selection1. The type of selection2 is chosen accordingly (‘both’ or the opposite).
|
|
244
|
+
(Default: 'both')
|
|
245
|
+
cutoff_dist (float, optional): The maximal distance between the hydrogen and acceptor to be considered a hydrogen bond. (Default: 2.5)
|
|
246
|
+
cutoff_angle (float, optional): The angle cutoff in degree between Donor-H..Acceptor to be considered a hydrogen bond. (Default: 120)
|
|
247
|
+
donor_elements, acceptor_elements (tuple of str): Elements to be considered as possible donors or acceptors. (Default: O, N, S)
|
|
248
|
+
periodic (bool, optional): If true, hydrogen bonds can also be detected in periodic boundary conditions. The box attribute of atoms is required in this case. (Default: False)
|
|
249
|
+
"""
|
|
250
|
+
self.selection1_type = selection1_type
|
|
251
|
+
self.cutoff_dist = cutoff_dist
|
|
252
|
+
self.cutoff_angle = cutoff_angle
|
|
253
|
+
self.donor_elements = donor_elements
|
|
254
|
+
self.acceptor_elements = acceptor_elements
|
|
255
|
+
self.periodic = periodic
|
|
256
|
+
self.make2d = make2d
|
|
257
|
+
|
|
258
|
+
def check_input(self, data: dict[str, Any]) -> None:
|
|
259
|
+
check_contains_keys(data, ["atom_array"])
|
|
260
|
+
check_is_instance(data, "atom_array", AtomArray)
|
|
261
|
+
check_atom_array_annotation(data, ["res_name"])
|
|
262
|
+
|
|
263
|
+
## turn off cause H addition debug ongoing
|
|
264
|
+
# check_atom_array_has_hydrogen(data)
|
|
265
|
+
|
|
266
|
+
def forward(self, data: dict) -> dict:
|
|
267
|
+
"""
|
|
268
|
+
Calculates Hbonds and adds it to the data dictionary under the key `hbonds`.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
data: dict
|
|
272
|
+
A dictionary containing the input data atomarray.
|
|
273
|
+
Expects the atom_array in data["atom_array"] contains hydrogens.
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
Returns:
|
|
277
|
+
dict: The data dictionary with hbonds added.
|
|
278
|
+
Sets hbond_type = [Donor, Acceptor] annotation to each atom. Donor, Acceptor can be both 0 or 1 (float). size: Lx2 (L: length of AtomArray)
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
atom_array: AtomArray = data["atom_array"]
|
|
282
|
+
|
|
283
|
+
try:
|
|
284
|
+
atom_array = add_hydrogen_atom_positions(atom_array)
|
|
285
|
+
|
|
286
|
+
except Exception as e:
|
|
287
|
+
print(
|
|
288
|
+
f"WARNING: problem adding hydrogens: {e}.\nThis example will get no hydrogen bond annotations."
|
|
289
|
+
)
|
|
290
|
+
atom_array.set_annotation(
|
|
291
|
+
"active_donor", np.zeros(atom_array.array_length(), dtype=bool)
|
|
292
|
+
)
|
|
293
|
+
atom_array.set_annotation(
|
|
294
|
+
"active_acceptor", np.zeros(atom_array.array_length(), dtype=bool)
|
|
295
|
+
)
|
|
296
|
+
data["atom_array"] = atom_array
|
|
297
|
+
return data
|
|
298
|
+
|
|
299
|
+
## These are the only two use-cases we have so far. Can be extended as needed
|
|
300
|
+
|
|
301
|
+
if data["sampled_condition_name"] == "ppi":
|
|
302
|
+
selection1_chain_types = ["POLYPEPTIDE(D)", "POLYPEPTIDE(L)"]
|
|
303
|
+
selection2_chain_types = ["POLYPEPTIDE(D)", "POLYPEPTIDE(L)"]
|
|
304
|
+
separate_selections_for_motif_and_diffused = True
|
|
305
|
+
else:
|
|
306
|
+
selection1_chain_types = SELECTION_PROTEIN
|
|
307
|
+
selection2_chain_types = SELECTION_NONPROTEIN
|
|
308
|
+
separate_selections_for_motif_and_diffused = False
|
|
309
|
+
|
|
310
|
+
selection1 = np.isin(atom_array.chain_type, selection1_chain_types)
|
|
311
|
+
selection2 = np.isin(atom_array.chain_type, selection2_chain_types)
|
|
312
|
+
|
|
313
|
+
# Optionally restrict to Hbonds between motif and diffused regions
|
|
314
|
+
if separate_selections_for_motif_and_diffused:
|
|
315
|
+
selection1 = selection1 & atom_array.is_motif_atom
|
|
316
|
+
selection2 = selection2 & ~atom_array.is_motif_atom
|
|
317
|
+
else:
|
|
318
|
+
# Include fixed motif atoms for hbond calculations
|
|
319
|
+
selection2 |= np.array(atom_array.is_motif_atom, dtype=bool)
|
|
320
|
+
selection1 = ~selection2
|
|
321
|
+
|
|
322
|
+
hbonds, hbond_types, atom_array = calculate_hbonds(
|
|
323
|
+
atom_array,
|
|
324
|
+
selection1=selection1,
|
|
325
|
+
selection2=selection2,
|
|
326
|
+
selection1_type=self.selection1_type,
|
|
327
|
+
cutoff_dist=self.cutoff_dist,
|
|
328
|
+
cutoff_angle=self.cutoff_angle,
|
|
329
|
+
donor_elements=self.donor_elements,
|
|
330
|
+
acceptor_elements=self.acceptor_elements,
|
|
331
|
+
periodic=self.periodic,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
# Initialize log_dict if not present
|
|
335
|
+
data.setdefault("log_dict", {})
|
|
336
|
+
log_dict = data["log_dict"]
|
|
337
|
+
|
|
338
|
+
# Log hbond statistics
|
|
339
|
+
log_dict["hbond_total_count"] = len(hbonds)
|
|
340
|
+
log_dict["hbond_total_atoms"] = hbond_types.sum()
|
|
341
|
+
|
|
342
|
+
# Subsample if hbond_subsample is set and number of atoms is bigger than 3
|
|
343
|
+
final_hbond_types = hbond_types
|
|
344
|
+
final_hbond_types[:, 0] = final_hbond_types[:, 0] * np.array(
|
|
345
|
+
atom_array.is_motif_atom
|
|
346
|
+
)
|
|
347
|
+
final_hbond_types[:, 1] = final_hbond_types[:, 1] * np.array(
|
|
348
|
+
atom_array.is_motif_atom
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
if data["conditions"]["hbond_subsample"] and np.sum(hbond_types) > 3:
|
|
352
|
+
# Linear correlation: fewer hbonds = higher fraction
|
|
353
|
+
base_fraction = 0.1 # minimum fraction (when many hbonds)
|
|
354
|
+
max_fraction = 0.9 # maximum fraction (when few hbonds)
|
|
355
|
+
n_hbonds = len(hbonds)
|
|
356
|
+
max_hbonds = 50 # Expected maximum number of hbonds for scaling
|
|
357
|
+
|
|
358
|
+
# Linear interpolation: fraction decreases linearly with number of hbonds
|
|
359
|
+
fraction = max_fraction - (max_fraction - base_fraction) * min(
|
|
360
|
+
n_hbonds / max_hbonds, 1.0
|
|
361
|
+
)
|
|
362
|
+
final_hbond_types = subsample_one_hot_np(hbond_types, fraction)
|
|
363
|
+
|
|
364
|
+
# Set annotations and log subsample atoms
|
|
365
|
+
atom_array.set_annotation("active_donor", final_hbond_types[:, 0])
|
|
366
|
+
atom_array.set_annotation("active_acceptor", final_hbond_types[:, 1])
|
|
367
|
+
log_dict["hbond_subsample_atoms"] = final_hbond_types.sum()
|
|
368
|
+
|
|
369
|
+
# Remove hydrogens after processing
|
|
370
|
+
atom_array = remove_hydrogens(atom_array)
|
|
371
|
+
data["log_dict"] = log_dict
|
|
372
|
+
data["atom_array"] = atom_array
|
|
373
|
+
return data
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def subsample_one_hot_np(array, fraction):
|
|
377
|
+
"""
|
|
378
|
+
Subsamples a one-hot encoded NumPy array by randomly keeping a given fraction of the 1s.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
array (np.ndarray): One-hot array of 0s and 1s.
|
|
382
|
+
fraction (float): Fraction of 1s to keep (0 < fraction <= 1).
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
np.ndarray: Subsampled array with same shape.
|
|
386
|
+
"""
|
|
387
|
+
if not (0 < fraction <= 1):
|
|
388
|
+
raise ValueError("Fraction must be in the range (0, 1].")
|
|
389
|
+
|
|
390
|
+
array = array.copy() # Don't modify original
|
|
391
|
+
one_indices = np.argwhere(array == 1)
|
|
392
|
+
num_ones = len(one_indices)
|
|
393
|
+
|
|
394
|
+
keep_count = int(num_ones * fraction)
|
|
395
|
+
|
|
396
|
+
# Shuffle and choose a subset of indices to keep
|
|
397
|
+
np.random.shuffle(one_indices)
|
|
398
|
+
keep_indices = one_indices[:keep_count]
|
|
399
|
+
|
|
400
|
+
# Create new zero array
|
|
401
|
+
new_array = np.zeros_like(array)
|
|
402
|
+
|
|
403
|
+
# Set selected indices to 1
|
|
404
|
+
for i, j in keep_indices:
|
|
405
|
+
new_array[i, j] = 1
|
|
406
|
+
|
|
407
|
+
return new_array
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import string
|
|
3
|
+
import subprocess
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Any, Tuple
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from atomworks.ml.transforms._checks import (
|
|
9
|
+
check_atom_array_annotation,
|
|
10
|
+
check_contains_keys,
|
|
11
|
+
check_is_instance,
|
|
12
|
+
)
|
|
13
|
+
from atomworks.ml.transforms.base import Transform
|
|
14
|
+
from biotite.structure import AtomArray
|
|
15
|
+
from biotite.structure.io.pdb import PDBFile
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def save_atomarray_to_pdb(atom_array, output_path):
|
|
19
|
+
def _handle_nan_coords(atom_array, noise_level=1e-3):
|
|
20
|
+
coords = atom_array.coord
|
|
21
|
+
nan_mask = np.isnan(coords)
|
|
22
|
+
coords[nan_mask] = np.random.uniform(
|
|
23
|
+
-noise_level, noise_level, size=nan_mask.sum()
|
|
24
|
+
)
|
|
25
|
+
atom_array.coord = coords
|
|
26
|
+
return atom_array, nan_mask
|
|
27
|
+
|
|
28
|
+
atom_array, nan_mask = _handle_nan_coords(atom_array)
|
|
29
|
+
|
|
30
|
+
chain_iids = np.unique(atom_array.chain_iid)
|
|
31
|
+
if len(chain_iids) > 52:
|
|
32
|
+
raise ValueError(
|
|
33
|
+
"Too many chain_iids, cannot convert to PDB", "skipping HBPLUS"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
all_possible_chainIDS = string.ascii_letters
|
|
37
|
+
chain_map = {}
|
|
38
|
+
for item in chain_iids:
|
|
39
|
+
if len(item) == 1:
|
|
40
|
+
chain_map[item] = item
|
|
41
|
+
all_possible_chainIDS = all_possible_chainIDS.replace(item, "")
|
|
42
|
+
for item in chain_iids:
|
|
43
|
+
if len(item) > 1:
|
|
44
|
+
chain_map[item] = all_possible_chainIDS[0]
|
|
45
|
+
all_possible_chainIDS = all_possible_chainIDS.replace(chain_map[item], "")
|
|
46
|
+
|
|
47
|
+
new_chain_ids = [chain_map[i] for i in atom_array.chain_iid]
|
|
48
|
+
inverted_chain_map = {v: k for k, v in chain_map.items()}
|
|
49
|
+
atom_array.chain_id = new_chain_ids
|
|
50
|
+
atom_array.b_factor = np.zeros(len(atom_array))
|
|
51
|
+
|
|
52
|
+
pdb = PDBFile()
|
|
53
|
+
pdb.set_structure(atom_array)
|
|
54
|
+
pdb.write(output_path)
|
|
55
|
+
|
|
56
|
+
return atom_array, nan_mask, inverted_chain_map
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def check_atom_array_has_hydrogen(data: dict[str, Any]):
|
|
60
|
+
if not np.any(data["atom_array"].element == "H"):
|
|
61
|
+
raise ValueError("Key `atom_array` in data has no hydrogens.")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def calculate_hbonds(
|
|
65
|
+
atom_array: AtomArray,
|
|
66
|
+
cutoff_HA_dist: float = 3,
|
|
67
|
+
cutoff_DA_distance: float = 3.5,
|
|
68
|
+
) -> Tuple[np.ndarray, np.ndarray, AtomArray]:
|
|
69
|
+
dtstr = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
70
|
+
pdb_path = f"{dtstr}_{np.random.randint(10000)}.pdb"
|
|
71
|
+
atom_array, nan_mask, chain_map = save_atomarray_to_pdb(atom_array, pdb_path)
|
|
72
|
+
|
|
73
|
+
hbplus_exe = os.environ.get("HBPLUS_PATH")
|
|
74
|
+
|
|
75
|
+
if hbplus_exe is None or hbplus_exe == "":
|
|
76
|
+
raise ValueError(
|
|
77
|
+
"HBPLUS_PATH environment variable not set. "
|
|
78
|
+
"Please set it to the path of the hbplus executable in order to calculate hydrogen bonds."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
subprocess.call(
|
|
82
|
+
[
|
|
83
|
+
hbplus_exe,
|
|
84
|
+
"-h",
|
|
85
|
+
str(cutoff_HA_dist),
|
|
86
|
+
"-d",
|
|
87
|
+
str(cutoff_DA_distance),
|
|
88
|
+
pdb_path,
|
|
89
|
+
pdb_path,
|
|
90
|
+
],
|
|
91
|
+
stdout=subprocess.DEVNULL,
|
|
92
|
+
stderr=subprocess.DEVNULL,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
HB = open(pdb_path.replace("pdb", "hb2"), "r").readlines()
|
|
96
|
+
hbonds = []
|
|
97
|
+
for i in range(8, len(HB)):
|
|
98
|
+
d_chain = HB[i][0]
|
|
99
|
+
d_resi = str(int(HB[i][1:5].strip()))
|
|
100
|
+
d_resn = HB[i][6:9].strip()
|
|
101
|
+
d_ins = HB[i][5].replace("-", " ")
|
|
102
|
+
d_atom = HB[i][9:13].strip()
|
|
103
|
+
a_chain = HB[i][14]
|
|
104
|
+
a_resi = str(int(HB[i][15:19].strip()))
|
|
105
|
+
a_ins = HB[i][19].replace("-", " ")
|
|
106
|
+
a_resn = HB[i][20:23].strip()
|
|
107
|
+
a_atom = HB[i][23:27].strip()
|
|
108
|
+
dist = float(HB[i][27:32].strip())
|
|
109
|
+
|
|
110
|
+
items = {
|
|
111
|
+
"d_chain": chain_map[d_chain],
|
|
112
|
+
"d_resi": d_resi,
|
|
113
|
+
"d_resn": d_resn,
|
|
114
|
+
"d_ins": d_ins,
|
|
115
|
+
"d_atom": d_atom,
|
|
116
|
+
"a_chain": chain_map[a_chain],
|
|
117
|
+
"a_resi": a_resi,
|
|
118
|
+
"a_resn": a_resn,
|
|
119
|
+
"a_ins": a_ins,
|
|
120
|
+
"a_atom": a_atom,
|
|
121
|
+
"dist": dist,
|
|
122
|
+
}
|
|
123
|
+
hbonds.append(items)
|
|
124
|
+
|
|
125
|
+
donor_array = np.zeros(len(atom_array))
|
|
126
|
+
acceptor_array = np.zeros(len(atom_array))
|
|
127
|
+
donor_mask = np.bool_(donor_array)
|
|
128
|
+
acceptor_mask = np.bool_(acceptor_array)
|
|
129
|
+
|
|
130
|
+
motif_hbonds = []
|
|
131
|
+
for item in hbonds:
|
|
132
|
+
current_donor_mask = (
|
|
133
|
+
(atom_array.chain_iid == item["d_chain"])
|
|
134
|
+
& (atom_array.res_id == float(item["d_resi"]))
|
|
135
|
+
& (atom_array.atom_name == item["d_atom"])
|
|
136
|
+
)
|
|
137
|
+
current_acceptor_mask = (
|
|
138
|
+
(atom_array.chain_iid == item["a_chain"])
|
|
139
|
+
& (atom_array.res_id == float(item["a_resi"]))
|
|
140
|
+
& (atom_array.atom_name == item["a_atom"])
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Ensure that we can uniquely identify the donor and acceptor atoms
|
|
144
|
+
if current_donor_mask.sum() != 1:
|
|
145
|
+
raise ValueError(
|
|
146
|
+
f"Unable to uniquely identify a donor atom with chain_iid={item['d_chain']}, res_id={item['d_resi']}, atom_name={item['d_atom']}."
|
|
147
|
+
)
|
|
148
|
+
if current_acceptor_mask.sum() != 1:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Unable to uniquely identify an acceptor atom with chain_iid={item['a_chain']}, res_id={item['a_resi']}, atom_name={item['a_atom']}."
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
current_donor_is_motif = atom_array.is_motif_atom[current_donor_mask][0]
|
|
154
|
+
current_acceptor_is_motif = atom_array.is_motif_atom[current_acceptor_mask][0]
|
|
155
|
+
|
|
156
|
+
# Only keep hbonds between the motif and diffused regions
|
|
157
|
+
if current_donor_is_motif != current_acceptor_is_motif:
|
|
158
|
+
motif_hbonds.append(item)
|
|
159
|
+
donor_mask |= current_donor_mask
|
|
160
|
+
acceptor_mask |= current_acceptor_mask
|
|
161
|
+
|
|
162
|
+
donor_array[donor_mask] = 1
|
|
163
|
+
acceptor_array[acceptor_mask] = 1
|
|
164
|
+
|
|
165
|
+
os.remove(pdb_path)
|
|
166
|
+
os.remove(pdb_path.replace("pdb", "hb2"))
|
|
167
|
+
atom_array.set_annotation("active_donor", donor_array)
|
|
168
|
+
atom_array.set_annotation("active_acceptor", acceptor_array)
|
|
169
|
+
|
|
170
|
+
return atom_array, motif_hbonds, len(motif_hbonds)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class CalculateHbondsPlus(Transform):
|
|
174
|
+
"""Transform for calculating Hbonds, expects an AtomArray containing hydrogens."""
|
|
175
|
+
|
|
176
|
+
def __init__(
|
|
177
|
+
self,
|
|
178
|
+
cutoff_HA_dist: float = 3,
|
|
179
|
+
cutoff_DA_distance: float = 3.5,
|
|
180
|
+
):
|
|
181
|
+
self.cutoff_HA_dist = cutoff_HA_dist
|
|
182
|
+
self.cutoff_DA_distance = cutoff_DA_distance
|
|
183
|
+
|
|
184
|
+
def check_input(self, data: dict[str, Any]) -> None:
|
|
185
|
+
check_contains_keys(data, ["atom_array"])
|
|
186
|
+
check_is_instance(data, "atom_array", AtomArray)
|
|
187
|
+
check_atom_array_annotation(data, ["res_name"])
|
|
188
|
+
# check_atom_array_has_hydrogen(data)
|
|
189
|
+
|
|
190
|
+
def forward(self, data: dict) -> dict:
|
|
191
|
+
atom_array: AtomArray = data["atom_array"]
|
|
192
|
+
|
|
193
|
+
atom_array, hbonds, _ = calculate_hbonds(
|
|
194
|
+
atom_array,
|
|
195
|
+
cutoff_HA_dist=self.cutoff_HA_dist,
|
|
196
|
+
cutoff_DA_distance=self.cutoff_DA_distance,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
data.setdefault("log_dict", {})
|
|
200
|
+
log_dict = data["log_dict"]
|
|
201
|
+
|
|
202
|
+
hbond_types = np.vstack((atom_array.active_donor, atom_array.active_acceptor)).T
|
|
203
|
+
|
|
204
|
+
final_hbond_types = hbond_types
|
|
205
|
+
final_hbond_types[:, 0] *= np.array(atom_array.is_motif_atom)
|
|
206
|
+
final_hbond_types[:, 1] *= np.array(atom_array.is_motif_atom)
|
|
207
|
+
log_dict["hbond_total_count"] = np.sum(final_hbond_types)
|
|
208
|
+
|
|
209
|
+
if data["conditions"]["hbond_subsample"] and np.sum(final_hbond_types) > 3:
|
|
210
|
+
base_fraction = 0.1
|
|
211
|
+
max_fraction = 0.9
|
|
212
|
+
n_hbonds = np.sum(final_hbond_types)
|
|
213
|
+
max_hbonds = 50
|
|
214
|
+
|
|
215
|
+
fraction = max_fraction - (max_fraction - base_fraction) * min(
|
|
216
|
+
n_hbonds / max_hbonds, 1.0
|
|
217
|
+
)
|
|
218
|
+
final_hbond_types = subsample_one_hot_np(final_hbond_types, fraction)
|
|
219
|
+
|
|
220
|
+
atom_array.set_annotation("active_donor", final_hbond_types[:, 0])
|
|
221
|
+
atom_array.set_annotation("active_acceptor", final_hbond_types[:, 1])
|
|
222
|
+
log_dict["hbond_subsample_atoms"] = np.sum(final_hbond_types)
|
|
223
|
+
|
|
224
|
+
data["log_dict"] = log_dict
|
|
225
|
+
data["atom_array"] = atom_array
|
|
226
|
+
|
|
227
|
+
return data
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def subsample_one_hot_np(array, fraction):
|
|
231
|
+
if not (0 < fraction <= 1):
|
|
232
|
+
raise ValueError("Fraction must be in the range (0, 1].")
|
|
233
|
+
|
|
234
|
+
array = array.copy()
|
|
235
|
+
one_indices = np.argwhere(array == 1)
|
|
236
|
+
num_ones = len(one_indices)
|
|
237
|
+
keep_count = int(num_ones * fraction)
|
|
238
|
+
|
|
239
|
+
np.random.shuffle(one_indices)
|
|
240
|
+
keep_indices = one_indices[:keep_count]
|
|
241
|
+
|
|
242
|
+
new_array = np.zeros_like(array)
|
|
243
|
+
for i, j in keep_indices:
|
|
244
|
+
new_array[i, j] = 1
|
|
245
|
+
|
|
246
|
+
return new_array
|