sawnergy 1.0.6__py3-none-any.whl → 1.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sawnergy might be problematic. Click here for more details.

@@ -0,0 +1,251 @@
1
+ from __future__ import annotations
2
+
3
+ # third party
4
+ import numpy as np
5
+ import matplotlib as mpl
6
+
7
+ # built-in
8
+ from pathlib import Path
9
+ from typing import Sequence
10
+ import logging
11
+
12
+ # local
13
+ from ..visual import visualizer_util
14
+ from .. import sawnergy_util
15
+
16
+ # *----------------------------------------------------*
17
+ # GLOBALS
18
+ # *----------------------------------------------------*
19
+
20
+ _logger = logging.getLogger(__name__)
21
+
22
+ # *----------------------------------------------------*
23
+ # HELPERS
24
+ # *----------------------------------------------------*
25
+
26
+ def _safe_svd_pca(X: np.ndarray, k: int, *, row_l2: bool = False) -> tuple[np.ndarray, np.ndarray]:
27
+ """Compute k principal directions via SVD and project onto them."""
28
+ if X.ndim != 2:
29
+ raise ValueError(f"PCA expects 2D array (N, D); got {X.shape}")
30
+ _, D = X.shape
31
+ if k not in (2, 3):
32
+ raise ValueError(f"PCA dimensionality must be 2 or 3; got {k}")
33
+ if D < k:
34
+ raise ValueError(f"Requested k={k} exceeds feature dim D={D}")
35
+ Xc = X - X.mean(axis=0, keepdims=True)
36
+ if row_l2:
37
+ norms = np.linalg.norm(Xc, axis=1, keepdims=True)
38
+ Xc = Xc / np.clip(norms, 1e-9, None)
39
+ _, _, Vt = np.linalg.svd(Xc, full_matrices=False)
40
+ comps = Vt[:k].copy()
41
+ proj = Xc @ comps.T
42
+ return proj, comps
43
+
44
+ def _set_equal_axes_3d(ax, xyz: np.ndarray, *, padding: float = 0.05) -> None:
45
+ if xyz.size == 0:
46
+ return
47
+ x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]
48
+ xmin, xmax = float(x.min()), float(x.max())
49
+ ymin, ymax = float(y.min()), float(y.max())
50
+ zmin, zmax = float(z.min()), float(z.max())
51
+ xr = xmax - xmin
52
+ yr = ymax - ymin
53
+ zr = zmax - zmin
54
+ r = max(xr, yr, zr)
55
+ pad = padding * (r if r > 0 else 1.0)
56
+ cx, cy, cz = (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, (zmin + zmax) / 2.0
57
+ ax.set_xlim(cx - r / 2 - pad, cx + r / 2 + pad)
58
+ ax.set_ylim(cy - r / 2 - pad, cy + r / 2 + pad)
59
+ ax.set_zlim(cz - r / 2 - pad, cz + r / 2 + pad)
60
+ try:
61
+ ax.set_box_aspect([1, 1, 1])
62
+ except Exception:
63
+ pass
64
+
65
+ # *----------------------------------------------------*
66
+ # CLASS
67
+ # *----------------------------------------------------*
68
+
69
+ class Visualizer:
70
+ """3D PCA visualizer for per-frame embeddings"""
71
+
72
+ no_instances: bool = True
73
+
74
+ def __init__(
75
+ self,
76
+ EMB_path: str | Path,
77
+ figsize: tuple[int, int] = (9, 7),
78
+ default_node_color: str = visualizer_util.GRAY,
79
+ depthshade: bool = False,
80
+ antialiased: bool = False,
81
+ init_elev: float = 35,
82
+ init_azim: float = 45,
83
+ *,
84
+ show: bool = False,
85
+ normalize_rows: bool = False,
86
+ ) -> None:
87
+ # Backend & pyplot
88
+ visualizer_util.ensure_backend(show)
89
+ import matplotlib.pyplot as plt
90
+ self._plt = plt
91
+
92
+ if Visualizer.no_instances:
93
+ try:
94
+ visualizer_util.warm_start_matplotlib()
95
+ finally:
96
+ Visualizer.no_instances = False
97
+
98
+ # Load embeddings archive
99
+ EMB_path = Path(EMB_path)
100
+ with sawnergy_util.ArrayStorage(EMB_path, mode="r") as storage:
101
+ name = storage.get_attr("frame_embeddings_name")
102
+ E = storage.read(name, slice(None))
103
+ if E.ndim != 3:
104
+ raise ValueError(f"Expected embeddings of shape (T,N,D); got {E.shape}")
105
+ self.E = np.asarray(E)
106
+ self.T, self.N, self.D = map(int, self.E.shape)
107
+ _logger.info("Loaded embeddings: T=%d, N=%d, D=%d", self.T, self.N, self.D)
108
+
109
+ # Coloring normalizer (parity with RIN Visualizer)
110
+ self._residue_norm = mpl.colors.Normalize(0, max(1, self.N - 1))
111
+
112
+ # Figure / axes / artists
113
+ self._fig = self._plt.figure(figsize=figsize, num="SAWNERGY")
114
+ self._ax = None
115
+ self._scatter = None
116
+ self._marker_size = 30.0
117
+ self._init_elev = init_elev
118
+ self._init_azim = init_azim
119
+ self.default_node_color = default_node_color
120
+ self._antialiased = bool(antialiased)
121
+ self._depthshade = bool(depthshade)
122
+ self._normalize_rows = bool(normalize_rows)
123
+
124
+ # ------------------------------ PRIVATE ------------------------------ #
125
+
126
+ def _ensure_axes(self) -> None:
127
+ if self._ax is not None and self._scatter is not None:
128
+ return
129
+ self._fig.clf()
130
+ self._ax = self._fig.add_subplot(111, projection="3d")
131
+ self._ax.view_init(self._init_elev, self._init_azim)
132
+ self._scatter = self._ax.scatter(
133
+ [], [], [],
134
+ s=self._marker_size,
135
+ depthshade=self._depthshade,
136
+ edgecolors="none",
137
+ antialiased=self._antialiased,
138
+ )
139
+ try:
140
+ self._ax.set_axis_off()
141
+ except Exception:
142
+ pass
143
+
144
+ def _project3(self, X: np.ndarray) -> np.ndarray:
145
+ """Return a 3D PCA projection of embeddings (always 3 coordinates).
146
+
147
+ If the embedding dimensionality D < 3, the remaining coordinate(s) are set to 0
148
+ so that the returned array still has shape (N, 3).
149
+ """
150
+ k = 3 if X.shape[1] >= 3 else 2
151
+ P, _ = _safe_svd_pca(X, k, row_l2=self._normalize_rows)
152
+ if k == 2:
153
+ P = np.c_[P, np.zeros((P.shape[0], 1), dtype=P.dtype)]
154
+ return P
155
+
156
+ def _select_nodes(self, displayed_nodes: Sequence[int] | str | None) -> np.ndarray:
157
+ if displayed_nodes is None or displayed_nodes == "ALL":
158
+ return np.arange(self.N, dtype=np.int64)
159
+ idx = np.asarray(displayed_nodes)
160
+ if idx.dtype.kind not in "iu":
161
+ raise TypeError("displayed_nodes must be None, 'ALL', or an integer sequence.")
162
+ if idx.min() < 1 or idx.max() > self.N:
163
+ raise IndexError(f"displayed_nodes out of range [1,{self.N}]")
164
+ return idx.astype(np.int64) - 1
165
+
166
+ def _apply_colors(self, node_colors, idx: np.ndarray) -> np.ndarray:
167
+ # RIN Visualizer semantics:
168
+ if isinstance(node_colors, str):
169
+ node_cmap = self._plt.get_cmap(node_colors)
170
+ return node_cmap(self._residue_norm(idx))
171
+ if node_colors is None:
172
+ full = visualizer_util.map_groups_to_colors(
173
+ N=self.N, groups=None, default_color=self.default_node_color, one_based=True
174
+ )
175
+ return np.asarray(full)[idx]
176
+ arr = np.asarray(node_colors)
177
+ if arr.ndim == 2 and arr.shape[0] == self.N and arr.shape[1] in (3, 4):
178
+ return arr[idx]
179
+ full = visualizer_util.map_groups_to_colors(
180
+ N=self.N, groups=node_colors, default_color=self.default_node_color, one_based=True
181
+ )
182
+ return np.asarray(full)[idx]
183
+
184
+ # ------------------------------ PUBLIC ------------------------------- #
185
+
186
+ def build_frame(
187
+ self,
188
+ frame_id: int,
189
+ *,
190
+ node_colors: str | np.ndarray | None = "rainbow",
191
+ displayed_nodes: Sequence[int] | str | None = "ALL",
192
+ show_node_labels: bool = False,
193
+ show: bool = False
194
+ ) -> None:
195
+ """Render a single frame as a PCA **3D** scatter (matches RIN Visualizer API)."""
196
+ frame0 = int(frame_id) - 1
197
+ if not (0 <= frame0 < self.T):
198
+ raise IndexError(f"frame_id out of range [1,{self.T}]")
199
+ self._ensure_axes()
200
+
201
+ idx = self._select_nodes(displayed_nodes)
202
+ X = self.E[frame0, idx, :] # (n, D)
203
+ P = self._project3(X) # (n, 3)
204
+ colors = self._apply_colors(node_colors, idx)
205
+
206
+ x, y, z = P[:, 0], P[:, 1], P[:, 2]
207
+ self._scatter._offsets3d = (x, y, z)
208
+ self._scatter.set_facecolors(colors)
209
+ _set_equal_axes_3d(self._ax, P, padding=0.05)
210
+ self._ax.view_init(self._init_elev, self._init_azim)
211
+
212
+ if show_node_labels:
213
+ for txt in getattr(self, "_labels", []):
214
+ try:
215
+ txt.remove()
216
+ except Exception:
217
+ pass
218
+ self._labels = []
219
+ for p, nid in zip(P, idx + 1):
220
+ self._labels.append(self._ax.text(p[0], p[1], p[2], str(int(nid)), fontsize=8))
221
+
222
+ try:
223
+ self._fig.tight_layout()
224
+ except Exception:
225
+ try:
226
+ self._fig.subplots_adjust()
227
+ except Exception:
228
+ pass
229
+ try:
230
+ self._fig.canvas.draw_idle()
231
+ except Exception:
232
+ pass
233
+
234
+ if show:
235
+ try:
236
+ self._plt.show(block=True)
237
+ except TypeError:
238
+ self._plt.show()
239
+
240
+ # convenience
241
+ def savefig(self, path: str | Path, *, dpi: int = 150) -> None:
242
+ self._fig.savefig(path, dpi=dpi)
243
+
244
+ def close(self) -> None:
245
+ try:
246
+ self._plt.close(self._fig)
247
+ except Exception:
248
+ pass
249
+
250
+
251
+ __all__ = ["Visualizer"]
sawnergy/logging_util.py CHANGED
@@ -6,7 +6,7 @@ from datetime import datetime
6
6
 
7
7
  def configure_logging(
8
8
  logs_dir: Path | str,
9
- file_level: int = logging.DEBUG,
9
+ file_level: int = logging.WARNING,
10
10
  console_level: int = logging.WARNING
11
11
  ) -> None:
12
12
  """
@@ -669,7 +669,7 @@ class RINBuilder:
669
669
  molecule_of_interest: int,
670
670
  frame_range: tuple[int, int] | None = None,
671
671
  frame_batch_size: int = -1,
672
- prune_low_energies_frac: float = 0.3,
672
+ prune_low_energies_frac: float = 0.85,
673
673
  output_path: str | Path | None = None,
674
674
  keep_prenormalized_energies: bool = True,
675
675
  *,
@@ -107,7 +107,7 @@ class Visualizer:
107
107
  visualizer_util.ensure_backend(show)
108
108
  import matplotlib.pyplot as plt
109
109
  self._plt = plt
110
- # ---------- WARM UP MPL ------------ #
110
+ # ---------- WARM UP MPL ------------ #
111
111
  _logger.debug("Visualizer.__init__ start | RIN_path=%s, figsize=%s, node_size=%s, edge_width=%s, depthshade=%s, antialiased=%s, init_view=(%s,%s)",
112
112
  RIN_path, figsize, node_size, edge_width, depthshade, antialiased, init_elev, init_azim)
113
113
  if Visualizer.no_instances:
@@ -116,7 +116,7 @@ class Visualizer:
116
116
  else:
117
117
  _logger.debug("Skipping warm-start (no_instances=False).")
118
118
 
119
- # ---------- LOAD THE DATA ---------- #
119
+ # ---------- LOAD THE DATA ---------- #
120
120
  with sawnergy_util.ArrayStorage(RIN_path, mode="r") as storage:
121
121
  com_name = storage.get_attr("com_name")
122
122
  attr_energies_name = storage.get_attr("attractive_energies_name")
@@ -135,7 +135,7 @@ class Visualizer:
135
135
  self.N = np.size(self.COM_coords[0], axis=0)
136
136
  _logger.debug("Computed N=%d", self.N)
137
137
 
138
- # - SET UP THE CANVAS AND THE AXES - #
138
+ # - SET UP THE CANVAS AND THE AXES - #
139
139
  self._fig = plt.figure(figsize=figsize, num="SAWNERGY")
140
140
  self._ax = self._fig.add_subplot(111, projection="3d")
141
141
  self._fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
@@ -145,14 +145,14 @@ class Visualizer:
145
145
  self._ax.set_axis_off()
146
146
  _logger.debug("Figure and 3D axes initialized.")
147
147
 
148
- # ------ SET UP PLOT ELEMENTS ------ #
148
+ # ------ SET UP PLOT ELEMENTS ------ #
149
149
  self._scatter: PathCollection = self._ax.scatter([], [], [], s=node_size, depthshade=depthshade, edgecolors="none")
150
150
  self._attr: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
151
151
  self._repuls: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
152
152
  self._ax.add_collection3d(self._attr); self._ax.add_collection3d(self._repuls) # set pointers to the attractive and repulsive collections
153
153
  _logger.debug("Artists created | scatter(empty), attr_lines(empty), repuls_lines(empty).")
154
154
 
155
- # ---------- HELPER FIELDS --------- #
155
+ # ---------- HELPER FIELDS --------- #
156
156
  # NOTE: 'under the hood' everything is 0-base indexed,
157
157
  # BUT, from the API point of view, the indexing is 1-base,
158
158
  # because amino acid residues are 1-base indexed.
@@ -160,7 +160,7 @@ class Visualizer:
160
160
  self.default_node_color = default_node_color
161
161
  _logger.debug("Helper fields set | residue_norm=[0,%d], default_node_color=%s", self.N-1, self.default_node_color)
162
162
 
163
- # DISALLOW MPL WARM-UP IN THE FUTURE
163
+ # DISALLOW MPL WARM-UP IN THE FUTURE
164
164
  Visualizer.no_instances = False
165
165
  _logger.debug("Visualizer.no_instances set to False.")
166
166
 
@@ -319,6 +319,9 @@ def build_line_segments(
319
319
  kept = edge_weights >= thresh
320
320
  rows, cols = rows[kept], cols[kept]
321
321
 
322
+ nz = weights[rows, cols] > 0.0
323
+ rows, cols = rows[nz], cols[nz]
324
+
322
325
  if rows.size == 0:
323
326
  _logger.debug("build_line_segments: no edges kept after threshold; returning empties.")
324
327
  return (np.empty((0, 2, 3), dtype=float),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sawnergy
3
- Version: 1.0.6
3
+ Version: 1.0.8
4
4
  Summary: Toolkit for transforming molecular dynamics (MD) trajectories into rich graph representations
5
5
  Home-page: https://github.com/Yehor-Mishchyriak/SAWNERGY
6
6
  Author: Yehor Mishchyriak
@@ -39,19 +39,57 @@ Dynamic: summary
39
39
  ![Python](https://img.shields.io/badge/python-3.11%2B-blue)
40
40
 
41
41
  A toolkit for transforming molecular dynamics (MD) trajectories into rich graph representations, sampling
42
- random and self-avoiding walks, learning node embeddings, and visualising residue interaction networks (RINs). SAWNERGY
42
+ random and self-avoiding walks, learning node embeddings, and visualizing residue interaction networks (RINs). SAWNERGY
43
43
  keeps the full workflow — from `cpptraj` output to skip-gram embeddings (node2vec approach) — inside Python, backed by efficient Zarr-based archives and optional GPU acceleration.
44
44
 
45
45
  ---
46
46
 
47
+ ## Installation
48
+
49
+ ```bash
50
+ pip install sawnergy
51
+ ```
52
+
53
+ > **Optional:** For GPU training, install PyTorch separately (e.g., `pip install torch`).
54
+ > **Note:** RIN building requires `cpptraj` (AmberTools). Ensure it is discoverable via `$PATH` or the `CPPTRAJ`
55
+ > environment variable. Probably the easiest solution: install AmberTools via Conda, activate the environment, and SAWNERGY will find the cpptraj executable on its own, so just run your code and don't worry about it.
56
+
57
+ ---
58
+
59
+ # UPDATES:
60
+
61
+ ## v1.0.8 — What’s new:
62
+ - **Temporary deprecation of `SGNS_Torch`**
63
+ - `sawnergy.embedding.SGNS_Torch` currently produces noisy embeddings in practice. The issue likely stems from **weight initialization**, although the root cause has not yet been conclusively determined.
64
+ - **Action:** The class and its `__init__` docstring now carry a deprecation notice. Constructing the class emits a **`DeprecationWarning`** and logs a **warning**.
65
+ - **Use instead:** Prefer **`SG_Torch`** (plain Skip-Gram with full softmax) or the PureML backends **`SGNS_PureML`** / **`SG_PureML`**.
66
+ - **Compatibility:** No breaking API changes; imports remain stable. PureML backends are unaffected.
67
+ - **Embedding visualizer update**
68
+ - Now you can L2 normalize your embeddings before display.
69
+ - **Small improvements in the embedding module**
70
+ - Improved API with a lot of good defaults in place to ease usage out of the box.
71
+ - Small internal model tweaks.
72
+
73
+ ## v1.0.7 — What’s new:
74
+ - **Added plain Skip-Gram model**
75
+ - Now, the user can choose if they want to apply the negative sampling technique (two binary classifiers) or train a single classifier over the vocabulary (full softmax). For more detail, see: [node2vec](https://arxiv.org/pdf/1607.00653), [word2vec](https://arxiv.org/pdf/1301.3781), and [negative_sampling](https://arxiv.org/pdf/1402.3722).
76
+ - **Set a harsher default for low interaction energies pruning during RIN construction**
77
+ - Now we zero out 85% of the lowest interaction energies as opposed to the past 30% default, leading to more meaningful embeddings.
78
+ - **BUG FIX: Visualizer**
79
+ - Previously, the visualizer would silently draw edges of 0 magnitude, meaning they were actually being drawn but were invisible due to full transparency and 0 width. As a result, the displayed image/animation would be very laggy. Now, this was fixed, and given the higher pruning default, the displayed interaction networks are clean and smooth under rotations, dragging, etc.
80
+ - **New Embedding Visualizer (3D)**
81
+ - New lightweight viewer for per-frame embeddings that projects embeddings with PCA to a **3D** scatter. Supports the same node coloring semantics, optional node labels, and the same antialiasing/depthshade controls. Works in headless setups using the same backend guard and uses a blocking `show=True` for scripts.
82
+
83
+ ---
84
+
47
85
  ## Why SAWNERGY?
48
86
 
49
87
  - **Bridge simulations and graph ML**: Convert raw MD trajectories into residue interaction networks ready for graph
50
88
  algorithms and downstream machine learning tasks.
51
- - **Deterministic, shareable artefacts**: Every stage produces compressed Zarr archives that contain both data and metadata so runs can be reproduced, shared, or inspected later.
52
- - **High-performance data handling**: Heavy arrays live in shared memory during walk sampling to allow parallel processing without serealization overhead; archives are written in chunked, compressed form for fast read/write.
53
- - **Flexible embedding backends**: Train skip-gram with negative sampling (SGNS) models using either PureML or PyTorch.
54
- - **Visualization out of the box**: Plot and animate residue networks without leaving Python, using the data produced by RINBuilder
89
+ - **Deterministic, shareable artifacts**: Every stage produces compressed Zarr archives that contain both data and metadata so runs can be reproduced, shared, or inspected later.
90
+ - **High-performance data handling**: Heavy arrays live in shared memory during walk sampling to allow parallel processing without serialization overhead; archives are written in chunked, compressed form for fast read/write.
91
+ - **Flexible objectives & backends**: Train Skip-Gram with **negative sampling** (`objective="sgns"`) or **plain Skip-Gram** (`objective="sg"`), using either **PureML** (default) or **PyTorch**.
92
+ - **Visualization out of the box**: Plot and animate residue networks without leaving Python, using the data produced by RINBuilder.
55
93
 
56
94
  ---
57
95
 
@@ -91,9 +129,9 @@ node indexing, and RNG seeds stay consistent across the toolchain.
91
129
  * Wraps the AmberTools `cpptraj` executable to:
92
130
  - compute per-frame electrostatic (EMAP) and van der Waals (VMAP) energy matrices at the atomic level,
93
131
  - project atom–atom interactions to residue–residue interactions using compositional masks,
94
- - prune, symmetrise, remove self-interactions, and L1-normalise the matrices,
95
- - compute per-residue centres of mass (COM) over the same frames.
96
- * Outputs a compressed Zarr archive with transition matrices, optional prenormalised energies, COM snapshots, and rich
132
+ - prune, symmetrize, remove self-interactions, and L1-normalize the matrices,
133
+ - compute per-residue centers of mass (COM) over the same frames.
134
+ * Outputs a compressed Zarr archive with transition matrices, optional pre-normalized energies, COM snapshots, and rich
97
135
  metadata (frame range, pruning quantile, molecule ID, etc.).
98
136
  * Supports parallel `cpptraj` execution, batch processing, and keeps temporary stores tidy via
99
137
  `ArrayStorage.compress_and_cleanup`.
@@ -103,7 +141,7 @@ node indexing, and RNG seeds stay consistent across the toolchain.
103
141
  * Opens RIN archives, resolves dataset names from attributes, and renders nodes plus attractive/repulsive edge bundles
104
142
  in 3D using Matplotlib.
105
143
  * Allows both static frame visualization and trajectory animation.
106
- * Handles backend selection (`Agg` fallback in headless environments) and offers convenient colour palettes via
144
+ * Handles backend selection (`Agg` fallback in headless environments) and offers convenient color palettes via
107
145
  `visualizer_util`.
108
146
 
109
147
  ### `sawnergy.walks.Walker`
@@ -116,13 +154,10 @@ node indexing, and RNG seeds stay consistent across the toolchain.
116
154
 
117
155
  ### `sawnergy.embedding.Embedder`
118
156
 
119
- * Consumes walk archives, generates skip-gram pairs, and normalises them to 0-based indices.
120
- * Provides a unified interface to SGNS implementations:
121
- - **PureML backend** (`SGNS_PureML`): works with the `pureml` ecosystem, optimistic for CPU training.
122
- - **PyTorch backend** (`SGNS_Torch`): uses `torch.nn.Embedding` plays nicely with GPUs.
123
- * Both `SGNS_PureML` and `SGNS_Torch` accept training hyperparameters such as batch_size, LR, optimizer and LR_scheduler, etc.
124
- * Exposes `embed_frame` (single frame) and `embed_all` (all frames, deterministic seeding per frame) which return the
125
- learned input embedding matrices and write them to disk when requested.
157
+ * Consumes walk archives, generates skip-gram pairs, and normalizes them to 0-based indices.
158
+ * Selects skip-gram (SG / SGNS) backends dynamically via `model_base="pureml"|"torch"` with per-backend overrides supplied through `model_kwargs`.
159
+ * Handles deterministic per-frame seeding and returns the requested embedding `kind` (`"in"`, `"out"`, or `"avg"`) from `embed_frame` and `embed_all`.
160
+ * Persists per-frame matrices with rich provenance (walk metadata, objective, hyperparameters, RNG seeds) when `embed_all` targets an output archive.
126
161
 
127
162
  ### Supporting Utilities
128
163
 
@@ -140,23 +175,13 @@ node indexing, and RNG seeds stay consistent across the toolchain.
140
175
  |---|---|---|
141
176
  | **RIN** | `ATTRACTIVE_transitions` → **(T, N, N)**, float32 • `REPULSIVE_transitions` → **(T, N, N)**, float32 (optional) • `ATTRACTIVE_energies` → **(T, N, N)**, float32 (optional) • `REPULSIVE_energies` → **(T, N, N)**, float32 (optional) • `COM` → **(T, N, 3)**, float32 | `time_created` (ISO) • `com_name` = `"COM"` • `molecule_of_interest` (int) • `frame_range` = `(start, end)` inclusive • `frame_batch_size` (int) • `prune_low_energies_frac` (float in [0,1]) • `attractive_transitions_name` / `repulsive_transitions_name` (dataset names or `None`) • `attractive_energies_name` / `repulsive_energies_name` (dataset names or `None`) |
142
177
  | **Walks** | `ATTRACTIVE_RWs` → **(T, N·num_RWs, L+1)**, int32 (optional) • `REPULSIVE_RWs` → **(T, N·num_RWs, L+1)**, int32 (optional) • `ATTRACTIVE_SAWs` → **(T, N·num_SAWs, L+1)**, int32 (optional) • `REPULSIVE_SAWs` → **(T, N·num_SAWs, L+1)**, int32 (optional) <br/>_Note:_ node IDs are **1-based**.| `time_created` (ISO) • `seed` (int) • `rng_scheme` = `"SeedSequence.spawn_per_batch_v1"` • `num_workers` (int) • `in_parallel` (bool) • `batch_size_nodes` (int) • `num_RWs` / `num_SAWs` (ints) • `node_count` (N) • `time_stamp_count` (T) • `walk_length` (L) • `walks_per_node` (int) • `attractive_RWs_name` / `repulsive_RWs_name` / `attractive_SAWs_name` / `repulsive_SAWs_name` (dataset names or `None`) • `walks_layout` = `"time_leading_3d"` |
143
- | **Embeddings** | `FRAME_EMBEDDINGS` → **(frames_written, vocab_size, D)**, typically float32 | `time_created` (ISO) • `seed` (int) • `rng_scheme` = `"SeedSequence.spawn_per_frame_v1"` • `source_walks_path` (str) • `model_base` = `"torch"` or `"pureml"` • `rin_type` = `"attr"` or `"repuls"` • `using_mode` = `"RW"|"SAW"|"merged"` • `window_size` (int) • `alpha` (float; noise exponent) • `dimensionality` = D • `num_negative_samples` (int) • `num_epochs` (int) • `batch_size` (int) • `shuffle_data` (bool) • `frames_written` (int) • `vocab_size` (int) • `frame_count` (int) • `embedding_dtype` (str) • `frame_embeddings_name` = `"FRAME_EMBEDDINGS"` • `arrays_per_chunk` (int) • `compression_level` (int) |
178
+ | **Embeddings** | `FRAME_EMBEDDINGS` → **(T, N, D)**, float32 | `created_at` (ISO) • `frame_embeddings_name` = `"FRAME_EMBEDDINGS"` • `time_stamp_count` = T • `node_count` = N • `embedding_dim` = D • `model_base` = `"torch"` or `"pureml"` • `embedding_kind` = `"in"|"out"|"avg"` • `objective` = `"sgns"` or `"sg"` • `negative_sampling` (bool) • `num_negative_samples` (int) `num_epochs` (int) • `batch_size` (int) • `window_size` (int) • `alpha` (float) • `lr_step_per_batch` (bool) • `shuffle_data` (bool) • `device_hint` (str) • `model_kwargs_repr` (repr string) • `RIN_type` = `"attr"` or `"repuls"` • `using` = `"RW"|"SAW"|"merged"` • `source_WALKS_path` (str) • `walk_length` (int) • `num_RWs` / `num_SAWs` (ints) • `attractive_*_name` / `repulsive_*_name` (dataset names or `None`) • `master_seed` (int) • `per_frame_seeds` (list[int]) • `arrays_per_chunk` (int) • `compression_level` (int) |
144
179
 
145
180
  **Notes**
146
181
 
147
- - In **RIN**, `T` equals the number of frame **batches** written (i.e., `frame_range` swept in steps of `frame_batch_size`). `ATTRACTIVE/REPULSIVE_energies` are **pre-normalised** absolute energies (written only when `keep_prenormalized_energies=True`), whereas `ATTRACTIVE/REPULSIVE_transitions` are the **row-wise L1-normalised** versions used for sampling.
182
+ - In **RIN**, `T` equals the number of frame **batches** written (i.e., `frame_range` swept in steps of `frame_batch_size`). `ATTRACTIVE/REPULSIVE_energies` are **pre-normalized** absolute energies (written only when `keep_prenormalized_energies=True`), whereas `ATTRACTIVE/REPULSIVE_transitions` are the **row-wise L1-normalized** versions used for sampling.
148
183
  - All archives are Zarr v3 groups. ArrayStorage also maintains per-block metadata in root attrs: `array_chunk_size_in_block`, `array_shape_in_block`, and `array_dtype_in_block` (dicts keyed by dataset name). You’ll see these in every archive.
149
-
150
- ---
151
-
152
- ## Installation
153
-
154
- ```bash
155
- pip install sawnergy
156
- ```
157
-
158
- > **Note:** RIN building requires `cpptraj` (AmberTools). Ensure it is discoverable via `$PATH` or the `CPPTRAJ`
159
- > environment variable.
184
+ - In **Embeddings**, `alpha` and `num_negative_samples` apply to **SGNS** only and are ignored for `objective="sg"`.
160
185
 
161
186
  ---
162
187
 
@@ -181,10 +206,10 @@ rin_builder.build_rin(
181
206
  molecule_of_interest=1,
182
207
  frame_range=(1, 100),
183
208
  frame_batch_size=10,
184
- prune_low_energies_frac=0.3,
209
+ prune_low_energies_frac=0.85,
185
210
  output_path=rin_path,
186
211
  include_attractive=True,
187
- include_repulsive=False,
212
+ include_repulsive=False
188
213
  )
189
214
 
190
215
  # 2. Sample walks from the RIN
@@ -192,52 +217,43 @@ walker = Walker(rin_path, seed=123)
192
217
  walks_path = Path("./WALKS_demo.zip")
193
218
  walker.sample_walks(
194
219
  walk_length=16,
195
- walks_per_node=32,
220
+ walks_per_node=100,
196
221
  saw_frac=0.25,
197
222
  include_attractive=True,
198
223
  include_repulsive=False,
199
224
  time_aware=False,
200
225
  output_path=walks_path,
201
- in_parallel=False,
226
+ in_parallel=False
202
227
  )
203
228
  walker.close()
204
229
 
205
230
  # 3. Train embeddings per frame (PyTorch backend)
206
231
  import torch
207
232
 
208
- embedder = Embedder(walks_path, base="torch", seed=999)
233
+ embedder = Embedder(walks_path, seed=999)
209
234
  embeddings_path = embedder.embed_all(
210
235
  RIN_type="attr",
211
236
  using="merged",
237
+ num_epochs=10,
238
+ negative_sampling=False,
212
239
  window_size=4,
213
- num_negative_samples=5,
214
- num_epochs=5,
215
- batch_size=1024,
216
- dimensionality=128,
217
- shuffle_data=True,
218
- output_path="./EMBEDDINGS_demo.zip",
219
- sgns_kwargs={
220
- "optim": torch.optim.Adam,
221
- "optim_kwargs": {"lr": 1e-3},
222
- "lr_sched": torch.optim.lr_scheduler.LambdaLR,
223
- "lr_sched_kwargs": {"lr_lambda": lambda _: 1.0},
224
- "device": "cuda" if torch.cuda.is_available() else "cpu",
225
- },
240
+ device="cuda" if torch.cuda.is_available() else "cpu",
241
+ model_base="torch",
242
+ output_path="./EMBEDDINGS_demo.zip"
226
243
  )
227
244
  print("Embeddings written to", embeddings_path)
228
245
  ```
229
246
 
230
- > For the PureML backend, supply the relevant optimiser and scheduler via `sgns_kwargs`
231
- > (for example `optim=pureml.optimizers.Adam`, `lr_sched=pureml.optimizers.CosineAnnealingLR`).
247
+ > For the PureML backend, set `model_base="pureml"` and pass the optimizer / scheduler classes inside `model_kwargs`.
232
248
 
233
249
  ---
234
250
 
235
- ## Visualisation
251
+ ## Visualization
236
252
 
237
253
  ```python
238
254
  from sawnergy.visual import Visualizer
239
255
 
240
- v = sawnergy.visual.Visualizer("./RIN_demo.zip")
256
+ v = Visualizer("./RIN_demo.zip")
241
257
  v.build_frame(1,
242
258
  node_colors="rainbow",
243
259
  displayed_nodes="ALL",
@@ -250,14 +266,20 @@ v.build_frame(1,
250
266
 
251
267
  `Visualizer` lazily loads datasets and works even in headless environments (falls back to the `Agg` backend).
252
268
 
269
+ ```python
270
+ from sawnergy.embedding import Visualizer
271
+
272
+ viz = Visualizer("./EMBEDDINGS_demo.zip", normalize_rows=True)
273
+ viz.build_frame(1, show=True)
274
+ ```
275
+
253
276
  ---
254
277
 
255
278
  ## Advanced Notes
256
279
 
257
280
  - **Time-aware walks**: Set `time_aware=True`, provide `stickiness` and `on_no_options` when calling `Walker.sample_walks`.
258
281
  - **Shared memory lifecycle**: Call `Walker.close()` (or use a context manager) to release shared-memory segments.
259
- - **PureML vs PyTorch**: Choose the backend via `Embedder(..., base="pureml"|"torch")` and provide backend-specific
260
- constructor kwargs through `sgns_kwargs` (optimizer, scheduler, device).
282
+ - **PureML vs PyTorch**: Select the backend at call time with `model_base="pureml"|"torch"` (defaults to `"pureml"`) and pass optimizer / scheduler overrides through `model_kwargs`.
261
283
  - **ArrayStorage utilities**: Use `ArrayStorage` directly to peek into archives, append arrays, or manage metadata.
262
284
 
263
285
  ---
@@ -268,8 +290,9 @@ v.build_frame(1,
268
290
  ├── sawnergy/
269
291
  │ ├── rin/ # RINBuilder and cpptraj integration helpers
270
292
  │ ├── walks/ # Walker class and shared-memory utilities
271
- │ ├── embedding/ # Embedder + SGNS backends (PureML / PyTorch)
293
+ │ ├── embedding/ # Embedder + SG/SGNS backends (PureML / PyTorch)
272
294
  │ ├── visual/ # Visualizer and palette utilities
295
+ │ │
273
296
  │ ├── logging_util.py
274
297
  │ └── sawnergy_util.py
275
298
 
@@ -278,7 +301,7 @@ v.build_frame(1,
278
301
 
279
302
  ---
280
303
 
281
- ## Acknowledgements
304
+ ## Acknowledgments
282
305
 
283
306
  SAWNERGY builds on the AmberTools `cpptraj` ecosystem, NumPy, Matplotlib, Zarr, and PyTorch (for GPU acceleration if necessary; PureML is available by default).
284
307
  Big thanks to the upstream communities whose work makes this toolkit possible.
@@ -0,0 +1,23 @@
1
+ sawnergy/__init__.py,sha256=Dq1U38ah6nPRFEDKN41mYphcTynKfnItca6QkYkpSbs,248
2
+ sawnergy/logging_util.py,sha256=mfYw8IsYtOfCXayjkd4g9jHuupluxRNbqyFegRkiAhQ,1476
3
+ sawnergy/sawnergy_util.py,sha256=Htx9wr0S8TXt5aHT2mtEdYf1TCo_BC1IUwNNuZdIR-4,49432
4
+ sawnergy/embedding/SGNS_pml.py,sha256=-S7K7qwbDGUO_KW4gnA3dGyxuezN1ZK-WikPm7krEvs,14291
5
+ sawnergy/embedding/SGNS_torch.py,sha256=NgVQnMtRSYY0IsPhB3XV7K1-uVSah0P77a8ID8zZ7Qw,13940
6
+ sawnergy/embedding/__init__.py,sha256=T1YXb7S5Zyy_kIqlarDSX3imd_FGFH6nDuvLQ3hMKsE,1764
7
+ sawnergy/embedding/embedder.py,sha256=02pcf3ies3Nuo19sCoJdMAYg7BFUHj4-wf4AZ5R6PAE,32492
8
+ sawnergy/embedding/visualizer.py,sha256=x0BiSG9_nk9AUQm9RsZ2syKeCiaxX1gTlC85aYycMXY,8830
9
+ sawnergy/rin/__init__.py,sha256=z19hLfEIp3bwzY-eCHQBQf0NRTCJzVz_FLIpVV5q0W4,162
10
+ sawnergy/rin/rin_builder.py,sha256=d1cC4KKY9zzNlqhxHWTFM-QyXRXubd2zlCrSM-dV5pc,44624
11
+ sawnergy/rin/rin_util.py,sha256=5TKywA5qfm76Gl4Cyz7oBPasmE5chclR7UM4hawwQOg,14939
12
+ sawnergy/visual/__init__.py,sha256=p_ByFtfrP19b5_qiJlkAnYesZN3M1LjIo421LUgVVbw,502
13
+ sawnergy/visual/visualizer.py,sha256=GVD_rFavDXFz9-h28eFf5nPBujUvRncn_zYoHcFHZ3Q,33155
14
+ sawnergy/visual/visualizer_util.py,sha256=7y3kWjHxDQMoG0dmimceHKTC5veVChoyvW7d0qXH23k,15100
15
+ sawnergy/walks/__init__.py,sha256=Z_Kaffhn3oUX13z9jbY0V5Ncdwj9Cnr--n9D-s7gh5k,250
16
+ sawnergy/walks/walker.py,sha256=scvfZFrSL4AwpmspD0Jb0uhnrVIRRwE_hPCE3bG6zpg,37729
17
+ sawnergy/walks/walker_util.py,sha256=ETdyPNIDwDQCA8Z5t38keBhYBJ56_ksT_0NhOCY-tHE,15361
18
+ sawnergy-1.0.8.dist-info/licenses/LICENSE,sha256=cElK4bCsDhyAEON3H05s35bQZvxBcXBiCOrOdiUhDCY,11346
19
+ sawnergy-1.0.8.dist-info/licenses/NOTICE,sha256=eVTbuSasZrmMJVtKoWOzsKyu4ZNm7Ks7dzI3Tx5tEHc,109
20
+ sawnergy-1.0.8.dist-info/METADATA,sha256=_0u1smFM5oMqaO0xuc4ZX094B6F2swQqUrolOkpikVM,16084
21
+ sawnergy-1.0.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
22
+ sawnergy-1.0.8.dist-info/top_level.txt,sha256=-67FQD6FD9Gjt74WTmO9hNYA3MLB4HaSxci0sEKC5Lo,9
23
+ sawnergy-1.0.8.dist-info/RECORD,,
@@ -1,22 +0,0 @@
1
- sawnergy/__init__.py,sha256=Dq1U38ah6nPRFEDKN41mYphcTynKfnItca6QkYkpSbs,248
2
- sawnergy/logging_util.py,sha256=tnhToHchnWaORHU73dxzBuL1e_C-AXFdPExDZTEI6tE,1474
3
- sawnergy/sawnergy_util.py,sha256=Htx9wr0S8TXt5aHT2mtEdYf1TCo_BC1IUwNNuZdIR-4,49432
4
- sawnergy/embedding/SGNS_pml.py,sha256=xF_0DksJTUH5DxchTwkg-Ol975lwH1O259Wa0ZSbmDA,6298
5
- sawnergy/embedding/SGNS_torch.py,sha256=3Pa_mk5mzsl27M87q4tNmitOouxDdG5ZzxpdaOSyGt8,6411
6
- sawnergy/embedding/__init__.py,sha256=sxUh2RcZyPs8aCdvec8x843Bm3DBaYQNrBF8VyvLQ-k,965
7
- sawnergy/embedding/embedder.py,sha256=0DRkEfjWqnKCHdr0AxN3wjqclezMOOw6THZE7GlxihE,26266
8
- sawnergy/rin/__init__.py,sha256=z19hLfEIp3bwzY-eCHQBQf0NRTCJzVz_FLIpVV5q0W4,162
9
- sawnergy/rin/rin_builder.py,sha256=z5hCvW-jHnnv7ZgHlQlruRAMKa-TnKFdvkMcoHBhX78,44623
10
- sawnergy/rin/rin_util.py,sha256=5TKywA5qfm76Gl4Cyz7oBPasmE5chclR7UM4hawwQOg,14939
11
- sawnergy/visual/__init__.py,sha256=p_ByFtfrP19b5_qiJlkAnYesZN3M1LjIo421LUgVVbw,502
12
- sawnergy/visual/visualizer.py,sha256=qqggoLRNi6t0awXEt-Hy2ut9S0Y8_uKznyozlGLR1Q8,33131
13
- sawnergy/visual/visualizer_util.py,sha256=C9W22CJmfJuTV5_uYsEnG8YChR4nH7OHKbNz26hAyB0,15028
14
- sawnergy/walks/__init__.py,sha256=Z_Kaffhn3oUX13z9jbY0V5Ncdwj9Cnr--n9D-s7gh5k,250
15
- sawnergy/walks/walker.py,sha256=scvfZFrSL4AwpmspD0Jb0uhnrVIRRwE_hPCE3bG6zpg,37729
16
- sawnergy/walks/walker_util.py,sha256=ETdyPNIDwDQCA8Z5t38keBhYBJ56_ksT_0NhOCY-tHE,15361
17
- sawnergy-1.0.6.dist-info/licenses/LICENSE,sha256=cElK4bCsDhyAEON3H05s35bQZvxBcXBiCOrOdiUhDCY,11346
18
- sawnergy-1.0.6.dist-info/licenses/NOTICE,sha256=eVTbuSasZrmMJVtKoWOzsKyu4ZNm7Ks7dzI3Tx5tEHc,109
19
- sawnergy-1.0.6.dist-info/METADATA,sha256=9_ocluBr8baUZfTcZdBkdNx_AIu3VOtKADEyMuTc3CY,13367
20
- sawnergy-1.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
21
- sawnergy-1.0.6.dist-info/top_level.txt,sha256=-67FQD6FD9Gjt74WTmO9hNYA3MLB4HaSxci0sEKC5Lo,9
22
- sawnergy-1.0.6.dist-info/RECORD,,