posebench-fast 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- posebench_fast/__init__.py +53 -0
- posebench_fast/datasets/__init__.py +1 -0
- posebench_fast/filters/__init__.py +15 -0
- posebench_fast/filters/fast_filters.py +526 -0
- posebench_fast/metrics/__init__.py +31 -0
- posebench_fast/metrics/aggregation.py +388 -0
- posebench_fast/metrics/rmsd.py +273 -0
- posebench_fast/utils/__init__.py +1 -0
- posebench_fast-0.1.0.dist-info/METADATA +109 -0
- posebench_fast-0.1.0.dist-info/RECORD +11 -0
- posebench_fast-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
"""Metrics aggregation and filtering for docking evaluation."""
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
|
|
9
|
+
from posebench_fast.metrics.rmsd import (
|
|
10
|
+
TimeoutException,
|
|
11
|
+
get_symmetry_rmsd_with_isomorphisms,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_best_results_by_score(all_results, score_name):
|
|
16
|
+
"""Get best sample for each UID based on a score.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
all_results: Dictionary {uid: {'sample_metrics': [...]}}
|
|
20
|
+
score_name: Name of score to minimize (or 'random' for first sample)
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Dictionary {uid: best_sample_metrics}
|
|
24
|
+
"""
|
|
25
|
+
filtered_results = {}
|
|
26
|
+
|
|
27
|
+
for uid in all_results:
|
|
28
|
+
metrics = all_results[uid]
|
|
29
|
+
if score_name == "random":
|
|
30
|
+
best_index = 0
|
|
31
|
+
else:
|
|
32
|
+
scores = np.array([metr[score_name] for metr in metrics["sample_metrics"]])
|
|
33
|
+
best_index = np.argmin(scores)
|
|
34
|
+
|
|
35
|
+
filtered_results[uid] = metrics["sample_metrics"][best_index]
|
|
36
|
+
return filtered_results
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def filter_results_by_posebusters(full_results, use_separate_samples=True):
|
|
40
|
+
"""Filter samples to keep only those with maximum PoseBusters filters passed.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
full_results: Dictionary of results
|
|
44
|
+
use_separate_samples: Whether samples are in 'sample_metrics' key
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Filtered results dictionary
|
|
48
|
+
"""
|
|
49
|
+
for uid in full_results:
|
|
50
|
+
if use_separate_samples:
|
|
51
|
+
samples = full_results[uid]["sample_metrics"]
|
|
52
|
+
else:
|
|
53
|
+
samples = full_results[uid]
|
|
54
|
+
|
|
55
|
+
pb_filters_name = "posebusters_filters_passed_count"
|
|
56
|
+
|
|
57
|
+
scores = np.array([sample[pb_filters_name] for sample in samples])
|
|
58
|
+
best_score = max(scores)
|
|
59
|
+
filtered_samples = [
|
|
60
|
+
sample for sample in samples if sample[pb_filters_name] == best_score
|
|
61
|
+
]
|
|
62
|
+
if use_separate_samples:
|
|
63
|
+
full_results[uid]["sample_metrics"] = filtered_samples
|
|
64
|
+
else:
|
|
65
|
+
full_results[uid] = filtered_samples
|
|
66
|
+
return full_results
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def filter_results_by_fast(full_results, use_separate_samples=True):
|
|
70
|
+
"""Filter samples to keep only those with maximum fast PoseBusters filters passed.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
full_results: Dictionary of results
|
|
74
|
+
use_separate_samples: Whether samples are in 'sample_metrics' key
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Filtered results dictionary
|
|
78
|
+
"""
|
|
79
|
+
for uid in full_results:
|
|
80
|
+
if use_separate_samples:
|
|
81
|
+
samples = full_results[uid]["sample_metrics"]
|
|
82
|
+
else:
|
|
83
|
+
samples = full_results[uid]
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
scores = np.array(
|
|
87
|
+
[sample["posebusters_filters_passed_count_fast"] for sample in samples]
|
|
88
|
+
)
|
|
89
|
+
best_score = max(scores)
|
|
90
|
+
filtered_samples = [
|
|
91
|
+
sample
|
|
92
|
+
for sample in samples
|
|
93
|
+
if sample["posebusters_filters_passed_count_fast"] == best_score
|
|
94
|
+
]
|
|
95
|
+
|
|
96
|
+
except KeyError:
|
|
97
|
+
filtered_samples = samples
|
|
98
|
+
|
|
99
|
+
if use_separate_samples:
|
|
100
|
+
full_results[uid]["sample_metrics"] = filtered_samples
|
|
101
|
+
else:
|
|
102
|
+
full_results[uid] = filtered_samples
|
|
103
|
+
return full_results
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def filter_empty_results_and_keep_necessary_ids(
|
|
107
|
+
full_results, use_separate_samples=True, ids_to_keep=None
|
|
108
|
+
):
|
|
109
|
+
"""Filter out empty results and keep only specified IDs.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
full_results: Dictionary of results
|
|
113
|
+
use_separate_samples: Whether samples are in 'sample_metrics' key
|
|
114
|
+
ids_to_keep: Optional list of UIDs to keep
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Filtered results dictionary
|
|
118
|
+
"""
|
|
119
|
+
if ids_to_keep is not None:
|
|
120
|
+
all_pred_uids = {key.split("_mol")[0] for key in full_results}
|
|
121
|
+
uids_to_pop = [
|
|
122
|
+
f"{uid}_mol0" for uid in sorted(all_pred_uids - set(ids_to_keep))
|
|
123
|
+
]
|
|
124
|
+
else:
|
|
125
|
+
uids_to_pop = []
|
|
126
|
+
|
|
127
|
+
if len(uids_to_pop) > 0:
|
|
128
|
+
print(f"Pop {len(uids_to_pop)} uids")
|
|
129
|
+
|
|
130
|
+
for uid in full_results:
|
|
131
|
+
if len(full_results[uid]) == 0:
|
|
132
|
+
print(f"{uid} has no valid samples")
|
|
133
|
+
uids_to_pop.append(uid)
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
if use_separate_samples:
|
|
137
|
+
samples = full_results[uid]["sample_metrics"]
|
|
138
|
+
else:
|
|
139
|
+
samples = full_results[uid]
|
|
140
|
+
|
|
141
|
+
if len(samples) == 0:
|
|
142
|
+
print(f"{uid} has no valid samples")
|
|
143
|
+
uids_to_pop.append(uid)
|
|
144
|
+
continue
|
|
145
|
+
|
|
146
|
+
for uid in uids_to_pop:
|
|
147
|
+
full_results.pop(uid)
|
|
148
|
+
|
|
149
|
+
return full_results
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def get_final_results_for_df(
|
|
153
|
+
full_results,
|
|
154
|
+
score_names,
|
|
155
|
+
score_name_prefix="",
|
|
156
|
+
posebusters_filter=False,
|
|
157
|
+
fast_filter=False,
|
|
158
|
+
ids_to_keep=None,
|
|
159
|
+
):
|
|
160
|
+
"""Compute final metrics DataFrame from results.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
full_results: Dictionary of results
|
|
164
|
+
score_names: List of score names to evaluate
|
|
165
|
+
score_name_prefix: Prefix for score names in output
|
|
166
|
+
posebusters_filter: Whether to also compute PoseBusters-filtered metrics
|
|
167
|
+
fast_filter: Whether to also compute fast-filter metrics
|
|
168
|
+
ids_to_keep: Optional list of UIDs to keep
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Tuple of (rows_list, all_scored_results)
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
def get_row(results, score_name, full_score_name, posebusters_filter):
|
|
175
|
+
scored_results = get_best_results_by_score(results, score_name)
|
|
176
|
+
|
|
177
|
+
rmsds = np.array([item["rmsd"] for item in scored_results.values()])
|
|
178
|
+
sym_rmsds = np.array([item["symm_rmsd"] for item in scored_results.values()])
|
|
179
|
+
tr_errs = np.array([item["tr_err"] for item in scored_results.values()])
|
|
180
|
+
|
|
181
|
+
row = {
|
|
182
|
+
"ranking": full_score_name,
|
|
183
|
+
"RMSD < 2A": (rmsds <= 2).mean(),
|
|
184
|
+
"RMSD < 5A": (rmsds <= 5).mean(),
|
|
185
|
+
"avg RMSD": rmsds.mean(),
|
|
186
|
+
"median RMSD": np.median(rmsds),
|
|
187
|
+
"SymRMSD < 2A": (sym_rmsds <= 2).mean(),
|
|
188
|
+
"SymRMSD < 5A": (sym_rmsds <= 5).mean(),
|
|
189
|
+
"avg SymRMSD": sym_rmsds.mean(),
|
|
190
|
+
"median SymRMSD": np.median(sym_rmsds),
|
|
191
|
+
"avg tr_err": tr_errs.mean(),
|
|
192
|
+
"median tr_err": np.median(tr_errs),
|
|
193
|
+
"tr_err < 1A": (tr_errs <= 1).mean(),
|
|
194
|
+
"num_samples": len(scored_results.values()),
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
if posebusters_filter:
|
|
198
|
+
posebusters_all = np.array(
|
|
199
|
+
[
|
|
200
|
+
item["all_posebusters_filters_passed_count"]
|
|
201
|
+
for item in scored_results.values()
|
|
202
|
+
]
|
|
203
|
+
)
|
|
204
|
+
row["SymRMSD < 2A & PB valid"] = np.logical_and(
|
|
205
|
+
sym_rmsds < 2, posebusters_all == 27
|
|
206
|
+
).mean()
|
|
207
|
+
return row, scored_results
|
|
208
|
+
|
|
209
|
+
rows_list = []
|
|
210
|
+
all_scored_results = {}
|
|
211
|
+
|
|
212
|
+
full_results = filter_empty_results_and_keep_necessary_ids(
|
|
213
|
+
full_results, use_separate_samples=True, ids_to_keep=ids_to_keep
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
if posebusters_filter:
|
|
217
|
+
filtered_results_posebusters = filter_results_by_posebusters(
|
|
218
|
+
copy.deepcopy(full_results)
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if fast_filter:
|
|
222
|
+
filtered_results_fast = filter_results_by_fast(copy.deepcopy(full_results))
|
|
223
|
+
|
|
224
|
+
for score_name in score_names:
|
|
225
|
+
full_score_name = f"{score_name_prefix}{score_name}"
|
|
226
|
+
|
|
227
|
+
row, scored_results = get_row(
|
|
228
|
+
full_results,
|
|
229
|
+
score_name,
|
|
230
|
+
full_score_name,
|
|
231
|
+
posebusters_filter=posebusters_filter,
|
|
232
|
+
)
|
|
233
|
+
all_scored_results[full_score_name] = scored_results
|
|
234
|
+
rows_list.append(row)
|
|
235
|
+
|
|
236
|
+
if posebusters_filter:
|
|
237
|
+
real_score_name = f"{full_score_name}_posebusters"
|
|
238
|
+
row, scored_results = get_row(
|
|
239
|
+
filtered_results_posebusters,
|
|
240
|
+
score_name,
|
|
241
|
+
real_score_name,
|
|
242
|
+
posebusters_filter=posebusters_filter,
|
|
243
|
+
)
|
|
244
|
+
all_scored_results[real_score_name] = scored_results
|
|
245
|
+
rows_list.append(row)
|
|
246
|
+
|
|
247
|
+
if fast_filter:
|
|
248
|
+
real_score_name = f"{full_score_name}_fast"
|
|
249
|
+
row, scored_results = get_row(
|
|
250
|
+
filtered_results_fast,
|
|
251
|
+
score_name,
|
|
252
|
+
real_score_name,
|
|
253
|
+
posebusters_filter=posebusters_filter,
|
|
254
|
+
)
|
|
255
|
+
all_scored_results[real_score_name] = scored_results
|
|
256
|
+
rows_list.append(row)
|
|
257
|
+
|
|
258
|
+
return rows_list, all_scored_results
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def add_score_results(all_rmsds_new, score_res, score_name, n_samples=None):
|
|
262
|
+
"""Add score results to samples.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
all_rmsds_new: Dictionary of results
|
|
266
|
+
score_res: Score results dictionary
|
|
267
|
+
score_name: Type of score ('mult', 'bin', 'reg')
|
|
268
|
+
n_samples: Number of samples to use for averaging
|
|
269
|
+
|
|
270
|
+
Returns:
|
|
271
|
+
Extended results dictionary
|
|
272
|
+
"""
|
|
273
|
+
extended_results = {}
|
|
274
|
+
for uid in tqdm(all_rmsds_new.keys(), desc="Adding score results"):
|
|
275
|
+
new_samples = []
|
|
276
|
+
for i in range(len(all_rmsds_new[uid])):
|
|
277
|
+
sample = all_rmsds_new[uid][i]
|
|
278
|
+
sample_scores = np.array(score_res[f"{uid}_{i}"])
|
|
279
|
+
nan_mask = np.isnan(sample_scores).sum(axis=1).astype(bool)
|
|
280
|
+
if nan_mask.sum() > 0:
|
|
281
|
+
if score_name == "mult":
|
|
282
|
+
sample_scores[nan_mask, 2] = 6.0
|
|
283
|
+
sample_scores[nan_mask, 0] = 0.0
|
|
284
|
+
sample_scores[nan_mask, 1] = 0.0
|
|
285
|
+
elif score_name == "bin":
|
|
286
|
+
sample_scores[nan_mask] = 0.0
|
|
287
|
+
elif score_name == "reg":
|
|
288
|
+
sample_scores[nan_mask] = 50.0
|
|
289
|
+
|
|
290
|
+
sample_scores = -sample_scores
|
|
291
|
+
if n_samples is None:
|
|
292
|
+
n_samples = len(sample_scores)
|
|
293
|
+
mean_scores = np.mean(sample_scores[:n_samples], axis=0)
|
|
294
|
+
|
|
295
|
+
for idx in range(len(mean_scores)):
|
|
296
|
+
sample[f"{score_name}_{idx}"] = mean_scores[idx]
|
|
297
|
+
|
|
298
|
+
new_samples.append(sample)
|
|
299
|
+
extended_results[uid] = new_samples
|
|
300
|
+
return extended_results
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def get_simple_metrics_df(
|
|
304
|
+
all_real_rmsds, compute_symm_rmsd, mol2isomorphisms, score_names
|
|
305
|
+
):
|
|
306
|
+
"""Compute simple metrics DataFrame.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
all_real_rmsds: Dictionary {uid: [samples]} with 'true_pos', 'transformed_orig'
|
|
310
|
+
compute_symm_rmsd: Whether to compute symmetry RMSD
|
|
311
|
+
mol2isomorphisms: Dictionary {uid: isomorphisms} from compute_all_isomorphisms
|
|
312
|
+
score_names: List of score names to include
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
Tuple of (DataFrame, all_scored_results, full_results)
|
|
316
|
+
"""
|
|
317
|
+
full_results = {}
|
|
318
|
+
for uid, samples in tqdm(all_real_rmsds.items(), desc="Computing metrics"):
|
|
319
|
+
samples_results = []
|
|
320
|
+
failed_symm_rmsd_count = 0
|
|
321
|
+
|
|
322
|
+
true_pos = samples[0]["true_pos"]
|
|
323
|
+
for idx in range(len(samples)):
|
|
324
|
+
pred_pos = samples[idx]["transformed_orig"]
|
|
325
|
+
|
|
326
|
+
if true_pos.shape[0] != pred_pos.shape[0]:
|
|
327
|
+
print(
|
|
328
|
+
f"{uid}_{idx:<8} true_pos.shape[0] != pred_pos.shape[0]",
|
|
329
|
+
true_pos.shape,
|
|
330
|
+
pred_pos.shape,
|
|
331
|
+
)
|
|
332
|
+
continue
|
|
333
|
+
|
|
334
|
+
tr_pred = pred_pos.mean(axis=0)
|
|
335
|
+
tr_true = true_pos.mean(axis=0)
|
|
336
|
+
tr_err = np.linalg.norm(tr_pred - tr_true)
|
|
337
|
+
|
|
338
|
+
rmsd = np.sqrt(
|
|
339
|
+
((true_pos - pred_pos) ** 2).sum(axis=1).sum() / true_pos.shape[0]
|
|
340
|
+
)
|
|
341
|
+
if compute_symm_rmsd and failed_symm_rmsd_count < 3:
|
|
342
|
+
try:
|
|
343
|
+
mol2iso = mol2isomorphisms.get(uid.split("_conf")[0])
|
|
344
|
+
if mol2iso is None:
|
|
345
|
+
symm_rmsd = rmsd
|
|
346
|
+
failed_symm_rmsd_count += 1
|
|
347
|
+
else:
|
|
348
|
+
symm_rmsd = get_symmetry_rmsd_with_isomorphisms(
|
|
349
|
+
true_pos, pred_pos, mol2iso
|
|
350
|
+
)
|
|
351
|
+
except TimeoutException:
|
|
352
|
+
symm_rmsd = rmsd
|
|
353
|
+
failed_symm_rmsd_count += 1
|
|
354
|
+
else:
|
|
355
|
+
symm_rmsd = rmsd
|
|
356
|
+
|
|
357
|
+
results = {
|
|
358
|
+
"tr_pred": tr_pred,
|
|
359
|
+
"tr_err": float(tr_err),
|
|
360
|
+
"symm_rmsd": float(symm_rmsd),
|
|
361
|
+
"rmsd": float(rmsd),
|
|
362
|
+
"pred_pos": pred_pos,
|
|
363
|
+
}
|
|
364
|
+
for score_name in set(score_names) - {"random", "symm_rmsd"}:
|
|
365
|
+
results[score_name] = float(samples[idx][score_name])
|
|
366
|
+
|
|
367
|
+
samples_results.append(results)
|
|
368
|
+
samples_results_dict = {
|
|
369
|
+
"sample_metrics": samples_results,
|
|
370
|
+
"true_pos": true_pos,
|
|
371
|
+
"orig_mol": samples[0]["orig_mol"],
|
|
372
|
+
}
|
|
373
|
+
if len(samples_results_dict["sample_metrics"]) > 0:
|
|
374
|
+
full_results[uid] = samples_results_dict
|
|
375
|
+
else:
|
|
376
|
+
print(f"{uid} has no valid samples")
|
|
377
|
+
print(
|
|
378
|
+
f"{uid} true_pos.shape[0] != pred_pos.shape[0]",
|
|
379
|
+
true_pos.shape,
|
|
380
|
+
pred_pos.shape,
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
if len(full_results) != len(all_real_rmsds):
|
|
384
|
+
print("Initial length of test_names", len(all_real_rmsds))
|
|
385
|
+
print("Length of full_results", len(full_results))
|
|
386
|
+
|
|
387
|
+
rows_list, all_scored_results = get_final_results_for_df(full_results, score_names)
|
|
388
|
+
return pd.DataFrame(rows_list), all_scored_results, full_results
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
# Taken from https://github.com/RMeli/spyrmsd and https://github.com/gcorso/DiffDock/
|
|
2
|
+
|
|
3
|
+
import signal
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from spyrmsd import graph, molecule, qcp, utils
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TimeoutException(Exception): # noqa: N818
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@contextmanager
|
|
16
|
+
def time_limit(seconds):
|
|
17
|
+
def signal_handler(signum, frame):
|
|
18
|
+
raise TimeoutException("Timed out!")
|
|
19
|
+
|
|
20
|
+
signal.signal(signal.SIGALRM, signal_handler)
|
|
21
|
+
signal.alarm(seconds)
|
|
22
|
+
try:
|
|
23
|
+
yield
|
|
24
|
+
finally:
|
|
25
|
+
signal.alarm(0)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def compute_all_isomorphisms(rdkit_mol):
|
|
29
|
+
"""Compute all graph isomorphisms for a molecule.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
rdkit_mol: RDKit molecule object
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
List of isomorphism tuples (idx1, idx2)
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
with time_limit(2):
|
|
39
|
+
mol = molecule.Molecule.from_rdkit(rdkit_mol)
|
|
40
|
+
G1 = graph.graph_from_adjacency_matrix(mol.adjacency_matrix, mol.atomicnums)
|
|
41
|
+
isomorphisms = graph.match_graphs(G1, G1)
|
|
42
|
+
except TimeoutException:
|
|
43
|
+
isomorphisms = [
|
|
44
|
+
(list(range(rdkit_mol.GetNumAtoms())), list(range(rdkit_mol.GetNumAtoms())))
|
|
45
|
+
]
|
|
46
|
+
return isomorphisms
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_symmetry_rmsd_with_isomorphisms(coords1, coords2, isomorphisms):
|
|
50
|
+
"""Compute symmetry-corrected RMSD using precomputed isomorphisms.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
coords1: Reference coordinates (N, 3)
|
|
54
|
+
coords2: Query coordinates (N, 3)
|
|
55
|
+
isomorphisms: List of isomorphism tuples from compute_all_isomorphisms
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Minimum RMSD over all isomorphisms
|
|
59
|
+
"""
|
|
60
|
+
with time_limit(1):
|
|
61
|
+
assert coords1.shape == coords2.shape
|
|
62
|
+
|
|
63
|
+
n = coords1.shape[0]
|
|
64
|
+
min_result = np.inf
|
|
65
|
+
|
|
66
|
+
for idx1, idx2 in isomorphisms:
|
|
67
|
+
c1i = coords1[idx1, :]
|
|
68
|
+
c2i = coords2[idx2, :]
|
|
69
|
+
result = np.sum((c1i - c2i) ** 2)
|
|
70
|
+
|
|
71
|
+
if result < min_result:
|
|
72
|
+
min_result = result
|
|
73
|
+
|
|
74
|
+
min_result = np.sqrt(min_result / n)
|
|
75
|
+
return min_result
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def get_symmetry_rmsd(mol, coords1, coords2, mol2=None, return_permutation=False):
|
|
79
|
+
"""Compute symmetry-corrected RMSD.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
mol: RDKit molecule
|
|
83
|
+
coords1: Reference coordinates
|
|
84
|
+
coords2: Query coordinates
|
|
85
|
+
mol2: Optional second molecule (if different from mol)
|
|
86
|
+
return_permutation: Whether to return the best permutation
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
RMSD value (and permutation if requested)
|
|
90
|
+
"""
|
|
91
|
+
with time_limit(10):
|
|
92
|
+
mol = molecule.Molecule.from_rdkit(mol)
|
|
93
|
+
mol2 = molecule.Molecule.from_rdkit(mol2) if mol2 is not None else mol2
|
|
94
|
+
mol2_atomicnums = mol2.atomicnums if mol2 is not None else mol.atomicnums
|
|
95
|
+
mol2_adjacency_matrix = (
|
|
96
|
+
mol2.adjacency_matrix if mol2 is not None else mol.adjacency_matrix
|
|
97
|
+
)
|
|
98
|
+
RMSD = symmrmsd(
|
|
99
|
+
coords1,
|
|
100
|
+
coords2,
|
|
101
|
+
mol.atomicnums,
|
|
102
|
+
mol2_atomicnums,
|
|
103
|
+
mol.adjacency_matrix,
|
|
104
|
+
mol2_adjacency_matrix,
|
|
105
|
+
return_permutation=return_permutation,
|
|
106
|
+
)
|
|
107
|
+
return RMSD
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _rmsd_isomorphic_core(
|
|
111
|
+
coords1: np.ndarray,
|
|
112
|
+
coords2: np.ndarray,
|
|
113
|
+
aprops1: np.ndarray,
|
|
114
|
+
aprops2: np.ndarray,
|
|
115
|
+
am1: np.ndarray,
|
|
116
|
+
am2: np.ndarray,
|
|
117
|
+
center: bool = False,
|
|
118
|
+
minimize: bool = False,
|
|
119
|
+
isomorphisms: list[tuple[list[int], list[int]]] | None = None,
|
|
120
|
+
atol: float = 1e-9,
|
|
121
|
+
) -> tuple[float, list[tuple[list[int], list[int]]], tuple[list[int], list[int]]]:
|
|
122
|
+
"""
|
|
123
|
+
Compute RMSD using graph isomorphism.
|
|
124
|
+
|
|
125
|
+
Parameters
|
|
126
|
+
----------
|
|
127
|
+
coords1: np.ndarray
|
|
128
|
+
Coordinate of molecule 1
|
|
129
|
+
coords2: np.ndarray
|
|
130
|
+
Coordinates of molecule 2
|
|
131
|
+
aprops1: np.ndarray
|
|
132
|
+
Atomic properties for molecule 1
|
|
133
|
+
aprops2: np.ndarray
|
|
134
|
+
Atomic properties for molecule 2
|
|
135
|
+
am1: np.ndarray
|
|
136
|
+
Adjacency matrix for molecule 1
|
|
137
|
+
am2: np.ndarray
|
|
138
|
+
Adjacency matrix for molecule 2
|
|
139
|
+
center: bool
|
|
140
|
+
Centering flag
|
|
141
|
+
minimize: bool
|
|
142
|
+
Compute minized RMSD
|
|
143
|
+
isomorphisms: Optional[List[Dict[int,int]]]
|
|
144
|
+
Previously computed graph isomorphism
|
|
145
|
+
atol: float
|
|
146
|
+
Absolute tolerance parameter for QCP (see :func:`qcp_rmsd`)
|
|
147
|
+
|
|
148
|
+
Returns
|
|
149
|
+
-------
|
|
150
|
+
Tuple[float, List[Dict[int, int]]]
|
|
151
|
+
RMSD (after graph matching) and graph isomorphisms
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
assert coords1.shape == coords2.shape
|
|
155
|
+
|
|
156
|
+
n = coords1.shape[0]
|
|
157
|
+
|
|
158
|
+
c1 = utils.center(coords1) if center or minimize else coords1
|
|
159
|
+
c2 = utils.center(coords2) if center or minimize else coords2
|
|
160
|
+
|
|
161
|
+
if isomorphisms is None:
|
|
162
|
+
G1 = graph.graph_from_adjacency_matrix(am1, aprops1)
|
|
163
|
+
G2 = graph.graph_from_adjacency_matrix(am2, aprops2)
|
|
164
|
+
isomorphisms = graph.match_graphs(G1, G2)
|
|
165
|
+
|
|
166
|
+
min_result = np.inf
|
|
167
|
+
min_isomorphisms = None
|
|
168
|
+
|
|
169
|
+
for idx1, idx2 in isomorphisms:
|
|
170
|
+
c1i = c1[idx1, :]
|
|
171
|
+
c2i = c2[idx2, :]
|
|
172
|
+
|
|
173
|
+
if not minimize:
|
|
174
|
+
result = np.sum((c1i - c2i) ** 2)
|
|
175
|
+
else:
|
|
176
|
+
result = qcp.qcp_rmsd(c1i, c2i, atol)
|
|
177
|
+
|
|
178
|
+
if result < min_result:
|
|
179
|
+
min_result = result
|
|
180
|
+
min_isomorphisms = (idx1, idx2)
|
|
181
|
+
|
|
182
|
+
if not minimize:
|
|
183
|
+
min_result = np.sqrt(min_result / n)
|
|
184
|
+
|
|
185
|
+
return min_result, isomorphisms, min_isomorphisms
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def symmrmsd(
|
|
189
|
+
coordsref: np.ndarray,
|
|
190
|
+
coords: np.ndarray | list[np.ndarray],
|
|
191
|
+
apropsref: np.ndarray,
|
|
192
|
+
aprops: np.ndarray,
|
|
193
|
+
amref: np.ndarray,
|
|
194
|
+
am: np.ndarray,
|
|
195
|
+
center: bool = False,
|
|
196
|
+
minimize: bool = False,
|
|
197
|
+
cache: bool = True,
|
|
198
|
+
atol: float = 1e-9,
|
|
199
|
+
return_permutation: bool = False,
|
|
200
|
+
) -> Any:
|
|
201
|
+
"""
|
|
202
|
+
Compute RMSD using graph isomorphism for multiple coordinates.
|
|
203
|
+
|
|
204
|
+
Parameters
|
|
205
|
+
----------
|
|
206
|
+
coordsref: np.ndarray
|
|
207
|
+
Coordinate of reference molecule
|
|
208
|
+
coords: List[np.ndarray]
|
|
209
|
+
Coordinates of other molecule
|
|
210
|
+
apropsref: np.ndarray
|
|
211
|
+
Atomic properties for reference
|
|
212
|
+
aprops: np.ndarray
|
|
213
|
+
Atomic properties for other molecule
|
|
214
|
+
amref: np.ndarray
|
|
215
|
+
Adjacency matrix for reference molecule
|
|
216
|
+
am: np.ndarray
|
|
217
|
+
Adjacency matrix for other molecule
|
|
218
|
+
center: bool
|
|
219
|
+
Centering flag
|
|
220
|
+
minimize: bool
|
|
221
|
+
Minimum RMSD
|
|
222
|
+
cache: bool
|
|
223
|
+
Cache graph isomorphisms
|
|
224
|
+
atol: float
|
|
225
|
+
Absolute tolerance parameter for QCP (see :func:`qcp_rmsd`)
|
|
226
|
+
|
|
227
|
+
Returns
|
|
228
|
+
-------
|
|
229
|
+
float: Union[float, List[float]]
|
|
230
|
+
Symmetry-corrected RMSD(s) and graph isomorphisms
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
if isinstance(coords, list):
|
|
234
|
+
RMSD: Any = []
|
|
235
|
+
isomorphism = None
|
|
236
|
+
min_iso = []
|
|
237
|
+
|
|
238
|
+
for c in coords:
|
|
239
|
+
if not cache:
|
|
240
|
+
isomorphism = None
|
|
241
|
+
|
|
242
|
+
srmsd, isomorphism, min_i = _rmsd_isomorphic_core(
|
|
243
|
+
coordsref,
|
|
244
|
+
c,
|
|
245
|
+
apropsref,
|
|
246
|
+
aprops,
|
|
247
|
+
amref,
|
|
248
|
+
am,
|
|
249
|
+
center=center,
|
|
250
|
+
minimize=minimize,
|
|
251
|
+
isomorphisms=isomorphism,
|
|
252
|
+
atol=atol,
|
|
253
|
+
)
|
|
254
|
+
min_iso.append(min_i)
|
|
255
|
+
RMSD.append(srmsd)
|
|
256
|
+
|
|
257
|
+
else:
|
|
258
|
+
RMSD, isomorphism, min_iso = _rmsd_isomorphic_core(
|
|
259
|
+
coordsref,
|
|
260
|
+
coords,
|
|
261
|
+
apropsref,
|
|
262
|
+
aprops,
|
|
263
|
+
amref,
|
|
264
|
+
am,
|
|
265
|
+
center=center,
|
|
266
|
+
minimize=minimize,
|
|
267
|
+
isomorphisms=None,
|
|
268
|
+
atol=atol,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
if return_permutation:
|
|
272
|
+
return RMSD, min_iso
|
|
273
|
+
return RMSD
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Utility functions for posebench-fast."""
|