molscope 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
molscope/molecule.py ADDED
@@ -0,0 +1,502 @@
1
+ """The :class:`Molecule` value type and its geometric operations.
2
+
3
+ Coordinates are held as an ``(N, 3)`` numpy array. Optional per-atom metadata
4
+ (atom name, residue name, residue id, chain) travels alongside, enabling
5
+ selections such as ``mol.backbone()`` or ``mol.select(chain="A")``.
6
+
7
+ Transformations return a new ``Molecule`` rather than mutating in place, so
8
+ chains like ``mol.centered().rotate("z", 90)`` read top to bottom and never
9
+ alias state.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass, field, replace
15
+ from typing import Any, Optional
16
+
17
+ import numpy as np
18
+
19
+ from . import elements
20
+
21
+ # Above this size the dense O(n^2) bond search is refused; install scipy for the
22
+ # KD-tree path (pip install 'molscope[fast]') to handle larger structures.
23
+ _DENSE_BOND_LIMIT = 8000
24
+
25
+ _BACKBONE_ATOMS = ("N", "CA", "C", "O")
26
+
27
+
28
+ @dataclass(frozen=True, eq=False)
29
+ class Molecule:
30
+ coords: np.ndarray
31
+ elements: list[str] = field(default_factory=list)
32
+ name: str = ""
33
+ # Optional per-atom metadata; empty when the source format carries none.
34
+ atom_names: list[str] = field(default_factory=list)
35
+ resnames: list[str] = field(default_factory=list)
36
+ resids: np.ndarray = field(default_factory=lambda: np.empty(0, dtype=int))
37
+ chains: list[str] = field(default_factory=list)
38
+ # Optional explicit bonds as an (E, 2) index array. When set, bonds() returns
39
+ # these instead of inferring from geometry (used by coarse-graining).
40
+ bond_index: Optional[np.ndarray] = None
41
+ _mapping_report: Optional[Any] = field(default=None, repr=False, compare=False)
42
+
43
+ def __post_init__(self):
44
+ coords = np.asarray(self.coords, dtype=float).reshape(-1, 3)
45
+ object.__setattr__(self, "coords", coords)
46
+ object.__setattr__(self, "resids", np.asarray(self.resids, dtype=int))
47
+ if self.bond_index is not None:
48
+ object.__setattr__(
49
+ self, "bond_index", np.asarray(self.bond_index, dtype=int).reshape(-1, 2)
50
+ )
51
+ if not self.elements:
52
+ object.__setattr__(self, "elements", [""] * len(coords))
53
+ for name in ("elements", "atom_names", "resnames", "chains"):
54
+ seq = getattr(self, name)
55
+ if seq and len(seq) != len(coords):
56
+ raise ValueError(f"{len(seq)} {name} for {len(coords)} coordinates")
57
+ if len(self.resids) and len(self.resids) != len(coords):
58
+ raise ValueError(f"{len(self.resids)} resids for {len(coords)} coordinates")
59
+
60
+ def __len__(self) -> int:
61
+ return len(self.coords)
62
+
63
+ def __eq__(self, other) -> bool:
64
+ # Auto-generated dataclass __eq__ can't compare the numpy field; do it
65
+ # explicitly. coords are mutable in place, so instances stay unhashable.
66
+ if not isinstance(other, Molecule):
67
+ return NotImplemented
68
+ return (
69
+ self.name == other.name
70
+ and self.elements == other.elements
71
+ and np.array_equal(self.coords, other.coords)
72
+ )
73
+
74
+ __hash__ = None
75
+
76
+ def __getitem__(self, selector) -> Molecule:
77
+ """``mol[mask]`` / ``mol[indices]`` -> a subset molecule (see :meth:`take`)."""
78
+ return self.take(selector)
79
+
80
+ @property
81
+ def has_topology(self) -> bool:
82
+ """True if per-atom names/residues/chains were parsed."""
83
+ return bool(self.atom_names)
84
+
85
+ # -- selection ----------------------------------------------------------
86
+
87
+ def take(self, selector) -> Molecule:
88
+ """Return the subset given by a boolean mask or an array of indices."""
89
+ idx = np.arange(len(self))[np.asarray(selector)]
90
+
91
+ def sub(seq):
92
+ return [seq[i] for i in idx] if seq else []
93
+
94
+ return replace(
95
+ self,
96
+ coords=self.coords[idx],
97
+ elements=sub(self.elements),
98
+ atom_names=sub(self.atom_names),
99
+ resnames=sub(self.resnames),
100
+ resids=self.resids[idx] if len(self.resids) else self.resids,
101
+ chains=sub(self.chains),
102
+ bond_index=self._subset_bonds(idx),
103
+ _mapping_report=None,
104
+ )
105
+
106
+ def _subset_bonds(self, idx):
107
+ """Restrict explicit bonds to a kept-atom index set and renumber them."""
108
+ if self.bond_index is None:
109
+ return None
110
+ remap = {old: new for new, old in enumerate(idx)}
111
+ kept = [
112
+ (remap[i], remap[j]) for i, j in self.bond_index
113
+ if i in remap and j in remap
114
+ ]
115
+ return np.array(kept, dtype=int).reshape(-1, 2)
116
+
117
+ def select(
118
+ self,
119
+ element=None,
120
+ chain=None,
121
+ resname=None,
122
+ atom_name=None,
123
+ resid=None,
124
+ ) -> Molecule:
125
+ """Return the atoms matching every supplied criterion.
126
+
127
+ Each of ``element``/``chain``/``resname``/``atom_name`` accepts a single
128
+ value or a collection. ``resid`` accepts an int, a collection of ints,
129
+ or a ``(low, high)`` inclusive range. Selecting on metadata the molecule
130
+ lacks raises ``ValueError``.
131
+ """
132
+ mask = np.ones(len(self), dtype=bool)
133
+ mask &= self._field_mask(self.elements, element, "element", upper=True)
134
+ mask &= self._field_mask(self.chains, chain, "chain")
135
+ mask &= self._field_mask(self.resnames, resname, "residue", upper=True)
136
+ mask &= self._field_mask(self.atom_names, atom_name, "atom name", upper=True)
137
+ if resid is not None:
138
+ mask &= self._resid_mask(resid)
139
+ return self.take(mask)
140
+
141
+ def backbone(self) -> Molecule:
142
+ """Protein backbone atoms (N, CA, C, O)."""
143
+ return self.select(atom_name=_BACKBONE_ATOMS)
144
+
145
+ def alpha_carbons(self) -> Molecule:
146
+ """Alpha-carbon (CA) atoms, the usual basis for protein RMSD."""
147
+ return self.select(atom_name="CA")
148
+
149
+ def _field_mask(self, values, criterion, label, upper=False):
150
+ if criterion is None:
151
+ return np.ones(len(self), dtype=bool)
152
+ if not values:
153
+ raise ValueError(f"no {label} information in this molecule")
154
+ wanted = {criterion} if isinstance(criterion, str) else set(criterion)
155
+ if upper:
156
+ wanted = {w.upper() for w in wanted}
157
+ return np.array([v.upper() in wanted for v in values], dtype=bool)
158
+ return np.array([v in wanted for v in values], dtype=bool)
159
+
160
+ def _resid_mask(self, resid):
161
+ if len(self.resids) == 0:
162
+ raise ValueError("no residue-id information in this molecule")
163
+ if isinstance(resid, tuple) and len(resid) == 2:
164
+ low, high = resid
165
+ return (self.resids >= low) & (self.resids <= high)
166
+ wanted = [resid] if isinstance(resid, int) else list(resid)
167
+ return np.isin(self.resids, wanted)
168
+
169
+ def residue_groups(self):
170
+ """Yield ``(atom_indices, resname, resid, chain)`` per residue, in order.
171
+
172
+ Residues are runs of atoms sharing ``(chain, resid)``. Yields nothing if
173
+ the molecule has no residue information.
174
+ """
175
+ n = len(self)
176
+ if len(self.resids) == 0 or n == 0:
177
+ return
178
+ chains = self.chains or [""] * n
179
+ resnames = self.resnames or [""] * n
180
+ resids = self.resids
181
+ start = 0
182
+ for i in range(1, n + 1):
183
+ if i == n or chains[i] != chains[i - 1] or resids[i] != resids[i - 1]:
184
+ yield (
185
+ list(range(start, i)), resnames[start],
186
+ int(resids[start]), chains[start],
187
+ )
188
+ start = i
189
+
190
+ # -- geometry -----------------------------------------------------------
191
+
192
+ @property
193
+ def masses(self) -> np.ndarray:
194
+ """Per-atom atomic weights (g/mol)."""
195
+ return np.array([elements.mass(e) for e in self.elements])
196
+
197
+ @property
198
+ def centroid(self) -> np.ndarray:
199
+ """Geometric centre (mean of all atom positions)."""
200
+ return self.coords.mean(axis=0)
201
+
202
+ @property
203
+ def center_of_mass(self) -> np.ndarray:
204
+ """Mass-weighted centre of the molecule."""
205
+ m = self.masses
206
+ return (m[:, None] * self.coords).sum(axis=0) / m.sum()
207
+
208
+ @property
209
+ def radius_of_gyration(self) -> float:
210
+ """Mass-weighted radius of gyration (angstrom)."""
211
+ m = self.masses
212
+ d2 = ((self.coords - self.center_of_mass) ** 2).sum(axis=1)
213
+ return float(np.sqrt((m * d2).sum() / m.sum()))
214
+
215
+ @property
216
+ def dimensions(self) -> np.ndarray:
217
+ """Axis-aligned bounding-box size (dx, dy, dz) in angstrom."""
218
+ return self.coords.max(axis=0) - self.coords.min(axis=0)
219
+
220
+ @property
221
+ def formula(self) -> str:
222
+ """Hill-order molecular formula, e.g. ``"C6 H12 O6"``."""
223
+ from collections import Counter
224
+
225
+ counts = Counter(e.capitalize() for e in self.elements if e)
226
+ if not counts:
227
+ return ""
228
+ ordered = []
229
+ for sym in ("C", "H"):
230
+ if sym in counts:
231
+ ordered.append((sym, counts.pop(sym)))
232
+ ordered += sorted(counts.items())
233
+ return " ".join(f"{s}{n}" if n > 1 else s for s, n in ordered)
234
+
235
+ # -- transforms (each returns a new Molecule) ---------------------------
236
+
237
+ def translate(self, vector) -> Molecule:
238
+ """Return a copy shifted by ``vector`` (dx, dy, dz)."""
239
+ return replace(self, coords=self.coords + np.asarray(vector, dtype=float))
240
+
241
+ def centered(self, weighted: bool = False) -> Molecule:
242
+ """Return a copy with its centre at the origin.
243
+
244
+ By default the geometric centroid is used; pass ``weighted=True`` to
245
+ centre on the mass-weighted centre of mass.
246
+ """
247
+ origin = self.center_of_mass if weighted else self.centroid
248
+ return replace(self, coords=self.coords - origin)
249
+
250
+ def rotate(self, axis, angle_deg: float) -> Molecule:
251
+ """Return a copy rotated ``angle_deg`` degrees about ``axis``.
252
+
253
+ ``axis`` may be ``"x"``, ``"y"``, ``"z"`` or any 3-vector. Rotation is
254
+ about the centroid so the molecule spins in place.
255
+ """
256
+ vec = {
257
+ "x": (1.0, 0.0, 0.0),
258
+ "y": (0.0, 1.0, 0.0),
259
+ "z": (0.0, 0.0, 1.0),
260
+ }.get(axis, axis)
261
+ rot = _rotation_matrix(np.asarray(vec, dtype=float), np.radians(angle_deg))
262
+ center = self.centroid
263
+ rotated = (self.coords - center) @ rot.T + center
264
+ return replace(self, coords=rotated)
265
+
266
+ def superpose(self, reference: Molecule) -> Molecule:
267
+ """Return a copy optimally rotated/translated onto ``reference``.
268
+
269
+ Uses the Kabsch algorithm (least-squares rigid-body fit). Requires the
270
+ same number of atoms, matched by index.
271
+ """
272
+ if len(self) != len(reference):
273
+ raise ValueError(f"atom count mismatch: {len(self)} vs {len(reference)}")
274
+ p = self.coords - self.centroid
275
+ q = reference.coords - reference.centroid
276
+ u, _, vt = np.linalg.svd(p.T @ q)
277
+ d = np.sign(np.linalg.det(vt.T @ u.T))
278
+ rot = vt.T @ np.diag([1.0, 1.0, d]) @ u.T
279
+ aligned = p @ rot.T + reference.centroid
280
+ return replace(self, coords=aligned)
281
+
282
+ # -- measurements & analysis -------------------------------------------
283
+
284
+ def distance(self, i: int, j: int) -> float:
285
+ """Distance between atoms ``i`` and ``j`` (angstrom)."""
286
+ return float(np.linalg.norm(self.coords[i] - self.coords[j]))
287
+
288
+ def angle(self, i: int, j: int, k: int) -> float:
289
+ """Angle in degrees at atom ``j`` formed by ``i``-``j``-``k``."""
290
+ a = self.coords[i] - self.coords[j]
291
+ b = self.coords[k] - self.coords[j]
292
+ cos = np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
293
+ return float(np.degrees(np.arccos(np.clip(cos, -1.0, 1.0))))
294
+
295
+ def dihedral(self, a: int, b: int, c: int, d: int) -> float:
296
+ """Dihedral (torsion) angle in degrees about the ``b``-``c`` bond."""
297
+ v0 = self.coords[a] - self.coords[b]
298
+ v1 = self.coords[c] - self.coords[b]
299
+ v2 = self.coords[d] - self.coords[c]
300
+ v1 = v1 / np.linalg.norm(v1)
301
+ v = v0 - np.dot(v0, v1) * v1
302
+ w = v2 - np.dot(v2, v1) * v1
303
+ x = np.dot(v, w)
304
+ y = np.dot(np.cross(v1, v), w)
305
+ return float(np.degrees(np.arctan2(y, x)))
306
+
307
+ def distance_matrix(self) -> np.ndarray:
308
+ """Full ``(N, N)`` pairwise distance matrix (angstrom)."""
309
+ deltas = self.coords[:, None, :] - self.coords[None, :, :]
310
+ return np.sqrt((deltas ** 2).sum(axis=-1))
311
+
312
+ def contacts(self, cutoff: float = 5.0) -> np.ndarray:
313
+ """Atom index pairs ``(i, j)`` closer than ``cutoff`` angstrom."""
314
+ n = len(self)
315
+ if n < 2:
316
+ return np.empty((0, 2), dtype=int)
317
+ try:
318
+ from scipy.spatial import cKDTree
319
+
320
+ return cKDTree(self.coords).query_pairs(cutoff, output_type="ndarray")
321
+ except ImportError:
322
+ dist = self.distance_matrix()
323
+ i, j = np.where(np.triu(dist < cutoff, k=1))
324
+ return np.stack([i, j], axis=1)
325
+
326
+ def contact_map(self, cutoff: float = 8.0, level: str = "residue", method: str = "ca"):
327
+ """Build a contact map. See :func:`molscope.contactmap.contact_map`.
328
+
329
+ ``level`` is ``"atom"`` or ``"residue"``; for residue level ``method`` is
330
+ ``"ca"`` (CA-CA distance), ``"com"`` (centre of mass) or ``"min"``
331
+ (closest inter-residue atom). Returns a :class:`ContactMap`.
332
+ """
333
+ from .contactmap import contact_map
334
+
335
+ return contact_map(self, cutoff=cutoff, level=level, method=method)
336
+
337
+ def plot_contact_map(self, cutoff: float = 8.0, level: str = "residue",
338
+ method: str = "ca", **kwargs):
339
+ """Shortcut for ``self.contact_map(...).plot()``."""
340
+ return self.contact_map(cutoff, level, method).plot(**kwargs)
341
+
342
+ def rmsd(self, other: Molecule, align: bool = False) -> float:
343
+ """Root-mean-square deviation from ``other`` (matched by index).
344
+
345
+ With ``align=True`` the molecules are Kabsch-superposed first, giving
346
+ the minimum RMSD over all rigid-body orientations.
347
+ """
348
+ if len(self) != len(other):
349
+ raise ValueError(f"atom count mismatch: {len(self)} vs {len(other)}")
350
+ a = self.superpose(other).coords if align else self.coords
351
+ return float(np.sqrt(((a - other.coords) ** 2).sum() / len(self)))
352
+
353
+ def bonds(self, tolerance: float = 1.2) -> np.ndarray:
354
+ """Infer bonds as index pairs ``(i, j)``.
355
+
356
+ Two atoms bond when their separation is within ``tolerance`` times the
357
+ sum of their covalent radii. Returns an ``(M, 2)`` int array.
358
+
359
+ Uses ``scipy.spatial.cKDTree`` when available (scales to large
360
+ structures); otherwise falls back to a dense search that is refused
361
+ above ``_DENSE_BOND_LIMIT`` atoms. If the molecule carries explicit
362
+ bonds (e.g. a coarse-grained model), those are returned directly.
363
+ """
364
+ if self.bond_index is not None:
365
+ return self.bond_index
366
+ n = len(self.coords)
367
+ if n < 2:
368
+ return np.empty((0, 2), dtype=int)
369
+ radii = np.array([elements.covalent_radius(e) for e in self.elements])
370
+
371
+ try:
372
+ from scipy.spatial import cKDTree
373
+ except ImportError:
374
+ cKDTree = None
375
+
376
+ if cKDTree is not None:
377
+ tree = cKDTree(self.coords)
378
+ cand = tree.query_pairs(tolerance * 2 * radii.max(), output_type="ndarray")
379
+ if len(cand) == 0:
380
+ return np.empty((0, 2), dtype=int)
381
+ i, j = cand[:, 0], cand[:, 1]
382
+ else:
383
+ if n > _DENSE_BOND_LIMIT:
384
+ raise ValueError(
385
+ f"{n} atoms exceeds the dense bond limit ({_DENSE_BOND_LIMIT}); "
386
+ "install scipy (pip install 'molscope[fast]') for large "
387
+ "structures."
388
+ )
389
+ i, j = np.triu_indices(n, k=1)
390
+
391
+ dist = np.linalg.norm(self.coords[i] - self.coords[j], axis=1)
392
+ cutoff = tolerance * (radii[i] + radii[j])
393
+ keep = dist < cutoff
394
+ return np.stack([i[keep], j[keep]], axis=1)
395
+
396
+ def summary(self) -> str:
397
+ """One-line human-readable description of the molecule."""
398
+ parts = [f"{self.name or 'molecule'}: {len(self)} atoms"]
399
+ if self.formula:
400
+ parts.append(f"formula {self.formula}")
401
+ if self.chains:
402
+ parts.append(f"chains {','.join(sorted(set(self.chains)))}")
403
+ dx, dy, dz = self.dimensions
404
+ parts.append(f"size {dx:.1f}x{dy:.1f}x{dz:.1f} A")
405
+ return " | ".join(parts)
406
+
407
+ # -- coarse-graining ----------------------------------------------------
408
+
409
+ def coarse_grain(self, mapping="residue_com", weighted: bool = True,
410
+ bonds=None, return_report: bool = False):
411
+ """Map this structure onto CG beads. See :mod:`molscope.coarsegrain`.
412
+
413
+ ``mapping`` is ``"residue_com"``, ``"residue_centroid"``, ``"martini"``,
414
+ a ``{resname: {bead: [atom_names]}}`` dict (by residue), or a
415
+ ``{bead: [atom_indices]}`` dict (by index, works on any structure).
416
+ ``bonds`` optionally defines the bead network as pairs of bead indices,
417
+ or bead names when those names are unique. Repeated residue bead names
418
+ such as ``BB``/``SC`` are ambiguous; use indices for those. Returns a
419
+ new ``Molecule`` of beads with CG bonds attached, or ``(molecule,
420
+ report)`` when ``return_report=True``.
421
+ """
422
+ from .coarsegrain import coarse_grain
423
+
424
+ return coarse_grain(
425
+ self, mapping=mapping, weighted=weighted, bonds=bonds,
426
+ return_report=return_report,
427
+ )
428
+
429
+ def mapping_report(self) -> str:
430
+ """Explain how this coarse-grained molecule was mapped."""
431
+ if self._mapping_report is None:
432
+ raise ValueError("no coarse-graining report is available for this molecule")
433
+ return self._mapping_report.format()
434
+
435
+ # -- ML descriptors -----------------------------------------------------
436
+
437
+ def descriptors(self, **kwargs) -> dict:
438
+ """Return fixed-size molecular descriptors for quick ML features."""
439
+ from .descriptors import descriptors
440
+
441
+ return descriptors(self, **kwargs)
442
+
443
+ # -- graph export -------------------------------------------------------
444
+
445
+ def to_graph(self, tolerance: float = 1.2, bonds=None):
446
+ """Build a :class:`molscope.graph.MolecularGraph` from this molecule.
447
+
448
+ Bonds are inferred from covalent radii (see :meth:`bonds`) unless an
449
+ explicit ``(E, 2)`` array of index pairs is passed. Node and edge
450
+ attributes (element, residue, chain, distance, ...) are carried along.
451
+ """
452
+ from .graph import MolecularGraph
453
+
454
+ edges = self.bonds(tolerance) if bonds is None else np.asarray(bonds, dtype=int)
455
+ edges = edges.reshape(-1, 2)
456
+ if len(edges):
457
+ dist = np.linalg.norm(self.coords[edges[:, 0]] - self.coords[edges[:, 1]], axis=1)
458
+ else:
459
+ dist = np.empty(0, dtype=float)
460
+ return MolecularGraph(
461
+ coords=self.coords, elements=self.elements, edges=edges,
462
+ edge_distances=dist, edge_types=np.ones(len(edges)),
463
+ atom_names=self.atom_names, resnames=self.resnames,
464
+ resids=self.resids, chains=self.chains, name=self.name,
465
+ )
466
+
467
+ def to_networkx(self, **kwargs):
468
+ """Shortcut for ``self.to_graph(...).to_networkx()``."""
469
+ return self.to_graph(**kwargs).to_networkx()
470
+
471
+ def to_pyg_data(self, **kwargs):
472
+ """Shortcut for ``self.to_graph(...).to_pyg_data()`` (PyTorch Geometric)."""
473
+ return self.to_graph(**kwargs).to_pyg_data()
474
+
475
+ def to_dgl_graph(self, **kwargs):
476
+ """Shortcut for ``self.to_graph(...).to_dgl_graph()`` (DGL)."""
477
+ return self.to_graph(**kwargs).to_dgl_graph()
478
+
479
+ def plot(self, **kwargs):
480
+ """Render the molecule in 3D. See :func:`molscope.plotting.plot`."""
481
+ from .plotting import plot
482
+
483
+ return plot(self, **kwargs)
484
+
485
+ def view(self, **kwargs):
486
+ """Interactive py3Dmol viewer. See :func:`molscope.plotting.view`."""
487
+ from .plotting import view
488
+
489
+ return view(self, **kwargs)
490
+
491
+
492
+ def _rotation_matrix(axis: np.ndarray, angle: float) -> np.ndarray:
493
+ """Rodrigues rotation matrix for ``angle`` radians about ``axis``."""
494
+ axis = axis / np.linalg.norm(axis)
495
+ x, y, z = axis
496
+ c, s = np.cos(angle), np.sin(angle)
497
+ C = 1 - c
498
+ return np.array([
499
+ [c + x * x * C, x * y * C - z * s, x * z * C + y * s],
500
+ [y * x * C + z * s, c + y * y * C, y * z * C - x * s],
501
+ [z * x * C - y * s, z * y * C + x * s, c + z * z * C],
502
+ ])
molscope/plotting.py ADDED
@@ -0,0 +1,191 @@
1
+ """3D visualization of molecules: matplotlib, py3Dmol, and GIF export."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import itertools
6
+ import warnings
7
+ from typing import Optional
8
+
9
+ import numpy as np
10
+
11
+ from . import elements
12
+
13
+
14
+ def plot(
15
+ molecule,
16
+ show_bonds: Optional[bool] = None,
17
+ bond_tolerance: float = 1.2,
18
+ color_by: str = "element",
19
+ scale: float = 60.0,
20
+ ax=None,
21
+ show: bool = True,
22
+ ):
23
+ """Scatter-plot atoms in 3D with an equal aspect ratio.
24
+
25
+ ``color_by`` selects the colouring: ``"element"`` (CPK), ``"chain"`` or
26
+ ``"residue"`` (categorical palette). Atom sizes scale with covalent radius.
27
+ Bonds are drawn when ``show_bonds`` is true, or, when ``None``, automatically
28
+ for molecules small enough to infer bonds cheaply. Returns the ``Axes3D``;
29
+ pass ``show=False`` to suppress ``plt.show()``.
30
+ """
31
+ import matplotlib.pyplot as plt # imported lazily so the core has no GUI dep
32
+
33
+ coords = molecule.coords
34
+ if ax is None:
35
+ fig = plt.figure()
36
+ ax = fig.add_subplot(1, 1, 1, projection="3d")
37
+
38
+ colors = _colors(molecule, color_by)
39
+ sizes = np.array([elements.covalent_radius(e) for e in molecule.elements]) * scale
40
+ ax.scatter(coords[:, 0], coords[:, 1], coords[:, 2], c=colors, s=sizes, depthshade=True)
41
+
42
+ if show_bonds is None:
43
+ show_bonds = len(molecule) <= 2000
44
+ if show_bonds and len(molecule) > 1:
45
+ try:
46
+ for i, j in molecule.bonds(tolerance=bond_tolerance):
47
+ seg = coords[[i, j]]
48
+ ax.plot(seg[:, 0], seg[:, 1], seg[:, 2], color="0.5", linewidth=1.0)
49
+ except ValueError as exc:
50
+ warnings.warn(f"skipping bonds: {exc}", stacklevel=2)
51
+
52
+ ax.set_xlabel("X")
53
+ ax.set_ylabel("Y")
54
+ ax.set_zlabel("Z")
55
+ if molecule.name:
56
+ ax.set_title(molecule.name)
57
+ _set_equal_aspect(ax, coords)
58
+
59
+ if show:
60
+ plt.show()
61
+ return ax
62
+
63
+
64
+ def view(molecule, style: str = "stick", width: int = 480, height: int = 360):
65
+ """Return an interactive py3Dmol viewer (for Jupyter notebooks).
66
+
67
+ Requires py3Dmol (``pip install py3Dmol``). ``style`` is any py3Dmol style
68
+ name such as ``"stick"``, ``"sphere"``, ``"line"`` or ``"cartoon"``.
69
+ """
70
+ try:
71
+ import py3Dmol
72
+ except ImportError as exc: # pragma: no cover - exercised only without py3Dmol
73
+ raise ImportError(
74
+ "view() needs py3Dmol; install it with: pip install py3Dmol"
75
+ ) from exc
76
+ from .io import _molecule_to_pdb_string
77
+
78
+ viewer = py3Dmol.view(width=width, height=height)
79
+ viewer.addModel(_molecule_to_pdb_string(molecule), "pdb")
80
+ viewer.setStyle({style: {"colorscheme": "default"}})
81
+ viewer.zoomTo()
82
+ return viewer
83
+
84
+
85
+ def spin_gif(molecule, path: str, frames: int = 36, fps: int = 15, **plot_kwargs):
86
+ """Render a spinning 3D view and save it as an animated GIF.
87
+
88
+ Rotates a full turn about the vertical axis over ``frames`` steps. Requires
89
+ Pillow (already a matplotlib dependency).
90
+ """
91
+ import matplotlib.pyplot as plt
92
+ from matplotlib import animation
93
+
94
+ fig = plt.figure()
95
+ ax = fig.add_subplot(1, 1, 1, projection="3d")
96
+ plot(molecule, ax=ax, show=False, **plot_kwargs)
97
+
98
+ def update(i):
99
+ ax.view_init(elev=20, azim=i * 360 / frames)
100
+ return ()
101
+
102
+ anim = animation.FuncAnimation(fig, update, frames=frames, blit=False)
103
+ anim.save(path, writer=animation.PillowWriter(fps=fps))
104
+ plt.close(fig)
105
+ return path
106
+
107
+
108
+ def plot_contact_map(contact_map, ax=None, cmap=None, show: bool = True):
109
+ """Draw a :class:`~molscope.contactmap.ContactMap` as a heatmap.
110
+
111
+ Booleans render as a binary map; ensemble frequencies render with a colour
112
+ scale and a colourbar. Returns the matplotlib ``Axes``.
113
+ """
114
+ import matplotlib.pyplot as plt
115
+
116
+ mat = contact_map.matrix
117
+ freq = contact_map.is_frequency
118
+ if ax is None:
119
+ _, ax = plt.subplots(figsize=(5, 4))
120
+ im = ax.imshow(
121
+ mat, origin="lower", interpolation="nearest", vmin=0, vmax=1,
122
+ cmap=cmap or ("viridis" if freq else "Greys"),
123
+ )
124
+
125
+ unit = "residue" if contact_map.level == "residue" else "atom"
126
+ ax.set_xlabel(f"{unit} index")
127
+ ax.set_ylabel(f"{unit} index")
128
+ label = "contact frequency" if freq else f"contact (< {contact_map.cutoff} Å)"
129
+ ax.figure.colorbar(im, ax=ax, label=label, fraction=0.046, pad=0.04)
130
+ ax.set_title(f"{unit} contact map ({contact_map.cutoff} Å)")
131
+
132
+ if show:
133
+ plt.show()
134
+ return ax
135
+
136
+
137
+ def plot_rmsd_heatmap(matrix, order=None, ax=None, cmap="viridis", show: bool = True):
138
+ """Draw a pairwise-RMSD matrix as a heatmap (angstrom).
139
+
140
+ Pass ``order`` (e.g. ``clustering.order``) to reorder rows/columns so
141
+ clusters appear as blocks along the diagonal. Returns the matplotlib ``Axes``.
142
+ """
143
+ import matplotlib.pyplot as plt
144
+
145
+ matrix = np.asarray(matrix, dtype=float)
146
+ if order is not None:
147
+ order = np.asarray(order)
148
+ matrix = matrix[np.ix_(order, order)]
149
+ if ax is None:
150
+ _, ax = plt.subplots(figsize=(5, 4))
151
+ im = ax.imshow(matrix, origin="lower", interpolation="nearest", cmap=cmap)
152
+ ax.set_xlabel("model")
153
+ ax.set_ylabel("model")
154
+ ax.figure.colorbar(im, ax=ax, label="RMSD (Å)", fraction=0.046, pad=0.04)
155
+ ax.set_title("pairwise RMSD")
156
+ if show:
157
+ plt.show()
158
+ return ax
159
+
160
+
161
+ def _colors(molecule, color_by: str):
162
+ if color_by == "element":
163
+ return [elements.color(e) for e in molecule.elements]
164
+ if color_by == "chain":
165
+ keys = molecule.chains
166
+ elif color_by == "residue":
167
+ keys = [str(r) for r in molecule.resids] if len(molecule.resids) else []
168
+ else:
169
+ raise ValueError(f"unknown color_by {color_by!r}")
170
+ if not keys:
171
+ raise ValueError(f"no {color_by} information to colour by")
172
+ return _categorical_colors(keys)
173
+
174
+
175
+ def _categorical_colors(keys):
176
+ import matplotlib.pyplot as plt
177
+
178
+ palette = plt.get_cmap("tab20").colors
179
+ cycle = {}
180
+ wheel = itertools.cycle(palette)
181
+ return [cycle.setdefault(k, next(wheel)) for k in keys]
182
+
183
+
184
+ def _set_equal_aspect(ax, coords: np.ndarray) -> None:
185
+ """Force equal scaling on all axes so the molecule isn't distorted."""
186
+ mins, maxs = coords.min(axis=0), coords.max(axis=0)
187
+ centers = (maxs + mins) / 2
188
+ radius = (maxs - mins).max() / 2 or 1.0
189
+ ax.set_xlim(centers[0] - radius, centers[0] + radius)
190
+ ax.set_ylim(centers[1] - radius, centers[1] + radius)
191
+ ax.set_zlim(centers[2] - radius, centers[2] + radius)