sawnergy 1.0.3__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,251 @@
1
+ from __future__ import annotations
2
+
3
+ # third party
4
+ import numpy as np
5
+ import matplotlib as mpl
6
+
7
+ # built-in
8
+ from pathlib import Path
9
+ from typing import Sequence
10
+ import logging
11
+
12
+ # local
13
+ from ..visual import visualizer_util
14
+ from .. import sawnergy_util
15
+
16
+ # *----------------------------------------------------*
17
+ # GLOBALS
18
+ # *----------------------------------------------------*
19
+
20
+ _logger = logging.getLogger(__name__)
21
+
22
+ # *----------------------------------------------------*
23
+ # HELPERS
24
+ # *----------------------------------------------------*
25
+
26
+ def _safe_svd_pca(X: np.ndarray, k: int, *, row_l2: bool = False) -> tuple[np.ndarray, np.ndarray]:
27
+ """Compute k principal directions via SVD and project onto them."""
28
+ if X.ndim != 2:
29
+ raise ValueError(f"PCA expects 2D array (N, D); got {X.shape}")
30
+ _, D = X.shape
31
+ if k not in (2, 3):
32
+ raise ValueError(f"PCA dimensionality must be 2 or 3; got {k}")
33
+ if D < k:
34
+ raise ValueError(f"Requested k={k} exceeds feature dim D={D}")
35
+ Xc = X - X.mean(axis=0, keepdims=True)
36
+ if row_l2:
37
+ norms = np.linalg.norm(Xc, axis=1, keepdims=True)
38
+ Xc = Xc / np.clip(norms, 1e-9, None)
39
+ _, _, Vt = np.linalg.svd(Xc, full_matrices=False)
40
+ comps = Vt[:k].copy()
41
+ proj = Xc @ comps.T
42
+ return proj, comps
43
+
44
+ def _set_equal_axes_3d(ax, xyz: np.ndarray, *, padding: float = 0.05) -> None:
45
+ if xyz.size == 0:
46
+ return
47
+ x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]
48
+ xmin, xmax = float(x.min()), float(x.max())
49
+ ymin, ymax = float(y.min()), float(y.max())
50
+ zmin, zmax = float(z.min()), float(z.max())
51
+ xr = xmax - xmin
52
+ yr = ymax - ymin
53
+ zr = zmax - zmin
54
+ r = max(xr, yr, zr)
55
+ pad = padding * (r if r > 0 else 1.0)
56
+ cx, cy, cz = (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, (zmin + zmax) / 2.0
57
+ ax.set_xlim(cx - r / 2 - pad, cx + r / 2 + pad)
58
+ ax.set_ylim(cy - r / 2 - pad, cy + r / 2 + pad)
59
+ ax.set_zlim(cz - r / 2 - pad, cz + r / 2 + pad)
60
+ try:
61
+ ax.set_box_aspect([1, 1, 1])
62
+ except Exception:
63
+ pass
64
+
65
+ # *----------------------------------------------------*
66
+ # CLASS
67
+ # *----------------------------------------------------*
68
+
69
+ class Visualizer:
70
+ """3D PCA visualizer for per-frame embeddings"""
71
+
72
+ no_instances: bool = True
73
+
74
+ def __init__(
75
+ self,
76
+ EMB_path: str | Path,
77
+ figsize: tuple[int, int] = (9, 7),
78
+ default_node_color: str = visualizer_util.GRAY,
79
+ depthshade: bool = False,
80
+ antialiased: bool = False,
81
+ init_elev: float = 35,
82
+ init_azim: float = 45,
83
+ *,
84
+ show: bool = False,
85
+ normalize_rows: bool = False,
86
+ ) -> None:
87
+ # Backend & pyplot
88
+ visualizer_util.ensure_backend(show)
89
+ import matplotlib.pyplot as plt
90
+ self._plt = plt
91
+
92
+ if Visualizer.no_instances:
93
+ try:
94
+ visualizer_util.warm_start_matplotlib()
95
+ finally:
96
+ Visualizer.no_instances = False
97
+
98
+ # Load embeddings archive
99
+ EMB_path = Path(EMB_path)
100
+ with sawnergy_util.ArrayStorage(EMB_path, mode="r") as storage:
101
+ name = storage.get_attr("frame_embeddings_name")
102
+ E = storage.read(name, slice(None))
103
+ if E.ndim != 3:
104
+ raise ValueError(f"Expected embeddings of shape (T,N,D); got {E.shape}")
105
+ self.E = np.asarray(E)
106
+ self.T, self.N, self.D = map(int, self.E.shape)
107
+ _logger.info("Loaded embeddings: T=%d, N=%d, D=%d", self.T, self.N, self.D)
108
+
109
+ # Coloring normalizer (parity with RIN Visualizer)
110
+ self._residue_norm = mpl.colors.Normalize(0, max(1, self.N - 1))
111
+
112
+ # Figure / axes / artists
113
+ self._fig = self._plt.figure(figsize=figsize, num="SAWNERGY")
114
+ self._ax = None
115
+ self._scatter = None
116
+ self._marker_size = 30.0
117
+ self._init_elev = init_elev
118
+ self._init_azim = init_azim
119
+ self.default_node_color = default_node_color
120
+ self._antialiased = bool(antialiased)
121
+ self._depthshade = bool(depthshade)
122
+ self._normalize_rows = bool(normalize_rows)
123
+
124
+ # ------------------------------ PRIVATE ------------------------------ #
125
+
126
+ def _ensure_axes(self) -> None:
127
+ if self._ax is not None and self._scatter is not None:
128
+ return
129
+ self._fig.clf()
130
+ self._ax = self._fig.add_subplot(111, projection="3d")
131
+ self._ax.view_init(self._init_elev, self._init_azim)
132
+ self._scatter = self._ax.scatter(
133
+ [], [], [],
134
+ s=self._marker_size,
135
+ depthshade=self._depthshade,
136
+ edgecolors="none",
137
+ antialiased=self._antialiased,
138
+ )
139
+ try:
140
+ self._ax.set_axis_off()
141
+ except Exception:
142
+ pass
143
+
144
+ def _project3(self, X: np.ndarray) -> np.ndarray:
145
+ """Return a 3D PCA projection of embeddings (always 3 coordinates).
146
+
147
+ If the embedding dimensionality D < 3, the remaining coordinate(s) are set to 0
148
+ so that the returned array still has shape (N, 3).
149
+ """
150
+ k = 3 if X.shape[1] >= 3 else 2
151
+ P, _ = _safe_svd_pca(X, k, row_l2=self._normalize_rows)
152
+ if k == 2:
153
+ P = np.c_[P, np.zeros((P.shape[0], 1), dtype=P.dtype)]
154
+ return P
155
+
156
+ def _select_nodes(self, displayed_nodes: Sequence[int] | str | None) -> np.ndarray:
157
+ if displayed_nodes is None or displayed_nodes == "ALL":
158
+ return np.arange(self.N, dtype=np.int64)
159
+ idx = np.asarray(displayed_nodes)
160
+ if idx.dtype.kind not in "iu":
161
+ raise TypeError("displayed_nodes must be None, 'ALL', or an integer sequence.")
162
+ if idx.min() < 1 or idx.max() > self.N:
163
+ raise IndexError(f"displayed_nodes out of range [1,{self.N}]")
164
+ return idx.astype(np.int64) - 1
165
+
166
+ def _apply_colors(self, node_colors, idx: np.ndarray) -> np.ndarray:
167
+ # RIN Visualizer semantics:
168
+ if isinstance(node_colors, str):
169
+ node_cmap = self._plt.get_cmap(node_colors)
170
+ return node_cmap(self._residue_norm(idx))
171
+ if node_colors is None:
172
+ full = visualizer_util.map_groups_to_colors(
173
+ N=self.N, groups=None, default_color=self.default_node_color, one_based=True
174
+ )
175
+ return np.asarray(full)[idx]
176
+ arr = np.asarray(node_colors)
177
+ if arr.ndim == 2 and arr.shape[0] == self.N and arr.shape[1] in (3, 4):
178
+ return arr[idx]
179
+ full = visualizer_util.map_groups_to_colors(
180
+ N=self.N, groups=node_colors, default_color=self.default_node_color, one_based=True
181
+ )
182
+ return np.asarray(full)[idx]
183
+
184
+ # ------------------------------ PUBLIC ------------------------------- #
185
+
186
+ def build_frame(
187
+ self,
188
+ frame_id: int,
189
+ *,
190
+ node_colors: str | np.ndarray | None = "rainbow",
191
+ displayed_nodes: Sequence[int] | str | None = "ALL",
192
+ show_node_labels: bool = False,
193
+ show: bool = False
194
+ ) -> None:
195
+ """Render a single frame as a PCA **3D** scatter (matches RIN Visualizer API)."""
196
+ frame0 = int(frame_id) - 1
197
+ if not (0 <= frame0 < self.T):
198
+ raise IndexError(f"frame_id out of range [1,{self.T}]")
199
+ self._ensure_axes()
200
+
201
+ idx = self._select_nodes(displayed_nodes)
202
+ X = self.E[frame0, idx, :] # (n, D)
203
+ P = self._project3(X) # (n, 3)
204
+ colors = self._apply_colors(node_colors, idx)
205
+
206
+ x, y, z = P[:, 0], P[:, 1], P[:, 2]
207
+ self._scatter._offsets3d = (x, y, z)
208
+ self._scatter.set_facecolors(colors)
209
+ _set_equal_axes_3d(self._ax, P, padding=0.05)
210
+ self._ax.view_init(self._init_elev, self._init_azim)
211
+
212
+ if show_node_labels:
213
+ for txt in getattr(self, "_labels", []):
214
+ try:
215
+ txt.remove()
216
+ except Exception:
217
+ pass
218
+ self._labels = []
219
+ for p, nid in zip(P, idx + 1):
220
+ self._labels.append(self._ax.text(p[0], p[1], p[2], str(int(nid)), fontsize=8))
221
+
222
+ try:
223
+ self._fig.tight_layout()
224
+ except Exception:
225
+ try:
226
+ self._fig.subplots_adjust()
227
+ except Exception:
228
+ pass
229
+ try:
230
+ self._fig.canvas.draw_idle()
231
+ except Exception:
232
+ pass
233
+
234
+ if show:
235
+ try:
236
+ self._plt.show(block=True)
237
+ except TypeError:
238
+ self._plt.show()
239
+
240
+ # convenience
241
+ def savefig(self, path: str | Path, *, dpi: int = 150) -> None:
242
+ self._fig.savefig(path, dpi=dpi)
243
+
244
+ def close(self) -> None:
245
+ try:
246
+ self._plt.close(self._fig)
247
+ except Exception:
248
+ pass
249
+
250
+
251
+ __all__ = ["Visualizer"]
sawnergy/logging_util.py CHANGED
@@ -6,7 +6,7 @@ from datetime import datetime
6
6
 
7
7
  def configure_logging(
8
8
  logs_dir: Path | str,
9
- file_level: int = logging.DEBUG,
9
+ file_level: int = logging.WARNING,
10
10
  console_level: int = logging.WARNING
11
11
  ) -> None:
12
12
  """
@@ -669,7 +669,7 @@ class RINBuilder:
669
669
  molecule_of_interest: int,
670
670
  frame_range: tuple[int, int] | None = None,
671
671
  frame_batch_size: int = -1,
672
- prune_low_energies_frac: float = 0.3,
672
+ prune_low_energies_frac: float = 0.85,
673
673
  output_path: str | Path | None = None,
674
674
  keep_prenormalized_energies: bool = True,
675
675
  *,
@@ -690,9 +690,9 @@ class RINBuilder:
690
690
 
691
691
  2. For each frame batch:
692
692
 
693
- a) Run cpptraj `pairwise` on atoms → EMAP + VMAP → sum (atomic matrix).
693
+ a) Run cpptraj 'pairwise' on atoms → EMAP + VMAP → sum (atomic matrix).
694
694
 
695
- b) Project atomic → residue with ``R = Pᵀ @ A @ P``.
695
+ b) Project atomic → residue with 'R = Pᵀ @ A @ P'.
696
696
 
697
697
  c) Post-process residue matrix:
698
698
  split into (attractive, repulsive) channels,
@@ -700,7 +700,7 @@ class RINBuilder:
700
700
  remove self-interactions,
701
701
  symmetrize.
702
702
 
703
- d. Optionally store **pre-normalized energies** (attractive or repulsive or both, depending on `include_<kind>`).
703
+ d. Optionally store **pre-normalized energies** (attractive or repulsive or both, depending on 'include_<kind>').
704
704
 
705
705
  e. Row-wise L1 normalize (directed transition probabilities) and store.
706
706
 
@@ -107,7 +107,7 @@ class Visualizer:
107
107
  visualizer_util.ensure_backend(show)
108
108
  import matplotlib.pyplot as plt
109
109
  self._plt = plt
110
- # ---------- WARM UP MPL ------------ #
110
+ # ---------- WARM UP MPL ------------ #
111
111
  _logger.debug("Visualizer.__init__ start | RIN_path=%s, figsize=%s, node_size=%s, edge_width=%s, depthshade=%s, antialiased=%s, init_view=(%s,%s)",
112
112
  RIN_path, figsize, node_size, edge_width, depthshade, antialiased, init_elev, init_azim)
113
113
  if Visualizer.no_instances:
@@ -116,7 +116,7 @@ class Visualizer:
116
116
  else:
117
117
  _logger.debug("Skipping warm-start (no_instances=False).")
118
118
 
119
- # ---------- LOAD THE DATA ---------- #
119
+ # ---------- LOAD THE DATA ---------- #
120
120
  with sawnergy_util.ArrayStorage(RIN_path, mode="r") as storage:
121
121
  com_name = storage.get_attr("com_name")
122
122
  attr_energies_name = storage.get_attr("attractive_energies_name")
@@ -135,7 +135,7 @@ class Visualizer:
135
135
  self.N = np.size(self.COM_coords[0], axis=0)
136
136
  _logger.debug("Computed N=%d", self.N)
137
137
 
138
- # - SET UP THE CANVAS AND THE AXES - #
138
+ # - SET UP THE CANVAS AND THE AXES - #
139
139
  self._fig = plt.figure(figsize=figsize, num="SAWNERGY")
140
140
  self._ax = self._fig.add_subplot(111, projection="3d")
141
141
  self._fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
@@ -145,14 +145,14 @@ class Visualizer:
145
145
  self._ax.set_axis_off()
146
146
  _logger.debug("Figure and 3D axes initialized.")
147
147
 
148
- # ------ SET UP PLOT ELEMENTS ------ #
148
+ # ------ SET UP PLOT ELEMENTS ------ #
149
149
  self._scatter: PathCollection = self._ax.scatter([], [], [], s=node_size, depthshade=depthshade, edgecolors="none")
150
150
  self._attr: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
151
151
  self._repuls: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
152
152
  self._ax.add_collection3d(self._attr); self._ax.add_collection3d(self._repuls) # set pointers to the attractive and repulsive collections
153
153
  _logger.debug("Artists created | scatter(empty), attr_lines(empty), repuls_lines(empty).")
154
154
 
155
- # ---------- HELPER FIELDS --------- #
155
+ # ---------- HELPER FIELDS --------- #
156
156
  # NOTE: 'under the hood' everything is 0-base indexed,
157
157
  # BUT, from the API point of view, the indexing is 1-base,
158
158
  # because amino acid residues are 1-base indexed.
@@ -160,7 +160,7 @@ class Visualizer:
160
160
  self.default_node_color = default_node_color
161
161
  _logger.debug("Helper fields set | residue_norm=[0,%d], default_node_color=%s", self.N-1, self.default_node_color)
162
162
 
163
- # DISALLOW MPL WARM-UP IN THE FUTURE
163
+ # DISALLOW MPL WARM-UP IN THE FUTURE
164
164
  Visualizer.no_instances = False
165
165
  _logger.debug("Visualizer.no_instances set to False.")
166
166
 
@@ -319,6 +319,9 @@ def build_line_segments(
319
319
  kept = edge_weights >= thresh
320
320
  rows, cols = rows[kept], cols[kept]
321
321
 
322
+ nz = weights[rows, cols] > 0.0
323
+ rows, cols = rows[nz], cols[nz]
324
+
322
325
  if rows.size == 0:
323
326
  _logger.debug("build_line_segments: no edges kept after threshold; returning empties.")
324
327
  return (np.empty((0, 2, 3), dtype=float),
sawnergy/walks/walker.py CHANGED
@@ -63,7 +63,7 @@ class Walker:
63
63
 
64
64
  # Load numpy arrays from read-only storage
65
65
  with sawnergy_util.ArrayStorage(RIN_path, mode="r") as storage:
66
- attr_name = storage.get_attr("attractive_transitions_name")
66
+ attr_name = storage.get_attr("attractive_transitions_name")
67
67
  repuls_name = storage.get_attr("repulsive_transitions_name")
68
68
  attr_matrices : np.ndarray | None = (
69
69
  storage.read(attr_name, slice(None)) if attr_name is not None else None
@@ -133,6 +133,9 @@ class Walker:
133
133
  ) if repuls_matrices is not None else None
134
134
  )
135
135
 
136
+ self._attr_owner_pid = os.getpid() if self.attr_matrices is not None else None
137
+ self._repuls_owner_pid = os.getpid() if self.repuls_matrices is not None else None
138
+
136
139
  _logger.debug(
137
140
  "SharedNDArray created | attr name=%r; repuls name=%r",
138
141
  getattr(self.attr_matrices, "name", None),
@@ -159,38 +162,56 @@ class Walker:
159
162
 
160
163
  # explicit resource cleanup
161
164
  def close(self) -> None:
162
- """Close shared-memory handles and (in main process) unlink segments.
165
+ """Release shared-memory resources used by this Walker.
166
+
167
+ This method:
168
+ - Closes local handles to the shared-memory backed arrays
169
+ (`self.attr_matrices`, `self.repuls_matrices`) in **the current process**.
170
+ - If the current process is the **creator** of a segment (its PID matches
171
+ `_attr_owner_pid` / `_repuls_owner_pid`), it also **unlinks** that segment
172
+ so the OS can reclaim it once all handles are closed.
173
+
174
+ Behavior & guarantees
175
+ ---------------------
176
+ - **Idempotent:** safe to call multiple times; subsequent calls are no-ops.
177
+ - **Multi-process aware:** non-creator processes only close their handles;
178
+ creators close **and** unlink. This prevents `resource_tracker` “leaked
179
+ shared_memory” warnings when using `ProcessPoolExecutor`/spawn.
180
+ - **Best-effort unlink:** `FileNotFoundError` during unlink (already unlinked
181
+ elsewhere) is swallowed.
182
+ - Invoked automatically by the context manager (`__exit__`) and destructor
183
+ (`__del__`), but it's fine to call explicitly.
184
+
185
+ After calling `close()`, any operation that relies on the shared arrays may
186
+ fail; treat the instance as finalized.
163
187
 
164
- Idempotent: if cleanup already occurred, returns immediately. Always
165
- closes local handles in the current process. If the caller is the main
166
- process (per ``sawnergy_util.is_main_process()``), also attempts to
167
- unlink the underlying shared-memory segments (best-effort; suppresses
168
- ``FileNotFoundError`` if already unlinked elsewhere).
188
+ Returns:
189
+ None
169
190
  """
170
191
  if self._memory_cleaned_up:
171
192
  _logger.debug("close(): already cleaned up; returning")
172
193
  return
173
- _logger.debug("Closing Walker resources (is_main=%s)", sawnergy_util.is_main_process())
194
+ _logger.debug("Closing Walker resources (pid=%s)", os.getpid())
174
195
  try:
175
196
  if self.attr_matrices is not None:
176
197
  self.attr_matrices.close()
177
198
  if self.repuls_matrices is not None:
178
199
  self.repuls_matrices.close()
179
200
  _logger.debug("SharedNDArray handles closed")
180
- if sawnergy_util.is_main_process():
181
- _logger.debug("Attempting to unlink shared memory segments (main process)")
182
- try:
183
- if self.attr_matrices is not None:
184
- self.attr_matrices.unlink()
185
- except FileNotFoundError:
186
- _logger.warning("attr SharedMemory already unlinked")
187
- try:
188
- if self.repuls_matrices is not None:
189
- self.repuls_matrices.unlink()
190
- except FileNotFoundError:
191
- _logger.warning("repuls SharedMemory already unlinked")
192
- else:
193
- _logger.debug("Not main process; skipping unlink")
201
+
202
+ # Unlink in whichever process actually CREATED the segment(s)
203
+ try:
204
+ if self.attr_matrices is not None and getattr(self, "_attr_owner_pid", None) == os.getpid():
205
+ self.attr_matrices.unlink()
206
+ except FileNotFoundError:
207
+ _logger.debug("attr SharedMemory already unlinked elsewhere")
208
+
209
+ try:
210
+ if self.repuls_matrices is not None and getattr(self, "_repuls_owner_pid", None) == os.getpid():
211
+ self.repuls_matrices.unlink()
212
+ except FileNotFoundError:
213
+ _logger.debug("repuls SharedMemory already unlinked elsewhere")
214
+
194
215
  finally:
195
216
  self._memory_cleaned_up = True
196
217
  _logger.debug("Cleanup complete")