prism-pruner 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prism_pruner/__init__.py +1 -0
- prism_pruner/algebra.py +163 -0
- prism_pruner/conformer_ensemble.py +57 -0
- prism_pruner/graph_manipulations.py +195 -0
- prism_pruner/pruner.py +623 -0
- prism_pruner/rmsd.py +38 -0
- prism_pruner/torsion_module.py +472 -0
- prism_pruner/typing.py +15 -0
- prism_pruner/utils.py +153 -0
- prism_pruner-0.0.3.dist-info/METADATA +34 -0
- prism_pruner-0.0.3.dist-info/RECORD +14 -0
- prism_pruner-0.0.3.dist-info/WHEEL +5 -0
- prism_pruner-0.0.3.dist-info/licenses/LICENSE +21 -0
- prism_pruner-0.0.3.dist-info/top_level.txt +1 -0
prism_pruner/pruner.py
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
1
|
+
"""PRISM - PRuning Interface for Similar Molecules."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from time import perf_counter
|
|
5
|
+
from typing import Any, Callable, Sequence
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from networkx import Graph, connected_components
|
|
9
|
+
from periodictable import elements
|
|
10
|
+
from scipy.spatial.distance import cdist
|
|
11
|
+
|
|
12
|
+
from prism_pruner.algebra import get_inertia_moments
|
|
13
|
+
from prism_pruner.rmsd import rmsd_and_max
|
|
14
|
+
from prism_pruner.torsion_module import (
|
|
15
|
+
get_angles,
|
|
16
|
+
get_hydrogen_bonds,
|
|
17
|
+
get_torsions,
|
|
18
|
+
is_nondummy,
|
|
19
|
+
rotationally_corrected_rmsd_and_max,
|
|
20
|
+
)
|
|
21
|
+
from prism_pruner.typing import (
|
|
22
|
+
Array1D_bool,
|
|
23
|
+
Array1D_float,
|
|
24
|
+
Array1D_int,
|
|
25
|
+
Array1D_str,
|
|
26
|
+
Array2D_float,
|
|
27
|
+
Array2D_int,
|
|
28
|
+
Array3D_float,
|
|
29
|
+
)
|
|
30
|
+
from prism_pruner.utils import flatten, get_double_bonds_indices, time_to_string
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class PrunerConfig:
|
|
35
|
+
"""Configuration dataclass for Pruner."""
|
|
36
|
+
|
|
37
|
+
structures: Array3D_float
|
|
38
|
+
|
|
39
|
+
# Optional parameters that get initialized
|
|
40
|
+
energies: Array1D_float = field(default_factory=lambda: np.array([]))
|
|
41
|
+
max_dE: float = field(default=0.0)
|
|
42
|
+
debugfunction: Callable[[str], None] | None = field(default=None)
|
|
43
|
+
|
|
44
|
+
# Computed fields
|
|
45
|
+
eval_calls: int = field(default=0, init=False)
|
|
46
|
+
cache_calls: int = field(default=0, init=False)
|
|
47
|
+
cache: set[tuple[int, int]] = field(default_factory=lambda: set(), init=False)
|
|
48
|
+
|
|
49
|
+
def __post_init__(self) -> None:
|
|
50
|
+
"""Validate inputs and initialize computed fields."""
|
|
51
|
+
self.mask = np.ones(shape=(self.structures.shape[0],), dtype=np.bool_)
|
|
52
|
+
|
|
53
|
+
if len(self.energies) != 0:
|
|
54
|
+
assert self.max_dE > 0.0, (
|
|
55
|
+
"If you provide energies, please also provide an appropriate energy window max_dE."
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Set defaults for optional parameters
|
|
59
|
+
if len(self.energies) == 0:
|
|
60
|
+
self.energies = np.zeros(self.structures.shape[0], dtype=float)
|
|
61
|
+
|
|
62
|
+
assert len(self.energies) == len(self.structures), (
|
|
63
|
+
"Please make sure that the energies "
|
|
64
|
+
+ "provided have the same len as the input structures."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
if self.max_dE == 0.0:
|
|
68
|
+
self.max_dE = 1.0
|
|
69
|
+
|
|
70
|
+
def evaluate_sim(self, *args: Any, **kwargs: Any) -> bool:
|
|
71
|
+
"""Stub method - override in subclasses as needed."""
|
|
72
|
+
raise NotImplementedError
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class RMSDRotCorrPrunerConfig(PrunerConfig):
|
|
77
|
+
"""Configuration dataclass for Pruner."""
|
|
78
|
+
|
|
79
|
+
atoms: Array1D_str = field(kw_only=True)
|
|
80
|
+
max_rmsd: float = field(kw_only=True)
|
|
81
|
+
max_dev: float = field(kw_only=True)
|
|
82
|
+
angles: Sequence[Sequence[int]] = field(kw_only=True)
|
|
83
|
+
torsions: Array2D_int = field(kw_only=True)
|
|
84
|
+
graph: Graph = field(kw_only=True)
|
|
85
|
+
heavy_atoms_only: bool = True
|
|
86
|
+
|
|
87
|
+
def evaluate_sim(self, i1: int, i2: int) -> bool:
|
|
88
|
+
"""Return whether the structures are similar."""
|
|
89
|
+
rmsd, max_dev = rotationally_corrected_rmsd_and_max(
|
|
90
|
+
self.structures[i1],
|
|
91
|
+
self.structures[i2],
|
|
92
|
+
atoms=self.atoms,
|
|
93
|
+
torsions=self.torsions,
|
|
94
|
+
graph=self.graph,
|
|
95
|
+
angles=self.angles,
|
|
96
|
+
# debugfunction=self.debugfunction, # lots of printout
|
|
97
|
+
heavy_atoms_only=self.heavy_atoms_only,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if rmsd > self.max_rmsd:
|
|
101
|
+
return False
|
|
102
|
+
|
|
103
|
+
if max_dev > self.max_dev:
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
return True
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@dataclass
|
|
110
|
+
class RMSDPrunerConfig(PrunerConfig):
|
|
111
|
+
"""Configuration dataclass for Pruner."""
|
|
112
|
+
|
|
113
|
+
atoms: Array1D_str = field(kw_only=True)
|
|
114
|
+
max_rmsd: float = field(kw_only=True)
|
|
115
|
+
max_dev: float = field(kw_only=True)
|
|
116
|
+
heavy_atoms_only: bool = True
|
|
117
|
+
|
|
118
|
+
def evaluate_sim(self, i1: int, i2: int) -> bool:
|
|
119
|
+
"""Return whether the structures are similar."""
|
|
120
|
+
if self.heavy_atoms_only:
|
|
121
|
+
mask = self.atoms != "H"
|
|
122
|
+
else:
|
|
123
|
+
mask = np.ones(self.structures[0].shape[0], dtype=bool)
|
|
124
|
+
|
|
125
|
+
rmsd, max_dev = rmsd_and_max(
|
|
126
|
+
self.structures[i1][mask],
|
|
127
|
+
self.structures[i2][mask],
|
|
128
|
+
center=True,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if rmsd > self.max_rmsd:
|
|
132
|
+
return False
|
|
133
|
+
|
|
134
|
+
if max_dev > self.max_dev:
|
|
135
|
+
return False
|
|
136
|
+
|
|
137
|
+
return True
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@dataclass
|
|
141
|
+
class MOIPrunerConfig(PrunerConfig):
|
|
142
|
+
"""Configuration dataclass for Pruner."""
|
|
143
|
+
|
|
144
|
+
masses: Array1D_float = field(kw_only=True)
|
|
145
|
+
max_dev: float = 0.01
|
|
146
|
+
|
|
147
|
+
def __post_init__(self) -> None:
|
|
148
|
+
"""Add the moi_vecs calc to the parent's __post_init__."""
|
|
149
|
+
super().__post_init__()
|
|
150
|
+
self.moi_vecs = {
|
|
151
|
+
c: get_inertia_moments(
|
|
152
|
+
coord,
|
|
153
|
+
self.masses,
|
|
154
|
+
)
|
|
155
|
+
for c, coord in enumerate(self.structures)
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
def evaluate_sim(self, i1: int, i2: int) -> bool:
|
|
159
|
+
"""Return whether the structures are similar."""
|
|
160
|
+
im_1 = self.moi_vecs[i1]
|
|
161
|
+
im_2 = self.moi_vecs[i2]
|
|
162
|
+
|
|
163
|
+
# compare the three MOIs via a Python loop:
|
|
164
|
+
# apparently much faster than numpy array operations
|
|
165
|
+
# for such a small array!
|
|
166
|
+
for j in range(3):
|
|
167
|
+
if np.abs(im_1[j] - im_2[j]) / im_1[j] >= self.max_dev:
|
|
168
|
+
return False
|
|
169
|
+
return True
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _main_compute_subrow(
|
|
173
|
+
prunerconfig: PrunerConfig,
|
|
174
|
+
in_mask: Array1D_bool,
|
|
175
|
+
first_abs_index: int,
|
|
176
|
+
num_str_in_row: int,
|
|
177
|
+
) -> bool:
|
|
178
|
+
"""Evaluate the similarity of a subrow of the similarity matrix.
|
|
179
|
+
|
|
180
|
+
Return True if ref is similar to any
|
|
181
|
+
structure in structures, returning at the first instance of a match.
|
|
182
|
+
Ignores structures that are False (0) in in_mask and does not perform
|
|
183
|
+
the comparison if the energy difference between the structures is less
|
|
184
|
+
than self.max_dE. Saves dissimilar structural pairs (i.e. that evaluate to
|
|
185
|
+
False (0)) by adding them to self.cache, avoiding redundant calcaulations.
|
|
186
|
+
"""
|
|
187
|
+
i1 = first_abs_index
|
|
188
|
+
|
|
189
|
+
# iterate over target structures
|
|
190
|
+
for i in range(num_str_in_row):
|
|
191
|
+
# only compare active structures
|
|
192
|
+
if in_mask[i]:
|
|
193
|
+
# check if we have performed this comparison already:
|
|
194
|
+
# if so, we already know that those two structures are not similar,
|
|
195
|
+
# since the in_mask attribute is not False for ref nor for i
|
|
196
|
+
i2 = first_abs_index + 1 + i
|
|
197
|
+
hash_value = (i1, i2)
|
|
198
|
+
|
|
199
|
+
if hash_value in prunerconfig.cache:
|
|
200
|
+
prunerconfig.cache_calls += 1
|
|
201
|
+
continue
|
|
202
|
+
|
|
203
|
+
# if we have not computed the value before, check if the two
|
|
204
|
+
# structures have close enough energy before running the comparison
|
|
205
|
+
elif (
|
|
206
|
+
np.abs(prunerconfig.energies[i1] - prunerconfig.energies[i2]) < prunerconfig.max_dE
|
|
207
|
+
):
|
|
208
|
+
# function will return True whether the structures are similar,
|
|
209
|
+
# and will stop iterating on this row, returning
|
|
210
|
+
prunerconfig.eval_calls += 1
|
|
211
|
+
if prunerconfig.evaluate_sim(i1, i2):
|
|
212
|
+
return True
|
|
213
|
+
|
|
214
|
+
# if structures are not similar, add the result to the
|
|
215
|
+
# cache, because they will return here,
|
|
216
|
+
# while similar structures are discarded and won't come back
|
|
217
|
+
else:
|
|
218
|
+
prunerconfig.cache.add(hash_value)
|
|
219
|
+
|
|
220
|
+
# if energy is not similar enough, also add to cache
|
|
221
|
+
else:
|
|
222
|
+
prunerconfig.cache.add(hash_value)
|
|
223
|
+
|
|
224
|
+
return False
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _main_compute_row(
|
|
228
|
+
prunerconfig: PrunerConfig,
|
|
229
|
+
row_indices: Array1D_int,
|
|
230
|
+
in_mask: Array1D_bool,
|
|
231
|
+
first_abs_index: int,
|
|
232
|
+
) -> Array1D_bool:
|
|
233
|
+
"""Evaluate the similarity of a row of the similarity matrix.
|
|
234
|
+
|
|
235
|
+
For a given set of structures, check if each is similar
|
|
236
|
+
to any other after itself. Return a boolean mask to slice
|
|
237
|
+
the array, only retaining the structures that are dissimilar.
|
|
238
|
+
The inner subrow function caches computed non-similar pairs.
|
|
239
|
+
|
|
240
|
+
"""
|
|
241
|
+
# initialize the result container
|
|
242
|
+
out_mask = np.ones(shape=in_mask.shape, dtype=np.bool_)
|
|
243
|
+
line_len = len(row_indices)
|
|
244
|
+
|
|
245
|
+
# loop over the structures
|
|
246
|
+
for i in range(line_len):
|
|
247
|
+
# only check for similarity if the structure is active
|
|
248
|
+
if in_mask[i]:
|
|
249
|
+
# reject structure i if it is similar to any other after itself
|
|
250
|
+
similar = _main_compute_subrow(
|
|
251
|
+
prunerconfig,
|
|
252
|
+
in_mask[i + 1 :],
|
|
253
|
+
first_abs_index=first_abs_index + i,
|
|
254
|
+
num_str_in_row=line_len - i - 1,
|
|
255
|
+
)
|
|
256
|
+
out_mask[i] = not similar
|
|
257
|
+
|
|
258
|
+
else:
|
|
259
|
+
out_mask[i] = False
|
|
260
|
+
|
|
261
|
+
return out_mask
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def _main_compute_group(
|
|
265
|
+
prunerconfig: PrunerConfig,
|
|
266
|
+
in_mask: Array1D_bool,
|
|
267
|
+
k: int,
|
|
268
|
+
) -> Array1D_bool:
|
|
269
|
+
"""Evaluate the similarity of each chunk of the similarity matrix.
|
|
270
|
+
|
|
271
|
+
Acts individually on k chunks of the structures array,
|
|
272
|
+
returning the updated mask.
|
|
273
|
+
"""
|
|
274
|
+
# initialize final result container
|
|
275
|
+
out_mask = np.ones(shape=in_mask.shape, dtype=np.bool_)
|
|
276
|
+
|
|
277
|
+
# calculate the size of each chunk
|
|
278
|
+
chunksize = int(len(prunerconfig.structures) // k)
|
|
279
|
+
|
|
280
|
+
# iterate over chunks (multithreading here?)
|
|
281
|
+
for chunk in range(int(k)):
|
|
282
|
+
first = chunk * chunksize
|
|
283
|
+
if chunk == k - 1:
|
|
284
|
+
last = len(prunerconfig.structures)
|
|
285
|
+
else:
|
|
286
|
+
last = chunksize * (chunk + 1)
|
|
287
|
+
|
|
288
|
+
# get the structure indices chunk
|
|
289
|
+
indices_chunk = np.arange(first, last, 1, dtype=int)
|
|
290
|
+
|
|
291
|
+
# compare structures within that chunk and save results to the out_mask
|
|
292
|
+
out_mask[first:last] = _main_compute_row(
|
|
293
|
+
prunerconfig,
|
|
294
|
+
indices_chunk,
|
|
295
|
+
in_mask[first:last],
|
|
296
|
+
first_abs_index=first,
|
|
297
|
+
)
|
|
298
|
+
return out_mask
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def prune(prunerconfig: PrunerConfig) -> tuple[Array2D_float, Array1D_bool]:
|
|
302
|
+
"""Perform the similarity pruning.
|
|
303
|
+
|
|
304
|
+
Remove similar structures by repeatedly grouping them into k
|
|
305
|
+
subgroups and removing similar ones. A cache is present to avoid
|
|
306
|
+
repeating RMSD computations.
|
|
307
|
+
|
|
308
|
+
Similarity occurs for structures with both rmsd < self.max_rmsd and
|
|
309
|
+
maximum absolute atomic deviation < self.max_dev.
|
|
310
|
+
|
|
311
|
+
Sets the self.structures and the corresponding self.mask attributes.
|
|
312
|
+
"""
|
|
313
|
+
start_t = perf_counter()
|
|
314
|
+
|
|
315
|
+
# initialize the output mask
|
|
316
|
+
out_mask = np.ones(shape=prunerconfig.structures.shape[0], dtype=np.bool_)
|
|
317
|
+
prunerconfig.cache = set()
|
|
318
|
+
|
|
319
|
+
# sort structures by ascending energy: this will have the effect of
|
|
320
|
+
# having energetically similar structures end up in the same chunk
|
|
321
|
+
# and therefore being pruned early
|
|
322
|
+
if np.abs(prunerconfig.energies[-1]) > 0:
|
|
323
|
+
sorting_indices = np.argsort(prunerconfig.energies)
|
|
324
|
+
prunerconfig.structures = prunerconfig.structures[sorting_indices]
|
|
325
|
+
prunerconfig.energies = prunerconfig.energies[sorting_indices]
|
|
326
|
+
|
|
327
|
+
# split the structure array in subgroups and prune them internally
|
|
328
|
+
for k in (
|
|
329
|
+
500_000,
|
|
330
|
+
200_000,
|
|
331
|
+
100_000,
|
|
332
|
+
50_000,
|
|
333
|
+
20_000,
|
|
334
|
+
10_000,
|
|
335
|
+
5000,
|
|
336
|
+
2000,
|
|
337
|
+
1000,
|
|
338
|
+
500,
|
|
339
|
+
200,
|
|
340
|
+
100,
|
|
341
|
+
50,
|
|
342
|
+
20,
|
|
343
|
+
10,
|
|
344
|
+
5,
|
|
345
|
+
2,
|
|
346
|
+
1,
|
|
347
|
+
):
|
|
348
|
+
# choose only k values such that every subgroup
|
|
349
|
+
# has on average at least twenty active structures in it
|
|
350
|
+
if k == 1 or 20 * k < np.count_nonzero(out_mask):
|
|
351
|
+
before = np.count_nonzero(out_mask)
|
|
352
|
+
|
|
353
|
+
start_t_k = perf_counter()
|
|
354
|
+
|
|
355
|
+
# compute similarities and get back the out_mask
|
|
356
|
+
# and the pairings to be added to cache
|
|
357
|
+
out_mask = _main_compute_group(
|
|
358
|
+
prunerconfig,
|
|
359
|
+
out_mask,
|
|
360
|
+
k=k,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
after = np.count_nonzero(out_mask)
|
|
364
|
+
newly_discarded = before - after
|
|
365
|
+
|
|
366
|
+
if prunerconfig.debugfunction is not None:
|
|
367
|
+
elapsed = perf_counter() - start_t_k
|
|
368
|
+
prunerconfig.debugfunction(
|
|
369
|
+
f"DEBUG: {prunerconfig.__class__.__name__} - k={k}, rejected {newly_discarded} "
|
|
370
|
+
+ f"(keeping {after}/{len(out_mask)}), in {time_to_string(elapsed)}"
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
del prunerconfig.cache
|
|
374
|
+
|
|
375
|
+
if prunerconfig.debugfunction is not None:
|
|
376
|
+
elapsed = perf_counter() - start_t
|
|
377
|
+
prunerconfig.debugfunction(
|
|
378
|
+
f"DEBUG: {prunerconfig.__class__.__name__} - keeping "
|
|
379
|
+
+ f"{after}/{len(out_mask)} "
|
|
380
|
+
+ f"({time_to_string(elapsed)})"
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
if prunerconfig.eval_calls == 0:
|
|
384
|
+
fraction = 0.0
|
|
385
|
+
else:
|
|
386
|
+
fraction = prunerconfig.cache_calls / (
|
|
387
|
+
prunerconfig.eval_calls + prunerconfig.cache_calls
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
prunerconfig.debugfunction(
|
|
391
|
+
f"DEBUG: {prunerconfig.__class__.__name__} - Used cached data "
|
|
392
|
+
+ f"{prunerconfig.cache_calls}/{prunerconfig.eval_calls + prunerconfig.cache_calls}"
|
|
393
|
+
+ f" times, {100 * fraction:.2f}% of total calls"
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
return prunerconfig.structures[out_mask], out_mask
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def prune_by_rmsd(
|
|
400
|
+
structures: Array3D_float,
|
|
401
|
+
atoms: Array1D_str,
|
|
402
|
+
max_rmsd: float = 0.25,
|
|
403
|
+
max_dev: float | None = None,
|
|
404
|
+
energies: Array1D_float | None = None,
|
|
405
|
+
max_dE: float = 0.0,
|
|
406
|
+
debugfunction: Callable[[str], None] | None = None,
|
|
407
|
+
) -> tuple[Array3D_float, Array1D_bool]:
|
|
408
|
+
"""Remove duplicate structures using a heavy-atom RMSD metric.
|
|
409
|
+
|
|
410
|
+
Remove similar structures by repeatedly grouping them into k
|
|
411
|
+
subgroups and removing similar ones. A cache is present to avoid
|
|
412
|
+
repeating RMSD computations.
|
|
413
|
+
|
|
414
|
+
Similarity occurs for structures with both RMSD < max_rmsd and
|
|
415
|
+
maximum deviation < max_dev. max_dev by default is 2 * max_rmsd.
|
|
416
|
+
"""
|
|
417
|
+
if energies is None:
|
|
418
|
+
energies = np.array([])
|
|
419
|
+
|
|
420
|
+
# set default max_dev if not provided
|
|
421
|
+
max_dev = max_dev or 2 * max_rmsd
|
|
422
|
+
|
|
423
|
+
# set up PrunerConfig dataclass
|
|
424
|
+
prunerconfig = RMSDPrunerConfig(
|
|
425
|
+
structures=structures,
|
|
426
|
+
atoms=atoms,
|
|
427
|
+
max_rmsd=max_rmsd,
|
|
428
|
+
max_dev=max_dev,
|
|
429
|
+
energies=energies,
|
|
430
|
+
max_dE=max_dE,
|
|
431
|
+
debugfunction=debugfunction,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# run the pruning
|
|
435
|
+
return prune(prunerconfig)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def prune_by_rmsd_rot_corr(
|
|
439
|
+
structures: Array3D_float,
|
|
440
|
+
atoms: Array1D_str,
|
|
441
|
+
graph: Graph,
|
|
442
|
+
max_rmsd: float = 0.25,
|
|
443
|
+
max_dev: float | None = None,
|
|
444
|
+
energies: Array1D_float | None = None,
|
|
445
|
+
max_dE: float = 0.0,
|
|
446
|
+
logfunction: Callable[[str], None] | None = None,
|
|
447
|
+
debugfunction: Callable[[str], None] | None = None,
|
|
448
|
+
) -> tuple[Array3D_float, Array1D_bool]:
|
|
449
|
+
"""Remove duplicates using a heavy-atom RMSD metric, corrected for degenerate torsions.
|
|
450
|
+
|
|
451
|
+
Remove similar structures by repeatedly grouping them into k
|
|
452
|
+
subgroups and removing similar ones. A cache is present to avoid
|
|
453
|
+
repeating RMSD computations.
|
|
454
|
+
|
|
455
|
+
Similarity occurs for structures with both RMSD < max_rmsd and
|
|
456
|
+
maximum deviation < max_dev. max_dev by default is 2 * max_rmsd.
|
|
457
|
+
|
|
458
|
+
The RMSD and maximum deviation metrics used are the lowest ones
|
|
459
|
+
of all the degenerate rotamers of the input structure.
|
|
460
|
+
"""
|
|
461
|
+
# center structures
|
|
462
|
+
structures = np.array([s - s.mean(axis=0) for s in structures])
|
|
463
|
+
ref = structures[0]
|
|
464
|
+
|
|
465
|
+
# get the number of molecular fragments
|
|
466
|
+
subgraphs = list(connected_components(graph))
|
|
467
|
+
|
|
468
|
+
# if they are more than two, give up on pruning by rot corr rmsd
|
|
469
|
+
if len(subgraphs) > 2:
|
|
470
|
+
return structures, np.ones(structures.shape[0], dtype=bool)
|
|
471
|
+
|
|
472
|
+
# if they are two, we can add a fictitious bond between the closest
|
|
473
|
+
# atoms on the two molecular fragment in the provided graph, and
|
|
474
|
+
# then removing it before returning
|
|
475
|
+
if len(subgraphs) == 2:
|
|
476
|
+
subgraphs = [list(vals) for vals in connected_components(graph)]
|
|
477
|
+
all_dists_array = cdist(ref[list(subgraphs[0])], ref[list(subgraphs[1])])
|
|
478
|
+
min_d = np.min(all_dists_array)
|
|
479
|
+
s1, s2 = np.where(all_dists_array == min_d)
|
|
480
|
+
i1, i2 = subgraphs[0][s1[0]], subgraphs[1][s2[0]]
|
|
481
|
+
graph.add_edge(i1, i2)
|
|
482
|
+
|
|
483
|
+
if debugfunction is not None:
|
|
484
|
+
debugfunction(
|
|
485
|
+
f"DEBUG: prune_by_rmsd_rot_corr - temporarily added "
|
|
486
|
+
f"edge {i1}-{i2} to the graph (will be removed before returning)"
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
# set default max_dev if not provided
|
|
490
|
+
max_dev = max_dev or 2 * max_rmsd
|
|
491
|
+
|
|
492
|
+
# add hydrogen bonds to molecular graph
|
|
493
|
+
hydrogen_bonds = get_hydrogen_bonds(ref, atoms, graph)
|
|
494
|
+
for hb in hydrogen_bonds:
|
|
495
|
+
graph.add_edge(*hb)
|
|
496
|
+
|
|
497
|
+
# keep an unraveled set of atoms in hbs
|
|
498
|
+
flat_hbs = set(flatten(hydrogen_bonds))
|
|
499
|
+
|
|
500
|
+
# get all rotable bonds in the molecule, including dummy rotations
|
|
501
|
+
torsions = get_torsions(
|
|
502
|
+
graph,
|
|
503
|
+
hydrogen_bonds=hydrogen_bonds,
|
|
504
|
+
double_bonds=get_double_bonds_indices(ref, atoms),
|
|
505
|
+
keepdummy=True,
|
|
506
|
+
mode="symmetry",
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
# only keep dummy rotations (checking both directions)
|
|
510
|
+
torsions = [
|
|
511
|
+
t
|
|
512
|
+
for t in torsions
|
|
513
|
+
if not (is_nondummy(t.i2, t.i3, graph) and (is_nondummy(t.i3, t.i2, graph)))
|
|
514
|
+
]
|
|
515
|
+
|
|
516
|
+
# since we only compute RMSD based on heavy atoms, discard
|
|
517
|
+
# quadruplets that involve hydrogen atom as termini, unless
|
|
518
|
+
# they are involved in hydrogen bonding
|
|
519
|
+
torsions = [
|
|
520
|
+
t
|
|
521
|
+
for t in torsions
|
|
522
|
+
if ("H" not in [atoms[i] for i in t.torsion])
|
|
523
|
+
or (t.torsion[0] in flat_hbs or t.torsion[3] in flat_hbs)
|
|
524
|
+
]
|
|
525
|
+
|
|
526
|
+
# get torsions angles
|
|
527
|
+
angles = [get_angles(t, graph) for t in torsions]
|
|
528
|
+
|
|
529
|
+
# Used specific directionality of torsions so that we always
|
|
530
|
+
# rotate the dummy portion (the one attached to the last index)
|
|
531
|
+
torsions_ids = np.asarray(
|
|
532
|
+
[
|
|
533
|
+
list(t.torsion) if is_nondummy(t.i2, t.i3, graph) else list(reversed(t.torsion))
|
|
534
|
+
for t in torsions
|
|
535
|
+
]
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
# Set up final mask and cache
|
|
539
|
+
final_mask = np.ones(structures.shape[0], dtype=bool)
|
|
540
|
+
|
|
541
|
+
# Halt the run if there are too many structures or no subsymmetrical bonds
|
|
542
|
+
if len(torsions_ids) == 0:
|
|
543
|
+
if debugfunction is not None:
|
|
544
|
+
debugfunction(
|
|
545
|
+
"DEBUG: prune_by_rmsd_rot_corr - No subsymmetrical torsions found: skipping "
|
|
546
|
+
"symmetry-corrected RMSD pruning"
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
return structures[final_mask], final_mask
|
|
550
|
+
|
|
551
|
+
# Print out torsion information
|
|
552
|
+
if logfunction is not None:
|
|
553
|
+
logfunction("\n >> Dihedrals considered for rotamer corrections:")
|
|
554
|
+
for i, (torsion, angle) in enumerate(zip(torsions_ids, angles, strict=False)):
|
|
555
|
+
logfunction(
|
|
556
|
+
" {:2s} - {:21s} : {}{}{}{} : {}-fold".format(
|
|
557
|
+
str(i + 1),
|
|
558
|
+
str(torsion),
|
|
559
|
+
atoms[torsion[0]],
|
|
560
|
+
atoms[torsion[1]],
|
|
561
|
+
atoms[torsion[2]],
|
|
562
|
+
atoms[torsion[3]],
|
|
563
|
+
len(angle),
|
|
564
|
+
)
|
|
565
|
+
)
|
|
566
|
+
logfunction("\n")
|
|
567
|
+
|
|
568
|
+
if energies is None:
|
|
569
|
+
energies = np.array([])
|
|
570
|
+
|
|
571
|
+
# Initialize PrunerConfig
|
|
572
|
+
prunerconfig = RMSDRotCorrPrunerConfig(
|
|
573
|
+
structures=structures,
|
|
574
|
+
atoms=atoms,
|
|
575
|
+
energies=energies,
|
|
576
|
+
max_dE=max_dE,
|
|
577
|
+
graph=graph,
|
|
578
|
+
torsions=torsions_ids,
|
|
579
|
+
debugfunction=debugfunction,
|
|
580
|
+
angles=angles,
|
|
581
|
+
max_rmsd=max_rmsd,
|
|
582
|
+
max_dev=max_dev,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
# run pruning
|
|
586
|
+
structures_out, mask = prune(prunerconfig)
|
|
587
|
+
|
|
588
|
+
# remove the extra bond in the molecular graph
|
|
589
|
+
if len(subgraphs) == 2:
|
|
590
|
+
graph.remove_edge(i1, i2)
|
|
591
|
+
|
|
592
|
+
return structures_out, mask
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def prune_by_moment_of_inertia(
|
|
596
|
+
structures: Array3D_float,
|
|
597
|
+
atoms: Array1D_str,
|
|
598
|
+
max_deviation: float = 1e-2,
|
|
599
|
+
energies: Array1D_float | None = None,
|
|
600
|
+
max_dE: float = 0.0,
|
|
601
|
+
debugfunction: Callable[[str], None] | None = None,
|
|
602
|
+
) -> tuple[Array3D_float, Array1D_bool]:
|
|
603
|
+
"""Remove duplicate structures using a moments of inertia-based metric.
|
|
604
|
+
|
|
605
|
+
Remove duplicate structures (enantiomeric or rotameric) based on the
|
|
606
|
+
moment of inertia on the principal axes. If all three deviate less than
|
|
607
|
+
max_deviation percent from another one, the structure is removed from
|
|
608
|
+
the ensemble (i.e. max_deviation = 0.1 is 10% relative deviation).
|
|
609
|
+
"""
|
|
610
|
+
if energies is None:
|
|
611
|
+
energies = np.array([])
|
|
612
|
+
|
|
613
|
+
# set up PrunerConfig dataclass
|
|
614
|
+
prunerconfig = MOIPrunerConfig(
|
|
615
|
+
structures=structures,
|
|
616
|
+
energies=energies,
|
|
617
|
+
max_dE=max_dE,
|
|
618
|
+
debugfunction=debugfunction,
|
|
619
|
+
max_dev=max_deviation,
|
|
620
|
+
masses=np.array([elements.symbol(a).mass for a in atoms]),
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
return prune(prunerconfig)
|
prism_pruner/rmsd.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""PRISM - PRuning Interface for Similar Molecules."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from prism_pruner.algebra import get_alignment_matrix
|
|
6
|
+
from prism_pruner.typing import Array2D_float
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def rmsd_and_max(
|
|
10
|
+
p: Array2D_float,
|
|
11
|
+
q: Array2D_float,
|
|
12
|
+
center: bool = True,
|
|
13
|
+
) -> tuple[float, float]:
|
|
14
|
+
"""Return RMSD and max deviation.
|
|
15
|
+
|
|
16
|
+
Return a tuple with the RMSD between p and q
|
|
17
|
+
and the maximum deviation of their positions.
|
|
18
|
+
"""
|
|
19
|
+
if center:
|
|
20
|
+
p = p - p.mean(axis=0)
|
|
21
|
+
q = q - q.mean(axis=0)
|
|
22
|
+
|
|
23
|
+
# get alignment matrix
|
|
24
|
+
rot_mat = get_alignment_matrix(p, q)
|
|
25
|
+
|
|
26
|
+
# Apply it to p
|
|
27
|
+
p = np.ascontiguousarray(p) @ rot_mat
|
|
28
|
+
|
|
29
|
+
# Calculate deviations
|
|
30
|
+
diff = p - q
|
|
31
|
+
|
|
32
|
+
# Calculate RMSD
|
|
33
|
+
rmsd = np.sqrt(np.sum(diff * diff) / len(diff))
|
|
34
|
+
|
|
35
|
+
# Calculate max deviation
|
|
36
|
+
max_delta = np.max(np.linalg.norm(diff, axis=1))
|
|
37
|
+
|
|
38
|
+
return rmsd, max_delta
|