molscope 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
molscope/__init__.py ADDED
@@ -0,0 +1,53 @@
1
+ """molscope: read, analyse, and plot molecular structures in 3D."""
2
+
3
+ from . import coarsegrain, ensemble
4
+ from .coarsegrain import BeadMapping, BondMapping, CoarseGrainReport, DroppedAtom
5
+ from .contactmap import ContactMap
6
+ from .descriptors import descriptors, featurize_many
7
+ from .ensemble import Clustering, cluster, rmsd_matrix
8
+ from .ensemble import contact_frequency as ensemble_contact_frequency
9
+ from .graph import MolecularGraph
10
+ from .io import (
11
+ fetch,
12
+ read,
13
+ read_cif,
14
+ read_pdb,
15
+ read_pdb_models,
16
+ read_sdf,
17
+ read_xyz,
18
+ read_xyz_frames,
19
+ write_pdb,
20
+ write_xyz,
21
+ )
22
+ from .molecule import Molecule
23
+ from .plotting import plot_rmsd_heatmap
24
+
25
+ __all__ = [
26
+ "Clustering",
27
+ "BeadMapping",
28
+ "BondMapping",
29
+ "CoarseGrainReport",
30
+ "ContactMap",
31
+ "DroppedAtom",
32
+ "Molecule",
33
+ "MolecularGraph",
34
+ "cluster",
35
+ "coarsegrain",
36
+ "descriptors",
37
+ "ensemble",
38
+ "ensemble_contact_frequency",
39
+ "featurize_many",
40
+ "fetch",
41
+ "plot_rmsd_heatmap",
42
+ "read",
43
+ "read_cif",
44
+ "read_pdb",
45
+ "read_pdb_models",
46
+ "read_sdf",
47
+ "read_xyz",
48
+ "read_xyz_frames",
49
+ "rmsd_matrix",
50
+ "write_pdb",
51
+ "write_xyz",
52
+ ]
53
+ __version__ = "0.6.0"
molscope/__main__.py ADDED
@@ -0,0 +1,3 @@
1
+ from .cli import main
2
+
3
+ raise SystemExit(main())
molscope/cli.py ADDED
@@ -0,0 +1,74 @@
1
+ """Command-line entry point: ``python -m molscope FILE [options]``."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+
7
+ from .io import fetch, read
8
+
9
+
10
+ def main(argv=None) -> int:
11
+ parser = argparse.ArgumentParser(
12
+ prog="molscope",
13
+ description="Read a structure (.xyz/.pdb/.cif/.sdf), transform it, and plot in 3D.",
14
+ )
15
+ src = parser.add_mutually_exclusive_group(required=True)
16
+ src.add_argument("file", nargs="?", help="path to a structure file")
17
+ src.add_argument("--fetch", metavar="PDBID", help="download a structure from RCSB by id")
18
+
19
+ parser.add_argument(
20
+ "--select", metavar="SPEC",
21
+ help="atom selection, e.g. 'chain=A' or 'atom_name=CA' or 'element=C'",
22
+ )
23
+ parser.add_argument(
24
+ "--translate", type=float, nargs=3, metavar=("DX", "DY", "DZ"),
25
+ help="shift all atoms by this vector before plotting",
26
+ )
27
+ parser.add_argument("--center", action="store_true", help="move the centroid to the origin")
28
+ parser.add_argument(
29
+ "--rotate", nargs=2, metavar=("AXIS", "DEG"),
30
+ help="rotate about AXIS (x/y/z) by DEG degrees, e.g. --rotate z 90",
31
+ )
32
+ parser.add_argument(
33
+ "--color-by", choices=["element", "chain", "residue"], default="element",
34
+ help="how to colour atoms (default: element)",
35
+ )
36
+ bonds = parser.add_mutually_exclusive_group()
37
+ bonds.add_argument("--bonds", dest="bonds", action="store_true", help="force drawing bonds")
38
+ bonds.add_argument("--no-bonds", dest="bonds", action="store_false", help="never draw bonds")
39
+ parser.set_defaults(bonds=None)
40
+ parser.add_argument("--save", metavar="PATH", help="save the figure instead of showing it")
41
+ parser.add_argument("--gif", metavar="PATH", help="save a spinning animation as a GIF")
42
+
43
+ args = parser.parse_args(argv)
44
+
45
+ mol = fetch(args.fetch) if args.fetch else read(args.file)
46
+ if args.select:
47
+ key, _, value = args.select.partition("=")
48
+ mol = mol.select(**{key.strip(): value.strip()})
49
+ if args.center:
50
+ mol = mol.centered()
51
+ if args.translate:
52
+ mol = mol.translate(args.translate)
53
+ if args.rotate:
54
+ mol = mol.rotate(args.rotate[0], float(args.rotate[1]))
55
+
56
+ print(mol.summary())
57
+
58
+ if args.gif:
59
+ from .plotting import spin_gif
60
+
61
+ spin_gif(mol, args.gif, color_by=args.color_by, show_bonds=args.bonds)
62
+ print(f"saved {args.gif}")
63
+ return 0
64
+
65
+ show = args.save is None
66
+ ax = mol.plot(show_bonds=args.bonds, color_by=args.color_by, show=show)
67
+ if args.save:
68
+ ax.figure.savefig(args.save, dpi=150, bbox_inches="tight")
69
+ print(f"saved {args.save}")
70
+ return 0
71
+
72
+
73
+ if __name__ == "__main__":
74
+ raise SystemExit(main())
@@ -0,0 +1,411 @@
1
+ """Coarse-graining: map an atomistic structure onto a smaller set of beads.
2
+
3
+ The result is an ordinary :class:`~molscope.molecule.Molecule` whose "atoms"
4
+ are beads, so it plots, transforms and exports to a graph like any other. Bead
5
+ positions are the mass-weighted centre (or geometric centroid) of their member
6
+ atoms, and explicit CG bonds are attached.
7
+
8
+ Built-in modes:
9
+
10
+ - ``"residue_com"`` — one bead per residue at its centre of mass
11
+ - ``"residue_centroid"`` — one bead per residue at its geometric centroid
12
+ - ``"martini"`` — a simplified backbone + side-chain (BB/SC) model
13
+
14
+ Custom mappings come in two forms:
15
+
16
+ - **By residue/atom name** (needs PDB/mmCIF metadata)::
17
+
18
+ {"ALA": {"BB": ["N", "CA", "C", "O"], "SC": ["CB"]}}
19
+
20
+ Bonds are generated automatically (sequential within a residue, plus a
21
+ backbone chain between consecutive residues) unless you pass ``bonds=``.
22
+
23
+ - **By atom index** (works on any structure, including ``.xyz``)::
24
+
25
+ {"head": [0, 1, 2, 3], "tail": [4, 5, 6, 7]}
26
+
27
+ No bonds are added unless you pass ``bonds=`` (see below).
28
+
29
+ ``bonds`` lets you define the bead network explicitly as pairs of bead indices,
30
+ or by bead name when names are unique, e.g. ``bonds=[("head", "tail")]`` or
31
+ ``bonds=[(0, 1)]``. Repeated names such as ``BB``/``SC`` in Martini-like residue
32
+ mappings are ambiguous; use bead indices for those.
33
+
34
+ Intended for teaching and prototyping CG mappings, not as a substitute for
35
+ production Martini parameters.
36
+ """
37
+
38
+ from __future__ import annotations
39
+
40
+ import warnings
41
+ from dataclasses import dataclass, field
42
+ from typing import Optional
43
+
44
+ import numpy as np
45
+
46
+ from . import elements
47
+ from .molecule import Molecule
48
+
49
+ _BACKBONE = ("N", "CA", "C", "O", "OXT")
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class BeadMapping:
54
+ """One coarse-grained bead and the atom names assigned to it."""
55
+
56
+ name: str
57
+ atom_names: list[str]
58
+ reduction: str
59
+ resname: str = ""
60
+ resid: Optional[int] = None
61
+ chain: str = ""
62
+
63
+
64
+ @dataclass(frozen=True)
65
+ class DroppedAtom:
66
+ """An atom omitted by a custom coarse-graining mapping."""
67
+
68
+ name: str
69
+ element: str
70
+ resname: str = ""
71
+ resid: Optional[int] = None
72
+ chain: str = ""
73
+
74
+
75
+ @dataclass(frozen=True)
76
+ class BondMapping:
77
+ """One generated or user-defined CG bond."""
78
+
79
+ a: str
80
+ b: str
81
+ reason: str
82
+
83
+
84
+ @dataclass(frozen=True)
85
+ class CoarseGrainReport:
86
+ """Human-readable explanation of a coarse-graining operation."""
87
+
88
+ mapping: str
89
+ beads: list[BeadMapping] = field(default_factory=list)
90
+ dropped_atoms: list[DroppedAtom] = field(default_factory=list)
91
+ bonds: list[BondMapping] = field(default_factory=list)
92
+
93
+ def __str__(self) -> str:
94
+ return self.format()
95
+
96
+ def format(self) -> str:
97
+ lines = [f"Mapping: {self.mapping}", "", "Beads:"]
98
+ if self.beads:
99
+ for bead in self.beads:
100
+ prefix = _residue_label(bead.resid, bead.resname, bead.chain)
101
+ atoms = ", ".join(bead.atom_names) if bead.atom_names else "(none)"
102
+ lines.append(f" {prefix}:")
103
+ lines.append(f" {bead.name} bead: {atoms} -> {bead.reduction}")
104
+ else:
105
+ lines.append(" (none)")
106
+
107
+ lines.extend(["", "Dropped atoms:"])
108
+ if self.dropped_atoms:
109
+ for atom in self.dropped_atoms:
110
+ label = _residue_label(atom.resid, atom.resname, atom.chain)
111
+ name = atom.name or atom.element or "(unnamed)"
112
+ lines.append(f" {label}: {name}")
113
+ else:
114
+ lines.append(" (none)")
115
+
116
+ lines.extend(["", "Generated bonds:"])
117
+ if self.bonds:
118
+ for bond in self.bonds:
119
+ lines.append(f" {bond.a}-{bond.b} {bond.reason}")
120
+ else:
121
+ lines.append(" (none)")
122
+ return "\n".join(lines)
123
+
124
+
125
+ def coarse_grain(molecule: Molecule, mapping="residue_com", weighted: bool = True,
126
+ bonds=None, return_report: bool = False):
127
+ """Coarse-grain ``molecule``; see the module docstring for the options."""
128
+ if _is_index_mapping(mapping):
129
+ cg, report = _by_index(molecule, mapping, weighted, bonds)
130
+ return (cg, report) if return_report else cg
131
+
132
+ if len(molecule.resids) == 0:
133
+ raise ValueError(
134
+ "coarse-graining by residue needs residue information; for a file "
135
+ "without it (e.g. .xyz) use an index mapping {bead: [atom_indices]}"
136
+ )
137
+ cg, report = _by_residue(molecule, mapping, weighted, bonds)
138
+ return (cg, report) if return_report else cg
139
+
140
+
141
+ def _is_index_mapping(mapping) -> bool:
142
+ """An index mapping is a dict whose values are index lists, not sub-dicts."""
143
+ if not isinstance(mapping, dict) or not mapping:
144
+ return False
145
+ return not isinstance(next(iter(mapping.values())), dict)
146
+
147
+
148
+ def _by_index(molecule: Molecule, mapping: dict, weighted: bool, bonds) -> Molecule:
149
+ bead_names: list[str] = []
150
+ bead_coords: list[np.ndarray] = []
151
+ bead_report: list[BeadMapping] = []
152
+ assigned: set[int] = set()
153
+ for name, members in mapping.items():
154
+ try:
155
+ members = [int(i) for i in members]
156
+ except (TypeError, ValueError):
157
+ raise ValueError(
158
+ f"bead {name!r}: an index mapping expects integer atom indices. "
159
+ "For atom-name beads use a residue mapping {resname: {bead: [names]}}."
160
+ ) from None
161
+ if not members:
162
+ continue
163
+ assigned.update(members)
164
+ bead_coords.append(_reduce(molecule, members, weighted))
165
+ bead_names.append(name)
166
+ bead_report.append(
167
+ BeadMapping(
168
+ name=name,
169
+ atom_names=_atom_names(molecule, members),
170
+ reduction=_reduction_name(weighted),
171
+ )
172
+ )
173
+
174
+ if not bead_coords:
175
+ raise ValueError("mapping produced no beads")
176
+ dropped = _dropped_atoms(molecule, assigned)
177
+ _warn_dropped(len(molecule), assigned)
178
+ bond_index, bond_report = _resolve_bonds(bonds, bead_names)
179
+ report = CoarseGrainReport(
180
+ mapping="index",
181
+ beads=bead_report,
182
+ dropped_atoms=dropped,
183
+ bonds=bond_report,
184
+ )
185
+ cg = Molecule(
186
+ np.array(bead_coords, dtype=float), elements=[""] * len(bead_coords),
187
+ name=f"{molecule.name} (CG)", atom_names=bead_names,
188
+ bond_index=bond_index, _mapping_report=report,
189
+ )
190
+ return cg, report
191
+
192
+
193
+ def _by_residue(molecule: Molecule, mapping, weighted: bool, bonds) -> Molecule:
194
+ use_mass = weighted and mapping != "residue_centroid"
195
+ bead_coords, bead_names = [], []
196
+ bead_resnames, bead_resids, bead_chains = [], [], []
197
+ residue_beads: list[list[int]] = []
198
+ bead_report: list[BeadMapping] = []
199
+ assigned: set[int] = set()
200
+
201
+ for atom_idx, resname, resid, chain in molecule.residue_groups():
202
+ local = []
203
+ for bead_name, members in _residue_beads(molecule, atom_idx, resname, mapping):
204
+ if not members:
205
+ continue
206
+ assigned.update(members)
207
+ bead_coords.append(_reduce(molecule, members, use_mass))
208
+ bead_names.append(bead_name)
209
+ bead_resnames.append(resname)
210
+ bead_resids.append(resid)
211
+ bead_chains.append(chain)
212
+ local.append(len(bead_coords) - 1)
213
+ bead_report.append(
214
+ BeadMapping(
215
+ name=bead_name,
216
+ atom_names=_atom_names(molecule, members),
217
+ reduction=_reduction_name(use_mass),
218
+ resname=resname,
219
+ resid=resid,
220
+ chain=chain,
221
+ )
222
+ )
223
+ if local:
224
+ residue_beads.append(local)
225
+
226
+ if not bead_coords:
227
+ raise ValueError("mapping produced no beads")
228
+ dropped = _dropped_atoms(molecule, assigned)
229
+ if isinstance(mapping, dict): # only custom mappings can leave atoms unassigned
230
+ _warn_dropped(len(molecule), assigned)
231
+
232
+ if bonds is not None:
233
+ bond_index, bond_report = _resolve_bonds(bonds, bead_names)
234
+ else:
235
+ bond_index, bond_report = _cg_bonds(residue_beads, bead_chains, bead_names)
236
+ report = CoarseGrainReport(
237
+ mapping=_mapping_name(mapping),
238
+ beads=bead_report,
239
+ dropped_atoms=dropped if isinstance(mapping, dict) or mapping == "martini" else [],
240
+ bonds=bond_report,
241
+ )
242
+ cg = Molecule(
243
+ np.array(bead_coords, dtype=float), elements=[""] * len(bead_coords),
244
+ name=f"{molecule.name} (CG)", atom_names=bead_names, resnames=bead_resnames,
245
+ resids=np.array(bead_resids, dtype=int), chains=bead_chains,
246
+ bond_index=bond_index, _mapping_report=report,
247
+ )
248
+ return cg, report
249
+
250
+
251
+ def _residue_beads(molecule: Molecule, atom_idx, resname, mapping):
252
+ """Return ``[(bead_name, [atom_index, ...]), ...]`` for one residue."""
253
+ if mapping in ("residue_com", "residue_centroid"):
254
+ return [(resname or "BEAD", atom_idx)]
255
+ if mapping == "martini":
256
+ return _backbone_sidechain(molecule, atom_idx)
257
+
258
+ spec = mapping.get(resname)
259
+ if spec is None:
260
+ warnings.warn(
261
+ f"no mapping for residue {resname!r}; collapsing it to one bead",
262
+ stacklevel=4,
263
+ )
264
+ return [(resname or "BEAD", atom_idx)]
265
+ names = {molecule.atom_names[i]: i for i in atom_idx}
266
+ return [
267
+ (bead, [names[a] for a in atoms if a in names])
268
+ for bead, atoms in spec.items()
269
+ ]
270
+
271
+
272
+ def _backbone_sidechain(molecule: Molecule, atom_idx):
273
+ """Simplified Martini-like split: a backbone bead and a side-chain bead."""
274
+ bb = [i for i in atom_idx if molecule.atom_names[i] in _BACKBONE]
275
+ sc = [
276
+ i for i in atom_idx
277
+ if molecule.atom_names[i] not in _BACKBONE and molecule.elements[i] != "H"
278
+ ]
279
+ beads = []
280
+ if bb:
281
+ beads.append(("BB", bb))
282
+ if sc:
283
+ beads.append(("SC", sc))
284
+ return beads
285
+
286
+
287
+ def _reduce(molecule: Molecule, members, use_mass: bool) -> np.ndarray:
288
+ coords = molecule.coords[members]
289
+ if use_mass:
290
+ w = np.array([elements.mass(molecule.elements[i]) for i in members])
291
+ return (w[:, None] * coords).sum(axis=0) / w.sum()
292
+ return coords.mean(axis=0)
293
+
294
+
295
+ def _resolve_bonds(bonds, bead_names):
296
+ """Turn user bond pairs (bead names or indices) into an (E, 2) index array."""
297
+ if bonds is None:
298
+ return None, []
299
+ name_to_idx = _unique_name_index(bead_names)
300
+ pairs = []
301
+ report = []
302
+ for a, b in bonds:
303
+ ai = _resolve_bond_endpoint(a, bead_names, name_to_idx)
304
+ bi = _resolve_bond_endpoint(b, bead_names, name_to_idx)
305
+ pairs.append((ai, bi))
306
+ report.append(BondMapping(bead_names[ai], bead_names[bi], "(user-defined)"))
307
+ return np.array(pairs, dtype=int).reshape(-1, 2), report
308
+
309
+
310
+ def _resolve_bond_endpoint(endpoint, bead_names, name_to_idx) -> int:
311
+ if not isinstance(endpoint, str):
312
+ return int(endpoint)
313
+ if endpoint in name_to_idx:
314
+ return name_to_idx[endpoint]
315
+ if endpoint in bead_names:
316
+ raise ValueError(
317
+ f"bead name {endpoint!r} is repeated and cannot identify one bead; "
318
+ "use bead indices for user-defined bonds"
319
+ )
320
+ raise ValueError(f"unknown bead name {endpoint!r}")
321
+
322
+
323
+ def _unique_name_index(bead_names):
324
+ """Map unique bead names to indices; repeated names cannot identify one bead."""
325
+ counts = {}
326
+ for name in bead_names:
327
+ counts[name] = counts.get(name, 0) + 1
328
+
329
+ name_to_idx = {}
330
+ for i, name in enumerate(bead_names):
331
+ if counts[name] == 1:
332
+ name_to_idx[name] = i
333
+ return name_to_idx
334
+
335
+
336
+ def _cg_bonds(residue_beads, bead_chains, bead_names):
337
+ """Bonds within each residue (sequential) plus a chain between residues."""
338
+ bonds: list[tuple[int, int]] = []
339
+ report: list[BondMapping] = []
340
+ for beads in residue_beads:
341
+ for a, b in zip(beads, beads[1:]):
342
+ bonds.append((a, b))
343
+ report.append(BondMapping(bead_names[a], bead_names[b], "within residue"))
344
+ for prev, curr in zip(residue_beads, residue_beads[1:]):
345
+ if bead_chains[prev[0]] == bead_chains[curr[0]]:
346
+ bonds.append((prev[0], curr[0]))
347
+ report.append(
348
+ BondMapping(bead_names[prev[0]], bead_names[curr[0]], "between residues")
349
+ )
350
+ return np.array(bonds, dtype=int).reshape(-1, 2), report
351
+
352
+
353
+ def _warn_dropped(n_atoms: int, assigned: set) -> None:
354
+ dropped = n_atoms - len(assigned)
355
+ if dropped:
356
+ warnings.warn(
357
+ f"{dropped} atom(s) were not assigned to any bead and were dropped",
358
+ stacklevel=3,
359
+ )
360
+
361
+
362
+ def _atom_names(molecule: Molecule, atom_indices) -> list[str]:
363
+ names = []
364
+ for i in atom_indices:
365
+ if molecule.atom_names:
366
+ names.append(molecule.atom_names[i])
367
+ elif molecule.elements:
368
+ names.append(molecule.elements[i])
369
+ else:
370
+ names.append(str(i))
371
+ return names
372
+
373
+
374
+ def _dropped_atoms(molecule: Molecule, assigned: set) -> list[DroppedAtom]:
375
+ dropped = []
376
+ chains = molecule.chains or [""] * len(molecule)
377
+ resnames = molecule.resnames or [""] * len(molecule)
378
+ resids = molecule.resids if len(molecule.resids) else [None] * len(molecule)
379
+ atom_names = molecule.atom_names or [""] * len(molecule)
380
+ for i in range(len(molecule)):
381
+ if i in assigned:
382
+ continue
383
+ dropped.append(
384
+ DroppedAtom(
385
+ name=atom_names[i],
386
+ element=molecule.elements[i] if molecule.elements else "",
387
+ resname=resnames[i],
388
+ resid=None if resids[i] is None else int(resids[i]),
389
+ chain=chains[i],
390
+ )
391
+ )
392
+ return dropped
393
+
394
+
395
+ def _reduction_name(weighted: bool) -> str:
396
+ return "centre of mass" if weighted else "centroid"
397
+
398
+
399
+ def _mapping_name(mapping) -> str:
400
+ return mapping if isinstance(mapping, str) else "custom residue"
401
+
402
+
403
+ def _residue_label(resid: Optional[int], resname: str, chain: str) -> str:
404
+ parts = ["Residue"]
405
+ if resid is not None:
406
+ parts.append(str(resid))
407
+ if resname:
408
+ parts.append(resname)
409
+ if chain:
410
+ parts.append(chain)
411
+ return " ".join(parts) if len(parts) > 1 else "Molecule"
molscope/contactmap.py ADDED
@@ -0,0 +1,116 @@
1
+ """Contact maps and residue-level contact analysis.
2
+
3
+ A contact map records which pairs of atoms (or residues) are within a distance
4
+ cutoff. Residue contact maps are a staple of protein-folding intuition, peptide
5
+ and NMR-ensemble comparison, and coarse-graining validation.
6
+
7
+ cmap = mol.contact_map(cutoff=8.0, level="residue") # -> ContactMap
8
+ cmap.matrix # (R, R) array
9
+ cmap.plot() # heatmap
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass, field
15
+
16
+ import numpy as np
17
+
18
+ from . import elements
19
+ from .molecule import Molecule
20
+
21
+
22
+ @dataclass
23
+ class ContactMap:
24
+ """A square contact (or contact-frequency) matrix with axis labels.
25
+
26
+ ``matrix`` holds booleans-as-floats for a single structure, or values in
27
+ ``[0, 1]`` for an ensemble frequency map. ``labels`` name the rows/columns
28
+ (e.g. ``"A:LYS8"``); ``resids`` are the residue numbers for a residue map.
29
+ """
30
+
31
+ matrix: np.ndarray
32
+ level: str
33
+ cutoff: float
34
+ labels: list[str] = field(default_factory=list)
35
+ resids: np.ndarray = field(default_factory=lambda: np.empty(0, dtype=int))
36
+
37
+ @property
38
+ def is_frequency(self) -> bool:
39
+ """True if the matrix holds fractional values (an ensemble map)."""
40
+ vals = np.unique(self.matrix)
41
+ return not np.all(np.isin(vals, (0.0, 1.0)))
42
+
43
+ def plot(self, **kwargs):
44
+ """Draw the contact map as a heatmap. See :func:`molscope.plotting.plot_contact_map`."""
45
+ from .plotting import plot_contact_map
46
+
47
+ return plot_contact_map(self, **kwargs)
48
+
49
+
50
+ def contact_map(molecule: Molecule, cutoff: float = 8.0, level: str = "residue",
51
+ method: str = "ca") -> ContactMap:
52
+ """Compute a contact map for one structure (see :class:`ContactMap`)."""
53
+ if level == "atom":
54
+ dist = molecule.distance_matrix()
55
+ mat = (dist < cutoff).astype(float)
56
+ np.fill_diagonal(mat, 0.0)
57
+ return ContactMap(mat, level="atom", cutoff=cutoff)
58
+
59
+ if level != "residue":
60
+ raise ValueError(f"level must be 'atom' or 'residue', got {level!r}")
61
+
62
+ groups = list(molecule.residue_groups())
63
+ if not groups:
64
+ raise ValueError("residue contact map needs residue information")
65
+ labels = [_label(chain, resname, resid) for _, resname, resid, chain in groups]
66
+ resids = np.array([resid for _, _, resid, _ in groups], dtype=int)
67
+ mat = _residue_contacts(molecule, groups, cutoff, method)
68
+ return ContactMap(mat, level="residue", cutoff=cutoff, labels=labels, resids=resids)
69
+
70
+
71
+ def _residue_contacts(molecule, groups, cutoff, method) -> np.ndarray:
72
+ if method in ("ca", "com"):
73
+ reps = np.array([_representative(molecule, idx, method) for idx, *_ in groups])
74
+ deltas = reps[:, None, :] - reps[None, :, :]
75
+ dist = np.sqrt((deltas ** 2).sum(axis=-1))
76
+ mat = (dist < cutoff).astype(float)
77
+ elif method == "min":
78
+ mat = _min_distance_contacts(molecule, groups, cutoff)
79
+ else:
80
+ raise ValueError(f"method must be 'ca', 'com' or 'min', got {method!r}")
81
+ np.fill_diagonal(mat, 0.0)
82
+ return mat
83
+
84
+
85
+ def _representative(molecule, idx, method) -> np.ndarray:
86
+ if method == "ca" and molecule.atom_names:
87
+ ca = [i for i in idx if molecule.atom_names[i] == "CA"]
88
+ if ca:
89
+ return molecule.coords[ca[0]]
90
+ if method == "com":
91
+ w = np.array([elements.mass(molecule.elements[i]) for i in idx])
92
+ return (w[:, None] * molecule.coords[idx]).sum(axis=0) / w.sum()
93
+ return molecule.coords[idx].mean(axis=0) # CA fallback: residue centroid
94
+
95
+
96
+ def _min_distance_contacts(molecule, groups, cutoff) -> np.ndarray:
97
+ try:
98
+ from scipy.spatial.distance import cdist
99
+ except ImportError:
100
+ def cdist(a, b):
101
+ d = a[:, None, :] - b[None, :, :]
102
+ return np.sqrt((d ** 2).sum(axis=-1))
103
+
104
+ n = len(groups)
105
+ coords = [molecule.coords[idx] for idx, *_ in groups]
106
+ mat = np.zeros((n, n))
107
+ for a in range(n):
108
+ for b in range(a + 1, n):
109
+ if cdist(coords[a], coords[b]).min() < cutoff:
110
+ mat[a, b] = mat[b, a] = 1.0
111
+ return mat
112
+
113
+
114
+ def _label(chain, resname, resid) -> str:
115
+ base = f"{resname or 'RES'}{resid}"
116
+ return f"{chain}:{base}" if chain else base