sawnergy 1.0.7__tar.gz → 1.0.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sawnergy might be problematic. Click here for more details.

Files changed (33) hide show
  1. {sawnergy-1.0.7/sawnergy.egg-info → sawnergy-1.0.8}/PKG-INFO +39 -40
  2. {sawnergy-1.0.7 → sawnergy-1.0.8}/README.md +38 -39
  3. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/embedding/SGNS_pml.py +36 -38
  4. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/embedding/SGNS_torch.py +82 -29
  5. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/embedding/embedder.py +325 -245
  6. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/embedding/visualizer.py +9 -5
  7. {sawnergy-1.0.7 → sawnergy-1.0.8/sawnergy.egg-info}/PKG-INFO +39 -40
  8. {sawnergy-1.0.7 → sawnergy-1.0.8}/tests/test_embedding.py +17 -3
  9. {sawnergy-1.0.7 → sawnergy-1.0.8}/LICENSE +0 -0
  10. {sawnergy-1.0.7 → sawnergy-1.0.8}/NOTICE +0 -0
  11. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/__init__.py +0 -0
  12. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/embedding/__init__.py +0 -0
  13. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/logging_util.py +0 -0
  14. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/rin/__init__.py +0 -0
  15. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/rin/rin_builder.py +0 -0
  16. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/rin/rin_util.py +0 -0
  17. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/sawnergy_util.py +0 -0
  18. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/visual/__init__.py +0 -0
  19. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/visual/visualizer.py +0 -0
  20. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/visual/visualizer_util.py +0 -0
  21. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/walks/__init__.py +0 -0
  22. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/walks/walker.py +0 -0
  23. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy/walks/walker_util.py +0 -0
  24. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy.egg-info/SOURCES.txt +0 -0
  25. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy.egg-info/dependency_links.txt +0 -0
  26. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy.egg-info/requires.txt +0 -0
  27. {sawnergy-1.0.7 → sawnergy-1.0.8}/sawnergy.egg-info/top_level.txt +0 -0
  28. {sawnergy-1.0.7 → sawnergy-1.0.8}/setup.cfg +0 -0
  29. {sawnergy-1.0.7 → sawnergy-1.0.8}/tests/test_embedding_visualizer.py +0 -0
  30. {sawnergy-1.0.7 → sawnergy-1.0.8}/tests/test_rin.py +0 -0
  31. {sawnergy-1.0.7 → sawnergy-1.0.8}/tests/test_storage.py +0 -0
  32. {sawnergy-1.0.7 → sawnergy-1.0.8}/tests/test_visual.py +0 -0
  33. {sawnergy-1.0.7 → sawnergy-1.0.8}/tests/test_walks.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sawnergy
3
- Version: 1.0.7
3
+ Version: 1.0.8
4
4
  Summary: Toolkit for transforming molecular dynamics (MD) trajectories into rich graph representations
5
5
  Home-page: https://github.com/Yehor-Mishchyriak/SAWNERGY
6
6
  Author: Yehor Mishchyriak
