molscope 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,232 @@
1
+ """Fixed-size molecular descriptors for quick ML feature tables."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import Counter
6
+ from typing import Optional
7
+
8
+ import numpy as np
9
+
10
+ from .io import read
11
+
12
+ DEFAULT_ELEMENTS = (
13
+ "H", "C", "N", "O", "S", "P", "F", "CL", "BR", "I", "NA", "MG", "CA", "FE", "ZN",
14
+ )
15
+
16
+
17
+ def descriptors(
18
+ molecule,
19
+ *,
20
+ elements_to_count=DEFAULT_ELEMENTS,
21
+ distance_bins: int = 10,
22
+ distance_range: tuple[float, float] = (0.0, 20.0),
23
+ contact_cutoff: float = 5.0,
24
+ residue_contact_cutoff: float = 8.0,
25
+ ) -> dict:
26
+ """Return a flat descriptor dictionary for a molecule.
27
+
28
+ The defaults are fixed-size and suitable for small ML tables. Matrix-valued
29
+ features such as contact maps remain available through ``mol.contact_map()``;
30
+ this function records table-friendly summaries of them.
31
+ """
32
+ coords = np.asarray(molecule.coords, dtype=float)
33
+ n_atoms = len(molecule)
34
+ masses = molecule.masses if n_atoms else np.empty(0, dtype=float)
35
+ desc = {
36
+ "n_atoms": float(n_atoms),
37
+ "n_residues": float(_n_residues(molecule)),
38
+ "molecular_mass": float(masses.sum()) if n_atoms else 0.0,
39
+ }
40
+
41
+ counts = Counter(e.upper() for e in molecule.elements if e)
42
+ for symbol in elements_to_count:
43
+ desc[f"count_{symbol.upper()}"] = float(counts.get(symbol.upper(), 0))
44
+
45
+ if n_atoms == 0:
46
+ return _empty_descriptors(desc, distance_bins)
47
+
48
+ dims = molecule.dimensions
49
+ centroid = molecule.centroid
50
+ center_of_mass = molecule.center_of_mass
51
+ desc.update({
52
+ "centroid_x": float(centroid[0]),
53
+ "centroid_y": float(centroid[1]),
54
+ "centroid_z": float(centroid[2]),
55
+ "center_of_mass_x": float(center_of_mass[0]),
56
+ "center_of_mass_y": float(center_of_mass[1]),
57
+ "center_of_mass_z": float(center_of_mass[2]),
58
+ "radius_of_gyration": molecule.radius_of_gyration,
59
+ "dim_x": float(dims[0]),
60
+ "dim_y": float(dims[1]),
61
+ "dim_z": float(dims[2]),
62
+ "bbox_volume": float(np.prod(dims)),
63
+ "compactness": _compactness(n_atoms, dims),
64
+ })
65
+
66
+ inertia = inertia_tensor(molecule)
67
+ principal_moments, principal_axes = np.linalg.eigh(inertia)
68
+ order = np.argsort(principal_moments)
69
+ principal_moments = principal_moments[order]
70
+ principal_axes = principal_axes[:, order]
71
+ desc["inertia_tensor"] = inertia.reshape(-1).astype(float).tolist()
72
+ desc["principal_moments"] = principal_moments.astype(float).tolist()
73
+ desc["principal_axes"] = principal_axes.reshape(-1).astype(float).tolist()
74
+ desc["shape_anisotropy"] = shape_anisotropy(principal_moments)
75
+
76
+ distances = _pairwise_distances(coords)
77
+ hist, _ = np.histogram(distances, bins=distance_bins, range=distance_range)
78
+ desc["distance_histogram"] = hist.astype(float).tolist()
79
+ desc.update(_bond_length_summary(molecule))
80
+ desc.update(_contact_summary(molecule, contact_cutoff))
81
+ desc.update(_residue_contact_summary(molecule, residue_contact_cutoff))
82
+ return desc
83
+
84
+
85
+ def inertia_tensor(molecule) -> np.ndarray:
86
+ """Mass-weighted inertia tensor around the centre of mass."""
87
+ coords = np.asarray(molecule.coords, dtype=float)
88
+ if len(molecule) == 0:
89
+ return np.zeros((3, 3), dtype=float)
90
+ centered = coords - molecule.center_of_mass
91
+ masses = molecule.masses
92
+ r2 = (centered ** 2).sum(axis=1)
93
+ tensor = np.eye(3) * np.sum(masses * r2)
94
+ tensor -= centered.T @ (centered * masses[:, None])
95
+ return tensor
96
+
97
+
98
+ def shape_anisotropy(principal_moments) -> float:
99
+ """Dimensionless anisotropy from principal moments of inertia."""
100
+ moments = np.asarray(principal_moments, dtype=float)
101
+ denom = float(np.sum(moments ** 2))
102
+ if denom == 0.0:
103
+ return 0.0
104
+ mean = float(moments.mean())
105
+ return float(1.5 * np.sum((moments - mean) ** 2) / denom)
106
+
107
+
108
+ def featurize_many(
109
+ paths,
110
+ *,
111
+ feature_names: Optional[list[str]] = None,
112
+ return_names: bool = False,
113
+ **descriptor_kwargs,
114
+ ):
115
+ """Read structures and return a numeric descriptor matrix.
116
+
117
+ By default columns are the union of descriptor keys found across the input
118
+ molecules. Pass ``feature_names`` to force a stable column order, or
119
+ ``return_names=True`` to receive ``(X, names)``.
120
+ """
121
+ rows = [flatten_descriptors(descriptors(read(path), **descriptor_kwargs)) for path in paths]
122
+ names = feature_names or sorted({key for row in rows for key in row})
123
+ matrix = np.array([[row.get(name, 0.0) for name in names] for row in rows], dtype=float)
124
+ return (matrix, names) if return_names else matrix
125
+
126
+
127
+ def flatten_descriptors(desc: dict) -> dict[str, float]:
128
+ """Expand list-valued descriptors into scalar columns."""
129
+ flat = {}
130
+ for key, value in desc.items():
131
+ if isinstance(value, (list, tuple, np.ndarray)):
132
+ for i, item in enumerate(value):
133
+ flat[f"{key}_{i}"] = float(item)
134
+ else:
135
+ flat[key] = float(value)
136
+ return flat
137
+
138
+
139
+ def _empty_descriptors(desc: dict, distance_bins: int) -> dict:
140
+ desc.update({
141
+ "centroid_x": 0.0,
142
+ "centroid_y": 0.0,
143
+ "centroid_z": 0.0,
144
+ "center_of_mass_x": 0.0,
145
+ "center_of_mass_y": 0.0,
146
+ "center_of_mass_z": 0.0,
147
+ "radius_of_gyration": 0.0,
148
+ "dim_x": 0.0,
149
+ "dim_y": 0.0,
150
+ "dim_z": 0.0,
151
+ "bbox_volume": 0.0,
152
+ "compactness": 0.0,
153
+ "inertia_tensor": [0.0] * 9,
154
+ "principal_moments": [0.0] * 3,
155
+ "principal_axes": [0.0] * 9,
156
+ "shape_anisotropy": 0.0,
157
+ "distance_histogram": [0.0] * distance_bins,
158
+ "bond_count": 0.0,
159
+ "bond_length_mean": 0.0,
160
+ "bond_length_std": 0.0,
161
+ "bond_length_min": 0.0,
162
+ "bond_length_max": 0.0,
163
+ "atom_contact_count": 0.0,
164
+ "atom_contact_density": 0.0,
165
+ "residue_contact_count": 0.0,
166
+ "residue_contact_density": 0.0,
167
+ })
168
+ return desc
169
+
170
+
171
+ def _pairwise_distances(coords: np.ndarray) -> np.ndarray:
172
+ n = len(coords)
173
+ if n < 2:
174
+ return np.empty(0, dtype=float)
175
+ i, j = np.triu_indices(n, k=1)
176
+ return np.linalg.norm(coords[i] - coords[j], axis=1)
177
+
178
+
179
+ def _bond_length_summary(molecule) -> dict[str, float]:
180
+ bonds = molecule.bonds()
181
+ if len(bonds) == 0:
182
+ return {
183
+ "bond_count": 0.0,
184
+ "bond_length_mean": 0.0,
185
+ "bond_length_std": 0.0,
186
+ "bond_length_min": 0.0,
187
+ "bond_length_max": 0.0,
188
+ }
189
+ lengths = np.linalg.norm(molecule.coords[bonds[:, 0]] - molecule.coords[bonds[:, 1]], axis=1)
190
+ return {
191
+ "bond_count": float(len(lengths)),
192
+ "bond_length_mean": float(lengths.mean()),
193
+ "bond_length_std": float(lengths.std()),
194
+ "bond_length_min": float(lengths.min()),
195
+ "bond_length_max": float(lengths.max()),
196
+ }
197
+
198
+
199
+ def _contact_summary(molecule, cutoff: float) -> dict[str, float]:
200
+ contacts = molecule.contacts(cutoff=cutoff)
201
+ possible = len(molecule) * (len(molecule) - 1) / 2
202
+ return {
203
+ "atom_contact_count": float(len(contacts)),
204
+ "atom_contact_density": float(len(contacts) / possible) if possible else 0.0,
205
+ }
206
+
207
+
208
+ def _residue_contact_summary(molecule, cutoff: float) -> dict[str, float]:
209
+ if len(molecule.resids) == 0:
210
+ return {"residue_contact_count": 0.0, "residue_contact_density": 0.0}
211
+ try:
212
+ matrix = molecule.contact_map(cutoff=cutoff, level="residue").matrix
213
+ except ValueError:
214
+ return {"residue_contact_count": 0.0, "residue_contact_density": 0.0}
215
+ n = len(matrix)
216
+ possible = n * (n - 1) / 2
217
+ count = float(np.triu(matrix.astype(bool), k=1).sum())
218
+ return {
219
+ "residue_contact_count": count,
220
+ "residue_contact_density": float(count / possible) if possible else 0.0,
221
+ }
222
+
223
+
224
+ def _n_residues(molecule) -> int:
225
+ if len(molecule.resids) == 0:
226
+ return 0
227
+ return sum(1 for _ in molecule.residue_groups())
228
+
229
+
230
+ def _compactness(n_atoms: int, dims: np.ndarray) -> float:
231
+ volume = float(np.prod(dims))
232
+ return float(n_atoms / volume) if volume > 0.0 else 0.0
molscope/elements.py ADDED
@@ -0,0 +1,75 @@
1
+ """Per-element reference data: CPK colours and covalent radii.
2
+
3
+ Values cover the elements common in the sample structures; anything missing
4
+ falls back to a neutral default so unknown atoms still render and bond.
5
+ """
6
+
7
+ # CPK colours (normalised RGB), the convention most molecular viewers use.
8
+ CPK_COLORS = {
9
+ "H": (1.00, 1.00, 1.00),
10
+ "C": (0.30, 0.30, 0.30),
11
+ "N": (0.10, 0.10, 0.85),
12
+ "O": (0.85, 0.10, 0.10),
13
+ "S": (0.90, 0.80, 0.20),
14
+ "P": (1.00, 0.50, 0.00),
15
+ "F": (0.30, 0.80, 0.30),
16
+ "CL": (0.20, 0.80, 0.20),
17
+ "BR": (0.60, 0.20, 0.10),
18
+ "I": (0.50, 0.10, 0.60),
19
+ "FE": (0.80, 0.40, 0.10),
20
+ "CA": (0.30, 0.70, 0.70),
21
+ "NA": (0.50, 0.20, 0.80),
22
+ "MG": (0.20, 0.60, 0.20),
23
+ "ZN": (0.50, 0.50, 0.60),
24
+ }
25
+ DEFAULT_COLOR = (0.50, 0.50, 0.50)
26
+
27
+ # Covalent radii in angstrom (Cordero et al. 2008, rounded). Used to infer bonds.
28
+ COVALENT_RADII = {
29
+ "H": 0.31, "C": 0.76, "N": 0.71, "O": 0.66, "S": 1.05,
30
+ "P": 1.07, "F": 0.57, "CL": 1.02, "BR": 1.20, "I": 1.39,
31
+ "FE": 1.32, "CA": 1.76, "NA": 1.66, "MG": 1.41, "ZN": 1.22,
32
+ }
33
+ DEFAULT_RADIUS = 0.75
34
+
35
+
36
+ # Standard atomic weights (g/mol). Unknown atoms fall back to 1.0 so that a
37
+ # mass-weighted centre over all-unknown elements reduces to the geometric mean.
38
+ ATOMIC_MASSES = {
39
+ "H": 1.008, "C": 12.011, "N": 14.007, "O": 15.999, "S": 32.06,
40
+ "P": 30.974, "F": 18.998, "CL": 35.45, "BR": 79.904, "I": 126.904,
41
+ "FE": 55.845, "CA": 40.078, "NA": 22.990, "MG": 24.305, "ZN": 65.38,
42
+ }
43
+ DEFAULT_MASS = 1.0
44
+
45
+
46
+ def color(element: str):
47
+ """CPK colour for an element symbol (case-insensitive)."""
48
+ return CPK_COLORS.get(element.upper(), DEFAULT_COLOR)
49
+
50
+
51
+ def covalent_radius(element: str) -> float:
52
+ """Covalent radius in angstrom for an element symbol (case-insensitive)."""
53
+ return COVALENT_RADII.get(element.upper(), DEFAULT_RADIUS)
54
+
55
+
56
+ def mass(element: str) -> float:
57
+ """Atomic weight in g/mol for an element symbol (case-insensitive)."""
58
+ return ATOMIC_MASSES.get(element.upper(), DEFAULT_MASS)
59
+
60
+
61
+ # Atomic numbers for the first four periods (enough for biomolecules and most
62
+ # small molecules); unknown symbols map to 0 so graph code never crashes.
63
+ ATOMIC_NUMBERS = {
64
+ "H": 1, "HE": 2, "LI": 3, "BE": 4, "B": 5, "C": 6, "N": 7, "O": 8,
65
+ "F": 9, "NE": 10, "NA": 11, "MG": 12, "AL": 13, "SI": 14, "P": 15,
66
+ "S": 16, "CL": 17, "AR": 18, "K": 19, "CA": 20, "SC": 21, "TI": 22,
67
+ "V": 23, "CR": 24, "MN": 25, "FE": 26, "CO": 27, "NI": 28, "CU": 29,
68
+ "ZN": 30, "GA": 31, "GE": 32, "AS": 33, "SE": 34, "BR": 35, "KR": 36,
69
+ "I": 53,
70
+ }
71
+
72
+
73
+ def atomic_number(element: str) -> int:
74
+ """Atomic number (Z) for an element symbol (case-insensitive); 0 if unknown."""
75
+ return ATOMIC_NUMBERS.get(element.upper(), 0)
molscope/ensemble.py ADDED
@@ -0,0 +1,143 @@
1
+ """Analysis across a set of structures, e.g. the models of an NMR ensemble."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field, replace
6
+ from typing import Optional
7
+
8
+ import numpy as np
9
+
10
+ from .molecule import Molecule
11
+
12
+
13
+ def align_all(models: list[Molecule], reference: Optional[Molecule] = None) -> list[Molecule]:
14
+ """Kabsch-superpose every model onto ``reference`` (default: the first model)."""
15
+ ref = reference if reference is not None else models[0]
16
+ return [m.superpose(ref) for m in models]
17
+
18
+
19
+ def average(models: list[Molecule], align: bool = True) -> Molecule:
20
+ """Average structure over the ensemble (atoms matched by index)."""
21
+ _check_consistent(models)
22
+ aligned = align_all(models) if align else models
23
+ coords = np.mean([m.coords for m in aligned], axis=0)
24
+ return replace(models[0], coords=coords, name=f"{models[0].name} (average)")
25
+
26
+
27
+ def rmsf(models: list[Molecule], align: bool = True) -> np.ndarray:
28
+ """Per-atom root-mean-square fluctuation about the mean position."""
29
+ _check_consistent(models)
30
+ aligned = align_all(models) if align else models
31
+ stack = np.array([m.coords for m in aligned]) # (n_models, n_atoms, 3)
32
+ mean = stack.mean(axis=0)
33
+ return np.sqrt(((stack - mean) ** 2).sum(axis=2).mean(axis=0))
34
+
35
+
36
+ def contact_frequency(models: list[Molecule], cutoff: float = 8.0,
37
+ level: str = "residue", method: str = "ca"):
38
+ """Fraction of models in which each pair is in contact (an ensemble map).
39
+
40
+ Returns a :class:`~molscope.contactmap.ContactMap` whose matrix holds
41
+ values in ``[0, 1]`` — the contact probability for each residue (or atom)
42
+ pair across the ensemble. Useful for NMR variability and folding analysis.
43
+ """
44
+ from .contactmap import ContactMap, contact_map
45
+
46
+ _check_consistent(models)
47
+ maps = [contact_map(m, cutoff=cutoff, level=level, method=method) for m in models]
48
+ freq = np.mean([cm.matrix for cm in maps], axis=0)
49
+ first = maps[0]
50
+ return ContactMap(freq, level=level, cutoff=cutoff,
51
+ labels=first.labels, resids=first.resids)
52
+
53
+
54
+ def rmsd_matrix(models: list[Molecule], align: bool = True) -> np.ndarray:
55
+ """Symmetric ``(M, M)`` matrix of pairwise RMSDs between models."""
56
+ _check_consistent(models)
57
+ n = len(models)
58
+ mat = np.zeros((n, n))
59
+ for i in range(n):
60
+ for j in range(i + 1, n):
61
+ mat[i, j] = mat[j, i] = models[i].rmsd(models[j], align=align)
62
+ return mat
63
+
64
+
65
+ def _check_consistent(models: list[Molecule]) -> None:
66
+ if not models:
67
+ raise ValueError("no models given")
68
+ sizes = {len(m) for m in models}
69
+ if len(sizes) != 1:
70
+ raise ValueError(f"models have differing atom counts: {sorted(sizes)}")
71
+
72
+
73
+ @dataclass
74
+ class Clustering:
75
+ """Result of clustering structures by RMSD.
76
+
77
+ ``labels`` gives the 1-based cluster id of each model (same order as the
78
+ input). ``matrix`` is the RMSD matrix used and ``linkage`` the scipy linkage.
79
+ """
80
+
81
+ labels: np.ndarray
82
+ matrix: np.ndarray
83
+ linkage: Optional[np.ndarray] = field(default=None)
84
+
85
+ @property
86
+ def n_clusters(self) -> int:
87
+ return int(len(np.unique(self.labels)))
88
+
89
+ @property
90
+ def order(self) -> np.ndarray:
91
+ """Model indices sorted by cluster (for a block-diagonal heatmap)."""
92
+ return np.argsort(self.labels, kind="stable")
93
+
94
+ def groups(self) -> dict[int, list[int]]:
95
+ """Map each cluster id to the list of model indices it contains."""
96
+ return {int(c): np.where(self.labels == c)[0].tolist()
97
+ for c in np.unique(self.labels)}
98
+
99
+ def medoid(self, cluster_id: int) -> int:
100
+ """Index of the most central model of a cluster (min total RMSD)."""
101
+ members = np.where(self.labels == cluster_id)[0]
102
+ sub = self.matrix[np.ix_(members, members)]
103
+ return int(members[sub.sum(axis=1).argmin()])
104
+
105
+ def representatives(self) -> dict[int, int]:
106
+ """Map each cluster id to its medoid model index."""
107
+ return {int(c): self.medoid(int(c)) for c in np.unique(self.labels)}
108
+
109
+
110
+ def cluster(models, method: str = "hierarchical", cutoff: Optional[float] = None,
111
+ n_clusters: Optional[int] = None, linkage: str = "average",
112
+ align: bool = True, matrix=None) -> Clustering:
113
+ """Cluster structures by pairwise RMSD.
114
+
115
+ Pass ``n_clusters`` to cut the tree into a fixed number of clusters, or
116
+ ``cutoff`` (an RMSD threshold in angstrom). With neither, a data-driven
117
+ cutoff (the mean pairwise RMSD) is used. Reuses ``matrix`` if given, else
118
+ computes :func:`rmsd_matrix`. Requires scipy.
119
+ """
120
+ if method != "hierarchical":
121
+ raise ValueError(f"unknown method {method!r}; only 'hierarchical' is supported")
122
+
123
+ dm = np.asarray(matrix, dtype=float) if matrix is not None else rmsd_matrix(models, align=align)
124
+ if len(dm) < 2:
125
+ return Clustering(labels=np.ones(len(dm), dtype=int), matrix=dm)
126
+
127
+ try:
128
+ from scipy.cluster.hierarchy import fcluster
129
+ from scipy.cluster.hierarchy import linkage as _linkage
130
+ from scipy.spatial.distance import squareform
131
+ except ImportError as exc: # pragma: no cover - exercised only without scipy
132
+ raise ImportError(
133
+ "clustering needs scipy; install it with: pip install 'molscope[fast]'"
134
+ ) from exc
135
+
136
+ z = _linkage(squareform(dm, checks=False), method=linkage)
137
+ if n_clusters is not None:
138
+ labels = fcluster(z, t=n_clusters, criterion="maxclust")
139
+ else:
140
+ if cutoff is None:
141
+ cutoff = float(dm[np.triu_indices_from(dm, k=1)].mean())
142
+ labels = fcluster(z, t=cutoff, criterion="distance")
143
+ return Clustering(labels=labels, matrix=dm, linkage=z)
molscope/graph.py ADDED
@@ -0,0 +1,151 @@
1
+ """Molecular graph layer.
2
+
3
+ ``Molecule.to_graph()`` turns 3D coordinates plus inferred bonds into a
4
+ :class:`MolecularGraph` — a small, dependency-free container of node and edge
5
+ data. From there it exports to the common ML graph formats:
6
+
7
+ G = mol.to_graph()
8
+ nxg = mol.to_networkx() # networkx.Graph
9
+ data = mol.to_pyg_data() # torch_geometric.data.Data
10
+ dglg = mol.to_dgl_graph() # dgl.DGLGraph
11
+
12
+ The exporters import their backend lazily, so networkx / PyTorch Geometric / DGL
13
+ are only required if you actually call the matching method.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Optional
20
+
21
+ import numpy as np
22
+
23
+ from . import elements
24
+
25
+
26
+ @dataclass
27
+ class MolecularGraph:
28
+ """Atoms as nodes, bonds as edges, with attributes attached to both.
29
+
30
+ Nodes carry element, atomic number, mass, coordinates and (when available)
31
+ atom name, residue name/id and chain. Edges carry the bonded atom pair, the
32
+ interatomic distance, and a bond order (``1.0`` for geometrically inferred
33
+ bonds, whose order is unknown).
34
+ """
35
+
36
+ coords: np.ndarray # (N, 3)
37
+ elements: list[str]
38
+ edges: np.ndarray # (E, 2), i < j
39
+ edge_distances: np.ndarray # (E,)
40
+ edge_types: np.ndarray # (E,) bond order; 1.0 when inferred/unknown
41
+ atom_names: list[str] = field(default_factory=list)
42
+ resnames: list[str] = field(default_factory=list)
43
+ resids: np.ndarray = field(default_factory=lambda: np.empty(0, dtype=int))
44
+ chains: list[str] = field(default_factory=list)
45
+ name: str = ""
46
+
47
+ @property
48
+ def n_atoms(self) -> int:
49
+ return len(self.coords)
50
+
51
+ @property
52
+ def n_bonds(self) -> int:
53
+ return len(self.edges)
54
+
55
+ @property
56
+ def atomic_numbers(self) -> np.ndarray:
57
+ return np.array([elements.atomic_number(e) for e in self.elements])
58
+
59
+ @property
60
+ def masses(self) -> np.ndarray:
61
+ return np.array([elements.mass(e) for e in self.elements])
62
+
63
+ def node_features(self) -> np.ndarray:
64
+ """Default ``(N, 2)`` node feature matrix: ``[atomic_number, mass]``."""
65
+ return np.stack([self.atomic_numbers, self.masses], axis=1).astype(float)
66
+
67
+ # -- exporters ----------------------------------------------------------
68
+
69
+ def to_networkx(self):
70
+ """Return a ``networkx.Graph`` with node and edge attributes."""
71
+ nx = _require("networkx", "networkx", "pip install networkx")
72
+ g = nx.Graph(name=self.name)
73
+ z, m = self.atomic_numbers, self.masses
74
+ for i in range(self.n_atoms):
75
+ attrs = {
76
+ "element": self.elements[i],
77
+ "atomic_number": int(z[i]),
78
+ "mass": float(m[i]),
79
+ "pos": tuple(float(c) for c in self.coords[i]),
80
+ }
81
+ if self.atom_names:
82
+ attrs["atom_name"] = self.atom_names[i]
83
+ if self.resnames:
84
+ attrs["resname"] = self.resnames[i]
85
+ if len(self.resids):
86
+ attrs["resid"] = int(self.resids[i])
87
+ if self.chains:
88
+ attrs["chain"] = self.chains[i]
89
+ g.add_node(i, **attrs)
90
+ for (i, j), dist, btype in zip(self.edges, self.edge_distances, self.edge_types):
91
+ g.add_edge(int(i), int(j), distance=float(dist), bond_type=float(btype))
92
+ return g
93
+
94
+ def to_pyg_data(self):
95
+ """Return a ``torch_geometric.data.Data`` object.
96
+
97
+ Populates ``x`` (node features), ``z`` (atomic numbers), ``pos`` (3D
98
+ coordinates), ``edge_index`` and ``edge_attr`` (distances). Edges are
99
+ made bidirectional for message passing.
100
+ """
101
+ torch = _require("torch", "PyTorch Geometric", "pip install torch torch_geometric")
102
+ Data = _require(
103
+ "torch_geometric.data", "PyTorch Geometric",
104
+ "pip install torch torch_geometric", attr="Data",
105
+ )
106
+ src, dst, dist = self._directed_edges()
107
+ return Data(
108
+ x=torch.tensor(self.node_features(), dtype=torch.float),
109
+ z=torch.tensor(self.atomic_numbers, dtype=torch.long),
110
+ pos=torch.tensor(self.coords, dtype=torch.float),
111
+ edge_index=torch.tensor(np.stack([src, dst]), dtype=torch.long),
112
+ edge_attr=torch.tensor(dist[:, None], dtype=torch.float),
113
+ num_nodes=self.n_atoms,
114
+ )
115
+
116
+ def to_dgl_graph(self):
117
+ """Return a ``dgl.DGLGraph`` with node/edge feature tensors."""
118
+ dgl = _require("dgl", "DGL", "pip install dgl")
119
+ torch = _require("torch", "DGL", "pip install dgl torch")
120
+ src, dst, dist = self._directed_edges()
121
+ g = dgl.graph(
122
+ (torch.tensor(src, dtype=torch.long), torch.tensor(dst, dtype=torch.long)),
123
+ num_nodes=self.n_atoms,
124
+ )
125
+ g.ndata["feat"] = torch.tensor(self.node_features(), dtype=torch.float)
126
+ g.ndata["z"] = torch.tensor(self.atomic_numbers, dtype=torch.long)
127
+ g.ndata["pos"] = torch.tensor(self.coords, dtype=torch.float)
128
+ g.edata["distance"] = torch.tensor(dist, dtype=torch.float)
129
+ return g
130
+
131
+ def _directed_edges(self):
132
+ """Edges in both directions: (src, dst, distance) for message passing."""
133
+ if self.n_bonds == 0:
134
+ empty_i = np.empty(0, dtype=int)
135
+ return empty_i, empty_i, np.empty(0, dtype=float)
136
+ i, j = self.edges[:, 0], self.edges[:, 1]
137
+ src = np.concatenate([i, j])
138
+ dst = np.concatenate([j, i])
139
+ dist = np.concatenate([self.edge_distances, self.edge_distances])
140
+ return src, dst, dist
141
+
142
+
143
+ def _require(module: str, feature: str, hint: str, attr: Optional[str] = None):
144
+ """Import a backend module (or attribute), raising a friendly error if absent."""
145
+ import importlib
146
+
147
+ try:
148
+ mod = importlib.import_module(module)
149
+ except ImportError as exc: # pragma: no cover - exercised only when missing
150
+ raise ImportError(f"{feature} is required for this export; {hint}") from exc
151
+ return getattr(mod, attr) if attr else mod