rc-foundry 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- foundry/__init__.py +57 -0
- foundry/callbacks/__init__.py +5 -0
- foundry/callbacks/callback.py +116 -0
- foundry/callbacks/health_logging.py +419 -0
- foundry/callbacks/metrics_logging.py +211 -0
- foundry/callbacks/timing_logging.py +67 -0
- foundry/callbacks/train_logging.py +278 -0
- foundry/common.py +108 -0
- foundry/constants.py +28 -0
- foundry/hydra/resolvers.py +77 -0
- foundry/inference_engines/base.py +235 -0
- foundry/inference_engines/checkpoint_registry.py +66 -0
- foundry/metrics/__init__.py +12 -0
- foundry/metrics/losses.py +30 -0
- foundry/metrics/metric.py +319 -0
- foundry/model/layers/blocks.py +47 -0
- foundry/testing/__init__.py +6 -0
- foundry/testing/fixtures.py +19 -0
- foundry/testing/pytest_hooks.py +15 -0
- foundry/trainers/fabric.py +923 -0
- foundry/training/EMA.py +67 -0
- foundry/training/checkpoint.py +61 -0
- foundry/training/schedulers.py +91 -0
- foundry/utils/alignment.py +86 -0
- foundry/utils/components.py +415 -0
- foundry/utils/datasets.py +405 -0
- foundry/utils/ddp.py +103 -0
- foundry/utils/instantiators.py +72 -0
- foundry/utils/logging.py +279 -0
- foundry/utils/rigid.py +1460 -0
- foundry/utils/rotation_augmentation.py +65 -0
- foundry/utils/squashfs.py +172 -0
- foundry/utils/torch.py +317 -0
- foundry/utils/weights.py +271 -0
- foundry/version.py +34 -0
- foundry_cli/__init__.py +3 -0
- foundry_cli/download_checkpoints.py +281 -0
- mpnn/__init__.py +1 -0
- mpnn/collate/feature_collator.py +265 -0
- mpnn/inference.py +53 -0
- mpnn/inference_engines/mpnn.py +549 -0
- mpnn/loss/nll_loss.py +122 -0
- mpnn/metrics/nll.py +369 -0
- mpnn/metrics/sequence_recovery.py +440 -0
- mpnn/model/layers/graph_embeddings.py +2372 -0
- mpnn/model/layers/message_passing.py +332 -0
- mpnn/model/layers/position_wise_feed_forward.py +44 -0
- mpnn/model/layers/positional_encoding.py +98 -0
- mpnn/model/mpnn.py +2632 -0
- mpnn/pipelines/mpnn.py +162 -0
- mpnn/samplers/samplers.py +167 -0
- mpnn/train.py +341 -0
- mpnn/trainers/mpnn.py +193 -0
- mpnn/transforms/feature_aggregation/mpnn.py +184 -0
- mpnn/transforms/feature_aggregation/polymer_ligand_interface.py +76 -0
- mpnn/transforms/feature_aggregation/token_encodings.py +132 -0
- mpnn/transforms/feature_aggregation/user_settings.py +347 -0
- mpnn/transforms/polymer_ligand_interface.py +164 -0
- mpnn/utils/inference.py +2397 -0
- mpnn/utils/probability.py +37 -0
- mpnn/utils/weights.py +309 -0
- rc_foundry-0.1.1.dist-info/METADATA +239 -0
- rc_foundry-0.1.1.dist-info/RECORD +180 -0
- rc_foundry-0.1.1.dist-info/WHEEL +4 -0
- rc_foundry-0.1.1.dist-info/entry_points.txt +5 -0
- rc_foundry-0.1.1.dist-info/licenses/LICENSE.md +28 -0
- rf3/__init__.py +3 -0
- rf3/_version.py +33 -0
- rf3/alignment.py +79 -0
- rf3/callbacks/dump_validation_structures.py +101 -0
- rf3/callbacks/metrics_logging.py +324 -0
- rf3/chemical.py +1529 -0
- rf3/cli.py +77 -0
- rf3/data/cyclic_transform.py +78 -0
- rf3/data/extra_xforms.py +36 -0
- rf3/data/ground_truth_template.py +463 -0
- rf3/data/paired_msa.py +206 -0
- rf3/data/pipeline_utils.py +128 -0
- rf3/data/pipelines.py +558 -0
- rf3/diffusion_samplers/inference_sampler.py +222 -0
- rf3/inference.py +65 -0
- rf3/inference_engines/__init__.py +5 -0
- rf3/inference_engines/rf3.py +735 -0
- rf3/kinematics.py +354 -0
- rf3/loss/af3_confidence_loss.py +515 -0
- rf3/loss/af3_losses.py +655 -0
- rf3/loss/loss.py +179 -0
- rf3/metrics/chiral.py +179 -0
- rf3/metrics/clashing_chains.py +68 -0
- rf3/metrics/distogram.py +421 -0
- rf3/metrics/lddt.py +523 -0
- rf3/metrics/metadata.py +43 -0
- rf3/metrics/metric_utils.py +192 -0
- rf3/metrics/predicted_error.py +134 -0
- rf3/metrics/rasa.py +108 -0
- rf3/metrics/selected_distances.py +91 -0
- rf3/model/RF3.py +527 -0
- rf3/model/RF3_blocks.py +92 -0
- rf3/model/RF3_structure.py +303 -0
- rf3/model/layers/af3_auxiliary_heads.py +255 -0
- rf3/model/layers/af3_diffusion_transformer.py +544 -0
- rf3/model/layers/attention.py +313 -0
- rf3/model/layers/layer_utils.py +127 -0
- rf3/model/layers/mlff.py +118 -0
- rf3/model/layers/outer_product.py +59 -0
- rf3/model/layers/pairformer_layers.py +783 -0
- rf3/model/layers/structure_bias.py +56 -0
- rf3/scoring.py +1787 -0
- rf3/symmetry/resolve.py +284 -0
- rf3/train.py +194 -0
- rf3/trainers/rf3.py +570 -0
- rf3/util_module.py +47 -0
- rf3/utils/frames.py +109 -0
- rf3/utils/inference.py +665 -0
- rf3/utils/io.py +198 -0
- rf3/utils/loss.py +72 -0
- rf3/utils/predict_and_score.py +165 -0
- rf3/utils/predicted_error.py +673 -0
- rf3/utils/recycling.py +42 -0
- rf3/validate.py +140 -0
- rfd3/.gitignore +7 -0
- rfd3/Makefile +76 -0
- rfd3/__init__.py +12 -0
- rfd3/callbacks.py +66 -0
- rfd3/cli.py +41 -0
- rfd3/constants.py +212 -0
- rfd3/engine.py +543 -0
- rfd3/inference/datasets.py +193 -0
- rfd3/inference/input_parsing.py +1123 -0
- rfd3/inference/legacy_input_parsing.py +717 -0
- rfd3/inference/parsing.py +165 -0
- rfd3/inference/symmetry/atom_array.py +298 -0
- rfd3/inference/symmetry/checks.py +241 -0
- rfd3/inference/symmetry/contigs.py +63 -0
- rfd3/inference/symmetry/frames.py +355 -0
- rfd3/inference/symmetry/symmetry_utils.py +398 -0
- rfd3/metrics/design_metrics.py +465 -0
- rfd3/metrics/hbonds_hbplus_metrics.py +308 -0
- rfd3/metrics/hbonds_metrics.py +389 -0
- rfd3/metrics/losses.py +325 -0
- rfd3/metrics/metrics_utils.py +118 -0
- rfd3/metrics/sidechain_metrics.py +349 -0
- rfd3/model/RFD3.py +105 -0
- rfd3/model/RFD3_diffusion_module.py +387 -0
- rfd3/model/cfg_utils.py +81 -0
- rfd3/model/inference_sampler.py +635 -0
- rfd3/model/layers/attention.py +577 -0
- rfd3/model/layers/block_utils.py +580 -0
- rfd3/model/layers/blocks.py +777 -0
- rfd3/model/layers/chunked_pairwise.py +377 -0
- rfd3/model/layers/encoders.py +417 -0
- rfd3/model/layers/layer_utils.py +197 -0
- rfd3/model/layers/pairformer_layers.py +128 -0
- rfd3/run_inference.py +45 -0
- rfd3/testing/debug.py +139 -0
- rfd3/testing/debug_utils.py +73 -0
- rfd3/testing/testing_utils.py +356 -0
- rfd3/train.py +194 -0
- rfd3/trainer/dump_validation_structures.py +154 -0
- rfd3/trainer/fabric_trainer.py +923 -0
- rfd3/trainer/recycling.py +42 -0
- rfd3/trainer/rfd3.py +485 -0
- rfd3/trainer/trainer_utils.py +502 -0
- rfd3/transforms/conditioning_base.py +508 -0
- rfd3/transforms/conditioning_utils.py +200 -0
- rfd3/transforms/design_transforms.py +807 -0
- rfd3/transforms/dna_crop.py +523 -0
- rfd3/transforms/hbonds.py +407 -0
- rfd3/transforms/hbonds_hbplus.py +246 -0
- rfd3/transforms/ncaa_transforms.py +153 -0
- rfd3/transforms/pipelines.py +632 -0
- rfd3/transforms/ppi_transforms.py +541 -0
- rfd3/transforms/rasa.py +116 -0
- rfd3/transforms/symmetry.py +76 -0
- rfd3/transforms/training_conditions.py +552 -0
- rfd3/transforms/util_transforms.py +498 -0
- rfd3/transforms/virtual_atoms.py +305 -0
- rfd3/utils/inference.py +648 -0
- rfd3/utils/io.py +245 -0
- rfd3/utils/vizualize.py +276 -0
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import biotite.structure as struc
|
|
4
|
+
import numpy as np
|
|
5
|
+
from rfd3.constants import (
|
|
6
|
+
ATOM14_ATOM_NAMES,
|
|
7
|
+
association_schemes_stripped,
|
|
8
|
+
)
|
|
9
|
+
from rfd3.transforms.conditioning_base import get_motif_features
|
|
10
|
+
from rfd3.transforms.hbonds_hbplus import calculate_hbonds
|
|
11
|
+
|
|
12
|
+
from foundry.metrics.metric import Metric
|
|
13
|
+
from foundry.utils.ddp import RankedLogger
|
|
14
|
+
|
|
15
|
+
logging.basicConfig(level=logging.INFO)
|
|
16
|
+
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def simplified_processing_atom_array(atom_arrays, central_atom="CB", threshold=0.5):
|
|
20
|
+
"""
|
|
21
|
+
Allows for sequence extraction from cleaned up virtual atoms. Needed for hbond metrics.
|
|
22
|
+
"""
|
|
23
|
+
final_atom_array = []
|
|
24
|
+
|
|
25
|
+
for atom_array in atom_arrays:
|
|
26
|
+
cur_atom_array_list = []
|
|
27
|
+
|
|
28
|
+
res_ids = atom_array.res_id
|
|
29
|
+
res_start_indices = np.concatenate(
|
|
30
|
+
[[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
|
|
31
|
+
)
|
|
32
|
+
res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
|
|
33
|
+
|
|
34
|
+
for start, end in zip(res_start_indices, res_end_indices):
|
|
35
|
+
cur_res_atom_array = atom_array[start:end]
|
|
36
|
+
|
|
37
|
+
# Check if the current residue is after padding (seq unknown):
|
|
38
|
+
if_seq_known = not any(
|
|
39
|
+
atom_name.startswith("V") for atom_name in cur_res_atom_array.atom_name
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
if not if_seq_known:
|
|
43
|
+
# Glycine fallback to CA
|
|
44
|
+
CA_coord = cur_res_atom_array.coord[
|
|
45
|
+
cur_res_atom_array.atom_name == "CA"
|
|
46
|
+
]
|
|
47
|
+
CB_coord = cur_res_atom_array.coord[
|
|
48
|
+
cur_res_atom_array.atom_name == "CB"
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
if np.linalg.norm(CA_coord - CB_coord) < threshold:
|
|
52
|
+
central_atom = "CA"
|
|
53
|
+
|
|
54
|
+
central_mask = cur_res_atom_array.atom_name == central_atom
|
|
55
|
+
central_coord = cur_res_atom_array.coord[central_mask][0]
|
|
56
|
+
dists = np.linalg.norm(
|
|
57
|
+
cur_res_atom_array.coord - central_coord, axis=-1
|
|
58
|
+
)
|
|
59
|
+
is_virtual = (dists < threshold) & ~central_mask
|
|
60
|
+
|
|
61
|
+
cur_res_atom_array = cur_res_atom_array[~is_virtual]
|
|
62
|
+
cur_pred_res_atom_names = cur_res_atom_array.atom_name
|
|
63
|
+
|
|
64
|
+
has_restype_assigned = False
|
|
65
|
+
for restype, atom_names in association_schemes_stripped[
|
|
66
|
+
"atom14"
|
|
67
|
+
].items():
|
|
68
|
+
if restype in ["UNK", "MSK"]:
|
|
69
|
+
continue
|
|
70
|
+
atom_names = np.array(atom_names)
|
|
71
|
+
atom_name_idx = np.array(
|
|
72
|
+
[
|
|
73
|
+
np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
|
|
74
|
+
for atom_name in cur_pred_res_atom_names
|
|
75
|
+
]
|
|
76
|
+
)
|
|
77
|
+
atom14_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
|
|
78
|
+
atom14_mask[atom_name_idx] = True
|
|
79
|
+
|
|
80
|
+
if all(x is not None for x in atom_names[atom14_mask]) and all(
|
|
81
|
+
x is None for x in atom_names[~atom14_mask]
|
|
82
|
+
):
|
|
83
|
+
cur_res_atom_array.res_name = np.array(
|
|
84
|
+
[restype] * len(cur_res_atom_array)
|
|
85
|
+
)
|
|
86
|
+
cur_res_atom_array.atom_name = np.asarray(
|
|
87
|
+
atom_names[atom14_mask], dtype=str
|
|
88
|
+
)
|
|
89
|
+
cur_atom_array_list.append(cur_res_atom_array)
|
|
90
|
+
has_restype_assigned = True
|
|
91
|
+
break
|
|
92
|
+
else:
|
|
93
|
+
cur_atom_array_list.append(cur_res_atom_array)
|
|
94
|
+
has_restype_assigned = True
|
|
95
|
+
|
|
96
|
+
if not has_restype_assigned:
|
|
97
|
+
cur_res_atom_array.res_name = np.array(
|
|
98
|
+
["UNK"] * len(cur_res_atom_array)
|
|
99
|
+
)
|
|
100
|
+
cur_atom_array_list.append(cur_res_atom_array)
|
|
101
|
+
|
|
102
|
+
cur_atom_array = struc.concatenate(cur_atom_array_list)
|
|
103
|
+
cur_atom_array.element = struc.infer_elements(cur_atom_array.atom_name)
|
|
104
|
+
final_atom_array.append(cur_atom_array)
|
|
105
|
+
|
|
106
|
+
return final_atom_array
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def calculate_hbond_stats(
|
|
110
|
+
input_atom_array_stack,
|
|
111
|
+
output_atom_array_stack,
|
|
112
|
+
cutoff_HA_dist=3,
|
|
113
|
+
cutoff_DA_distance=3.5,
|
|
114
|
+
inference_metrics=False,
|
|
115
|
+
):
|
|
116
|
+
output_atom_array_stack = simplified_processing_atom_array(output_atom_array_stack)
|
|
117
|
+
assert len(input_atom_array_stack) == len(output_atom_array_stack)
|
|
118
|
+
|
|
119
|
+
total_correct_donors_percent = 0.0
|
|
120
|
+
total_correct_acceptors_percent = 0.0
|
|
121
|
+
total_number_donors_acceptors = 0
|
|
122
|
+
total_number_hbonds = 0
|
|
123
|
+
num_valid_samples = 0
|
|
124
|
+
|
|
125
|
+
for input_atom_array, output_atom_array in zip(
|
|
126
|
+
input_atom_array_stack, output_atom_array_stack
|
|
127
|
+
):
|
|
128
|
+
# Ensure required annotations exist
|
|
129
|
+
for annotation in ["active_donor", "active_acceptor"]:
|
|
130
|
+
if annotation not in input_atom_array.get_annotation_categories():
|
|
131
|
+
input_atom_array.set_annotation(
|
|
132
|
+
annotation, np.zeros(len(input_atom_array), dtype=bool)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Skip samples with no donors or acceptors
|
|
136
|
+
if (
|
|
137
|
+
np.sum(input_atom_array.active_donor) == 0
|
|
138
|
+
and np.sum(input_atom_array.active_acceptor) == 0
|
|
139
|
+
):
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
# Clean up coordinate annotations
|
|
143
|
+
for atom_array in [input_atom_array, output_atom_array]:
|
|
144
|
+
if "coord_to_be_noised" in atom_array.get_annotation_categories():
|
|
145
|
+
atom_array.del_annotation("coord_to_be_noised")
|
|
146
|
+
|
|
147
|
+
# Calculate hydrogen bonds
|
|
148
|
+
output_atom_array, hbonds, motif_diffused_hbond_count = calculate_hbonds(
|
|
149
|
+
output_atom_array,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
# Update hbond annotations for motif atoms only
|
|
153
|
+
hbond_types = np.vstack(
|
|
154
|
+
(output_atom_array.active_donor, output_atom_array.active_acceptor)
|
|
155
|
+
).T
|
|
156
|
+
motif_mask = np.array(get_motif_features(output_atom_array)["is_motif_atom"])
|
|
157
|
+
hbond_types[:, 0] *= motif_mask
|
|
158
|
+
hbond_types[:, 1] *= motif_mask
|
|
159
|
+
|
|
160
|
+
output_atom_array.set_annotation("active_donor", hbond_types[:, 0])
|
|
161
|
+
output_atom_array.set_annotation("active_acceptor", hbond_types[:, 1])
|
|
162
|
+
|
|
163
|
+
# Count correct predictions
|
|
164
|
+
correct_donors = _count_correct_hbond_atoms(
|
|
165
|
+
input_atom_array, output_atom_array, "active_donor"
|
|
166
|
+
)
|
|
167
|
+
correct_acceptors = _count_correct_hbond_atoms(
|
|
168
|
+
input_atom_array, output_atom_array, "active_acceptor"
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Calculate percentages
|
|
172
|
+
given_donors = np.sum(input_atom_array.active_donor)
|
|
173
|
+
given_acceptors = np.sum(input_atom_array.active_acceptor)
|
|
174
|
+
|
|
175
|
+
correct_donor_pct = correct_donors / given_donors if given_donors > 0 else 1.0
|
|
176
|
+
correct_acceptor_pct = (
|
|
177
|
+
correct_acceptors / given_acceptors if given_acceptors > 0 else 1.0
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Accumulate totals
|
|
181
|
+
total_correct_donors_percent += correct_donor_pct
|
|
182
|
+
total_correct_acceptors_percent += correct_acceptor_pct
|
|
183
|
+
total_number_donors_acceptors += np.sum(hbond_types)
|
|
184
|
+
total_number_hbonds += motif_diffused_hbond_count
|
|
185
|
+
num_valid_samples += 1
|
|
186
|
+
|
|
187
|
+
if num_valid_samples == 0:
|
|
188
|
+
if inference_metrics:
|
|
189
|
+
return {
|
|
190
|
+
"correct_donor_percent": "",
|
|
191
|
+
"correct_acceptor_percent": "",
|
|
192
|
+
"num_hbonds": "",
|
|
193
|
+
"hbonds": [],
|
|
194
|
+
"total_number_donors_acceptors": "",
|
|
195
|
+
"output_atom_array": None,
|
|
196
|
+
}
|
|
197
|
+
return 0, 0, 0
|
|
198
|
+
|
|
199
|
+
avg_donor_pct = total_correct_donors_percent / num_valid_samples
|
|
200
|
+
avg_acceptor_pct = total_correct_acceptors_percent / num_valid_samples
|
|
201
|
+
avg_hbonds = total_number_hbonds / num_valid_samples
|
|
202
|
+
|
|
203
|
+
if inference_metrics:
|
|
204
|
+
return {
|
|
205
|
+
"correct_donor_percent": avg_donor_pct,
|
|
206
|
+
"correct_acceptor_percent": avg_acceptor_pct,
|
|
207
|
+
"num_hbonds": avg_hbonds,
|
|
208
|
+
"hbonds": hbonds,
|
|
209
|
+
"total_number_donors_acceptors": total_number_donors_acceptors,
|
|
210
|
+
"output_atom_array": output_atom_array,
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
# Return results
|
|
214
|
+
if num_valid_samples == 0:
|
|
215
|
+
return 0, 0, 0
|
|
216
|
+
|
|
217
|
+
return avg_donor_pct, avg_acceptor_pct, avg_hbonds
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _count_correct_hbond_atoms(input_atom_array, output_atom_array, annotation_type):
|
|
221
|
+
"""Count correctly predicted hydrogen bond atoms."""
|
|
222
|
+
correct_count = 0
|
|
223
|
+
target_indices = np.where(getattr(input_atom_array, annotation_type) == 1)[0]
|
|
224
|
+
|
|
225
|
+
for idx in target_indices:
|
|
226
|
+
matching_atoms = output_atom_array[
|
|
227
|
+
(output_atom_array.chain_iid == input_atom_array.chain_iid[idx])
|
|
228
|
+
& (output_atom_array.res_id == input_atom_array.res_id[idx])
|
|
229
|
+
& (output_atom_array.atom_name == input_atom_array.gt_atom_name[idx])
|
|
230
|
+
]
|
|
231
|
+
|
|
232
|
+
if len(matching_atoms) > 0 and bool(getattr(matching_atoms, annotation_type)):
|
|
233
|
+
correct_count += 1
|
|
234
|
+
|
|
235
|
+
return correct_count
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def get_hbond_metrics(atom_array=None):
|
|
239
|
+
if atom_array is None:
|
|
240
|
+
global_logger.warning("atom_array is None")
|
|
241
|
+
return None
|
|
242
|
+
|
|
243
|
+
try:
|
|
244
|
+
output = calculate_hbond_stats(
|
|
245
|
+
[atom_array.copy()], [atom_array.copy()], inference_metrics=True
|
|
246
|
+
)
|
|
247
|
+
hbonds = output["hbonds"]
|
|
248
|
+
|
|
249
|
+
o = {
|
|
250
|
+
"donor_atom_names": list(
|
|
251
|
+
set(f"{hb['d_atom']}_{hb['d_resn']}_{hb['d_resi']}" for hb in hbonds)
|
|
252
|
+
),
|
|
253
|
+
"acceptor_atom_names": list(
|
|
254
|
+
set(f"{hb['a_atom']}_{hb['a_resn']}_{hb['a_resi']}" for hb in hbonds)
|
|
255
|
+
),
|
|
256
|
+
"hbond_connections": list(
|
|
257
|
+
set(
|
|
258
|
+
f"{hb['d_atom']}_{hb['d_resn']}_{hb['d_resi']}-{hb['a_atom']}_{hb['a_resn']}_{hb['a_resi']}"
|
|
259
|
+
for hb in hbonds
|
|
260
|
+
)
|
|
261
|
+
),
|
|
262
|
+
"correct_donor_percent": float(output["correct_donor_percent"]),
|
|
263
|
+
"correct_acceptor_percent": float(output["correct_acceptor_percent"]),
|
|
264
|
+
"num_hbonds": float(output["num_hbonds"]),
|
|
265
|
+
}
|
|
266
|
+
return o
|
|
267
|
+
|
|
268
|
+
except Exception as e:
|
|
269
|
+
global_logger.warning(f"Could not calculate hbond metrics: {e}")
|
|
270
|
+
return {}
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class HbondMetrics(Metric):
|
|
274
|
+
def __init__(
|
|
275
|
+
self,
|
|
276
|
+
cutoff_HA_dist: float = 3,
|
|
277
|
+
cutoff_DA_distance: float = 3.5,
|
|
278
|
+
):
|
|
279
|
+
super().__init__()
|
|
280
|
+
self.cutoff_HA_dist = cutoff_HA_dist
|
|
281
|
+
self.cutoff_DA_distance = cutoff_DA_distance
|
|
282
|
+
|
|
283
|
+
@property
|
|
284
|
+
def kwargs_to_compute_args(self):
|
|
285
|
+
return {
|
|
286
|
+
"ground_truth_atom_array_stack": ("ground_truth_atom_array_stack",),
|
|
287
|
+
"predicted_atom_array_stack": ("predicted_atom_array_stack",),
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
def compute(self, *, ground_truth_atom_array_stack, predicted_atom_array_stack):
|
|
291
|
+
try:
|
|
292
|
+
d_pct, a_pct, n_hbonds = calculate_hbond_stats(
|
|
293
|
+
input_atom_array_stack=ground_truth_atom_array_stack,
|
|
294
|
+
output_atom_array_stack=predicted_atom_array_stack,
|
|
295
|
+
cutoff_HA_dist=self.cutoff_HA_dist,
|
|
296
|
+
cutoff_DA_distance=self.cutoff_DA_distance,
|
|
297
|
+
)
|
|
298
|
+
except Exception as e:
|
|
299
|
+
global_logger.error(
|
|
300
|
+
f"Error calculating hydrogen bond metrics: {e} | Skipping"
|
|
301
|
+
)
|
|
302
|
+
return {}
|
|
303
|
+
|
|
304
|
+
return {
|
|
305
|
+
"mean_correct_donors_percent": float(d_pct),
|
|
306
|
+
"mean_correct_acceptors_percent": float(a_pct),
|
|
307
|
+
"mean_num_hbonds": float(n_hbonds),
|
|
308
|
+
}
|
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import biotite.structure as struc
|
|
5
|
+
import numpy as np
|
|
6
|
+
from atomworks.enums import ChainType
|
|
7
|
+
from atomworks.io.transforms.atom_array import remove_hydrogens
|
|
8
|
+
from rfd3.constants import (
|
|
9
|
+
ATOM14_ATOM_NAMES,
|
|
10
|
+
SELECTION_NONPROTEIN,
|
|
11
|
+
SELECTION_PROTEIN,
|
|
12
|
+
association_schemes_stripped,
|
|
13
|
+
)
|
|
14
|
+
from rfd3.transforms.hbonds import (
|
|
15
|
+
add_hydrogen_atom_positions,
|
|
16
|
+
calculate_hbonds,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from foundry.metrics.base import Metric
|
|
20
|
+
from foundry.utils.ddp import RankedLogger
|
|
21
|
+
|
|
22
|
+
logging.basicConfig(level=logging.INFO)
|
|
23
|
+
global_logger = RankedLogger(__name__, rank_zero_only=False)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def simplified_processing_atom_array(atom_arrays, central_atom="CB", threshold=0.5):
|
|
27
|
+
"""
|
|
28
|
+
Allows for sequence extraction from cleaned up virtual atoms. Needed for hbond metrics.
|
|
29
|
+
"""
|
|
30
|
+
final_atom_array = []
|
|
31
|
+
for atom_array in atom_arrays:
|
|
32
|
+
cur_atom_array_list = []
|
|
33
|
+
|
|
34
|
+
res_ids = atom_array.res_id
|
|
35
|
+
res_start_indices = np.concatenate(
|
|
36
|
+
[[0], np.where(res_ids[1:] != res_ids[:-1])[0] + 1]
|
|
37
|
+
)
|
|
38
|
+
res_end_indices = np.concatenate([res_start_indices[1:], [len(res_ids)]])
|
|
39
|
+
|
|
40
|
+
for start, end in zip(res_start_indices, res_end_indices):
|
|
41
|
+
cur_res_atom_array = atom_array[start:end]
|
|
42
|
+
|
|
43
|
+
# Check if the current residue is after padding (seq unknown):
|
|
44
|
+
if_seq_known = not any(
|
|
45
|
+
atom_name.startswith("V") for atom_name in cur_res_atom_array.atom_name
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
if not if_seq_known:
|
|
49
|
+
# For Glycine: it doesn't have CB, so set the virtual atom as CA.
|
|
50
|
+
# The current way to handle this is to check if predicted CA and CB are too close, because in the case of glycine and we pad virtual atoms based on CB, CB's coords are set as CA.
|
|
51
|
+
# There might be a better way to do this.
|
|
52
|
+
CA_coord = cur_res_atom_array.coord[
|
|
53
|
+
cur_res_atom_array.atom_name == "CA"
|
|
54
|
+
]
|
|
55
|
+
CB_coord = cur_res_atom_array.coord[
|
|
56
|
+
cur_res_atom_array.atom_name == "CB"
|
|
57
|
+
]
|
|
58
|
+
if np.linalg.norm(CA_coord - CB_coord) < threshold:
|
|
59
|
+
central_atom = "CA"
|
|
60
|
+
|
|
61
|
+
central_mask = cur_res_atom_array.atom_name == central_atom
|
|
62
|
+
|
|
63
|
+
# ... Calculate the distance to the central atom
|
|
64
|
+
central_coord = cur_res_atom_array.coord[central_mask][
|
|
65
|
+
0
|
|
66
|
+
] # Should only have one central atom anyway
|
|
67
|
+
dists = np.linalg.norm(
|
|
68
|
+
cur_res_atom_array.coord - central_coord, axis=-1
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# ... Select virtual atom by the distance. Shouldn't count the central atom itself.
|
|
72
|
+
is_virtual = (dists < threshold) & ~central_mask
|
|
73
|
+
|
|
74
|
+
cur_res_atom_array = cur_res_atom_array[~is_virtual]
|
|
75
|
+
cur_pred_res_atom_names = (
|
|
76
|
+
cur_res_atom_array.atom_name
|
|
77
|
+
) # e.g. [N, CA, C, O, CB, V6, V2]
|
|
78
|
+
|
|
79
|
+
has_restype_assigned = False
|
|
80
|
+
for restype, atom_names in association_schemes_stripped[
|
|
81
|
+
"atom14"
|
|
82
|
+
].items():
|
|
83
|
+
atom_names = np.array(atom_names)
|
|
84
|
+
if restype in ["UNK", "MSK"]:
|
|
85
|
+
continue
|
|
86
|
+
|
|
87
|
+
atom_name_idx_in_atom14_scheme = np.array(
|
|
88
|
+
[
|
|
89
|
+
np.where(ATOM14_ATOM_NAMES == atom_name)[0][0]
|
|
90
|
+
for atom_name in cur_pred_res_atom_names
|
|
91
|
+
]
|
|
92
|
+
) # [0, 1, 2, 3, 4, 11, 7]
|
|
93
|
+
atom14_scheme_mask = np.zeros_like(ATOM14_ATOM_NAMES, dtype=bool)
|
|
94
|
+
atom14_scheme_mask[atom_name_idx_in_atom14_scheme] = True
|
|
95
|
+
if all(
|
|
96
|
+
x is not None for x in atom_names[atom14_scheme_mask]
|
|
97
|
+
) and all(x is None for x in atom_names[~atom14_scheme_mask]):
|
|
98
|
+
cur_res_atom_array.res_name = np.array(
|
|
99
|
+
[restype] * len(cur_res_atom_array)
|
|
100
|
+
)
|
|
101
|
+
cur_res_atom_array.atom_name = np.asarray(
|
|
102
|
+
atom_names[atom14_scheme_mask], dtype=str
|
|
103
|
+
)
|
|
104
|
+
cur_atom_array_list.append(cur_res_atom_array)
|
|
105
|
+
has_restype_assigned = True
|
|
106
|
+
break
|
|
107
|
+
else:
|
|
108
|
+
cur_atom_array_list.append(cur_res_atom_array)
|
|
109
|
+
has_restype_assigned = True
|
|
110
|
+
|
|
111
|
+
if not has_restype_assigned:
|
|
112
|
+
cur_res_atom_array.res_name = np.array(
|
|
113
|
+
["UNK"] * len(cur_res_atom_array)
|
|
114
|
+
)
|
|
115
|
+
cur_atom_array_list.append(cur_res_atom_array)
|
|
116
|
+
|
|
117
|
+
cur_atom_array = struc.concatenate(cur_atom_array_list)
|
|
118
|
+
cur_atom_array.element = struc.infer_elements(cur_atom_array.atom_name)
|
|
119
|
+
|
|
120
|
+
final_atom_array.append(cur_atom_array)
|
|
121
|
+
|
|
122
|
+
return final_atom_array
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
# Training comparison
|
|
126
|
+
def calculate_hbond_stats(
|
|
127
|
+
input_atom_array_stack,
|
|
128
|
+
output_atom_array_stack,
|
|
129
|
+
selection1,
|
|
130
|
+
selection2,
|
|
131
|
+
selection1_type,
|
|
132
|
+
cutoff_dist,
|
|
133
|
+
cutoff_angle,
|
|
134
|
+
donor_elements,
|
|
135
|
+
acceptor_elements,
|
|
136
|
+
periodic,
|
|
137
|
+
):
|
|
138
|
+
"""
|
|
139
|
+
Compare the number of hbonds correctly recapitualted in the output atom array.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
input_atom_array_stack: Input atom array stack
|
|
143
|
+
output_atom_array_stack: Output atom array stack
|
|
144
|
+
selection1: Selection of atom types allowed to be donors (5,6)
|
|
145
|
+
selection2: Selection of atom types allowed to be acceptors (1,2,3...)
|
|
146
|
+
cutoff_dist: Cutoff distance for hbonds
|
|
147
|
+
cutoff_angle: Cutoff angle for hbonds
|
|
148
|
+
"""
|
|
149
|
+
# Used the latest function above, should check if it works correctly
|
|
150
|
+
output_atom_array_stack = simplified_processing_atom_array(output_atom_array_stack)
|
|
151
|
+
|
|
152
|
+
assert len(input_atom_array_stack) == len(
|
|
153
|
+
output_atom_array_stack
|
|
154
|
+
), "Input and output atom arrays must have the same length"
|
|
155
|
+
|
|
156
|
+
total_correct_donors_percent = 0.0
|
|
157
|
+
total_correct_acceptors_percent = 0.0
|
|
158
|
+
total_number_hbonds = 0
|
|
159
|
+
num_valid_samples = 0
|
|
160
|
+
for i in range(len(input_atom_array_stack)):
|
|
161
|
+
correct_donors = 0
|
|
162
|
+
correct_acceptors = 0
|
|
163
|
+
|
|
164
|
+
input_atom_array = input_atom_array_stack[i]
|
|
165
|
+
output_atom_array = output_atom_array_stack[i]
|
|
166
|
+
|
|
167
|
+
if not (
|
|
168
|
+
"active_donor" in input_atom_array.get_annotation_categories()
|
|
169
|
+
or "active_acceptor" in input_atom_array.get_annotation_categories()
|
|
170
|
+
):
|
|
171
|
+
# print("active donor/acceptor not in annotation")
|
|
172
|
+
continue
|
|
173
|
+
if np.sum(input_atom_array.active_donor == 0) and np.sum(
|
|
174
|
+
input_atom_array.active_acceptor == 0
|
|
175
|
+
):
|
|
176
|
+
continue
|
|
177
|
+
|
|
178
|
+
# Select possible donors and acceptors for the model output
|
|
179
|
+
if selection1 is None or selection2 is None:
|
|
180
|
+
continue
|
|
181
|
+
|
|
182
|
+
# Hack: Temporarily use biotite to infer bonds, should be replaced with cifutils?
|
|
183
|
+
output_atom_array.bonds = struc.connect_via_distances(
|
|
184
|
+
output_atom_array, default_bond_type=1
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Hack: delete coords_to_be_diffused (if exists) to temporarily solve a weird bug in create hydrogens. Anyway it will not be used.
|
|
188
|
+
if "coord_to_be_noised" in input_atom_array.get_annotation_categories():
|
|
189
|
+
input_atom_array.del_annotation("coord_to_be_noised")
|
|
190
|
+
if "coord_to_be_noised" in output_atom_array.get_annotation_categories():
|
|
191
|
+
output_atom_array.del_annotation("coord_to_be_noised")
|
|
192
|
+
|
|
193
|
+
output_atom_array = add_hydrogen_atom_positions(output_atom_array)
|
|
194
|
+
|
|
195
|
+
cur_selection1 = np.isin(output_atom_array.chain_type, selection1)
|
|
196
|
+
cur_selection2 = (
|
|
197
|
+
np.isin(output_atom_array.chain_type, selection2)
|
|
198
|
+
| get_motif_features(output_atom_array)["is_motif_atom"]
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
hbonds, hbond_types, output_atom_array = calculate_hbonds(
|
|
202
|
+
output_atom_array,
|
|
203
|
+
cur_selection1,
|
|
204
|
+
cur_selection2,
|
|
205
|
+
selection1_type=selection1_type,
|
|
206
|
+
cutoff_dist=cutoff_dist,
|
|
207
|
+
cutoff_angle=cutoff_angle,
|
|
208
|
+
donor_elements=donor_elements,
|
|
209
|
+
acceptor_elements=acceptor_elements,
|
|
210
|
+
periodic=periodic,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
output_atom_array.set_annotation("active_donor", hbond_types[:, 0])
|
|
214
|
+
output_atom_array.set_annotation("active_acceptor", hbond_types[:, 1])
|
|
215
|
+
|
|
216
|
+
output_atom_array = remove_hydrogens(output_atom_array)
|
|
217
|
+
|
|
218
|
+
given_hbond_donors = np.array(input_atom_array.active_donor, dtype=bool)
|
|
219
|
+
given_hbond_acceptors = np.array(input_atom_array.active_acceptor, dtype=bool)
|
|
220
|
+
given_hbond_donors_index = np.where(input_atom_array.active_donor == 1)[0]
|
|
221
|
+
given_hbond_acceptors_index = np.where(input_atom_array.active_acceptor == 1)[0]
|
|
222
|
+
|
|
223
|
+
# Ensure the produced hbonds matches input hbond requirements: have the same atom type, residue name, and atom name
|
|
224
|
+
for idx in given_hbond_donors_index:
|
|
225
|
+
if bool(
|
|
226
|
+
output_atom_array[
|
|
227
|
+
(output_atom_array.chain_id == input_atom_array.chain_id[idx])
|
|
228
|
+
& (output_atom_array.res_id == input_atom_array.res_id[idx])
|
|
229
|
+
& (
|
|
230
|
+
output_atom_array.atom_name
|
|
231
|
+
== input_atom_array.gt_atom_name[idx]
|
|
232
|
+
)
|
|
233
|
+
].active_donor
|
|
234
|
+
):
|
|
235
|
+
correct_donors += 1
|
|
236
|
+
|
|
237
|
+
for idx in given_hbond_acceptors_index:
|
|
238
|
+
if bool(
|
|
239
|
+
output_atom_array[
|
|
240
|
+
(output_atom_array.chain_id == input_atom_array.chain_id[idx])
|
|
241
|
+
& (output_atom_array.res_id == input_atom_array.res_id[idx])
|
|
242
|
+
& (
|
|
243
|
+
output_atom_array.atom_name
|
|
244
|
+
== input_atom_array.gt_atom_name[idx]
|
|
245
|
+
)
|
|
246
|
+
].active_acceptor
|
|
247
|
+
):
|
|
248
|
+
correct_acceptors += 1
|
|
249
|
+
|
|
250
|
+
correct_hbond_donors_percent = (
|
|
251
|
+
correct_donors / np.sum(given_hbond_donors)
|
|
252
|
+
if np.sum(given_hbond_donors) > 0
|
|
253
|
+
else 1.0
|
|
254
|
+
)
|
|
255
|
+
correct_hbond_acceptors_percent = (
|
|
256
|
+
correct_acceptors / np.sum(given_hbond_acceptors)
|
|
257
|
+
if np.sum(given_hbond_acceptors) > 0
|
|
258
|
+
else 1.0
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
total_correct_donors_percent += correct_hbond_donors_percent
|
|
262
|
+
total_correct_acceptors_percent += correct_hbond_acceptors_percent
|
|
263
|
+
total_number_hbonds += len(hbonds)
|
|
264
|
+
num_valid_samples += 1
|
|
265
|
+
|
|
266
|
+
if num_valid_samples == 0:
|
|
267
|
+
return 0, 0, 0
|
|
268
|
+
return (
|
|
269
|
+
total_correct_donors_percent / num_valid_samples,
|
|
270
|
+
total_correct_acceptors_percent / num_valid_samples,
|
|
271
|
+
total_number_hbonds / num_valid_samples,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# Inference comparison -> tempportary fix to test out sm_hbonds, should be merged with hbond in transforms down the line
|
|
276
|
+
def get_hbond_metrics(atom_array=None):
|
|
277
|
+
if atom_array is None:
|
|
278
|
+
print("WARNING: atom_array is None")
|
|
279
|
+
return None # Or raise a more descriptive error
|
|
280
|
+
|
|
281
|
+
curr_copy = atom_array.copy()
|
|
282
|
+
o = {}
|
|
283
|
+
selection1 = np.array([ChainType.as_enum(item).value for item in SELECTION_PROTEIN])
|
|
284
|
+
selection2 = np.array(
|
|
285
|
+
[ChainType.as_enum(item).value for item in SELECTION_NONPROTEIN]
|
|
286
|
+
)
|
|
287
|
+
# Hack: Temporarily use biotite to infer bonds, should be replaced with cifutils?
|
|
288
|
+
curr_copy.bonds = struc.connect_via_distances(curr_copy, default_bond_type=1)
|
|
289
|
+
# Hack: delete coords_to_be_diffused (if exists) to temporarily solve a weird bug in create hydrogens. Anyway it will not be used.
|
|
290
|
+
if "coord_to_be_noised" in curr_copy.get_annotation_categories():
|
|
291
|
+
curr_copy.del_annotation("coord_to_be_noised")
|
|
292
|
+
|
|
293
|
+
try:
|
|
294
|
+
curr_copy = add_hydrogen_atom_positions(curr_copy)
|
|
295
|
+
except Exception as e:
|
|
296
|
+
print("WARNING: problem adding hydrogen", e)
|
|
297
|
+
|
|
298
|
+
if selection1 is not None:
|
|
299
|
+
selection1 = np.isin(curr_copy.chain_type, selection1)
|
|
300
|
+
else:
|
|
301
|
+
selection1 = selection1
|
|
302
|
+
if selection2 is not None:
|
|
303
|
+
selection2 = np.isin(curr_copy.chain_type, selection2)
|
|
304
|
+
else:
|
|
305
|
+
selection2 = selection2
|
|
306
|
+
|
|
307
|
+
# Always include fixed motif atoms for hbond calculations
|
|
308
|
+
selection2 |= np.array(curr_copy.is_motif_atom, dtype=bool)
|
|
309
|
+
selection1 = ~selection2
|
|
310
|
+
|
|
311
|
+
hbonds, hbond_types, curr_copy = calculate_hbonds(
|
|
312
|
+
curr_copy,
|
|
313
|
+
selection1=selection1,
|
|
314
|
+
selection2=selection2,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
o["num_hbonds"] = int(len(hbonds))
|
|
318
|
+
o["num_donors"] = int(np.sum(hbond_types[:, 0]))
|
|
319
|
+
o["num_acceptors"] = int(np.sum(hbond_types[:, 1]))
|
|
320
|
+
|
|
321
|
+
return o
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
class HbondMetrics(Metric):
|
|
325
|
+
def __init__(
|
|
326
|
+
self,
|
|
327
|
+
selection1: list[str] = SELECTION_PROTEIN,
|
|
328
|
+
selection2: list[str] = SELECTION_NONPROTEIN,
|
|
329
|
+
selection1_type: Literal["acceptor", "donor", "both"] = "both",
|
|
330
|
+
cutoff_dist: float = 3.0,
|
|
331
|
+
cutoff_angle: float = 120.0,
|
|
332
|
+
donor_elements: list[str] = ["N", "O", "S", "F"],
|
|
333
|
+
acceptor_elements: list[str] = ["N", "O", "S", "F"],
|
|
334
|
+
periodic: bool = False,
|
|
335
|
+
):
|
|
336
|
+
super().__init__()
|
|
337
|
+
|
|
338
|
+
self.selection1 = np.array(
|
|
339
|
+
[ChainType.as_enum(item).value for item in selection1]
|
|
340
|
+
)
|
|
341
|
+
self.selection2 = np.array(
|
|
342
|
+
[ChainType.as_enum(item).value for item in selection2]
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
self.selection1_type = selection1_type
|
|
346
|
+
self.cutoff_dist = cutoff_dist
|
|
347
|
+
self.cutoff_angle = cutoff_angle
|
|
348
|
+
self.donor_elements = donor_elements
|
|
349
|
+
self.acceptor_elements = acceptor_elements
|
|
350
|
+
self.periodic = periodic
|
|
351
|
+
|
|
352
|
+
@property
|
|
353
|
+
def kwargs_to_compute_args(self):
|
|
354
|
+
return {
|
|
355
|
+
"ground_truth_atom_array_stack": ("ground_truth_atom_array_stack",),
|
|
356
|
+
"predicted_atom_array_stack": ("predicted_atom_array_stack",),
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
def compute(self, *, ground_truth_atom_array_stack, predicted_atom_array_stack):
|
|
360
|
+
try:
|
|
361
|
+
(
|
|
362
|
+
mean_correct_donors_percent,
|
|
363
|
+
mean_correct_acceptors_percent,
|
|
364
|
+
mean_num_hbonds,
|
|
365
|
+
) = calculate_hbond_stats(
|
|
366
|
+
input_atom_array_stack=ground_truth_atom_array_stack,
|
|
367
|
+
output_atom_array_stack=predicted_atom_array_stack,
|
|
368
|
+
selection1=self.selection1,
|
|
369
|
+
selection2=self.selection2,
|
|
370
|
+
selection1_type=self.selection1_type,
|
|
371
|
+
cutoff_dist=self.cutoff_dist,
|
|
372
|
+
cutoff_angle=self.cutoff_angle,
|
|
373
|
+
donor_elements=self.donor_elements,
|
|
374
|
+
acceptor_elements=self.acceptor_elements,
|
|
375
|
+
periodic=self.periodic,
|
|
376
|
+
)
|
|
377
|
+
except Exception as e:
|
|
378
|
+
global_logger.error(
|
|
379
|
+
f"Error calculating hydrogen bond metrics: {e} | Skipping"
|
|
380
|
+
)
|
|
381
|
+
return {}
|
|
382
|
+
|
|
383
|
+
# Aggregate output for batch-level metrics
|
|
384
|
+
o = {
|
|
385
|
+
"mean_correct_donors_percent": float(mean_correct_donors_percent),
|
|
386
|
+
"mean_correct_acceptors_percent": float(mean_correct_acceptors_percent),
|
|
387
|
+
"mean_num_hbonds": float(mean_num_hbonds),
|
|
388
|
+
}
|
|
389
|
+
return o
|