molscope 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
molscope/io.py ADDED
@@ -0,0 +1,342 @@
1
+ """Readers and writers for molecular coordinate files.
2
+
3
+ ``read`` dispatches on file extension; the individual readers/writers can also
4
+ be called directly. PDB parsing uses fixed columns (not whitespace splitting),
5
+ which is the only correct way to read the format. Per-atom metadata (atom name,
6
+ residue, chain) is captured where the format provides it.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import gzip
12
+ import os
13
+ import tempfile
14
+ from typing import Optional
15
+ from urllib.request import urlopen
16
+
17
+ import numpy as np
18
+
19
+ from .molecule import Molecule
20
+
21
+ # altLoc codes we accept: blank (no alternates) or the primary "A" conformation.
22
+ _PRIMARY_ALTLOCS = (" ", "A", "")
23
+
24
+
25
+ def read(path: str) -> Molecule:
26
+ """Read a molecule, picking the parser from the file extension.
27
+
28
+ Transparently handles gzip-compressed files (``.pdb.gz``, ``.xyz.gz``).
29
+ """
30
+ ext = _data_extension(path)
31
+ if ext == ".pdb":
32
+ return read_pdb(path)
33
+ if ext == ".xyz":
34
+ return read_xyz(path)
35
+ if ext in (".cif", ".mmcif"):
36
+ return read_cif(path)
37
+ if ext in (".sdf", ".mol"):
38
+ return read_sdf(path)
39
+ raise ValueError(f"Unsupported file type {ext!r}; expected .pdb/.xyz/.cif/.sdf")
40
+
41
+
42
+ def fetch(pdb_id: str, fmt: str = "pdb", cache_dir: Optional[str] = None) -> Molecule:
43
+ """Download a structure from RCSB by its PDB id and read it.
44
+
45
+ ``fmt`` is ``"pdb"`` or ``"cif"``. Files are cached (default: the system temp
46
+ directory) so repeat calls don't re-download. Example: ``ms.fetch("1fqy")``.
47
+ """
48
+ fmt = fmt.lower()
49
+ if fmt not in ("pdb", "cif"):
50
+ raise ValueError("fmt must be 'pdb' or 'cif'")
51
+ pdb_id = pdb_id.lower()
52
+ cache_dir = cache_dir or os.path.join(tempfile.gettempdir(), "molscope_cache")
53
+ os.makedirs(cache_dir, exist_ok=True)
54
+ dest = os.path.join(cache_dir, f"{pdb_id}.{fmt}")
55
+ if not os.path.exists(dest):
56
+ with urlopen(f"https://files.rcsb.org/download/{pdb_id}.{fmt}") as resp:
57
+ data = resp.read()
58
+ with open(dest, "wb") as fh:
59
+ fh.write(data)
60
+ return read(dest)
61
+
62
+
63
+ def read_xyz(path: str) -> Molecule:
64
+ """Read a single-frame ``.xyz`` file.
65
+
66
+ Handles both the standard ``element x y z`` layout and the bare
67
+ ``x y z`` coordinate dumps (with ``#`` comment lines) used by some tools.
68
+ """
69
+ frames = read_xyz_frames(path)
70
+ if not frames:
71
+ return Molecule(np.empty((0, 3)), [], name=_stem(path))
72
+ return frames[0]
73
+
74
+
75
+ def read_xyz_frames(path: str) -> list[Molecule]:
76
+ """Read every frame of a (possibly multi-frame) ``.xyz`` trajectory.
77
+
78
+ Standard xyz frames begin with an atom-count line; bare coordinate dumps
79
+ with ``#`` comments are returned as a single frame.
80
+ """
81
+ with _open(path) as f:
82
+ lines = f.readlines()
83
+
84
+ stem = _stem(path)
85
+ frames: list[Molecule] = []
86
+ i = 0
87
+ n_lines = len(lines)
88
+ while i < n_lines:
89
+ tokens = lines[i].split()
90
+ if tokens and tokens[0].isdigit():
91
+ count = int(tokens[0])
92
+ block = lines[i + 2:i + 2 + count]
93
+ frames.append(_xyz_block(block, name=f"{stem}#{len(frames) + 1}"))
94
+ i += 2 + count
95
+ else:
96
+ # Bare coordinate dump (no header): consume the rest as one frame.
97
+ frames.append(_xyz_block(lines[i:], name=stem))
98
+ break
99
+ return frames
100
+
101
+
102
+ def read_pdb(path: str, model: int = 1) -> Molecule:
103
+ """Read ``ATOM``/``HETATM`` records from a ``.pdb`` file.
104
+
105
+ Coordinates, element, atom name, residue name/id and chain are sliced from
106
+ their fixed columns per the PDB spec. Alternate conformations (altLoc) other
107
+ than the primary one are skipped. For multi-model (NMR) files the 1-based
108
+ ``model`` is returned; files without ``MODEL`` records are read in full.
109
+ """
110
+ models = _parse_pdb_models(path)
111
+ if not models:
112
+ return Molecule(np.empty((0, 3)), [], name=_stem(path))
113
+ if not 1 <= model <= len(models):
114
+ raise ValueError(f"model {model} out of range (1..{len(models)})")
115
+ return _molecule_from_record(models[model - 1], _stem(path))
116
+
117
+
118
+ def read_pdb_models(path: str) -> list[Molecule]:
119
+ """Read every model from a ``.pdb`` file as a list of molecules."""
120
+ stem = _stem(path)
121
+ return [
122
+ _molecule_from_record(rec, f"{stem}#{i + 1}")
123
+ for i, rec in enumerate(_parse_pdb_models(path))
124
+ ]
125
+
126
+
127
+ def read_cif(path: str) -> Molecule:
128
+ """Basic mmCIF reader for standard ``_atom_site`` coordinate loops.
129
+
130
+ This parser handles simple whitespace-separated atom-site rows. It is not a
131
+ full mmCIF syntax implementation for quoted values, multiline fields, or
132
+ complex loop constructs.
133
+ """
134
+ columns: list[str] = []
135
+ rows: list[list[str]] = []
136
+ in_atom_site = False
137
+ with _open(path) as f:
138
+ for raw in f:
139
+ line = raw.strip()
140
+ if line.startswith("_atom_site."):
141
+ columns.append(line.split(".", 1)[1])
142
+ in_atom_site = True
143
+ continue
144
+ if in_atom_site:
145
+ if not line or line.startswith(("_", "#", "loop_")):
146
+ break # end of the data block
147
+ rows.append(line.split())
148
+
149
+ if not columns or not rows:
150
+ raise ValueError(f"no _atom_site records found in {path}")
151
+ idx = {name: i for i, name in enumerate(columns)}
152
+
153
+ def col(row, *names, default=""):
154
+ for nm in names:
155
+ if nm in idx and idx[nm] < len(row):
156
+ return row[idx[nm]]
157
+ return default
158
+
159
+ coords, els, anames, rnames, rids, chains = [], [], [], [], [], []
160
+ for row in rows:
161
+ coords.append((
162
+ float(row[idx["Cartn_x"]]),
163
+ float(row[idx["Cartn_y"]]),
164
+ float(row[idx["Cartn_z"]]),
165
+ ))
166
+ els.append(col(row, "type_symbol"))
167
+ anames.append(col(row, "label_atom_id", "auth_atom_id"))
168
+ rnames.append(col(row, "label_comp_id", "auth_comp_id"))
169
+ chains.append(col(row, "auth_asym_id", "label_asym_id"))
170
+ rid = col(row, "auth_seq_id", "label_seq_id", default="0")
171
+ rids.append(int(rid) if rid.lstrip("-").isdigit() else 0)
172
+
173
+ return Molecule(
174
+ np.array(coords, dtype=float), els, name=_stem(path),
175
+ atom_names=anames, resnames=rnames, resids=np.array(rids, dtype=int),
176
+ chains=chains,
177
+ )
178
+
179
+
180
+ def read_sdf(path: str) -> Molecule:
181
+ """Read the first molecule from an SDF / MDL MOL (V2000) file."""
182
+ with _open(path) as f:
183
+ lines = f.readlines()
184
+ if len(lines) < 4:
185
+ raise ValueError(f"{path}: too short to be a MOL file")
186
+ counts = lines[3]
187
+ n_atoms = int(counts[:3])
188
+ coords, els = [], []
189
+ for line in lines[4:4 + n_atoms]:
190
+ coords.append((float(line[0:10]), float(line[10:20]), float(line[20:30])))
191
+ els.append(line[31:34].strip())
192
+ return Molecule(np.array(coords, dtype=float), els, name=_stem(path))
193
+
194
+
195
+ def write_xyz(molecule: Molecule, path: str) -> None:
196
+ """Write a molecule to an ``.xyz`` file (unknown elements written as ``X``)."""
197
+ with _open(path, "w") as f:
198
+ f.write(f"{len(molecule)}\n{molecule.name}\n")
199
+ for element, (x, y, z) in zip(molecule.elements, molecule.coords):
200
+ f.write(f"{element or 'X':<2} {x:15.8f} {y:15.8f} {z:15.8f}\n")
201
+
202
+
203
+ def write_pdb(molecule: Molecule, path: str) -> None:
204
+ """Write a molecule to a ``.pdb`` file, preserving metadata when present."""
205
+ with _open(path, "w") as f:
206
+ f.write(_molecule_to_pdb_string(molecule))
207
+
208
+
209
+ def _molecule_to_pdb_string(molecule: Molecule) -> str:
210
+ """Serialise a molecule to PDB text, preserving metadata when present."""
211
+ n = len(molecule)
212
+ names = molecule.atom_names or [e or "X" for e in molecule.elements]
213
+ resnames = molecule.resnames or ["MOL"] * n
214
+ chains = molecule.chains or ["A"] * n
215
+ resids = molecule.resids if len(molecule.resids) else np.ones(n, dtype=int)
216
+ lines = [
217
+ _pdb_atom_line(
218
+ serial + 1, names[serial], resnames[serial], chains[serial] or "A",
219
+ int(resids[serial]), molecule.elements[serial], *molecule.coords[serial],
220
+ )
221
+ for serial in range(n)
222
+ ]
223
+ lines.append("END\n")
224
+ return "".join(lines)
225
+
226
+
227
+ # -- internals --------------------------------------------------------------
228
+
229
+
230
+ def _xyz_block(block: list[str], name: str) -> Molecule:
231
+ coords, elements = [], []
232
+ for line in block:
233
+ stripped = line.strip()
234
+ if not stripped or stripped.startswith("#"):
235
+ continue
236
+ tokens = stripped.split()
237
+ if len(tokens) >= 4 and not _is_float(tokens[0]):
238
+ elements.append(tokens[0])
239
+ coords.append(tuple(float(t) for t in tokens[1:4]))
240
+ else:
241
+ elements.append("")
242
+ coords.append(tuple(float(t) for t in tokens[:3]))
243
+ return Molecule(np.array(coords, dtype=float), elements, name=name)
244
+
245
+
246
+ def _parse_pdb_models(path: str) -> list[dict]:
247
+ """Return a list of per-model records (dict of parallel atom arrays)."""
248
+ models: list[dict] = []
249
+ cur = _new_record()
250
+
251
+ def flush():
252
+ nonlocal cur
253
+ if cur["coords"]:
254
+ models.append(cur)
255
+ cur = _new_record()
256
+
257
+ with _open(path) as f:
258
+ for line in f:
259
+ record = line[:6].strip()
260
+ if record in ("MODEL", "ENDMDL"):
261
+ flush()
262
+ elif record in ("ATOM", "HETATM"):
263
+ altloc = line[16] if len(line) > 16 else " "
264
+ if altloc not in _PRIMARY_ALTLOCS:
265
+ continue
266
+ cur["coords"].append((
267
+ float(line[30:38]), float(line[38:46]), float(line[46:54])
268
+ ))
269
+ element = line[76:78].strip()
270
+ if not element:
271
+ element = line[12:16].strip().lstrip("0123456789")[:2]
272
+ cur["elements"].append(element)
273
+ cur["atom_names"].append(line[12:16].strip())
274
+ cur["resnames"].append(line[17:20].strip())
275
+ cur["chains"].append(line[21].strip())
276
+ resid = line[22:26].strip()
277
+ cur["resids"].append(int(resid) if resid.lstrip("-").isdigit() else 0)
278
+ flush()
279
+ return models
280
+
281
+
282
+ def _new_record() -> dict:
283
+ return {k: [] for k in ("coords", "elements", "atom_names", "resnames",
284
+ "chains", "resids")}
285
+
286
+
287
+ def _molecule_from_record(rec: dict, name: str) -> Molecule:
288
+ return Molecule(
289
+ np.array(rec["coords"], dtype=float), rec["elements"], name=name,
290
+ atom_names=rec["atom_names"], resnames=rec["resnames"],
291
+ resids=np.array(rec["resids"], dtype=int), chains=rec["chains"],
292
+ )
293
+
294
+
295
+ def _pdb_atom_line(serial, atom_name, resname, chain, resid, element, x, y, z):
296
+ """Build a single fixed-column ``ATOM`` record (80 columns)."""
297
+ line = list(" " * 80)
298
+
299
+ def put(value: str, start: int): # start is 1-based
300
+ line[start - 1:start - 1 + len(value)] = value
301
+
302
+ put("ATOM", 1)
303
+ put(f"{serial:>5}", 7)
304
+ put(f"{(atom_name or 'X')[:4]:<4}", 13)
305
+ put(f"{(resname or 'MOL')[:3]:>3}", 18)
306
+ put((chain or "A")[0], 22)
307
+ put(f"{resid:>4}", 23)
308
+ put(f"{x:8.3f}", 31)
309
+ put(f"{y:8.3f}", 39)
310
+ put(f"{z:8.3f}", 47)
311
+ put(f"{1.0:6.2f}", 55)
312
+ put(f"{0.0:6.2f}", 61)
313
+ put(f"{(element or 'X'):>2}", 77)
314
+ return "".join(line).rstrip() + "\n"
315
+
316
+
317
+ def _open(path: str, mode: str = "r"):
318
+ """Open a file, transparently handling gzip by the ``.gz`` suffix."""
319
+ if path.endswith(".gz"):
320
+ return gzip.open(path, mode + "t")
321
+ return open(path, mode)
322
+
323
+
324
+ def _data_extension(path: str) -> str:
325
+ """Extension ignoring a trailing ``.gz`` (so ``a.pdb.gz`` -> ``.pdb``)."""
326
+ base = path[:-3] if path.endswith(".gz") else path
327
+ return os.path.splitext(base)[1].lower()
328
+
329
+
330
+ def _is_float(token: str) -> bool:
331
+ try:
332
+ float(token)
333
+ return True
334
+ except ValueError:
335
+ return False
336
+
337
+
338
+ def _stem(path: str) -> str:
339
+ base = os.path.basename(path)
340
+ if base.endswith(".gz"):
341
+ base = base[:-3]
342
+ return os.path.splitext(base)[0]