posebench-fast 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,388 @@
1
+ """Metrics aggregation and filtering for docking evaluation."""
2
+
3
+ import copy
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from tqdm import tqdm
8
+
9
+ from posebench_fast.metrics.rmsd import (
10
+ TimeoutException,
11
+ get_symmetry_rmsd_with_isomorphisms,
12
+ )
13
+
14
+
15
+ def get_best_results_by_score(all_results, score_name):
16
+ """Get best sample for each UID based on a score.
17
+
18
+ Args:
19
+ all_results: Dictionary {uid: {'sample_metrics': [...]}}
20
+ score_name: Name of score to minimize (or 'random' for first sample)
21
+
22
+ Returns:
23
+ Dictionary {uid: best_sample_metrics}
24
+ """
25
+ filtered_results = {}
26
+
27
+ for uid in all_results:
28
+ metrics = all_results[uid]
29
+ if score_name == "random":
30
+ best_index = 0
31
+ else:
32
+ scores = np.array([metr[score_name] for metr in metrics["sample_metrics"]])
33
+ best_index = np.argmin(scores)
34
+
35
+ filtered_results[uid] = metrics["sample_metrics"][best_index]
36
+ return filtered_results
37
+
38
+
39
+ def filter_results_by_posebusters(full_results, use_separate_samples=True):
40
+ """Filter samples to keep only those with maximum PoseBusters filters passed.
41
+
42
+ Args:
43
+ full_results: Dictionary of results
44
+ use_separate_samples: Whether samples are in 'sample_metrics' key
45
+
46
+ Returns:
47
+ Filtered results dictionary
48
+ """
49
+ for uid in full_results:
50
+ if use_separate_samples:
51
+ samples = full_results[uid]["sample_metrics"]
52
+ else:
53
+ samples = full_results[uid]
54
+
55
+ pb_filters_name = "posebusters_filters_passed_count"
56
+
57
+ scores = np.array([sample[pb_filters_name] for sample in samples])
58
+ best_score = max(scores)
59
+ filtered_samples = [
60
+ sample for sample in samples if sample[pb_filters_name] == best_score
61
+ ]
62
+ if use_separate_samples:
63
+ full_results[uid]["sample_metrics"] = filtered_samples
64
+ else:
65
+ full_results[uid] = filtered_samples
66
+ return full_results
67
+
68
+
69
+ def filter_results_by_fast(full_results, use_separate_samples=True):
70
+ """Filter samples to keep only those with maximum fast PoseBusters filters passed.
71
+
72
+ Args:
73
+ full_results: Dictionary of results
74
+ use_separate_samples: Whether samples are in 'sample_metrics' key
75
+
76
+ Returns:
77
+ Filtered results dictionary
78
+ """
79
+ for uid in full_results:
80
+ if use_separate_samples:
81
+ samples = full_results[uid]["sample_metrics"]
82
+ else:
83
+ samples = full_results[uid]
84
+
85
+ try:
86
+ scores = np.array(
87
+ [sample["posebusters_filters_passed_count_fast"] for sample in samples]
88
+ )
89
+ best_score = max(scores)
90
+ filtered_samples = [
91
+ sample
92
+ for sample in samples
93
+ if sample["posebusters_filters_passed_count_fast"] == best_score
94
+ ]
95
+
96
+ except KeyError:
97
+ filtered_samples = samples
98
+
99
+ if use_separate_samples:
100
+ full_results[uid]["sample_metrics"] = filtered_samples
101
+ else:
102
+ full_results[uid] = filtered_samples
103
+ return full_results
104
+
105
+
106
+ def filter_empty_results_and_keep_necessary_ids(
107
+ full_results, use_separate_samples=True, ids_to_keep=None
108
+ ):
109
+ """Filter out empty results and keep only specified IDs.
110
+
111
+ Args:
112
+ full_results: Dictionary of results
113
+ use_separate_samples: Whether samples are in 'sample_metrics' key
114
+ ids_to_keep: Optional list of UIDs to keep
115
+
116
+ Returns:
117
+ Filtered results dictionary
118
+ """
119
+ if ids_to_keep is not None:
120
+ all_pred_uids = {key.split("_mol")[0] for key in full_results}
121
+ uids_to_pop = [
122
+ f"{uid}_mol0" for uid in sorted(all_pred_uids - set(ids_to_keep))
123
+ ]
124
+ else:
125
+ uids_to_pop = []
126
+
127
+ if len(uids_to_pop) > 0:
128
+ print(f"Pop {len(uids_to_pop)} uids")
129
+
130
+ for uid in full_results:
131
+ if len(full_results[uid]) == 0:
132
+ print(f"{uid} has no valid samples")
133
+ uids_to_pop.append(uid)
134
+ continue
135
+
136
+ if use_separate_samples:
137
+ samples = full_results[uid]["sample_metrics"]
138
+ else:
139
+ samples = full_results[uid]
140
+
141
+ if len(samples) == 0:
142
+ print(f"{uid} has no valid samples")
143
+ uids_to_pop.append(uid)
144
+ continue
145
+
146
+ for uid in uids_to_pop:
147
+ full_results.pop(uid)
148
+
149
+ return full_results
150
+
151
+
152
+ def get_final_results_for_df(
153
+ full_results,
154
+ score_names,
155
+ score_name_prefix="",
156
+ posebusters_filter=False,
157
+ fast_filter=False,
158
+ ids_to_keep=None,
159
+ ):
160
+ """Compute final metrics DataFrame from results.
161
+
162
+ Args:
163
+ full_results: Dictionary of results
164
+ score_names: List of score names to evaluate
165
+ score_name_prefix: Prefix for score names in output
166
+ posebusters_filter: Whether to also compute PoseBusters-filtered metrics
167
+ fast_filter: Whether to also compute fast-filter metrics
168
+ ids_to_keep: Optional list of UIDs to keep
169
+
170
+ Returns:
171
+ Tuple of (rows_list, all_scored_results)
172
+ """
173
+
174
+ def get_row(results, score_name, full_score_name, posebusters_filter):
175
+ scored_results = get_best_results_by_score(results, score_name)
176
+
177
+ rmsds = np.array([item["rmsd"] for item in scored_results.values()])
178
+ sym_rmsds = np.array([item["symm_rmsd"] for item in scored_results.values()])
179
+ tr_errs = np.array([item["tr_err"] for item in scored_results.values()])
180
+
181
+ row = {
182
+ "ranking": full_score_name,
183
+ "RMSD < 2A": (rmsds <= 2).mean(),
184
+ "RMSD < 5A": (rmsds <= 5).mean(),
185
+ "avg RMSD": rmsds.mean(),
186
+ "median RMSD": np.median(rmsds),
187
+ "SymRMSD < 2A": (sym_rmsds <= 2).mean(),
188
+ "SymRMSD < 5A": (sym_rmsds <= 5).mean(),
189
+ "avg SymRMSD": sym_rmsds.mean(),
190
+ "median SymRMSD": np.median(sym_rmsds),
191
+ "avg tr_err": tr_errs.mean(),
192
+ "median tr_err": np.median(tr_errs),
193
+ "tr_err < 1A": (tr_errs <= 1).mean(),
194
+ "num_samples": len(scored_results.values()),
195
+ }
196
+
197
+ if posebusters_filter:
198
+ posebusters_all = np.array(
199
+ [
200
+ item["all_posebusters_filters_passed_count"]
201
+ for item in scored_results.values()
202
+ ]
203
+ )
204
+ row["SymRMSD < 2A & PB valid"] = np.logical_and(
205
+ sym_rmsds < 2, posebusters_all == 27
206
+ ).mean()
207
+ return row, scored_results
208
+
209
+ rows_list = []
210
+ all_scored_results = {}
211
+
212
+ full_results = filter_empty_results_and_keep_necessary_ids(
213
+ full_results, use_separate_samples=True, ids_to_keep=ids_to_keep
214
+ )
215
+
216
+ if posebusters_filter:
217
+ filtered_results_posebusters = filter_results_by_posebusters(
218
+ copy.deepcopy(full_results)
219
+ )
220
+
221
+ if fast_filter:
222
+ filtered_results_fast = filter_results_by_fast(copy.deepcopy(full_results))
223
+
224
+ for score_name in score_names:
225
+ full_score_name = f"{score_name_prefix}{score_name}"
226
+
227
+ row, scored_results = get_row(
228
+ full_results,
229
+ score_name,
230
+ full_score_name,
231
+ posebusters_filter=posebusters_filter,
232
+ )
233
+ all_scored_results[full_score_name] = scored_results
234
+ rows_list.append(row)
235
+
236
+ if posebusters_filter:
237
+ real_score_name = f"{full_score_name}_posebusters"
238
+ row, scored_results = get_row(
239
+ filtered_results_posebusters,
240
+ score_name,
241
+ real_score_name,
242
+ posebusters_filter=posebusters_filter,
243
+ )
244
+ all_scored_results[real_score_name] = scored_results
245
+ rows_list.append(row)
246
+
247
+ if fast_filter:
248
+ real_score_name = f"{full_score_name}_fast"
249
+ row, scored_results = get_row(
250
+ filtered_results_fast,
251
+ score_name,
252
+ real_score_name,
253
+ posebusters_filter=posebusters_filter,
254
+ )
255
+ all_scored_results[real_score_name] = scored_results
256
+ rows_list.append(row)
257
+
258
+ return rows_list, all_scored_results
259
+
260
+
261
+ def add_score_results(all_rmsds_new, score_res, score_name, n_samples=None):
262
+ """Add score results to samples.
263
+
264
+ Args:
265
+ all_rmsds_new: Dictionary of results
266
+ score_res: Score results dictionary
267
+ score_name: Type of score ('mult', 'bin', 'reg')
268
+ n_samples: Number of samples to use for averaging
269
+
270
+ Returns:
271
+ Extended results dictionary
272
+ """
273
+ extended_results = {}
274
+ for uid in tqdm(all_rmsds_new.keys(), desc="Adding score results"):
275
+ new_samples = []
276
+ for i in range(len(all_rmsds_new[uid])):
277
+ sample = all_rmsds_new[uid][i]
278
+ sample_scores = np.array(score_res[f"{uid}_{i}"])
279
+ nan_mask = np.isnan(sample_scores).sum(axis=1).astype(bool)
280
+ if nan_mask.sum() > 0:
281
+ if score_name == "mult":
282
+ sample_scores[nan_mask, 2] = 6.0
283
+ sample_scores[nan_mask, 0] = 0.0
284
+ sample_scores[nan_mask, 1] = 0.0
285
+ elif score_name == "bin":
286
+ sample_scores[nan_mask] = 0.0
287
+ elif score_name == "reg":
288
+ sample_scores[nan_mask] = 50.0
289
+
290
+ sample_scores = -sample_scores
291
+ if n_samples is None:
292
+ n_samples = len(sample_scores)
293
+ mean_scores = np.mean(sample_scores[:n_samples], axis=0)
294
+
295
+ for idx in range(len(mean_scores)):
296
+ sample[f"{score_name}_{idx}"] = mean_scores[idx]
297
+
298
+ new_samples.append(sample)
299
+ extended_results[uid] = new_samples
300
+ return extended_results
301
+
302
+
303
+ def get_simple_metrics_df(
304
+ all_real_rmsds, compute_symm_rmsd, mol2isomorphisms, score_names
305
+ ):
306
+ """Compute simple metrics DataFrame.
307
+
308
+ Args:
309
+ all_real_rmsds: Dictionary {uid: [samples]} with 'true_pos', 'transformed_orig'
310
+ compute_symm_rmsd: Whether to compute symmetry RMSD
311
+ mol2isomorphisms: Dictionary {uid: isomorphisms} from compute_all_isomorphisms
312
+ score_names: List of score names to include
313
+
314
+ Returns:
315
+ Tuple of (DataFrame, all_scored_results, full_results)
316
+ """
317
+ full_results = {}
318
+ for uid, samples in tqdm(all_real_rmsds.items(), desc="Computing metrics"):
319
+ samples_results = []
320
+ failed_symm_rmsd_count = 0
321
+
322
+ true_pos = samples[0]["true_pos"]
323
+ for idx in range(len(samples)):
324
+ pred_pos = samples[idx]["transformed_orig"]
325
+
326
+ if true_pos.shape[0] != pred_pos.shape[0]:
327
+ print(
328
+ f"{uid}_{idx:<8} true_pos.shape[0] != pred_pos.shape[0]",
329
+ true_pos.shape,
330
+ pred_pos.shape,
331
+ )
332
+ continue
333
+
334
+ tr_pred = pred_pos.mean(axis=0)
335
+ tr_true = true_pos.mean(axis=0)
336
+ tr_err = np.linalg.norm(tr_pred - tr_true)
337
+
338
+ rmsd = np.sqrt(
339
+ ((true_pos - pred_pos) ** 2).sum(axis=1).sum() / true_pos.shape[0]
340
+ )
341
+ if compute_symm_rmsd and failed_symm_rmsd_count < 3:
342
+ try:
343
+ mol2iso = mol2isomorphisms.get(uid.split("_conf")[0])
344
+ if mol2iso is None:
345
+ symm_rmsd = rmsd
346
+ failed_symm_rmsd_count += 1
347
+ else:
348
+ symm_rmsd = get_symmetry_rmsd_with_isomorphisms(
349
+ true_pos, pred_pos, mol2iso
350
+ )
351
+ except TimeoutException:
352
+ symm_rmsd = rmsd
353
+ failed_symm_rmsd_count += 1
354
+ else:
355
+ symm_rmsd = rmsd
356
+
357
+ results = {
358
+ "tr_pred": tr_pred,
359
+ "tr_err": float(tr_err),
360
+ "symm_rmsd": float(symm_rmsd),
361
+ "rmsd": float(rmsd),
362
+ "pred_pos": pred_pos,
363
+ }
364
+ for score_name in set(score_names) - {"random", "symm_rmsd"}:
365
+ results[score_name] = float(samples[idx][score_name])
366
+
367
+ samples_results.append(results)
368
+ samples_results_dict = {
369
+ "sample_metrics": samples_results,
370
+ "true_pos": true_pos,
371
+ "orig_mol": samples[0]["orig_mol"],
372
+ }
373
+ if len(samples_results_dict["sample_metrics"]) > 0:
374
+ full_results[uid] = samples_results_dict
375
+ else:
376
+ print(f"{uid} has no valid samples")
377
+ print(
378
+ f"{uid} true_pos.shape[0] != pred_pos.shape[0]",
379
+ true_pos.shape,
380
+ pred_pos.shape,
381
+ )
382
+
383
+ if len(full_results) != len(all_real_rmsds):
384
+ print("Initial length of test_names", len(all_real_rmsds))
385
+ print("Length of full_results", len(full_results))
386
+
387
+ rows_list, all_scored_results = get_final_results_for_df(full_results, score_names)
388
+ return pd.DataFrame(rows_list), all_scored_results, full_results
@@ -0,0 +1,273 @@
1
+ # Taken from https://github.com/RMeli/spyrmsd and https://github.com/gcorso/DiffDock/
2
+
3
+ import signal
4
+ from contextlib import contextmanager
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ from spyrmsd import graph, molecule, qcp, utils
9
+
10
+
11
+ class TimeoutException(Exception): # noqa: N818
12
+ pass
13
+
14
+
15
+ @contextmanager
16
+ def time_limit(seconds):
17
+ def signal_handler(signum, frame):
18
+ raise TimeoutException("Timed out!")
19
+
20
+ signal.signal(signal.SIGALRM, signal_handler)
21
+ signal.alarm(seconds)
22
+ try:
23
+ yield
24
+ finally:
25
+ signal.alarm(0)
26
+
27
+
28
+ def compute_all_isomorphisms(rdkit_mol):
29
+ """Compute all graph isomorphisms for a molecule.
30
+
31
+ Args:
32
+ rdkit_mol: RDKit molecule object
33
+
34
+ Returns:
35
+ List of isomorphism tuples (idx1, idx2)
36
+ """
37
+ try:
38
+ with time_limit(2):
39
+ mol = molecule.Molecule.from_rdkit(rdkit_mol)
40
+ G1 = graph.graph_from_adjacency_matrix(mol.adjacency_matrix, mol.atomicnums)
41
+ isomorphisms = graph.match_graphs(G1, G1)
42
+ except TimeoutException:
43
+ isomorphisms = [
44
+ (list(range(rdkit_mol.GetNumAtoms())), list(range(rdkit_mol.GetNumAtoms())))
45
+ ]
46
+ return isomorphisms
47
+
48
+
49
+ def get_symmetry_rmsd_with_isomorphisms(coords1, coords2, isomorphisms):
50
+ """Compute symmetry-corrected RMSD using precomputed isomorphisms.
51
+
52
+ Args:
53
+ coords1: Reference coordinates (N, 3)
54
+ coords2: Query coordinates (N, 3)
55
+ isomorphisms: List of isomorphism tuples from compute_all_isomorphisms
56
+
57
+ Returns:
58
+ Minimum RMSD over all isomorphisms
59
+ """
60
+ with time_limit(1):
61
+ assert coords1.shape == coords2.shape
62
+
63
+ n = coords1.shape[0]
64
+ min_result = np.inf
65
+
66
+ for idx1, idx2 in isomorphisms:
67
+ c1i = coords1[idx1, :]
68
+ c2i = coords2[idx2, :]
69
+ result = np.sum((c1i - c2i) ** 2)
70
+
71
+ if result < min_result:
72
+ min_result = result
73
+
74
+ min_result = np.sqrt(min_result / n)
75
+ return min_result
76
+
77
+
78
+ def get_symmetry_rmsd(mol, coords1, coords2, mol2=None, return_permutation=False):
79
+ """Compute symmetry-corrected RMSD.
80
+
81
+ Args:
82
+ mol: RDKit molecule
83
+ coords1: Reference coordinates
84
+ coords2: Query coordinates
85
+ mol2: Optional second molecule (if different from mol)
86
+ return_permutation: Whether to return the best permutation
87
+
88
+ Returns:
89
+ RMSD value (and permutation if requested)
90
+ """
91
+ with time_limit(10):
92
+ mol = molecule.Molecule.from_rdkit(mol)
93
+ mol2 = molecule.Molecule.from_rdkit(mol2) if mol2 is not None else mol2
94
+ mol2_atomicnums = mol2.atomicnums if mol2 is not None else mol.atomicnums
95
+ mol2_adjacency_matrix = (
96
+ mol2.adjacency_matrix if mol2 is not None else mol.adjacency_matrix
97
+ )
98
+ RMSD = symmrmsd(
99
+ coords1,
100
+ coords2,
101
+ mol.atomicnums,
102
+ mol2_atomicnums,
103
+ mol.adjacency_matrix,
104
+ mol2_adjacency_matrix,
105
+ return_permutation=return_permutation,
106
+ )
107
+ return RMSD
108
+
109
+
110
+ def _rmsd_isomorphic_core(
111
+ coords1: np.ndarray,
112
+ coords2: np.ndarray,
113
+ aprops1: np.ndarray,
114
+ aprops2: np.ndarray,
115
+ am1: np.ndarray,
116
+ am2: np.ndarray,
117
+ center: bool = False,
118
+ minimize: bool = False,
119
+ isomorphisms: list[tuple[list[int], list[int]]] | None = None,
120
+ atol: float = 1e-9,
121
+ ) -> tuple[float, list[tuple[list[int], list[int]]], tuple[list[int], list[int]]]:
122
+ """
123
+ Compute RMSD using graph isomorphism.
124
+
125
+ Parameters
126
+ ----------
127
+ coords1: np.ndarray
128
+ Coordinate of molecule 1
129
+ coords2: np.ndarray
130
+ Coordinates of molecule 2
131
+ aprops1: np.ndarray
132
+ Atomic properties for molecule 1
133
+ aprops2: np.ndarray
134
+ Atomic properties for molecule 2
135
+ am1: np.ndarray
136
+ Adjacency matrix for molecule 1
137
+ am2: np.ndarray
138
+ Adjacency matrix for molecule 2
139
+ center: bool
140
+ Centering flag
141
+ minimize: bool
142
+ Compute minized RMSD
143
+ isomorphisms: Optional[List[Dict[int,int]]]
144
+ Previously computed graph isomorphism
145
+ atol: float
146
+ Absolute tolerance parameter for QCP (see :func:`qcp_rmsd`)
147
+
148
+ Returns
149
+ -------
150
+ Tuple[float, List[Dict[int, int]]]
151
+ RMSD (after graph matching) and graph isomorphisms
152
+ """
153
+
154
+ assert coords1.shape == coords2.shape
155
+
156
+ n = coords1.shape[0]
157
+
158
+ c1 = utils.center(coords1) if center or minimize else coords1
159
+ c2 = utils.center(coords2) if center or minimize else coords2
160
+
161
+ if isomorphisms is None:
162
+ G1 = graph.graph_from_adjacency_matrix(am1, aprops1)
163
+ G2 = graph.graph_from_adjacency_matrix(am2, aprops2)
164
+ isomorphisms = graph.match_graphs(G1, G2)
165
+
166
+ min_result = np.inf
167
+ min_isomorphisms = None
168
+
169
+ for idx1, idx2 in isomorphisms:
170
+ c1i = c1[idx1, :]
171
+ c2i = c2[idx2, :]
172
+
173
+ if not minimize:
174
+ result = np.sum((c1i - c2i) ** 2)
175
+ else:
176
+ result = qcp.qcp_rmsd(c1i, c2i, atol)
177
+
178
+ if result < min_result:
179
+ min_result = result
180
+ min_isomorphisms = (idx1, idx2)
181
+
182
+ if not minimize:
183
+ min_result = np.sqrt(min_result / n)
184
+
185
+ return min_result, isomorphisms, min_isomorphisms
186
+
187
+
188
+ def symmrmsd(
189
+ coordsref: np.ndarray,
190
+ coords: np.ndarray | list[np.ndarray],
191
+ apropsref: np.ndarray,
192
+ aprops: np.ndarray,
193
+ amref: np.ndarray,
194
+ am: np.ndarray,
195
+ center: bool = False,
196
+ minimize: bool = False,
197
+ cache: bool = True,
198
+ atol: float = 1e-9,
199
+ return_permutation: bool = False,
200
+ ) -> Any:
201
+ """
202
+ Compute RMSD using graph isomorphism for multiple coordinates.
203
+
204
+ Parameters
205
+ ----------
206
+ coordsref: np.ndarray
207
+ Coordinate of reference molecule
208
+ coords: List[np.ndarray]
209
+ Coordinates of other molecule
210
+ apropsref: np.ndarray
211
+ Atomic properties for reference
212
+ aprops: np.ndarray
213
+ Atomic properties for other molecule
214
+ amref: np.ndarray
215
+ Adjacency matrix for reference molecule
216
+ am: np.ndarray
217
+ Adjacency matrix for other molecule
218
+ center: bool
219
+ Centering flag
220
+ minimize: bool
221
+ Minimum RMSD
222
+ cache: bool
223
+ Cache graph isomorphisms
224
+ atol: float
225
+ Absolute tolerance parameter for QCP (see :func:`qcp_rmsd`)
226
+
227
+ Returns
228
+ -------
229
+ float: Union[float, List[float]]
230
+ Symmetry-corrected RMSD(s) and graph isomorphisms
231
+ """
232
+
233
+ if isinstance(coords, list):
234
+ RMSD: Any = []
235
+ isomorphism = None
236
+ min_iso = []
237
+
238
+ for c in coords:
239
+ if not cache:
240
+ isomorphism = None
241
+
242
+ srmsd, isomorphism, min_i = _rmsd_isomorphic_core(
243
+ coordsref,
244
+ c,
245
+ apropsref,
246
+ aprops,
247
+ amref,
248
+ am,
249
+ center=center,
250
+ minimize=minimize,
251
+ isomorphisms=isomorphism,
252
+ atol=atol,
253
+ )
254
+ min_iso.append(min_i)
255
+ RMSD.append(srmsd)
256
+
257
+ else:
258
+ RMSD, isomorphism, min_iso = _rmsd_isomorphic_core(
259
+ coordsref,
260
+ coords,
261
+ apropsref,
262
+ aprops,
263
+ amref,
264
+ am,
265
+ center=center,
266
+ minimize=minimize,
267
+ isomorphisms=None,
268
+ atol=atol,
269
+ )
270
+
271
+ if return_permutation:
272
+ return RMSD, min_iso
273
+ return RMSD
@@ -0,0 +1 @@
1
+ """Utility functions for posebench-fast."""