prism-pruner 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ """PRISM - PRuning Interface for Similar Molecules."""
@@ -0,0 +1,163 @@
1
+ """Algebra utilities."""
2
+
3
+ from typing import Sequence
4
+
5
+ import numpy as np
6
+
7
+ from prism_pruner.typing import Array1D_float, Array2D_float, Array3D_float
8
+
9
+
10
+ def normalize(vec: Array1D_float) -> Array1D_float:
11
+ """Normalize a vector."""
12
+ return vec / np.linalg.norm(vec)
13
+
14
+
15
+ def vec_angle(v1: Array1D_float, v2: Array1D_float) -> float:
16
+ """Return the planar angle defined by two 3D vectors."""
17
+ return float(
18
+ np.degrees(
19
+ np.arccos(
20
+ np.clip(
21
+ np.dot(
22
+ v1 / np.linalg.norm(v1),
23
+ v2 / np.linalg.norm(v2),
24
+ ),
25
+ -1.0,
26
+ 1.0,
27
+ ),
28
+ )
29
+ )
30
+ )
31
+
32
+
33
+ def dihedral(p: Array2D_float) -> float:
34
+ """
35
+ Find dihedral angle in degrees from 4 3D vecs.
36
+
37
+ Praxeolitic formula: 1 sqrt, 1 cross product.
38
+ """
39
+ p0, p1, p2, p3 = p
40
+
41
+ b0 = -1.0 * (p1 - p0)
42
+ b1 = p2 - p1
43
+ b2 = p3 - p2
44
+
45
+ # normalize b1 so that it does not influence magnitude of vector
46
+ # rejections that come next
47
+ b1 /= np.linalg.norm(b1)
48
+
49
+ # vector rejections
50
+ # v = projection of b0 onto plane perpendicular to b1
51
+ # = b0 minus component that aligns with b1
52
+ # w = projection of b2 onto plane perpendicular to b1
53
+ # = b2 minus component that aligns with b1
54
+ v = b0 - np.dot(b0, b1) * b1
55
+ w = b2 - np.dot(b2, b1) * b1
56
+
57
+ # angle between v and w in a plane is the torsion angle
58
+ # v and w may not be normalized but that's fine since tan is y/x
59
+ x = np.dot(v, w)
60
+ y = np.dot(np.cross(b1, v), w)
61
+
62
+ return float(np.degrees(np.arctan2(y, x)))
63
+
64
+
65
+ def rot_mat_from_pointer(pointer: Array1D_float, angle: float) -> Array2D_float:
66
+ """
67
+ Get the rotation matrix from the rotation pivot using a quaternion.
68
+
69
+ :param pointer: 3D vector representing the rotation pivot
70
+ :param angle: rotation angle in degrees
71
+ :return rotation_matrix: matrix that applied to a point, rotates it along the pointer
72
+ """
73
+ assert pointer.shape[0] == 3
74
+
75
+ angle_2 = np.radians(angle) / 2
76
+ sin = np.sin(angle_2)
77
+ pointer = pointer / np.linalg.norm(pointer)
78
+ return quaternion_to_rotation_matrix(
79
+ [
80
+ sin * pointer[0],
81
+ sin * pointer[1],
82
+ sin * pointer[2],
83
+ np.cos(angle_2),
84
+ ]
85
+ )
86
+
87
+
88
+ def quaternion_to_rotation_matrix(quat: Array1D_float | Sequence[float]) -> Array2D_float:
89
+ """
90
+ Convert a quaternion into a full three-dimensional rotation matrix.
91
+
92
+ This rotation matrix converts a point in the local reference frame to a
93
+ point in the global reference frame.
94
+
95
+ :param quat: 4-element array representing the quaternion (q0, q1, q2, q3)
96
+ :return: 3x3 element array representing the full 3D rotation matrix
97
+ """
98
+ # Extract the values from Q (adjusting for scalar last in input)
99
+ q1, q2, q3, q0 = quat
100
+
101
+ # First row of the rotation matrix
102
+ r00 = 2 * (q0 * q0 + q1 * q1) - 1
103
+ r01 = 2 * (q1 * q2 - q0 * q3)
104
+ r02 = 2 * (q1 * q3 + q0 * q2)
105
+
106
+ # Second row of the rotation matrix
107
+ r10 = 2 * (q1 * q2 + q0 * q3)
108
+ r11 = 2 * (q0 * q0 + q2 * q2) - 1
109
+ r12 = 2 * (q2 * q3 - q0 * q1)
110
+
111
+ # Third row of the rotation matrix
112
+ r20 = 2 * (q1 * q3 - q0 * q2)
113
+ r21 = 2 * (q2 * q3 + q0 * q1)
114
+ r22 = 2 * (q0 * q0 + q3 * q3) - 1
115
+
116
+ # 3x3 rotation matrix
117
+ return np.array([[r00, r01, r02], [r10, r11, r12], [r20, r21, r22]])
118
+
119
+
120
+ def get_inertia_moments(coords: Array3D_float, masses: Array1D_float) -> Array1D_float:
121
+ """Compute the principal moments of inertia of a molecule.
122
+
123
+ Returns a length-3 array [I_x, I_y, I_z], sorted ascending.
124
+ """
125
+ # Shift to center of mass
126
+ com = np.sum(coords * masses[:, np.newaxis], axis=0) / np.sum(masses)
127
+ coords = coords - com
128
+
129
+ # Compute inertia tensor
130
+ norms_sq = np.einsum("ni,ni->n", coords, coords)
131
+ total = np.sum(masses * norms_sq)
132
+ I_matrix = total * np.eye(3) - np.einsum("n,ni,nj->ij", masses, coords, coords)
133
+
134
+ # Principal moments via symmetric eigendecomposition
135
+ moments, _ = np.linalg.eigh(I_matrix)
136
+
137
+ return np.sort(moments)
138
+
139
+
140
+ def diagonalize(a: Array2D_float) -> Array2D_float:
141
+ """Build the diagonalized matrix."""
142
+ eigenvalues_of_a, eigenvectors_of_a = np.linalg.eig(a)
143
+ b = eigenvectors_of_a[:, np.abs(eigenvalues_of_a).argsort()]
144
+ return np.dot(np.linalg.inv(b), np.dot(a, b)) # type: ignore[no-any-return]
145
+
146
+
147
+ def get_alignment_matrix(p: Array1D_float, q: Array1D_float) -> Array2D_float:
148
+ """
149
+ Build the rotation matrix that aligns vectors q to p (Kabsch algorithm).
150
+
151
+ Assumes centered vector sets (i.e. their mean is the origin).
152
+ """
153
+ # calculate the covariance matrix
154
+ cov_mat = p.T @ q
155
+
156
+ # Compute the SVD
157
+ v, _, w = np.linalg.svd(cov_mat)
158
+
159
+ # Ensure proper rotation (det = 1, not -1)
160
+ if np.linalg.det(v) * np.linalg.det(w) < 0.0:
161
+ v[:, -1] *= -1
162
+
163
+ return v @ w # type: ignore[no-any-return]
@@ -0,0 +1,57 @@
1
+ """ConformerEnsemble class."""
2
+
3
+ import re
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Self
7
+
8
+ import numpy as np
9
+
10
+ from prism_pruner.typing import Array1D_float, Array1D_str, Array2D_float, Array3D_float
11
+
12
+
13
+ @dataclass
14
+ class ConformerEnsemble:
15
+ """Class representing a conformer ensemble."""
16
+
17
+ coords: Array3D_float
18
+ atoms: Array1D_str
19
+ energies: Array1D_float = field(default_factory=lambda: np.array([]))
20
+
21
+ @classmethod
22
+ def from_xyz(cls, file: Path | str, read_energies: bool = False) -> Self:
23
+ """Generate ensemble from a multiple conformer xyz file."""
24
+ coords = []
25
+ atoms = []
26
+ energies = []
27
+ with Path(file).open() as f:
28
+ for num in f:
29
+ if read_energies:
30
+ energy = next(re.finditer(r"-*\d+\.\d+", next(f))).group()
31
+ energies.append(float(energy))
32
+ else:
33
+ _comment = next(f)
34
+
35
+ conf_atoms = []
36
+ conf_coords = []
37
+ for _ in range(int(num)):
38
+ atom, *xyz = next(f).split()
39
+ conf_atoms.append(atom)
40
+ conf_coords.append([float(x) for x in xyz])
41
+
42
+ atoms.append(conf_atoms)
43
+ coords.append(conf_coords)
44
+
45
+ return cls(coords=np.array(coords), atoms=np.array(atoms[0]), energies=np.array(energies))
46
+
47
+ def to_xyz(self, file: Path | str) -> None:
48
+ """Write ensemble to an xyz file."""
49
+
50
+ def to_xyz(coords: Array2D_float) -> str:
51
+ return "\n".join(
52
+ f"{atom} {x:15.8f} {y:15.8f} {z:15.8f}"
53
+ for atom, (x, y, z) in zip(self.atoms, coords, strict=True)
54
+ )
55
+
56
+ with Path(file).open("w") as f:
57
+ f.write("\n".join(map(to_xyz, self.coords)))
@@ -0,0 +1,195 @@
1
+ """Graph manipulation utilities for molecular structures."""
2
+
3
+ from functools import lru_cache
4
+
5
+ import numpy as np
6
+ from networkx import Graph, all_simple_paths, from_numpy_array, set_node_attributes
7
+ from periodictable import elements
8
+ from scipy.spatial.distance import cdist
9
+
10
+ from prism_pruner.algebra import dihedral
11
+ from prism_pruner.typing import Array1D_bool, Array1D_str, Array2D_float
12
+
13
+
14
+ @lru_cache()
15
+ def d_min_bond(a1: str, a2: str, factor: float = 1.2) -> float:
16
+ """Return the bond distance between two atoms."""
17
+ return factor * (elements.symbol(a1).covalent_radius + elements.symbol(a2).covalent_radius) # type: ignore [no-any-return]
18
+
19
+
20
+ def graphize(
21
+ atoms: Array1D_str,
22
+ coords: Array2D_float,
23
+ mask: Array1D_bool | None = None,
24
+ ) -> Graph:
25
+ """
26
+ Return a NetworkX undirected graph of molecular connectivity.
27
+
28
+ :param atoms: atomic symbols
29
+ :param coords: atomic coordinates as 3D vectors
30
+ :param mask: bool array, with False for atoms to be excluded in the bond evaluation
31
+ :return: connectivity graph
32
+ """
33
+ mask = np.array([True for _ in atoms], dtype=bool) if mask is None else mask
34
+ assert len(coords) == len(atoms)
35
+ assert len(coords) == len(mask)
36
+
37
+ matrix = np.zeros((len(coords), len(coords)))
38
+ for i, mask_i in enumerate(mask):
39
+ if not mask_i:
40
+ continue
41
+
42
+ for j, mask_j in enumerate(mask[i + 1 :], start=i + 1):
43
+ if not mask_j:
44
+ continue
45
+
46
+ if np.linalg.norm(coords[i] - coords[j]) < d_min_bond(atoms[i], atoms[j]):
47
+ matrix[i][j] = 1
48
+
49
+ graph = from_numpy_array(matrix)
50
+ set_node_attributes(graph, dict(enumerate(atoms)), "atoms")
51
+
52
+ return graph
53
+
54
+
55
+ def get_sp_n(index: int, graph: Graph) -> int | None:
56
+ """
57
+ Get hybridization of selected atom.
58
+
59
+ Return n, that is the apex of sp^n hybridization for CONPS atoms.
60
+ This is just an assimilation to the carbon geometry in relation to sp^n:
61
+ - sp¹ is linear
62
+ - sp² is planar
63
+ - sp³ is tetraedral
64
+ This is mainly used to understand if a torsion is to be rotated or not.
65
+ """
66
+ atom = graph.nodes[index]["atoms"]
67
+
68
+ if atom not in {"C", "N", "O", "P", "S"}:
69
+ return None
70
+
71
+ # Relationship of number of neighbors to sp^n hybridization
72
+ d: dict[str, dict[int, int | None]] = {
73
+ "C": {2: 1, 3: 2, 4: 3},
74
+ "N": {2: 2, 3: None, 4: 3}, # 3 could mean sp3 or sp2
75
+ "O": {1: 2, 2: 3, 3: 3, 4: 3},
76
+ "P": {2: 2, 3: 3, 4: 3},
77
+ "S": {2: 2, 3: 3, 4: 3},
78
+ }
79
+ return d[atom].get(len(set(graph.neighbors(index))))
80
+
81
+
82
+ def is_amide_n(index: int, graph: Graph, mode: int = -1) -> bool:
83
+ """
84
+ Assess if the atom is an amide-like nitrogen.
85
+
86
+ Note: carbamates and ureas are considered amides.
87
+
88
+ mode:
89
+ -1 - any amide
90
+ 0 - primary amide (CONH2)
91
+ 1 - secondary amide (CONHR)
92
+ 2 - tertiary amide (CONR2)
93
+ """
94
+ # Must be a nitrogen atom
95
+ if graph.nodes[index]["atoms"] == "N":
96
+ nb = set(graph.neighbors(index))
97
+ nb_atoms = [graph.nodes[j]["atoms"] for j in nb]
98
+
99
+ if mode != -1:
100
+ # Primary amides need to have 1H, secondary amides none
101
+ if nb_atoms.count("H") != (2, 1, 0)[mode]:
102
+ return False
103
+
104
+ for n in nb:
105
+ # There must be at least one carbon atom next to N
106
+ if graph.nodes[n]["atoms"] == "C":
107
+ nb_nb = set(graph.neighbors(n))
108
+ # Bonded to three atoms
109
+ if len(nb_nb) == 3:
110
+ # and at least one of them has to be an oxygen
111
+ if "O" in {graph.nodes[i]["atoms"] for i in nb_nb}:
112
+ return True
113
+ return False
114
+
115
+
116
+ def is_ester_o(index: int, graph: Graph) -> bool:
117
+ """
118
+ Assess if the index is an ester-like oxygen.
119
+
120
+ Note: carbamates and carbonates return True, carboxylic acids return False.
121
+ """
122
+ if graph.nodes[index]["atoms"] == "O":
123
+ if "H" in (nb := set(graph.neighbors(index))):
124
+ return False
125
+
126
+ for n in nb:
127
+ if graph.nodes[n]["atoms"] == "C":
128
+ nb_nb = set(graph.neighbors(n))
129
+ if len(nb_nb) == 3:
130
+ nb_nb_sym = [graph.nodes[i]["atoms"] for i in nb_nb]
131
+ if nb_nb_sym.count("O") > 1:
132
+ return True
133
+ return False
134
+
135
+
136
+ def is_phenyl(coords: Array2D_float) -> bool:
137
+ """
138
+ Assess if the six atomic coords refer to a phenyl-like ring.
139
+
140
+ Note: quinones evaluate to True
141
+
142
+ :params coords: six coordinates of C/N atoms
143
+ :return: bool indicating if the six atoms look like part of a phenyl/naphtyl/pyridine
144
+ system, coordinates for the center of that ring
145
+ """
146
+ # if any atomic couple is more than 3 A away from each other, this is not a Ph
147
+ if np.max(cdist(coords, coords)) > 3:
148
+ return False
149
+
150
+ threshold_delta: float = 1 - np.cos(10 * np.pi / 180)
151
+ flat_delta: float = 1 - np.abs(np.cos(dihedral(coords[[0, 1, 2, 3]]) * np.pi / 180))
152
+
153
+ return flat_delta < threshold_delta
154
+
155
+
156
+ def get_phenyl_ids(index: int, graph: Graph) -> list[int] | None:
157
+ """If index is part of a phenyl, return the six heavy atoms ids associated with the ring."""
158
+ for n in graph.neighbors(index):
159
+ for path in all_simple_paths(graph, source=index, target=n, cutoff=6):
160
+ if len(path) != 6 or any(graph.nodes[n]["atoms"] == "H" for n in path):
161
+ continue
162
+ if all(len(set(graph.neighbors(i))) == 3 for i in path):
163
+ return path # type: ignore [no-any-return]
164
+
165
+ return None
166
+
167
+
168
+ def find_paths(
169
+ graph: Graph,
170
+ u: int,
171
+ n: int,
172
+ exclude_set: set[int] | None = None,
173
+ ) -> list[list[int]]:
174
+ """
175
+ Find paths in graph.
176
+
177
+ Recursively find all paths of a NetworkX graph with length = n, starting from node u.
178
+
179
+ :param graph: NetworkX graph
180
+ :param u: starting node
181
+ :param n: path length
182
+ :param exclude_set: set of nodes to exclude from the paths
183
+ :return: list of paths (each path is a list of node indices)
184
+ """
185
+ exclude_set = (exclude_set or set()) | {u}
186
+
187
+ if n == 0:
188
+ return [[u]]
189
+
190
+ return [
191
+ [u, *path]
192
+ for neighbor in graph.neighbors(u)
193
+ if neighbor not in exclude_set
194
+ for path in find_paths(graph, neighbor, n - 1, exclude_set)
195
+ ]