prism-pruner 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prism_pruner/pruner.py ADDED
@@ -0,0 +1,623 @@
1
+ """PRISM - PRuning Interface for Similar Molecules."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from time import perf_counter
5
+ from typing import Any, Callable, Sequence
6
+
7
+ import numpy as np
8
+ from networkx import Graph, connected_components
9
+ from periodictable import elements
10
+ from scipy.spatial.distance import cdist
11
+
12
+ from prism_pruner.algebra import get_inertia_moments
13
+ from prism_pruner.rmsd import rmsd_and_max
14
+ from prism_pruner.torsion_module import (
15
+ get_angles,
16
+ get_hydrogen_bonds,
17
+ get_torsions,
18
+ is_nondummy,
19
+ rotationally_corrected_rmsd_and_max,
20
+ )
21
+ from prism_pruner.typing import (
22
+ Array1D_bool,
23
+ Array1D_float,
24
+ Array1D_int,
25
+ Array1D_str,
26
+ Array2D_float,
27
+ Array2D_int,
28
+ Array3D_float,
29
+ )
30
+ from prism_pruner.utils import flatten, get_double_bonds_indices, time_to_string
31
+
32
+
33
+ @dataclass
34
+ class PrunerConfig:
35
+ """Configuration dataclass for Pruner."""
36
+
37
+ structures: Array3D_float
38
+
39
+ # Optional parameters that get initialized
40
+ energies: Array1D_float = field(default_factory=lambda: np.array([]))
41
+ max_dE: float = field(default=0.0)
42
+ debugfunction: Callable[[str], None] | None = field(default=None)
43
+
44
+ # Computed fields
45
+ eval_calls: int = field(default=0, init=False)
46
+ cache_calls: int = field(default=0, init=False)
47
+ cache: set[tuple[int, int]] = field(default_factory=lambda: set(), init=False)
48
+
49
+ def __post_init__(self) -> None:
50
+ """Validate inputs and initialize computed fields."""
51
+ self.mask = np.ones(shape=(self.structures.shape[0],), dtype=np.bool_)
52
+
53
+ if len(self.energies) != 0:
54
+ assert self.max_dE > 0.0, (
55
+ "If you provide energies, please also provide an appropriate energy window max_dE."
56
+ )
57
+
58
+ # Set defaults for optional parameters
59
+ if len(self.energies) == 0:
60
+ self.energies = np.zeros(self.structures.shape[0], dtype=float)
61
+
62
+ assert len(self.energies) == len(self.structures), (
63
+ "Please make sure that the energies "
64
+ + "provided have the same len as the input structures."
65
+ )
66
+
67
+ if self.max_dE == 0.0:
68
+ self.max_dE = 1.0
69
+
70
+ def evaluate_sim(self, *args: Any, **kwargs: Any) -> bool:
71
+ """Stub method - override in subclasses as needed."""
72
+ raise NotImplementedError
73
+
74
+
75
+ @dataclass
76
+ class RMSDRotCorrPrunerConfig(PrunerConfig):
77
+ """Configuration dataclass for Pruner."""
78
+
79
+ atoms: Array1D_str = field(kw_only=True)
80
+ max_rmsd: float = field(kw_only=True)
81
+ max_dev: float = field(kw_only=True)
82
+ angles: Sequence[Sequence[int]] = field(kw_only=True)
83
+ torsions: Array2D_int = field(kw_only=True)
84
+ graph: Graph = field(kw_only=True)
85
+ heavy_atoms_only: bool = True
86
+
87
+ def evaluate_sim(self, i1: int, i2: int) -> bool:
88
+ """Return whether the structures are similar."""
89
+ rmsd, max_dev = rotationally_corrected_rmsd_and_max(
90
+ self.structures[i1],
91
+ self.structures[i2],
92
+ atoms=self.atoms,
93
+ torsions=self.torsions,
94
+ graph=self.graph,
95
+ angles=self.angles,
96
+ # debugfunction=self.debugfunction, # lots of printout
97
+ heavy_atoms_only=self.heavy_atoms_only,
98
+ )
99
+
100
+ if rmsd > self.max_rmsd:
101
+ return False
102
+
103
+ if max_dev > self.max_dev:
104
+ return False
105
+
106
+ return True
107
+
108
+
109
+ @dataclass
110
+ class RMSDPrunerConfig(PrunerConfig):
111
+ """Configuration dataclass for Pruner."""
112
+
113
+ atoms: Array1D_str = field(kw_only=True)
114
+ max_rmsd: float = field(kw_only=True)
115
+ max_dev: float = field(kw_only=True)
116
+ heavy_atoms_only: bool = True
117
+
118
+ def evaluate_sim(self, i1: int, i2: int) -> bool:
119
+ """Return whether the structures are similar."""
120
+ if self.heavy_atoms_only:
121
+ mask = self.atoms != "H"
122
+ else:
123
+ mask = np.ones(self.structures[0].shape[0], dtype=bool)
124
+
125
+ rmsd, max_dev = rmsd_and_max(
126
+ self.structures[i1][mask],
127
+ self.structures[i2][mask],
128
+ center=True,
129
+ )
130
+
131
+ if rmsd > self.max_rmsd:
132
+ return False
133
+
134
+ if max_dev > self.max_dev:
135
+ return False
136
+
137
+ return True
138
+
139
+
140
+ @dataclass
141
+ class MOIPrunerConfig(PrunerConfig):
142
+ """Configuration dataclass for Pruner."""
143
+
144
+ masses: Array1D_float = field(kw_only=True)
145
+ max_dev: float = 0.01
146
+
147
+ def __post_init__(self) -> None:
148
+ """Add the moi_vecs calc to the parent's __post_init__."""
149
+ super().__post_init__()
150
+ self.moi_vecs = {
151
+ c: get_inertia_moments(
152
+ coord,
153
+ self.masses,
154
+ )
155
+ for c, coord in enumerate(self.structures)
156
+ }
157
+
158
+ def evaluate_sim(self, i1: int, i2: int) -> bool:
159
+ """Return whether the structures are similar."""
160
+ im_1 = self.moi_vecs[i1]
161
+ im_2 = self.moi_vecs[i2]
162
+
163
+ # compare the three MOIs via a Python loop:
164
+ # apparently much faster than numpy array operations
165
+ # for such a small array!
166
+ for j in range(3):
167
+ if np.abs(im_1[j] - im_2[j]) / im_1[j] >= self.max_dev:
168
+ return False
169
+ return True
170
+
171
+
172
+ def _main_compute_subrow(
173
+ prunerconfig: PrunerConfig,
174
+ in_mask: Array1D_bool,
175
+ first_abs_index: int,
176
+ num_str_in_row: int,
177
+ ) -> bool:
178
+ """Evaluate the similarity of a subrow of the similarity matrix.
179
+
180
+ Return True if ref is similar to any
181
+ structure in structures, returning at the first instance of a match.
182
+ Ignores structures that are False (0) in in_mask and does not perform
183
+ the comparison if the energy difference between the structures is less
184
+ than self.max_dE. Saves dissimilar structural pairs (i.e. that evaluate to
185
+ False (0)) by adding them to self.cache, avoiding redundant calcaulations.
186
+ """
187
+ i1 = first_abs_index
188
+
189
+ # iterate over target structures
190
+ for i in range(num_str_in_row):
191
+ # only compare active structures
192
+ if in_mask[i]:
193
+ # check if we have performed this comparison already:
194
+ # if so, we already know that those two structures are not similar,
195
+ # since the in_mask attribute is not False for ref nor for i
196
+ i2 = first_abs_index + 1 + i
197
+ hash_value = (i1, i2)
198
+
199
+ if hash_value in prunerconfig.cache:
200
+ prunerconfig.cache_calls += 1
201
+ continue
202
+
203
+ # if we have not computed the value before, check if the two
204
+ # structures have close enough energy before running the comparison
205
+ elif (
206
+ np.abs(prunerconfig.energies[i1] - prunerconfig.energies[i2]) < prunerconfig.max_dE
207
+ ):
208
+ # function will return True whether the structures are similar,
209
+ # and will stop iterating on this row, returning
210
+ prunerconfig.eval_calls += 1
211
+ if prunerconfig.evaluate_sim(i1, i2):
212
+ return True
213
+
214
+ # if structures are not similar, add the result to the
215
+ # cache, because they will return here,
216
+ # while similar structures are discarded and won't come back
217
+ else:
218
+ prunerconfig.cache.add(hash_value)
219
+
220
+ # if energy is not similar enough, also add to cache
221
+ else:
222
+ prunerconfig.cache.add(hash_value)
223
+
224
+ return False
225
+
226
+
227
+ def _main_compute_row(
228
+ prunerconfig: PrunerConfig,
229
+ row_indices: Array1D_int,
230
+ in_mask: Array1D_bool,
231
+ first_abs_index: int,
232
+ ) -> Array1D_bool:
233
+ """Evaluate the similarity of a row of the similarity matrix.
234
+
235
+ For a given set of structures, check if each is similar
236
+ to any other after itself. Return a boolean mask to slice
237
+ the array, only retaining the structures that are dissimilar.
238
+ The inner subrow function caches computed non-similar pairs.
239
+
240
+ """
241
+ # initialize the result container
242
+ out_mask = np.ones(shape=in_mask.shape, dtype=np.bool_)
243
+ line_len = len(row_indices)
244
+
245
+ # loop over the structures
246
+ for i in range(line_len):
247
+ # only check for similarity if the structure is active
248
+ if in_mask[i]:
249
+ # reject structure i if it is similar to any other after itself
250
+ similar = _main_compute_subrow(
251
+ prunerconfig,
252
+ in_mask[i + 1 :],
253
+ first_abs_index=first_abs_index + i,
254
+ num_str_in_row=line_len - i - 1,
255
+ )
256
+ out_mask[i] = not similar
257
+
258
+ else:
259
+ out_mask[i] = False
260
+
261
+ return out_mask
262
+
263
+
264
+ def _main_compute_group(
265
+ prunerconfig: PrunerConfig,
266
+ in_mask: Array1D_bool,
267
+ k: int,
268
+ ) -> Array1D_bool:
269
+ """Evaluate the similarity of each chunk of the similarity matrix.
270
+
271
+ Acts individually on k chunks of the structures array,
272
+ returning the updated mask.
273
+ """
274
+ # initialize final result container
275
+ out_mask = np.ones(shape=in_mask.shape, dtype=np.bool_)
276
+
277
+ # calculate the size of each chunk
278
+ chunksize = int(len(prunerconfig.structures) // k)
279
+
280
+ # iterate over chunks (multithreading here?)
281
+ for chunk in range(int(k)):
282
+ first = chunk * chunksize
283
+ if chunk == k - 1:
284
+ last = len(prunerconfig.structures)
285
+ else:
286
+ last = chunksize * (chunk + 1)
287
+
288
+ # get the structure indices chunk
289
+ indices_chunk = np.arange(first, last, 1, dtype=int)
290
+
291
+ # compare structures within that chunk and save results to the out_mask
292
+ out_mask[first:last] = _main_compute_row(
293
+ prunerconfig,
294
+ indices_chunk,
295
+ in_mask[first:last],
296
+ first_abs_index=first,
297
+ )
298
+ return out_mask
299
+
300
+
301
+ def prune(prunerconfig: PrunerConfig) -> tuple[Array2D_float, Array1D_bool]:
302
+ """Perform the similarity pruning.
303
+
304
+ Remove similar structures by repeatedly grouping them into k
305
+ subgroups and removing similar ones. A cache is present to avoid
306
+ repeating RMSD computations.
307
+
308
+ Similarity occurs for structures with both rmsd < self.max_rmsd and
309
+ maximum absolute atomic deviation < self.max_dev.
310
+
311
+ Sets the self.structures and the corresponding self.mask attributes.
312
+ """
313
+ start_t = perf_counter()
314
+
315
+ # initialize the output mask
316
+ out_mask = np.ones(shape=prunerconfig.structures.shape[0], dtype=np.bool_)
317
+ prunerconfig.cache = set()
318
+
319
+ # sort structures by ascending energy: this will have the effect of
320
+ # having energetically similar structures end up in the same chunk
321
+ # and therefore being pruned early
322
+ if np.abs(prunerconfig.energies[-1]) > 0:
323
+ sorting_indices = np.argsort(prunerconfig.energies)
324
+ prunerconfig.structures = prunerconfig.structures[sorting_indices]
325
+ prunerconfig.energies = prunerconfig.energies[sorting_indices]
326
+
327
+ # split the structure array in subgroups and prune them internally
328
+ for k in (
329
+ 500_000,
330
+ 200_000,
331
+ 100_000,
332
+ 50_000,
333
+ 20_000,
334
+ 10_000,
335
+ 5000,
336
+ 2000,
337
+ 1000,
338
+ 500,
339
+ 200,
340
+ 100,
341
+ 50,
342
+ 20,
343
+ 10,
344
+ 5,
345
+ 2,
346
+ 1,
347
+ ):
348
+ # choose only k values such that every subgroup
349
+ # has on average at least twenty active structures in it
350
+ if k == 1 or 20 * k < np.count_nonzero(out_mask):
351
+ before = np.count_nonzero(out_mask)
352
+
353
+ start_t_k = perf_counter()
354
+
355
+ # compute similarities and get back the out_mask
356
+ # and the pairings to be added to cache
357
+ out_mask = _main_compute_group(
358
+ prunerconfig,
359
+ out_mask,
360
+ k=k,
361
+ )
362
+
363
+ after = np.count_nonzero(out_mask)
364
+ newly_discarded = before - after
365
+
366
+ if prunerconfig.debugfunction is not None:
367
+ elapsed = perf_counter() - start_t_k
368
+ prunerconfig.debugfunction(
369
+ f"DEBUG: {prunerconfig.__class__.__name__} - k={k}, rejected {newly_discarded} "
370
+ + f"(keeping {after}/{len(out_mask)}), in {time_to_string(elapsed)}"
371
+ )
372
+
373
+ del prunerconfig.cache
374
+
375
+ if prunerconfig.debugfunction is not None:
376
+ elapsed = perf_counter() - start_t
377
+ prunerconfig.debugfunction(
378
+ f"DEBUG: {prunerconfig.__class__.__name__} - keeping "
379
+ + f"{after}/{len(out_mask)} "
380
+ + f"({time_to_string(elapsed)})"
381
+ )
382
+
383
+ if prunerconfig.eval_calls == 0:
384
+ fraction = 0.0
385
+ else:
386
+ fraction = prunerconfig.cache_calls / (
387
+ prunerconfig.eval_calls + prunerconfig.cache_calls
388
+ )
389
+
390
+ prunerconfig.debugfunction(
391
+ f"DEBUG: {prunerconfig.__class__.__name__} - Used cached data "
392
+ + f"{prunerconfig.cache_calls}/{prunerconfig.eval_calls + prunerconfig.cache_calls}"
393
+ + f" times, {100 * fraction:.2f}% of total calls"
394
+ )
395
+
396
+ return prunerconfig.structures[out_mask], out_mask
397
+
398
+
399
+ def prune_by_rmsd(
400
+ structures: Array3D_float,
401
+ atoms: Array1D_str,
402
+ max_rmsd: float = 0.25,
403
+ max_dev: float | None = None,
404
+ energies: Array1D_float | None = None,
405
+ max_dE: float = 0.0,
406
+ debugfunction: Callable[[str], None] | None = None,
407
+ ) -> tuple[Array3D_float, Array1D_bool]:
408
+ """Remove duplicate structures using a heavy-atom RMSD metric.
409
+
410
+ Remove similar structures by repeatedly grouping them into k
411
+ subgroups and removing similar ones. A cache is present to avoid
412
+ repeating RMSD computations.
413
+
414
+ Similarity occurs for structures with both RMSD < max_rmsd and
415
+ maximum deviation < max_dev. max_dev by default is 2 * max_rmsd.
416
+ """
417
+ if energies is None:
418
+ energies = np.array([])
419
+
420
+ # set default max_dev if not provided
421
+ max_dev = max_dev or 2 * max_rmsd
422
+
423
+ # set up PrunerConfig dataclass
424
+ prunerconfig = RMSDPrunerConfig(
425
+ structures=structures,
426
+ atoms=atoms,
427
+ max_rmsd=max_rmsd,
428
+ max_dev=max_dev,
429
+ energies=energies,
430
+ max_dE=max_dE,
431
+ debugfunction=debugfunction,
432
+ )
433
+
434
+ # run the pruning
435
+ return prune(prunerconfig)
436
+
437
+
438
+ def prune_by_rmsd_rot_corr(
439
+ structures: Array3D_float,
440
+ atoms: Array1D_str,
441
+ graph: Graph,
442
+ max_rmsd: float = 0.25,
443
+ max_dev: float | None = None,
444
+ energies: Array1D_float | None = None,
445
+ max_dE: float = 0.0,
446
+ logfunction: Callable[[str], None] | None = None,
447
+ debugfunction: Callable[[str], None] | None = None,
448
+ ) -> tuple[Array3D_float, Array1D_bool]:
449
+ """Remove duplicates using a heavy-atom RMSD metric, corrected for degenerate torsions.
450
+
451
+ Remove similar structures by repeatedly grouping them into k
452
+ subgroups and removing similar ones. A cache is present to avoid
453
+ repeating RMSD computations.
454
+
455
+ Similarity occurs for structures with both RMSD < max_rmsd and
456
+ maximum deviation < max_dev. max_dev by default is 2 * max_rmsd.
457
+
458
+ The RMSD and maximum deviation metrics used are the lowest ones
459
+ of all the degenerate rotamers of the input structure.
460
+ """
461
+ # center structures
462
+ structures = np.array([s - s.mean(axis=0) for s in structures])
463
+ ref = structures[0]
464
+
465
+ # get the number of molecular fragments
466
+ subgraphs = list(connected_components(graph))
467
+
468
+ # if they are more than two, give up on pruning by rot corr rmsd
469
+ if len(subgraphs) > 2:
470
+ return structures, np.ones(structures.shape[0], dtype=bool)
471
+
472
+ # if they are two, we can add a fictitious bond between the closest
473
+ # atoms on the two molecular fragment in the provided graph, and
474
+ # then removing it before returning
475
+ if len(subgraphs) == 2:
476
+ subgraphs = [list(vals) for vals in connected_components(graph)]
477
+ all_dists_array = cdist(ref[list(subgraphs[0])], ref[list(subgraphs[1])])
478
+ min_d = np.min(all_dists_array)
479
+ s1, s2 = np.where(all_dists_array == min_d)
480
+ i1, i2 = subgraphs[0][s1[0]], subgraphs[1][s2[0]]
481
+ graph.add_edge(i1, i2)
482
+
483
+ if debugfunction is not None:
484
+ debugfunction(
485
+ f"DEBUG: prune_by_rmsd_rot_corr - temporarily added "
486
+ f"edge {i1}-{i2} to the graph (will be removed before returning)"
487
+ )
488
+
489
+ # set default max_dev if not provided
490
+ max_dev = max_dev or 2 * max_rmsd
491
+
492
+ # add hydrogen bonds to molecular graph
493
+ hydrogen_bonds = get_hydrogen_bonds(ref, atoms, graph)
494
+ for hb in hydrogen_bonds:
495
+ graph.add_edge(*hb)
496
+
497
+ # keep an unraveled set of atoms in hbs
498
+ flat_hbs = set(flatten(hydrogen_bonds))
499
+
500
+ # get all rotable bonds in the molecule, including dummy rotations
501
+ torsions = get_torsions(
502
+ graph,
503
+ hydrogen_bonds=hydrogen_bonds,
504
+ double_bonds=get_double_bonds_indices(ref, atoms),
505
+ keepdummy=True,
506
+ mode="symmetry",
507
+ )
508
+
509
+ # only keep dummy rotations (checking both directions)
510
+ torsions = [
511
+ t
512
+ for t in torsions
513
+ if not (is_nondummy(t.i2, t.i3, graph) and (is_nondummy(t.i3, t.i2, graph)))
514
+ ]
515
+
516
+ # since we only compute RMSD based on heavy atoms, discard
517
+ # quadruplets that involve hydrogen atom as termini, unless
518
+ # they are involved in hydrogen bonding
519
+ torsions = [
520
+ t
521
+ for t in torsions
522
+ if ("H" not in [atoms[i] for i in t.torsion])
523
+ or (t.torsion[0] in flat_hbs or t.torsion[3] in flat_hbs)
524
+ ]
525
+
526
+ # get torsions angles
527
+ angles = [get_angles(t, graph) for t in torsions]
528
+
529
+ # Used specific directionality of torsions so that we always
530
+ # rotate the dummy portion (the one attached to the last index)
531
+ torsions_ids = np.asarray(
532
+ [
533
+ list(t.torsion) if is_nondummy(t.i2, t.i3, graph) else list(reversed(t.torsion))
534
+ for t in torsions
535
+ ]
536
+ )
537
+
538
+ # Set up final mask and cache
539
+ final_mask = np.ones(structures.shape[0], dtype=bool)
540
+
541
+ # Halt the run if there are too many structures or no subsymmetrical bonds
542
+ if len(torsions_ids) == 0:
543
+ if debugfunction is not None:
544
+ debugfunction(
545
+ "DEBUG: prune_by_rmsd_rot_corr - No subsymmetrical torsions found: skipping "
546
+ "symmetry-corrected RMSD pruning"
547
+ )
548
+
549
+ return structures[final_mask], final_mask
550
+
551
+ # Print out torsion information
552
+ if logfunction is not None:
553
+ logfunction("\n >> Dihedrals considered for rotamer corrections:")
554
+ for i, (torsion, angle) in enumerate(zip(torsions_ids, angles, strict=False)):
555
+ logfunction(
556
+ " {:2s} - {:21s} : {}{}{}{} : {}-fold".format(
557
+ str(i + 1),
558
+ str(torsion),
559
+ atoms[torsion[0]],
560
+ atoms[torsion[1]],
561
+ atoms[torsion[2]],
562
+ atoms[torsion[3]],
563
+ len(angle),
564
+ )
565
+ )
566
+ logfunction("\n")
567
+
568
+ if energies is None:
569
+ energies = np.array([])
570
+
571
+ # Initialize PrunerConfig
572
+ prunerconfig = RMSDRotCorrPrunerConfig(
573
+ structures=structures,
574
+ atoms=atoms,
575
+ energies=energies,
576
+ max_dE=max_dE,
577
+ graph=graph,
578
+ torsions=torsions_ids,
579
+ debugfunction=debugfunction,
580
+ angles=angles,
581
+ max_rmsd=max_rmsd,
582
+ max_dev=max_dev,
583
+ )
584
+
585
+ # run pruning
586
+ structures_out, mask = prune(prunerconfig)
587
+
588
+ # remove the extra bond in the molecular graph
589
+ if len(subgraphs) == 2:
590
+ graph.remove_edge(i1, i2)
591
+
592
+ return structures_out, mask
593
+
594
+
595
+ def prune_by_moment_of_inertia(
596
+ structures: Array3D_float,
597
+ atoms: Array1D_str,
598
+ max_deviation: float = 1e-2,
599
+ energies: Array1D_float | None = None,
600
+ max_dE: float = 0.0,
601
+ debugfunction: Callable[[str], None] | None = None,
602
+ ) -> tuple[Array3D_float, Array1D_bool]:
603
+ """Remove duplicate structures using a moments of inertia-based metric.
604
+
605
+ Remove duplicate structures (enantiomeric or rotameric) based on the
606
+ moment of inertia on the principal axes. If all three deviate less than
607
+ max_deviation percent from another one, the structure is removed from
608
+ the ensemble (i.e. max_deviation = 0.1 is 10% relative deviation).
609
+ """
610
+ if energies is None:
611
+ energies = np.array([])
612
+
613
+ # set up PrunerConfig dataclass
614
+ prunerconfig = MOIPrunerConfig(
615
+ structures=structures,
616
+ energies=energies,
617
+ max_dE=max_dE,
618
+ debugfunction=debugfunction,
619
+ max_dev=max_deviation,
620
+ masses=np.array([elements.symbol(a).mass for a in atoms]),
621
+ )
622
+
623
+ return prune(prunerconfig)
prism_pruner/rmsd.py ADDED
@@ -0,0 +1,38 @@
1
+ """PRISM - PRuning Interface for Similar Molecules."""
2
+
3
+ import numpy as np
4
+
5
+ from prism_pruner.algebra import get_alignment_matrix
6
+ from prism_pruner.typing import Array2D_float
7
+
8
+
9
+ def rmsd_and_max(
10
+ p: Array2D_float,
11
+ q: Array2D_float,
12
+ center: bool = True,
13
+ ) -> tuple[float, float]:
14
+ """Return RMSD and max deviation.
15
+
16
+ Return a tuple with the RMSD between p and q
17
+ and the maximum deviation of their positions.
18
+ """
19
+ if center:
20
+ p = p - p.mean(axis=0)
21
+ q = q - q.mean(axis=0)
22
+
23
+ # get alignment matrix
24
+ rot_mat = get_alignment_matrix(p, q)
25
+
26
+ # Apply it to p
27
+ p = np.ascontiguousarray(p) @ rot_mat
28
+
29
+ # Calculate deviations
30
+ diff = p - q
31
+
32
+ # Calculate RMSD
33
+ rmsd = np.sqrt(np.sum(diff * diff) / len(diff))
34
+
35
+ # Calculate max deviation
36
+ max_delta = np.max(np.linalg.norm(diff, axis=1))
37
+
38
+ return rmsd, max_delta