sawnergy 1.0.7__tar.gz → 1.0.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sawnergy might be problematic. Click here for more details.
- {sawnergy-1.0.7/sawnergy.egg-info → sawnergy-1.0.8}/PKG-INFO +39 -40
- {sawnergy-1.0.7 → sawnergy-1.0.8}/README.md +38 -39
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/embedding/SGNS_pml.py +36 -38
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/embedding/SGNS_torch.py +82 -29
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/embedding/embedder.py +325 -245
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/embedding/visualizer.py +9 -5
- {sawnergy-1.0.7 → sawnergy-1.0.8/sawnergy.egg-info}/PKG-INFO +39 -40
- {sawnergy-1.0.7 → sawnergy-1.0.8}/tests/test_embedding.py +17 -3
- {sawnergy-1.0.7 → sawnergy-1.0.8}/LICENSE +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/NOTICE +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/__init__.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/embedding/__init__.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/logging_util.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/rin/__init__.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/rin/rin_builder.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/rin/rin_util.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/sawnergy_util.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/visual/__init__.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/visual/visualizer.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/visual/visualizer_util.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/walks/__init__.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/walks/walker.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/walks/walker_util.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy.egg-info/SOURCES.txt +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy.egg-info/dependency_links.txt +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy.egg-info/requires.txt +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy.egg-info/top_level.txt +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/setup.cfg +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/tests/test_embedding_visualizer.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/tests/test_rin.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/tests/test_storage.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/tests/test_visual.py +0 -0
- {sawnergy-1.0.7 → sawnergy-1.0.8}/tests/test_walks.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sawnergy
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.8
|
|
4
4
|
Summary: Toolkit for transforming molecular dynamics (MD) trajectories into rich graph representations
|
|
5
5
|
Home-page: https://github.com/Yehor-Mishchyriak/SAWNERGY
|
|
6
6
|
Author: Yehor Mishchyriak
|
|
@@ -52,19 +52,31 @@ keeps the full workflow — from `cpptraj` output to skip-gram embeddings (node2
|
|
|
52
52
|
|
|
53
53
|
> **Optional:** For GPU training, install PyTorch separately (e.g., `pip install torch`).
|
|
54
54
|
> **Note:** RIN building requires `cpptraj` (AmberTools). Ensure it is discoverable via `$PATH` or the `CPPTRAJ`
|
|
55
|
-
> environment variable. Probably the easiest solution: install AmberTools via
|
|
55
|
+
> environment variable. Probably the easiest solution: install AmberTools via Conda, activate the environment, and SAWNERGY will find the cpptraj executable on its own, so just run your code and don't worry about it.
|
|
56
56
|
|
|
57
57
|
---
|
|
58
58
|
|
|
59
59
|
# UPDATES:
|
|
60
60
|
|
|
61
|
+
## v1.0.8 — What’s new:
|
|
62
|
+
- **Temporary deprecation of `SGNS_Torch`**
|
|
63
|
+
- `sawnergy.embedding.SGNS_Torch` currently produces noisy embeddings in practice. The issue likely stems from **weight initialization**, although the root cause has not yet been conclusively determined.
|
|
64
|
+
- **Action:** The class and its `__init__` docstring now carry a deprecation notice. Constructing the class emits a **`DeprecationWarning`** and logs a **warning**.
|
|
65
|
+
- **Use instead:** Prefer **`SG_Torch`** (plain Skip-Gram with full softmax) or the PureML backends **`SGNS_PureML`** / **`SG_PureML`**.
|
|
66
|
+
- **Compatibility:** No breaking API changes; imports remain stable. PureML backends are unaffected.
|
|
67
|
+
- **Embedding visualizer update**
|
|
68
|
+
- Now you can L2 normalize your embeddings before display.
|
|
69
|
+
- **Small improvements in the embedding module**
|
|
70
|
+
- Improved API with a lot of good defaults in place to ease usage out of the box.
|
|
71
|
+
- Small internal model tweaks.
|
|
72
|
+
|
|
61
73
|
## v1.0.7 — What’s new:
|
|
62
|
-
- **Added plain
|
|
74
|
+
- **Added plain Skip-Gram model**
|
|
63
75
|
- Now, the user can choose if they want to apply the negative sampling technique (two binary classifiers) or train a single classifier over the vocabulary (full softmax). For more detail, see: [node2vec](https://arxiv.org/pdf/1607.00653), [word2vec](https://arxiv.org/pdf/1301.3781), and [negative_sampling](https://arxiv.org/pdf/1402.3722).
|
|
64
76
|
- **Set a harsher default for low interaction energies pruning during RIN construction**
|
|
65
77
|
- Now we zero out 85% of the lowest interaction energies as opposed to the past 30% default, leading to more meaningful embeddings.
|
|
66
78
|
- **BUG FIX: Visualizer**
|
|
67
|
-
- Previously, the visualizer would silently draw edges of 0 magnitude, meaning they were actually being drawn but were invisible due to full transparency and 0 width. As a result, the displayed image
|
|
79
|
+
- Previously, the visualizer would silently draw edges of 0 magnitude, meaning they were actually being drawn but were invisible due to full transparency and 0 width. As a result, the displayed image/animation would be very laggy. Now, this was fixed, and given the higher pruning default, the displayed interaction networks are clean and smooth under rotations, dragging, etc.
|
|
68
80
|
- **New Embedding Visualizer (3D)**
|
|
69
81
|
- New lightweight viewer for per-frame embeddings that projects embeddings with PCA to a **3D** scatter. Supports the same node coloring semantics, optional node labels, and the same antialiasing/depthshade controls. Works in headless setups using the same backend guard and uses a blocking `show=True` for scripts.
|
|
70
82
|
|
|
@@ -77,7 +89,7 @@ keeps the full workflow — from `cpptraj` output to skip-gram embeddings (node2
|
|
|
77
89
|
- **Deterministic, shareable artifacts**: Every stage produces compressed Zarr archives that contain both data and metadata so runs can be reproduced, shared, or inspected later.
|
|
78
90
|
- **High-performance data handling**: Heavy arrays live in shared memory during walk sampling to allow parallel processing without serialization overhead; archives are written in chunked, compressed form for fast read/write.
|
|
79
91
|
- **Flexible objectives & backends**: Train Skip-Gram with **negative sampling** (`objective="sgns"`) or **plain Skip-Gram** (`objective="sg"`), using either **PureML** (default) or **PyTorch**.
|
|
80
|
-
- **Visualization out of the box**: Plot and animate residue networks without leaving Python, using the data produced by RINBuilder
|
|
92
|
+
- **Visualization out of the box**: Plot and animate residue networks without leaving Python, using the data produced by RINBuilder.
|
|
81
93
|
|
|
82
94
|
---
|
|
83
95
|
|
|
@@ -117,7 +129,7 @@ node indexing, and RNG seeds stay consistent across the toolchain.
|
|
|
117
129
|
* Wraps the AmberTools `cpptraj` executable to:
|
|
118
130
|
- compute per-frame electrostatic (EMAP) and van der Waals (VMAP) energy matrices at the atomic level,
|
|
119
131
|
- project atom–atom interactions to residue–residue interactions using compositional masks,
|
|
120
|
-
- prune, symmetrize, remove self-interactions, and L1-
|
|
132
|
+
- prune, symmetrize, remove self-interactions, and L1-normalize the matrices,
|
|
121
133
|
- compute per-residue centers of mass (COM) over the same frames.
|
|
122
134
|
* Outputs a compressed Zarr archive with transition matrices, optional pre-normalized energies, COM snapshots, and rich
|
|
123
135
|
metadata (frame range, pruning quantile, molecule ID, etc.).
|
|
@@ -142,13 +154,10 @@ node indexing, and RNG seeds stay consistent across the toolchain.
|
|
|
142
154
|
|
|
143
155
|
### `sawnergy.embedding.Embedder`
|
|
144
156
|
|
|
145
|
-
* Consumes walk archives, generates skip-gram pairs, and
|
|
146
|
-
*
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
* Both `SGNS_PureML` and `SGNS_Torch` accept training hyperparameters such as batch_size, LR, optimizer and LR_scheduler, etc.
|
|
150
|
-
* Exposes `embed_frame` (single frame) and `embed_all` (all frames, deterministic seeding per frame) which return the
|
|
151
|
-
learned input embedding matrices and write them to disk when requested.
|
|
157
|
+
* Consumes walk archives, generates skip-gram pairs, and normalizes them to 0-based indices.
|
|
158
|
+
* Selects skip-gram (SG / SGNS) backends dynamically via `model_base="pureml"|"torch"` with per-backend overrides supplied through `model_kwargs`.
|
|
159
|
+
* Handles deterministic per-frame seeding and returns the requested embedding `kind` (`"in"`, `"out"`, or `"avg"`) from `embed_frame` and `embed_all`.
|
|
160
|
+
* Persists per-frame matrices with rich provenance (walk metadata, objective, hyperparameters, RNG seeds) when `embed_all` targets an output archive.
|
|
152
161
|
|
|
153
162
|
### Supporting Utilities
|
|
154
163
|
|
|
@@ -166,11 +175,11 @@ node indexing, and RNG seeds stay consistent across the toolchain.
|
|
|
166
175
|
|---|---|---|
|
|
167
176
|
| **RIN** | `ATTRACTIVE_transitions` → **(T, N, N)**, float32 • `REPULSIVE_transitions` → **(T, N, N)**, float32 (optional) • `ATTRACTIVE_energies` → **(T, N, N)**, float32 (optional) • `REPULSIVE_energies` → **(T, N, N)**, float32 (optional) • `COM` → **(T, N, 3)**, float32 | `time_created` (ISO) • `com_name` = `"COM"` • `molecule_of_interest` (int) • `frame_range` = `(start, end)` inclusive • `frame_batch_size` (int) • `prune_low_energies_frac` (float in [0,1]) • `attractive_transitions_name` / `repulsive_transitions_name` (dataset names or `None`) • `attractive_energies_name` / `repulsive_energies_name` (dataset names or `None`) |
|
|
168
177
|
| **Walks** | `ATTRACTIVE_RWs` → **(T, N·num_RWs, L+1)**, int32 (optional) • `REPULSIVE_RWs` → **(T, N·num_RWs, L+1)**, int32 (optional) • `ATTRACTIVE_SAWs` → **(T, N·num_SAWs, L+1)**, int32 (optional) • `REPULSIVE_SAWs` → **(T, N·num_SAWs, L+1)**, int32 (optional) <br/>_Note:_ node IDs are **1-based**.| `time_created` (ISO) • `seed` (int) • `rng_scheme` = `"SeedSequence.spawn_per_batch_v1"` • `num_workers` (int) • `in_parallel` (bool) • `batch_size_nodes` (int) • `num_RWs` / `num_SAWs` (ints) • `node_count` (N) • `time_stamp_count` (T) • `walk_length` (L) • `walks_per_node` (int) • `attractive_RWs_name` / `repulsive_RWs_name` / `attractive_SAWs_name` / `repulsive_SAWs_name` (dataset names or `None`) • `walks_layout` = `"time_leading_3d"` |
|
|
169
|
-
| **Embeddings** | `FRAME_EMBEDDINGS` → **(
|
|
178
|
+
| **Embeddings** | `FRAME_EMBEDDINGS` → **(T, N, D)**, float32 | `created_at` (ISO) • `frame_embeddings_name` = `"FRAME_EMBEDDINGS"` • `time_stamp_count` = T • `node_count` = N • `embedding_dim` = D • `model_base` = `"torch"` or `"pureml"` • `embedding_kind` = `"in"|"out"|"avg"` • `objective` = `"sgns"` or `"sg"` • `negative_sampling` (bool) • `num_negative_samples` (int) • `num_epochs` (int) • `batch_size` (int) • `window_size` (int) • `alpha` (float) • `lr_step_per_batch` (bool) • `shuffle_data` (bool) • `device_hint` (str) • `model_kwargs_repr` (repr string) • `RIN_type` = `"attr"` or `"repuls"` • `using` = `"RW"|"SAW"|"merged"` • `source_WALKS_path` (str) • `walk_length` (int) • `num_RWs` / `num_SAWs` (ints) • `attractive_*_name` / `repulsive_*_name` (dataset names or `None`) • `master_seed` (int) • `per_frame_seeds` (list[int]) • `arrays_per_chunk` (int) • `compression_level` (int) |
|
|
170
179
|
|
|
171
180
|
**Notes**
|
|
172
181
|
|
|
173
|
-
- In **RIN**, `T` equals the number of frame **batches** written (i.e., `frame_range` swept in steps of `frame_batch_size`). `ATTRACTIVE/REPULSIVE_energies` are **pre-
|
|
182
|
+
- In **RIN**, `T` equals the number of frame **batches** written (i.e., `frame_range` swept in steps of `frame_batch_size`). `ATTRACTIVE/REPULSIVE_energies` are **pre-normalized** absolute energies (written only when `keep_prenormalized_energies=True`), whereas `ATTRACTIVE/REPULSIVE_transitions` are the **row-wise L1-normalized** versions used for sampling.
|
|
174
183
|
- All archives are Zarr v3 groups. ArrayStorage also maintains per-block metadata in root attrs: `array_chunk_size_in_block`, `array_shape_in_block`, and `array_dtype_in_block` (dicts keyed by dataset name). You’ll see these in every archive.
|
|
175
184
|
- In **Embeddings**, `alpha` and `num_negative_samples` apply to **SGNS** only and are ignored for `objective="sg"`.
|
|
176
185
|
|
|
@@ -200,7 +209,7 @@ rin_builder.build_rin(
|
|
|
200
209
|
prune_low_energies_frac=0.85,
|
|
201
210
|
output_path=rin_path,
|
|
202
211
|
include_attractive=True,
|
|
203
|
-
include_repulsive=False
|
|
212
|
+
include_repulsive=False
|
|
204
213
|
)
|
|
205
214
|
|
|
206
215
|
# 2. Sample walks from the RIN
|
|
@@ -208,44 +217,34 @@ walker = Walker(rin_path, seed=123)
|
|
|
208
217
|
walks_path = Path("./WALKS_demo.zip")
|
|
209
218
|
walker.sample_walks(
|
|
210
219
|
walk_length=16,
|
|
211
|
-
walks_per_node=
|
|
220
|
+
walks_per_node=100,
|
|
212
221
|
saw_frac=0.25,
|
|
213
222
|
include_attractive=True,
|
|
214
223
|
include_repulsive=False,
|
|
215
224
|
time_aware=False,
|
|
216
225
|
output_path=walks_path,
|
|
217
|
-
in_parallel=False
|
|
226
|
+
in_parallel=False
|
|
218
227
|
)
|
|
219
228
|
walker.close()
|
|
220
229
|
|
|
221
230
|
# 3. Train embeddings per frame (PyTorch backend)
|
|
222
231
|
import torch
|
|
223
232
|
|
|
224
|
-
embedder = Embedder(walks_path,
|
|
233
|
+
embedder = Embedder(walks_path, seed=999)
|
|
225
234
|
embeddings_path = embedder.embed_all(
|
|
226
235
|
RIN_type="attr",
|
|
227
236
|
using="merged",
|
|
237
|
+
num_epochs=10,
|
|
238
|
+
negative_sampling=False,
|
|
228
239
|
window_size=4,
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
batch_size=1024,
|
|
233
|
-
dimensionality=128,
|
|
234
|
-
shuffle_data=True,
|
|
235
|
-
output_path="./EMBEDDINGS_demo.zip",
|
|
236
|
-
sgns_kwargs={
|
|
237
|
-
"optim": torch.optim.Adam,
|
|
238
|
-
"optim_kwargs": {"lr": 1e-3},
|
|
239
|
-
"lr_sched": torch.optim.lr_scheduler.LambdaLR,
|
|
240
|
-
"lr_sched_kwargs": {"lr_lambda": lambda _: 1.0},
|
|
241
|
-
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
|
242
|
-
},
|
|
240
|
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
|
241
|
+
model_base="torch",
|
|
242
|
+
output_path="./EMBEDDINGS_demo.zip"
|
|
243
243
|
)
|
|
244
244
|
print("Embeddings written to", embeddings_path)
|
|
245
245
|
```
|
|
246
246
|
|
|
247
|
-
> For the PureML backend,
|
|
248
|
-
> (for example `optim=pureml.optimizers.Adam`, `lr_sched=pureml.optimizers.CosineAnnealingLR`).
|
|
247
|
+
> For the PureML backend, set `model_base="pureml"` and pass the optimizer / scheduler classes inside `model_kwargs`.
|
|
249
248
|
|
|
250
249
|
---
|
|
251
250
|
|
|
@@ -270,7 +269,7 @@ v.build_frame(1,
|
|
|
270
269
|
```python
|
|
271
270
|
from sawnergy.embedding import Visualizer
|
|
272
271
|
|
|
273
|
-
viz =
|
|
272
|
+
viz = Visualizer("./EMBEDDINGS_demo.zip", normalize_rows=True)
|
|
274
273
|
viz.build_frame(1, show=True)
|
|
275
274
|
```
|
|
276
275
|
|
|
@@ -280,8 +279,7 @@ viz.build_frame(1, show=True)
|
|
|
280
279
|
|
|
281
280
|
- **Time-aware walks**: Set `time_aware=True`, provide `stickiness` and `on_no_options` when calling `Walker.sample_walks`.
|
|
282
281
|
- **Shared memory lifecycle**: Call `Walker.close()` (or use a context manager) to release shared-memory segments.
|
|
283
|
-
- **PureML vs PyTorch**:
|
|
284
|
-
constructor kwargs through `sgns_kwargs` (optimizer, scheduler, device).
|
|
282
|
+
- **PureML vs PyTorch**: Select the backend at call time with `model_base="pureml"|"torch"` (defaults to `"pureml"`) and pass optimizer / scheduler overrides through `model_kwargs`.
|
|
285
283
|
- **ArrayStorage utilities**: Use `ArrayStorage` directly to peek into archives, append arrays, or manage metadata.
|
|
286
284
|
|
|
287
285
|
---
|
|
@@ -292,8 +290,9 @@ viz.build_frame(1, show=True)
|
|
|
292
290
|
├── sawnergy/
|
|
293
291
|
│ ├── rin/ # RINBuilder and cpptraj integration helpers
|
|
294
292
|
│ ├── walks/ # Walker class and shared-memory utilities
|
|
295
|
-
│ ├── embedding/ # Embedder + SGNS backends (PureML / PyTorch)
|
|
293
|
+
│ ├── embedding/ # Embedder + SG/SGNS backends (PureML / PyTorch)
|
|
296
294
|
│ ├── visual/ # Visualizer and palette utilities
|
|
295
|
+
│ │
|
|
297
296
|
│ ├── logging_util.py
|
|
298
297
|
│ └── sawnergy_util.py
|
|
299
298
|
│
|
|
@@ -302,7 +301,7 @@ viz.build_frame(1, show=True)
|
|
|
302
301
|
|
|
303
302
|
---
|
|
304
303
|
|
|
305
|
-
##
|
|
304
|
+
## Acknowledgments
|
|
306
305
|
|
|
307
306
|
SAWNERGY builds on the AmberTools `cpptraj` ecosystem, NumPy, Matplotlib, Zarr, and PyTorch (for GPU acceleration if necessary; PureML is available by default).
|
|
308
307
|
Big thanks to the upstream communities whose work makes this toolkit possible.
|
|
@@ -18,19 +18,31 @@ keeps the full workflow — from `cpptraj` output to skip-gram embeddings (node2
|
|
|
18
18
|
|
|
19
19
|
> **Optional:** For GPU training, install PyTorch separately (e.g., `pip install torch`).
|
|
20
20
|
> **Note:** RIN building requires `cpptraj` (AmberTools). Ensure it is discoverable via `$PATH` or the `CPPTRAJ`
|
|
21
|
-
> environment variable. Probably the easiest solution: install AmberTools via
|
|
21
|
+
> environment variable. Probably the easiest solution: install AmberTools via Conda, activate the environment, and SAWNERGY will find the cpptraj executable on its own, so just run your code and don't worry about it.
|
|
22
22
|
|
|
23
23
|
---
|
|
24
24
|
|
|
25
25
|
# UPDATES:
|
|
26
26
|
|
|
27
|
+
## v1.0.8 — What’s new:
|
|
28
|
+
- **Temporary deprecation of `SGNS_Torch`**
|
|
29
|
+
- `sawnergy.embedding.SGNS_Torch` currently produces noisy embeddings in practice. The issue likely stems from **weight initialization**, although the root cause has not yet been conclusively determined.
|
|
30
|
+
- **Action:** The class and its `__init__` docstring now carry a deprecation notice. Constructing the class emits a **`DeprecationWarning`** and logs a **warning**.
|
|
31
|
+
- **Use instead:** Prefer **`SG_Torch`** (plain Skip-Gram with full softmax) or the PureML backends **`SGNS_PureML`** / **`SG_PureML`**.
|
|
32
|
+
- **Compatibility:** No breaking API changes; imports remain stable. PureML backends are unaffected.
|
|
33
|
+
- **Embedding visualizer update**
|
|
34
|
+
- Now you can L2 normalize your embeddings before display.
|
|
35
|
+
- **Small improvements in the embedding module**
|
|
36
|
+
- Improved API with a lot of good defaults in place to ease usage out of the box.
|
|
37
|
+
- Small internal model tweaks.
|
|
38
|
+
|
|
27
39
|
## v1.0.7 — What’s new:
|
|
28
|
-
- **Added plain
|
|
40
|
+
- **Added plain Skip-Gram model**
|
|
29
41
|
- Now, the user can choose if they want to apply the negative sampling technique (two binary classifiers) or train a single classifier over the vocabulary (full softmax). For more detail, see: [node2vec](https://arxiv.org/pdf/1607.00653), [word2vec](https://arxiv.org/pdf/1301.3781), and [negative_sampling](https://arxiv.org/pdf/1402.3722).
|
|
30
42
|
- **Set a harsher default for low interaction energies pruning during RIN construction**
|
|
31
43
|
- Now we zero out 85% of the lowest interaction energies as opposed to the past 30% default, leading to more meaningful embeddings.
|
|
32
44
|
- **BUG FIX: Visualizer**
|
|
33
|
-
- Previously, the visualizer would silently draw edges of 0 magnitude, meaning they were actually being drawn but were invisible due to full transparency and 0 width. As a result, the displayed image
|
|
45
|
+
- Previously, the visualizer would silently draw edges of 0 magnitude, meaning they were actually being drawn but were invisible due to full transparency and 0 width. As a result, the displayed image/animation would be very laggy. Now, this was fixed, and given the higher pruning default, the displayed interaction networks are clean and smooth under rotations, dragging, etc.
|
|
34
46
|
- **New Embedding Visualizer (3D)**
|
|
35
47
|
- New lightweight viewer for per-frame embeddings that projects embeddings with PCA to a **3D** scatter. Supports the same node coloring semantics, optional node labels, and the same antialiasing/depthshade controls. Works in headless setups using the same backend guard and uses a blocking `show=True` for scripts.
|
|
36
48
|
|
|
@@ -43,7 +55,7 @@ keeps the full workflow — from `cpptraj` output to skip-gram embeddings (node2
|
|
|
43
55
|
- **Deterministic, shareable artifacts**: Every stage produces compressed Zarr archives that contain both data and metadata so runs can be reproduced, shared, or inspected later.
|
|
44
56
|
- **High-performance data handling**: Heavy arrays live in shared memory during walk sampling to allow parallel processing without serialization overhead; archives are written in chunked, compressed form for fast read/write.
|
|
45
57
|
- **Flexible objectives & backends**: Train Skip-Gram with **negative sampling** (`objective="sgns"`) or **plain Skip-Gram** (`objective="sg"`), using either **PureML** (default) or **PyTorch**.
|
|
46
|
-
- **Visualization out of the box**: Plot and animate residue networks without leaving Python, using the data produced by RINBuilder
|
|
58
|
+
- **Visualization out of the box**: Plot and animate residue networks without leaving Python, using the data produced by RINBuilder.
|
|
47
59
|
|
|
48
60
|
---
|
|
49
61
|
|
|
@@ -83,7 +95,7 @@ node indexing, and RNG seeds stay consistent across the toolchain.
|
|
|
83
95
|
* Wraps the AmberTools `cpptraj` executable to:
|
|
84
96
|
- compute per-frame electrostatic (EMAP) and van der Waals (VMAP) energy matrices at the atomic level,
|
|
85
97
|
- project atom–atom interactions to residue–residue interactions using compositional masks,
|
|
86
|
-
- prune, symmetrize, remove self-interactions, and L1-
|
|
98
|
+
- prune, symmetrize, remove self-interactions, and L1-normalize the matrices,
|
|
87
99
|
- compute per-residue centers of mass (COM) over the same frames.
|
|
88
100
|
* Outputs a compressed Zarr archive with transition matrices, optional pre-normalized energies, COM snapshots, and rich
|
|
89
101
|
metadata (frame range, pruning quantile, molecule ID, etc.).
|
|
@@ -108,13 +120,10 @@ node indexing, and RNG seeds stay consistent across the toolchain.
|
|
|
108
120
|
|
|
109
121
|
### `sawnergy.embedding.Embedder`
|
|
110
122
|
|
|
111
|
-
* Consumes walk archives, generates skip-gram pairs, and
|
|
112
|
-
*
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
* Both `SGNS_PureML` and `SGNS_Torch` accept training hyperparameters such as batch_size, LR, optimizer and LR_scheduler, etc.
|
|
116
|
-
* Exposes `embed_frame` (single frame) and `embed_all` (all frames, deterministic seeding per frame) which return the
|
|
117
|
-
learned input embedding matrices and write them to disk when requested.
|
|
123
|
+
* Consumes walk archives, generates skip-gram pairs, and normalizes them to 0-based indices.
|
|
124
|
+
* Selects skip-gram (SG / SGNS) backends dynamically via `model_base="pureml"|"torch"` with per-backend overrides supplied through `model_kwargs`.
|
|
125
|
+
* Handles deterministic per-frame seeding and returns the requested embedding `kind` (`"in"`, `"out"`, or `"avg"`) from `embed_frame` and `embed_all`.
|
|
126
|
+
* Persists per-frame matrices with rich provenance (walk metadata, objective, hyperparameters, RNG seeds) when `embed_all` targets an output archive.
|
|
118
127
|
|
|
119
128
|
### Supporting Utilities
|
|
120
129
|
|
|
@@ -132,11 +141,11 @@ node indexing, and RNG seeds stay consistent across the toolchain.
|
|
|
132
141
|
|---|---|---|
|
|
133
142
|
| **RIN** | `ATTRACTIVE_transitions` → **(T, N, N)**, float32 • `REPULSIVE_transitions` → **(T, N, N)**, float32 (optional) • `ATTRACTIVE_energies` → **(T, N, N)**, float32 (optional) • `REPULSIVE_energies` → **(T, N, N)**, float32 (optional) • `COM` → **(T, N, 3)**, float32 | `time_created` (ISO) • `com_name` = `"COM"` • `molecule_of_interest` (int) • `frame_range` = `(start, end)` inclusive • `frame_batch_size` (int) • `prune_low_energies_frac` (float in [0,1]) • `attractive_transitions_name` / `repulsive_transitions_name` (dataset names or `None`) • `attractive_energies_name` / `repulsive_energies_name` (dataset names or `None`) |
|
|
134
143
|
| **Walks** | `ATTRACTIVE_RWs` → **(T, N·num_RWs, L+1)**, int32 (optional) • `REPULSIVE_RWs` → **(T, N·num_RWs, L+1)**, int32 (optional) • `ATTRACTIVE_SAWs` → **(T, N·num_SAWs, L+1)**, int32 (optional) • `REPULSIVE_SAWs` → **(T, N·num_SAWs, L+1)**, int32 (optional) <br/>_Note:_ node IDs are **1-based**.| `time_created` (ISO) • `seed` (int) • `rng_scheme` = `"SeedSequence.spawn_per_batch_v1"` • `num_workers` (int) • `in_parallel` (bool) • `batch_size_nodes` (int) • `num_RWs` / `num_SAWs` (ints) • `node_count` (N) • `time_stamp_count` (T) • `walk_length` (L) • `walks_per_node` (int) • `attractive_RWs_name` / `repulsive_RWs_name` / `attractive_SAWs_name` / `repulsive_SAWs_name` (dataset names or `None`) • `walks_layout` = `"time_leading_3d"` |
|
|
135
|
-
| **Embeddings** | `FRAME_EMBEDDINGS` → **(
|
|
144
|
+
| **Embeddings** | `FRAME_EMBEDDINGS` → **(T, N, D)**, float32 | `created_at` (ISO) • `frame_embeddings_name` = `"FRAME_EMBEDDINGS"` • `time_stamp_count` = T • `node_count` = N • `embedding_dim` = D • `model_base` = `"torch"` or `"pureml"` • `embedding_kind` = `"in"|"out"|"avg"` • `objective` = `"sgns"` or `"sg"` • `negative_sampling` (bool) • `num_negative_samples` (int) • `num_epochs` (int) • `batch_size` (int) • `window_size` (int) • `alpha` (float) • `lr_step_per_batch` (bool) • `shuffle_data` (bool) • `device_hint` (str) • `model_kwargs_repr` (repr string) • `RIN_type` = `"attr"` or `"repuls"` • `using` = `"RW"|"SAW"|"merged"` • `source_WALKS_path` (str) • `walk_length` (int) • `num_RWs` / `num_SAWs` (ints) • `attractive_*_name` / `repulsive_*_name` (dataset names or `None`) • `master_seed` (int) • `per_frame_seeds` (list[int]) • `arrays_per_chunk` (int) • `compression_level` (int) |
|
|
136
145
|
|
|
137
146
|
**Notes**
|
|
138
147
|
|
|
139
|
-
- In **RIN**, `T` equals the number of frame **batches** written (i.e., `frame_range` swept in steps of `frame_batch_size`). `ATTRACTIVE/REPULSIVE_energies` are **pre-
|
|
148
|
+
- In **RIN**, `T` equals the number of frame **batches** written (i.e., `frame_range` swept in steps of `frame_batch_size`). `ATTRACTIVE/REPULSIVE_energies` are **pre-normalized** absolute energies (written only when `keep_prenormalized_energies=True`), whereas `ATTRACTIVE/REPULSIVE_transitions` are the **row-wise L1-normalized** versions used for sampling.
|
|
140
149
|
- All archives are Zarr v3 groups. ArrayStorage also maintains per-block metadata in root attrs: `array_chunk_size_in_block`, `array_shape_in_block`, and `array_dtype_in_block` (dicts keyed by dataset name). You’ll see these in every archive.
|
|
141
150
|
- In **Embeddings**, `alpha` and `num_negative_samples` apply to **SGNS** only and are ignored for `objective="sg"`.
|
|
142
151
|
|
|
@@ -166,7 +175,7 @@ rin_builder.build_rin(
|
|
|
166
175
|
prune_low_energies_frac=0.85,
|
|
167
176
|
output_path=rin_path,
|
|
168
177
|
include_attractive=True,
|
|
169
|
-
include_repulsive=False
|
|
178
|
+
include_repulsive=False
|
|
170
179
|
)
|
|
171
180
|
|
|
172
181
|
# 2. Sample walks from the RIN
|
|
@@ -174,44 +183,34 @@ walker = Walker(rin_path, seed=123)
|
|
|
174
183
|
walks_path = Path("./WALKS_demo.zip")
|
|
175
184
|
walker.sample_walks(
|
|
176
185
|
walk_length=16,
|
|
177
|
-
walks_per_node=
|
|
186
|
+
walks_per_node=100,
|
|
178
187
|
saw_frac=0.25,
|
|
179
188
|
include_attractive=True,
|
|
180
189
|
include_repulsive=False,
|
|
181
190
|
time_aware=False,
|
|
182
191
|
output_path=walks_path,
|
|
183
|
-
in_parallel=False
|
|
192
|
+
in_parallel=False
|
|
184
193
|
)
|
|
185
194
|
walker.close()
|
|
186
195
|
|
|
187
196
|
# 3. Train embeddings per frame (PyTorch backend)
|
|
188
197
|
import torch
|
|
189
198
|
|
|
190
|
-
embedder = Embedder(walks_path,
|
|
199
|
+
embedder = Embedder(walks_path, seed=999)
|
|
191
200
|
embeddings_path = embedder.embed_all(
|
|
192
201
|
RIN_type="attr",
|
|
193
202
|
using="merged",
|
|
203
|
+
num_epochs=10,
|
|
204
|
+
negative_sampling=False,
|
|
194
205
|
window_size=4,
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
batch_size=1024,
|
|
199
|
-
dimensionality=128,
|
|
200
|
-
shuffle_data=True,
|
|
201
|
-
output_path="./EMBEDDINGS_demo.zip",
|
|
202
|
-
sgns_kwargs={
|
|
203
|
-
"optim": torch.optim.Adam,
|
|
204
|
-
"optim_kwargs": {"lr": 1e-3},
|
|
205
|
-
"lr_sched": torch.optim.lr_scheduler.LambdaLR,
|
|
206
|
-
"lr_sched_kwargs": {"lr_lambda": lambda _: 1.0},
|
|
207
|
-
"device": "cuda" if torch.cuda.is_available() else "cpu",
|
|
208
|
-
},
|
|
206
|
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
|
207
|
+
model_base="torch",
|
|
208
|
+
output_path="./EMBEDDINGS_demo.zip"
|
|
209
209
|
)
|
|
210
210
|
print("Embeddings written to", embeddings_path)
|
|
211
211
|
```
|
|
212
212
|
|
|
213
|
-
> For the PureML backend,
|
|
214
|
-
> (for example `optim=pureml.optimizers.Adam`, `lr_sched=pureml.optimizers.CosineAnnealingLR`).
|
|
213
|
+
> For the PureML backend, set `model_base="pureml"` and pass the optimizer / scheduler classes inside `model_kwargs`.
|
|
215
214
|
|
|
216
215
|
---
|
|
217
216
|
|
|
@@ -236,7 +235,7 @@ v.build_frame(1,
|
|
|
236
235
|
```python
|
|
237
236
|
from sawnergy.embedding import Visualizer
|
|
238
237
|
|
|
239
|
-
viz =
|
|
238
|
+
viz = Visualizer("./EMBEDDINGS_demo.zip", normalize_rows=True)
|
|
240
239
|
viz.build_frame(1, show=True)
|
|
241
240
|
```
|
|
242
241
|
|
|
@@ -246,8 +245,7 @@ viz.build_frame(1, show=True)
|
|
|
246
245
|
|
|
247
246
|
- **Time-aware walks**: Set `time_aware=True`, provide `stickiness` and `on_no_options` when calling `Walker.sample_walks`.
|
|
248
247
|
- **Shared memory lifecycle**: Call `Walker.close()` (or use a context manager) to release shared-memory segments.
|
|
249
|
-
- **PureML vs PyTorch**:
|
|
250
|
-
constructor kwargs through `sgns_kwargs` (optimizer, scheduler, device).
|
|
248
|
+
- **PureML vs PyTorch**: Select the backend at call time with `model_base="pureml"|"torch"` (defaults to `"pureml"`) and pass optimizer / scheduler overrides through `model_kwargs`.
|
|
251
249
|
- **ArrayStorage utilities**: Use `ArrayStorage` directly to peek into archives, append arrays, or manage metadata.
|
|
252
250
|
|
|
253
251
|
---
|
|
@@ -258,8 +256,9 @@ viz.build_frame(1, show=True)
|
|
|
258
256
|
├── sawnergy/
|
|
259
257
|
│ ├── rin/ # RINBuilder and cpptraj integration helpers
|
|
260
258
|
│ ├── walks/ # Walker class and shared-memory utilities
|
|
261
|
-
│ ├── embedding/ # Embedder + SGNS backends (PureML / PyTorch)
|
|
259
|
+
│ ├── embedding/ # Embedder + SG/SGNS backends (PureML / PyTorch)
|
|
262
260
|
│ ├── visual/ # Visualizer and palette utilities
|
|
261
|
+
│ │
|
|
263
262
|
│ ├── logging_util.py
|
|
264
263
|
│ └── sawnergy_util.py
|
|
265
264
|
│
|
|
@@ -268,7 +267,7 @@ viz.build_frame(1, show=True)
|
|
|
268
267
|
|
|
269
268
|
---
|
|
270
269
|
|
|
271
|
-
##
|
|
270
|
+
## Acknowledgments
|
|
272
271
|
|
|
273
272
|
SAWNERGY builds on the AmberTools `cpptraj` ecosystem, NumPy, Matplotlib, Zarr, and PyTorch (for GPU acceleration if necessary; PureML is available by default).
|
|
274
273
|
Big thanks to the upstream communities whose work makes this toolkit possible.
|
|
@@ -6,7 +6,7 @@ from pureml.machinery import Tensor
|
|
|
6
6
|
from pureml.layers import Embedding, Affine
|
|
7
7
|
from pureml.losses import BCE, CCE
|
|
8
8
|
from pureml.general_math import sum as t_sum
|
|
9
|
-
from pureml.optimizers import Optim, LRScheduler
|
|
9
|
+
from pureml.optimizers import Optim, LRScheduler, SGD
|
|
10
10
|
from pureml.training_utils import TensorDataset, DataLoader, one_hot
|
|
11
11
|
from pureml.base import NN
|
|
12
12
|
|
|
@@ -32,8 +32,8 @@ class SGNS_PureML(NN):
|
|
|
32
32
|
D: int,
|
|
33
33
|
*,
|
|
34
34
|
seed: int | None = None,
|
|
35
|
-
optim: Type[Optim],
|
|
36
|
-
optim_kwargs: dict,
|
|
35
|
+
optim: Type[Optim] = SGD,
|
|
36
|
+
optim_kwargs: dict | None = None,
|
|
37
37
|
lr_sched: Type[LRScheduler] | None = None,
|
|
38
38
|
lr_sched_kwargs: dict | None = None,
|
|
39
39
|
device: str | None = None):
|
|
@@ -42,15 +42,15 @@ class SGNS_PureML(NN):
|
|
|
42
42
|
V: Vocabulary size (number of nodes).
|
|
43
43
|
D: Embedding dimensionality.
|
|
44
44
|
seed: Optional RNG seed for negative sampling.
|
|
45
|
-
optim: Optimizer class to instantiate.
|
|
46
|
-
optim_kwargs: Keyword arguments for the optimizer
|
|
45
|
+
optim: Optimizer class to instantiate. Defaults to plain SGD.
|
|
46
|
+
optim_kwargs: Keyword arguments for the optimizer. Defaults to {"lr": 0.1}.
|
|
47
47
|
lr_sched: Optional learning-rate scheduler class.
|
|
48
48
|
lr_sched_kwargs: Keyword arguments for the scheduler (required if lr_sched is provided).
|
|
49
49
|
device: Target device string (e.g. "cuda"); accepted for API parity, ignored by PureML.
|
|
50
50
|
"""
|
|
51
51
|
|
|
52
|
-
|
|
53
|
-
|
|
52
|
+
optim_kwargs = optim_kwargs or {"lr": 0.1}
|
|
53
|
+
|
|
54
54
|
if lr_sched is not None and lr_sched_kwargs is None:
|
|
55
55
|
raise ValueError("lr_sched_kwargs required when lr_sched is provided")
|
|
56
56
|
|
|
@@ -147,7 +147,7 @@ class SGNS_PureML(NN):
|
|
|
147
147
|
K = int(neg.data.shape[1])
|
|
148
148
|
loss = (
|
|
149
149
|
BCE(y_pos, x_pos_logits, from_logits=True)
|
|
150
|
-
+ K*BCE(y_neg, x_neg_logits, from_logits=True)
|
|
150
|
+
+ Tensor(K)*BCE(y_neg, x_neg_logits, from_logits=True)
|
|
151
151
|
)
|
|
152
152
|
|
|
153
153
|
self.optim.zero_grad()
|
|
@@ -176,7 +176,9 @@ class SGNS_PureML(NN):
|
|
|
176
176
|
"Wrong embedding matrix shape: "
|
|
177
177
|
"self.in_emb.parameters[0].shape != (V, D)"
|
|
178
178
|
)
|
|
179
|
-
|
|
179
|
+
arr = W.numpy(copy=True, readonly=True) # (V, D)
|
|
180
|
+
_logger.debug("In emb shape: %s", arr.shape)
|
|
181
|
+
return arr
|
|
180
182
|
|
|
181
183
|
@property
|
|
182
184
|
def out_embeddings(self) -> np.ndarray:
|
|
@@ -186,7 +188,9 @@ class SGNS_PureML(NN):
|
|
|
186
188
|
"Wrong embedding matrix shape: "
|
|
187
189
|
"self.out_emb.parameters[0].shape != (V, D)"
|
|
188
190
|
)
|
|
189
|
-
|
|
191
|
+
arr = W.numpy(copy=True, readonly=True) # (V, D)
|
|
192
|
+
_logger.debug("Out emb shape: %s", arr.shape)
|
|
193
|
+
return arr
|
|
190
194
|
|
|
191
195
|
@property
|
|
192
196
|
def avg_embeddings(self) -> np.ndarray:
|
|
@@ -208,37 +212,29 @@ class SG_PureML(NN):
|
|
|
208
212
|
"""
|
|
209
213
|
|
|
210
214
|
def __init__(self,
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
215
|
+
V: int,
|
|
216
|
+
D: int,
|
|
217
|
+
*,
|
|
218
|
+
seed: int | None = None,
|
|
219
|
+
optim: Type[Optim] = SGD,
|
|
220
|
+
optim_kwargs: dict | None = None,
|
|
221
|
+
lr_sched: Type[LRScheduler] | None = None,
|
|
222
|
+
lr_sched_kwargs: dict | None = None,
|
|
223
|
+
device: str | None = None):
|
|
220
224
|
"""Initialize the plain Skip-Gram model (full softmax).
|
|
221
225
|
|
|
222
226
|
Args:
|
|
223
227
|
V: Vocabulary size (number of nodes/tokens).
|
|
224
228
|
D: Embedding dimensionality.
|
|
225
229
|
seed: Optional RNG seed (kept for API parity; not used in layer init).
|
|
226
|
-
optim: Optimizer class to instantiate
|
|
227
|
-
optim_kwargs: Keyword arguments
|
|
230
|
+
optim: Optimizer class to instantiate. Defaults to plain SGD.
|
|
231
|
+
optim_kwargs: Keyword arguments for the optimizer. Defaults to {"lr": 0.1}.
|
|
228
232
|
lr_sched: Optional learning-rate scheduler class.
|
|
229
|
-
lr_sched_kwargs: Keyword arguments for the scheduler
|
|
230
|
-
|
|
231
|
-
device: Device string (e.g., `"cuda"`). Accepted for parity, ignored
|
|
232
|
-
by PureML (CPU-only).
|
|
233
|
-
|
|
234
|
-
Notes:
|
|
235
|
-
The encoder/decoder are implemented as:
|
|
236
|
-
• `in_emb = Affine(V, D)` (acts on a one-hot center index)
|
|
237
|
-
• `out_emb = Affine(D, V)`
|
|
238
|
-
so forward pass produces vocabulary-sized logits.
|
|
233
|
+
lr_sched_kwargs: Keyword arguments for the scheduler (required if lr_sched is provided).
|
|
234
|
+
device: Device string (e.g., "cuda"). Accepted for parity, ignored by PureML (CPU-only).
|
|
239
235
|
"""
|
|
240
|
-
|
|
241
|
-
|
|
236
|
+
|
|
237
|
+
optim_kwargs = optim_kwargs or {"lr": 0.1}
|
|
242
238
|
if lr_sched is not None and lr_sched_kwargs is None:
|
|
243
239
|
raise ValueError("lr_sched_kwargs required when lr_sched is provided")
|
|
244
240
|
|
|
@@ -249,9 +245,7 @@ class SG_PureML(NN):
|
|
|
249
245
|
self.out_emb = Affine(self.D, self.V)
|
|
250
246
|
|
|
251
247
|
self.seed = None if seed is None else int(seed)
|
|
252
|
-
|
|
253
|
-
# API compatibility: PureML is CPU-only
|
|
254
|
-
self.device = "cpu"
|
|
248
|
+
self.device = "cpu" # API parity
|
|
255
249
|
|
|
256
250
|
# optimizer / scheduler
|
|
257
251
|
self.optim: Optim = optim(self.parameters, **optim_kwargs)
|
|
@@ -344,7 +338,9 @@ class SG_PureML(NN):
|
|
|
344
338
|
"Wrong embedding matrix shape: "
|
|
345
339
|
"self.in_emb.parameters[0].shape != (V, D)"
|
|
346
340
|
)
|
|
347
|
-
|
|
341
|
+
arr = W.numpy(copy=True, readonly=True) # (V, D)
|
|
342
|
+
_logger.debug("In emb shape: %s", arr.shape)
|
|
343
|
+
return arr
|
|
348
344
|
|
|
349
345
|
@property
|
|
350
346
|
def out_embeddings(self) -> np.ndarray:
|
|
@@ -356,7 +352,9 @@ class SG_PureML(NN):
|
|
|
356
352
|
"Wrong embedding matrix shape: "
|
|
357
353
|
"self.out_emb.parameters[0].shape != (D, V)"
|
|
358
354
|
)
|
|
359
|
-
|
|
355
|
+
arr = W.numpy(copy=True, readonly=True).T # (V, D)
|
|
356
|
+
_logger.debug("Out emb shape: %s", arr.shape)
|
|
357
|
+
return arr
|
|
360
358
|
|
|
361
359
|
@property
|
|
362
360
|
def avg_embeddings(self) -> np.ndarray:
|