posebench-fast 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,53 @@
1
+ """
2
+ posebench-fast: Fast docking evaluation metrics
3
+
4
+ Provides:
5
+ - Symmetry-corrected RMSD computation
6
+ - Fast PoseBusters filters (without full energy evaluation)
7
+ - Docking metrics (success rates, averages)
8
+ """
9
+
10
+ from posebench_fast.filters.fast_filters import (
11
+ calc_posebusters,
12
+ check_geometry,
13
+ check_intermolecular_distance,
14
+ check_volume_overlap,
15
+ )
16
+ from posebench_fast.metrics.aggregation import (
17
+ filter_results_by_fast,
18
+ filter_results_by_posebusters,
19
+ get_best_results_by_score,
20
+ get_final_results_for_df,
21
+ get_simple_metrics_df,
22
+ )
23
+ from posebench_fast.metrics.rmsd import (
24
+ TimeoutException,
25
+ compute_all_isomorphisms,
26
+ get_symmetry_rmsd,
27
+ get_symmetry_rmsd_with_isomorphisms,
28
+ symmrmsd,
29
+ time_limit,
30
+ )
31
+
32
+ __version__ = "0.1.0"
33
+
34
+ __all__ = [
35
+ # RMSD
36
+ "compute_all_isomorphisms",
37
+ "get_symmetry_rmsd_with_isomorphisms",
38
+ "get_symmetry_rmsd",
39
+ "symmrmsd",
40
+ "TimeoutException",
41
+ "time_limit",
42
+ # Filters
43
+ "calc_posebusters",
44
+ "check_intermolecular_distance",
45
+ "check_volume_overlap",
46
+ "check_geometry",
47
+ # Metrics
48
+ "get_simple_metrics_df",
49
+ "get_final_results_for_df",
50
+ "filter_results_by_posebusters",
51
+ "filter_results_by_fast",
52
+ "get_best_results_by_score",
53
+ ]
@@ -0,0 +1 @@
1
+ """Dataset utilities for posebench-fast."""
@@ -0,0 +1,15 @@
1
+ """Fast PoseBusters filters for docking evaluation."""
2
+
3
+ from posebench_fast.filters.fast_filters import (
4
+ calc_posebusters,
5
+ check_geometry,
6
+ check_intermolecular_distance,
7
+ check_volume_overlap,
8
+ )
9
+
10
+ __all__ = [
11
+ "calc_posebusters",
12
+ "check_intermolecular_distance",
13
+ "check_volume_overlap",
14
+ "check_geometry",
15
+ ]
@@ -0,0 +1,526 @@
1
+ """Fast PoseBusters filters without full energy evaluation."""
2
+
3
+ from copy import deepcopy
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ from posebusters.modules.intermolecular_distance import _pairwise_distance
9
+ from rdkit import Chem
10
+ from rdkit.Chem import MolFromSmarts
11
+ from rdkit.Chem.rdchem import GetPeriodicTable, Mol
12
+ from rdkit.Chem.rdDistGeom import GetMoleculeBoundsMatrix
13
+ from rdkit.Chem.rdmolops import SanitizeMol
14
+ from rdkit.Chem.rdShapeHelpers import ShapeTverskyIndex
15
+
16
+ _periodic_table = GetPeriodicTable()
17
+ # get all atoms from periodic table
18
+ atoms_vocab = {
19
+ _periodic_table.GetElementSymbol(i + 1): i
20
+ for i in range(_periodic_table.GetMaxAtomicNumber())
21
+ }
22
+ vdw_radius = torch.tensor(
23
+ [
24
+ _periodic_table.GetRvdw(_periodic_table.GetElementSymbol(i + 1))
25
+ for i in range(_periodic_table.GetMaxAtomicNumber())
26
+ ]
27
+ )
28
+
29
+ col_lb = "lower_bound"
30
+ col_ub = "upper_bound"
31
+ col_pe = "percent_error"
32
+ col_bpe = "bound_percent_error"
33
+ col_bape = "bound_absolute_percent_error"
34
+
35
+ bound_matrix_params = {
36
+ "set15bounds": True,
37
+ "scaleVDW": True,
38
+ "doTriangleSmoothing": True,
39
+ "useMacrocycle14config": False,
40
+ }
41
+
42
+ col_n_bonds = "number_bonds"
43
+ col_shortest_bond = "shortest_bond_relative_length"
44
+ col_longest_bond = "longest_bond_relative_length"
45
+ col_n_short_bonds = "number_short_outlier_bonds"
46
+ col_n_long_bonds = "number_long_outlier_bonds"
47
+ col_n_good_bonds = "number_valid_bonds"
48
+ col_bonds_result = "bond_lengths_within_bounds"
49
+ col_n_angles = "number_angles"
50
+ col_extremest_angle = "most_extreme_relative_angle"
51
+ col_n_bad_angles = "number_outlier_angles"
52
+ col_n_good_angles = "number_valid_angles"
53
+ col_angles_result = "bond_angles_within_bounds"
54
+ col_n_noncov = "number_noncov_pairs"
55
+ col_closest_noncov = "shortest_noncovalent_relative_distance"
56
+ col_n_clashes = "number_clashes"
57
+ col_n_good_noncov = "number_valid_noncov_pairs"
58
+ col_clash_result = "no_internal_clash"
59
+
60
+ _empty_results = {
61
+ col_n_bonds: np.nan,
62
+ col_shortest_bond: np.nan,
63
+ col_longest_bond: np.nan,
64
+ col_n_short_bonds: np.nan,
65
+ col_n_long_bonds: np.nan,
66
+ col_bonds_result: np.nan,
67
+ col_n_angles: np.nan,
68
+ col_extremest_angle: np.nan,
69
+ col_n_bad_angles: np.nan,
70
+ col_angles_result: np.nan,
71
+ col_n_noncov: np.nan,
72
+ col_closest_noncov: np.nan,
73
+ col_n_clashes: np.nan,
74
+ col_clash_result: np.nan,
75
+ }
76
+
77
+
78
+ # Allowable features for atom types (simplified version)
79
+ ALLOWABLE_ATOM_TYPES = [
80
+ 1,
81
+ 5,
82
+ 6,
83
+ 7,
84
+ 8,
85
+ 9,
86
+ 14,
87
+ 15,
88
+ 16,
89
+ 17,
90
+ 23,
91
+ 26,
92
+ 27,
93
+ 29,
94
+ 30,
95
+ 33,
96
+ 34,
97
+ 35,
98
+ 44,
99
+ 51,
100
+ 53,
101
+ 78,
102
+ ]
103
+
104
+
105
+ def symmetrize_conjugated_terminal_bonds(df: pd.DataFrame, mol: Mol) -> pd.DataFrame:
106
+ """
107
+ Symmetrize the lower and upper bounds of the conjugated terminal bonds.
108
+
109
+ Args:
110
+ df: Dataframe with the bond geometry information and bounds.
111
+ mol: RDKit molecule object
112
+
113
+ Returns:
114
+ Dataframe with symmetrized bounds for conjugated terminal bonds.
115
+ """
116
+
117
+ def _sort_bond_ids(bond_ids: tuple) -> tuple:
118
+ return tuple(tuple(sorted(_)) for _ in bond_ids)
119
+
120
+ def _get_terminal_group_matches(_mol: Mol) -> tuple:
121
+ qsmarts = "[O,N;D1;$([O,N;D1]-[*]=[O,N;D1]),$([O,N;D1]=[*]-[O,N;D1])]~[*]"
122
+ qsmarts = MolFromSmarts(qsmarts)
123
+ matches = _mol.GetSubstructMatches(qsmarts)
124
+ return _sort_bond_ids(matches)
125
+
126
+ df["atom_types_sorted"] = df["atom_types"].apply(
127
+ lambda a: tuple(sorted(a.split("--")))
128
+ )
129
+ matches = _get_terminal_group_matches(mol)
130
+ matched = df[df["atom_pair"].isin(matches)].copy()
131
+ grouped = matched.groupby("atom_types_sorted").agg(
132
+ {"lower_bound": np.amin, "upper_bound": np.amax}
133
+ )
134
+ index_orig = matched.index
135
+ matched = matched.set_index("atom_types_sorted")
136
+ matched.update(grouped)
137
+ matched = matched.set_index(index_orig)
138
+ df.update(matched)
139
+ return df.drop(columns=["atom_types_sorted"])
140
+
141
+
142
+ def _get_bond_atom_indices(mol: Mol) -> list:
143
+ bonds = []
144
+ for bond in mol.GetBonds():
145
+ bond_tuple = (bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())
146
+ bond_tuple = _sort_bond(bond_tuple)
147
+ bonds.append(bond_tuple)
148
+ return bonds
149
+
150
+
151
+ def _get_angle_atom_indices(bonds: list) -> list:
152
+ """Check all combinations of bonds to generate list of molecule angles."""
153
+ angles = []
154
+ bonds = list(bonds)
155
+ for i in range(len(bonds)):
156
+ for j in range(i + 1, len(bonds)):
157
+ angle = _two_bonds_to_angle(bonds[i], bonds[j])
158
+ if angle is not None:
159
+ angles.append(angle)
160
+ return angles
161
+
162
+
163
+ def _two_bonds_to_angle(bond1: tuple, bond2: tuple):
164
+ set1 = set(bond1)
165
+ set2 = set(bond2)
166
+ all_atoms = set1 | set2
167
+ if len(all_atoms) != 3:
168
+ return None
169
+ shared_atom = set1 & set2
170
+ other_atoms = all_atoms - shared_atom
171
+ return (min(other_atoms), shared_atom.pop(), max(other_atoms))
172
+
173
+
174
+ def _sort_bond(bond: tuple) -> tuple:
175
+ return (min(bond), max(bond))
176
+
177
+
178
+ def _has_hydrogen(mol: Mol, idcs) -> bool:
179
+ return any(_is_hydrogen(mol, idx) for idx in idcs)
180
+
181
+
182
+ def _is_hydrogen(mol: Mol, idx: int) -> bool:
183
+ return mol.GetAtomWithIdx(int(idx)).GetAtomicNum() == 1
184
+
185
+
186
+ def mol_from_symbols_and_npcoords(symbols, coords_np: np.ndarray):
187
+ """Create RDKit mol from symbols and coordinates.
188
+
189
+ Args:
190
+ symbols: List of atom symbols
191
+ coords_np: Coordinates array shape (N, 3)
192
+
193
+ Returns:
194
+ RDKit Mol with conformer
195
+ """
196
+ assert coords_np.shape == (len(symbols), 3)
197
+ rw = Chem.RWMol()
198
+ for sym in symbols:
199
+ a = Chem.Atom(sym)
200
+ a.SetNoImplicit(True)
201
+ a.SetNumExplicitHs(0)
202
+ rw.AddAtom(a)
203
+ m = rw.GetMol()
204
+ conf = Chem.Conformer(len(symbols))
205
+ conf.SetPositions(coords_np.astype(np.float64, copy=False))
206
+ m.AddConformer(conf, assignId=True)
207
+ return m
208
+
209
+
210
+ def check_intermolecular_distance(
211
+ mol_orig,
212
+ pos_pred,
213
+ pos_cond,
214
+ atom_names_pred,
215
+ atom_names_cond,
216
+ radius_type: str = "vdw",
217
+ radius_scale: float = 1.0,
218
+ clash_cutoff: float = 0.75,
219
+ clash_cutoff_volume: float = 0.075,
220
+ ignore_types: set | None = None,
221
+ max_distance: float = 5.0,
222
+ search_distance: float = 6.0,
223
+ vdw_scale: float = 0.8,
224
+ ):
225
+ """Check that predicted molecule is not too close and not too far away from conditioning molecule.
226
+
227
+ Args:
228
+ mol_orig: Original ligand molecule
229
+ pos_pred: Predicted ligand positions (n_preds, n_atoms, 3)
230
+ pos_cond: Conditioning (protein) positions (n_atoms, 3)
231
+ atom_names_pred: Ligand atom names
232
+ atom_names_cond: Protein atom names
233
+ radius_type: Type of atomic radius ("vdw" or "covalent")
234
+ radius_scale: Scaling factor for radii
235
+ clash_cutoff: Threshold for clash detection
236
+ clash_cutoff_volume: Threshold for volume overlap
237
+ ignore_types: Atom types to ignore
238
+ max_distance: Maximum allowed distance
239
+ search_distance: Search distance for nearby atoms
240
+ vdw_scale: VDW radius scale for volume calculation
241
+
242
+ Returns:
243
+ Dictionary with filter results
244
+ """
245
+ if ignore_types is None:
246
+ ignore_types = {"H"}
247
+ device = "cuda" if torch.cuda.is_available() else "cpu"
248
+ coords_ligand = torch.tensor(pos_pred, device=device).float()
249
+ coords_protein = torch.tensor(pos_cond, device=device).float()
250
+
251
+ atoms_ligand = torch.tensor(
252
+ [atoms_vocab[atom] for atom in atom_names_pred], device=device
253
+ ).long()
254
+ atoms_protein_all = torch.tensor(
255
+ [atoms_vocab[atom] for atom in atom_names_cond], device=device
256
+ ).long()
257
+
258
+ mask = atoms_ligand != atoms_vocab["H"]
259
+ coords_ligand = coords_ligand[:, mask, :]
260
+ atoms_ligand = atoms_ligand[mask]
261
+ if ignore_types:
262
+ mask = atoms_protein_all != atoms_vocab["H"]
263
+ coords_protein = coords_protein[mask, :]
264
+ atoms_protein_all = atoms_protein_all[mask]
265
+
266
+ radius_ligand = vdw_radius.to(device)[atoms_ligand]
267
+ radius_protein_all = vdw_radius.to(device)[atoms_protein_all]
268
+
269
+ distances_all = (coords_ligand[:, :, None] - coords_protein[None, None, :]).norm(
270
+ dim=-1
271
+ )
272
+ distances = distances_all
273
+ radius_protein = radius_protein_all
274
+
275
+ is_buried_fraction = (distances < 5).any(dim=-1).sum(dim=-1) / distances.size(1)
276
+
277
+ radius_sum = radius_ligand[None, :, None] + radius_protein[None, None, :]
278
+ distance = distances
279
+ sum_radii_scaled = radius_sum * radius_scale
280
+ relative_distance = distance / sum_radii_scaled
281
+ clash = relative_distance < clash_cutoff
282
+
283
+ candidates = distance < (
284
+ (radius_ligand[None, :, None] + radius_protein_all[None, None, :]) * vdw_scale
285
+ + 2 * 3 * 0.25
286
+ )
287
+ ids_conds = candidates.any(dim=1).cpu().numpy()
288
+ overlap = []
289
+ for i in range(coords_ligand.size(0)):
290
+ ids_cond = ids_conds[i]
291
+ overlap.append(
292
+ ShapeTverskyIndex(
293
+ mol_from_symbols_and_npcoords(atom_names_pred, pos_pred[i]),
294
+ mol_from_symbols_and_npcoords(
295
+ atom_names_cond[ids_cond], pos_cond[ids_cond]
296
+ ),
297
+ alpha=1,
298
+ beta=0,
299
+ vdwScale=vdw_scale,
300
+ )
301
+ < clash_cutoff_volume
302
+ )
303
+
304
+ results = {
305
+ "not_too_far_away": (
306
+ distance.reshape(distance.size(0), -1).min(dim=-1)[0] <= max_distance
307
+ ).tolist(),
308
+ "no_clashes": torch.logical_not(clash.any(dim=(1, 2))).tolist(),
309
+ "no_volume_clash": overlap,
310
+ "is_buried_fraction": is_buried_fraction.tolist(),
311
+ "no_internal_clash": check_geometry(
312
+ mol_orig,
313
+ coords_ligand,
314
+ threshold_bad_bond_length=0.25,
315
+ threshold_clash=0.3,
316
+ threshold_bad_angle=0.25,
317
+ bound_matrix_params=bound_matrix_params,
318
+ ignore_hydrogens=True,
319
+ sanitize=True,
320
+ symmetrize_conjugated_terminal_groups=True,
321
+ ),
322
+ }
323
+ return {"results": results}
324
+
325
+
326
+ def check_volume_overlap(
327
+ pos_pred,
328
+ pos_cond,
329
+ atom_names_pred,
330
+ atom_names_cond,
331
+ clash_cutoff: float = 0.05,
332
+ vdw_scale: float = 0.8,
333
+ ignore_types: set | None = None,
334
+ search_distance: float = 6.0,
335
+ ):
336
+ """Check volume overlap between ligand and protein.
337
+
338
+ Args:
339
+ pos_pred: Predicted ligand positions
340
+ pos_cond: Protein positions
341
+ atom_names_pred: Ligand atom names
342
+ atom_names_cond: Protein atom names
343
+ clash_cutoff: Maximum allowed volume overlap fraction
344
+ vdw_scale: VDW radius scale
345
+ ignore_types: Atom types to ignore
346
+ search_distance: Search distance for nearby atoms
347
+
348
+ Returns:
349
+ Dictionary with volume overlap results
350
+ """
351
+ if ignore_types is None:
352
+ ignore_types = {"H"}
353
+ keep_mask = atom_names_cond != "H"
354
+ pos_cond = pos_cond[keep_mask]
355
+ atom_names_cond = atom_names_cond[keep_mask]
356
+ if len(pos_cond) == 0:
357
+ return {"results": {"volume_overlap": np.nan, "no_volume_clash": True}}
358
+
359
+ distances = _pairwise_distance(pos_pred, pos_cond)
360
+ keep_mask = distances.min(axis=0) <= search_distance * vdw_scale
361
+ pos_cond = pos_cond[keep_mask]
362
+ atom_names_cond = atom_names_cond[keep_mask]
363
+ if len(pos_cond) == 0:
364
+ return {"results": {"volume_overlap": np.nan, "no_volume_clash": True}}
365
+
366
+ ignore_hydrogens = "H" in ignore_types
367
+ overlap = ShapeTverskyIndex(
368
+ mol_from_symbols_and_npcoords(atom_names_pred, pos_pred),
369
+ mol_from_symbols_and_npcoords(atom_names_cond, pos_cond),
370
+ alpha=1,
371
+ beta=0,
372
+ vdwScale=vdw_scale,
373
+ ignoreHs=ignore_hydrogens,
374
+ )
375
+
376
+ results = {
377
+ "volume_overlap": overlap,
378
+ "no_volume_clash": overlap <= clash_cutoff,
379
+ }
380
+
381
+ return {"results": results}
382
+
383
+
384
+ def check_geometry(
385
+ mol_orig,
386
+ pos_preds,
387
+ threshold_bad_bond_length: float = 0.25,
388
+ threshold_clash: float = 0.3,
389
+ threshold_bad_angle: float = 0.25,
390
+ bound_matrix_params=bound_matrix_params,
391
+ ignore_hydrogens: bool = True,
392
+ sanitize: bool = True,
393
+ symmetrize_conjugated_terminal_groups: bool = True,
394
+ ):
395
+ """Use RDKit distance geometry bounds to check the geometry of a molecule.
396
+
397
+ Args:
398
+ mol_orig: Original molecule
399
+ pos_preds: Predicted positions tensor
400
+ threshold_bad_bond_length: Bond length threshold
401
+ threshold_clash: Clash threshold
402
+ threshold_bad_angle: Angle threshold
403
+ bound_matrix_params: Parameters for GetMoleculeBoundsMatrix
404
+ ignore_hydrogens: Whether to ignore hydrogens
405
+ sanitize: Whether to sanitize molecule
406
+ symmetrize_conjugated_terminal_groups: Whether to symmetrize terminal groups
407
+
408
+ Returns:
409
+ List of booleans indicating valid geometry
410
+ """
411
+ mol_pred = deepcopy(mol_orig)
412
+ mol_pred.GetConformer().SetPositions(pos_preds[0].cpu().numpy().astype(np.float64))
413
+ assert mol_pred.GetNumConformers() == 1, "Molecule must have exactly one conformer"
414
+ mol = deepcopy(mol_pred)
415
+ results = _empty_results.copy()
416
+ if mol.GetNumConformers() == 0:
417
+ print("Molecule does not have a conformer.")
418
+ return {"results": results}
419
+ if mol.GetNumAtoms() == 1:
420
+ print(f"Molecule has only {mol.GetNumAtoms()} atoms.")
421
+ results[col_angles_result] = True
422
+ results[col_bonds_result] = True
423
+ results[col_clash_result] = True
424
+ return {"results": results}
425
+ try:
426
+ if sanitize:
427
+ flags = SanitizeMol(mol)
428
+ assert flags == 0, f"Sanitization failed with flags {flags}"
429
+ except Exception:
430
+ return {"results": results}
431
+ bond_set = sorted(_get_bond_atom_indices(mol))
432
+ angles = sorted(_get_angle_atom_indices(bond_set))
433
+ angle_set = {(a[0], a[2]): a for a in angles}
434
+ if len(bond_set) == 0:
435
+ print("Molecule has no bonds.")
436
+
437
+ bounds = GetMoleculeBoundsMatrix(mol, **bound_matrix_params)
438
+ lower_triangle_idcs = np.tril_indices(mol.GetNumAtoms(), k=-1)
439
+ upper_triangle_idcs = (lower_triangle_idcs[1], lower_triangle_idcs[0])
440
+ df_12 = pd.DataFrame()
441
+ df_12["atom_pair"] = list(zip(*upper_triangle_idcs, strict=False))
442
+ df_12["atom_types"] = [
443
+ "--".join(tuple(mol.GetAtomWithIdx(int(j)).GetSymbol() for j in i))
444
+ for i in df_12["atom_pair"]
445
+ ]
446
+ df_12["angle"] = df_12["atom_pair"].apply(lambda x: angle_set.get(x))
447
+ df_12["has_hydrogen"] = [_has_hydrogen(mol, i) for i in df_12["atom_pair"]]
448
+ df_12["is_bond"] = [i in bond_set for i in df_12["atom_pair"]]
449
+ df_12["is_angle"] = df_12["angle"].apply(lambda x: x is not None)
450
+ df_12[col_lb] = bounds[lower_triangle_idcs]
451
+ df_12[col_ub] = bounds[upper_triangle_idcs]
452
+ if symmetrize_conjugated_terminal_groups:
453
+ df_12 = symmetrize_conjugated_terminal_bonds(df_12, mol)
454
+ distances_all = (pos_preds[:, :, None] - pos_preds[:, None, :]).norm(dim=-1)[
455
+ :, lower_triangle_idcs[0], lower_triangle_idcs[1]
456
+ ]
457
+ distances_valid = distances_all[:, (~df_12["is_bond"] & ~df_12["is_angle"]).values]
458
+ lower_bounds_valid = torch.tensor(
459
+ df_12[col_lb][~df_12["is_bond"] & ~df_12["is_angle"]].values,
460
+ device=distances_valid.device,
461
+ )
462
+ df_clash = torch.where(
463
+ distances_valid >= lower_bounds_valid[None],
464
+ 0,
465
+ (distances_valid - lower_bounds_valid[None]) / lower_bounds_valid[None],
466
+ )
467
+ col_n_clashes_count = (df_clash < -threshold_clash).sum(dim=-1)
468
+ col_n_good_noncov_count = len(df_clash) - col_n_clashes_count
469
+ res = (col_n_good_noncov_count == len(df_clash)).tolist()
470
+ return res
471
+
472
+
473
+ def calc_posebusters(
474
+ pos_pred, pos_cond, atom_ids_pred, atom_names_cond, names, lig_mol_for_posebusters
475
+ ):
476
+ """Calculate fast PoseBusters filters.
477
+
478
+ Args:
479
+ pos_pred: Predicted ligand positions
480
+ pos_cond: Protein positions
481
+ atom_ids_pred: Ligand atom type IDs
482
+ atom_names_cond: Protein atom names
483
+ names: Sample names (for error logging)
484
+ lig_mol_for_posebusters: Ligand molecule for PoseBusters
485
+
486
+ Returns:
487
+ Dictionary with filter results or None on error
488
+ """
489
+ if 22 in atom_ids_pred:
490
+ with open("error.txt", "a") as f:
491
+ f.write(f"Error in {names}\n")
492
+ f.write("22 (misc) in atom_ids_pred\n")
493
+ return None
494
+ atom_names_pred = np.array(
495
+ [
496
+ _periodic_table.GetElementSymbol(ALLOWABLE_ATOM_TYPES[atom_id])
497
+ for atom_id in atom_ids_pred
498
+ if atom_id >= 0
499
+ ],
500
+ dtype=object,
501
+ )
502
+
503
+ posebusters_results = {}
504
+ try:
505
+ assert len(pos_pred[0]) == len(atom_names_pred), (
506
+ f"len(pos_pred[i]) = {len(pos_pred[0])} != len(atom_names_pred[i]) = {len(atom_names_pred)}"
507
+ )
508
+ assert len(pos_cond) == len(atom_names_cond), (
509
+ f"len(pos_cond[i]) = {len(pos_cond[0])} != len(atom_names_cond[i]) = {len(atom_names_cond)}"
510
+ )
511
+ except Exception as e:
512
+ print(f"Error in {names}")
513
+ print(e)
514
+ with open("error.txt", "a") as f:
515
+ f.write(f"Error in {names}\n")
516
+ f.write(
517
+ f"len(pos_pred[i]) = {len(pos_pred[0])} != len(atom_names_pred[i]) = {len(atom_names_pred)}\n"
518
+ )
519
+ return None
520
+ res1 = check_intermolecular_distance(
521
+ lig_mol_for_posebusters, pos_pred, pos_cond, atom_names_pred, atom_names_cond
522
+ )
523
+ res = {**res1["results"]}
524
+ for key in res:
525
+ posebusters_results[key] = res[key]
526
+ return posebusters_results
@@ -0,0 +1,31 @@
1
+ """Metrics for docking evaluation."""
2
+
3
+ from posebench_fast.metrics.aggregation import (
4
+ filter_results_by_fast,
5
+ filter_results_by_posebusters,
6
+ get_best_results_by_score,
7
+ get_final_results_for_df,
8
+ get_simple_metrics_df,
9
+ )
10
+ from posebench_fast.metrics.rmsd import (
11
+ TimeoutException,
12
+ compute_all_isomorphisms,
13
+ get_symmetry_rmsd,
14
+ get_symmetry_rmsd_with_isomorphisms,
15
+ symmrmsd,
16
+ time_limit,
17
+ )
18
+
19
+ __all__ = [
20
+ "compute_all_isomorphisms",
21
+ "get_symmetry_rmsd_with_isomorphisms",
22
+ "get_symmetry_rmsd",
23
+ "symmrmsd",
24
+ "TimeoutException",
25
+ "time_limit",
26
+ "get_simple_metrics_df",
27
+ "get_final_results_for_df",
28
+ "filter_results_by_posebusters",
29
+ "filter_results_by_fast",
30
+ "get_best_results_by_score",
31
+ ]