molforge 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. molforge/__init__.py +18 -0
  2. molforge/core/__init__.py +64 -0
  3. molforge/core/atom.py +154 -0
  4. molforge/core/atom_array.py +318 -0
  5. molforge/core/chain.py +149 -0
  6. molforge/core/constants.py +145 -0
  7. molforge/core/protein.py +166 -0
  8. molforge/core/residue.py +160 -0
  9. molforge/docking/__init__.py +143 -0
  10. molforge/io/__init__.py +88 -0
  11. molforge/io/dispatch.py +156 -0
  12. molforge/io/fasta.py +185 -0
  13. molforge/io/mmcif.py +559 -0
  14. molforge/io/mol2.py +33 -0
  15. molforge/io/pdb.py +635 -0
  16. molforge/io/pdb_alphafold.py +91 -0
  17. molforge/io/pdbqt.py +33 -0
  18. molforge/io/pqr.py +33 -0
  19. molforge/io/sdf.py +33 -0
  20. molforge/md/__init__.py +23 -0
  21. molforge/metrics/__init__.py +25 -0
  22. molforge/ml/__init__.py +25 -0
  23. molforge/plugins/__init__.py +34 -0
  24. molforge/plugins/registry.py +48 -0
  25. molforge/py.typed +0 -0
  26. molforge/sequence/__init__.py +69 -0
  27. molforge/sequence/alignment.py +385 -0
  28. molforge/sequence/composition.py +132 -0
  29. molforge/sequence/matrices.py +99 -0
  30. molforge/sequence/mutations.py +195 -0
  31. molforge/structure/__init__.py +85 -0
  32. molforge/structure/contacts.py +175 -0
  33. molforge/structure/dssp.py +428 -0
  34. molforge/structure/geometry.py +144 -0
  35. molforge/structure/rmsd.py +171 -0
  36. molforge/structure/superposition.py +135 -0
  37. molforge/wrappers/__init__.py +11 -0
  38. molforge/wrappers/docking/__init__.py +34 -0
  39. molforge/wrappers/docking/_base.py +7 -0
  40. molforge/wrappers/docking/diffdock.py +17 -0
  41. molforge/wrappers/docking/vina.py +364 -0
  42. molforge/wrappers/folding/__init__.py +38 -0
  43. molforge/wrappers/folding/_base.py +102 -0
  44. molforge/wrappers/folding/alphafold.py +20 -0
  45. molforge/wrappers/folding/boltz.py +20 -0
  46. molforge/wrappers/folding/esmfold.py +234 -0
  47. molforge/wrappers/folding/rosetta.py +19 -0
  48. molforge/wrappers/md/__init__.py +8 -0
  49. molforge/wrappers/md/_base.py +13 -0
  50. molforge/wrappers/md/gromacs.py +18 -0
  51. molforge/wrappers/md/openmm.py +18 -0
  52. molforge-0.0.1.dist-info/METADATA +246 -0
  53. molforge-0.0.1.dist-info/RECORD +55 -0
  54. molforge-0.0.1.dist-info/WHEEL +4 -0
  55. molforge-0.0.1.dist-info/licenses/LICENSE +21 -0
