rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,541 @@
|
|
|
1
|
+
# from atomworks.ml.utils.token import get_token_masks, get_token_starts
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import biotite.structure as struc
|
|
5
|
+
import numpy as np
|
|
6
|
+
from assertpy import assert_that
|
|
7
|
+
from atomworks.ml.preprocessing.utils.structure_utils import (
|
|
8
|
+
get_atom_mask_from_cell_list,
|
|
9
|
+
)
|
|
10
|
+
from atomworks.ml.transforms._checks import (
|
|
11
|
+
check_atom_array_annotation,
|
|
12
|
+
check_contains_keys,
|
|
13
|
+
check_is_instance,
|
|
14
|
+
)
|
|
15
|
+
from atomworks.ml.transforms.atom_array import atom_id_to_atom_idx, atom_id_to_token_idx
|
|
16
|
+
from atomworks.ml.transforms.base import Transform
|
|
17
|
+
from atomworks.ml.transforms.crop import (
|
|
18
|
+
get_spatial_crop_center,
|
|
19
|
+
get_token_count,
|
|
20
|
+
resize_crop_info_if_too_many_atoms,
|
|
21
|
+
)
|
|
22
|
+
from atomworks.ml.utils.token import (
|
|
23
|
+
get_af3_token_center_coords,
|
|
24
|
+
get_af3_token_center_masks,
|
|
25
|
+
get_token_starts,
|
|
26
|
+
spread_token_wise,
|
|
27
|
+
)
|
|
28
|
+
from biotite.structure import AtomArray
|
|
29
|
+
from rfd3.transforms.conditioning_utils import sample_island_tokens
|
|
30
|
+
from scipy.spatial import KDTree
|
|
31
|
+
|
|
32
|
+
from foundry.utils.ddp import RankedLogger
|
|
33
|
+
|
|
34
|
+
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
35
|
+
|
|
36
|
+
# NOTE: This transform is based off of `rf_diffusion_aa.rf_diffusion.ppi.FindHotspotsTrainingTransform`
|
|
37
|
+
# However, this is progressing piecewise, and many features of that transform are not yet implemented.
|
|
38
|
+
# If this seems to be working, those should definitely be added in the future!
|
|
39
|
+
|
|
40
|
+
# NOTE: In contrast to RFD, we are providing hotspots at the atom level, not the residue level.
|
|
41
|
+
# Future hotspot subsampling schemes might want to avoid giving redundant information via (say) bonded atoms
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_hotspot_atoms(atom_array, binder_pn_unit_iid, distance_cutoff=4.5):
|
|
45
|
+
"""Get hotspot atoms for a given distance cutoff.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
atom_array (AtomArray): The atom array containing the protein structure.
|
|
49
|
+
binder_pn_unit_iid (str): The chain ID of the binder (diffused chain).
|
|
50
|
+
distance_cutoff (float): The interchain distance cutoff that defines hotspot atoms.
|
|
51
|
+
|
|
52
|
+
Hotspots are atoms on non-binder chains that are within the distance cutoff of some residue on the binder.
|
|
53
|
+
Residue distances are computed as the minimum pairwise distance between the two atoms.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
# We can only perform distance computations on atoms with non-NaN coordinates
|
|
57
|
+
nan_coords_mask = np.any(np.isnan(atom_array.coord), axis=1)
|
|
58
|
+
non_nan_atom_array = atom_array[~nan_coords_mask]
|
|
59
|
+
|
|
60
|
+
binder_atom_array = non_nan_atom_array[
|
|
61
|
+
non_nan_atom_array.pn_unit_iid == binder_pn_unit_iid
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
# Perform the hotspot computation
|
|
65
|
+
cell_list = struc.CellList(non_nan_atom_array, cell_size=distance_cutoff)
|
|
66
|
+
|
|
67
|
+
full_contacting_atom_mask = get_atom_mask_from_cell_list(
|
|
68
|
+
binder_atom_array.coord, cell_list, len(non_nan_atom_array), distance_cutoff
|
|
69
|
+
) # (n_query, n_cell_list)
|
|
70
|
+
contacting_atoms_mask = np.any(full_contacting_atom_mask, axis=0) # (n_cell_list,)
|
|
71
|
+
|
|
72
|
+
# Filter out atoms in the binder chain
|
|
73
|
+
non_query_atoms_mask = non_nan_atom_array.pn_unit_iid != binder_pn_unit_iid
|
|
74
|
+
hotspot_atom_mask = contacting_atoms_mask & non_query_atoms_mask
|
|
75
|
+
|
|
76
|
+
# Convert from mask over non-nan coords to mask over all coords
|
|
77
|
+
full_hotspot_atom_mask = np.zeros(len(atom_array), dtype=bool)
|
|
78
|
+
full_hotspot_atom_mask[~nan_coords_mask] = hotspot_atom_mask
|
|
79
|
+
|
|
80
|
+
return full_hotspot_atom_mask
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def get_secondary_structure_types(atom_array: AtomArray) -> np.ndarray:
|
|
84
|
+
"""Get the secondary structure types for a given atom array.
|
|
85
|
+
|
|
86
|
+
For now, only three categories will be one-hot encoded: helix, sheet, and loop.
|
|
87
|
+
"""
|
|
88
|
+
ss_types = np.zeros((atom_array.array_length(), 3), dtype=bool)
|
|
89
|
+
|
|
90
|
+
# HACK: Temporarily overwrite res_id with token_id so that the sse_array will have length n_tokens
|
|
91
|
+
actual_res_id = atom_array.res_id.copy()
|
|
92
|
+
atom_array.res_id = atom_array.token_id
|
|
93
|
+
|
|
94
|
+
# Since annotate_sse detects chainbreaks based on res_id discontinuities, we create discontinuities where needed
|
|
95
|
+
_, chain_offsets = np.unique(atom_array.chain_iid, return_inverse=True)
|
|
96
|
+
atom_array.res_id += chain_offsets
|
|
97
|
+
|
|
98
|
+
# Compute secondary structure information
|
|
99
|
+
sse_array = struc.annotate_sse(atom_array)
|
|
100
|
+
assert len(sse_array) == len(
|
|
101
|
+
np.unique(atom_array.token_id)
|
|
102
|
+
), "SSE array length does not match number of tokens."
|
|
103
|
+
|
|
104
|
+
# Restore original res_id
|
|
105
|
+
atom_array.res_id = actual_res_id
|
|
106
|
+
|
|
107
|
+
sse_array = spread_token_wise(atom_array, sse_array)
|
|
108
|
+
ss_types[:, 0] = sse_array == "a"
|
|
109
|
+
ss_types[:, 1] = sse_array == "b"
|
|
110
|
+
ss_types[:, 2] = sse_array == "c"
|
|
111
|
+
|
|
112
|
+
return ss_types
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class AddGlobalIsNonLoopyFeature(Transform):
|
|
116
|
+
"""Add feature indicating whether the global loop content in the non-motif region is below 30%.
|
|
117
|
+
|
|
118
|
+
For this initial implementation, only three categories will be one-hot encoded: helix, sheet, and loop.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def check_input(self, data: dict[str, Any]) -> None:
|
|
122
|
+
check_contains_keys(data, ["atom_array"])
|
|
123
|
+
check_is_instance(data, "atom_array", AtomArray)
|
|
124
|
+
check_atom_array_annotation(data, ["is_motif_token"])
|
|
125
|
+
|
|
126
|
+
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
127
|
+
atom_array = data["atom_array"]
|
|
128
|
+
|
|
129
|
+
# Compute all ground-truth secondary structure types for the binder chain.
|
|
130
|
+
# For now boolean, later could include distances as in RFD. But maybe that's better as a 2D condition
|
|
131
|
+
gt_secondary_structures = get_secondary_structure_types(atom_array)
|
|
132
|
+
atom_array.set_annotation("is_loop_gt", gt_secondary_structures[:, 2])
|
|
133
|
+
|
|
134
|
+
is_motif_atom = atom_array.is_motif_token
|
|
135
|
+
is_non_loopy = atom_array.is_loop_gt[~is_motif_atom].mean() < 0.3
|
|
136
|
+
is_non_loopy_annot = np.full(
|
|
137
|
+
atom_array.array_length(), 1 if is_non_loopy else -1, dtype=int
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
atom_array.set_annotation("is_non_loopy", is_non_loopy_annot)
|
|
141
|
+
|
|
142
|
+
# HACK: Enables adding as atom-level features as well
|
|
143
|
+
atom_array.set_annotation("is_non_loopy_atom_level", is_non_loopy_annot)
|
|
144
|
+
|
|
145
|
+
return data
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class Add1DSSFeature(Transform):
|
|
149
|
+
"""Add secondary structure features to training examples.
|
|
150
|
+
|
|
151
|
+
For this initial implementation, only three categories will be one-hot encoded: helix, sheet, and loop.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
max_secondary_structure_frac_to_provide: float = 0.4,
|
|
157
|
+
min_ss_island_len: int = 1,
|
|
158
|
+
max_ss_island_len: int = 10, # Might want to expand later, this is only average. Done for now to avoid over-conditioning.
|
|
159
|
+
n_islands_max: int = 3,
|
|
160
|
+
):
|
|
161
|
+
self.max_secondary_structure_frac_to_provide = (
|
|
162
|
+
max_secondary_structure_frac_to_provide
|
|
163
|
+
)
|
|
164
|
+
self.min_ss_island_len = min_ss_island_len
|
|
165
|
+
self.max_ss_island_len = max_ss_island_len
|
|
166
|
+
self.n_islands_max = n_islands_max
|
|
167
|
+
|
|
168
|
+
def check_input(self, data: dict[str, Any]) -> None:
|
|
169
|
+
check_contains_keys(data, ["atom_array"])
|
|
170
|
+
check_is_instance(data, "atom_array", AtomArray)
|
|
171
|
+
check_atom_array_annotation(data, ["is_motif_token"])
|
|
172
|
+
|
|
173
|
+
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
174
|
+
atom_array = data["atom_array"]
|
|
175
|
+
|
|
176
|
+
# Compute all ground-truth secondary structure types for the binder chain.
|
|
177
|
+
gt_secondary_structures = get_secondary_structure_types(atom_array)
|
|
178
|
+
atom_array.set_annotation("is_helix_gt", gt_secondary_structures[:, 0])
|
|
179
|
+
atom_array.set_annotation("is_sheet_gt", gt_secondary_structures[:, 1])
|
|
180
|
+
atom_array.set_annotation("is_loop_gt", gt_secondary_structures[:, 2])
|
|
181
|
+
|
|
182
|
+
if not data["conditions"]["add_1d_ss_features"]:
|
|
183
|
+
return data
|
|
184
|
+
|
|
185
|
+
# Always add the secondary structure type annotation, even if all zeros
|
|
186
|
+
atom_array.set_annotation(
|
|
187
|
+
"is_helix_conditioning", np.zeros(atom_array.array_length(), dtype=bool)
|
|
188
|
+
)
|
|
189
|
+
atom_array.set_annotation(
|
|
190
|
+
"is_sheet_conditioning", np.zeros(atom_array.array_length(), dtype=bool)
|
|
191
|
+
)
|
|
192
|
+
atom_array.set_annotation(
|
|
193
|
+
"is_loop_conditioning", np.zeros(atom_array.array_length(), dtype=bool)
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# # Uniformly sample a number of tokens to receive secondary structure conditioning, up to the given maximum fraction
|
|
197
|
+
max_residues_with_ss_conditioning = int(
|
|
198
|
+
np.ceil(
|
|
199
|
+
gt_secondary_structures.sum()
|
|
200
|
+
* self.max_secondary_structure_frac_to_provide
|
|
201
|
+
)
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# Compute islands within the subset that is diffused and has secondary structure types.
|
|
205
|
+
token_level_array = atom_array[get_token_starts(atom_array)]
|
|
206
|
+
is_motif_token = token_level_array.is_motif_token
|
|
207
|
+
eligible_for_ss_info_mask = (
|
|
208
|
+
~is_motif_token
|
|
209
|
+
& token_level_array.is_protein
|
|
210
|
+
& ( # Protein atoms with NaN coordinates would have no secondary structure annotation
|
|
211
|
+
token_level_array.is_helix_gt
|
|
212
|
+
| token_level_array.is_sheet_gt
|
|
213
|
+
| token_level_array.is_loop_gt
|
|
214
|
+
)
|
|
215
|
+
)
|
|
216
|
+
where_to_show_ss = sample_island_tokens(
|
|
217
|
+
eligible_for_ss_info_mask.sum(),
|
|
218
|
+
island_len_min=self.min_ss_island_len,
|
|
219
|
+
island_len_max=self.max_ss_island_len,
|
|
220
|
+
n_islands_min=1,
|
|
221
|
+
n_islands_max=self.n_islands_max,
|
|
222
|
+
max_length=max_residues_with_ss_conditioning,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Convert this to a mask over the entire token-level atom array
|
|
226
|
+
token_level_ss_mask = np.zeros(token_level_array.array_length(), dtype=bool)
|
|
227
|
+
token_level_ss_mask[eligible_for_ss_info_mask] = where_to_show_ss
|
|
228
|
+
ss_mask = spread_token_wise(atom_array, token_level_ss_mask)
|
|
229
|
+
|
|
230
|
+
# Add the secondary structure type annotation
|
|
231
|
+
atom_array.is_helix_conditioning[ss_mask] = atom_array.is_helix_gt[ss_mask]
|
|
232
|
+
atom_array.is_sheet_conditioning[ss_mask] = atom_array.is_sheet_gt[ss_mask]
|
|
233
|
+
atom_array.is_loop_conditioning[ss_mask] = atom_array.is_loop_gt[ss_mask]
|
|
234
|
+
|
|
235
|
+
return data
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class AddPPIHotspotFeature(Transform):
|
|
239
|
+
"""Add hotspot features to PPI training examples."""
|
|
240
|
+
|
|
241
|
+
def __init__(
|
|
242
|
+
self,
|
|
243
|
+
max_hotspots_frac_to_provide: float = 0.2,
|
|
244
|
+
hotspot_max_distance: float = 7.0,
|
|
245
|
+
):
|
|
246
|
+
"""
|
|
247
|
+
Args:
|
|
248
|
+
max_hotspots_frac_to_provide (int): Maximum fraction of ground-truth hotspots to add to the training example.
|
|
249
|
+
The actual number added will be sampled uniformly from 0 to the number dictated by this parameter.
|
|
250
|
+
hotspot_min_distance (float): Maximum distance to the binder for an atom to be considered a hotspot.
|
|
251
|
+
"""
|
|
252
|
+
self.max_hotspots_frac_to_provide = max_hotspots_frac_to_provide
|
|
253
|
+
self.hotspot_max_distance = hotspot_max_distance
|
|
254
|
+
|
|
255
|
+
def check_input(self, data: dict[str, Any]) -> None:
|
|
256
|
+
check_contains_keys(data, ["atom_array"])
|
|
257
|
+
check_is_instance(data, "atom_array", AtomArray)
|
|
258
|
+
check_atom_array_annotation(data, ["is_motif_token"])
|
|
259
|
+
|
|
260
|
+
def forward(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
261
|
+
atom_array = data["atom_array"]
|
|
262
|
+
|
|
263
|
+
# Always add the hotspot annotation, even if all zeros
|
|
264
|
+
atom_array.set_annotation(
|
|
265
|
+
"is_atom_level_hotspot", np.zeros(atom_array.array_length(), dtype=bool)
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Compute all ground-truth hotspots for the binder chain.
|
|
269
|
+
# For now boolean, later could include distances as in RFD. But maybe that's better as a 2D condition
|
|
270
|
+
is_hotspot_atom_mask = get_hotspot_atoms(
|
|
271
|
+
atom_array,
|
|
272
|
+
binder_pn_unit_iid=data["binder_pn_unit"],
|
|
273
|
+
distance_cutoff=self.hotspot_max_distance,
|
|
274
|
+
)
|
|
275
|
+
atom_array.set_annotation("is_hotspot_gt", is_hotspot_atom_mask)
|
|
276
|
+
|
|
277
|
+
# Uniformly sample a number of hotspots to include, up to the given maximum fraction
|
|
278
|
+
max_hotspots_to_keep = int(
|
|
279
|
+
np.ceil(sum(is_hotspot_atom_mask) * self.max_hotspots_frac_to_provide)
|
|
280
|
+
)
|
|
281
|
+
if max_hotspots_to_keep == 0:
|
|
282
|
+
ranked_logger.warning("No hotspots found in the input data")
|
|
283
|
+
return data
|
|
284
|
+
else:
|
|
285
|
+
num_hotspots_to_keep = np.random.randint(
|
|
286
|
+
0,
|
|
287
|
+
int(
|
|
288
|
+
np.ceil(
|
|
289
|
+
sum(is_hotspot_atom_mask) * self.max_hotspots_frac_to_provide
|
|
290
|
+
)
|
|
291
|
+
),
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# Subsample hotspots to add.
|
|
295
|
+
# For now random, later could add speckle_or_region from RFD
|
|
296
|
+
true_hotspot_indices = np.where(is_hotspot_atom_mask)[0]
|
|
297
|
+
hotspots_to_provide = np.random.choice(
|
|
298
|
+
true_hotspot_indices, size=num_hotspots_to_keep, replace=False
|
|
299
|
+
)
|
|
300
|
+
atom_array.is_atom_level_hotspot[hotspots_to_provide] = True
|
|
301
|
+
|
|
302
|
+
return data
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
class PPIFullBinderCropSpatial(Transform):
|
|
306
|
+
"""Crop which keeps the entire binder chain, then crops spatially around the given interface.
|
|
307
|
+
Args:
|
|
308
|
+
crop_size (int): The maximum number of tokens to crop. Must be greater than 0.
|
|
309
|
+
jitter_scale (float, optional): The scale of the jitter to apply to the crop center.
|
|
310
|
+
This is to break ties between atoms with the same spatial distance. Defaults to 1e-3.
|
|
311
|
+
crop_center_cutoff_distance (float, optional): The cutoff distance to consider for
|
|
312
|
+
selecting crop centers. Measured in Angstroms. Defaults to 15.0.
|
|
313
|
+
keep_uncropped_atom_array (bool, optional): Whether to keep the uncropped atom array in the data.
|
|
314
|
+
If `True`, the uncropped atom array will be stored in the `crop_info` dictionary
|
|
315
|
+
under the key `"atom_array"`. Defaults to `False`.
|
|
316
|
+
force_crop (bool, optional): Whether to force crop even if the atom array is already small enough.
|
|
317
|
+
Defaults to `False`.
|
|
318
|
+
max_atoms_in_crop (int, optional): Maximum number of atoms allowed in a crop. If None, no resizing is performed.
|
|
319
|
+
Defaults to None.
|
|
320
|
+
"""
|
|
321
|
+
|
|
322
|
+
def __init__(
|
|
323
|
+
self,
|
|
324
|
+
crop_size: int,
|
|
325
|
+
jitter_scale: float = 1e-3,
|
|
326
|
+
crop_center_cutoff_distance: float = 15.0,
|
|
327
|
+
keep_uncropped_atom_array: bool = False,
|
|
328
|
+
force_crop: bool = False,
|
|
329
|
+
max_atoms_in_crop: int | None = None,
|
|
330
|
+
):
|
|
331
|
+
self.crop_size = crop_size
|
|
332
|
+
self.jitter_scale = jitter_scale
|
|
333
|
+
self.crop_center_cutoff_distance = crop_center_cutoff_distance
|
|
334
|
+
self.keep_uncropped_atom_array = keep_uncropped_atom_array
|
|
335
|
+
self.force_crop = force_crop
|
|
336
|
+
self.max_atoms_in_crop = max_atoms_in_crop
|
|
337
|
+
|
|
338
|
+
def check_input(self, data: dict):
|
|
339
|
+
check_contains_keys(data, ["atom_array"])
|
|
340
|
+
check_is_instance(data, "atom_array", AtomArray)
|
|
341
|
+
check_atom_array_annotation(data, ["pn_unit_iid", "atomize", "atom_id"])
|
|
342
|
+
|
|
343
|
+
def forward(self, data: dict) -> dict:
|
|
344
|
+
atom_array = data["atom_array"]
|
|
345
|
+
|
|
346
|
+
if "query_pn_unit_iids" in data and data["query_pn_unit_iids"]:
|
|
347
|
+
query_pn_units = data["query_pn_unit_iids"]
|
|
348
|
+
else:
|
|
349
|
+
query_pn_units = np.unique(atom_array.pn_unit_iid)
|
|
350
|
+
ranked_logger.info(
|
|
351
|
+
f"No query PN unit(s) provided for spatial crop. Randomly selecting from {query_pn_units}."
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
if "binder_pn_unit" not in data:
|
|
355
|
+
raise ValueError("Data dict must contain 'binder_pn_unit' key.")
|
|
356
|
+
|
|
357
|
+
crop_info = crop_spatial_keep_full_binder(
|
|
358
|
+
atom_array=atom_array,
|
|
359
|
+
query_pn_unit_iids=query_pn_units,
|
|
360
|
+
binder_pn_unit_iid=data["binder_pn_unit"],
|
|
361
|
+
crop_size=self.crop_size,
|
|
362
|
+
jitter_scale=self.jitter_scale,
|
|
363
|
+
crop_center_cutoff_distance=self.crop_center_cutoff_distance,
|
|
364
|
+
force_crop=self.force_crop,
|
|
365
|
+
)
|
|
366
|
+
crop_info = resize_crop_info_if_too_many_atoms(
|
|
367
|
+
crop_info=crop_info,
|
|
368
|
+
atom_array=atom_array,
|
|
369
|
+
max_atoms=self.max_atoms_in_crop,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
data["crop_info"] = {"type": self.__class__.__name__} | crop_info
|
|
373
|
+
|
|
374
|
+
if self.keep_uncropped_atom_array:
|
|
375
|
+
data["crop_info"]["atom_array"] = atom_array
|
|
376
|
+
|
|
377
|
+
# Update data with cropped atom array
|
|
378
|
+
data["atom_array"] = atom_array[crop_info["crop_atom_idxs"]]
|
|
379
|
+
|
|
380
|
+
return data
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def crop_spatial_keep_full_binder(
|
|
384
|
+
atom_array: AtomArray,
|
|
385
|
+
query_pn_unit_iids: list[str],
|
|
386
|
+
binder_pn_unit_iid: str,
|
|
387
|
+
crop_size: int,
|
|
388
|
+
jitter_scale: float = 1e-3,
|
|
389
|
+
crop_center_cutoff_distance: float = 15.0,
|
|
390
|
+
force_crop: bool = False,
|
|
391
|
+
) -> dict:
|
|
392
|
+
"""
|
|
393
|
+
Crop spatial tokens around a given `crop_center` by keeping the entire binder chain, then taking nearest
|
|
394
|
+
neighbors (with jitter) until reaching the `crop_size`.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
- atom_array (AtomArray): The atom array to crop.
|
|
398
|
+
- query_pn_unit_iids (list[str]): List of query polymer/non-polymer unit instance IDs.
|
|
399
|
+
- binder_pn_unit_iid (str): The polymer/non-polymer unit instance ID corresponding to the binder.
|
|
400
|
+
- crop_size (int): The maximum number of tokens to crop.
|
|
401
|
+
- jitter_scale (float, optional): Scale of jitter to apply when calculating distances.
|
|
402
|
+
Defaults to 1e-3.
|
|
403
|
+
- crop_center_cutoff_distance (float, optional): Maximum distance from query units to
|
|
404
|
+
consider for crop center. Defaults to 15.0 Angstroms.
|
|
405
|
+
- force_crop (bool, optional): Whether to force crop even if the atom array is already small enough.
|
|
406
|
+
Defaults to False.
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
dict: A dictionary containing crop information, including:
|
|
410
|
+
- requires_crop (bool): Whether cropping was necessary.
|
|
411
|
+
- crop_center_atom_id (int or np.nan): ID of the atom chosen as crop center.
|
|
412
|
+
- crop_center_atom_idx (int or np.nan): Index of the atom chosen as crop center.
|
|
413
|
+
- crop_center_token_idx (int or np.nan): Index of the token containing the crop center.
|
|
414
|
+
- crop_token_idxs (np.ndarray): Indices of tokens included in the crop.
|
|
415
|
+
- crop_atom_idxs (np.ndarray): Indices of atoms included in the crop.
|
|
416
|
+
|
|
417
|
+
Note:
|
|
418
|
+
This function implements the spatial cropping procedure as described in AlphaFold 3 and AlphaFold 2 Multimer.
|
|
419
|
+
|
|
420
|
+
References:
|
|
421
|
+
- AF3 https://static-content.springer.com/esm/art%3A10.1038%2Fs41586-024-07487-w/MediaObjects/41586_2024_7487_MOESM1_ESM.pdf
|
|
422
|
+
- AF2 Multimer https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf
|
|
423
|
+
"""
|
|
424
|
+
if binder_pn_unit_iid not in query_pn_unit_iids:
|
|
425
|
+
raise ValueError(
|
|
426
|
+
f"Binder polymer/non-polymer unit instance ID '{binder_pn_unit_iid}' "
|
|
427
|
+
f"not found in query polymer/non-polymer unit instance IDs: {query_pn_unit_iids}"
|
|
428
|
+
)
|
|
429
|
+
n_tokens = get_token_count(atom_array)
|
|
430
|
+
requires_crop = n_tokens > crop_size
|
|
431
|
+
|
|
432
|
+
# ... get binder information
|
|
433
|
+
binder_token_mask = (
|
|
434
|
+
atom_array[get_af3_token_center_masks(atom_array)].pn_unit_iid
|
|
435
|
+
== binder_pn_unit_iid
|
|
436
|
+
)
|
|
437
|
+
binder_atom_mask = atom_array.pn_unit_iid == binder_pn_unit_iid
|
|
438
|
+
n_binder_tokens = get_token_count(atom_array[binder_atom_mask])
|
|
439
|
+
|
|
440
|
+
if force_crop or requires_crop:
|
|
441
|
+
# Get possible crop centers
|
|
442
|
+
can_be_crop_center = get_spatial_crop_center(
|
|
443
|
+
atom_array, query_pn_unit_iids, crop_center_cutoff_distance
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
# ... sample crop center atom
|
|
447
|
+
crop_center_atom_id = np.random.choice(atom_array[can_be_crop_center].atom_id)
|
|
448
|
+
crop_center_atom_idx = atom_id_to_atom_idx(atom_array, crop_center_atom_id)
|
|
449
|
+
|
|
450
|
+
# ... sample crop, excluding the binder polymer/non-polymer unit
|
|
451
|
+
token_coords = get_af3_token_center_coords(atom_array)
|
|
452
|
+
crop_center_token_idx = atom_id_to_token_idx(atom_array, crop_center_atom_id)
|
|
453
|
+
is_token_in_crop = get_spatial_crop_excluding_mask(
|
|
454
|
+
token_coords,
|
|
455
|
+
crop_center_token_idx,
|
|
456
|
+
crop_size=crop_size
|
|
457
|
+
- n_binder_tokens, # reserve space for the binder tokens
|
|
458
|
+
mask_to_exclude=binder_token_mask,
|
|
459
|
+
jitter_scale=jitter_scale,
|
|
460
|
+
)
|
|
461
|
+
# ... spread token-level crop mask to atom-level
|
|
462
|
+
is_atom_in_crop = spread_token_wise(atom_array, is_token_in_crop)
|
|
463
|
+
|
|
464
|
+
# ... add in binder tokens and atoms
|
|
465
|
+
is_token_in_crop = is_token_in_crop | binder_token_mask
|
|
466
|
+
is_atom_in_crop = is_atom_in_crop | binder_atom_mask
|
|
467
|
+
else:
|
|
468
|
+
# ... no need to crop since the atom array is already small enough
|
|
469
|
+
crop_center_atom_id = np.nan
|
|
470
|
+
crop_center_atom_idx = np.nan
|
|
471
|
+
crop_center_token_idx = np.nan
|
|
472
|
+
is_atom_in_crop = np.ones(len(atom_array), dtype=bool)
|
|
473
|
+
is_token_in_crop = np.ones(n_tokens, dtype=bool)
|
|
474
|
+
|
|
475
|
+
return {
|
|
476
|
+
"requires_crop": requires_crop, # whether cropping was necessary
|
|
477
|
+
"crop_center_atom_id": crop_center_atom_id, # atom_id of crop center
|
|
478
|
+
"crop_center_atom_idx": crop_center_atom_idx, # atom_idx of crop center
|
|
479
|
+
"crop_center_token_idx": crop_center_token_idx, # token_idx of crop center
|
|
480
|
+
"crop_token_idxs": np.where(is_token_in_crop)[0], # token_idxs in crop
|
|
481
|
+
"crop_atom_idxs": np.where(is_atom_in_crop)[0], # atom_idxs in crop
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def get_spatial_crop_excluding_mask(
|
|
486
|
+
coord: np.ndarray,
|
|
487
|
+
crop_center_idx: int,
|
|
488
|
+
crop_size: int,
|
|
489
|
+
mask_to_exclude: np.ndarray,
|
|
490
|
+
jitter_scale: float = 1e-3,
|
|
491
|
+
) -> np.ndarray:
|
|
492
|
+
"""
|
|
493
|
+
Crop spatial tokens around a given `crop_center`, keeping nearest neighbors (with jitter) and excluding atoms in a
|
|
494
|
+
specified mask, until reaching the `crop_size`.
|
|
495
|
+
|
|
496
|
+
Args:
|
|
497
|
+
coord (np.ndarray): A 2D numpy array of shape (N, 3) representing the 3D token-level coordinates.
|
|
498
|
+
Coordinates are expected to be in Angstroms.
|
|
499
|
+
crop_center_idx (int): The index of the token to be used as the center of the crop.
|
|
500
|
+
crop_size (int): The number of nearest neighbors to include in the crop.
|
|
501
|
+
mask_to_exclude (siwnp.ndarray): A mask indicating atoms to be excluded from the crop.
|
|
502
|
+
jitter_scale (float): The scale of the jitter to add to the coordinates.
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
crop_mask (np.ndarray): A boolean mask of shape (N,) where True indicates that the token is within the crop.
|
|
506
|
+
|
|
507
|
+
"""
|
|
508
|
+
assert_that(coord.ndim).is_equal_to(2)
|
|
509
|
+
assert_that(coord.shape[1]).is_equal_to(3)
|
|
510
|
+
assert_that(crop_center_idx).is_less_than(coord.shape[0])
|
|
511
|
+
assert_that(crop_size).is_greater_than(0)
|
|
512
|
+
assert_that(jitter_scale).is_greater_than_or_equal_to(0)
|
|
513
|
+
|
|
514
|
+
# Add small jitter to coordinates to break ties
|
|
515
|
+
if jitter_scale > 0:
|
|
516
|
+
coord = coord + np.random.normal(scale=jitter_scale, size=coord.shape)
|
|
517
|
+
|
|
518
|
+
# ... get query center
|
|
519
|
+
query_center = coord[crop_center_idx]
|
|
520
|
+
|
|
521
|
+
# ... extract a mask for valid coordinates (i.e. no `nan`'s, which indicate unknown token centers)
|
|
522
|
+
# including including unoccupied tokens in the crop
|
|
523
|
+
is_valid = np.isfinite(coord).all(axis=1)
|
|
524
|
+
|
|
525
|
+
# ... exclude the specified pn_unit
|
|
526
|
+
is_valid = is_valid & ~mask_to_exclude
|
|
527
|
+
|
|
528
|
+
# ... build a KDTree for efficient querying, excluding invalid coordinates
|
|
529
|
+
tree = KDTree(coord[is_valid])
|
|
530
|
+
|
|
531
|
+
# ... query the `crop_size` nearest neighbors of the crop center
|
|
532
|
+
_, nearest_neighbor_idxs = tree.query(query_center, k=crop_size, p=2)
|
|
533
|
+
# ... filter out missing neighbours (index equal to `tree.n`)
|
|
534
|
+
nearest_neighbor_idxs = nearest_neighbor_idxs[nearest_neighbor_idxs < tree.n]
|
|
535
|
+
|
|
536
|
+
# ... crop mask is True for the `crop_size` nearest neighbors of the crop center
|
|
537
|
+
crop_mask = np.zeros(coord.shape[0], dtype=bool)
|
|
538
|
+
is_valid_and_in_crop_idxs = np.where(is_valid)[0][nearest_neighbor_idxs]
|
|
539
|
+
crop_mask[is_valid_and_in_crop_idxs] = True
|
|
540
|
+
|
|
541
|
+
return crop_mask
|
rfd3/transforms/rasa.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from atomworks.ml.transforms.base import Transform
|
|
3
|
+
from atomworks.ml.transforms.sasa import calculate_atomwise_rasa
|
|
4
|
+
from atomworks.ml.utils.token import apply_and_spread_token_wise
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class CalculateRASA(Transform):
|
|
8
|
+
"""Transform for calculating relative SASA (RASA) for each atom in an AtomArray."""
|
|
9
|
+
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
probe_radius: float = 1.4,
|
|
13
|
+
atom_radii: str | np.ndarray = "ProtOr",
|
|
14
|
+
point_number: int = 100,
|
|
15
|
+
requires_ligand=False,
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
probe_radius (float, optional): Van-der-Waals radius of the probe in Angstrom. Defaults to 1.4 (for water).
|
|
19
|
+
atom_radii (str | np.ndarray, optional): Atom radii set to use for calculation. Defaults to "ProtOr".
|
|
20
|
+
"ProtOr" will not get sasa's for hydrogen atoms and some other atoms, like ions or certain atoms with charges
|
|
21
|
+
point_number (int, optional): Number of points in the Shrake-Rupley algorithm to sample for calculating SASA. Defaults to 100.
|
|
22
|
+
"""
|
|
23
|
+
self.probe_radius = probe_radius
|
|
24
|
+
self.atom_radii = atom_radii
|
|
25
|
+
self.point_number = point_number
|
|
26
|
+
self.requires_ligand = requires_ligand
|
|
27
|
+
|
|
28
|
+
def forward(self, data):
|
|
29
|
+
atom_array = data["atom_array"]
|
|
30
|
+
|
|
31
|
+
if not np.any(atom_array.is_ligand) and self.requires_ligand:
|
|
32
|
+
return data
|
|
33
|
+
|
|
34
|
+
# Calculate exact rasa
|
|
35
|
+
rasa = calculate_atomwise_rasa(
|
|
36
|
+
atom_array, self.probe_radius, self.atom_radii, self.point_number
|
|
37
|
+
)
|
|
38
|
+
atom_array.set_annotation("rasa", rasa)
|
|
39
|
+
|
|
40
|
+
data["atom_array"] = atom_array
|
|
41
|
+
return data
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def discretize_rasa(atom_array, low=0, high=0.2, n_bins=3, keep_protein_motif=False):
|
|
45
|
+
inclusion_mask = ~np.isnan(atom_array.rasa)
|
|
46
|
+
inclusion_mask = inclusion_mask & atom_array.is_motif_token
|
|
47
|
+
if not keep_protein_motif:
|
|
48
|
+
inclusion_mask = inclusion_mask & ~atom_array.is_protein
|
|
49
|
+
|
|
50
|
+
bin_edges = np.linspace(low, high, n_bins) # e.g., [0.0, 0.1, 0.2]
|
|
51
|
+
bins = (
|
|
52
|
+
np.digitize(atom_array.rasa, bin_edges, right=False)
|
|
53
|
+
- 1 # Subtract 1 since first bin would mean negative rasa!
|
|
54
|
+
) # bins in [0, n_bins-1]
|
|
55
|
+
bins[~inclusion_mask] = n_bins # Assign excluded atoms to an additional, unused bin
|
|
56
|
+
return bins
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class SetZeroOccOnDeltaRASA(Transform):
|
|
60
|
+
"""
|
|
61
|
+
Recomputes RASA and sets zero-occupancy for those that have become significantly exposed
|
|
62
|
+
|
|
63
|
+
Used to measure if the atomwise RASA changed during cropping
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
requires_previous_transforms = [CalculateRASA]
|
|
67
|
+
incompatible_previous_transforms = [
|
|
68
|
+
"PadWithVirtualAtoms", # must have the same atom names
|
|
69
|
+
"CreateDesignReferenceFeatures",
|
|
70
|
+
"AggregateFeaturesLikeAF3WithoutMSA",
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
probe_radius: float = 1.4,
|
|
76
|
+
atom_radii: str | np.ndarray = "ProtOr",
|
|
77
|
+
point_number: int = 100,
|
|
78
|
+
):
|
|
79
|
+
self.probe_radius = probe_radius
|
|
80
|
+
self.atom_radii = atom_radii
|
|
81
|
+
self.point_number = point_number
|
|
82
|
+
|
|
83
|
+
def check_input(self, data):
|
|
84
|
+
assert "rasa" in data["atom_array"].get_annotation_categories()
|
|
85
|
+
|
|
86
|
+
def forward(self, data):
|
|
87
|
+
atom_array = data["atom_array"]
|
|
88
|
+
rasa_old = atom_array.rasa
|
|
89
|
+
|
|
90
|
+
rasa_new = calculate_atomwise_rasa(
|
|
91
|
+
atom_array, self.probe_radius, self.atom_radii, self.point_number
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
delta_rasa = np.clip(rasa_new, a_min=0, a_max=0.2) - np.clip(
|
|
95
|
+
rasa_old, a_min=0, a_max=0.2
|
|
96
|
+
)
|
|
97
|
+
has_become_exposed = np.nan_to_num(delta_rasa) > 0.075
|
|
98
|
+
token_has_become_exposed = apply_and_spread_token_wise(
|
|
99
|
+
atom_array,
|
|
100
|
+
has_become_exposed,
|
|
101
|
+
function=lambda x: np.any(x),
|
|
102
|
+
)
|
|
103
|
+
is_sidechain = (
|
|
104
|
+
~np.isin(atom_array.atom_name, ["N", "CA", "C", "O"])
|
|
105
|
+
& atom_array.is_residue
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Set zero occupancy for sidechains only
|
|
109
|
+
atom_has_become_exposed = token_has_become_exposed & is_sidechain
|
|
110
|
+
|
|
111
|
+
atom_array.occupancy[atom_has_become_exposed] = 0.0
|
|
112
|
+
# atom_array.res_name[token_has_become_exposed] = "UNK"
|
|
113
|
+
|
|
114
|
+
data["atom_array"] = atom_array
|
|
115
|
+
|
|
116
|
+
return data
|