molscope 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- molscope/__init__.py +53 -0
- molscope/__main__.py +3 -0
- molscope/cli.py +74 -0
- molscope/coarsegrain.py +411 -0
- molscope/contactmap.py +116 -0
- molscope/descriptors.py +232 -0
- molscope/elements.py +75 -0
- molscope/ensemble.py +143 -0
- molscope/graph.py +151 -0
- molscope/io.py +342 -0
- molscope/molecule.py +502 -0
- molscope/plotting.py +191 -0
- molscope-0.6.0.dist-info/METADATA +335 -0
- molscope-0.6.0.dist-info/RECORD +18 -0
- molscope-0.6.0.dist-info/WHEEL +5 -0
- molscope-0.6.0.dist-info/entry_points.txt +2 -0
- molscope-0.6.0.dist-info/licenses/LICENSE +21 -0
- molscope-0.6.0.dist-info/top_level.txt +1 -0
molscope/molecule.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
1
|
+
"""The :class:`Molecule` value type and its geometric operations.
|
|
2
|
+
|
|
3
|
+
Coordinates are held as an ``(N, 3)`` numpy array. Optional per-atom metadata
|
|
4
|
+
(atom name, residue name, residue id, chain) travels alongside, enabling
|
|
5
|
+
selections such as ``mol.backbone()`` or ``mol.select(chain="A")``.
|
|
6
|
+
|
|
7
|
+
Transformations return a new ``Molecule`` rather than mutating in place, so
|
|
8
|
+
chains like ``mol.centered().rotate("z", 90)`` read top to bottom and never
|
|
9
|
+
alias state.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from dataclasses import dataclass, field, replace
|
|
15
|
+
from typing import Any, Optional
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from . import elements
|
|
20
|
+
|
|
21
|
+
# Above this size the dense O(n^2) bond search is refused; install scipy for the
|
|
22
|
+
# KD-tree path (pip install 'molscope[fast]') to handle larger structures.
|
|
23
|
+
_DENSE_BOND_LIMIT = 8000
|
|
24
|
+
|
|
25
|
+
_BACKBONE_ATOMS = ("N", "CA", "C", "O")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True, eq=False)
|
|
29
|
+
class Molecule:
|
|
30
|
+
coords: np.ndarray
|
|
31
|
+
elements: list[str] = field(default_factory=list)
|
|
32
|
+
name: str = ""
|
|
33
|
+
# Optional per-atom metadata; empty when the source format carries none.
|
|
34
|
+
atom_names: list[str] = field(default_factory=list)
|
|
35
|
+
resnames: list[str] = field(default_factory=list)
|
|
36
|
+
resids: np.ndarray = field(default_factory=lambda: np.empty(0, dtype=int))
|
|
37
|
+
chains: list[str] = field(default_factory=list)
|
|
38
|
+
# Optional explicit bonds as an (E, 2) index array. When set, bonds() returns
|
|
39
|
+
# these instead of inferring from geometry (used by coarse-graining).
|
|
40
|
+
bond_index: Optional[np.ndarray] = None
|
|
41
|
+
_mapping_report: Optional[Any] = field(default=None, repr=False, compare=False)
|
|
42
|
+
|
|
43
|
+
def __post_init__(self):
|
|
44
|
+
coords = np.asarray(self.coords, dtype=float).reshape(-1, 3)
|
|
45
|
+
object.__setattr__(self, "coords", coords)
|
|
46
|
+
object.__setattr__(self, "resids", np.asarray(self.resids, dtype=int))
|
|
47
|
+
if self.bond_index is not None:
|
|
48
|
+
object.__setattr__(
|
|
49
|
+
self, "bond_index", np.asarray(self.bond_index, dtype=int).reshape(-1, 2)
|
|
50
|
+
)
|
|
51
|
+
if not self.elements:
|
|
52
|
+
object.__setattr__(self, "elements", [""] * len(coords))
|
|
53
|
+
for name in ("elements", "atom_names", "resnames", "chains"):
|
|
54
|
+
seq = getattr(self, name)
|
|
55
|
+
if seq and len(seq) != len(coords):
|
|
56
|
+
raise ValueError(f"{len(seq)} {name} for {len(coords)} coordinates")
|
|
57
|
+
if len(self.resids) and len(self.resids) != len(coords):
|
|
58
|
+
raise ValueError(f"{len(self.resids)} resids for {len(coords)} coordinates")
|
|
59
|
+
|
|
60
|
+
def __len__(self) -> int:
|
|
61
|
+
return len(self.coords)
|
|
62
|
+
|
|
63
|
+
def __eq__(self, other) -> bool:
|
|
64
|
+
# Auto-generated dataclass __eq__ can't compare the numpy field; do it
|
|
65
|
+
# explicitly. coords are mutable in place, so instances stay unhashable.
|
|
66
|
+
if not isinstance(other, Molecule):
|
|
67
|
+
return NotImplemented
|
|
68
|
+
return (
|
|
69
|
+
self.name == other.name
|
|
70
|
+
and self.elements == other.elements
|
|
71
|
+
and np.array_equal(self.coords, other.coords)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
__hash__ = None
|
|
75
|
+
|
|
76
|
+
def __getitem__(self, selector) -> Molecule:
|
|
77
|
+
"""``mol[mask]`` / ``mol[indices]`` -> a subset molecule (see :meth:`take`)."""
|
|
78
|
+
return self.take(selector)
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def has_topology(self) -> bool:
|
|
82
|
+
"""True if per-atom names/residues/chains were parsed."""
|
|
83
|
+
return bool(self.atom_names)
|
|
84
|
+
|
|
85
|
+
# -- selection ----------------------------------------------------------
|
|
86
|
+
|
|
87
|
+
def take(self, selector) -> Molecule:
|
|
88
|
+
"""Return the subset given by a boolean mask or an array of indices."""
|
|
89
|
+
idx = np.arange(len(self))[np.asarray(selector)]
|
|
90
|
+
|
|
91
|
+
def sub(seq):
|
|
92
|
+
return [seq[i] for i in idx] if seq else []
|
|
93
|
+
|
|
94
|
+
return replace(
|
|
95
|
+
self,
|
|
96
|
+
coords=self.coords[idx],
|
|
97
|
+
elements=sub(self.elements),
|
|
98
|
+
atom_names=sub(self.atom_names),
|
|
99
|
+
resnames=sub(self.resnames),
|
|
100
|
+
resids=self.resids[idx] if len(self.resids) else self.resids,
|
|
101
|
+
chains=sub(self.chains),
|
|
102
|
+
bond_index=self._subset_bonds(idx),
|
|
103
|
+
_mapping_report=None,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def _subset_bonds(self, idx):
|
|
107
|
+
"""Restrict explicit bonds to a kept-atom index set and renumber them."""
|
|
108
|
+
if self.bond_index is None:
|
|
109
|
+
return None
|
|
110
|
+
remap = {old: new for new, old in enumerate(idx)}
|
|
111
|
+
kept = [
|
|
112
|
+
(remap[i], remap[j]) for i, j in self.bond_index
|
|
113
|
+
if i in remap and j in remap
|
|
114
|
+
]
|
|
115
|
+
return np.array(kept, dtype=int).reshape(-1, 2)
|
|
116
|
+
|
|
117
|
+
def select(
|
|
118
|
+
self,
|
|
119
|
+
element=None,
|
|
120
|
+
chain=None,
|
|
121
|
+
resname=None,
|
|
122
|
+
atom_name=None,
|
|
123
|
+
resid=None,
|
|
124
|
+
) -> Molecule:
|
|
125
|
+
"""Return the atoms matching every supplied criterion.
|
|
126
|
+
|
|
127
|
+
Each of ``element``/``chain``/``resname``/``atom_name`` accepts a single
|
|
128
|
+
value or a collection. ``resid`` accepts an int, a collection of ints,
|
|
129
|
+
or a ``(low, high)`` inclusive range. Selecting on metadata the molecule
|
|
130
|
+
lacks raises ``ValueError``.
|
|
131
|
+
"""
|
|
132
|
+
mask = np.ones(len(self), dtype=bool)
|
|
133
|
+
mask &= self._field_mask(self.elements, element, "element", upper=True)
|
|
134
|
+
mask &= self._field_mask(self.chains, chain, "chain")
|
|
135
|
+
mask &= self._field_mask(self.resnames, resname, "residue", upper=True)
|
|
136
|
+
mask &= self._field_mask(self.atom_names, atom_name, "atom name", upper=True)
|
|
137
|
+
if resid is not None:
|
|
138
|
+
mask &= self._resid_mask(resid)
|
|
139
|
+
return self.take(mask)
|
|
140
|
+
|
|
141
|
+
def backbone(self) -> Molecule:
|
|
142
|
+
"""Protein backbone atoms (N, CA, C, O)."""
|
|
143
|
+
return self.select(atom_name=_BACKBONE_ATOMS)
|
|
144
|
+
|
|
145
|
+
def alpha_carbons(self) -> Molecule:
|
|
146
|
+
"""Alpha-carbon (CA) atoms, the usual basis for protein RMSD."""
|
|
147
|
+
return self.select(atom_name="CA")
|
|
148
|
+
|
|
149
|
+
def _field_mask(self, values, criterion, label, upper=False):
|
|
150
|
+
if criterion is None:
|
|
151
|
+
return np.ones(len(self), dtype=bool)
|
|
152
|
+
if not values:
|
|
153
|
+
raise ValueError(f"no {label} information in this molecule")
|
|
154
|
+
wanted = {criterion} if isinstance(criterion, str) else set(criterion)
|
|
155
|
+
if upper:
|
|
156
|
+
wanted = {w.upper() for w in wanted}
|
|
157
|
+
return np.array([v.upper() in wanted for v in values], dtype=bool)
|
|
158
|
+
return np.array([v in wanted for v in values], dtype=bool)
|
|
159
|
+
|
|
160
|
+
def _resid_mask(self, resid):
|
|
161
|
+
if len(self.resids) == 0:
|
|
162
|
+
raise ValueError("no residue-id information in this molecule")
|
|
163
|
+
if isinstance(resid, tuple) and len(resid) == 2:
|
|
164
|
+
low, high = resid
|
|
165
|
+
return (self.resids >= low) & (self.resids <= high)
|
|
166
|
+
wanted = [resid] if isinstance(resid, int) else list(resid)
|
|
167
|
+
return np.isin(self.resids, wanted)
|
|
168
|
+
|
|
169
|
+
def residue_groups(self):
|
|
170
|
+
"""Yield ``(atom_indices, resname, resid, chain)`` per residue, in order.
|
|
171
|
+
|
|
172
|
+
Residues are runs of atoms sharing ``(chain, resid)``. Yields nothing if
|
|
173
|
+
the molecule has no residue information.
|
|
174
|
+
"""
|
|
175
|
+
n = len(self)
|
|
176
|
+
if len(self.resids) == 0 or n == 0:
|
|
177
|
+
return
|
|
178
|
+
chains = self.chains or [""] * n
|
|
179
|
+
resnames = self.resnames or [""] * n
|
|
180
|
+
resids = self.resids
|
|
181
|
+
start = 0
|
|
182
|
+
for i in range(1, n + 1):
|
|
183
|
+
if i == n or chains[i] != chains[i - 1] or resids[i] != resids[i - 1]:
|
|
184
|
+
yield (
|
|
185
|
+
list(range(start, i)), resnames[start],
|
|
186
|
+
int(resids[start]), chains[start],
|
|
187
|
+
)
|
|
188
|
+
start = i
|
|
189
|
+
|
|
190
|
+
# -- geometry -----------------------------------------------------------
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def masses(self) -> np.ndarray:
|
|
194
|
+
"""Per-atom atomic weights (g/mol)."""
|
|
195
|
+
return np.array([elements.mass(e) for e in self.elements])
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def centroid(self) -> np.ndarray:
|
|
199
|
+
"""Geometric centre (mean of all atom positions)."""
|
|
200
|
+
return self.coords.mean(axis=0)
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def center_of_mass(self) -> np.ndarray:
|
|
204
|
+
"""Mass-weighted centre of the molecule."""
|
|
205
|
+
m = self.masses
|
|
206
|
+
return (m[:, None] * self.coords).sum(axis=0) / m.sum()
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def radius_of_gyration(self) -> float:
|
|
210
|
+
"""Mass-weighted radius of gyration (angstrom)."""
|
|
211
|
+
m = self.masses
|
|
212
|
+
d2 = ((self.coords - self.center_of_mass) ** 2).sum(axis=1)
|
|
213
|
+
return float(np.sqrt((m * d2).sum() / m.sum()))
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def dimensions(self) -> np.ndarray:
|
|
217
|
+
"""Axis-aligned bounding-box size (dx, dy, dz) in angstrom."""
|
|
218
|
+
return self.coords.max(axis=0) - self.coords.min(axis=0)
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def formula(self) -> str:
|
|
222
|
+
"""Hill-order molecular formula, e.g. ``"C6 H12 O6"``."""
|
|
223
|
+
from collections import Counter
|
|
224
|
+
|
|
225
|
+
counts = Counter(e.capitalize() for e in self.elements if e)
|
|
226
|
+
if not counts:
|
|
227
|
+
return ""
|
|
228
|
+
ordered = []
|
|
229
|
+
for sym in ("C", "H"):
|
|
230
|
+
if sym in counts:
|
|
231
|
+
ordered.append((sym, counts.pop(sym)))
|
|
232
|
+
ordered += sorted(counts.items())
|
|
233
|
+
return " ".join(f"{s}{n}" if n > 1 else s for s, n in ordered)
|
|
234
|
+
|
|
235
|
+
# -- transforms (each returns a new Molecule) ---------------------------
|
|
236
|
+
|
|
237
|
+
def translate(self, vector) -> Molecule:
|
|
238
|
+
"""Return a copy shifted by ``vector`` (dx, dy, dz)."""
|
|
239
|
+
return replace(self, coords=self.coords + np.asarray(vector, dtype=float))
|
|
240
|
+
|
|
241
|
+
def centered(self, weighted: bool = False) -> Molecule:
|
|
242
|
+
"""Return a copy with its centre at the origin.
|
|
243
|
+
|
|
244
|
+
By default the geometric centroid is used; pass ``weighted=True`` to
|
|
245
|
+
centre on the mass-weighted centre of mass.
|
|
246
|
+
"""
|
|
247
|
+
origin = self.center_of_mass if weighted else self.centroid
|
|
248
|
+
return replace(self, coords=self.coords - origin)
|
|
249
|
+
|
|
250
|
+
def rotate(self, axis, angle_deg: float) -> Molecule:
|
|
251
|
+
"""Return a copy rotated ``angle_deg`` degrees about ``axis``.
|
|
252
|
+
|
|
253
|
+
``axis`` may be ``"x"``, ``"y"``, ``"z"`` or any 3-vector. Rotation is
|
|
254
|
+
about the centroid so the molecule spins in place.
|
|
255
|
+
"""
|
|
256
|
+
vec = {
|
|
257
|
+
"x": (1.0, 0.0, 0.0),
|
|
258
|
+
"y": (0.0, 1.0, 0.0),
|
|
259
|
+
"z": (0.0, 0.0, 1.0),
|
|
260
|
+
}.get(axis, axis)
|
|
261
|
+
rot = _rotation_matrix(np.asarray(vec, dtype=float), np.radians(angle_deg))
|
|
262
|
+
center = self.centroid
|
|
263
|
+
rotated = (self.coords - center) @ rot.T + center
|
|
264
|
+
return replace(self, coords=rotated)
|
|
265
|
+
|
|
266
|
+
def superpose(self, reference: Molecule) -> Molecule:
|
|
267
|
+
"""Return a copy optimally rotated/translated onto ``reference``.
|
|
268
|
+
|
|
269
|
+
Uses the Kabsch algorithm (least-squares rigid-body fit). Requires the
|
|
270
|
+
same number of atoms, matched by index.
|
|
271
|
+
"""
|
|
272
|
+
if len(self) != len(reference):
|
|
273
|
+
raise ValueError(f"atom count mismatch: {len(self)} vs {len(reference)}")
|
|
274
|
+
p = self.coords - self.centroid
|
|
275
|
+
q = reference.coords - reference.centroid
|
|
276
|
+
u, _, vt = np.linalg.svd(p.T @ q)
|
|
277
|
+
d = np.sign(np.linalg.det(vt.T @ u.T))
|
|
278
|
+
rot = vt.T @ np.diag([1.0, 1.0, d]) @ u.T
|
|
279
|
+
aligned = p @ rot.T + reference.centroid
|
|
280
|
+
return replace(self, coords=aligned)
|
|
281
|
+
|
|
282
|
+
# -- measurements & analysis -------------------------------------------
|
|
283
|
+
|
|
284
|
+
def distance(self, i: int, j: int) -> float:
|
|
285
|
+
"""Distance between atoms ``i`` and ``j`` (angstrom)."""
|
|
286
|
+
return float(np.linalg.norm(self.coords[i] - self.coords[j]))
|
|
287
|
+
|
|
288
|
+
def angle(self, i: int, j: int, k: int) -> float:
|
|
289
|
+
"""Angle in degrees at atom ``j`` formed by ``i``-``j``-``k``."""
|
|
290
|
+
a = self.coords[i] - self.coords[j]
|
|
291
|
+
b = self.coords[k] - self.coords[j]
|
|
292
|
+
cos = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
|
293
|
+
return float(np.degrees(np.arccos(np.clip(cos, -1.0, 1.0))))
|
|
294
|
+
|
|
295
|
+
def dihedral(self, a: int, b: int, c: int, d: int) -> float:
|
|
296
|
+
"""Dihedral (torsion) angle in degrees about the ``b``-``c`` bond."""
|
|
297
|
+
v0 = self.coords[a] - self.coords[b]
|
|
298
|
+
v1 = self.coords[c] - self.coords[b]
|
|
299
|
+
v2 = self.coords[d] - self.coords[c]
|
|
300
|
+
v1 = v1 / np.linalg.norm(v1)
|
|
301
|
+
v = v0 - np.dot(v0, v1) * v1
|
|
302
|
+
w = v2 - np.dot(v2, v1) * v1
|
|
303
|
+
x = np.dot(v, w)
|
|
304
|
+
y = np.dot(np.cross(v1, v), w)
|
|
305
|
+
return float(np.degrees(np.arctan2(y, x)))
|
|
306
|
+
|
|
307
|
+
def distance_matrix(self) -> np.ndarray:
|
|
308
|
+
"""Full ``(N, N)`` pairwise distance matrix (angstrom)."""
|
|
309
|
+
deltas = self.coords[:, None, :] - self.coords[None, :, :]
|
|
310
|
+
return np.sqrt((deltas ** 2).sum(axis=-1))
|
|
311
|
+
|
|
312
|
+
def contacts(self, cutoff: float = 5.0) -> np.ndarray:
|
|
313
|
+
"""Atom index pairs ``(i, j)`` closer than ``cutoff`` angstrom."""
|
|
314
|
+
n = len(self)
|
|
315
|
+
if n < 2:
|
|
316
|
+
return np.empty((0, 2), dtype=int)
|
|
317
|
+
try:
|
|
318
|
+
from scipy.spatial import cKDTree
|
|
319
|
+
|
|
320
|
+
return cKDTree(self.coords).query_pairs(cutoff, output_type="ndarray")
|
|
321
|
+
except ImportError:
|
|
322
|
+
dist = self.distance_matrix()
|
|
323
|
+
i, j = np.where(np.triu(dist < cutoff, k=1))
|
|
324
|
+
return np.stack([i, j], axis=1)
|
|
325
|
+
|
|
326
|
+
def contact_map(self, cutoff: float = 8.0, level: str = "residue", method: str = "ca"):
|
|
327
|
+
"""Build a contact map. See :func:`molscope.contactmap.contact_map`.
|
|
328
|
+
|
|
329
|
+
``level`` is ``"atom"`` or ``"residue"``; for residue level ``method`` is
|
|
330
|
+
``"ca"`` (CA-CA distance), ``"com"`` (centre of mass) or ``"min"``
|
|
331
|
+
(closest inter-residue atom). Returns a :class:`ContactMap`.
|
|
332
|
+
"""
|
|
333
|
+
from .contactmap import contact_map
|
|
334
|
+
|
|
335
|
+
return contact_map(self, cutoff=cutoff, level=level, method=method)
|
|
336
|
+
|
|
337
|
+
def plot_contact_map(self, cutoff: float = 8.0, level: str = "residue",
|
|
338
|
+
method: str = "ca", **kwargs):
|
|
339
|
+
"""Shortcut for ``self.contact_map(...).plot()``."""
|
|
340
|
+
return self.contact_map(cutoff, level, method).plot(**kwargs)
|
|
341
|
+
|
|
342
|
+
def rmsd(self, other: Molecule, align: bool = False) -> float:
|
|
343
|
+
"""Root-mean-square deviation from ``other`` (matched by index).
|
|
344
|
+
|
|
345
|
+
With ``align=True`` the molecules are Kabsch-superposed first, giving
|
|
346
|
+
the minimum RMSD over all rigid-body orientations.
|
|
347
|
+
"""
|
|
348
|
+
if len(self) != len(other):
|
|
349
|
+
raise ValueError(f"atom count mismatch: {len(self)} vs {len(other)}")
|
|
350
|
+
a = self.superpose(other).coords if align else self.coords
|
|
351
|
+
return float(np.sqrt(((a - other.coords) ** 2).sum() / len(self)))
|
|
352
|
+
|
|
353
|
+
def bonds(self, tolerance: float = 1.2) -> np.ndarray:
|
|
354
|
+
"""Infer bonds as index pairs ``(i, j)``.
|
|
355
|
+
|
|
356
|
+
Two atoms bond when their separation is within ``tolerance`` times the
|
|
357
|
+
sum of their covalent radii. Returns an ``(M, 2)`` int array.
|
|
358
|
+
|
|
359
|
+
Uses ``scipy.spatial.cKDTree`` when available (scales to large
|
|
360
|
+
structures); otherwise falls back to a dense search that is refused
|
|
361
|
+
above ``_DENSE_BOND_LIMIT`` atoms. If the molecule carries explicit
|
|
362
|
+
bonds (e.g. a coarse-grained model), those are returned directly.
|
|
363
|
+
"""
|
|
364
|
+
if self.bond_index is not None:
|
|
365
|
+
return self.bond_index
|
|
366
|
+
n = len(self.coords)
|
|
367
|
+
if n < 2:
|
|
368
|
+
return np.empty((0, 2), dtype=int)
|
|
369
|
+
radii = np.array([elements.covalent_radius(e) for e in self.elements])
|
|
370
|
+
|
|
371
|
+
try:
|
|
372
|
+
from scipy.spatial import cKDTree
|
|
373
|
+
except ImportError:
|
|
374
|
+
cKDTree = None
|
|
375
|
+
|
|
376
|
+
if cKDTree is not None:
|
|
377
|
+
tree = cKDTree(self.coords)
|
|
378
|
+
cand = tree.query_pairs(tolerance * 2 * radii.max(), output_type="ndarray")
|
|
379
|
+
if len(cand) == 0:
|
|
380
|
+
return np.empty((0, 2), dtype=int)
|
|
381
|
+
i, j = cand[:, 0], cand[:, 1]
|
|
382
|
+
else:
|
|
383
|
+
if n > _DENSE_BOND_LIMIT:
|
|
384
|
+
raise ValueError(
|
|
385
|
+
f"{n} atoms exceeds the dense bond limit ({_DENSE_BOND_LIMIT}); "
|
|
386
|
+
"install scipy (pip install 'molscope[fast]') for large "
|
|
387
|
+
"structures."
|
|
388
|
+
)
|
|
389
|
+
i, j = np.triu_indices(n, k=1)
|
|
390
|
+
|
|
391
|
+
dist = np.linalg.norm(self.coords[i] - self.coords[j], axis=1)
|
|
392
|
+
cutoff = tolerance * (radii[i] + radii[j])
|
|
393
|
+
keep = dist < cutoff
|
|
394
|
+
return np.stack([i[keep], j[keep]], axis=1)
|
|
395
|
+
|
|
396
|
+
def summary(self) -> str:
|
|
397
|
+
"""One-line human-readable description of the molecule."""
|
|
398
|
+
parts = [f"{self.name or 'molecule'}: {len(self)} atoms"]
|
|
399
|
+
if self.formula:
|
|
400
|
+
parts.append(f"formula {self.formula}")
|
|
401
|
+
if self.chains:
|
|
402
|
+
parts.append(f"chains {','.join(sorted(set(self.chains)))}")
|
|
403
|
+
dx, dy, dz = self.dimensions
|
|
404
|
+
parts.append(f"size {dx:.1f}x{dy:.1f}x{dz:.1f} A")
|
|
405
|
+
return " | ".join(parts)
|
|
406
|
+
|
|
407
|
+
# -- coarse-graining ----------------------------------------------------
|
|
408
|
+
|
|
409
|
+
def coarse_grain(self, mapping="residue_com", weighted: bool = True,
|
|
410
|
+
bonds=None, return_report: bool = False):
|
|
411
|
+
"""Map this structure onto CG beads. See :mod:`molscope.coarsegrain`.
|
|
412
|
+
|
|
413
|
+
``mapping`` is ``"residue_com"``, ``"residue_centroid"``, ``"martini"``,
|
|
414
|
+
a ``{resname: {bead: [atom_names]}}`` dict (by residue), or a
|
|
415
|
+
``{bead: [atom_indices]}`` dict (by index, works on any structure).
|
|
416
|
+
``bonds`` optionally defines the bead network as pairs of bead indices,
|
|
417
|
+
or bead names when those names are unique. Repeated residue bead names
|
|
418
|
+
such as ``BB``/``SC`` are ambiguous; use indices for those. Returns a
|
|
419
|
+
new ``Molecule`` of beads with CG bonds attached, or ``(molecule,
|
|
420
|
+
report)`` when ``return_report=True``.
|
|
421
|
+
"""
|
|
422
|
+
from .coarsegrain import coarse_grain
|
|
423
|
+
|
|
424
|
+
return coarse_grain(
|
|
425
|
+
self, mapping=mapping, weighted=weighted, bonds=bonds,
|
|
426
|
+
return_report=return_report,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
def mapping_report(self) -> str:
|
|
430
|
+
"""Explain how this coarse-grained molecule was mapped."""
|
|
431
|
+
if self._mapping_report is None:
|
|
432
|
+
raise ValueError("no coarse-graining report is available for this molecule")
|
|
433
|
+
return self._mapping_report.format()
|
|
434
|
+
|
|
435
|
+
# -- ML descriptors -----------------------------------------------------
|
|
436
|
+
|
|
437
|
+
def descriptors(self, **kwargs) -> dict:
|
|
438
|
+
"""Return fixed-size molecular descriptors for quick ML features."""
|
|
439
|
+
from .descriptors import descriptors
|
|
440
|
+
|
|
441
|
+
return descriptors(self, **kwargs)
|
|
442
|
+
|
|
443
|
+
# -- graph export -------------------------------------------------------
|
|
444
|
+
|
|
445
|
+
def to_graph(self, tolerance: float = 1.2, bonds=None):
|
|
446
|
+
"""Build a :class:`molscope.graph.MolecularGraph` from this molecule.
|
|
447
|
+
|
|
448
|
+
Bonds are inferred from covalent radii (see :meth:`bonds`) unless an
|
|
449
|
+
explicit ``(E, 2)`` array of index pairs is passed. Node and edge
|
|
450
|
+
attributes (element, residue, chain, distance, ...) are carried along.
|
|
451
|
+
"""
|
|
452
|
+
from .graph import MolecularGraph
|
|
453
|
+
|
|
454
|
+
edges = self.bonds(tolerance) if bonds is None else np.asarray(bonds, dtype=int)
|
|
455
|
+
edges = edges.reshape(-1, 2)
|
|
456
|
+
if len(edges):
|
|
457
|
+
dist = np.linalg.norm(self.coords[edges[:, 0]] - self.coords[edges[:, 1]], axis=1)
|
|
458
|
+
else:
|
|
459
|
+
dist = np.empty(0, dtype=float)
|
|
460
|
+
return MolecularGraph(
|
|
461
|
+
coords=self.coords, elements=self.elements, edges=edges,
|
|
462
|
+
edge_distances=dist, edge_types=np.ones(len(edges)),
|
|
463
|
+
atom_names=self.atom_names, resnames=self.resnames,
|
|
464
|
+
resids=self.resids, chains=self.chains, name=self.name,
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
def to_networkx(self, **kwargs):
|
|
468
|
+
"""Shortcut for ``self.to_graph(...).to_networkx()``."""
|
|
469
|
+
return self.to_graph(**kwargs).to_networkx()
|
|
470
|
+
|
|
471
|
+
def to_pyg_data(self, **kwargs):
|
|
472
|
+
"""Shortcut for ``self.to_graph(...).to_pyg_data()`` (PyTorch Geometric)."""
|
|
473
|
+
return self.to_graph(**kwargs).to_pyg_data()
|
|
474
|
+
|
|
475
|
+
def to_dgl_graph(self, **kwargs):
|
|
476
|
+
"""Shortcut for ``self.to_graph(...).to_dgl_graph()`` (DGL)."""
|
|
477
|
+
return self.to_graph(**kwargs).to_dgl_graph()
|
|
478
|
+
|
|
479
|
+
def plot(self, **kwargs):
|
|
480
|
+
"""Render the molecule in 3D. See :func:`molscope.plotting.plot`."""
|
|
481
|
+
from .plotting import plot
|
|
482
|
+
|
|
483
|
+
return plot(self, **kwargs)
|
|
484
|
+
|
|
485
|
+
def view(self, **kwargs):
|
|
486
|
+
"""Interactive py3Dmol viewer. See :func:`molscope.plotting.view`."""
|
|
487
|
+
from .plotting import view
|
|
488
|
+
|
|
489
|
+
return view(self, **kwargs)
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def _rotation_matrix(axis: np.ndarray, angle: float) -> np.ndarray:
|
|
493
|
+
"""Rodrigues rotation matrix for ``angle`` radians about ``axis``."""
|
|
494
|
+
axis = axis / np.linalg.norm(axis)
|
|
495
|
+
x, y, z = axis
|
|
496
|
+
c, s = np.cos(angle), np.sin(angle)
|
|
497
|
+
C = 1 - c
|
|
498
|
+
return np.array([
|
|
499
|
+
[c + x * x * C, x * y * C - z * s, x * z * C + y * s],
|
|
500
|
+
[y * x * C + z * s, c + y * y * C, y * z * C - x * s],
|
|
501
|
+
[z * x * C - y * s, z * y * C + x * s, c + z * z * C],
|
|
502
|
+
])
|
molscope/plotting.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""3D visualization of molecules: matplotlib, py3Dmol, and GIF export."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import itertools
|
|
6
|
+
import warnings
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from . import elements
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def plot(
|
|
15
|
+
molecule,
|
|
16
|
+
show_bonds: Optional[bool] = None,
|
|
17
|
+
bond_tolerance: float = 1.2,
|
|
18
|
+
color_by: str = "element",
|
|
19
|
+
scale: float = 60.0,
|
|
20
|
+
ax=None,
|
|
21
|
+
show: bool = True,
|
|
22
|
+
):
|
|
23
|
+
"""Scatter-plot atoms in 3D with an equal aspect ratio.
|
|
24
|
+
|
|
25
|
+
``color_by`` selects the colouring: ``"element"`` (CPK), ``"chain"`` or
|
|
26
|
+
``"residue"`` (categorical palette). Atom sizes scale with covalent radius.
|
|
27
|
+
Bonds are drawn when ``show_bonds`` is true, or, when ``None``, automatically
|
|
28
|
+
for molecules small enough to infer bonds cheaply. Returns the ``Axes3D``;
|
|
29
|
+
pass ``show=False`` to suppress ``plt.show()``.
|
|
30
|
+
"""
|
|
31
|
+
import matplotlib.pyplot as plt # imported lazily so the core has no GUI dep
|
|
32
|
+
|
|
33
|
+
coords = molecule.coords
|
|
34
|
+
if ax is None:
|
|
35
|
+
fig = plt.figure()
|
|
36
|
+
ax = fig.add_subplot(1, 1, 1, projection="3d")
|
|
37
|
+
|
|
38
|
+
colors = _colors(molecule, color_by)
|
|
39
|
+
sizes = np.array([elements.covalent_radius(e) for e in molecule.elements]) * scale
|
|
40
|
+
ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], c=colors, s=sizes, depthshade=True)
|
|
41
|
+
|
|
42
|
+
if show_bonds is None:
|
|
43
|
+
show_bonds = len(molecule) <= 2000
|
|
44
|
+
if show_bonds and len(molecule) > 1:
|
|
45
|
+
try:
|
|
46
|
+
for i, j in molecule.bonds(tolerance=bond_tolerance):
|
|
47
|
+
seg = coords[[i, j]]
|
|
48
|
+
ax.plot(seg[:, 0], seg[:, 1], seg[:, 2], color="0.5", linewidth=1.0)
|
|
49
|
+
except ValueError as exc:
|
|
50
|
+
warnings.warn(f"skipping bonds: {exc}", stacklevel=2)
|
|
51
|
+
|
|
52
|
+
ax.set_xlabel("X")
|
|
53
|
+
ax.set_ylabel("Y")
|
|
54
|
+
ax.set_zlabel("Z")
|
|
55
|
+
if molecule.name:
|
|
56
|
+
ax.set_title(molecule.name)
|
|
57
|
+
_set_equal_aspect(ax, coords)
|
|
58
|
+
|
|
59
|
+
if show:
|
|
60
|
+
plt.show()
|
|
61
|
+
return ax
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def view(molecule, style: str = "stick", width: int = 480, height: int = 360):
|
|
65
|
+
"""Return an interactive py3Dmol viewer (for Jupyter notebooks).
|
|
66
|
+
|
|
67
|
+
Requires py3Dmol (``pip install py3Dmol``). ``style`` is any py3Dmol style
|
|
68
|
+
name such as ``"stick"``, ``"sphere"``, ``"line"`` or ``"cartoon"``.
|
|
69
|
+
"""
|
|
70
|
+
try:
|
|
71
|
+
import py3Dmol
|
|
72
|
+
except ImportError as exc: # pragma: no cover - exercised only without py3Dmol
|
|
73
|
+
raise ImportError(
|
|
74
|
+
"view() needs py3Dmol; install it with: pip install py3Dmol"
|
|
75
|
+
) from exc
|
|
76
|
+
from .io import _molecule_to_pdb_string
|
|
77
|
+
|
|
78
|
+
viewer = py3Dmol.view(width=width, height=height)
|
|
79
|
+
viewer.addModel(_molecule_to_pdb_string(molecule), "pdb")
|
|
80
|
+
viewer.setStyle({style: {"colorscheme": "default"}})
|
|
81
|
+
viewer.zoomTo()
|
|
82
|
+
return viewer
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def spin_gif(molecule, path: str, frames: int = 36, fps: int = 15, **plot_kwargs):
|
|
86
|
+
"""Render a spinning 3D view and save it as an animated GIF.
|
|
87
|
+
|
|
88
|
+
Rotates a full turn about the vertical axis over ``frames`` steps. Requires
|
|
89
|
+
Pillow (already a matplotlib dependency).
|
|
90
|
+
"""
|
|
91
|
+
import matplotlib.pyplot as plt
|
|
92
|
+
from matplotlib import animation
|
|
93
|
+
|
|
94
|
+
fig = plt.figure()
|
|
95
|
+
ax = fig.add_subplot(1, 1, 1, projection="3d")
|
|
96
|
+
plot(molecule, ax=ax, show=False, **plot_kwargs)
|
|
97
|
+
|
|
98
|
+
def update(i):
|
|
99
|
+
ax.view_init(elev=20, azim=i * 360 / frames)
|
|
100
|
+
return ()
|
|
101
|
+
|
|
102
|
+
anim = animation.FuncAnimation(fig, update, frames=frames, blit=False)
|
|
103
|
+
anim.save(path, writer=animation.PillowWriter(fps=fps))
|
|
104
|
+
plt.close(fig)
|
|
105
|
+
return path
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def plot_contact_map(contact_map, ax=None, cmap=None, show: bool = True):
|
|
109
|
+
"""Draw a :class:`~molscope.contactmap.ContactMap` as a heatmap.
|
|
110
|
+
|
|
111
|
+
Booleans render as a binary map; ensemble frequencies render with a colour
|
|
112
|
+
scale and a colourbar. Returns the matplotlib ``Axes``.
|
|
113
|
+
"""
|
|
114
|
+
import matplotlib.pyplot as plt
|
|
115
|
+
|
|
116
|
+
mat = contact_map.matrix
|
|
117
|
+
freq = contact_map.is_frequency
|
|
118
|
+
if ax is None:
|
|
119
|
+
_, ax = plt.subplots(figsize=(5, 4))
|
|
120
|
+
im = ax.imshow(
|
|
121
|
+
mat, origin="lower", interpolation="nearest", vmin=0, vmax=1,
|
|
122
|
+
cmap=cmap or ("viridis" if freq else "Greys"),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
unit = "residue" if contact_map.level == "residue" else "atom"
|
|
126
|
+
ax.set_xlabel(f"{unit} index")
|
|
127
|
+
ax.set_ylabel(f"{unit} index")
|
|
128
|
+
label = "contact frequency" if freq else f"contact (< {contact_map.cutoff} Å)"
|
|
129
|
+
ax.figure.colorbar(im, ax=ax, label=label, fraction=0.046, pad=0.04)
|
|
130
|
+
ax.set_title(f"{unit} contact map ({contact_map.cutoff} Å)")
|
|
131
|
+
|
|
132
|
+
if show:
|
|
133
|
+
plt.show()
|
|
134
|
+
return ax
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def plot_rmsd_heatmap(matrix, order=None, ax=None, cmap="viridis", show: bool = True):
|
|
138
|
+
"""Draw a pairwise-RMSD matrix as a heatmap (angstrom).
|
|
139
|
+
|
|
140
|
+
Pass ``order`` (e.g. ``clustering.order``) to reorder rows/columns so
|
|
141
|
+
clusters appear as blocks along the diagonal. Returns the matplotlib ``Axes``.
|
|
142
|
+
"""
|
|
143
|
+
import matplotlib.pyplot as plt
|
|
144
|
+
|
|
145
|
+
matrix = np.asarray(matrix, dtype=float)
|
|
146
|
+
if order is not None:
|
|
147
|
+
order = np.asarray(order)
|
|
148
|
+
matrix = matrix[np.ix_(order, order)]
|
|
149
|
+
if ax is None:
|
|
150
|
+
_, ax = plt.subplots(figsize=(5, 4))
|
|
151
|
+
im = ax.imshow(matrix, origin="lower", interpolation="nearest", cmap=cmap)
|
|
152
|
+
ax.set_xlabel("model")
|
|
153
|
+
ax.set_ylabel("model")
|
|
154
|
+
ax.figure.colorbar(im, ax=ax, label="RMSD (Å)", fraction=0.046, pad=0.04)
|
|
155
|
+
ax.set_title("pairwise RMSD")
|
|
156
|
+
if show:
|
|
157
|
+
plt.show()
|
|
158
|
+
return ax
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _colors(molecule, color_by: str):
|
|
162
|
+
if color_by == "element":
|
|
163
|
+
return [elements.color(e) for e in molecule.elements]
|
|
164
|
+
if color_by == "chain":
|
|
165
|
+
keys = molecule.chains
|
|
166
|
+
elif color_by == "residue":
|
|
167
|
+
keys = [str(r) for r in molecule.resids] if len(molecule.resids) else []
|
|
168
|
+
else:
|
|
169
|
+
raise ValueError(f"unknown color_by {color_by!r}")
|
|
170
|
+
if not keys:
|
|
171
|
+
raise ValueError(f"no {color_by} information to colour by")
|
|
172
|
+
return _categorical_colors(keys)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _categorical_colors(keys):
|
|
176
|
+
import matplotlib.pyplot as plt
|
|
177
|
+
|
|
178
|
+
palette = plt.get_cmap("tab20").colors
|
|
179
|
+
cycle = {}
|
|
180
|
+
wheel = itertools.cycle(palette)
|
|
181
|
+
return [cycle.setdefault(k, next(wheel)) for k in keys]
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _set_equal_aspect(ax, coords: np.ndarray) -> None:
|
|
185
|
+
"""Force equal scaling on all axes so the molecule isn't distorted."""
|
|
186
|
+
mins, maxs = coords.min(axis=0), coords.max(axis=0)
|
|
187
|
+
centers = (maxs + mins) / 2
|
|
188
|
+
radius = (maxs - mins).max() / 2 or 1.0
|
|
189
|
+
ax.set_xlim(centers[0] - radius, centers[0] + radius)
|
|
190
|
+
ax.set_ylim(centers[1] - radius, centers[1] + radius)
|
|
191
|
+
ax.set_zlim(centers[2] - radius, centers[2] + radius)
|