molforge/__init__.py ADDED
@@ -0,0 +1,18 @@
1
+ """molforge — a unified library for structural bioinformatics, MD, and ML.
2
+
3
+ This package exposes a small top-level surface. Subpackages are the primary
4
+ import points; users should typically import them directly:
5
+
6
+ >>> from molforge.core import Protein, Chain, Residue, Atom
7
+ >>> from molforge.io import load, save
8
+ >>> from molforge.structure import rmsd
9
+
10
+ `molforge` is a *library*, not a framework: there is no runtime, no
11
+ orchestration layer, and no required entry point. Import what you need.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ __version__ = "0.0.1"
17
+
18
+ __all__ = ["__version__"]
@@ -0,0 +1,64 @@
1
+ """Core data model: hierarchical and linear views of protein structure.
2
+
3
+ The :class:`AtomArray` is the *canonical* representation — a flat,
4
+ NumPy-backed array of all atoms. The hierarchical classes
5
+ (:class:`Protein`, :class:`Chain`, :class:`Residue`, :class:`Atom`) are
6
+ lightweight views that read and write through to the array.
7
+
8
+ Typical usage:
9
+
10
+ >>> from molforge.core import Protein, AtomArray
11
+ >>> protein = Protein(atom_array=AtomArray(0), name="example")
12
+ >>> protein.n_atoms
13
+ 0
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from molforge.core.atom import Atom
19
+ from molforge.core.atom_array import (
20
+ ATOM_FIELDS,
21
+ AtomArray,
22
+ BoolArray,
23
+ FloatArray,
24
+ IntArray,
25
+ StrArray,
26
+ )
27
+ from molforge.core.chain import Chain
28
+ from molforge.core.constants import (
29
+ NUCLEOTIDE_TO_ONE,
30
+ ONE_TO_THREE,
31
+ PROTEIN_BACKBONE_ATOMS,
32
+ THREE_TO_ONE,
33
+ is_ion,
34
+ is_standard_amino_acid,
35
+ is_water,
36
+ three_to_one,
37
+ )
38
+ from molforge.core.protein import Protein
39
+ from molforge.core.residue import Residue
40
+
41
+ __all__ = [ # noqa: RUF022 — grouped by concept, not alphabetical
42
+ # Hierarchical
43
+ "Atom",
44
+ "Residue",
45
+ "Chain",
46
+ "Protein",
47
+ # Linear
48
+ "AtomArray",
49
+ "ATOM_FIELDS",
50
+ # Type aliases
51
+ "BoolArray",
52
+ "FloatArray",
53
+ "IntArray",
54
+ "StrArray",
55
+ # Constants & helpers
56
+ "THREE_TO_ONE",
57
+ "ONE_TO_THREE",
58
+ "NUCLEOTIDE_TO_ONE",
59
+ "PROTEIN_BACKBONE_ATOMS",
60
+ "three_to_one",
61
+ "is_standard_amino_acid",
62
+ "is_water",
63
+ "is_ion",
64
+ ]
molforge/core/atom.py ADDED
@@ -0,0 +1,154 @@
1
+ """Atom — a lightweight view onto a single atom in an :class:`AtomArray`.
2
+
3
+ An ``Atom`` does not own its data. It holds a reference to the parent
4
+ ``AtomArray`` and an integer index. All property accesses read from the
5
+ underlying arrays; all mutations write through.
6
+
7
+ This makes ``Atom`` cheap to create (no copy) and guarantees consistency
8
+ with the linear view.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import TYPE_CHECKING
14
+
15
+ import numpy as np
16
+ from numpy.typing import NDArray
17
+
18
+ if TYPE_CHECKING:
19
+ from molforge.core.atom_array import AtomArray
20
+ from molforge.core.residue import Residue
21
+
22
+ from molforge.core.constants import PROTEIN_BACKBONE_ATOMS
23
+
24
+
25
+ class Atom:
26
+ """View of a single atom in an :class:`AtomArray`.
27
+
28
+ Attributes are read/written through to the underlying array, so
29
+ mutating an ``Atom`` mutates the source-of-truth representation.
30
+ """
31
+
32
+ __slots__ = ("_array", "_index", "_parent")
33
+
34
+ def __init__(
35
+ self,
36
+ array: AtomArray,
37
+ index: int,
38
+ *,
39
+ parent: Residue | None = None,
40
+ ) -> None:
41
+ if not 0 <= index < len(array):
42
+ raise IndexError(f"index {index} out of bounds for array of length {len(array)}")
43
+ self._array = array
44
+ self._index = index
45
+ self._parent = parent
46
+
47
+ # ------------------------------------------------------------------
48
+ # Identity / context
49
+ # ------------------------------------------------------------------
50
+ @property
51
+ def index(self) -> int:
52
+ """The atom's index into the parent :class:`AtomArray`."""
53
+ return self._index
54
+
55
+ @property
56
+ def parent(self) -> Residue | None:
57
+ """The containing residue, if known."""
58
+ return self._parent
59
+
60
+ # ------------------------------------------------------------------
61
+ # Field accessors (read/write through to the array)
62
+ # ------------------------------------------------------------------
63
+ @property
64
+ def name(self) -> str:
65
+ return str(self._array.atom_name[self._index])
66
+
67
+ @name.setter
68
+ def name(self, value: str) -> None:
69
+ self._array.atom_name[self._index] = value
70
+
71
+ @property
72
+ def element(self) -> str:
73
+ return str(self._array.element[self._index])
74
+
75
+ @element.setter
76
+ def element(self, value: str) -> None:
77
+ self._array.element[self._index] = value
78
+
79
+ @property
80
+ def coord(self) -> NDArray[np.float32]:
81
+ """The atom's 3-D coordinate as a (3,) float32 view (mutable)."""
82
+ return self._array.coords[self._index]
83
+
84
+ @coord.setter
85
+ def coord(self, value: NDArray[np.float32]) -> None:
86
+ self._array.coords[self._index] = value
87
+
88
+ @property
89
+ def b_factor(self) -> float:
90
+ return float(self._array.b_factor[self._index])
91
+
92
+ @b_factor.setter
93
+ def b_factor(self, value: float) -> None:
94
+ self._array.b_factor[self._index] = value
95
+
96
+ @property
97
+ def occupancy(self) -> float:
98
+ return float(self._array.occupancy[self._index])
99
+
100
+ @occupancy.setter
101
+ def occupancy(self, value: float) -> None:
102
+ self._array.occupancy[self._index] = value
103
+
104
+ @property
105
+ def charge(self) -> float:
106
+ return float(self._array.charge[self._index])
107
+
108
+ @charge.setter
109
+ def charge(self, value: float) -> None:
110
+ self._array.charge[self._index] = value
111
+
112
+ @property
113
+ def serial(self) -> int:
114
+ return int(self._array.serial[self._index])
115
+
116
+ @serial.setter
117
+ def serial(self, value: int) -> None:
118
+ self._array.serial[self._index] = value
119
+
120
+ @property
121
+ def altloc(self) -> str:
122
+ return str(self._array.altloc[self._index])
123
+
124
+ @altloc.setter
125
+ def altloc(self, value: str) -> None:
126
+ self._array.altloc[self._index] = value
127
+
128
+ @property
129
+ def record_type(self) -> str:
130
+ return str(self._array.record_type[self._index])
131
+
132
+ # ------------------------------------------------------------------
133
+ # Derived / convenience
134
+ # ------------------------------------------------------------------
135
+ @property
136
+ def is_backbone(self) -> bool:
137
+ """True if this is a standard protein backbone atom (N, CA, C, O, OXT)."""
138
+ return self.name in PROTEIN_BACKBONE_ATOMS
139
+
140
+ @property
141
+ def is_hetero(self) -> bool:
142
+ """True if this atom comes from a HETATM record."""
143
+ return self.record_type == "HETATM"
144
+
145
+ def __repr__(self) -> str:
146
+ return f"Atom(name={self.name!r}, element={self.element!r}, index={self._index})"
147
+
148
+ def __eq__(self, other: object) -> bool:
149
+ if not isinstance(other, Atom):
150
+ return NotImplemented
151
+ return self._array is other._array and self._index == other._index
152
+
153
+ def __hash__(self) -> int:
154
+ return hash((id(self._array), self._index))
@@ -0,0 +1,318 @@
1
+ """AtomArray — the canonical, flat, NumPy-backed representation.
2
+
3
+ This is the *source of truth* for a protein's atomic data. The hierarchical
4
+ classes (`Atom`, `Residue`, `Chain`, `Protein`) are thin accessors that hold
5
+ an :class:`AtomArray` reference plus index slices into it.
6
+
7
+ Design notes:
8
+ - All per-atom fields are parallel NumPy arrays of length ``N``.
9
+ - Coordinates are ``(N, 3)`` float32 — float32 is enough for any single PDB
10
+ and halves the memory of float64. Promote on demand if you need it.
11
+ - String fields use NumPy unicode dtypes (``"U1"``, ``"U3"``, ``"U4"``)
12
+ rather than Python ``object`` arrays — keeps memory predictable and
13
+ enables vectorized comparisons.
14
+ - The class exposes a residue-/chain-boundary index (`_chain_starts`,
15
+ `_residue_starts`) that's built lazily and cached. Any mutation that
16
+ changes residue/chain identity invalidates the cache.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from typing import TYPE_CHECKING
22
+
23
+ import numpy as np
24
+ from numpy.typing import NDArray
25
+
26
+ if TYPE_CHECKING:
27
+ from collections.abc import Iterable
28
+
29
+ # Public type aliases for users.
30
+ FloatArray = NDArray[np.float32]
31
+ IntArray = NDArray[np.int32]
32
+ StrArray = NDArray[np.str_]
33
+ BoolArray = NDArray[np.bool_]
34
+
35
+
36
+ # Field schema — single source of truth for column names, dtypes, and defaults.
37
+ # Used by AtomArray constructors and IO routines so adding a column is a
38
+ # one-line change here.
39
+ _FIELD_SCHEMA: dict[str, tuple[str, object]] = {
40
+ # name dtype default
41
+ "coords": ("float32", 0.0), # special-cased: shape (N, 3)
42
+ "element": ("U2", ""),
43
+ "atom_name": ("U4", ""),
44
+ "residue_name": ("U3", ""),
45
+ "residue_id": ("int32", 0),
46
+ "insertion_code": ("U1", ""),
47
+ "chain_id": ("U4", ""), # 4 chars to support mmCIF auth_asym_id
48
+ "b_factor": ("float32", 0.0),
49
+ "occupancy": ("float32", 1.0),
50
+ "charge": ("float32", 0.0),
51
+ "serial": ("int32", 0),
52
+ "record_type": ("U6", "ATOM"), # "ATOM" or "HETATM"
53
+ "entity_type": ("U8", "protein"), # protein, dna, rna, ligand, water, ion, other
54
+ "altloc": ("U1", ""),
55
+ "model_id": ("int32", 0), # for multi-model NMR / trajectories
56
+ }
57
+
58
+ ATOM_FIELDS: tuple[str, ...] = tuple(_FIELD_SCHEMA.keys())
59
+
60
+
61
+ class AtomArray:
62
+ """Flat, NumPy-backed array of atoms.
63
+
64
+ This is the canonical representation; hierarchical views read from
65
+ these arrays. All per-atom fields have shape ``(N,)`` except
66
+ ``coords`` which has shape ``(N, 3)``.
67
+
68
+ Example:
69
+ >>> aa = AtomArray.empty(3)
70
+ >>> aa.element[:] = ["C", "N", "O"]
71
+ >>> aa.coords[0] = [1.0, 2.0, 3.0]
72
+ >>> len(aa)
73
+ 3
74
+ """
75
+
76
+ __slots__ = (
77
+ "_chain_starts_cache",
78
+ "_residue_starts_cache",
79
+ "altloc",
80
+ "atom_name",
81
+ "b_factor",
82
+ "chain_id",
83
+ "charge",
84
+ "coords",
85
+ "element",
86
+ "entity_type",
87
+ "insertion_code",
88
+ "model_id",
89
+ "occupancy",
90
+ "record_type",
91
+ "residue_id",
92
+ "residue_name",
93
+ "serial",
94
+ )
95
+
96
+ coords: FloatArray
97
+ element: StrArray
98
+ atom_name: StrArray
99
+ residue_name: StrArray
100
+ residue_id: IntArray
101
+ insertion_code: StrArray
102
+ chain_id: StrArray
103
+ b_factor: FloatArray
104
+ occupancy: FloatArray
105
+ charge: FloatArray
106
+ serial: IntArray
107
+ record_type: StrArray
108
+ entity_type: StrArray
109
+ altloc: StrArray
110
+ model_id: IntArray
111
+
112
+ # ------------------------------------------------------------------
113
+ # Construction
114
+ # ------------------------------------------------------------------
115
+ def __init__(self, n: int = 0) -> None:
116
+ """Create an empty array of ``n`` atoms, all fields at default values."""
117
+ if n < 0:
118
+ raise ValueError(f"n must be non-negative, got {n}")
119
+ self.coords = np.zeros((n, 3), dtype=np.float32)
120
+ for field, (dtype, default) in _FIELD_SCHEMA.items():
121
+ if field == "coords":
122
+ continue
123
+ arr = np.empty(n, dtype=dtype)
124
+ arr[:] = default
125
+ object.__setattr__(self, field, arr)
126
+ self._chain_starts_cache = None
127
+ self._residue_starts_cache = None
128
+
129
+ @classmethod
130
+ def empty(cls, n: int) -> AtomArray:
131
+ """Alias for ``AtomArray(n)`` — more readable at call sites."""
132
+ return cls(n)
133
+
134
+ @classmethod
135
+ def from_dict(cls, data: dict[str, NDArray]) -> AtomArray:
136
+ """Construct from a dict of equal-length arrays.
137
+
138
+ Args:
139
+ data: Mapping field-name -> array. Must include ``coords``;
140
+ missing fields are filled with their schema defaults.
141
+
142
+ Raises:
143
+ KeyError: If ``coords`` is missing.
144
+ ValueError: If array lengths disagree.
145
+ """
146
+ if "coords" not in data:
147
+ raise KeyError("`coords` is required to construct an AtomArray")
148
+ n = data["coords"].shape[0]
149
+ for name, arr in data.items():
150
+ if name == "coords":
151
+ continue
152
+ if arr.shape[0] != n:
153
+ raise ValueError(f"Field {name!r} has length {arr.shape[0]}, expected {n}")
154
+ out = cls(n)
155
+ for name, arr in data.items():
156
+ if name not in _FIELD_SCHEMA:
157
+ raise KeyError(f"Unknown field {name!r}; valid: {ATOM_FIELDS}")
158
+ setattr(out, name, np.asarray(arr))
159
+ out._invalidate_cache()
160
+ return out
161
+
162
+ # ------------------------------------------------------------------
163
+ # Dunder methods
164
+ # ------------------------------------------------------------------
165
+ def __len__(self) -> int:
166
+ return int(self.coords.shape[0])
167
+
168
+ def __repr__(self) -> str:
169
+ return f"AtomArray(n_atoms={len(self)})"
170
+
171
+ def __getitem__(self, key: int | slice | NDArray) -> AtomArray:
172
+ """Slice or fancy-index the array; returns a new AtomArray (copy)."""
173
+ if isinstance(key, int):
174
+ # Single-atom selection still returns an AtomArray of length 1
175
+ # so the API stays uniform. Use `.to_atom(i)` for a hierarchical Atom view.
176
+ key = slice(key, key + 1)
177
+ out = AtomArray(0)
178
+ out.coords = np.ascontiguousarray(self.coords[key])
179
+ for field in ATOM_FIELDS:
180
+ if field == "coords":
181
+ continue
182
+ setattr(out, field, np.ascontiguousarray(getattr(self, field)[key]))
183
+ out._invalidate_cache()
184
+ return out
185
+
186
+ # ------------------------------------------------------------------
187
+ # Mutation helpers
188
+ # ------------------------------------------------------------------
189
+ def _invalidate_cache(self) -> None:
190
+ """Drop the residue / chain boundary caches. Call after mutations."""
191
+ object.__setattr__(self, "_chain_starts_cache", None)
192
+ object.__setattr__(self, "_residue_starts_cache", None)
193
+
194
+ def append(self, other: AtomArray) -> AtomArray:
195
+ """Return a new array with ``other`` concatenated after this one."""
196
+ if not isinstance(other, AtomArray):
197
+ raise TypeError(f"expected AtomArray, got {type(other).__name__}")
198
+ out = AtomArray(0)
199
+ out.coords = np.concatenate([self.coords, other.coords], axis=0)
200
+ for field in ATOM_FIELDS:
201
+ if field == "coords":
202
+ continue
203
+ out_arr = np.concatenate([getattr(self, field), getattr(other, field)])
204
+ setattr(out, field, out_arr)
205
+ out._invalidate_cache()
206
+ return out
207
+
208
+ # ------------------------------------------------------------------
209
+ # Selection
210
+ # ------------------------------------------------------------------
211
+ def select(self, mask: BoolArray) -> AtomArray:
212
+ """Return a new AtomArray containing only atoms where ``mask`` is True.
213
+
214
+ Args:
215
+ mask: Boolean array of length ``len(self)``.
216
+ """
217
+ mask = np.asarray(mask, dtype=bool)
218
+ if mask.shape != (len(self),):
219
+ raise ValueError(f"mask shape {mask.shape} does not match atom count ({len(self)},)")
220
+ return self[mask]
221
+
222
+ def where(self, **filters: object) -> BoolArray:
223
+ """Build a boolean mask from equality filters on any field.
224
+
225
+ Example:
226
+ >>> mask = aa.where(chain_id="A", atom_name="CA")
227
+ >>> ca_atoms = aa.select(mask)
228
+ """
229
+ mask = np.ones(len(self), dtype=bool)
230
+ for field, value in filters.items():
231
+ if field not in _FIELD_SCHEMA:
232
+ raise KeyError(f"Unknown field {field!r}; valid: {ATOM_FIELDS}")
233
+ arr = getattr(self, field)
234
+ if isinstance(value, (list, tuple, set, np.ndarray)):
235
+ mask &= np.isin(arr, list(value))
236
+ else:
237
+ mask &= arr == value
238
+ return mask
239
+
240
+ # ------------------------------------------------------------------
241
+ # Boundary indices (chains / residues)
242
+ # ------------------------------------------------------------------
243
+ @property
244
+ def chain_starts(self) -> IntArray:
245
+ """Indices of the first atom of each chain, in order.
246
+
247
+ A chain boundary is any change in ``chain_id`` or ``model_id``.
248
+ """
249
+ if self._chain_starts_cache is None:
250
+ object.__setattr__(self, "_chain_starts_cache", self._compute_chain_starts())
251
+ return self._chain_starts_cache # type: ignore[return-value]
252
+
253
+ @property
254
+ def residue_starts(self) -> IntArray:
255
+ """Indices of the first atom of each residue, in order.
256
+
257
+ A residue boundary is any change in
258
+ ``(chain_id, residue_id, insertion_code, model_id)``.
259
+ """
260
+ if self._residue_starts_cache is None:
261
+ object.__setattr__(self, "_residue_starts_cache", self._compute_residue_starts())
262
+ return self._residue_starts_cache # type: ignore[return-value]
263
+
264
+ def _compute_chain_starts(self) -> IntArray:
265
+ n = len(self)
266
+ if n == 0:
267
+ return np.empty(0, dtype=np.int32)
268
+ chain_change = np.empty(n, dtype=bool)
269
+ chain_change[0] = True
270
+ chain_change[1:] = (self.chain_id[1:] != self.chain_id[:-1]) | (
271
+ self.model_id[1:] != self.model_id[:-1]
272
+ )
273
+ return np.nonzero(chain_change)[0].astype(np.int32)
274
+
275
+ def _compute_residue_starts(self) -> IntArray:
276
+ n = len(self)
277
+ if n == 0:
278
+ return np.empty(0, dtype=np.int32)
279
+ res_change = np.empty(n, dtype=bool)
280
+ res_change[0] = True
281
+ res_change[1:] = (
282
+ (self.residue_id[1:] != self.residue_id[:-1])
283
+ | (self.chain_id[1:] != self.chain_id[:-1])
284
+ | (self.insertion_code[1:] != self.insertion_code[:-1])
285
+ | (self.model_id[1:] != self.model_id[:-1])
286
+ )
287
+ return np.nonzero(res_change)[0].astype(np.int32)
288
+
289
+ @property
290
+ def n_atoms(self) -> int:
291
+ return len(self)
292
+
293
+ @property
294
+ def n_residues(self) -> int:
295
+ return int(self.residue_starts.shape[0])
296
+
297
+ @property
298
+ def n_chains(self) -> int:
299
+ return int(self.chain_starts.shape[0])
300
+
301
+ # ------------------------------------------------------------------
302
+ # Iteration helpers
303
+ # ------------------------------------------------------------------
304
+ def iter_residue_slices(self) -> Iterable[slice]:
305
+ """Yield a ``slice`` for each residue's atoms (in array order)."""
306
+ starts = self.residue_starts
307
+ n = len(self)
308
+ for i, s in enumerate(starts):
309
+ e = int(starts[i + 1]) if i + 1 < len(starts) else n
310
+ yield slice(int(s), e)
311
+
312
+ def iter_chain_slices(self) -> Iterable[slice]:
313
+ """Yield a ``slice`` for each chain's atoms (in array order)."""
314
+ starts = self.chain_starts
315
+ n = len(self)
316
+ for i, s in enumerate(starts):
317
+ e = int(starts[i + 1]) if i + 1 < len(starts) else n
318
+ yield slice(int(s), e)
molforge/core/chain.py ADDED
@@ -0,0 +1,149 @@
1
+ """Chain — a view over the residues of a single chain.
2
+
3
+ A chain corresponds to a contiguous slice of an :class:`AtomArray` sharing
4
+ the same ``(chain_id, model_id)``. Residue boundaries within the chain
5
+ are resolved on demand from the underlying array's
6
+ :attr:`~molforge.core.atom_array.AtomArray.residue_starts` cache.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import TYPE_CHECKING
12
+
13
+ import numpy as np
14
+ from numpy.typing import NDArray
15
+
16
+ if TYPE_CHECKING:
17
+ from molforge.core.atom_array import AtomArray
18
+ from molforge.core.protein import Protein
19
+
20
+ from molforge.core.residue import Residue
21
+
22
+
23
+ class Chain:
24
+ """View over a chain's atoms inside an :class:`AtomArray`."""
25
+
26
+ __slots__ = ("_array", "_end", "_parent", "_start")
27
+
28
+ def __init__(
29
+ self,
30
+ array: AtomArray,
31
+ start: int,
32
+ end: int,
33
+ *,
34
+ parent: Protein | None = None,
35
+ ) -> None:
36
+ if not 0 <= start < end <= len(array):
37
+ raise IndexError(
38
+ f"invalid chain slice [{start}, {end}) for array of length {len(array)}"
39
+ )
40
+ self._array = array
41
+ self._start = int(start)
42
+ self._end = int(end)
43
+ self._parent = parent
44
+
45
+ # ------------------------------------------------------------------
46
+ # Identity
47
+ # ------------------------------------------------------------------
48
+ @property
49
+ def chain_id(self) -> str:
50
+ return str(self._array.chain_id[self._start])
51
+
52
+ @property
53
+ def model_id(self) -> int:
54
+ return int(self._array.model_id[self._start])
55
+
56
+ @property
57
+ def parent(self) -> Protein | None:
58
+ return self._parent
59
+
60
+ # ------------------------------------------------------------------
61
+ # Residue access
62
+ # ------------------------------------------------------------------
63
+ def _residue_slice_bounds(self) -> NDArray[np.int32]:
64
+ """Return the global residue-start indices that fall within this chain."""
65
+ starts = self._array.residue_starts
66
+ mask = (starts >= self._start) & (starts < self._end)
67
+ return starts[mask]
68
+
69
+ @property
70
+ def residues(self) -> list[Residue]:
71
+ """All residues in this chain, in N-to-C order."""
72
+ bounds = self._residue_slice_bounds()
73
+ out: list[Residue] = []
74
+ for i, s in enumerate(bounds):
75
+ e = int(bounds[i + 1]) if i + 1 < len(bounds) else self._end
76
+ out.append(Residue(self._array, int(s), e, parent=self))
77
+ return out
78
+
79
+ def __iter__(self): # type: ignore[no-untyped-def]
80
+ bounds = self._residue_slice_bounds()
81
+ for i, s in enumerate(bounds):
82
+ e = int(bounds[i + 1]) if i + 1 < len(bounds) else self._end
83
+ yield Residue(self._array, int(s), e, parent=self)
84
+
85
+ def __len__(self) -> int:
86
+ return int(self._residue_slice_bounds().shape[0])
87
+
88
+ def __getitem__(self, key: int | tuple[int, str]) -> Residue:
89
+ """Look up a residue.
90
+
91
+ - ``chain[42]`` returns the residue with ``seq_id == 42`` (no insertion code).
92
+ - ``chain[(42, "A")]`` returns the residue with ``seq_id == 42`` and
93
+ ``insertion_code == "A"``.
94
+
95
+ Raises:
96
+ KeyError: If no matching residue exists.
97
+ """
98
+ if isinstance(key, tuple):
99
+ seq_id, ins = key
100
+ else:
101
+ seq_id, ins = key, ""
102
+ for res in self:
103
+ if res.seq_id == seq_id and res.insertion_code == ins:
104
+ return res
105
+ raise KeyError(f"chain {self.chain_id!r} has no residue {seq_id}{ins or ''}")
106
+
107
+ @property
108
+ def sequence(self) -> str:
109
+ """One-letter sequence for this chain (standard AAs + non-canonical mappings).
110
+
111
+ Non-amino-acid residues (ligands, water, ions) are skipped.
112
+ Unknown residues become ``"X"``.
113
+ """
114
+ out: list[str] = []
115
+ for res in self:
116
+ if res.entity_type not in {"protein", "dna", "rna"}:
117
+ continue
118
+ out.append(res.one_letter)
119
+ return "".join(out)
120
+
121
+ @property
122
+ def n_atoms(self) -> int:
123
+ return self._end - self._start
124
+
125
+ @property
126
+ def n_residues(self) -> int:
127
+ return len(self)
128
+
129
+ @property
130
+ def coords(self) -> NDArray[np.float32]:
131
+ """All atom coordinates for this chain, shape ``(n_atoms, 3)``."""
132
+ return self._array.coords[self._start : self._end]
133
+
134
+ @property
135
+ def slice(self) -> slice:
136
+ return slice(self._start, self._end)
137
+
138
+ def __repr__(self) -> str:
139
+ return f"Chain(chain_id={self.chain_id!r}, n_residues={len(self)}, n_atoms={self.n_atoms})"
140
+
141
+ def __eq__(self, other: object) -> bool:
142
+ if not isinstance(other, Chain):
143
+ return NotImplemented
144
+ return (
145
+ self._array is other._array and self._start == other._start and self._end == other._end
146
+ )
147
+
148
+ def __hash__(self) -> int:
149
+ return hash((id(self._array), self._start, self._end))