@@ -52,19 +52,31 @@ keeps the full workflow — from `cpptraj` output to skip-gram embeddings (node2
52
52
 
53
53
  > **Optional:** For GPU training, install PyTorch separately (e.g., `pip install torch`).
54
54
  > **Note:** RIN building requires `cpptraj` (AmberTools). Ensure it is discoverable via `$PATH` or the `CPPTRAJ`
55
- > environment variable. Probably the easiest solution: install AmberTools via conda, activate the environment, and SAWNERGY will find cpptraj executable on its own, so just run your code and don't worry about it.
55
+ > environment variable. Probably the easiest solution: install AmberTools via Conda, activate the environment, and SAWNERGY will find the cpptraj executable on its own, so just run your code and don't worry about it.
56
56
 
57
57
  ---
58
58
 
59
59
  # UPDATES:
60
60
 
61
+ ## v1.0.8 — What’s new:
62
+ - **Temporary deprecation of `SGNS_Torch`**
63
+ - `sawnergy.embedding.SGNS_Torch` currently produces noisy embeddings in practice. The issue likely stems from **weight initialization**, although the root cause has not yet been conclusively determined.
64
+ - **Action:** The class and its `__init__` docstring now carry a deprecation notice. Constructing the class emits a **`DeprecationWarning`** and logs a **warning**.
65
+ - **Use instead:** Prefer **`SG_Torch`** (plain Skip-Gram with full softmax) or the PureML backends **`SGNS_PureML`** / **`SG_PureML`**.
66
+ - **Compatibility:** No breaking API changes; imports remain stable. PureML backends are unaffected.
67
+ - **Embedding visualizer update**
68
+ - Now you can L2 normalize your embeddings before display.
69
+ - **Small improvements in the embedding module**
70
+ - Improved API with a lot of good defaults in place to ease usage out of the box.
71
+ - Small internal model tweaks.
72
+
61
73
  ## v1.0.7 — What’s new:
62
- - **Added plain SkipGram model**
74
+ - **Added plain Skip-Gram model**
63
75
  - Now, the user can choose if they want to apply the negative sampling technique (two binary classifiers) or train a single classifier over the vocabulary (full softmax). For more detail, see: [node2vec](https://arxiv.org/pdf/1607.00653), [word2vec](https://arxiv.org/pdf/1301.3781), and [negative_sampling](https://arxiv.org/pdf/1402.3722).
64
76
  - **Set a harsher default for low interaction energies pruning during RIN construction**
65
77
  - Now we zero out 85% of the lowest interaction energies as opposed to the past 30% default, leading to more meaningful embeddings.
66
78
  - **BUG FIX: Visualizer**
67
- - Previously, the visualizer would silently draw edges of 0 magnitude, meaning they were actually being drawn but were invisible due to full transparency and 0 width. As a result, the displayed image / animation would be very laggy. Now, this was fixed, and given high pruning default, the displayed interaction networks are clean and smooth under rotations, dragging, etc.
79
+ - Previously, the visualizer would silently draw edges of 0 magnitude, meaning they were actually being drawn but were invisible due to full transparency and 0 width. As a result, the displayed image/animation would be very laggy. Now, this was fixed, and given the higher pruning default, the displayed interaction networks are clean and smooth under rotations, dragging, etc.
68
80
  - **New Embedding Visualizer (3D)**
69
81
  - New lightweight viewer for per-frame embeddings that projects embeddings with PCA to a **3D** scatter. Supports the same node coloring semantics, optional node labels, and the same antialiasing/depthshade controls. Works in headless setups using the same backend guard and uses a blocking `show=True` for scripts.
70
82
 
@@ -77,7 +89,7 @@ keeps the full workflow — from `cpptraj` output to skip-gram embeddings (node2
77
89
  - **Deterministic, shareable artifacts**: Every stage produces compressed Zarr archives that contain both data and metadata so runs can be reproduced, shared, or inspected later.
78
90
  - **High-performance data handling**: Heavy arrays live in shared memory during walk sampling to allow parallel processing without serialization overhead; archives are written in chunked, compressed form for fast read/write.
79
91
  - **Flexible objectives & backends**: Train Skip-Gram with **negative sampling** (`objective="sgns"`) or **plain Skip-Gram** (`objective="sg"`), using either **PureML** (default) or **PyTorch**.
80
- - **Visualization out of the box**: Plot and animate residue networks without leaving Python, using the data produced by RINBuilder
92
+ - **Visualization out of the box**: Plot and animate residue networks without leaving Python, using the data produced by RINBuilder.
81
93
 
82
94
  ---
83
95
 
@@ -117,7 +129,7 @@ node indexing, and RNG seeds stay consistent across the toolchain.
117
129
  * Wraps the AmberTools `cpptraj` executable to:
118
130
  - compute per-frame electrostatic (EMAP) and van der Waals (VMAP) energy matrices at the atomic level,
119
131
  - project atom–atom interactions to residue–residue interactions using compositional masks,
120
- - prune, symmetrize, remove self-interactions, and L1-normalise the matrices,
132
+ - prune, symmetrize, remove self-interactions, and L1-normalize the matrices,
121
133
  - compute per-residue centers of mass (COM) over the same frames.
122
134
  * Outputs a compressed Zarr archive with transition matrices, optional pre-normalized energies, COM snapshots, and rich
123
135
  metadata (frame range, pruning quantile, molecule ID, etc.).
@@ -142,13 +154,10 @@ node indexing, and RNG seeds stay consistent across the toolchain.
142
154
 
143
155
  ### `sawnergy.embedding.Embedder`
144
156
 
145
- * Consumes walk archives, generates skip-gram pairs, and normalises them to 0-based indices.
146
- * Provides a unified interface to SGNS implementations:
147
- - **PureML backend** (`SGNS_PureML`): works with the `pureml` ecosystem, optimistic for CPU training.
148
- - **PyTorch backend** (`SGNS_Torch`): uses `torch.nn.Embedding` plays nicely with GPUs.
149
- * Both `SGNS_PureML` and `SGNS_Torch` accept training hyperparameters such as batch_size, LR, optimizer and LR_scheduler, etc.
150
- * Exposes `embed_frame` (single frame) and `embed_all` (all frames, deterministic seeding per frame) which return the
151
- learned input embedding matrices and write them to disk when requested.
157
+ * Consumes walk archives, generates skip-gram pairs, and normalizes them to 0-based indices.
158
+ * Selects skip-gram (SG / SGNS) backends dynamically via `model_base="pureml"|"torch"` with per-backend overrides supplied through `model_kwargs`.
159
+ * Handles deterministic per-frame seeding and returns the requested embedding `kind` (`"in"`, `"out"`, or `"avg"`) from `embed_frame` and `embed_all`.
160
+ * Persists per-frame matrices with rich provenance (walk metadata, objective, hyperparameters, RNG seeds) when `embed_all` targets an output archive.
152
161
 
153
162
  ### Supporting Utilities
154
163
 
@@ -166,11 +175,11 @@ node indexing, and RNG seeds stay consistent across the toolchain.
166
175
  |---|---|---|
167
176
  | **RIN** | `ATTRACTIVE_transitions` → **(T, N, N)**, float32 • `REPULSIVE_transitions` → **(T, N, N)**, float32 (optional) • `ATTRACTIVE_energies` → **(T, N, N)**, float32 (optional) • `REPULSIVE_energies` → **(T, N, N)**, float32 (optional) • `COM` → **(T, N, 3)**, float32 | `time_created` (ISO) • `com_name` = `"COM"` • `molecule_of_interest` (int) • `frame_range` = `(start, end)` inclusive • `frame_batch_size` (int) • `prune_low_energies_frac` (float in [0,1]) • `attractive_transitions_name` / `repulsive_transitions_name` (dataset names or `None`) • `attractive_energies_name` / `repulsive_energies_name` (dataset names or `None`) |
168
177
  | **Walks** | `ATTRACTIVE_RWs` → **(T, N·num_RWs, L+1)**, int32 (optional) • `REPULSIVE_RWs` → **(T, N·num_RWs, L+1)**, int32 (optional) • `ATTRACTIVE_SAWs` → **(T, N·num_SAWs, L+1)**, int32 (optional) • `REPULSIVE_SAWs` → **(T, N·num_SAWs, L+1)**, int32 (optional) <br/>_Note:_ node IDs are **1-based**.| `time_created` (ISO) • `seed` (int) • `rng_scheme` = `"SeedSequence.spawn_per_batch_v1"` • `num_workers` (int) • `in_parallel` (bool) • `batch_size_nodes` (int) • `num_RWs` / `num_SAWs` (ints) • `node_count` (N) • `time_stamp_count` (T) • `walk_length` (L) • `walks_per_node` (int) • `attractive_RWs_name` / `repulsive_RWs_name` / `attractive_SAWs_name` / `repulsive_SAWs_name` (dataset names or `None`) • `walks_layout` = `"time_leading_3d"` |
169
- | **Embeddings** | `FRAME_EMBEDDINGS` → **(frames_written, vocab_size, D)**, typically float32 | `time_created` (ISO) • `seed` (int) • `rng_scheme` = `"SeedSequence.spawn_per_frame_v1"` • `source_walks_path` (str) • `model_base` = `"torch"` or `"pureml"` • `rin_type` = `"attr"` or `"repuls"` • `using_mode` = `"RW"|"SAW"|"merged"` • `window_size` (int) • `alpha` (float; noise exponent) • `dimensionality` = D • `num_negative_samples` (int) • `num_epochs` (int) • `batch_size` (int) • `shuffle_data` (bool) • `frames_written` (int) • `vocab_size` (int) • `frame_count` (int) • `embedding_dtype` (str) • `frame_embeddings_name` = `"FRAME_EMBEDDINGS"` • `arrays_per_chunk` (int) • `compression_level` (int) • `objective` = `"sgns"` or `"sg"` |
178
+ | **Embeddings** | `FRAME_EMBEDDINGS` → **(T, N, D)**, float32 | `created_at` (ISO) • `frame_embeddings_name` = `"FRAME_EMBEDDINGS"` • `time_stamp_count` = T • `node_count` = N • `embedding_dim` = D • `model_base` = `"torch"` or `"pureml"` • `embedding_kind` = `"in"|"out"|"avg"` • `objective` = `"sgns"` or `"sg"` • `negative_sampling` (bool) • `num_negative_samples` (int) • `num_epochs` (int) • `batch_size` (int) • `window_size` (int) • `alpha` (float) • `lr_step_per_batch` (bool) • `shuffle_data` (bool) • `device_hint` (str) • `model_kwargs_repr` (repr string) • `RIN_type` = `"attr"` or `"repuls"` • `using` = `"RW"|"SAW"|"merged"` • `source_WALKS_path` (str) • `walk_length` (int) • `num_RWs` / `num_SAWs` (ints) • `attractive_*_name` / `repulsive_*_name` (dataset names or `None`) • `master_seed` (int) • `per_frame_seeds` (list[int]) • `arrays_per_chunk` (int) • `compression_level` (int) |
170
179
 
171
180
  **Notes**
172
181
 
173
- - In **RIN**, `T` equals the number of frame **batches** written (i.e., `frame_range` swept in steps of `frame_batch_size`). `ATTRACTIVE/REPULSIVE_energies` are **pre-normalised** absolute energies (written only when `keep_prenormalized_energies=True`), whereas `ATTRACTIVE/REPULSIVE_transitions` are the **row-wise L1-normalised** versions used for sampling.
182
+ - In **RIN**, `T` equals the number of frame **batches** written (i.e., `frame_range` swept in steps of `frame_batch_size`). `ATTRACTIVE/REPULSIVE_energies` are **pre-normalized** absolute energies (written only when `keep_prenormalized_energies=True`), whereas `ATTRACTIVE/REPULSIVE_transitions` are the **row-wise L1-normalized** versions used for sampling.
174
183
  - All archives are Zarr v3 groups. ArrayStorage also maintains per-block metadata in root attrs: `array_chunk_size_in_block`, `array_shape_in_block`, and `array_dtype_in_block` (dicts keyed by dataset name). You’ll see these in every archive.
175
184
  - In **Embeddings**, `alpha` and `num_negative_samples` apply to **SGNS** only and are ignored for `objective="sg"`.
176
185
 
@@ -200,7 +209,7 @@ rin_builder.build_rin(
200
209
  prune_low_energies_frac=0.85,
201
210
  output_path=rin_path,
202
211
  include_attractive=True,
203
- include_repulsive=False,
212
+ include_repulsive=False
204
213
  )
205
214
 
206
215
  # 2. Sample walks from the RIN
@@ -208,44 +217,34 @@ walker = Walker(rin_path, seed=123)
208
217
  walks_path = Path("./WALKS_demo.zip")
209
218
  walker.sample_walks(
210
219
  walk_length=16,
211
- walks_per_node=32,
220
+ walks_per_node=100,
212
221
  saw_frac=0.25,
213
222
  include_attractive=True,
214
223
  include_repulsive=False,
215
224
  time_aware=False,
216
225
  output_path=walks_path,
217
- in_parallel=False,
226
+ in_parallel=False
218
227
  )
219
228
  walker.close()
220
229
 
221
230
  # 3. Train embeddings per frame (PyTorch backend)
222
231
  import torch
223
232
 
224
- embedder = Embedder(walks_path, base="torch", seed=999)
233
+ embedder = Embedder(walks_path, seed=999)
225
234
  embeddings_path = embedder.embed_all(
226
235
  RIN_type="attr",
227
236
  using="merged",
237
+ num_epochs=10,
238
+ negative_sampling=False,
228
239
  window_size=4,
229
- objective="sgns",
230
- num_negative_samples=5,
231
- num_epochs=5,
232
- batch_size=1024,
233
- dimensionality=128,
234
- shuffle_data=True,
235
- output_path="./EMBEDDINGS_demo.zip",
236
- sgns_kwargs={
237
- "optim": torch.optim.Adam,
238
- "optim_kwargs": {"lr": 1e-3},
239
- "lr_sched": torch.optim.lr_scheduler.LambdaLR,
240
- "lr_sched_kwargs": {"lr_lambda": lambda _: 1.0},
241
- "device": "cuda" if torch.cuda.is_available() else "cpu",
242
- },
240
+ device="cuda" if torch.cuda.is_available() else "cpu",
241
+ model_base="torch",
242
+ output_path="./EMBEDDINGS_demo.zip"
243
243
  )
244
244
  print("Embeddings written to", embeddings_path)
245
245
  ```
246
246
 
247
- > For the PureML backend, supply the relevant optimiser and scheduler via `sgns_kwargs`
248
- > (for example `optim=pureml.optimizers.Adam`, `lr_sched=pureml.optimizers.CosineAnnealingLR`).
247
+ > For the PureML backend, set `model_base="pureml"` and pass the optimizer / scheduler classes inside `model_kwargs`.
249
248
 
250
249
  ---
251
250
 
@@ -270,7 +269,7 @@ v.build_frame(1,
270
269
  ```python
271
270
  from sawnergy.embedding import Visualizer
272
271
 
273
- viz = sawnergy.embedding.Visualizer("./EMBEDDINGS_demo.zip")
272
+ viz = Visualizer("./EMBEDDINGS_demo.zip", normalize_rows=True)
274
273
  viz.build_frame(1, show=True)
275
274
  ```
276
275
 
@@ -280,8 +279,7 @@ viz.build_frame(1, show=True)
280
279
 
281
280
  - **Time-aware walks**: Set `time_aware=True`, provide `stickiness` and `on_no_options` when calling `Walker.sample_walks`.
282
281
  - **Shared memory lifecycle**: Call `Walker.close()` (or use a context manager) to release shared-memory segments.
283
- - **PureML vs PyTorch**: Choose the backend via `Embedder(..., base="pureml"|"torch")` and provide backend-specific
284
- constructor kwargs through `sgns_kwargs` (optimizer, scheduler, device).
282
+ - **PureML vs PyTorch**: Select the backend at call time with `model_base="pureml"|"torch"` (defaults to `"pureml"`) and pass optimizer / scheduler overrides through `model_kwargs`.
285
283
  - **ArrayStorage utilities**: Use `ArrayStorage` directly to peek into archives, append arrays, or manage metadata.
286
284
 
287
285
  ---
@@ -292,8 +290,9 @@ viz.build_frame(1, show=True)
292
290
  ├── sawnergy/
293
291
  │ ├── rin/ # RINBuilder and cpptraj integration helpers
294
292
  │ ├── walks/ # Walker class and shared-memory utilities
295
- │ ├── embedding/ # Embedder + SGNS backends (PureML / PyTorch)
293
+ │ ├── embedding/ # Embedder + SG/SGNS backends (PureML / PyTorch)
296
294
  │ ├── visual/ # Visualizer and palette utilities
295
+ │ │
297
296
  │ ├── logging_util.py
298
297
  │ └── sawnergy_util.py
299
298
 
@@ -302,7 +301,7 @@ viz.build_frame(1, show=True)
302
301
 
303
302
  ---
304
303
 
305
- ## Acknowledgements
304
+ ## Acknowledgments
306
305
 
307
306
  SAWNERGY builds on the AmberTools `cpptraj` ecosystem, NumPy, Matplotlib, Zarr, and PyTorch (for GPU acceleration if necessary; PureML is available by default).
308
307
  Big thanks to the upstream communities whose work makes this toolkit possible.
@@ -18,19 +18,31 @@ keeps the full workflow — from `cpptraj` output to skip-gram embeddings (node2
18
18
 
19
19
  > **Optional:** For GPU training, install PyTorch separately (e.g., `pip install torch`).
20
20
  > **Note:** RIN building requires `cpptraj` (AmberTools). Ensure it is discoverable via `$PATH` or the `CPPTRAJ`
21
- > environment variable. Probably the easiest solution: install AmberTools via conda, activate the environment, and SAWNERGY will find cpptraj executable on its own, so just run your code and don't worry about it.
21
+ > environment variable. Probably the easiest solution: install AmberTools via Conda, activate the environment, and SAWNERGY will find the cpptraj executable on its own, so just run your code and don't worry about it.
22
22
 
23
23
  ---
24
24
 
25
25
  # UPDATES:
26
26
 
27
+ ## v1.0.8 — What’s new:
28
+ - **Temporary deprecation of `SGNS_Torch`**
29
+ - `sawnergy.embedding.SGNS_Torch` currently produces noisy embeddings in practice. The issue likely stems from **weight initialization**, although the root cause has not yet been conclusively determined.
30
+ - **Action:** The class and its `__init__` docstring now carry a deprecation notice. Constructing the class emits a **`DeprecationWarning`** and logs a **warning**.
31
+ - **Use instead:** Prefer **`SG_Torch`** (plain Skip-Gram with full softmax) or the PureML backends **`SGNS_PureML`** / **`SG_PureML`**.
32
+ - **Compatibility:** No breaking API changes; imports remain stable. PureML backends are unaffected.
33
+ - **Embedding visualizer update**
34
+ - Now you can L2 normalize your embeddings before display.
35
+ - **Small improvements in the embedding module**
36
+ - Improved API with a lot of good defaults in place to ease usage out of the box.
37
+ - Small internal model tweaks.
38
+
27
39
  ## v1.0.7 — What’s new:
28
- - **Added plain SkipGram model**
40
+ - **Added plain Skip-Gram model**
29
41
  - Now, the user can choose if they want to apply the negative sampling technique (two binary classifiers) or train a single classifier over the vocabulary (full softmax). For more detail, see: [node2vec](https://arxiv.org/pdf/1607.00653), [word2vec](https://arxiv.org/pdf/1301.3781), and [negative_sampling](https://arxiv.org/pdf/1402.3722).
30
42
  - **Set a harsher default for low interaction energies pruning during RIN construction**
31
43
  - Now we zero out 85% of the lowest interaction energies as opposed to the past 30% default, leading to more meaningful embeddings.
32
44
  - **BUG FIX: Visualizer**
33
- - Previously, the visualizer would silently draw edges of 0 magnitude, meaning they were actually being drawn but were invisible due to full transparency and 0 width. As a result, the displayed image / animation would be very laggy. Now, this was fixed, and given high pruning default, the displayed interaction networks are clean and smooth under rotations, dragging, etc.
45
+ - Previously, the visualizer would silently draw edges of 0 magnitude, meaning they were actually being drawn but were invisible due to full transparency and 0 width. As a result, the displayed image/animation would be very laggy. Now, this was fixed, and given the higher pruning default, the displayed interaction networks are clean and smooth under rotations, dragging, etc.
34
46
  - **New Embedding Visualizer (3D)**
35
47
  - New lightweight viewer for per-frame embeddings that projects embeddings with PCA to a **3D** scatter. Supports the same node coloring semantics, optional node labels, and the same antialiasing/depthshade controls. Works in headless setups using the same backend guard and uses a blocking `show=True` for scripts.
36
48
 
@@ -43,7 +55,7 @@ keeps the full workflow — from `cpptraj` output to skip-gram embeddings (node2
43
55
  - **Deterministic, shareable artifacts**: Every stage produces compressed Zarr archives that contain both data and metadata so runs can be reproduced, shared, or inspected later.
44
56
  - **High-performance data handling**: Heavy arrays live in shared memory during walk sampling to allow parallel processing without serialization overhead; archives are written in chunked, compressed form for fast read/write.
45
57
  - **Flexible objectives & backends**: Train Skip-Gram with **negative sampling** (`objective="sgns"`) or **plain Skip-Gram** (`objective="sg"`), using either **PureML** (default) or **PyTorch**.
46
- - **Visualization out of the box**: Plot and animate residue networks without leaving Python, using the data produced by RINBuilder
58
+ - **Visualization out of the box**: Plot and animate residue networks without leaving Python, using the data produced by RINBuilder.
47
59
 
48
60
  ---
49
61
 
@@ -83,7 +95,7 @@ node indexing, and RNG seeds stay consistent across the toolchain.
83
95
  * Wraps the AmberTools `cpptraj` executable to:
84
96
  - compute per-frame electrostatic (EMAP) and van der Waals (VMAP) energy matrices at the atomic level,
85
97
  - project atom–atom interactions to residue–residue interactions using compositional masks,
86
- - prune, symmetrize, remove self-interactions, and L1-normalise the matrices,
98
+ - prune, symmetrize, remove self-interactions, and L1-normalize the matrices,
87
99
  - compute per-residue centers of mass (COM) over the same frames.
88
100
  * Outputs a compressed Zarr archive with transition matrices, optional pre-normalized energies, COM snapshots, and rich
89
101
  metadata (frame range, pruning quantile, molecule ID, etc.).
@@ -108,13 +120,10 @@ node indexing, and RNG seeds stay consistent across the toolchain.
108
120
 
109
121
  ### `sawnergy.embedding.Embedder`
110
122
 
111
- * Consumes walk archives, generates skip-gram pairs, and normalises them to 0-based indices.
112
- * Provides a unified interface to SGNS implementations:
113
- - **PureML backend** (`SGNS_PureML`): works with the `pureml` ecosystem, optimistic for CPU training.
114
- - **PyTorch backend** (`SGNS_Torch`): uses `torch.nn.Embedding` plays nicely with GPUs.
115
- * Both `SGNS_PureML` and `SGNS_Torch` accept training hyperparameters such as batch_size, LR, optimizer and LR_scheduler, etc.
116
- * Exposes `embed_frame` (single frame) and `embed_all` (all frames, deterministic seeding per frame) which return the
117
- learned input embedding matrices and write them to disk when requested.
123
+ * Consumes walk archives, generates skip-gram pairs, and normalizes them to 0-based indices.
124
+ * Selects skip-gram (SG / SGNS) backends dynamically via `model_base="pureml"|"torch"` with per-backend overrides supplied through `model_kwargs`.
125
+ * Handles deterministic per-frame seeding and returns the requested embedding `kind` (`"in"`, `"out"`, or `"avg"`) from `embed_frame` and `embed_all`.
126
+ * Persists per-frame matrices with rich provenance (walk metadata, objective, hyperparameters, RNG seeds) when `embed_all` targets an output archive.
118
127
 
119
128
  ### Supporting Utilities
120
129
 
@@ -132,11 +141,11 @@ node indexing, and RNG seeds stay consistent across the toolchain.
132
141
  |---|---|---|
133
142
  | **RIN** | `ATTRACTIVE_transitions` → **(T, N, N)**, float32 • `REPULSIVE_transitions` → **(T, N, N)**, float32 (optional) • `ATTRACTIVE_energies` → **(T, N, N)**, float32 (optional) • `REPULSIVE_energies` → **(T, N, N)**, float32 (optional) • `COM` → **(T, N, 3)**, float32 | `time_created` (ISO) • `com_name` = `"COM"` • `molecule_of_interest` (int) • `frame_range` = `(start, end)` inclusive • `frame_batch_size` (int) • `prune_low_energies_frac` (float in [0,1]) • `attractive_transitions_name` / `repulsive_transitions_name` (dataset names or `None`) • `attractive_energies_name` / `repulsive_energies_name` (dataset names or `None`) |
134
143
  | **Walks** | `ATTRACTIVE_RWs` → **(T, N·num_RWs, L+1)**, int32 (optional) • `REPULSIVE_RWs` → **(T, N·num_RWs, L+1)**, int32 (optional) • `ATTRACTIVE_SAWs` → **(T, N·num_SAWs, L+1)**, int32 (optional) • `REPULSIVE_SAWs` → **(T, N·num_SAWs, L+1)**, int32 (optional) <br/>_Note:_ node IDs are **1-based**.| `time_created` (ISO) • `seed` (int) • `rng_scheme` = `"SeedSequence.spawn_per_batch_v1"` • `num_workers` (int) • `in_parallel` (bool) • `batch_size_nodes` (int) • `num_RWs` / `num_SAWs` (ints) • `node_count` (N) • `time_stamp_count` (T) • `walk_length` (L) • `walks_per_node` (int) • `attractive_RWs_name` / `repulsive_RWs_name` / `attractive_SAWs_name` / `repulsive_SAWs_name` (dataset names or `None`) • `walks_layout` = `"time_leading_3d"` |
135
- | **Embeddings** | `FRAME_EMBEDDINGS` → **(frames_written, vocab_size, D)**, typically float32 | `time_created` (ISO) • `seed` (int) • `rng_scheme` = `"SeedSequence.spawn_per_frame_v1"` • `source_walks_path` (str) • `model_base` = `"torch"` or `"pureml"` • `rin_type` = `"attr"` or `"repuls"` • `using_mode` = `"RW"|"SAW"|"merged"` • `window_size` (int) • `alpha` (float; noise exponent) • `dimensionality` = D • `num_negative_samples` (int) • `num_epochs` (int) • `batch_size` (int) • `shuffle_data` (bool) • `frames_written` (int) • `vocab_size` (int) • `frame_count` (int) • `embedding_dtype` (str) • `frame_embeddings_name` = `"FRAME_EMBEDDINGS"` • `arrays_per_chunk` (int) • `compression_level` (int) • `objective` = `"sgns"` or `"sg"` |
144
+ | **Embeddings** | `FRAME_EMBEDDINGS` → **(T, N, D)**, float32 | `created_at` (ISO) • `frame_embeddings_name` = `"FRAME_EMBEDDINGS"` • `time_stamp_count` = T • `node_count` = N • `embedding_dim` = D • `model_base` = `"torch"` or `"pureml"` • `embedding_kind` = `"in"|"out"|"avg"` • `objective` = `"sgns"` or `"sg"` • `negative_sampling` (bool) • `num_negative_samples` (int) • `num_epochs` (int) • `batch_size` (int) • `window_size` (int) • `alpha` (float) • `lr_step_per_batch` (bool) • `shuffle_data` (bool) • `device_hint` (str) • `model_kwargs_repr` (repr string) • `RIN_type` = `"attr"` or `"repuls"` • `using` = `"RW"|"SAW"|"merged"` • `source_WALKS_path` (str) • `walk_length` (int) • `num_RWs` / `num_SAWs` (ints) • `attractive_*_name` / `repulsive_*_name` (dataset names or `None`) • `master_seed` (int) • `per_frame_seeds` (list[int]) • `arrays_per_chunk` (int) • `compression_level` (int) |
136
145
 
137
146
  **Notes**
138
147
 
139
- - In **RIN**, `T` equals the number of frame **batches** written (i.e., `frame_range` swept in steps of `frame_batch_size`). `ATTRACTIVE/REPULSIVE_energies` are **pre-normalised** absolute energies (written only when `keep_prenormalized_energies=True`), whereas `ATTRACTIVE/REPULSIVE_transitions` are the **row-wise L1-normalised** versions used for sampling.
148
+ - In **RIN**, `T` equals the number of frame **batches** written (i.e., `frame_range` swept in steps of `frame_batch_size`). `ATTRACTIVE/REPULSIVE_energies` are **pre-normalized** absolute energies (written only when `keep_prenormalized_energies=True`), whereas `ATTRACTIVE/REPULSIVE_transitions` are the **row-wise L1-normalized** versions used for sampling.
140
149
  - All archives are Zarr v3 groups. ArrayStorage also maintains per-block metadata in root attrs: `array_chunk_size_in_block`, `array_shape_in_block`, and `array_dtype_in_block` (dicts keyed by dataset name). You’ll see these in every archive.
141
150
  - In **Embeddings**, `alpha` and `num_negative_samples` apply to **SGNS** only and are ignored for `objective="sg"`.
142
151
 
@@ -166,7 +175,7 @@ rin_builder.build_rin(
166
175
  prune_low_energies_frac=0.85,
167
176
  output_path=rin_path,
168
177
  include_attractive=True,
169
- include_repulsive=False,
178
+ include_repulsive=False
170
179
  )
171
180
 
172
181
  # 2. Sample walks from the RIN
@@ -174,44 +183,34 @@ walker = Walker(rin_path, seed=123)
174
183
  walks_path = Path("./WALKS_demo.zip")
175
184
  walker.sample_walks(
176
185
  walk_length=16,
177
- walks_per_node=32,
186
+ walks_per_node=100,
178
187
  saw_frac=0.25,
179
188
  include_attractive=True,
180
189
  include_repulsive=False,
181
190
  time_aware=False,
182
191
  output_path=walks_path,
183
- in_parallel=False,
192
+ in_parallel=False
184
193
  )
185
194
  walker.close()
186
195
 
187
196
  # 3. Train embeddings per frame (PyTorch backend)
188
197
  import torch
189
198
 
190
- embedder = Embedder(walks_path, base="torch", seed=999)
199
+ embedder = Embedder(walks_path, seed=999)
191
200
  embeddings_path = embedder.embed_all(
192
201
  RIN_type="attr",
193
202
  using="merged",
203
+ num_epochs=10,
204
+ negative_sampling=False,
194
205
  window_size=4,
195
- objective="sgns",
196
- num_negative_samples=5,
197
- num_epochs=5,
198
- batch_size=1024,
199
- dimensionality=128,
200
- shuffle_data=True,
201
- output_path="./EMBEDDINGS_demo.zip",
202
- sgns_kwargs={
203
- "optim": torch.optim.Adam,
204
- "optim_kwargs": {"lr": 1e-3},
205
- "lr_sched": torch.optim.lr_scheduler.LambdaLR,
206
- "lr_sched_kwargs": {"lr_lambda": lambda _: 1.0},
207
- "device": "cuda" if torch.cuda.is_available() else "cpu",
208
- },
206
+ device="cuda" if torch.cuda.is_available() else "cpu",
207
+ model_base="torch",
208
+ output_path="./EMBEDDINGS_demo.zip"
209
209
  )
210
210
  print("Embeddings written to", embeddings_path)
211
211
  ```
212
212
 
213
- > For the PureML backend, supply the relevant optimiser and scheduler via `sgns_kwargs`
214
- > (for example `optim=pureml.optimizers.Adam`, `lr_sched=pureml.optimizers.CosineAnnealingLR`).
213
+ > For the PureML backend, set `model_base="pureml"` and pass the optimizer / scheduler classes inside `model_kwargs`.
215
214
 
216
215
  ---
217
216
 
@@ -236,7 +235,7 @@ v.build_frame(1,
236
235
  ```python
237
236
  from sawnergy.embedding import Visualizer
238
237
 
239
- viz = sawnergy.embedding.Visualizer("./EMBEDDINGS_demo.zip")
238
+ viz = Visualizer("./EMBEDDINGS_demo.zip", normalize_rows=True)
240
239
  viz.build_frame(1, show=True)
241
240
  ```
242
241
 
@@ -246,8 +245,7 @@ viz.build_frame(1, show=True)
246
245
 
247
246
  - **Time-aware walks**: Set `time_aware=True`, provide `stickiness` and `on_no_options` when calling `Walker.sample_walks`.
248
247
  - **Shared memory lifecycle**: Call `Walker.close()` (or use a context manager) to release shared-memory segments.
249
- - **PureML vs PyTorch**: Choose the backend via `Embedder(..., base="pureml"|"torch")` and provide backend-specific
250
- constructor kwargs through `sgns_kwargs` (optimizer, scheduler, device).
248
+ - **PureML vs PyTorch**: Select the backend at call time with `model_base="pureml"|"torch"` (defaults to `"pureml"`) and pass optimizer / scheduler overrides through `model_kwargs`.
251
249
  - **ArrayStorage utilities**: Use `ArrayStorage` directly to peek into archives, append arrays, or manage metadata.
252
250
 
253
251
  ---
@@ -258,8 +256,9 @@ viz.build_frame(1, show=True)
258
256
  ├── sawnergy/
259
257
  │ ├── rin/ # RINBuilder and cpptraj integration helpers
260
258
  │ ├── walks/ # Walker class and shared-memory utilities
261
- │ ├── embedding/ # Embedder + SGNS backends (PureML / PyTorch)
259
+ │ ├── embedding/ # Embedder + SG/SGNS backends (PureML / PyTorch)
262
260
  │ ├── visual/ # Visualizer and palette utilities
261
+ │ │
263
262
  │ ├── logging_util.py
264
263
  │ └── sawnergy_util.py
265
264
 
@@ -268,7 +267,7 @@ viz.build_frame(1, show=True)
268
267
 
269
268
  ---
270
269
 
271
- ## Acknowledgements
270
+ ## Acknowledgments
272
271
 
273
272
  SAWNERGY builds on the AmberTools `cpptraj` ecosystem, NumPy, Matplotlib, Zarr, and PyTorch (for GPU acceleration if necessary; PureML is available by default).
274
273
  Big thanks to the upstream communities whose work makes this toolkit possible.
@@ -6,7 +6,7 @@ from pureml.machinery import Tensor
6
6
  from pureml.layers import Embedding, Affine
7
7
  from pureml.losses import BCE, CCE
8
8
  from pureml.general_math import sum as t_sum
9
- from pureml.optimizers import Optim, LRScheduler
9
+ from pureml.optimizers import Optim, LRScheduler, SGD
10
10
  from pureml.training_utils import TensorDataset, DataLoader, one_hot
11
11
  from pureml.base import NN
12
12
 
@@ -32,8 +32,8 @@ class SGNS_PureML(NN):
32
32
  D: int,
33
33
  *,
34
34
  seed: int | None = None,
35
- optim: Type[Optim],
36
- optim_kwargs: dict,
35
+ optim: Type[Optim] = SGD,
36
+ optim_kwargs: dict | None = None,
37
37
  lr_sched: Type[LRScheduler] | None = None,
38
38
  lr_sched_kwargs: dict | None = None,
39
39
  device: str | None = None):
@@ -42,15 +42,15 @@ class SGNS_PureML(NN):
42
42
  V: Vocabulary size (number of nodes).
43
43
  D: Embedding dimensionality.
44
44
  seed: Optional RNG seed for negative sampling.
45
- optim: Optimizer class to instantiate.
46
- optim_kwargs: Keyword arguments for the optimizer (required).
45
+ optim: Optimizer class to instantiate. Defaults to plain SGD.
46
+ optim_kwargs: Keyword arguments for the optimizer. Defaults to {"lr": 0.1}.
47
47
  lr_sched: Optional learning-rate scheduler class.
48
48
  lr_sched_kwargs: Keyword arguments for the scheduler (required if lr_sched is provided).
49
49
  device: Target device string (e.g. "cuda"); accepted for API parity, ignored by PureML.
50
50
  """
51
51
 
52
- if optim_kwargs is None:
53
- raise ValueError("optim_kwargs must be provided")
52
+ optim_kwargs = optim_kwargs or {"lr": 0.1}
53
+
54
54
  if lr_sched is not None and lr_sched_kwargs is None:
55
55
  raise ValueError("lr_sched_kwargs required when lr_sched is provided")
56
56
 
@@ -147,7 +147,7 @@ class SGNS_PureML(NN):
147
147
  K = int(neg.data.shape[1])
148
148
  loss = (
149
149
  BCE(y_pos, x_pos_logits, from_logits=True)
150
- + K*BCE(y_neg, x_neg_logits, from_logits=True)
150
+ + Tensor(K)*BCE(y_neg, x_neg_logits, from_logits=True)
151
151
  )
152
152
 
153
153
  self.optim.zero_grad()
@@ -176,7 +176,9 @@ class SGNS_PureML(NN):
176
176
  "Wrong embedding matrix shape: "
177
177
  "self.in_emb.parameters[0].shape != (V, D)"
178
178
  )
179
- return W.numpy(copy=True, readonly=True)
179
+ arr = W.numpy(copy=True, readonly=True) # (V, D)
180
+ _logger.debug("In emb shape: %s", arr.shape)
181
+ return arr
180
182
 
181
183
  @property
182
184
  def out_embeddings(self) -> np.ndarray:
@@ -186,7 +188,9 @@ class SGNS_PureML(NN):
186
188
  "Wrong embedding matrix shape: "
187
189
  "self.out_emb.parameters[0].shape != (V, D)"
188
190
  )
189
- return W.numpy(copy=True, readonly=True)
191
+ arr = W.numpy(copy=True, readonly=True) # (V, D)
192
+ _logger.debug("Out emb shape: %s", arr.shape)
193
+ return arr
190
194
 
191
195
  @property
192
196
  def avg_embeddings(self) -> np.ndarray:
@@ -208,37 +212,29 @@ class SG_PureML(NN):
208
212
  """
209
213
 
210
214
  def __init__(self,
211
- V: int,
212
- D: int,
213
- *,
214
- seed: int | None = None,
215
- optim: Type[Optim],
216
- optim_kwargs: dict,
217
- lr_sched: Type[LRScheduler] | None = None,
218
- lr_sched_kwargs: dict | None = None,
219
- device: str | None = None):
215
+ V: int,
216
+ D: int,
217
+ *,
218
+ seed: int | None = None,
219
+ optim: Type[Optim] = SGD,
220
+ optim_kwargs: dict | None = None,
221
+ lr_sched: Type[LRScheduler] | None = None,
222
+ lr_sched_kwargs: dict | None = None,
223
+ device: str | None = None):
220
224
  """Initialize the plain Skip-Gram model (full softmax).
221
225
 
222
226
  Args:
223
227
  V: Vocabulary size (number of nodes/tokens).
224
228
  D: Embedding dimensionality.
225
229
  seed: Optional RNG seed (kept for API parity; not used in layer init).
226
- optim: Optimizer class to instantiate (e.g., `Adam`, `SGD`).
227
- optim_kwargs: Keyword arguments passed to the optimizer constructor.
230
+ optim: Optimizer class to instantiate. Defaults to plain SGD.
231
+ optim_kwargs: Keyword arguments for the optimizer. Defaults to {"lr": 0.1}.
228
232
  lr_sched: Optional learning-rate scheduler class.
229
- lr_sched_kwargs: Keyword arguments for the scheduler
230
- (required if `lr_sched` is provided).
231
- device: Device string (e.g., `"cuda"`). Accepted for parity, ignored
232
- by PureML (CPU-only).
233
-
234
- Notes:
235
- The encoder/decoder are implemented as:
236
- • `in_emb = Affine(V, D)` (acts on a one-hot center index)
237
- • `out_emb = Affine(D, V)`
238
- so forward pass produces vocabulary-sized logits.
233
+ lr_sched_kwargs: Keyword arguments for the scheduler (required if lr_sched is provided).
234
+ device: Device string (e.g., "cuda"). Accepted for parity, ignored by PureML (CPU-only).
239
235
  """
240
- if optim_kwargs is None:
241
- raise ValueError("optim_kwargs must be provided")
236
+
237
+ optim_kwargs = optim_kwargs or {"lr": 0.1}
242
238
  if lr_sched is not None and lr_sched_kwargs is None:
243
239
  raise ValueError("lr_sched_kwargs required when lr_sched is provided")
244
240
 
@@ -249,9 +245,7 @@ class SG_PureML(NN):
249
245
  self.out_emb = Affine(self.D, self.V)
250
246
 
251
247
  self.seed = None if seed is None else int(seed)
252
-
253
- # API compatibility: PureML is CPU-only
254
- self.device = "cpu"
248
+ self.device = "cpu" # API parity
255
249
 
256
250
  # optimizer / scheduler
257
251
  self.optim: Optim = optim(self.parameters, **optim_kwargs)
@@ -344,7 +338,9 @@ class SG_PureML(NN):
344
338
  "Wrong embedding matrix shape: "
345
339
  "self.in_emb.parameters[0].shape != (V, D)"
346
340
  )
347
- return W.numpy(copy=True, readonly=True) # (V, D)
341
+ arr = W.numpy(copy=True, readonly=True) # (V, D)
342
+ _logger.debug("In emb shape: %s", arr.shape)
343
+ return arr
348
344
 
349
345
  @property
350
346
  def out_embeddings(self) -> np.ndarray:
@@ -356,7 +352,9 @@ class SG_PureML(NN):
356
352
  "Wrong embedding matrix shape: "
357
353
  "self.out_emb.parameters[0].shape != (D, V)"
358
354
  )
359
- return W.numpy(copy=True, readonly=True).T # (V, D)
355
+ arr = W.numpy(copy=True, readonly=True).T # (V, D)
356
+ _logger.debug("Out emb shape: %s", arr.shape)
357
+ return arr
360
358
 
361
359
  @property
362
360
  def avg_embeddings(self) -> np.ndarray: