posebench-fast 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- posebench_fast/__init__.py +53 -0
- posebench_fast/datasets/__init__.py +1 -0
- posebench_fast/filters/__init__.py +15 -0
- posebench_fast/filters/fast_filters.py +526 -0
- posebench_fast/metrics/__init__.py +31 -0
- posebench_fast/metrics/aggregation.py +388 -0
- posebench_fast/metrics/rmsd.py +273 -0
- posebench_fast/utils/__init__.py +1 -0
- posebench_fast-0.1.0.dist-info/METADATA +109 -0
- posebench_fast-0.1.0.dist-info/RECORD +11 -0
- posebench_fast-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""
|
|
2
|
+
posebench-fast: Fast docking evaluation metrics
|
|
3
|
+
|
|
4
|
+
Provides:
|
|
5
|
+
- Symmetry-corrected RMSD computation
|
|
6
|
+
- Fast PoseBusters filters (without full energy evaluation)
|
|
7
|
+
- Docking metrics (success rates, averages)
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from posebench_fast.filters.fast_filters import (
|
|
11
|
+
calc_posebusters,
|
|
12
|
+
check_geometry,
|
|
13
|
+
check_intermolecular_distance,
|
|
14
|
+
check_volume_overlap,
|
|
15
|
+
)
|
|
16
|
+
from posebench_fast.metrics.aggregation import (
|
|
17
|
+
filter_results_by_fast,
|
|
18
|
+
filter_results_by_posebusters,
|
|
19
|
+
get_best_results_by_score,
|
|
20
|
+
get_final_results_for_df,
|
|
21
|
+
get_simple_metrics_df,
|
|
22
|
+
)
|
|
23
|
+
from posebench_fast.metrics.rmsd import (
|
|
24
|
+
TimeoutException,
|
|
25
|
+
compute_all_isomorphisms,
|
|
26
|
+
get_symmetry_rmsd,
|
|
27
|
+
get_symmetry_rmsd_with_isomorphisms,
|
|
28
|
+
symmrmsd,
|
|
29
|
+
time_limit,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
__version__ = "0.1.0"
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
# RMSD
|
|
36
|
+
"compute_all_isomorphisms",
|
|
37
|
+
"get_symmetry_rmsd_with_isomorphisms",
|
|
38
|
+
"get_symmetry_rmsd",
|
|
39
|
+
"symmrmsd",
|
|
40
|
+
"TimeoutException",
|
|
41
|
+
"time_limit",
|
|
42
|
+
# Filters
|
|
43
|
+
"calc_posebusters",
|
|
44
|
+
"check_intermolecular_distance",
|
|
45
|
+
"check_volume_overlap",
|
|
46
|
+
"check_geometry",
|
|
47
|
+
# Metrics
|
|
48
|
+
"get_simple_metrics_df",
|
|
49
|
+
"get_final_results_for_df",
|
|
50
|
+
"filter_results_by_posebusters",
|
|
51
|
+
"filter_results_by_fast",
|
|
52
|
+
"get_best_results_by_score",
|
|
53
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Dataset utilities for posebench-fast."""
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Fast PoseBusters filters for docking evaluation."""
|
|
2
|
+
|
|
3
|
+
from posebench_fast.filters.fast_filters import (
|
|
4
|
+
calc_posebusters,
|
|
5
|
+
check_geometry,
|
|
6
|
+
check_intermolecular_distance,
|
|
7
|
+
check_volume_overlap,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"calc_posebusters",
|
|
12
|
+
"check_intermolecular_distance",
|
|
13
|
+
"check_volume_overlap",
|
|
14
|
+
"check_geometry",
|
|
15
|
+
]
|
|
@@ -0,0 +1,526 @@
|
|
|
1
|
+
"""Fast PoseBusters filters without full energy evaluation."""
|
|
2
|
+
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import torch
|
|
8
|
+
from posebusters.modules.intermolecular_distance import _pairwise_distance
|
|
9
|
+
from rdkit import Chem
|
|
10
|
+
from rdkit.Chem import MolFromSmarts
|
|
11
|
+
from rdkit.Chem.rdchem import GetPeriodicTable, Mol
|
|
12
|
+
from rdkit.Chem.rdDistGeom import GetMoleculeBoundsMatrix
|
|
13
|
+
from rdkit.Chem.rdmolops import SanitizeMol
|
|
14
|
+
from rdkit.Chem.rdShapeHelpers import ShapeTverskyIndex
|
|
15
|
+
|
|
16
|
+
_periodic_table = GetPeriodicTable()
|
|
17
|
+
# get all atoms from periodic table
|
|
18
|
+
atoms_vocab = {
|
|
19
|
+
_periodic_table.GetElementSymbol(i + 1): i
|
|
20
|
+
for i in range(_periodic_table.GetMaxAtomicNumber())
|
|
21
|
+
}
|
|
22
|
+
vdw_radius = torch.tensor(
|
|
23
|
+
[
|
|
24
|
+
_periodic_table.GetRvdw(_periodic_table.GetElementSymbol(i + 1))
|
|
25
|
+
for i in range(_periodic_table.GetMaxAtomicNumber())
|
|
26
|
+
]
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
col_lb = "lower_bound"
|
|
30
|
+
col_ub = "upper_bound"
|
|
31
|
+
col_pe = "percent_error"
|
|
32
|
+
col_bpe = "bound_percent_error"
|
|
33
|
+
col_bape = "bound_absolute_percent_error"
|
|
34
|
+
|
|
35
|
+
bound_matrix_params = {
|
|
36
|
+
"set15bounds": True,
|
|
37
|
+
"scaleVDW": True,
|
|
38
|
+
"doTriangleSmoothing": True,
|
|
39
|
+
"useMacrocycle14config": False,
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
col_n_bonds = "number_bonds"
|
|
43
|
+
col_shortest_bond = "shortest_bond_relative_length"
|
|
44
|
+
col_longest_bond = "longest_bond_relative_length"
|
|
45
|
+
col_n_short_bonds = "number_short_outlier_bonds"
|
|
46
|
+
col_n_long_bonds = "number_long_outlier_bonds"
|
|
47
|
+
col_n_good_bonds = "number_valid_bonds"
|
|
48
|
+
col_bonds_result = "bond_lengths_within_bounds"
|
|
49
|
+
col_n_angles = "number_angles"
|
|
50
|
+
col_extremest_angle = "most_extreme_relative_angle"
|
|
51
|
+
col_n_bad_angles = "number_outlier_angles"
|
|
52
|
+
col_n_good_angles = "number_valid_angles"
|
|
53
|
+
col_angles_result = "bond_angles_within_bounds"
|
|
54
|
+
col_n_noncov = "number_noncov_pairs"
|
|
55
|
+
col_closest_noncov = "shortest_noncovalent_relative_distance"
|
|
56
|
+
col_n_clashes = "number_clashes"
|
|
57
|
+
col_n_good_noncov = "number_valid_noncov_pairs"
|
|
58
|
+
col_clash_result = "no_internal_clash"
|
|
59
|
+
|
|
60
|
+
_empty_results = {
|
|
61
|
+
col_n_bonds: np.nan,
|
|
62
|
+
col_shortest_bond: np.nan,
|
|
63
|
+
col_longest_bond: np.nan,
|
|
64
|
+
col_n_short_bonds: np.nan,
|
|
65
|
+
col_n_long_bonds: np.nan,
|
|
66
|
+
col_bonds_result: np.nan,
|
|
67
|
+
col_n_angles: np.nan,
|
|
68
|
+
col_extremest_angle: np.nan,
|
|
69
|
+
col_n_bad_angles: np.nan,
|
|
70
|
+
col_angles_result: np.nan,
|
|
71
|
+
col_n_noncov: np.nan,
|
|
72
|
+
col_closest_noncov: np.nan,
|
|
73
|
+
col_n_clashes: np.nan,
|
|
74
|
+
col_clash_result: np.nan,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# Allowable features for atom types (simplified version)
|
|
79
|
+
ALLOWABLE_ATOM_TYPES = [
|
|
80
|
+
1,
|
|
81
|
+
5,
|
|
82
|
+
6,
|
|
83
|
+
7,
|
|
84
|
+
8,
|
|
85
|
+
9,
|
|
86
|
+
14,
|
|
87
|
+
15,
|
|
88
|
+
16,
|
|
89
|
+
17,
|
|
90
|
+
23,
|
|
91
|
+
26,
|
|
92
|
+
27,
|
|
93
|
+
29,
|
|
94
|
+
30,
|
|
95
|
+
33,
|
|
96
|
+
34,
|
|
97
|
+
35,
|
|
98
|
+
44,
|
|
99
|
+
51,
|
|
100
|
+
53,
|
|
101
|
+
78,
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def symmetrize_conjugated_terminal_bonds(df: pd.DataFrame, mol: Mol) -> pd.DataFrame:
|
|
106
|
+
"""
|
|
107
|
+
Symmetrize the lower and upper bounds of the conjugated terminal bonds.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
df: Dataframe with the bond geometry information and bounds.
|
|
111
|
+
mol: RDKit molecule object
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
Dataframe with symmetrized bounds for conjugated terminal bonds.
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
def _sort_bond_ids(bond_ids: tuple) -> tuple:
|
|
118
|
+
return tuple(tuple(sorted(_)) for _ in bond_ids)
|
|
119
|
+
|
|
120
|
+
def _get_terminal_group_matches(_mol: Mol) -> tuple:
|
|
121
|
+
qsmarts = "[O,N;D1;$([O,N;D1]-[*]=[O,N;D1]),$([O,N;D1]=[*]-[O,N;D1])]~[*]"
|
|
122
|
+
qsmarts = MolFromSmarts(qsmarts)
|
|
123
|
+
matches = _mol.GetSubstructMatches(qsmarts)
|
|
124
|
+
return _sort_bond_ids(matches)
|
|
125
|
+
|
|
126
|
+
df["atom_types_sorted"] = df["atom_types"].apply(
|
|
127
|
+
lambda a: tuple(sorted(a.split("--")))
|
|
128
|
+
)
|
|
129
|
+
matches = _get_terminal_group_matches(mol)
|
|
130
|
+
matched = df[df["atom_pair"].isin(matches)].copy()
|
|
131
|
+
grouped = matched.groupby("atom_types_sorted").agg(
|
|
132
|
+
{"lower_bound": np.amin, "upper_bound": np.amax}
|
|
133
|
+
)
|
|
134
|
+
index_orig = matched.index
|
|
135
|
+
matched = matched.set_index("atom_types_sorted")
|
|
136
|
+
matched.update(grouped)
|
|
137
|
+
matched = matched.set_index(index_orig)
|
|
138
|
+
df.update(matched)
|
|
139
|
+
return df.drop(columns=["atom_types_sorted"])
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _get_bond_atom_indices(mol: Mol) -> list:
|
|
143
|
+
bonds = []
|
|
144
|
+
for bond in mol.GetBonds():
|
|
145
|
+
bond_tuple = (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())
|
|
146
|
+
bond_tuple = _sort_bond(bond_tuple)
|
|
147
|
+
bonds.append(bond_tuple)
|
|
148
|
+
return bonds
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _get_angle_atom_indices(bonds: list) -> list:
|
|
152
|
+
"""Check all combinations of bonds to generate list of molecule angles."""
|
|
153
|
+
angles = []
|
|
154
|
+
bonds = list(bonds)
|
|
155
|
+
for i in range(len(bonds)):
|
|
156
|
+
for j in range(i + 1, len(bonds)):
|
|
157
|
+
angle = _two_bonds_to_angle(bonds[i], bonds[j])
|
|
158
|
+
if angle is not None:
|
|
159
|
+
angles.append(angle)
|
|
160
|
+
return angles
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _two_bonds_to_angle(bond1: tuple, bond2: tuple):
|
|
164
|
+
set1 = set(bond1)
|
|
165
|
+
set2 = set(bond2)
|
|
166
|
+
all_atoms = set1 | set2
|
|
167
|
+
if len(all_atoms) != 3:
|
|
168
|
+
return None
|
|
169
|
+
shared_atom = set1 & set2
|
|
170
|
+
other_atoms = all_atoms - shared_atom
|
|
171
|
+
return (min(other_atoms), shared_atom.pop(), max(other_atoms))
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def _sort_bond(bond: tuple) -> tuple:
|
|
175
|
+
return (min(bond), max(bond))
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _has_hydrogen(mol: Mol, idcs) -> bool:
|
|
179
|
+
return any(_is_hydrogen(mol, idx) for idx in idcs)
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _is_hydrogen(mol: Mol, idx: int) -> bool:
|
|
183
|
+
return mol.GetAtomWithIdx(int(idx)).GetAtomicNum() == 1
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def mol_from_symbols_and_npcoords(symbols, coords_np: np.ndarray):
|
|
187
|
+
"""Create RDKit mol from symbols and coordinates.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
symbols: List of atom symbols
|
|
191
|
+
coords_np: Coordinates array shape (N, 3)
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
RDKit Mol with conformer
|
|
195
|
+
"""
|
|
196
|
+
assert coords_np.shape == (len(symbols), 3)
|
|
197
|
+
rw = Chem.RWMol()
|
|
198
|
+
for sym in symbols:
|
|
199
|
+
a = Chem.Atom(sym)
|
|
200
|
+
a.SetNoImplicit(True)
|
|
201
|
+
a.SetNumExplicitHs(0)
|
|
202
|
+
rw.AddAtom(a)
|
|
203
|
+
m = rw.GetMol()
|
|
204
|
+
conf = Chem.Conformer(len(symbols))
|
|
205
|
+
conf.SetPositions(coords_np.astype(np.float64, copy=False))
|
|
206
|
+
m.AddConformer(conf, assignId=True)
|
|
207
|
+
return m
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def check_intermolecular_distance(
|
|
211
|
+
mol_orig,
|
|
212
|
+
pos_pred,
|
|
213
|
+
pos_cond,
|
|
214
|
+
atom_names_pred,
|
|
215
|
+
atom_names_cond,
|
|
216
|
+
radius_type: str = "vdw",
|
|
217
|
+
radius_scale: float = 1.0,
|
|
218
|
+
clash_cutoff: float = 0.75,
|
|
219
|
+
clash_cutoff_volume: float = 0.075,
|
|
220
|
+
ignore_types: set | None = None,
|
|
221
|
+
max_distance: float = 5.0,
|
|
222
|
+
search_distance: float = 6.0,
|
|
223
|
+
vdw_scale: float = 0.8,
|
|
224
|
+
):
|
|
225
|
+
"""Check that predicted molecule is not too close and not too far away from conditioning molecule.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
mol_orig: Original ligand molecule
|
|
229
|
+
pos_pred: Predicted ligand positions (n_preds, n_atoms, 3)
|
|
230
|
+
pos_cond: Conditioning (protein) positions (n_atoms, 3)
|
|
231
|
+
atom_names_pred: Ligand atom names
|
|
232
|
+
atom_names_cond: Protein atom names
|
|
233
|
+
radius_type: Type of atomic radius ("vdw" or "covalent")
|
|
234
|
+
radius_scale: Scaling factor for radii
|
|
235
|
+
clash_cutoff: Threshold for clash detection
|
|
236
|
+
clash_cutoff_volume: Threshold for volume overlap
|
|
237
|
+
ignore_types: Atom types to ignore
|
|
238
|
+
max_distance: Maximum allowed distance
|
|
239
|
+
search_distance: Search distance for nearby atoms
|
|
240
|
+
vdw_scale: VDW radius scale for volume calculation
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
Dictionary with filter results
|
|
244
|
+
"""
|
|
245
|
+
if ignore_types is None:
|
|
246
|
+
ignore_types = {"H"}
|
|
247
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
248
|
+
coords_ligand = torch.tensor(pos_pred, device=device).float()
|
|
249
|
+
coords_protein = torch.tensor(pos_cond, device=device).float()
|
|
250
|
+
|
|
251
|
+
atoms_ligand = torch.tensor(
|
|
252
|
+
[atoms_vocab[atom] for atom in atom_names_pred], device=device
|
|
253
|
+
).long()
|
|
254
|
+
atoms_protein_all = torch.tensor(
|
|
255
|
+
[atoms_vocab[atom] for atom in atom_names_cond], device=device
|
|
256
|
+
).long()
|
|
257
|
+
|
|
258
|
+
mask = atoms_ligand != atoms_vocab["H"]
|
|
259
|
+
coords_ligand = coords_ligand[:, mask, :]
|
|
260
|
+
atoms_ligand = atoms_ligand[mask]
|
|
261
|
+
if ignore_types:
|
|
262
|
+
mask = atoms_protein_all != atoms_vocab["H"]
|
|
263
|
+
coords_protein = coords_protein[mask, :]
|
|
264
|
+
atoms_protein_all = atoms_protein_all[mask]
|
|
265
|
+
|
|
266
|
+
radius_ligand = vdw_radius.to(device)[atoms_ligand]
|
|
267
|
+
radius_protein_all = vdw_radius.to(device)[atoms_protein_all]
|
|
268
|
+
|
|
269
|
+
distances_all = (coords_ligand[:, :, None] - coords_protein[None, None, :]).norm(
|
|
270
|
+
dim=-1
|
|
271
|
+
)
|
|
272
|
+
distances = distances_all
|
|
273
|
+
radius_protein = radius_protein_all
|
|
274
|
+
|
|
275
|
+
is_buried_fraction = (distances < 5).any(dim=-1).sum(dim=-1) / distances.size(1)
|
|
276
|
+
|
|
277
|
+
radius_sum = radius_ligand[None, :, None] + radius_protein[None, None, :]
|
|
278
|
+
distance = distances
|
|
279
|
+
sum_radii_scaled = radius_sum * radius_scale
|
|
280
|
+
relative_distance = distance / sum_radii_scaled
|
|
281
|
+
clash = relative_distance < clash_cutoff
|
|
282
|
+
|
|
283
|
+
candidates = distance < (
|
|
284
|
+
(radius_ligand[None, :, None] + radius_protein_all[None, None, :]) * vdw_scale
|
|
285
|
+
+ 2 * 3 * 0.25
|
|
286
|
+
)
|
|
287
|
+
ids_conds = candidates.any(dim=1).cpu().numpy()
|
|
288
|
+
overlap = []
|
|
289
|
+
for i in range(coords_ligand.size(0)):
|
|
290
|
+
ids_cond = ids_conds[i]
|
|
291
|
+
overlap.append(
|
|
292
|
+
ShapeTverskyIndex(
|
|
293
|
+
mol_from_symbols_and_npcoords(atom_names_pred, pos_pred[i]),
|
|
294
|
+
mol_from_symbols_and_npcoords(
|
|
295
|
+
atom_names_cond[ids_cond], pos_cond[ids_cond]
|
|
296
|
+
),
|
|
297
|
+
alpha=1,
|
|
298
|
+
beta=0,
|
|
299
|
+
vdwScale=vdw_scale,
|
|
300
|
+
)
|
|
301
|
+
< clash_cutoff_volume
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
results = {
|
|
305
|
+
"not_too_far_away": (
|
|
306
|
+
distance.reshape(distance.size(0), -1).min(dim=-1)[0] <= max_distance
|
|
307
|
+
).tolist(),
|
|
308
|
+
"no_clashes": torch.logical_not(clash.any(dim=(1, 2))).tolist(),
|
|
309
|
+
"no_volume_clash": overlap,
|
|
310
|
+
"is_buried_fraction": is_buried_fraction.tolist(),
|
|
311
|
+
"no_internal_clash": check_geometry(
|
|
312
|
+
mol_orig,
|
|
313
|
+
coords_ligand,
|
|
314
|
+
threshold_bad_bond_length=0.25,
|
|
315
|
+
threshold_clash=0.3,
|
|
316
|
+
threshold_bad_angle=0.25,
|
|
317
|
+
bound_matrix_params=bound_matrix_params,
|
|
318
|
+
ignore_hydrogens=True,
|
|
319
|
+
sanitize=True,
|
|
320
|
+
symmetrize_conjugated_terminal_groups=True,
|
|
321
|
+
),
|
|
322
|
+
}
|
|
323
|
+
return {"results": results}
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def check_volume_overlap(
|
|
327
|
+
pos_pred,
|
|
328
|
+
pos_cond,
|
|
329
|
+
atom_names_pred,
|
|
330
|
+
atom_names_cond,
|
|
331
|
+
clash_cutoff: float = 0.05,
|
|
332
|
+
vdw_scale: float = 0.8,
|
|
333
|
+
ignore_types: set | None = None,
|
|
334
|
+
search_distance: float = 6.0,
|
|
335
|
+
):
|
|
336
|
+
"""Check volume overlap between ligand and protein.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
pos_pred: Predicted ligand positions
|
|
340
|
+
pos_cond: Protein positions
|
|
341
|
+
atom_names_pred: Ligand atom names
|
|
342
|
+
atom_names_cond: Protein atom names
|
|
343
|
+
clash_cutoff: Maximum allowed volume overlap fraction
|
|
344
|
+
vdw_scale: VDW radius scale
|
|
345
|
+
ignore_types: Atom types to ignore
|
|
346
|
+
search_distance: Search distance for nearby atoms
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
Dictionary with volume overlap results
|
|
350
|
+
"""
|
|
351
|
+
if ignore_types is None:
|
|
352
|
+
ignore_types = {"H"}
|
|
353
|
+
keep_mask = atom_names_cond != "H"
|
|
354
|
+
pos_cond = pos_cond[keep_mask]
|
|
355
|
+
atom_names_cond = atom_names_cond[keep_mask]
|
|
356
|
+
if len(pos_cond) == 0:
|
|
357
|
+
return {"results": {"volume_overlap": np.nan, "no_volume_clash": True}}
|
|
358
|
+
|
|
359
|
+
distances = _pairwise_distance(pos_pred, pos_cond)
|
|
360
|
+
keep_mask = distances.min(axis=0) <= search_distance * vdw_scale
|
|
361
|
+
pos_cond = pos_cond[keep_mask]
|
|
362
|
+
atom_names_cond = atom_names_cond[keep_mask]
|
|
363
|
+
if len(pos_cond) == 0:
|
|
364
|
+
return {"results": {"volume_overlap": np.nan, "no_volume_clash": True}}
|
|
365
|
+
|
|
366
|
+
ignore_hydrogens = "H" in ignore_types
|
|
367
|
+
overlap = ShapeTverskyIndex(
|
|
368
|
+
mol_from_symbols_and_npcoords(atom_names_pred, pos_pred),
|
|
369
|
+
mol_from_symbols_and_npcoords(atom_names_cond, pos_cond),
|
|
370
|
+
alpha=1,
|
|
371
|
+
beta=0,
|
|
372
|
+
vdwScale=vdw_scale,
|
|
373
|
+
ignoreHs=ignore_hydrogens,
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
results = {
|
|
377
|
+
"volume_overlap": overlap,
|
|
378
|
+
"no_volume_clash": overlap <= clash_cutoff,
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
return {"results": results}
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def check_geometry(
|
|
385
|
+
mol_orig,
|
|
386
|
+
pos_preds,
|
|
387
|
+
threshold_bad_bond_length: float = 0.25,
|
|
388
|
+
threshold_clash: float = 0.3,
|
|
389
|
+
threshold_bad_angle: float = 0.25,
|
|
390
|
+
bound_matrix_params=bound_matrix_params,
|
|
391
|
+
ignore_hydrogens: bool = True,
|
|
392
|
+
sanitize: bool = True,
|
|
393
|
+
symmetrize_conjugated_terminal_groups: bool = True,
|
|
394
|
+
):
|
|
395
|
+
"""Use RDKit distance geometry bounds to check the geometry of a molecule.
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
mol_orig: Original molecule
|
|
399
|
+
pos_preds: Predicted positions tensor
|
|
400
|
+
threshold_bad_bond_length: Bond length threshold
|
|
401
|
+
threshold_clash: Clash threshold
|
|
402
|
+
threshold_bad_angle: Angle threshold
|
|
403
|
+
bound_matrix_params: Parameters for GetMoleculeBoundsMatrix
|
|
404
|
+
ignore_hydrogens: Whether to ignore hydrogens
|
|
405
|
+
sanitize: Whether to sanitize molecule
|
|
406
|
+
symmetrize_conjugated_terminal_groups: Whether to symmetrize terminal groups
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
List of booleans indicating valid geometry
|
|
410
|
+
"""
|
|
411
|
+
mol_pred = deepcopy(mol_orig)
|
|
412
|
+
mol_pred.GetConformer().SetPositions(pos_preds[0].cpu().numpy().astype(np.float64))
|
|
413
|
+
assert mol_pred.GetNumConformers() == 1, "Molecule must have exactly one conformer"
|
|
414
|
+
mol = deepcopy(mol_pred)
|
|
415
|
+
results = _empty_results.copy()
|
|
416
|
+
if mol.GetNumConformers() == 0:
|
|
417
|
+
print("Molecule does not have a conformer.")
|
|
418
|
+
return {"results": results}
|
|
419
|
+
if mol.GetNumAtoms() == 1:
|
|
420
|
+
print(f"Molecule has only {mol.GetNumAtoms()} atoms.")
|
|
421
|
+
results[col_angles_result] = True
|
|
422
|
+
results[col_bonds_result] = True
|
|
423
|
+
results[col_clash_result] = True
|
|
424
|
+
return {"results": results}
|
|
425
|
+
try:
|
|
426
|
+
if sanitize:
|
|
427
|
+
flags = SanitizeMol(mol)
|
|
428
|
+
assert flags == 0, f"Sanitization failed with flags {flags}"
|
|
429
|
+
except Exception:
|
|
430
|
+
return {"results": results}
|
|
431
|
+
bond_set = sorted(_get_bond_atom_indices(mol))
|
|
432
|
+
angles = sorted(_get_angle_atom_indices(bond_set))
|
|
433
|
+
angle_set = {(a[0], a[2]): a for a in angles}
|
|
434
|
+
if len(bond_set) == 0:
|
|
435
|
+
print("Molecule has no bonds.")
|
|
436
|
+
|
|
437
|
+
bounds = GetMoleculeBoundsMatrix(mol, **bound_matrix_params)
|
|
438
|
+
lower_triangle_idcs = np.tril_indices(mol.GetNumAtoms(), k=-1)
|
|
439
|
+
upper_triangle_idcs = (lower_triangle_idcs[1], lower_triangle_idcs[0])
|
|
440
|
+
df_12 = pd.DataFrame()
|
|
441
|
+
df_12["atom_pair"] = list(zip(*upper_triangle_idcs, strict=False))
|
|
442
|
+
df_12["atom_types"] = [
|
|
443
|
+
"--".join(tuple(mol.GetAtomWithIdx(int(j)).GetSymbol() for j in i))
|
|
444
|
+
for i in df_12["atom_pair"]
|
|
445
|
+
]
|
|
446
|
+
df_12["angle"] = df_12["atom_pair"].apply(lambda x: angle_set.get(x))
|
|
447
|
+
df_12["has_hydrogen"] = [_has_hydrogen(mol, i) for i in df_12["atom_pair"]]
|
|
448
|
+
df_12["is_bond"] = [i in bond_set for i in df_12["atom_pair"]]
|
|
449
|
+
df_12["is_angle"] = df_12["angle"].apply(lambda x: x is not None)
|
|
450
|
+
df_12[col_lb] = bounds[lower_triangle_idcs]
|
|
451
|
+
df_12[col_ub] = bounds[upper_triangle_idcs]
|
|
452
|
+
if symmetrize_conjugated_terminal_groups:
|
|
453
|
+
df_12 = symmetrize_conjugated_terminal_bonds(df_12, mol)
|
|
454
|
+
distances_all = (pos_preds[:, :, None] - pos_preds[:, None, :]).norm(dim=-1)[
|
|
455
|
+
:, lower_triangle_idcs[0], lower_triangle_idcs[1]
|
|
456
|
+
]
|
|
457
|
+
distances_valid = distances_all[:, (~df_12["is_bond"] & ~df_12["is_angle"]).values]
|
|
458
|
+
lower_bounds_valid = torch.tensor(
|
|
459
|
+
df_12[col_lb][~df_12["is_bond"] & ~df_12["is_angle"]].values,
|
|
460
|
+
device=distances_valid.device,
|
|
461
|
+
)
|
|
462
|
+
df_clash = torch.where(
|
|
463
|
+
distances_valid >= lower_bounds_valid[None],
|
|
464
|
+
0,
|
|
465
|
+
(distances_valid - lower_bounds_valid[None]) / lower_bounds_valid[None],
|
|
466
|
+
)
|
|
467
|
+
col_n_clashes_count = (df_clash < -threshold_clash).sum(dim=-1)
|
|
468
|
+
col_n_good_noncov_count = len(df_clash) - col_n_clashes_count
|
|
469
|
+
res = (col_n_good_noncov_count == len(df_clash)).tolist()
|
|
470
|
+
return res
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def calc_posebusters(
|
|
474
|
+
pos_pred, pos_cond, atom_ids_pred, atom_names_cond, names, lig_mol_for_posebusters
|
|
475
|
+
):
|
|
476
|
+
"""Calculate fast PoseBusters filters.
|
|
477
|
+
|
|
478
|
+
Args:
|
|
479
|
+
pos_pred: Predicted ligand positions
|
|
480
|
+
pos_cond: Protein positions
|
|
481
|
+
atom_ids_pred: Ligand atom type IDs
|
|
482
|
+
atom_names_cond: Protein atom names
|
|
483
|
+
names: Sample names (for error logging)
|
|
484
|
+
lig_mol_for_posebusters: Ligand molecule for PoseBusters
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
Dictionary with filter results or None on error
|
|
488
|
+
"""
|
|
489
|
+
if 22 in atom_ids_pred:
|
|
490
|
+
with open("error.txt", "a") as f:
|
|
491
|
+
f.write(f"Error in {names}\n")
|
|
492
|
+
f.write("22 (misc) in atom_ids_pred\n")
|
|
493
|
+
return None
|
|
494
|
+
atom_names_pred = np.array(
|
|
495
|
+
[
|
|
496
|
+
_periodic_table.GetElementSymbol(ALLOWABLE_ATOM_TYPES[atom_id])
|
|
497
|
+
for atom_id in atom_ids_pred
|
|
498
|
+
if atom_id >= 0
|
|
499
|
+
],
|
|
500
|
+
dtype=object,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
posebusters_results = {}
|
|
504
|
+
try:
|
|
505
|
+
assert len(pos_pred[0]) == len(atom_names_pred), (
|
|
506
|
+
f"len(pos_pred[i]) = {len(pos_pred[0])} != len(atom_names_pred[i]) = {len(atom_names_pred)}"
|
|
507
|
+
)
|
|
508
|
+
assert len(pos_cond) == len(atom_names_cond), (
|
|
509
|
+
f"len(pos_cond[i]) = {len(pos_cond[0])} != len(atom_names_cond[i]) = {len(atom_names_cond)}"
|
|
510
|
+
)
|
|
511
|
+
except Exception as e:
|
|
512
|
+
print(f"Error in {names}")
|
|
513
|
+
print(e)
|
|
514
|
+
with open("error.txt", "a") as f:
|
|
515
|
+
f.write(f"Error in {names}\n")
|
|
516
|
+
f.write(
|
|
517
|
+
f"len(pos_pred[i]) = {len(pos_pred[0])} != len(atom_names_pred[i]) = {len(atom_names_pred)}\n"
|
|
518
|
+
)
|
|
519
|
+
return None
|
|
520
|
+
res1 = check_intermolecular_distance(
|
|
521
|
+
lig_mol_for_posebusters, pos_pred, pos_cond, atom_names_pred, atom_names_cond
|
|
522
|
+
)
|
|
523
|
+
res = {**res1["results"]}
|
|
524
|
+
for key in res:
|
|
525
|
+
posebusters_results[key] = res[key]
|
|
526
|
+
return posebusters_results
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Metrics for docking evaluation."""
|
|
2
|
+
|
|
3
|
+
from posebench_fast.metrics.aggregation import (
|
|
4
|
+
filter_results_by_fast,
|
|
5
|
+
filter_results_by_posebusters,
|
|
6
|
+
get_best_results_by_score,
|
|
7
|
+
get_final_results_for_df,
|
|
8
|
+
get_simple_metrics_df,
|
|
9
|
+
)
|
|
10
|
+
from posebench_fast.metrics.rmsd import (
|
|
11
|
+
TimeoutException,
|
|
12
|
+
compute_all_isomorphisms,
|
|
13
|
+
get_symmetry_rmsd,
|
|
14
|
+
get_symmetry_rmsd_with_isomorphisms,
|
|
15
|
+
symmrmsd,
|
|
16
|
+
time_limit,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"compute_all_isomorphisms",
|
|
21
|
+
"get_symmetry_rmsd_with_isomorphisms",
|
|
22
|
+
"get_symmetry_rmsd",
|
|
23
|
+
"symmrmsd",
|
|
24
|
+
"TimeoutException",
|
|
25
|
+
"time_limit",
|
|
26
|
+
"get_simple_metrics_df",
|
|
27
|
+
"get_final_results_for_df",
|
|
28
|
+
"filter_results_by_posebusters",
|
|
29
|
+
"filter_results_by_fast",
|
|
30
|
+
"get_best_results_by_score",
|
|
31
|
+
]
|