prism-pruner 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of prism-pruner might be problematic. Click here for more details.

prism_pruner/pruner.py ADDED
@@ -0,0 +1,571 @@
1
+ """PRISM - PRuning Interface for Similar Molecules."""
2
+
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass, field
5
+ from time import perf_counter
6
+ from typing import Any, Callable, Sequence, TypeVar
7
+
8
+ import numpy as np
9
+ from networkx import Graph, connected_components
10
+ from scipy.spatial.distance import cdist
11
+
12
+ from prism_pruner.algebra import get_moi_deviation_vec
13
+ from prism_pruner.pt import pt
14
+ from prism_pruner.rmsd import rmsd_and_max
15
+ from prism_pruner.torsion_module import (
16
+ get_angles,
17
+ get_hydrogen_bonds,
18
+ get_torsions,
19
+ is_nondummy,
20
+ rotationally_corrected_rmsd_and_max,
21
+ )
22
+ from prism_pruner.typing import (
23
+ Array1D_bool,
24
+ Array1D_float,
25
+ Array1D_int,
26
+ Array2D_float,
27
+ Array2D_int,
28
+ Array3D_float,
29
+ )
30
+ from prism_pruner.utils import flatten, get_double_bonds_indices, time_to_string
31
+
32
+ __version__ = "1.0.0"
33
+
34
+
35
+ @dataclass
36
+ class PrunerConfig:
37
+ """Configuration dataclass for Pruner."""
38
+
39
+ structures: Array3D_float
40
+
41
+ # Optional parameters that get initialized
42
+ energies: Array1D_float = field(default_factory=lambda: np.array([]))
43
+ ewin: float = field(default=0.0)
44
+ debugfunction: Callable[[str], None] | None = field(default=None)
45
+
46
+ # Computed fields
47
+ calls: int = field(default=0, init=False)
48
+ cache_calls: int = field(default=0, init=False)
49
+ cache: set[tuple[int, int]] = field(default_factory=lambda: set(), init=False)
50
+
51
+ def __post_init__(self) -> None:
52
+ """Validate inputs and initialize computed fields."""
53
+ self.mask = np.ones(shape=(self.structures.shape[0],), dtype=np.bool_)
54
+
55
+ if len(self.energies) != 0:
56
+ assert self.ewin > 0.0, (
57
+ "If you provide energies, please also provide an appropriate energy window ewin."
58
+ )
59
+
60
+ # Set defaults for optional parameters
61
+ if len(self.energies) == 0:
62
+ self.energies = np.zeros(self.structures.shape[0])
63
+
64
+ if self.ewin == 0.0:
65
+ self.ewin = 1.0
66
+
67
+ def evaluate_sim(self, *args: Any, **kwargs: Any) -> bool:
68
+ """Stub method - override in subclasses as needed."""
69
+ raise NotImplementedError
70
+
71
+
72
+ PrunerConfigType = TypeVar("PrunerConfigType", bound=PrunerConfig)
73
+
74
+
75
+ @dataclass
76
+ class RMSDRotCorrPrunerConfig(PrunerConfig):
77
+ """Configuration dataclass for Pruner."""
78
+
79
+ atomnos: Array1D_int = field(kw_only=True)
80
+ max_rmsd: float = field(kw_only=True)
81
+ max_dev: float = field(kw_only=True)
82
+ angles: Sequence[Sequence[int]] = field(kw_only=True)
83
+ torsions: Array2D_int = field(kw_only=True)
84
+ graph: Graph = field(kw_only=True)
85
+ heavy_atoms_only: bool = True
86
+
87
+ def evaluate_sim(self, coord1: Array2D_float, coord2: Array2D_float) -> bool:
88
+ """Return if the structures are similar."""
89
+ rmsd, max_dev = rotationally_corrected_rmsd_and_max(
90
+ coord1,
91
+ coord2,
92
+ atomnos=self.atomnos,
93
+ torsions=self.torsions,
94
+ graph=self.graph,
95
+ angles=self.angles,
96
+ debugfunction=self.debugfunction,
97
+ heavy_atoms_only=self.heavy_atoms_only,
98
+ )
99
+
100
+ if rmsd > self.max_rmsd:
101
+ return False
102
+
103
+ if max_dev > self.max_dev:
104
+ return False
105
+
106
+ return True
107
+
108
+
109
+ @dataclass
110
+ class RMSDPrunerConfig(PrunerConfig):
111
+ """Configuration dataclass for Pruner."""
112
+
113
+ atomnos: Array1D_int = field(kw_only=True)
114
+ max_rmsd: float = field(kw_only=True)
115
+ max_dev: float = field(kw_only=True)
116
+ heavy_atoms_only: bool = True
117
+
118
+ def evaluate_sim(self, coord1: Array2D_float, coord2: Array2D_float) -> bool:
119
+ """Return if the structures are similar."""
120
+ if self.heavy_atoms_only:
121
+ mask = self.atomnos != 1
122
+ else:
123
+ mask = np.ones(self.structures[0].shape[0], dtype=bool)
124
+
125
+ rmsd, max_dev = rmsd_and_max(
126
+ coord1[mask],
127
+ coord2[mask],
128
+ center=True,
129
+ )
130
+
131
+ if rmsd > self.max_rmsd:
132
+ return False
133
+
134
+ if max_dev > self.max_dev:
135
+ return False
136
+
137
+ return True
138
+
139
+
140
+ @dataclass
141
+ class MOIPrunerConfig(PrunerConfig):
142
+ """Configuration dataclass for Pruner."""
143
+
144
+ masses: Array1D_float = field(kw_only=True)
145
+ max_dev: float = 0.01
146
+
147
+ def evaluate_sim(self, coord1: Array2D_float, coord2: Array2D_float) -> bool:
148
+ """Return if the structures are similar."""
149
+ dev_vec = get_moi_deviation_vec(
150
+ coord1,
151
+ coord2,
152
+ masses=self.masses,
153
+ )
154
+
155
+ return bool((dev_vec < self.max_dev).all())
156
+
157
+
158
+ def _main_compute_subrow(
159
+ prunerconfig: PrunerConfigType,
160
+ ref: Array2D_float,
161
+ structures: Array3D_float,
162
+ in_mask: Array1D_bool,
163
+ first_abs_index: int,
164
+ ) -> bool:
165
+ """Evaluate the similarity of a subrow of the similarity matrix.
166
+
167
+ Return True if ref is similar to any
168
+ structure in structures, returning at the first instance of a match.
169
+ Ignores structures that are False (0) in in_mask and does not perform
170
+ the comparison if the energy difference between the structures is less
171
+ than self.ewin. Saves dissimilar structural pairs (i.e. that evaluate to
172
+ False (0)) by adding them to self.cache, avoiding redundant calcaulations.
173
+ """
174
+ # iterate over target structures
175
+ for i, structure in enumerate(structures):
176
+ # only compare active structures
177
+ if in_mask[i]:
178
+ # check if we have performed this comparison already:
179
+ # if so, we already know that those two structures are not similar,
180
+ # since the in_mask attribute is not False for ref nor for i
181
+ i1 = first_abs_index
182
+ i2 = first_abs_index + 1 + i
183
+ hash_value = (i1, i2)
184
+
185
+ prunerconfig.calls += 1
186
+ if hash_value in prunerconfig.cache:
187
+ prunerconfig.cache_calls += 1
188
+
189
+ # if we have not computed the value before, check if the two
190
+ # structures have close enough energy before running the comparison
191
+ elif np.abs(prunerconfig.energies[i1] - prunerconfig.energies[i2]) < prunerconfig.ewin:
192
+ # function will return True if the structures are similar,
193
+ # and will stop iterating on this row, returning
194
+ if prunerconfig.evaluate_sim(ref, structure):
195
+ return True
196
+
197
+ # if structures are not similar, add the result to the
198
+ # cache, because they will return here,
199
+ # while similar structures are discarded and won't come back
200
+ prunerconfig.cache.add(hash_value)
201
+
202
+ return False
203
+
204
+
205
+ def _main_compute_row(
206
+ prunerconfig: PrunerConfigType,
207
+ structures: Array3D_float,
208
+ in_mask: Array1D_bool,
209
+ first_abs_index: int,
210
+ ) -> Array1D_bool:
211
+ """Evaluate the similarity of a row of the similarity matrix.
212
+
213
+ For a given set of structures, check if each is similar
214
+ to any other after itself. Return a boolean mask to slice
215
+ the array, only retaining the structures that are dissimilar.
216
+ The inner subrow function caches computed non-similar pairs.
217
+
218
+ """
219
+ # initialize the result container
220
+ out_mask = np.ones(shape=in_mask.shape, dtype=np.bool_)
221
+
222
+ # loop over the structures
223
+ for i, ref in enumerate(structures):
224
+ # only check for similarity if the structure is active
225
+ if in_mask[i]:
226
+ # reject structure i if it is similar to any other after itself
227
+ similar = _main_compute_subrow(
228
+ prunerconfig,
229
+ ref,
230
+ structures[i + 1 :],
231
+ in_mask[i + 1 :],
232
+ first_abs_index=first_abs_index + i,
233
+ )
234
+ out_mask[i] = not similar
235
+
236
+ else:
237
+ out_mask[i] = 0
238
+
239
+ return out_mask
240
+
241
+
242
+ def _main_compute_group(
243
+ prunerconfig: PrunerConfigType,
244
+ structures: Array2D_float,
245
+ in_mask: Array1D_bool,
246
+ k: int,
247
+ ) -> Array1D_bool:
248
+ """Evaluate the similarity of each chunk of the similarity matrix.
249
+
250
+ Acts individually on k chunks of the structures array,
251
+ returning the updated mask.
252
+ """
253
+ # initialize final result container
254
+ out_mask = np.ones(shape=structures.shape[0], dtype=np.bool_)
255
+
256
+ # calculate the size of each chunk
257
+ chunksize = int(len(structures) // k)
258
+
259
+ # iterate over chunks (multithreading here?)
260
+ for chunk in range(int(k)):
261
+ first = chunk * chunksize
262
+ if chunk == k - 1:
263
+ last = len(structures)
264
+ else:
265
+ last = chunksize * (chunk + 1)
266
+
267
+ # get the structure chunk
268
+ structures_chunk = structures[first:last]
269
+
270
+ # compare structures within that chunk and save results to the out_mask
271
+ out_mask[first:last] = _main_compute_row(
272
+ prunerconfig,
273
+ structures_chunk,
274
+ in_mask[first:last],
275
+ first_abs_index=first,
276
+ )
277
+ return out_mask
278
+
279
+
280
+ def prune(prunerconfig: PrunerConfigType) -> tuple[Array2D_float, Array1D_bool]:
281
+ """Perform the similarity pruning.
282
+
283
+ Remove similar structures by repeatedly grouping them into k
284
+ subgroups and removing similar ones. A cache is present to avoid
285
+ repeating RMSD computations.
286
+
287
+ Similarity occurs for structures with both rmsd < self.max_rmsd and
288
+ maximum absolute atomic deviation < self.max_dev.
289
+
290
+ Sets the self.structures and the corresponding self.mask attributes.
291
+ """
292
+ start_t = perf_counter()
293
+
294
+ structures = deepcopy(prunerconfig.structures)
295
+
296
+ # initialize the output mask
297
+ out_mask = np.ones(shape=prunerconfig.structures.shape[0], dtype=np.bool_)
298
+ prunerconfig.cache = set()
299
+
300
+ # split the structure array in subgroups and prune them internally
301
+ for k in (
302
+ 5e5,
303
+ 2e5,
304
+ 1e5,
305
+ 5e4,
306
+ 2e4,
307
+ 1e4,
308
+ 5000,
309
+ 2000,
310
+ 1000,
311
+ 500,
312
+ 200,
313
+ 100,
314
+ 50,
315
+ 20,
316
+ 10,
317
+ 5,
318
+ 2,
319
+ 1,
320
+ ):
321
+ # choose only k values such that every subgroup
322
+ # has on average at least twenty active structures in it
323
+ if k == 1 or 20 * k < np.count_nonzero(out_mask):
324
+ before = np.count_nonzero(out_mask)
325
+
326
+ start_t_k = perf_counter()
327
+
328
+ # compute similarities and get back the out_mask
329
+ # and the pairings to be added to cache
330
+ out_mask = _main_compute_group(
331
+ prunerconfig,
332
+ structures,
333
+ out_mask,
334
+ k=int(k),
335
+ )
336
+
337
+ after = np.count_nonzero(out_mask)
338
+ newly_discarded = before - after
339
+
340
+ if prunerconfig.debugfunction is not None:
341
+ elapsed = start_t_k - perf_counter()
342
+ prunerconfig.debugfunction(
343
+ f"DEBUG: {prunerconfig.__class__.__name__} - k={k}, rejected {newly_discarded} "
344
+ + f"(keeping {after}/{len(out_mask)}), in {time_to_string(elapsed)}"
345
+ )
346
+
347
+ del prunerconfig.cache
348
+
349
+ if prunerconfig.debugfunction is not None:
350
+ elapsed = start_t - perf_counter()
351
+ prunerconfig.debugfunction(
352
+ f"DEBUG: {prunerconfig.__class__.__name__} - keeping "
353
+ + f"{after}/{len(out_mask)} "
354
+ + f"({time_to_string(elapsed)})"
355
+ )
356
+
357
+ fraction = 0 if prunerconfig.calls == 0 else prunerconfig.cache_calls / prunerconfig.calls
358
+ prunerconfig.debugfunction(
359
+ f"DEBUG: {prunerconfig.__class__.__name__} - Used cached data "
360
+ + f"{prunerconfig.cache_calls}/{prunerconfig.calls} times, "
361
+ + f"{100 * fraction:.2f}% of total calls"
362
+ )
363
+
364
+ return prunerconfig.structures[out_mask], out_mask
365
+
366
+
367
+ def prune_by_rmsd(
368
+ structures: Array3D_float,
369
+ atomnos: Array1D_int,
370
+ max_rmsd: float = 0.25,
371
+ max_dev: float | None = None,
372
+ debugfunction: Callable[[str], None] | None = None,
373
+ ) -> tuple[Array3D_float, Array1D_bool]:
374
+ """Remove duplicate structures using a heavy-atom RMSD metric.
375
+
376
+ Remove similar structures by repeatedly grouping them into k
377
+ subgroups and removing similar ones. A cache is present to avoid
378
+ repeating RMSD computations.
379
+
380
+ Similarity occurs for structures with both RMSD < max_rmsd and
381
+ maximum deviation < max_dev. max_dev by default is 2 * max_rmsd.
382
+ """
383
+ # set default max_dev if not provided
384
+ max_dev = max_dev or 2 * max_rmsd
385
+
386
+ # set up PrunerConfig dataclass
387
+ prunerconfig = RMSDPrunerConfig(
388
+ structures=structures,
389
+ atomnos=atomnos,
390
+ max_rmsd=max_rmsd,
391
+ max_dev=max_dev,
392
+ debugfunction=debugfunction,
393
+ )
394
+
395
+ # run the pruning
396
+ return prune(prunerconfig)
397
+
398
+
399
+ def prune_by_rmsd_rot_corr(
400
+ structures: Array3D_float,
401
+ atomnos: Array1D_int,
402
+ graph: Graph,
403
+ max_rmsd: float = 0.25,
404
+ max_dev: float | None = None,
405
+ logfunction: Callable[[str], None] | None = None,
406
+ debugfunction: Callable[[str], None] | None = None,
407
+ ) -> tuple[Array3D_float, Array1D_bool]:
408
+ """Remove duplicates using a heavy-atom RMSD metric, corrected for degenerate torsions.
409
+
410
+ Remove similar structures by repeatedly grouping them into k
411
+ subgroups and removing similar ones. A cache is present to avoid
412
+ repeating RMSD computations.
413
+
414
+ Similarity occurs for structures with both RMSD < max_rmsd and
415
+ maximum deviation < max_dev. max_dev by default is 2 * max_rmsd.
416
+
417
+ The RMSD and maximum deviation metrics used are the lowest ones
418
+ of all the degenerate rotamers of the input structure.
419
+ """
420
+ # center structures
421
+ structures = np.array([s - s.mean(axis=0) for s in structures])
422
+ ref = structures[0]
423
+
424
+ # get the number of molecular fragments
425
+ subgraphs = list(connected_components(graph))
426
+
427
+ # if they are more than two, give up on pruning by rot corr rmsd
428
+ if len(subgraphs) > 2:
429
+ return structures, np.ones(structures.shape[0], dtype=bool)
430
+
431
+ # if they are two, we can add a fictitious bond between the closest
432
+ # atoms on the two molecular fragment in the provided graph, and
433
+ # then removing it before returning
434
+ if len(subgraphs) == 2:
435
+ subgraphs = [list(vals) for vals in connected_components(graph)]
436
+ all_dists_array = cdist(ref[list(subgraphs[0])], ref[list(subgraphs[1])])
437
+ min_d = np.min(all_dists_array)
438
+ s1, s2 = np.where(all_dists_array == min_d)
439
+ i1, i2 = subgraphs[0][s1[0]], subgraphs[1][s2[0]]
440
+ graph.add_edge(i1, i2)
441
+
442
+ if debugfunction is not None:
443
+ debugfunction(
444
+ f"DEBUG: prune_by_rmsd_rot_corr - temporarily added "
445
+ f"edge {i1}-{i2} to the graph (will be removed before returning)"
446
+ )
447
+
448
+ # set default max_dev if not provided
449
+ max_dev = max_dev or 2 * max_rmsd
450
+
451
+ # add hydrogen bonds to molecular graph
452
+ hydrogen_bonds = get_hydrogen_bonds(ref, atomnos, graph)
453
+ for hb in hydrogen_bonds:
454
+ graph.add_edge(*hb)
455
+
456
+ # keep an unraveled set of atoms in hbs
457
+ flat_hbs = set(flatten(hydrogen_bonds))
458
+
459
+ # get all rotable bonds in the molecule, including dummy rotations
460
+ torsions = get_torsions(
461
+ graph,
462
+ hydrogen_bonds=hydrogen_bonds,
463
+ double_bonds=get_double_bonds_indices(ref, atomnos),
464
+ keepdummy=True,
465
+ mode="symmetry",
466
+ )
467
+
468
+ # only keep dummy rotations (checking both directions)
469
+ torsions = [
470
+ t
471
+ for t in torsions
472
+ if not (is_nondummy(t.i2, t.i3, graph) and (is_nondummy(t.i3, t.i2, graph)))
473
+ ]
474
+
475
+ # since we only compute RMSD based on heavy atoms, discard
476
+ # quadruplets that involve hydrogen atom as termini, unless
477
+ # they are involved in hydrogen bonding
478
+ torsions = [
479
+ t
480
+ for t in torsions
481
+ if (1 not in [atomnos[i] for i in t.torsion])
482
+ or (t.torsion[0] in flat_hbs or t.torsion[3] in flat_hbs)
483
+ ]
484
+
485
+ # get torsions angles
486
+ angles = [get_angles(t, graph) for t in torsions]
487
+
488
+ # Used specific directionality of torsions so that we always
489
+ # rotate the dummy portion (the one attached to the last index)
490
+ torsions_ids = np.asarray(
491
+ [
492
+ list(t.torsion) if is_nondummy(t.i2, t.i3, graph) else list(reversed(t.torsion))
493
+ for t in torsions
494
+ ]
495
+ )
496
+
497
+ # Set up final mask and cache
498
+ final_mask = np.ones(structures.shape[0], dtype=bool)
499
+
500
+ # Halt the run if there are too many structures or no subsymmetrical bonds
501
+ if len(torsions_ids) == 0:
502
+ if debugfunction is not None:
503
+ debugfunction(
504
+ "DEBUG: prune_by_rmsd_rot_corr - No subsymmetrical torsions found: skipping "
505
+ "symmetry-corrected RMSD pruning"
506
+ )
507
+
508
+ return structures[final_mask], final_mask
509
+
510
+ # Print out torsion information
511
+ if logfunction is not None:
512
+ logfunction("\n >> Dihedrals considered for rotamer corrections:")
513
+ for i, (torsion, angle) in enumerate(zip(torsions_ids, angles, strict=False)):
514
+ logfunction(
515
+ " {:2s} - {:21s} : {}{}{}{} : {}-fold".format(
516
+ str(i + 1),
517
+ str(torsion),
518
+ pt[atomnos[torsion[0]]].symbol,
519
+ pt[atomnos[torsion[1]]].symbol,
520
+ pt[atomnos[torsion[2]]].symbol,
521
+ pt[atomnos[torsion[3]]].symbol,
522
+ len(angle),
523
+ )
524
+ )
525
+ logfunction("\n")
526
+
527
+ # Initialize PrunerConfig
528
+ prunerconfig = RMSDRotCorrPrunerConfig(
529
+ structures=structures,
530
+ atomnos=atomnos,
531
+ graph=graph,
532
+ torsions=torsions_ids,
533
+ debugfunction=debugfunction,
534
+ angles=angles,
535
+ max_rmsd=max_rmsd,
536
+ max_dev=max_dev,
537
+ )
538
+
539
+ # run pruning
540
+ structures_out, mask = prune(prunerconfig)
541
+
542
+ # remove the extra bond in the molecular graph
543
+ if len(subgraphs) == 2:
544
+ graph.remove_edge(i1, i2)
545
+
546
+ return structures_out, mask
547
+
548
+
549
+ def prune_by_moment_of_inertia(
550
+ structures: Array3D_float,
551
+ atomnos: Array1D_int,
552
+ max_deviation: float = 1e-2,
553
+ debugfunction: Callable[[str], None] | None = None,
554
+ ) -> tuple[Array3D_float, Array1D_bool]:
555
+ """Remove duplicate structures using a moments of inertia-based metric.
556
+
557
+ Remove duplicate structures (enantiomeric or rotameric) based on the
558
+ moments of inertia on the principal axes. If all three MOI
559
+ deviate less than max_deviation percent from another structure,
560
+ they are classified as rotamers or enantiomers and therefore only one
561
+ of them is kept (i.e. max_deviation = 0.1 is 10% relative deviation).
562
+ """
563
+ # set up PrunerConfig dataclass
564
+ prunerconfig = MOIPrunerConfig(
565
+ structures=structures,
566
+ debugfunction=debugfunction,
567
+ max_dev=max_deviation,
568
+ masses=np.array([pt[a].mass for a in atomnos]),
569
+ )
570
+
571
+ return prune(prunerconfig)
prism_pruner/pt.py ADDED
@@ -0,0 +1,12 @@
1
+ """PRISM - PRuning Interface for Similar Molecules."""
2
+
3
+ from periodictable import core, covalent_radius, mass
4
+
5
+ for pt_n in range(5):
6
+ try:
7
+ pt = core.PeriodicTable(table=f"H={pt_n + 1}")
8
+ covalent_radius.init(pt)
9
+ mass.init(pt)
10
+ except ValueError:
11
+ continue
12
+ break
prism_pruner/rmsd.py ADDED
@@ -0,0 +1,39 @@
1
+ """PRISM - PRuning Interface for Similar Molecules."""
2
+
3
+ import numpy as np
4
+
5
+ from prism_pruner.algebra import get_alignment_matrix, norm_of
6
+ from prism_pruner.typing import Array2D_float
7
+
8
+
9
+ def rmsd_and_max(
10
+ p: Array2D_float,
11
+ q: Array2D_float,
12
+ center: bool = False,
13
+ ) -> tuple[float, float]:
14
+ """Return RMSD and max deviation.
15
+
16
+ Return a tuple with the RMSD between p and q
17
+ and the maximum deviation of their positions.
18
+ """
19
+ if center:
20
+ p -= p.mean(axis=0)
21
+ q -= q.mean(axis=0)
22
+
23
+ # get alignment matrix
24
+ rot_mat = get_alignment_matrix(p, q)
25
+
26
+ # Apply it to p
27
+ p = np.ascontiguousarray(p) @ rot_mat
28
+
29
+ # Calculate deviations
30
+ diff = p - q
31
+
32
+ # Calculate RMSD
33
+ rmsd = np.sqrt((diff * diff).sum() / len(diff))
34
+
35
+ # # Calculate max deviation
36
+ # max_delta = np.linalg.norm(diff, axis=1).max()
37
+ max_delta = max([norm_of(v) for v in diff])
38
+
39
+ return rmsd, max_delta