sawnergy 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sawnergy might be problematic. Click here for more details.
- sawnergy/__init__.py +13 -0
- sawnergy/embedding/SGNS_pml.py +135 -0
- sawnergy/embedding/SGNS_torch.py +177 -0
- sawnergy/embedding/__init__.py +34 -0
- sawnergy/embedding/embedder.py +578 -0
- sawnergy/logging_util.py +54 -0
- sawnergy/rin/__init__.py +9 -0
- sawnergy/rin/rin_builder.py +936 -0
- sawnergy/rin/rin_util.py +391 -0
- sawnergy/sawnergy_util.py +1182 -0
- sawnergy/visual/__init__.py +42 -0
- sawnergy/visual/visualizer.py +690 -0
- sawnergy/visual/visualizer_util.py +387 -0
- sawnergy/walks/__init__.py +16 -0
- sawnergy/walks/walker.py +795 -0
- sawnergy/walks/walker_util.py +384 -0
- sawnergy-1.0.0.dist-info/METADATA +290 -0
- sawnergy-1.0.0.dist-info/RECORD +22 -0
- sawnergy-1.0.0.dist-info/WHEEL +5 -0
- sawnergy-1.0.0.dist-info/licenses/LICENSE +201 -0
- sawnergy-1.0.0.dist-info/licenses/NOTICE +4 -0
- sawnergy-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,936 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# third-pary
|
|
4
|
+
import numpy as np
|
|
5
|
+
import threadpoolctl
|
|
6
|
+
# built-in
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import logging
|
|
9
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
10
|
+
import re
|
|
11
|
+
import math
|
|
12
|
+
import os
|
|
13
|
+
# local
|
|
14
|
+
from . import rin_util
|
|
15
|
+
from .. import sawnergy_util
|
|
16
|
+
|
|
17
|
+
# *----------------------------------------------------*
|
|
18
|
+
# GLOBALS
|
|
19
|
+
# *----------------------------------------------------*
|
|
20
|
+
|
|
21
|
+
_logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# *----------------------------------------------------*
|
|
24
|
+
# CLASSES
|
|
25
|
+
# *----------------------------------------------------*
|
|
26
|
+
|
|
27
|
+
class RINBuilder:
|
|
28
|
+
"""Builds Residue Interaction Networks (RINs) from MD trajectories.
|
|
29
|
+
|
|
30
|
+
This class orchestrates running cpptraj to:
|
|
31
|
+
* compute per-frame, per-residue centers of mass (COMs),
|
|
32
|
+
* compute pairwise atomic non-bonded energies (electrostatics + van der Waals),
|
|
33
|
+
* project atomic interactions to residue-level interactions,
|
|
34
|
+
* post-process residue matrices (split, prune, remove self-interactions, symmetrize, normalize),
|
|
35
|
+
* package outputs into a compressed archive.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
cpptraj_path (Path | str | None): Optional explicit path to the `cpptraj`
|
|
39
|
+
executable. If not provided, an attempt is made to locate it via
|
|
40
|
+
`rin_util.locate_cpptraj`.
|
|
41
|
+
|
|
42
|
+
Attributes:
|
|
43
|
+
cpptraj (Path): Resolved path to the `cpptraj` executable.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, cpptraj_path: Path | str | None = None):
|
|
47
|
+
_logger.debug("Initializing RINBuilder with cpptraj_path=%s", cpptraj_path)
|
|
48
|
+
if isinstance(cpptraj_path, str):
|
|
49
|
+
cpptraj_path = Path(cpptraj_path)
|
|
50
|
+
|
|
51
|
+
self.cpptraj = rin_util.locate_cpptraj(explicit=cpptraj_path, verify=True)
|
|
52
|
+
_logger.info("Using cpptraj at %s", self.cpptraj)
|
|
53
|
+
|
|
54
|
+
# ---------------------------------------------------------------------------------------------- #
|
|
55
|
+
# CPPTRAJ HELPERS
|
|
56
|
+
# ---------------------------------------------------------------------------------------------- #
|
|
57
|
+
|
|
58
|
+
# NOTE: the pattern might be version specific
|
|
59
|
+
_elec_vdw_pattern = re.compile(r"""
|
|
60
|
+
^\s*\[printdata\s+PW\[EMAP\]\s+square2d\s+noheader\]\s*\r?\n
|
|
61
|
+
([0-9.eE+\-\s]+?)
|
|
62
|
+
^\s*\[printdata\s+PW\[VMAP\]\s+square2d\s+noheader\]\s*\r?\n
|
|
63
|
+
([0-9.eE+\-\s]+?)
|
|
64
|
+
(?=^\s*\[|^\s*TIME:|\Z)
|
|
65
|
+
""", re.MULTILINE | re.DOTALL | re.VERBOSE)
|
|
66
|
+
|
|
67
|
+
# NOTE: the pattern might be version specific
|
|
68
|
+
_com_block_pattern = lambda _, N: re.compile(rf"""
|
|
69
|
+
^[^\n]*\bCOMZ{N}\b[^\n]*\n
|
|
70
|
+
([0-9.eE+\-\s]+?)
|
|
71
|
+
(?=^\s*\[quit\]\s*$)
|
|
72
|
+
""", re.MULTILINE | re.DOTALL | re.VERBOSE)
|
|
73
|
+
|
|
74
|
+
_com_row_pattern = re.compile(r'^\s*\d+\s+(.+?)\s*$', re.MULTILINE)
|
|
75
|
+
|
|
76
|
+
def _get_number_frames(self,
|
|
77
|
+
topology_file: str,
|
|
78
|
+
trajectory_file: str,
|
|
79
|
+
*,
|
|
80
|
+
subprocess_env: dict | None = None,
|
|
81
|
+
timeout: float | None = None) -> int:
|
|
82
|
+
"""Return total number of frames in a trajectory.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
topology_file (str): Path to topology (parm/prmtop) file.
|
|
86
|
+
trajectory_file (str): Path to trajectory file readable by cpptraj.
|
|
87
|
+
subprocess_env (dict | None): Optional environment overrides for the
|
|
88
|
+
cpptraj subprocess (e.g., thread settings).
|
|
89
|
+
timeout (float | None): Optional time limit (seconds) for the cpptraj call.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
int: Total number of frames.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
RuntimeError: If cpptraj output cannot be parsed into an integer.
|
|
96
|
+
"""
|
|
97
|
+
_logger.debug("Requesting number of frames (topology=%s, trajectory=%s, timeout=%s)",
|
|
98
|
+
topology_file, trajectory_file, timeout)
|
|
99
|
+
raw_out = rin_util.run_cpptraj(self.cpptraj,
|
|
100
|
+
argv=["-p", topology_file, "-y", trajectory_file, "-tl"],
|
|
101
|
+
env=subprocess_env,
|
|
102
|
+
timeout=timeout)
|
|
103
|
+
_logger.debug("cpptraj -tl raw output: %r", raw_out)
|
|
104
|
+
out = raw_out.replace("Frames: ", "")
|
|
105
|
+
try:
|
|
106
|
+
frames = int(out)
|
|
107
|
+
_logger.info("Detected %d frames in trajectory %s", frames, trajectory_file)
|
|
108
|
+
return frames
|
|
109
|
+
except ValueError:
|
|
110
|
+
_logger.exception("Failed parsing frame count from cpptraj output: %r", out)
|
|
111
|
+
raise RuntimeError(f"Could not retrieve the number of frames from '{trajectory_file}' trajectory")
|
|
112
|
+
|
|
113
|
+
def _get_atomic_composition_of_molecule(self,
|
|
114
|
+
topology_file: str,
|
|
115
|
+
trajectory_file: str,
|
|
116
|
+
molecule_id: int,
|
|
117
|
+
*,
|
|
118
|
+
subprocess_env: dict | None = None,
|
|
119
|
+
timeout: float | None = None) -> dict:
|
|
120
|
+
"""Extract per-residue atomic composition for a molecule.
|
|
121
|
+
|
|
122
|
+
Runs a small cpptraj script that prints residue/atom membership and parses
|
|
123
|
+
it into a dictionary mapping residue IDs to sets of atom IDs.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
topology_file (str): Path to topology (parm/prmtop) file.
|
|
127
|
+
trajectory_file (str): Path to trajectory file.
|
|
128
|
+
molecule_id (int): Molecule selector/ID used by cpptraj (e.g., `^1`).
|
|
129
|
+
subprocess_env (dict | None): Optional environment overrides for cpptraj.
|
|
130
|
+
timeout (float | None): Optional time limit (seconds).
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
dict: Mapping ``{residue_id: set(atom_ids)}``.
|
|
134
|
+
|
|
135
|
+
Raises:
|
|
136
|
+
KeyError: If the requested `molecule_id` is not present in parsed output.
|
|
137
|
+
"""
|
|
138
|
+
_logger.debug("Extracting atomic composition (molecule_id=%s)", molecule_id)
|
|
139
|
+
tmp_file: Path = sawnergy_util.temporary_file(prefix="mol_comp", suffix=".dat")
|
|
140
|
+
_logger.debug("Temporary composition file: %s", tmp_file)
|
|
141
|
+
try:
|
|
142
|
+
molecule_compositions_script = (self._load_data_from(topology_file, trajectory_file, 1, 1) + \
|
|
143
|
+
self._extract_molecule_compositions()) > str(tmp_file)
|
|
144
|
+
script = molecule_compositions_script.render()
|
|
145
|
+
_logger.debug("Running composition cpptraj script (len=%d chars)", len(script))
|
|
146
|
+
rin_util.run_cpptraj(self.cpptraj, script=script, env=subprocess_env, timeout=timeout)
|
|
147
|
+
hierarchy = rin_util.CpptrajMaskParser.hierarchize_molecular_composition(tmp_file)
|
|
148
|
+
if molecule_id not in hierarchy:
|
|
149
|
+
_logger.error("Molecule ID %s not found in composition hierarchy (available keys: %s)",
|
|
150
|
+
molecule_id, list(hierarchy.keys())[:10])
|
|
151
|
+
comp = hierarchy[molecule_id]
|
|
152
|
+
_logger.info("Retrieved composition for molecule %s (residues=%d)", molecule_id, len(comp))
|
|
153
|
+
return comp
|
|
154
|
+
finally:
|
|
155
|
+
try:
|
|
156
|
+
tmp_file.unlink()
|
|
157
|
+
_logger.debug("Cleaned up temp file %s", tmp_file)
|
|
158
|
+
except OSError:
|
|
159
|
+
_logger.warning("Failed to remove temp file %s", tmp_file, exc_info=True)
|
|
160
|
+
|
|
161
|
+
def _calc_avg_atomic_interactions_in_frames(self,
|
|
162
|
+
frame_range: tuple[int, int],
|
|
163
|
+
topology_file: str,
|
|
164
|
+
trajectory_file: str,
|
|
165
|
+
molecule_id: int,
|
|
166
|
+
*,
|
|
167
|
+
subprocess_env: dict | None = None,
|
|
168
|
+
timeout: float | None = None) -> np.ndarray:
|
|
169
|
+
"""Compute average atomic interaction matrix over a frame range.
|
|
170
|
+
|
|
171
|
+
Uses cpptraj `pairwise` to compute electrostatic (EMAP) and van der Waals
|
|
172
|
+
(VMAP) atomic interaction matrices, sums them, and returns the result.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
frame_range (tuple[int, int]): Inclusive (start_frame, end_frame).
|
|
176
|
+
topology_file (str): Path to topology file.
|
|
177
|
+
trajectory_file (str): Path to trajectory file.
|
|
178
|
+
molecule_id (int): Molecule selector/ID for restricting computation.
|
|
179
|
+
subprocess_env (dict | None): Optional environment for cpptraj.
|
|
180
|
+
timeout (float | None): Optional time limit (seconds).
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
np.ndarray: 2D array (n_atoms, n_atoms) of summed interactions. For the given
|
|
184
|
+
frame range, cpptraj's pairwise driver accumulates/averages internally and the
|
|
185
|
+
printed EMAP/VMAP “square2d” blocks correspond to the range specified.
|
|
186
|
+
|
|
187
|
+
Raises:
|
|
188
|
+
ValueError: If EMAP/VMAP blocks are not found, sizes mismatch, or the
|
|
189
|
+
block cannot be reshaped into a square matrix.
|
|
190
|
+
"""
|
|
191
|
+
start_frame, end_frame = frame_range
|
|
192
|
+
_logger.debug("Calculating avg atomic interactions (frames=%s..%s, molecule_id=%s)",
|
|
193
|
+
start_frame, end_frame, molecule_id)
|
|
194
|
+
interaction_calc_script = (
|
|
195
|
+
self._load_data_from(topology_file, trajectory_file, start_frame, end_frame)
|
|
196
|
+
+ self._calc_nonbonded_energies_in_molecule(molecule_id)
|
|
197
|
+
) > rin_util.PAIRWISE_STDOUT
|
|
198
|
+
script = interaction_calc_script.render()
|
|
199
|
+
_logger.debug("Running pairwise cpptraj script (len=%d chars)", len(script))
|
|
200
|
+
output = rin_util.run_cpptraj(self.cpptraj, script=script, env=subprocess_env, timeout=timeout)
|
|
201
|
+
_logger.debug("cpptraj pairwise output length: %d", len(output))
|
|
202
|
+
|
|
203
|
+
m = self._elec_vdw_pattern.search(output)
|
|
204
|
+
if not m:
|
|
205
|
+
_logger.error("EMAP/VMAP blocks not found in cpptraj output.")
|
|
206
|
+
raise ValueError("Could not find EMAP/VMAP blocks in cpptraj output. "
|
|
207
|
+
"Potentially due to cpptraj version mismatch. "
|
|
208
|
+
"The data retrieval is stable for CPPTRAJ of Version V6.18.1 (AmberTools)")
|
|
209
|
+
emap_txt, vmap_txt = m.group(1), m.group(2)
|
|
210
|
+
|
|
211
|
+
# Robust to wrapped lines: read all numbers, ignore line structure
|
|
212
|
+
emap_flat = np.fromstring(emap_txt, dtype=np.float32, sep=' ')
|
|
213
|
+
vmap_flat = np.fromstring(vmap_txt, dtype=np.float32, sep=' ')
|
|
214
|
+
_logger.debug("Parsed EMAP=%d values, VMAP=%d values", emap_flat.size, vmap_flat.size)
|
|
215
|
+
|
|
216
|
+
if emap_flat.size != vmap_flat.size:
|
|
217
|
+
_logger.error("Size mismatch EMAP(%d) vs VMAP(%d)", emap_flat.size, vmap_flat.size)
|
|
218
|
+
raise ValueError(f"EMAP and VMAP sizes differ: {emap_flat.size} vs {vmap_flat.size} "
|
|
219
|
+
"Potentially due to cpptraj version mismatch. "
|
|
220
|
+
"The data retrieval is stable for CPPTRAJ of Version V6.18.1 (AmberTools)")
|
|
221
|
+
|
|
222
|
+
n = int(round(math.sqrt(emap_flat.size)))
|
|
223
|
+
if n * n != emap_flat.size:
|
|
224
|
+
_logger.error("Non-square block: %d values (cannot form nxn)", emap_flat.size)
|
|
225
|
+
raise ValueError(f"Block is not square: {emap_flat.size} values (cannot reshape to nxn). "
|
|
226
|
+
"Potentially due to cpptraj version mismatch. "
|
|
227
|
+
"The data retrieval is stable for CPPTRAJ of Version V6.18.1 (AmberTools)")
|
|
228
|
+
|
|
229
|
+
elec_matrix = emap_flat.reshape(n, n)
|
|
230
|
+
vdw_matrix = vmap_flat.reshape(n, n)
|
|
231
|
+
_logger.debug("Reshaped EMAP/VMAP to (%d, %d)", n, n)
|
|
232
|
+
|
|
233
|
+
interaction_matrix = (elec_matrix + vdw_matrix).astype(np.float32)
|
|
234
|
+
_logger.info("Computed interaction matrix shape: %s", interaction_matrix.shape)
|
|
235
|
+
return interaction_matrix
|
|
236
|
+
|
|
237
|
+
def _get_residue_COMs_per_frame(
|
|
238
|
+
self,
|
|
239
|
+
frame_range: tuple[int, int],
|
|
240
|
+
topology_file: str,
|
|
241
|
+
trajectory_file: str,
|
|
242
|
+
molecule_id: int,
|
|
243
|
+
number_residues: int,
|
|
244
|
+
*,
|
|
245
|
+
subprocess_env: dict | None = None,
|
|
246
|
+
timeout: float | None = None,
|
|
247
|
+
) -> list[np.ndarray]:
|
|
248
|
+
"""Compute per-residue COM coordinates for each frame.
|
|
249
|
+
|
|
250
|
+
Runs a cpptraj loop to compute ``vector COM<i>`` per residue and parses the
|
|
251
|
+
printed data.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
frame_range (tuple[int, int]): Inclusive ``(start_frame, end_frame)`` (1-based).
|
|
255
|
+
topology_file (str): Path to topology file.
|
|
256
|
+
trajectory_file (str): Path to trajectory file.
|
|
257
|
+
molecule_id (int): Molecule selector/ID used by cpptraj for residue iteration.
|
|
258
|
+
number_residues (int): Expected residue count (used for validation).
|
|
259
|
+
subprocess_env (dict | None): Optional environment overrides for cpptraj.
|
|
260
|
+
timeout (float | None): Optional time limit (seconds) for cpptraj.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
list[np.ndarray]: A list of length ``n_frames`` where each element is a
|
|
264
|
+
``(n_residues, 3)`` array of ``float32`` giving XYZ COM coordinates for that
|
|
265
|
+
frame. Element ``0`` corresponds to ``start_frame``, element ``-1`` to
|
|
266
|
+
``end_frame``.
|
|
267
|
+
|
|
268
|
+
Raises:
|
|
269
|
+
ValueError: If ``frame_range`` is invalid (end < start).
|
|
270
|
+
RuntimeError: If the COM print block is missing/malformed or row sizes mismatch.
|
|
271
|
+
"""
|
|
272
|
+
start_frame, end_frame = frame_range
|
|
273
|
+
_logger.debug("Getting COMs per frame (frames=%s..%s, residues=%d, molecule_id=%s)",
|
|
274
|
+
start_frame, end_frame, number_residues, molecule_id)
|
|
275
|
+
if end_frame < start_frame:
|
|
276
|
+
_logger.error("Bad frame_range %s: end < start", frame_range)
|
|
277
|
+
raise ValueError(f"Bad frame_range {frame_range}: end < start")
|
|
278
|
+
number_frames = end_frame - start_frame + 1
|
|
279
|
+
|
|
280
|
+
# build and run cpptraj script
|
|
281
|
+
COM_script = (
|
|
282
|
+
self._load_data_from(topology_file, trajectory_file, start_frame, end_frame)
|
|
283
|
+
+ self._compute_residue_COMs_in_molecule(molecule_id)
|
|
284
|
+
) > rin_util.COM_STDOUT(molecule_id)
|
|
285
|
+
script_rendered = COM_script.render()
|
|
286
|
+
_logger.debug("Running COM cpptraj script (len=%d chars)", len(script_rendered))
|
|
287
|
+
output = rin_util.run_cpptraj(self.cpptraj, script=script_rendered,
|
|
288
|
+
env=subprocess_env, timeout=timeout)
|
|
289
|
+
_logger.debug("cpptraj COM output length: %d", len(output))
|
|
290
|
+
|
|
291
|
+
# extract COM block and per-frame rows
|
|
292
|
+
m = self._com_block_pattern(number_residues).search(output)
|
|
293
|
+
if not m:
|
|
294
|
+
_logger.error("COM print block not found in cpptraj output (expected COMZ%d header).",
|
|
295
|
+
number_residues)
|
|
296
|
+
raise RuntimeError("Could not find COM print block in cpptraj output. "
|
|
297
|
+
"Potentially due to cpptraj version mismatch. "
|
|
298
|
+
"The data retrieval is stable for CPPTRAJ of Version V6.18.1 (AmberTools)")
|
|
299
|
+
block = m.group(1)
|
|
300
|
+
lines = self._com_row_pattern.findall(block) # list[str], coords only (no frame #)
|
|
301
|
+
_logger.debug("Extracted %d COM rows (expected %d)", len(lines), number_frames)
|
|
302
|
+
|
|
303
|
+
if len(lines) != number_frames:
|
|
304
|
+
_logger.error("Frame row count mismatch: expected %d, got %d",
|
|
305
|
+
number_frames, len(lines))
|
|
306
|
+
raise RuntimeError(f"Expected {number_frames} frame rows, got {len(lines)}. "
|
|
307
|
+
"Potentially due to cpptraj version mismatch. "
|
|
308
|
+
"The data retrieval is stable for CPPTRAJ of Version V6.18.1 (AmberTools)")
|
|
309
|
+
|
|
310
|
+
# parse and reshape to (n_residues, 3) per frame
|
|
311
|
+
rows = [np.fromstring(line, dtype=np.float32, sep=' ') for line in lines]
|
|
312
|
+
bad = [i for i, arr in enumerate(rows) if arr.size != number_residues * 3]
|
|
313
|
+
if bad:
|
|
314
|
+
_logger.error("Row(s) with wrong length detected (showing first few): %s", bad[:5])
|
|
315
|
+
raise RuntimeError(
|
|
316
|
+
f"Row(s) {bad[:5]} have wrong length; expected {number_residues*3} floats. "
|
|
317
|
+
"Potentially due to cpptraj version mismatch. "
|
|
318
|
+
"The data retrieval is stable for CPPTRAJ of Version V6.18.1 (AmberTools)"
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
coords: list[np.ndarray] = [row.reshape(3, number_residues).T for row in rows]
|
|
322
|
+
_logger.info("Built %d COM arrays of shape %s (one per frame)",
|
|
323
|
+
len(coords), coords[0].shape)
|
|
324
|
+
return coords
|
|
325
|
+
|
|
326
|
+
# ---------------------------------------------------------------------------------------------- #
|
|
327
|
+
# CPPTRAJ COMMANDS
|
|
328
|
+
# ---------------------------------------------------------------------------------------------- #
|
|
329
|
+
|
|
330
|
+
@staticmethod
|
|
331
|
+
def _load_data_from(topology_file: str,
|
|
332
|
+
trajectory_file: str,
|
|
333
|
+
start_frame: int,
|
|
334
|
+
end_frame: int) -> rin_util.CpptrajScript:
|
|
335
|
+
"""Create a cpptraj script that loads topology/trajectory and selects frames.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
topology_file (str): Path to topology file.
|
|
339
|
+
trajectory_file (str): Path to trajectory file.
|
|
340
|
+
start_frame (int): First frame (1-based inclusive).
|
|
341
|
+
end_frame (int): Last frame (1-based inclusive).
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
rin_util.CpptrajScript: Composable script object.
|
|
345
|
+
"""
|
|
346
|
+
_logger.debug("Preparing data load (parm=%s, trajin=%s %s %s)",
|
|
347
|
+
topology_file, trajectory_file, start_frame, end_frame)
|
|
348
|
+
return rin_util.CpptrajScript((f"parm {topology_file}",
|
|
349
|
+
f"trajin {trajectory_file} {start_frame} {end_frame}",
|
|
350
|
+
"noprogress silenceactions"))
|
|
351
|
+
|
|
352
|
+
@staticmethod
|
|
353
|
+
def _calc_nonbonded_energies_in_molecule(molecule_id: int) -> rin_util.CpptrajScript:
|
|
354
|
+
"""Create a cpptraj command to compute pairwise non-bonded energies.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
molecule_id (int): Molecule selector/ID for pairwise computation.
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
rin_util.CpptrajScript: Script with `pairwise PW` command.
|
|
361
|
+
"""
|
|
362
|
+
_logger.debug("Preparing pairwise command for molecule_id=%s", molecule_id)
|
|
363
|
+
return rin_util.CpptrajScript.from_cmd(f"pairwise PW ^{molecule_id} cuteelec 0.0 cutevdw 0.0")
|
|
364
|
+
|
|
365
|
+
@staticmethod
|
|
366
|
+
def _extract_molecule_compositions() -> rin_util.CpptrajScript:
|
|
367
|
+
"""Create a cpptraj command that emits residue/atom masks.
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
rin_util.CpptrajScript: Script with `mask :*` command.
|
|
371
|
+
"""
|
|
372
|
+
_logger.debug("Preparing mask extraction command")
|
|
373
|
+
return rin_util.CpptrajScript.from_cmd(f"mask :*")
|
|
374
|
+
|
|
375
|
+
@staticmethod
|
|
376
|
+
def _compute_residue_COMs_in_molecule(molecule_id: int):
|
|
377
|
+
"""Create a cpptraj loop to compute per-residue COM vectors.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
molecule_id (int): Molecule selector/ID whose residues are iterated.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
rin_util.CpptrajScript: Script that defines COM vectors (COM1, COM2, ...).
|
|
384
|
+
"""
|
|
385
|
+
_logger.debug("Preparing COM vectors loop for molecule_id=%s", molecule_id)
|
|
386
|
+
return rin_util.CpptrajScript((
|
|
387
|
+
"autoimage",
|
|
388
|
+
"unwrap byres",
|
|
389
|
+
f"for residues R inmask ^{molecule_id} i=1;i++",
|
|
390
|
+
"vector COM$i center $R",
|
|
391
|
+
"done"
|
|
392
|
+
))
|
|
393
|
+
|
|
394
|
+
# ---------------------------------------------------------------------------------------------- #
|
|
395
|
+
# POST-CPPTRAJ
|
|
396
|
+
# ---------------------------------------------------------------------------------------------- #
|
|
397
|
+
|
|
398
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
|
|
399
|
+
# CONVERSION OF ATOMIC LEVEL INTERACTIONS INTO RESIDUE LEVEL
|
|
400
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
|
|
401
|
+
@staticmethod
|
|
402
|
+
def _compute_residue_membership_matrix(
|
|
403
|
+
res_to_atoms: dict[int, set[int]],
|
|
404
|
+
*,
|
|
405
|
+
dtype=np.float32
|
|
406
|
+
) -> np.ndarray:
|
|
407
|
+
"""Build a binary (n_atoms x n_residues) membership matrix.
|
|
408
|
+
|
|
409
|
+
Contiguously re-indexes arbitrary atom/residue IDs to 0..N-1 and sets
|
|
410
|
+
``P[row(atom), col(residue)] = 1`` when the atom belongs to the residue.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
res_to_atoms (dict[int, set[int]]): Mapping of residue IDs to sets of atom IDs.
|
|
414
|
+
dtype (np.dtype): Output dtype.
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
np.ndarray: Membership matrix of shape (n_atoms, n_residues).
|
|
418
|
+
"""
|
|
419
|
+
if not res_to_atoms:
|
|
420
|
+
_logger.info("Empty residue->atoms mapping; returning (0,0) matrix.")
|
|
421
|
+
return np.zeros((0, 0), dtype=dtype)
|
|
422
|
+
|
|
423
|
+
# ----- Build contiguous indices for residues (columns) -----
|
|
424
|
+
# Use numeric sort so indices are stable and predictable.
|
|
425
|
+
res_ids = sorted(res_to_atoms.keys())
|
|
426
|
+
res_to_col = {r: i for i, r in enumerate(res_ids)}
|
|
427
|
+
n_res = len(res_ids)
|
|
428
|
+
|
|
429
|
+
# ----- Build contiguous indices for atoms (rows) -----
|
|
430
|
+
# Union all atom IDs, then sort numerically.
|
|
431
|
+
atom_ids_set = set()
|
|
432
|
+
for r in res_ids:
|
|
433
|
+
atom_ids_set.update(res_to_atoms[r])
|
|
434
|
+
atom_ids = sorted(atom_ids_set)
|
|
435
|
+
atom_to_row = {a: i for i, a in enumerate(atom_ids)}
|
|
436
|
+
n_atoms = len(atom_ids)
|
|
437
|
+
|
|
438
|
+
_logger.debug("Membership dims: atoms=%d, residues=%d", n_atoms, n_res)
|
|
439
|
+
|
|
440
|
+
# ----- Fill membership matrix -----
|
|
441
|
+
P = np.zeros((n_atoms, n_res), dtype=dtype)
|
|
442
|
+
for r in res_ids:
|
|
443
|
+
c = res_to_col[r]
|
|
444
|
+
for a in res_to_atoms[r]:
|
|
445
|
+
P[atom_to_row[a], c] = 1.0
|
|
446
|
+
|
|
447
|
+
_logger.info("Built membership matrix with shape %s and density %.6f",
|
|
448
|
+
P.shape, float(P.sum()) / (P.size if P.size else 1.0))
|
|
449
|
+
return P
|
|
450
|
+
|
|
451
|
+
@staticmethod
|
|
452
|
+
def _convert_atomic_to_residue_interactions(atomic_matrix: np.ndarray,
|
|
453
|
+
membership_matrix: np.ndarray) -> np.ndarray:
|
|
454
|
+
"""Project atomic interaction matrix to residue space.
|
|
455
|
+
|
|
456
|
+
Computes ``R = Pᵀ @ A @ P`` where `A` is atomic (n_atoms x n_atoms) and
|
|
457
|
+
`P` is membership (n_atoms x n_residues).
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
atomic_matrix (np.ndarray): Atomic interaction matrix (n_atoms, n_atoms).
|
|
461
|
+
membership_matrix (np.ndarray): Membership matrix (n_atoms, n_residues).
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
np.ndarray: Residue interaction matrix (n_residues, n_residues).
|
|
465
|
+
"""
|
|
466
|
+
_logger.debug("Converting atomic->residue: atomic_matrix=%s, membership=%s",
|
|
467
|
+
atomic_matrix.shape, membership_matrix.shape)
|
|
468
|
+
|
|
469
|
+
if atomic_matrix.ndim != 2 or atomic_matrix.shape[0] != atomic_matrix.shape[1]:
|
|
470
|
+
raise ValueError(f"atomic_matrix must be square 2D; got shape {atomic_matrix.shape}")
|
|
471
|
+
if membership_matrix.ndim != 2 or membership_matrix.shape[0] != atomic_matrix.shape[0]:
|
|
472
|
+
raise ValueError(
|
|
473
|
+
f"Row count mismatch: atomic_matrix is {atomic_matrix.shape}, "
|
|
474
|
+
f"membership_matrix is {membership_matrix.shape}. Rows must match (#atoms)."
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
thread_count = os.cpu_count() or 1
|
|
478
|
+
with threadpoolctl.threadpool_limits(limits=thread_count):
|
|
479
|
+
result = (membership_matrix.T @ atomic_matrix @ membership_matrix).astype(dtype=np.float32)
|
|
480
|
+
_logger.info("Residue interaction matrix shape: %s", result.shape)
|
|
481
|
+
return result
|
|
482
|
+
|
|
483
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
|
|
484
|
+
# POST-PROCESSING OF RESIDUE LEVEL INTERACTIONS
|
|
485
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
|
|
486
|
+
|
|
487
|
+
def _split_into_attractive_repulsive(self, residue_matrix: np.ndarray) -> np.ndarray:
|
|
488
|
+
"""Split residue interactions into attractive and repulsive channels.
|
|
489
|
+
|
|
490
|
+
Negative values go to the attractive channel (as positive magnitudes),
|
|
491
|
+
positive values go to the repulsive channel.
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
residue_matrix (np.ndarray): Residue interaction matrix (N, N).
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
np.ndarray: Array of shape (2, N, N): [attractive, repulsive].
|
|
498
|
+
"""
|
|
499
|
+
_logger.debug("Splitting matrix into attractive/repulsive channels; input shape=%s",
|
|
500
|
+
residue_matrix.shape)
|
|
501
|
+
attr = np.where(residue_matrix <= 0, -residue_matrix, 0.0).astype(np.float32, copy=False)
|
|
502
|
+
rep = np.where(residue_matrix > 0, residue_matrix, 0.0).astype(np.float32, copy=False)
|
|
503
|
+
out = np.stack([attr, rep], axis=0) # (2, N, N)
|
|
504
|
+
_logger.info("Two-channel matrix shape: %s", out.shape)
|
|
505
|
+
return out
|
|
506
|
+
|
|
507
|
+
def _prune_low_energies(self, two_channel_residue_matrix: np.ndarray, q: float) -> np.ndarray:
|
|
508
|
+
"""Zero out values below a per-row quantile threshold.
|
|
509
|
+
|
|
510
|
+
Applies independently to attractive and repulsive channels.
|
|
511
|
+
|
|
512
|
+
Args:
|
|
513
|
+
two_channel_residue_matrix (np.ndarray): Array (2, N, N).
|
|
514
|
+
q (float): Quantile in (0, 1] used as threshold.
|
|
515
|
+
|
|
516
|
+
Returns:
|
|
517
|
+
np.ndarray: Pruned two-channel matrix.
|
|
518
|
+
|
|
519
|
+
Raises:
|
|
520
|
+
ValueError: If `q` is not in (0, 1].
|
|
521
|
+
"""
|
|
522
|
+
_logger.debug("Pruning low energies with q=%s on matrix shape=%s", q, two_channel_residue_matrix.shape)
|
|
523
|
+
if not (0.0 < q <= 1.0):
|
|
524
|
+
_logger.error("Invalid pruning quantile q=%s", q)
|
|
525
|
+
raise ValueError(f"Invalid 'q' value. Expected a value between 0 and 1; received: {q}")
|
|
526
|
+
A = two_channel_residue_matrix[0]
|
|
527
|
+
R = two_channel_residue_matrix[1]
|
|
528
|
+
Ath = np.quantile(A, q, axis=1, keepdims=True)
|
|
529
|
+
Rth = np.quantile(R, q, axis=1, keepdims=True)
|
|
530
|
+
two_channel_residue_matrix[0] = np.where(A < Ath, 0.0, A)
|
|
531
|
+
two_channel_residue_matrix[1] = np.where(R < Rth, 0.0, R)
|
|
532
|
+
_logger.info("Pruning done at q=%s", q)
|
|
533
|
+
return two_channel_residue_matrix
|
|
534
|
+
|
|
535
|
+
def _remove_self_interactions(self, two_channel_residue_matrix: np.ndarray) -> np.ndarray:
|
|
536
|
+
"""Zero the diagonal in both channels.
|
|
537
|
+
|
|
538
|
+
Args:
|
|
539
|
+
two_channel_residue_matrix (np.ndarray): Array (2, N, N).
|
|
540
|
+
|
|
541
|
+
Returns:
|
|
542
|
+
np.ndarray: Same array with zeroed diagonals.
|
|
543
|
+
"""
|
|
544
|
+
_logger.debug("Zeroing self-interactions on shape=%s", two_channel_residue_matrix.shape)
|
|
545
|
+
np.fill_diagonal(two_channel_residue_matrix[0], 0.0); np.fill_diagonal(two_channel_residue_matrix[1], 0.0)
|
|
546
|
+
return two_channel_residue_matrix
|
|
547
|
+
|
|
548
|
+
def _symmetrize(self, two_channel_residue_matrix: np.ndarray) -> np.ndarray:
|
|
549
|
+
"""Symmetrize both channels via (M + Mᵀ)/2.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
two_channel_residue_matrix (np.ndarray): Array (2, N, N).
|
|
553
|
+
|
|
554
|
+
Returns:
|
|
555
|
+
np.ndarray: Symmetrized two-channel matrix.
|
|
556
|
+
"""
|
|
557
|
+
_logger.debug("Symmetrizing two-channel matrix shape=%s", two_channel_residue_matrix.shape)
|
|
558
|
+
A = two_channel_residue_matrix[0]
|
|
559
|
+
R = two_channel_residue_matrix[1]
|
|
560
|
+
two_channel_residue_matrix[0] = (A + A.T) * 0.5
|
|
561
|
+
two_channel_residue_matrix[1] = (R + R.T) * 0.5
|
|
562
|
+
_logger.info("Symmetrization complete")
|
|
563
|
+
return two_channel_residue_matrix
|
|
564
|
+
|
|
565
|
+
def _L1_normalize(self, two_channel_residue_matrix: np.ndarray) -> np.ndarray:
|
|
566
|
+
"""Row-wise L1-normalization of both channels.
|
|
567
|
+
|
|
568
|
+
Each row is divided by its sum; zero rows remain zero.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
two_channel_residue_matrix (np.ndarray): Array (2, N, N).
|
|
572
|
+
|
|
573
|
+
Returns:
|
|
574
|
+
np.ndarray: L1-normalized two-channel matrix.
|
|
575
|
+
|
|
576
|
+
Note:
|
|
577
|
+
Row-wise normalization breaks symmetry because it converts energies into
|
|
578
|
+
per-row transition probabilities (rows sum to 1). Even if (i, j) == (j, i)
|
|
579
|
+
before normalization, differing row totals generally yield (i, j) != (j, i)
|
|
580
|
+
afterward.
|
|
581
|
+
"""
|
|
582
|
+
_logger.debug("L1-normalizing two-channel matrix shape=%s", two_channel_residue_matrix.shape)
|
|
583
|
+
A = two_channel_residue_matrix[0]
|
|
584
|
+
R = two_channel_residue_matrix[1]
|
|
585
|
+
eps = 1e-12
|
|
586
|
+
Asum = A.sum(axis=1, keepdims=True)
|
|
587
|
+
Rsum = R.sum(axis=1, keepdims=True)
|
|
588
|
+
two_channel_residue_matrix[0] = np.divide(A, np.clip(Asum, eps, None),
|
|
589
|
+
out=np.zeros_like(A), where=Asum > 0)
|
|
590
|
+
two_channel_residue_matrix[1] = np.divide(R, np.clip(Rsum, eps, None),
|
|
591
|
+
out=np.zeros_like(R), where=Rsum > 0)
|
|
592
|
+
_logger.info("L1 normalization complete (zero-row counts: A=%d, R=%d)",
|
|
593
|
+
int((Asum <= eps).sum()), int((Rsum <= eps).sum()))
|
|
594
|
+
return two_channel_residue_matrix
|
|
595
|
+
|
|
596
|
+
def _store_two_channel_array(
|
|
597
|
+
self,
|
|
598
|
+
arr: np.ndarray,
|
|
599
|
+
storage: sawnergy_util.ArrayStorage,
|
|
600
|
+
arrays_per_chunk: int,
|
|
601
|
+
attractive_dataset_name: str | None,
|
|
602
|
+
repulsive_dataset_name: str | None,
|
|
603
|
+
) -> None:
|
|
604
|
+
"""Persist a two-channel residue interaction array to storage.
|
|
605
|
+
|
|
606
|
+
This writes the two channels (index 0 → attractive, index 1 → repulsive)
|
|
607
|
+
into two separate dataset blocks inside the provided ``ArrayStorage``,
|
|
608
|
+
using the specified chunking policy.
|
|
609
|
+
|
|
610
|
+
Args:
|
|
611
|
+
arr: A 3-D array with shape ``(2, N, N)`` where the first axis indexes
|
|
612
|
+
the channels: ``arr[0]`` is the attractive channel and ``arr[1]`` is
|
|
613
|
+
the repulsive channel.
|
|
614
|
+
storage: An open ``ArrayStorage`` handle to write into.
|
|
615
|
+
arrays_per_chunk: Number of matrices per chunk along the leading axis
|
|
616
|
+
when writing into the Zarr arrays.
|
|
617
|
+
attractive_dataset_name: Dataset (block) name to store the attractive
|
|
618
|
+
channel under (if None, the dataset isn't persisted).
|
|
619
|
+
repulsive_dataset_name: Dataset (block) name to store the repulsive
|
|
620
|
+
channel under (if None, the dataset isn't persisted).
|
|
621
|
+
|
|
622
|
+
Returns:
|
|
623
|
+
None
|
|
624
|
+
|
|
625
|
+
Notes:
|
|
626
|
+
- If ``arr`` does not have the expected shape ``(2, N, N)``, a warning
|
|
627
|
+
is logged and the function still attempts to write ``arr[0]`` and
|
|
628
|
+
``arr[1]`` as the two channels.
|
|
629
|
+
- Exceptions raised by the storage layer propagate to the caller.
|
|
630
|
+
"""
|
|
631
|
+
if arr.ndim != 3 or arr.shape[0] != 2:
|
|
632
|
+
_logger.warning(
|
|
633
|
+
"Expected two-channel array with shape (2, N, N); got %s",
|
|
634
|
+
getattr(arr, "shape", None),
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
_logger.debug(
|
|
638
|
+
"Storing two-channel array: shape=%s, chunksize=%s, datasets=(%s, %s)",
|
|
639
|
+
arr.shape, arrays_per_chunk, attractive_dataset_name, repulsive_dataset_name
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
if attractive_dataset_name is not None:
|
|
643
|
+
storage.write(
|
|
644
|
+
these_arrays=[arr[0]],
|
|
645
|
+
to_block_named=attractive_dataset_name,
|
|
646
|
+
arrays_per_chunk=arrays_per_chunk
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
if repulsive_dataset_name is not None:
|
|
650
|
+
storage.write(
|
|
651
|
+
these_arrays=[arr[1]],
|
|
652
|
+
to_block_named=repulsive_dataset_name,
|
|
653
|
+
arrays_per_chunk=arrays_per_chunk
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
_logger.info(
|
|
657
|
+
"Stored attractive/repulsive arrays to '%s' / '%s'",
|
|
658
|
+
attractive_dataset_name, repulsive_dataset_name
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
# ---------------------------------------------------------------------------------------------- #
|
|
662
|
+
# PUBLIC API
|
|
663
|
+
# ---------------------------------------------------------------------------------------------- #
|
|
664
|
+
|
|
665
|
+
def build_rin(
|
|
666
|
+
self,
|
|
667
|
+
topology_file: str,
|
|
668
|
+
trajectory_file: str,
|
|
669
|
+
molecule_of_interest: int,
|
|
670
|
+
frame_range: tuple[int, int] | None = None,
|
|
671
|
+
frame_batch_size: int = -1,
|
|
672
|
+
prune_low_energies_frac: float = 0.3,
|
|
673
|
+
output_path: str | Path | None = None,
|
|
674
|
+
keep_prenormalized_energies: bool = True,
|
|
675
|
+
*,
|
|
676
|
+
include_attractive: bool = True,
|
|
677
|
+
include_repulsive: bool = True,
|
|
678
|
+
parallel_cpptraj: bool = False,
|
|
679
|
+
simul_cpptraj_instances: int | None = None,
|
|
680
|
+
num_matrices_in_compressed_blocks: int = 10,
|
|
681
|
+
compression_level: int = 3,
|
|
682
|
+
cpptraj_run_time_limit: float | None = None
|
|
683
|
+
) -> str:
|
|
684
|
+
"""Build a Residue Interaction Network (RIN) archive from an MD trajectory.
|
|
685
|
+
|
|
686
|
+
High-level pipeline:
|
|
687
|
+
|
|
688
|
+
1. Discover MD metadata (trajectory frame count; residue membership of the
|
|
689
|
+
target molecule).
|
|
690
|
+
|
|
691
|
+
2. For each frame batch:
|
|
692
|
+
|
|
693
|
+
a) Run cpptraj `pairwise` on atoms → EMAP + VMAP → sum (atomic matrix).
|
|
694
|
+
|
|
695
|
+
b) Project atomic → residue with ``R = Pᵀ @ A @ P``.
|
|
696
|
+
|
|
697
|
+
c) Post-process residue matrix:
|
|
698
|
+
split into (attractive, repulsive) channels,
|
|
699
|
+
per-row quantile pruning,
|
|
700
|
+
remove self-interactions,
|
|
701
|
+
symmetrize.
|
|
702
|
+
|
|
703
|
+
d. Optionally store **pre-normalized energies** (attractive or repulsive or both, depending on `include_<kind>`).
|
|
704
|
+
|
|
705
|
+
e. Row-wise L1 normalize (directed transition probabilities) and store.
|
|
706
|
+
|
|
707
|
+
3. Compute per-residue COM coordinates across requested frames and store.
|
|
708
|
+
|
|
709
|
+
4. Close and compress the temporary store into a zip (Zarr v3). Return path.
|
|
710
|
+
|
|
711
|
+
Args:
|
|
712
|
+
topology_file: Path to the topology (parm/prmtop) file.
|
|
713
|
+
trajectory_file: Path to a cpptraj-readable trajectory file.
|
|
714
|
+
molecule_of_interest: Molecule selector/ID used by cpptraj (e.g., ``1``
|
|
715
|
+
for ``^1``).
|
|
716
|
+
frame_range: 1-based inclusive ``(start, end)`` frames to process. If
|
|
717
|
+
``None``, uses the full trajectory.
|
|
718
|
+
frame_batch_size: Number of frames per batch for pairwise calculations.
|
|
719
|
+
If ``<= 0``, processes all frames in a single batch.
|
|
720
|
+
prune_low_energies_frac: Per-row quantile ``q`` in ``(0, 1]`` used to
|
|
721
|
+
zero out small values independently in both channels.
|
|
722
|
+
output_path: Destination path (with or without ``.zip``). Defaults to
|
|
723
|
+
``RIN_<timestamp>.zip`` in the current working directory.
|
|
724
|
+
keep_prenormalized_energies: If ``True``, stores the pre-normalized
|
|
725
|
+
attractive/repulsive matrices under ``ATTRACTIVE|REPULSIVE_energies``.
|
|
726
|
+
parallel_cpptraj: If ``True``, run multiple cpptraj frame batches in
|
|
727
|
+
parallel using threads (safe w.r.t. pickling).
|
|
728
|
+
simul_cpptraj_instances: Maximum concurrent cpptraj tasks (defaults to
|
|
729
|
+
``os.cpu_count()`` when ``None``).
|
|
730
|
+
num_matrices_in_compressed_blocks: Number of matrices per chunk along
|
|
731
|
+
the leading axis when writing Zarr arrays.
|
|
732
|
+
compression_level: Blosc Zstd compression level for the final ZipStore.
|
|
733
|
+
cpptraj_run_time_limit: Optional timeout (seconds) for cpptraj calls.
|
|
734
|
+
|
|
735
|
+
Returns:
|
|
736
|
+
str: Path to the created ``.zip`` archive (Zarr v3).
|
|
737
|
+
|
|
738
|
+
Raises:
|
|
739
|
+
RuntimeError: Propagated from helper methods (e.g., cpptraj failures).
|
|
740
|
+
ValueError: Propagated from helper methods (e.g., bad frame ranges or
|
|
741
|
+
pruning quantile).
|
|
742
|
+
|
|
743
|
+
Notes:
|
|
744
|
+
* Row-wise L1 normalization produces **directed** transition
|
|
745
|
+
probabilities (rows sum to 1) and therefore breaks symmetry.
|
|
746
|
+
* All linear algebra runs in a single Python thread; BLAS may use
|
|
747
|
+
multiple threads internally. cpptraj parallelism is optional and uses
|
|
748
|
+
threads to avoid pickling constraints.
|
|
749
|
+
"""
|
|
750
|
+
_logger.info(
|
|
751
|
+
"Building RIN (mol=%s, traj=%s, frame_range=%s, frame_batch_size=%s, "
|
|
752
|
+
"keep_abs=%s, parallel_cpptraj=%s, simul_instances=%s, comp_level=%s)",
|
|
753
|
+
molecule_of_interest, trajectory_file, frame_range, frame_batch_size,
|
|
754
|
+
keep_prenormalized_energies, parallel_cpptraj, simul_cpptraj_instances, compression_level
|
|
755
|
+
)
|
|
756
|
+
|
|
757
|
+
# ----------------------------------- MD META DATA -------------------------------------
|
|
758
|
+
total_frames = self._get_number_frames(
|
|
759
|
+
topology_file,
|
|
760
|
+
trajectory_file,
|
|
761
|
+
timeout=cpptraj_run_time_limit
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
molecule_composition = self._get_atomic_composition_of_molecule(
|
|
765
|
+
topology_file,
|
|
766
|
+
trajectory_file,
|
|
767
|
+
molecule_of_interest,
|
|
768
|
+
timeout=cpptraj_run_time_limit
|
|
769
|
+
)
|
|
770
|
+
number_residues = len(molecule_composition)
|
|
771
|
+
_logger.info("MD metadata: total_frames=%d, residues=%d", total_frames, number_residues)
|
|
772
|
+
# --------------------------------------------------------------------------------------
|
|
773
|
+
|
|
774
|
+
# --------------------- AUXILIARY VARIABLES' / TOOLS PREPARATION -----------------------
|
|
775
|
+
current_time = sawnergy_util.current_time()
|
|
776
|
+
attractive_transitions_name = "ATTRACTIVE_transitions"
|
|
777
|
+
repulsive_transitions_name = "REPULSIVE_transitions"
|
|
778
|
+
attractive_energies_name = "ATTRACTIVE_energies"
|
|
779
|
+
repulsive_energies_name = "REPULSIVE_energies"
|
|
780
|
+
simul_cpptraj_instances = simul_cpptraj_instances or (os.cpu_count() or 1)
|
|
781
|
+
output_path = Path((output_path or (Path(os.getcwd()) /
|
|
782
|
+
f"RIN_{current_time}"))).with_suffix(".zip")
|
|
783
|
+
_logger.debug("Output archive path: %s", output_path)
|
|
784
|
+
|
|
785
|
+
# -=- FRAMES OF THE MD SIMULATION -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
|
786
|
+
if frame_batch_size <= 0:
|
|
787
|
+
frame_batch_size = total_frames
|
|
788
|
+
if frame_range is None:
|
|
789
|
+
start_frame, end_frame = 1, total_frames
|
|
790
|
+
else:
|
|
791
|
+
start_frame, end_frame = frame_range
|
|
792
|
+
|
|
793
|
+
if not (1 <= start_frame <= end_frame <= total_frames):
|
|
794
|
+
raise ValueError(f"frame_range must lie within [1, {total_frames}] and be ordered; got {frame_range}.")
|
|
795
|
+
|
|
796
|
+
frames = (
|
|
797
|
+
(s, min(s + frame_batch_size - 1, end_frame))
|
|
798
|
+
for s in range(start_frame, end_frame + 1, max(1, frame_batch_size))
|
|
799
|
+
)
|
|
800
|
+
_logger.debug("Frame selection: [%d..%d], batch_size=%d", start_frame, end_frame, frame_batch_size)
|
|
801
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
|
802
|
+
|
|
803
|
+
# -=- DATA PROCESSORS -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
|
804
|
+
frame_processor = sawnergy_util.elementwise_processor(
|
|
805
|
+
in_parallel=parallel_cpptraj,
|
|
806
|
+
Executor=ThreadPoolExecutor,
|
|
807
|
+
max_workers=simul_cpptraj_instances,
|
|
808
|
+
capture_output=True
|
|
809
|
+
)
|
|
810
|
+
matrix_processor = sawnergy_util.elementwise_processor(
|
|
811
|
+
in_parallel=False, # <- BLAS handles lin. alg. parallelism &
|
|
812
|
+
capture_output=True # the rest of the code is vectorized by default
|
|
813
|
+
)
|
|
814
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=--=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
|
815
|
+
|
|
816
|
+
# -=- ADJUST CPPTRAJ PARALLELISM -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
|
817
|
+
non_bonded_energies_subprocess_env = (
|
|
818
|
+
sawnergy_util.create_updated_subprocess_env(
|
|
819
|
+
OMP_NUM_THREADS=1, # PREVENTING OVERSUBSCRIPTION
|
|
820
|
+
MKL_NUM_THREADS=1,
|
|
821
|
+
OPENBLAS_NUM_THREADS=1,
|
|
822
|
+
MKL_DYNAMIC=False
|
|
823
|
+
) if parallel_cpptraj else None
|
|
824
|
+
)
|
|
825
|
+
_logger.debug("cpptraj parallel: %s (instances=%s)", parallel_cpptraj, simul_cpptraj_instances)
|
|
826
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
|
|
827
|
+
|
|
828
|
+
# create a membership matrix for atoms in residues
|
|
829
|
+
membership_matrix = self._compute_residue_membership_matrix(molecule_composition)
|
|
830
|
+
_logger.info("Membership matrix ready: shape=%s, nnz=%d",
|
|
831
|
+
membership_matrix.shape, int(membership_matrix.sum()))
|
|
832
|
+
# --------------------------------------------------------------------------------------
|
|
833
|
+
|
|
834
|
+
# -------------------- INTERACTION DATA EXTRACTION AND PROCESSING ----------------------
|
|
835
|
+
pipeline = sawnergy_util.compose_steps(
|
|
836
|
+
(self._convert_atomic_to_residue_interactions, {"membership_matrix": membership_matrix}),
|
|
837
|
+
(self._split_into_attractive_repulsive, None),
|
|
838
|
+
(self._prune_low_energies, {"q": prune_low_energies_frac}),
|
|
839
|
+
(self._remove_self_interactions, None),
|
|
840
|
+
(self._symmetrize, None),
|
|
841
|
+
)
|
|
842
|
+
_logger.debug("Post-processing pipeline assembled")
|
|
843
|
+
|
|
844
|
+
with sawnergy_util.ArrayStorage.compress_and_cleanup(output_path, compression_level) as storage:
|
|
845
|
+
_logger.debug("Opened temporary store for writing")
|
|
846
|
+
com_avgs: list[np.ndarray] = []
|
|
847
|
+
|
|
848
|
+
for frame_batch in sawnergy_util.batches_of(
|
|
849
|
+
frames, batch_size=simul_cpptraj_instances if parallel_cpptraj else 1
|
|
850
|
+
):
|
|
851
|
+
_logger.debug("Processing next frame batch")
|
|
852
|
+
|
|
853
|
+
atomic_matrices = frame_processor(
|
|
854
|
+
frame_batch,
|
|
855
|
+
self._calc_avg_atomic_interactions_in_frames,
|
|
856
|
+
topology_file,
|
|
857
|
+
trajectory_file,
|
|
858
|
+
molecule_of_interest,
|
|
859
|
+
subprocess_env=non_bonded_energies_subprocess_env,
|
|
860
|
+
timeout=cpptraj_run_time_limit,
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
interaction_matrices = matrix_processor(atomic_matrices, pipeline)
|
|
864
|
+
|
|
865
|
+
if keep_prenormalized_energies:
|
|
866
|
+
_logger.debug("Writing absolute energy channels")
|
|
867
|
+
for arr in interaction_matrices:
|
|
868
|
+
self._store_two_channel_array(
|
|
869
|
+
arr,
|
|
870
|
+
storage,
|
|
871
|
+
num_matrices_in_compressed_blocks,
|
|
872
|
+
attractive_energies_name if include_attractive else None,
|
|
873
|
+
repulsive_energies_name if include_repulsive else None
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
transition_matrices = matrix_processor(interaction_matrices, self._L1_normalize)
|
|
877
|
+
|
|
878
|
+
_logger.debug("Writing normalized transition channels")
|
|
879
|
+
for arr in transition_matrices:
|
|
880
|
+
self._store_two_channel_array(
|
|
881
|
+
arr,
|
|
882
|
+
storage,
|
|
883
|
+
num_matrices_in_compressed_blocks,
|
|
884
|
+
attractive_transitions_name if include_attractive else None,
|
|
885
|
+
repulsive_transitions_name if include_repulsive else None
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
_logger.debug("Computing batch COMs for this frame batch (len=%d)", len(frame_batch))
|
|
889
|
+
com_lists_per_range = frame_processor(
|
|
890
|
+
frame_batch,
|
|
891
|
+
self._get_residue_COMs_per_frame,
|
|
892
|
+
topology_file,
|
|
893
|
+
trajectory_file,
|
|
894
|
+
molecule_of_interest,
|
|
895
|
+
number_residues,
|
|
896
|
+
timeout=cpptraj_run_time_limit,
|
|
897
|
+
)
|
|
898
|
+
|
|
899
|
+
for i, com_frames in enumerate(com_lists_per_range):
|
|
900
|
+
avg = np.stack(com_frames, axis=0).mean(axis=0).astype(np.float32, copy=False)
|
|
901
|
+
com_avgs.append(avg)
|
|
902
|
+
_logger.debug("Batch %d: COM avg shape=%s", i, avg.shape)
|
|
903
|
+
|
|
904
|
+
_logger.debug("Writing %d batch-averaged COM snapshots (chunk=%d)",
|
|
905
|
+
len(com_avgs), num_matrices_in_compressed_blocks)
|
|
906
|
+
|
|
907
|
+
storage.write(
|
|
908
|
+
com_avgs,
|
|
909
|
+
to_block_named="COM",
|
|
910
|
+
arrays_per_chunk=num_matrices_in_compressed_blocks,
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
# ----------------------------- META-DATA --------------------------------
|
|
914
|
+
storage.add_attr("time_created", current_time)
|
|
915
|
+
storage.add_attr("com_name", "COM")
|
|
916
|
+
storage.add_attr("molecule_of_interest", molecule_of_interest)
|
|
917
|
+
storage.add_attr("frame_range", frame_range)
|
|
918
|
+
storage.add_attr("frame_batch_size", frame_batch_size)
|
|
919
|
+
storage.add_attr("prune_low_energies_frac", prune_low_energies_frac)
|
|
920
|
+
storage.add_attr("attractive_transitions_name", attractive_transitions_name if include_attractive else None)
|
|
921
|
+
storage.add_attr("repulsive_transitions_name", repulsive_transitions_name if include_repulsive else None)
|
|
922
|
+
storage.add_attr("attractive_energies_name",
|
|
923
|
+
attractive_energies_name if include_attractive and keep_prenormalized_energies else None)
|
|
924
|
+
storage.add_attr("repulsive_energies_name",
|
|
925
|
+
repulsive_energies_name if include_repulsive and keep_prenormalized_energies else None)
|
|
926
|
+
# ------------------------------------------------------------------------
|
|
927
|
+
|
|
928
|
+
_logger.info("RIN build complete -> %s", output_path)
|
|
929
|
+
return str(output_path)
|
|
930
|
+
|
|
931
|
+
__all__ = [
|
|
932
|
+
"RINBuilder"
|
|
933
|
+
]
|
|
934
|
+
|
|
935
|
+
if __name__ == "__main__":
|
|
936
|
+
pass
|