molscope 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- molscope/__init__.py +53 -0
- molscope/__main__.py +3 -0
- molscope/cli.py +74 -0
- molscope/coarsegrain.py +411 -0
- molscope/contactmap.py +116 -0
- molscope/descriptors.py +232 -0
- molscope/elements.py +75 -0
- molscope/ensemble.py +143 -0
- molscope/graph.py +151 -0
- molscope/io.py +342 -0
- molscope/molecule.py +502 -0
- molscope/plotting.py +191 -0
- molscope-0.6.0.dist-info/METADATA +335 -0
- molscope-0.6.0.dist-info/RECORD +18 -0
- molscope-0.6.0.dist-info/WHEEL +5 -0
- molscope-0.6.0.dist-info/entry_points.txt +2 -0
- molscope-0.6.0.dist-info/licenses/LICENSE +21 -0
- molscope-0.6.0.dist-info/top_level.txt +1 -0
molscope/descriptors.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""Fixed-size molecular descriptors for quick ML feature tables."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections import Counter
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from .io import read
|
|
11
|
+
|
|
12
|
+
DEFAULT_ELEMENTS = (
|
|
13
|
+
"H", "C", "N", "O", "S", "P", "F", "CL", "BR", "I", "NA", "MG", "CA", "FE", "ZN",
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def descriptors(
|
|
18
|
+
molecule,
|
|
19
|
+
*,
|
|
20
|
+
elements_to_count=DEFAULT_ELEMENTS,
|
|
21
|
+
distance_bins: int = 10,
|
|
22
|
+
distance_range: tuple[float, float] = (0.0, 20.0),
|
|
23
|
+
contact_cutoff: float = 5.0,
|
|
24
|
+
residue_contact_cutoff: float = 8.0,
|
|
25
|
+
) -> dict:
|
|
26
|
+
"""Return a flat descriptor dictionary for a molecule.
|
|
27
|
+
|
|
28
|
+
The defaults are fixed-size and suitable for small ML tables. Matrix-valued
|
|
29
|
+
features such as contact maps remain available through ``mol.contact_map()``;
|
|
30
|
+
this function records table-friendly summaries of them.
|
|
31
|
+
"""
|
|
32
|
+
coords = np.asarray(molecule.coords, dtype=float)
|
|
33
|
+
n_atoms = len(molecule)
|
|
34
|
+
masses = molecule.masses if n_atoms else np.empty(0, dtype=float)
|
|
35
|
+
desc = {
|
|
36
|
+
"n_atoms": float(n_atoms),
|
|
37
|
+
"n_residues": float(_n_residues(molecule)),
|
|
38
|
+
"molecular_mass": float(masses.sum()) if n_atoms else 0.0,
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
counts = Counter(e.upper() for e in molecule.elements if e)
|
|
42
|
+
for symbol in elements_to_count:
|
|
43
|
+
desc[f"count_{symbol.upper()}"] = float(counts.get(symbol.upper(), 0))
|
|
44
|
+
|
|
45
|
+
if n_atoms == 0:
|
|
46
|
+
return _empty_descriptors(desc, distance_bins)
|
|
47
|
+
|
|
48
|
+
dims = molecule.dimensions
|
|
49
|
+
centroid = molecule.centroid
|
|
50
|
+
center_of_mass = molecule.center_of_mass
|
|
51
|
+
desc.update({
|
|
52
|
+
"centroid_x": float(centroid[0]),
|
|
53
|
+
"centroid_y": float(centroid[1]),
|
|
54
|
+
"centroid_z": float(centroid[2]),
|
|
55
|
+
"center_of_mass_x": float(center_of_mass[0]),
|
|
56
|
+
"center_of_mass_y": float(center_of_mass[1]),
|
|
57
|
+
"center_of_mass_z": float(center_of_mass[2]),
|
|
58
|
+
"radius_of_gyration": molecule.radius_of_gyration,
|
|
59
|
+
"dim_x": float(dims[0]),
|
|
60
|
+
"dim_y": float(dims[1]),
|
|
61
|
+
"dim_z": float(dims[2]),
|
|
62
|
+
"bbox_volume": float(np.prod(dims)),
|
|
63
|
+
"compactness": _compactness(n_atoms, dims),
|
|
64
|
+
})
|
|
65
|
+
|
|
66
|
+
inertia = inertia_tensor(molecule)
|
|
67
|
+
principal_moments, principal_axes = np.linalg.eigh(inertia)
|
|
68
|
+
order = np.argsort(principal_moments)
|
|
69
|
+
principal_moments = principal_moments[order]
|
|
70
|
+
principal_axes = principal_axes[:, order]
|
|
71
|
+
desc["inertia_tensor"] = inertia.reshape(-1).astype(float).tolist()
|
|
72
|
+
desc["principal_moments"] = principal_moments.astype(float).tolist()
|
|
73
|
+
desc["principal_axes"] = principal_axes.reshape(-1).astype(float).tolist()
|
|
74
|
+
desc["shape_anisotropy"] = shape_anisotropy(principal_moments)
|
|
75
|
+
|
|
76
|
+
distances = _pairwise_distances(coords)
|
|
77
|
+
hist, _ = np.histogram(distances, bins=distance_bins, range=distance_range)
|
|
78
|
+
desc["distance_histogram"] = hist.astype(float).tolist()
|
|
79
|
+
desc.update(_bond_length_summary(molecule))
|
|
80
|
+
desc.update(_contact_summary(molecule, contact_cutoff))
|
|
81
|
+
desc.update(_residue_contact_summary(molecule, residue_contact_cutoff))
|
|
82
|
+
return desc
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def inertia_tensor(molecule) -> np.ndarray:
|
|
86
|
+
"""Mass-weighted inertia tensor around the centre of mass."""
|
|
87
|
+
coords = np.asarray(molecule.coords, dtype=float)
|
|
88
|
+
if len(molecule) == 0:
|
|
89
|
+
return np.zeros((3, 3), dtype=float)
|
|
90
|
+
centered = coords - molecule.center_of_mass
|
|
91
|
+
masses = molecule.masses
|
|
92
|
+
r2 = (centered ** 2).sum(axis=1)
|
|
93
|
+
tensor = np.eye(3) * np.sum(masses * r2)
|
|
94
|
+
tensor -= centered.T @ (centered * masses[:, None])
|
|
95
|
+
return tensor
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def shape_anisotropy(principal_moments) -> float:
|
|
99
|
+
"""Dimensionless anisotropy from principal moments of inertia."""
|
|
100
|
+
moments = np.asarray(principal_moments, dtype=float)
|
|
101
|
+
denom = float(np.sum(moments ** 2))
|
|
102
|
+
if denom == 0.0:
|
|
103
|
+
return 0.0
|
|
104
|
+
mean = float(moments.mean())
|
|
105
|
+
return float(1.5 * np.sum((moments - mean) ** 2) / denom)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def featurize_many(
|
|
109
|
+
paths,
|
|
110
|
+
*,
|
|
111
|
+
feature_names: Optional[list[str]] = None,
|
|
112
|
+
return_names: bool = False,
|
|
113
|
+
**descriptor_kwargs,
|
|
114
|
+
):
|
|
115
|
+
"""Read structures and return a numeric descriptor matrix.
|
|
116
|
+
|
|
117
|
+
By default columns are the union of descriptor keys found across the input
|
|
118
|
+
molecules. Pass ``feature_names`` to force a stable column order, or
|
|
119
|
+
``return_names=True`` to receive ``(X, names)``.
|
|
120
|
+
"""
|
|
121
|
+
rows = [flatten_descriptors(descriptors(read(path), **descriptor_kwargs)) for path in paths]
|
|
122
|
+
names = feature_names or sorted({key for row in rows for key in row})
|
|
123
|
+
matrix = np.array([[row.get(name, 0.0) for name in names] for row in rows], dtype=float)
|
|
124
|
+
return (matrix, names) if return_names else matrix
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def flatten_descriptors(desc: dict) -> dict[str, float]:
|
|
128
|
+
"""Expand list-valued descriptors into scalar columns."""
|
|
129
|
+
flat = {}
|
|
130
|
+
for key, value in desc.items():
|
|
131
|
+
if isinstance(value, (list, tuple, np.ndarray)):
|
|
132
|
+
for i, item in enumerate(value):
|
|
133
|
+
flat[f"{key}_{i}"] = float(item)
|
|
134
|
+
else:
|
|
135
|
+
flat[key] = float(value)
|
|
136
|
+
return flat
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _empty_descriptors(desc: dict, distance_bins: int) -> dict:
|
|
140
|
+
desc.update({
|
|
141
|
+
"centroid_x": 0.0,
|
|
142
|
+
"centroid_y": 0.0,
|
|
143
|
+
"centroid_z": 0.0,
|
|
144
|
+
"center_of_mass_x": 0.0,
|
|
145
|
+
"center_of_mass_y": 0.0,
|
|
146
|
+
"center_of_mass_z": 0.0,
|
|
147
|
+
"radius_of_gyration": 0.0,
|
|
148
|
+
"dim_x": 0.0,
|
|
149
|
+
"dim_y": 0.0,
|
|
150
|
+
"dim_z": 0.0,
|
|
151
|
+
"bbox_volume": 0.0,
|
|
152
|
+
"compactness": 0.0,
|
|
153
|
+
"inertia_tensor": [0.0] * 9,
|
|
154
|
+
"principal_moments": [0.0] * 3,
|
|
155
|
+
"principal_axes": [0.0] * 9,
|
|
156
|
+
"shape_anisotropy": 0.0,
|
|
157
|
+
"distance_histogram": [0.0] * distance_bins,
|
|
158
|
+
"bond_count": 0.0,
|
|
159
|
+
"bond_length_mean": 0.0,
|
|
160
|
+
"bond_length_std": 0.0,
|
|
161
|
+
"bond_length_min": 0.0,
|
|
162
|
+
"bond_length_max": 0.0,
|
|
163
|
+
"atom_contact_count": 0.0,
|
|
164
|
+
"atom_contact_density": 0.0,
|
|
165
|
+
"residue_contact_count": 0.0,
|
|
166
|
+
"residue_contact_density": 0.0,
|
|
167
|
+
})
|
|
168
|
+
return desc
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _pairwise_distances(coords: np.ndarray) -> np.ndarray:
|
|
172
|
+
n = len(coords)
|
|
173
|
+
if n < 2:
|
|
174
|
+
return np.empty(0, dtype=float)
|
|
175
|
+
i, j = np.triu_indices(n, k=1)
|
|
176
|
+
return np.linalg.norm(coords[i] - coords[j], axis=1)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _bond_length_summary(molecule) -> dict[str, float]:
|
|
180
|
+
bonds = molecule.bonds()
|
|
181
|
+
if len(bonds) == 0:
|
|
182
|
+
return {
|
|
183
|
+
"bond_count": 0.0,
|
|
184
|
+
"bond_length_mean": 0.0,
|
|
185
|
+
"bond_length_std": 0.0,
|
|
186
|
+
"bond_length_min": 0.0,
|
|
187
|
+
"bond_length_max": 0.0,
|
|
188
|
+
}
|
|
189
|
+
lengths = np.linalg.norm(molecule.coords[bonds[:, 0]] - molecule.coords[bonds[:, 1]], axis=1)
|
|
190
|
+
return {
|
|
191
|
+
"bond_count": float(len(lengths)),
|
|
192
|
+
"bond_length_mean": float(lengths.mean()),
|
|
193
|
+
"bond_length_std": float(lengths.std()),
|
|
194
|
+
"bond_length_min": float(lengths.min()),
|
|
195
|
+
"bond_length_max": float(lengths.max()),
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _contact_summary(molecule, cutoff: float) -> dict[str, float]:
|
|
200
|
+
contacts = molecule.contacts(cutoff=cutoff)
|
|
201
|
+
possible = len(molecule) * (len(molecule) - 1) / 2
|
|
202
|
+
return {
|
|
203
|
+
"atom_contact_count": float(len(contacts)),
|
|
204
|
+
"atom_contact_density": float(len(contacts) / possible) if possible else 0.0,
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _residue_contact_summary(molecule, cutoff: float) -> dict[str, float]:
|
|
209
|
+
if len(molecule.resids) == 0:
|
|
210
|
+
return {"residue_contact_count": 0.0, "residue_contact_density": 0.0}
|
|
211
|
+
try:
|
|
212
|
+
matrix = molecule.contact_map(cutoff=cutoff, level="residue").matrix
|
|
213
|
+
except ValueError:
|
|
214
|
+
return {"residue_contact_count": 0.0, "residue_contact_density": 0.0}
|
|
215
|
+
n = len(matrix)
|
|
216
|
+
possible = n * (n - 1) / 2
|
|
217
|
+
count = float(np.triu(matrix.astype(bool), k=1).sum())
|
|
218
|
+
return {
|
|
219
|
+
"residue_contact_count": count,
|
|
220
|
+
"residue_contact_density": float(count / possible) if possible else 0.0,
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _n_residues(molecule) -> int:
|
|
225
|
+
if len(molecule.resids) == 0:
|
|
226
|
+
return 0
|
|
227
|
+
return sum(1 for _ in molecule.residue_groups())
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def _compactness(n_atoms: int, dims: np.ndarray) -> float:
|
|
231
|
+
volume = float(np.prod(dims))
|
|
232
|
+
return float(n_atoms / volume) if volume > 0.0 else 0.0
|
molscope/elements.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""Per-element reference data: CPK colours and covalent radii.
|
|
2
|
+
|
|
3
|
+
Values cover the elements common in the sample structures; anything missing
|
|
4
|
+
falls back to a neutral default so unknown atoms still render and bond.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
# CPK colours (normalised RGB), the convention most molecular viewers use.
|
|
8
|
+
CPK_COLORS = {
|
|
9
|
+
"H": (1.00, 1.00, 1.00),
|
|
10
|
+
"C": (0.30, 0.30, 0.30),
|
|
11
|
+
"N": (0.10, 0.10, 0.85),
|
|
12
|
+
"O": (0.85, 0.10, 0.10),
|
|
13
|
+
"S": (0.90, 0.80, 0.20),
|
|
14
|
+
"P": (1.00, 0.50, 0.00),
|
|
15
|
+
"F": (0.30, 0.80, 0.30),
|
|
16
|
+
"CL": (0.20, 0.80, 0.20),
|
|
17
|
+
"BR": (0.60, 0.20, 0.10),
|
|
18
|
+
"I": (0.50, 0.10, 0.60),
|
|
19
|
+
"FE": (0.80, 0.40, 0.10),
|
|
20
|
+
"CA": (0.30, 0.70, 0.70),
|
|
21
|
+
"NA": (0.50, 0.20, 0.80),
|
|
22
|
+
"MG": (0.20, 0.60, 0.20),
|
|
23
|
+
"ZN": (0.50, 0.50, 0.60),
|
|
24
|
+
}
|
|
25
|
+
DEFAULT_COLOR = (0.50, 0.50, 0.50)
|
|
26
|
+
|
|
27
|
+
# Covalent radii in angstrom (Cordero et al. 2008, rounded). Used to infer bonds.
|
|
28
|
+
COVALENT_RADII = {
|
|
29
|
+
"H": 0.31, "C": 0.76, "N": 0.71, "O": 0.66, "S": 1.05,
|
|
30
|
+
"P": 1.07, "F": 0.57, "CL": 1.02, "BR": 1.20, "I": 1.39,
|
|
31
|
+
"FE": 1.32, "CA": 1.76, "NA": 1.66, "MG": 1.41, "ZN": 1.22,
|
|
32
|
+
}
|
|
33
|
+
DEFAULT_RADIUS = 0.75
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# Standard atomic weights (g/mol). Unknown atoms fall back to 1.0 so that a
|
|
37
|
+
# mass-weighted centre over all-unknown elements reduces to the geometric mean.
|
|
38
|
+
ATOMIC_MASSES = {
|
|
39
|
+
"H": 1.008, "C": 12.011, "N": 14.007, "O": 15.999, "S": 32.06,
|
|
40
|
+
"P": 30.974, "F": 18.998, "CL": 35.45, "BR": 79.904, "I": 126.904,
|
|
41
|
+
"FE": 55.845, "CA": 40.078, "NA": 22.990, "MG": 24.305, "ZN": 65.38,
|
|
42
|
+
}
|
|
43
|
+
DEFAULT_MASS = 1.0
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def color(element: str):
|
|
47
|
+
"""CPK colour for an element symbol (case-insensitive)."""
|
|
48
|
+
return CPK_COLORS.get(element.upper(), DEFAULT_COLOR)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def covalent_radius(element: str) -> float:
|
|
52
|
+
"""Covalent radius in angstrom for an element symbol (case-insensitive)."""
|
|
53
|
+
return COVALENT_RADII.get(element.upper(), DEFAULT_RADIUS)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def mass(element: str) -> float:
|
|
57
|
+
"""Atomic weight in g/mol for an element symbol (case-insensitive)."""
|
|
58
|
+
return ATOMIC_MASSES.get(element.upper(), DEFAULT_MASS)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# Atomic numbers for the first four periods (enough for biomolecules and most
|
|
62
|
+
# small molecules); unknown symbols map to 0 so graph code never crashes.
|
|
63
|
+
ATOMIC_NUMBERS = {
|
|
64
|
+
"H": 1, "HE": 2, "LI": 3, "BE": 4, "B": 5, "C": 6, "N": 7, "O": 8,
|
|
65
|
+
"F": 9, "NE": 10, "NA": 11, "MG": 12, "AL": 13, "SI": 14, "P": 15,
|
|
66
|
+
"S": 16, "CL": 17, "AR": 18, "K": 19, "CA": 20, "SC": 21, "TI": 22,
|
|
67
|
+
"V": 23, "CR": 24, "MN": 25, "FE": 26, "CO": 27, "NI": 28, "CU": 29,
|
|
68
|
+
"ZN": 30, "GA": 31, "GE": 32, "AS": 33, "SE": 34, "BR": 35, "KR": 36,
|
|
69
|
+
"I": 53,
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def atomic_number(element: str) -> int:
|
|
74
|
+
"""Atomic number (Z) for an element symbol (case-insensitive); 0 if unknown."""
|
|
75
|
+
return ATOMIC_NUMBERS.get(element.upper(), 0)
|
molscope/ensemble.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Analysis across a set of structures, e.g. the models of an NMR ensemble."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field, replace
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from .molecule import Molecule
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def align_all(models: list[Molecule], reference: Optional[Molecule] = None) -> list[Molecule]:
|
|
14
|
+
"""Kabsch-superpose every model onto ``reference`` (default: the first model)."""
|
|
15
|
+
ref = reference if reference is not None else models[0]
|
|
16
|
+
return [m.superpose(ref) for m in models]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def average(models: list[Molecule], align: bool = True) -> Molecule:
|
|
20
|
+
"""Average structure over the ensemble (atoms matched by index)."""
|
|
21
|
+
_check_consistent(models)
|
|
22
|
+
aligned = align_all(models) if align else models
|
|
23
|
+
coords = np.mean([m.coords for m in aligned], axis=0)
|
|
24
|
+
return replace(models[0], coords=coords, name=f"{models[0].name} (average)")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def rmsf(models: list[Molecule], align: bool = True) -> np.ndarray:
|
|
28
|
+
"""Per-atom root-mean-square fluctuation about the mean position."""
|
|
29
|
+
_check_consistent(models)
|
|
30
|
+
aligned = align_all(models) if align else models
|
|
31
|
+
stack = np.array([m.coords for m in aligned]) # (n_models, n_atoms, 3)
|
|
32
|
+
mean = stack.mean(axis=0)
|
|
33
|
+
return np.sqrt(((stack - mean) ** 2).sum(axis=2).mean(axis=0))
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def contact_frequency(models: list[Molecule], cutoff: float = 8.0,
|
|
37
|
+
level: str = "residue", method: str = "ca"):
|
|
38
|
+
"""Fraction of models in which each pair is in contact (an ensemble map).
|
|
39
|
+
|
|
40
|
+
Returns a :class:`~molscope.contactmap.ContactMap` whose matrix holds
|
|
41
|
+
values in ``[0, 1]`` — the contact probability for each residue (or atom)
|
|
42
|
+
pair across the ensemble. Useful for NMR variability and folding analysis.
|
|
43
|
+
"""
|
|
44
|
+
from .contactmap import ContactMap, contact_map
|
|
45
|
+
|
|
46
|
+
_check_consistent(models)
|
|
47
|
+
maps = [contact_map(m, cutoff=cutoff, level=level, method=method) for m in models]
|
|
48
|
+
freq = np.mean([cm.matrix for cm in maps], axis=0)
|
|
49
|
+
first = maps[0]
|
|
50
|
+
return ContactMap(freq, level=level, cutoff=cutoff,
|
|
51
|
+
labels=first.labels, resids=first.resids)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def rmsd_matrix(models: list[Molecule], align: bool = True) -> np.ndarray:
|
|
55
|
+
"""Symmetric ``(M, M)`` matrix of pairwise RMSDs between models."""
|
|
56
|
+
_check_consistent(models)
|
|
57
|
+
n = len(models)
|
|
58
|
+
mat = np.zeros((n, n))
|
|
59
|
+
for i in range(n):
|
|
60
|
+
for j in range(i + 1, n):
|
|
61
|
+
mat[i, j] = mat[j, i] = models[i].rmsd(models[j], align=align)
|
|
62
|
+
return mat
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _check_consistent(models: list[Molecule]) -> None:
|
|
66
|
+
if not models:
|
|
67
|
+
raise ValueError("no models given")
|
|
68
|
+
sizes = {len(m) for m in models}
|
|
69
|
+
if len(sizes) != 1:
|
|
70
|
+
raise ValueError(f"models have differing atom counts: {sorted(sizes)}")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class Clustering:
|
|
75
|
+
"""Result of clustering structures by RMSD.
|
|
76
|
+
|
|
77
|
+
``labels`` gives the 1-based cluster id of each model (same order as the
|
|
78
|
+
input). ``matrix`` is the RMSD matrix used and ``linkage`` the scipy linkage.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
labels: np.ndarray
|
|
82
|
+
matrix: np.ndarray
|
|
83
|
+
linkage: Optional[np.ndarray] = field(default=None)
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def n_clusters(self) -> int:
|
|
87
|
+
return int(len(np.unique(self.labels)))
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def order(self) -> np.ndarray:
|
|
91
|
+
"""Model indices sorted by cluster (for a block-diagonal heatmap)."""
|
|
92
|
+
return np.argsort(self.labels, kind="stable")
|
|
93
|
+
|
|
94
|
+
def groups(self) -> dict[int, list[int]]:
|
|
95
|
+
"""Map each cluster id to the list of model indices it contains."""
|
|
96
|
+
return {int(c): np.where(self.labels == c)[0].tolist()
|
|
97
|
+
for c in np.unique(self.labels)}
|
|
98
|
+
|
|
99
|
+
def medoid(self, cluster_id: int) -> int:
|
|
100
|
+
"""Index of the most central model of a cluster (min total RMSD)."""
|
|
101
|
+
members = np.where(self.labels == cluster_id)[0]
|
|
102
|
+
sub = self.matrix[np.ix_(members, members)]
|
|
103
|
+
return int(members[sub.sum(axis=1).argmin()])
|
|
104
|
+
|
|
105
|
+
def representatives(self) -> dict[int, int]:
|
|
106
|
+
"""Map each cluster id to its medoid model index."""
|
|
107
|
+
return {int(c): self.medoid(int(c)) for c in np.unique(self.labels)}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def cluster(models, method: str = "hierarchical", cutoff: Optional[float] = None,
|
|
111
|
+
n_clusters: Optional[int] = None, linkage: str = "average",
|
|
112
|
+
align: bool = True, matrix=None) -> Clustering:
|
|
113
|
+
"""Cluster structures by pairwise RMSD.
|
|
114
|
+
|
|
115
|
+
Pass ``n_clusters`` to cut the tree into a fixed number of clusters, or
|
|
116
|
+
``cutoff`` (an RMSD threshold in angstrom). With neither, a data-driven
|
|
117
|
+
cutoff (the mean pairwise RMSD) is used. Reuses ``matrix`` if given, else
|
|
118
|
+
computes :func:`rmsd_matrix`. Requires scipy.
|
|
119
|
+
"""
|
|
120
|
+
if method != "hierarchical":
|
|
121
|
+
raise ValueError(f"unknown method {method!r}; only 'hierarchical' is supported")
|
|
122
|
+
|
|
123
|
+
dm = np.asarray(matrix, dtype=float) if matrix is not None else rmsd_matrix(models, align=align)
|
|
124
|
+
if len(dm) < 2:
|
|
125
|
+
return Clustering(labels=np.ones(len(dm), dtype=int), matrix=dm)
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
from scipy.cluster.hierarchy import fcluster
|
|
129
|
+
from scipy.cluster.hierarchy import linkage as _linkage
|
|
130
|
+
from scipy.spatial.distance import squareform
|
|
131
|
+
except ImportError as exc: # pragma: no cover - exercised only without scipy
|
|
132
|
+
raise ImportError(
|
|
133
|
+
"clustering needs scipy; install it with: pip install 'molscope[fast]'"
|
|
134
|
+
) from exc
|
|
135
|
+
|
|
136
|
+
z = _linkage(squareform(dm, checks=False), method=linkage)
|
|
137
|
+
if n_clusters is not None:
|
|
138
|
+
labels = fcluster(z, t=n_clusters, criterion="maxclust")
|
|
139
|
+
else:
|
|
140
|
+
if cutoff is None:
|
|
141
|
+
cutoff = float(dm[np.triu_indices_from(dm, k=1)].mean())
|
|
142
|
+
labels = fcluster(z, t=cutoff, criterion="distance")
|
|
143
|
+
return Clustering(labels=labels, matrix=dm, linkage=z)
|
molscope/graph.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""Molecular graph layer.
|
|
2
|
+
|
|
3
|
+
``Molecule.to_graph()`` turns 3D coordinates plus inferred bonds into a
|
|
4
|
+
:class:`MolecularGraph` — a small, dependency-free container of node and edge
|
|
5
|
+
data. From there it exports to the common ML graph formats:
|
|
6
|
+
|
|
7
|
+
G = mol.to_graph()
|
|
8
|
+
nxg = mol.to_networkx() # networkx.Graph
|
|
9
|
+
data = mol.to_pyg_data() # torch_geometric.data.Data
|
|
10
|
+
dglg = mol.to_dgl_graph() # dgl.DGLGraph
|
|
11
|
+
|
|
12
|
+
The exporters import their backend lazily, so networkx / PyTorch Geometric / DGL
|
|
13
|
+
are only required if you actually call the matching method.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
from . import elements
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class MolecularGraph:
|
|
28
|
+
"""Atoms as nodes, bonds as edges, with attributes attached to both.
|
|
29
|
+
|
|
30
|
+
Nodes carry element, atomic number, mass, coordinates and (when available)
|
|
31
|
+
atom name, residue name/id and chain. Edges carry the bonded atom pair, the
|
|
32
|
+
interatomic distance, and a bond order (``1.0`` for geometrically inferred
|
|
33
|
+
bonds, whose order is unknown).
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
coords: np.ndarray # (N, 3)
|
|
37
|
+
elements: list[str]
|
|
38
|
+
edges: np.ndarray # (E, 2), i < j
|
|
39
|
+
edge_distances: np.ndarray # (E,)
|
|
40
|
+
edge_types: np.ndarray # (E,) bond order; 1.0 when inferred/unknown
|
|
41
|
+
atom_names: list[str] = field(default_factory=list)
|
|
42
|
+
resnames: list[str] = field(default_factory=list)
|
|
43
|
+
resids: np.ndarray = field(default_factory=lambda: np.empty(0, dtype=int))
|
|
44
|
+
chains: list[str] = field(default_factory=list)
|
|
45
|
+
name: str = ""
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def n_atoms(self) -> int:
|
|
49
|
+
return len(self.coords)
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def n_bonds(self) -> int:
|
|
53
|
+
return len(self.edges)
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def atomic_numbers(self) -> np.ndarray:
|
|
57
|
+
return np.array([elements.atomic_number(e) for e in self.elements])
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def masses(self) -> np.ndarray:
|
|
61
|
+
return np.array([elements.mass(e) for e in self.elements])
|
|
62
|
+
|
|
63
|
+
def node_features(self) -> np.ndarray:
|
|
64
|
+
"""Default ``(N, 2)`` node feature matrix: ``[atomic_number, mass]``."""
|
|
65
|
+
return np.stack([self.atomic_numbers, self.masses], axis=1).astype(float)
|
|
66
|
+
|
|
67
|
+
# -- exporters ----------------------------------------------------------
|
|
68
|
+
|
|
69
|
+
def to_networkx(self):
|
|
70
|
+
"""Return a ``networkx.Graph`` with node and edge attributes."""
|
|
71
|
+
nx = _require("networkx", "networkx", "pip install networkx")
|
|
72
|
+
g = nx.Graph(name=self.name)
|
|
73
|
+
z, m = self.atomic_numbers, self.masses
|
|
74
|
+
for i in range(self.n_atoms):
|
|
75
|
+
attrs = {
|
|
76
|
+
"element": self.elements[i],
|
|
77
|
+
"atomic_number": int(z[i]),
|
|
78
|
+
"mass": float(m[i]),
|
|
79
|
+
"pos": tuple(float(c) for c in self.coords[i]),
|
|
80
|
+
}
|
|
81
|
+
if self.atom_names:
|
|
82
|
+
attrs["atom_name"] = self.atom_names[i]
|
|
83
|
+
if self.resnames:
|
|
84
|
+
attrs["resname"] = self.resnames[i]
|
|
85
|
+
if len(self.resids):
|
|
86
|
+
attrs["resid"] = int(self.resids[i])
|
|
87
|
+
if self.chains:
|
|
88
|
+
attrs["chain"] = self.chains[i]
|
|
89
|
+
g.add_node(i, **attrs)
|
|
90
|
+
for (i, j), dist, btype in zip(self.edges, self.edge_distances, self.edge_types):
|
|
91
|
+
g.add_edge(int(i), int(j), distance=float(dist), bond_type=float(btype))
|
|
92
|
+
return g
|
|
93
|
+
|
|
94
|
+
def to_pyg_data(self):
|
|
95
|
+
"""Return a ``torch_geometric.data.Data`` object.
|
|
96
|
+
|
|
97
|
+
Populates ``x`` (node features), ``z`` (atomic numbers), ``pos`` (3D
|
|
98
|
+
coordinates), ``edge_index`` and ``edge_attr`` (distances). Edges are
|
|
99
|
+
made bidirectional for message passing.
|
|
100
|
+
"""
|
|
101
|
+
torch = _require("torch", "PyTorch Geometric", "pip install torch torch_geometric")
|
|
102
|
+
Data = _require(
|
|
103
|
+
"torch_geometric.data", "PyTorch Geometric",
|
|
104
|
+
"pip install torch torch_geometric", attr="Data",
|
|
105
|
+
)
|
|
106
|
+
src, dst, dist = self._directed_edges()
|
|
107
|
+
return Data(
|
|
108
|
+
x=torch.tensor(self.node_features(), dtype=torch.float),
|
|
109
|
+
z=torch.tensor(self.atomic_numbers, dtype=torch.long),
|
|
110
|
+
pos=torch.tensor(self.coords, dtype=torch.float),
|
|
111
|
+
edge_index=torch.tensor(np.stack([src, dst]), dtype=torch.long),
|
|
112
|
+
edge_attr=torch.tensor(dist[:, None], dtype=torch.float),
|
|
113
|
+
num_nodes=self.n_atoms,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def to_dgl_graph(self):
|
|
117
|
+
"""Return a ``dgl.DGLGraph`` with node/edge feature tensors."""
|
|
118
|
+
dgl = _require("dgl", "DGL", "pip install dgl")
|
|
119
|
+
torch = _require("torch", "DGL", "pip install dgl torch")
|
|
120
|
+
src, dst, dist = self._directed_edges()
|
|
121
|
+
g = dgl.graph(
|
|
122
|
+
(torch.tensor(src, dtype=torch.long), torch.tensor(dst, dtype=torch.long)),
|
|
123
|
+
num_nodes=self.n_atoms,
|
|
124
|
+
)
|
|
125
|
+
g.ndata["feat"] = torch.tensor(self.node_features(), dtype=torch.float)
|
|
126
|
+
g.ndata["z"] = torch.tensor(self.atomic_numbers, dtype=torch.long)
|
|
127
|
+
g.ndata["pos"] = torch.tensor(self.coords, dtype=torch.float)
|
|
128
|
+
g.edata["distance"] = torch.tensor(dist, dtype=torch.float)
|
|
129
|
+
return g
|
|
130
|
+
|
|
131
|
+
def _directed_edges(self):
|
|
132
|
+
"""Edges in both directions: (src, dst, distance) for message passing."""
|
|
133
|
+
if self.n_bonds == 0:
|
|
134
|
+
empty_i = np.empty(0, dtype=int)
|
|
135
|
+
return empty_i, empty_i, np.empty(0, dtype=float)
|
|
136
|
+
i, j = self.edges[:, 0], self.edges[:, 1]
|
|
137
|
+
src = np.concatenate([i, j])
|
|
138
|
+
dst = np.concatenate([j, i])
|
|
139
|
+
dist = np.concatenate([self.edge_distances, self.edge_distances])
|
|
140
|
+
return src, dst, dist
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _require(module: str, feature: str, hint: str, attr: Optional[str] = None):
|
|
144
|
+
"""Import a backend module (or attribute), raising a friendly error if absent."""
|
|
145
|
+
import importlib
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
mod = importlib.import_module(module)
|
|
149
|
+
except ImportError as exc: # pragma: no cover - exercised only when missing
|
|
150
|
+
raise ImportError(f"{feature} is required for this export; {hint}") from exc
|
|
151
|
+
return getattr(mod, attr) if attr else mod
|