sawnergy 1.0.6__tar.gz → 1.0.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sawnergy might be problematic. Click here for more details.
- {sawnergy-1.0.6/sawnergy.egg-info → sawnergy-1.0.7}/PKG-INFO +48 -24
- {sawnergy-1.0.6 → sawnergy-1.0.7}/README.md +47 -23
- sawnergy-1.0.7/sawnergy/embedding/SGNS_pml.py +370 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/embedding/SGNS_torch.py +145 -11
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/embedding/__init__.py +24 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/embedding/embedder.py +99 -49
- sawnergy-1.0.7/sawnergy/embedding/visualizer.py +247 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/logging_util.py +1 -1
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/rin/rin_builder.py +1 -1
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/visual/visualizer.py +6 -6
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/visual/visualizer_util.py +3 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7/sawnergy.egg-info}/PKG-INFO +48 -24
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy.egg-info/SOURCES.txt +2 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/tests/test_embedding.py +86 -3
- sawnergy-1.0.7/tests/test_embedding_visualizer.py +58 -0
- sawnergy-1.0.6/sawnergy/embedding/SGNS_pml.py +0 -172
- {sawnergy-1.0.6 → sawnergy-1.0.7}/LICENSE +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/NOTICE +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/__init__.py +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/rin/__init__.py +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/rin/rin_util.py +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/sawnergy_util.py +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/visual/__init__.py +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/walks/__init__.py +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/walks/walker.py +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy/walks/walker_util.py +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy.egg-info/dependency_links.txt +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy.egg-info/requires.txt +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/sawnergy.egg-info/top_level.txt +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/setup.cfg +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/tests/test_rin.py +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/tests/test_storage.py +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/tests/test_visual.py +0 -0
- {sawnergy-1.0.6 → sawnergy-1.0.7}/tests/test_walks.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sawnergy
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.7
|
|
4
4
|
Summary: Toolkit for transforming molecular dynamics (MD) trajectories into rich graph representations
|
|
5
5
|
Home-page: https://github.com/Yehor-Mishchyriak/SAWNERGY
|
|
6
6
|
Author: Yehor Mishchyriak
|
|
@@ -39,18 +39,44 @@ Dynamic: summary
|
|
|
39
39
|

|
|
40
40
|
|
|
41
41
|
A toolkit for transforming molecular dynamics (MD) trajectories into rich graph representations, sampling
|
|
42
|
-
random and self-avoiding walks, learning node embeddings, and
|
|
42
|
+
random and self-avoiding walks, learning node embeddings, and visualizing residue interaction networks (RINs). SAWNERGY
|
|
43
43
|
keeps the full workflow — from `cpptraj` output to skip-gram embeddings (node2vec approach) — inside Python, backed by efficient Zarr-based archives and optional GPU acceleration.
|
|
44
44
|
|
|
45
45
|
---
|
|
46
46
|
|
|
47
|
+
## Installation
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
pip install sawnergy
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
> **Optional:** For GPU training, install PyTorch separately (e.g., `pip install torch`).
|
|
54
|
+
> **Note:** RIN building requires `cpptraj` (AmberTools). Ensure it is discoverable via `$PATH` or the `CPPTRAJ`
|
|
55
|
+
> environment variable. Probably the easiest solution: install AmberTools via conda, activate the environment, and SAWNERGY will find cpptraj executable on its own, so just run your code and don't worry about it.
|
|
56
|
+
|
|
57
|
+
---
|
|
58
|
+
|
|
59
|
+
# UPDATES:
|
|
60
|
+
|
|
61
|
+
## v1.0.7 — What’s new:
|
|
62
|
+
- **Added plain SkipGram model**
|
|
63
|
+
- Now, the user can choose if they want to apply the negative sampling technique (two binary classifiers) or train a single classifier over the vocabulary (full softmax). For more detail, see: [node2vec](https://arxiv.org/pdf/1607.00653), [word2vec](https://arxiv.org/pdf/1301.3781), and [negative_sampling](https://arxiv.org/pdf/1402.3722).
|
|
64
|
+
- **Set a harsher default for low interaction energies pruning during RIN construction**
|
|
65
|
+
- Now we zero out 85% of the lowest interaction energies as opposed to the past 30% default, leading to more meaningful embeddings.
|
|
66
|
+
- **BUG FIX: Visualizer**
|
|
67
|
+
- Previously, the visualizer would silently draw edges of 0 magnitude, meaning they were actually being drawn but were invisible due to full transparency and 0 width. As a result, the displayed image / animation would be very laggy. Now, this was fixed, and given high pruning default, the displayed interaction networks are clean and smooth under rotations, dragging, etc.
|
|
68
|
+
- **New Embedding Visualizer (3D)**
|
|
69
|
+
- New lightweight viewer for per-frame embeddings that projects embeddings with PCA to a **3D** scatter. Supports the same node coloring semantics, optional node labels, and the same antialiasing/depthshade controls. Works in headless setups using the same backend guard and uses a blocking `show=True` for scripts.
|
|
70
|
+
|
|
71
|
+
---
|
|
72
|
+
|
|
47
73
|
## Why SAWNERGY?
|
|
48
74
|
|
|
49
75
|
- **Bridge simulations and graph ML**: Convert raw MD trajectories into residue interaction networks ready for graph
|
|
50
76
|
algorithms and downstream machine learning tasks.
|
|
51
|
-
- **Deterministic, shareable
|
|
52
|
-
- **High-performance data handling**: Heavy arrays live in shared memory during walk sampling to allow parallel processing without
|
|
53
|
-
- **Flexible
|
|
77
|
+
- **Deterministic, shareable artifacts**: Every stage produces compressed Zarr archives that contain both data and metadata so runs can be reproduced, shared, or inspected later.
|
|
78
|
+
- **High-performance data handling**: Heavy arrays live in shared memory during walk sampling to allow parallel processing without serialization overhead; archives are written in chunked, compressed form for fast read/write.
|
|
79
|
+
- **Flexible objectives & backends**: Train Skip-Gram with **negative sampling** (`objective="sgns"`) or **plain Skip-Gram** (`objective="sg"`), using either **PureML** (default) or **PyTorch**.
|
|
54
80
|
- **Visualization out of the box**: Plot and animate residue networks without leaving Python, using the data produced by RINBuilder
|
|
55
81
|
|
|
56
82
|
---
|
|
@@ -91,9 +117,9 @@ node indexing, and RNG seeds stay consistent across the toolchain.
|
|
|
91
117
|
* Wraps the AmberTools `cpptraj` executable to:
|
|
92
118
|
- compute per-frame electrostatic (EMAP) and van der Waals (VMAP) energy matrices at the atomic level,
|
|
93
119
|
- project atom–atom interactions to residue–residue interactions using compositional masks,
|
|
94
|
-
- prune,
|
|
95
|
-
- compute per-residue
|
|
96
|
-
* Outputs a compressed Zarr archive with transition matrices, optional
|
|
120
|
+
- prune, symmetrize, remove self-interactions, and L1-normalise the matrices,
|
|
121
|
+
- compute per-residue centers of mass (COM) over the same frames.
|
|
122
|
+
* Outputs a compressed Zarr archive with transition matrices, optional pre-normalized energies, COM snapshots, and rich
|
|
97
123
|
metadata (frame range, pruning quantile, molecule ID, etc.).
|
|
98
124
|
* Supports parallel `cpptraj` execution, batch processing, and keeps temporary stores tidy via
|
|
99
125
|
`ArrayStorage.compress_and_cleanup`.
|
|
@@ -103,7 +129,7 @@ node indexing, and RNG seeds stay consistent across the toolchain.
|
|
|
103
129
|
* Opens RIN archives, resolves dataset names from attributes, and renders nodes plus attractive/repulsive edge bundles
|
|
104
130
|
in 3D using Matplotlib.
|
|
105
131
|
* Allows both static frame visualization and trajectory animation.
|
|
106
|
-
* Handles backend selection (`Agg` fallback in headless environments) and offers convenient
|
|
132
|
+
* Handles backend selection (`Agg` fallback in headless environments) and offers convenient color palettes via
|
|
107
133
|
`visualizer_util`.
|
|
108
134
|
|
|
109
135
|
### `sawnergy.walks.Walker`
|
|
@@ -140,23 +166,13 @@ node indexing, and RNG seeds stay consistent across the toolchain.
|
|
|
140
166
|
|---|---|---|
|
|
141
167
|
| **RIN** | `ATTRACTIVE_transitions` → **(T, N, N)**, float32 • `REPULSIVE_transitions` → **(T, N, N)**, float32 (optional) • `ATTRACTIVE_energies` → **(T, N, N)**, float32 (optional) • `REPULSIVE_energies` → **(T, N, N)**, float32 (optional) • `COM` → **(T, N, 3)**, float32 | `time_created` (ISO) • `com_name` = `"COM"` • `molecule_of_interest` (int) • `frame_range` = `(start, end)` inclusive • `frame_batch_size` (int) • `prune_low_energies_frac` (float in [0,1]) • `attractive_transitions_name` / `repulsive_transitions_name` (dataset names or `None`) • `attractive_energies_name` / `repulsive_energies_name` (dataset names or `None`) |
|
|
142
168
|
| **Walks** | `ATTRACTIVE_RWs` → **(T, N·num_RWs, L+1)**, int32 (optional) • `REPULSIVE_RWs` → **(T, N·num_RWs, L+1)**, int32 (optional) • `ATTRACTIVE_SAWs` → **(T, N·num_SAWs, L+1)**, int32 (optional) • `REPULSIVE_SAWs` → **(T, N·num_SAWs, L+1)**, int32 (optional) <br/>_Note:_ node IDs are **1-based**.| `time_created` (ISO) • `seed` (int) • `rng_scheme` = `"SeedSequence.spawn_per_batch_v1"` • `num_workers` (int) • `in_parallel` (bool) • `batch_size_nodes` (int) • `num_RWs` / `num_SAWs` (ints) • `node_count` (N) • `time_stamp_count` (T) • `walk_length` (L) • `walks_per_node` (int) • `attractive_RWs_name` / `repulsive_RWs_name` / `attractive_SAWs_name` / `repulsive_SAWs_name` (dataset names or `None`) • `walks_layout` = `"time_leading_3d"` |
|
|
143
|
-
| **Embeddings** | `FRAME_EMBEDDINGS` → **(frames_written, vocab_size, D)**, typically float32 | `time_created` (ISO) • `seed` (int) • `rng_scheme` = `"SeedSequence.spawn_per_frame_v1"` • `source_walks_path` (str) • `model_base` = `"torch"` or `"pureml"` • `rin_type` = `"attr"` or `"repuls"` • `using_mode` = `"RW"|"SAW"|"merged"` • `window_size` (int) • `alpha` (float; noise exponent) • `dimensionality` = D • `num_negative_samples` (int) • `num_epochs` (int) • `batch_size` (int) • `shuffle_data` (bool) • `frames_written` (int) • `vocab_size` (int) • `frame_count` (int) • `embedding_dtype` (str) • `frame_embeddings_name` = `"FRAME_EMBEDDINGS"` • `arrays_per_chunk` (int) • `compression_level` (int) |
|
|
169
|
+
| **Embeddings** | `FRAME_EMBEDDINGS` → **(frames_written, vocab_size, D)**, typically float32 | `time_created` (ISO) • `seed` (int) • `rng_scheme` = `"SeedSequence.spawn_per_frame_v1"` • `source_walks_path` (str) • `model_base` = `"torch"` or `"pureml"` • `rin_type` = `"attr"` or `"repuls"` • `using_mode` = `"RW"|"SAW"|"merged"` • `window_size` (int) • `alpha` (float; noise exponent) • `dimensionality` = D • `num_negative_samples` (int) • `num_epochs` (int) • `batch_size` (int) • `shuffle_data` (bool) • `frames_written` (int) • `vocab_size` (int) • `frame_count` (int) • `embedding_dtype` (str) • `frame_embeddings_name` = `"FRAME_EMBEDDINGS"` • `arrays_per_chunk` (int) • `compression_level` (int) • `objective` = `"sgns"` or `"sg"` |
|
|
144
170
|
|
|
145
171
|
**Notes**
|
|
146
172
|
|
|
147
173
|
- In **RIN**, `T` equals the number of frame **batches** written (i.e., `frame_range` swept in steps of `frame_batch_size`). `ATTRACTIVE/REPULSIVE_energies` are **pre-normalised** absolute energies (written only when `keep_prenormalized_energies=True`), whereas `ATTRACTIVE/REPULSIVE_transitions` are the **row-wise L1-normalised** versions used for sampling.
|
|
148
174
|
- All archives are Zarr v3 groups. ArrayStorage also maintains per-block metadata in root attrs: `array_chunk_size_in_block`, `array_shape_in_block`, and `array_dtype_in_block` (dicts keyed by dataset name). You’ll see these in every archive.
|
|
149
|
-
|
|
150
|
-
---
|
|
151
|
-
|
|
152
|
-
## Installation
|
|
153
|
-
|
|
154
|
-
```bash
|
|
155
|
-
pip install sawnergy
|
|
156
|
-
```
|
|
157
|
-
|
|
158
|
-
> **Note:** RIN building requires `cpptraj` (AmberTools). Ensure it is discoverable via `$PATH` or the `CPPTRAJ`
|
|
159
|
-
> environment variable.
|
|
175
|
+
- In **Embeddings**, `alpha` and `num_negative_samples` apply to **SGNS** only and are ignored for `objective="sg"`.
|
|
160
176
|
|
|
161
177
|
---
|
|
162
178
|
|
|
@@ -181,7 +197,7 @@ rin_builder.build_rin(
|
|
|
181
197
|
molecule_of_interest=1,
|
|
182
198
|
frame_range=(1, 100),
|
|
183
199
|
frame_batch_size=10,
|
|
184
|
-
prune_low_energies_frac=0.
|
|
200
|
+
prune_low_energies_frac=0.85,
|
|
185
201
|
output_path=rin_path,
|
|
186
202
|
include_attractive=True,
|
|
187
203
|
include_repulsive=False,
|
|
@@ -210,6 +226,7 @@ embeddings_path = embedder.embed_all(
|
|
|
210
226
|
RIN_type="attr",
|
|
211
227
|
using="merged",
|
|
212
228
|
window_size=4,
|
|
229
|
+
objective="sgns",
|
|
213
230
|
num_negative_samples=5,
|
|
214
231
|
num_epochs=5,
|
|
215
232
|
batch_size=1024,
|
|
@@ -232,12 +249,12 @@ print("Embeddings written to", embeddings_path)
|
|
|
232
249
|
|
|
233
250
|
---
|
|
234
251
|
|
|
235
|
-
##
|
|
252
|
+
## Visualization
|
|
236
253
|
|
|
237
254
|
```python
|
|
238
255
|
from sawnergy.visual import Visualizer
|
|
239
256
|
|
|
240
|
-
v =
|
|
257
|
+
v = Visualizer("./RIN_demo.zip")
|
|
241
258
|
v.build_frame(1,
|
|
242
259
|
node_colors="rainbow",
|
|
243
260
|
displayed_nodes="ALL",
|
|
@@ -250,6 +267,13 @@ v.build_frame(1,
|
|
|
250
267
|
|
|
251
268
|
`Visualizer` lazily loads datasets and works even in headless environments (falls back to the `Agg` backend).
|
|
252
269
|
|
|
270
|
+
```python
|
|
271
|
+
from sawnergy.embedding import Visualizer
|
|
272
|
+
|
|
273
|
+
viz = sawnergy.embedding.Visualizer("./EMBEDDINGS_demo.zip")
|
|
274
|
+
viz.build_frame(1, show=True)
|
|
275
|
+
```
|
|
276
|
+
|
|
253
277
|
---
|
|
254
278
|
|
|
255
279
|
## Advanced Notes
|
|
@@ -5,18 +5,44 @@
|
|
|
5
5
|

|
|
6
6
|
|
|
7
7
|
A toolkit for transforming molecular dynamics (MD) trajectories into rich graph representations, sampling
|
|
8
|
-
random and self-avoiding walks, learning node embeddings, and
|
|
8
|
+
random and self-avoiding walks, learning node embeddings, and visualizing residue interaction networks (RINs). SAWNERGY
|
|
9
9
|
keeps the full workflow — from `cpptraj` output to skip-gram embeddings (node2vec approach) — inside Python, backed by efficient Zarr-based archives and optional GPU acceleration.
|
|
10
10
|
|
|
11
11
|
---
|
|
12
12
|
|
|
13
|
+
## Installation
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
pip install sawnergy
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
> **Optional:** For GPU training, install PyTorch separately (e.g., `pip install torch`).
|
|
20
|
+
> **Note:** RIN building requires `cpptraj` (AmberTools). Ensure it is discoverable via `$PATH` or the `CPPTRAJ`
|
|
21
|
+
> environment variable. Probably the easiest solution: install AmberTools via conda, activate the environment, and SAWNERGY will find cpptraj executable on its own, so just run your code and don't worry about it.
|
|
22
|
+
|
|
23
|
+
---
|
|
24
|
+
|
|
25
|
+
# UPDATES:
|
|
26
|
+
|
|
27
|
+
## v1.0.7 — What’s new:
|
|
28
|
+
- **Added plain SkipGram model**
|
|
29
|
+
- Now, the user can choose if they want to apply the negative sampling technique (two binary classifiers) or train a single classifier over the vocabulary (full softmax). For more detail, see: [node2vec](https://arxiv.org/pdf/1607.00653), [word2vec](https://arxiv.org/pdf/1301.3781), and [negative_sampling](https://arxiv.org/pdf/1402.3722).
|
|
30
|
+
- **Set a harsher default for low interaction energies pruning during RIN construction**
|
|
31
|
+
- Now we zero out 85% of the lowest interaction energies as opposed to the past 30% default, leading to more meaningful embeddings.
|
|
32
|
+
- **BUG FIX: Visualizer**
|
|
33
|
+
- Previously, the visualizer would silently draw edges of 0 magnitude, meaning they were actually being drawn but were invisible due to full transparency and 0 width. As a result, the displayed image / animation would be very laggy. Now, this was fixed, and given high pruning default, the displayed interaction networks are clean and smooth under rotations, dragging, etc.
|
|
34
|
+
- **New Embedding Visualizer (3D)**
|
|
35
|
+
- New lightweight viewer for per-frame embeddings that projects embeddings with PCA to a **3D** scatter. Supports the same node coloring semantics, optional node labels, and the same antialiasing/depthshade controls. Works in headless setups using the same backend guard and uses a blocking `show=True` for scripts.
|
|
36
|
+
|
|
37
|
+
---
|
|
38
|
+
|
|
13
39
|
## Why SAWNERGY?
|
|
14
40
|
|
|
15
41
|
- **Bridge simulations and graph ML**: Convert raw MD trajectories into residue interaction networks ready for graph
|
|
16
42
|
algorithms and downstream machine learning tasks.
|
|
17
|
-
- **Deterministic, shareable
|
|
18
|
-
- **High-performance data handling**: Heavy arrays live in shared memory during walk sampling to allow parallel processing without
|
|
19
|
-
- **Flexible
|
|
43
|
+
- **Deterministic, shareable artifacts**: Every stage produces compressed Zarr archives that contain both data and metadata so runs can be reproduced, shared, or inspected later.
|
|
44
|
+
- **High-performance data handling**: Heavy arrays live in shared memory during walk sampling to allow parallel processing without serialization overhead; archives are written in chunked, compressed form for fast read/write.
|
|
45
|
+
- **Flexible objectives & backends**: Train Skip-Gram with **negative sampling** (`objective="sgns"`) or **plain Skip-Gram** (`objective="sg"`), using either **PureML** (default) or **PyTorch**.
|
|
20
46
|
- **Visualization out of the box**: Plot and animate residue networks without leaving Python, using the data produced by RINBuilder
|
|
21
47
|
|
|
22
48
|
---
|
|
@@ -57,9 +83,9 @@ node indexing, and RNG seeds stay consistent across the toolchain.
|
|
|
57
83
|
* Wraps the AmberTools `cpptraj` executable to:
|
|
58
84
|
- compute per-frame electrostatic (EMAP) and van der Waals (VMAP) energy matrices at the atomic level,
|
|
59
85
|
- project atom–atom interactions to residue–residue interactions using compositional masks,
|
|
60
|
-
- prune,
|
|
61
|
-
- compute per-residue
|
|
62
|
-
* Outputs a compressed Zarr archive with transition matrices, optional
|
|
86
|
+
- prune, symmetrize, remove self-interactions, and L1-normalise the matrices,
|
|
87
|
+
- compute per-residue centers of mass (COM) over the same frames.
|
|
88
|
+
* Outputs a compressed Zarr archive with transition matrices, optional pre-normalized energies, COM snapshots, and rich
|
|
63
89
|
metadata (frame range, pruning quantile, molecule ID, etc.).
|
|
64
90
|
* Supports parallel `cpptraj` execution, batch processing, and keeps temporary stores tidy via
|
|
65
91
|
`ArrayStorage.compress_and_cleanup`.
|
|
@@ -69,7 +95,7 @@ node indexing, and RNG seeds stay consistent across the toolchain.
|
|
|
69
95
|
* Opens RIN archives, resolves dataset names from attributes, and renders nodes plus attractive/repulsive edge bundles
|
|
70
96
|
in 3D using Matplotlib.
|
|
71
97
|
* Allows both static frame visualization and trajectory animation.
|
|
72
|
-
* Handles backend selection (`Agg` fallback in headless environments) and offers convenient
|
|
98
|
+
* Handles backend selection (`Agg` fallback in headless environments) and offers convenient color palettes via
|
|
73
99
|
`visualizer_util`.
|
|
74
100
|
|
|
75
101
|
### `sawnergy.walks.Walker`
|
|
@@ -106,23 +132,13 @@ node indexing, and RNG seeds stay consistent across the toolchain.
|
|
|
106
132
|
|---|---|---|
|
|
107
133
|
| **RIN** | `ATTRACTIVE_transitions` → **(T, N, N)**, float32 • `REPULSIVE_transitions` → **(T, N, N)**, float32 (optional) • `ATTRACTIVE_energies` → **(T, N, N)**, float32 (optional) • `REPULSIVE_energies` → **(T, N, N)**, float32 (optional) • `COM` → **(T, N, 3)**, float32 | `time_created` (ISO) • `com_name` = `"COM"` • `molecule_of_interest` (int) • `frame_range` = `(start, end)` inclusive • `frame_batch_size` (int) • `prune_low_energies_frac` (float in [0,1]) • `attractive_transitions_name` / `repulsive_transitions_name` (dataset names or `None`) • `attractive_energies_name` / `repulsive_energies_name` (dataset names or `None`) |
|
|
108
134
|
| **Walks** | `ATTRACTIVE_RWs` → **(T, N·num_RWs, L+1)**, int32 (optional) • `REPULSIVE_RWs` → **(T, N·num_RWs, L+1)**, int32 (optional) • `ATTRACTIVE_SAWs` → **(T, N·num_SAWs, L+1)**, int32 (optional) • `REPULSIVE_SAWs` → **(T, N·num_SAWs, L+1)**, int32 (optional) <br/>_Note:_ node IDs are **1-based**.| `time_created` (ISO) • `seed` (int) • `rng_scheme` = `"SeedSequence.spawn_per_batch_v1"` • `num_workers` (int) • `in_parallel` (bool) • `batch_size_nodes` (int) • `num_RWs` / `num_SAWs` (ints) • `node_count` (N) • `time_stamp_count` (T) • `walk_length` (L) • `walks_per_node` (int) • `attractive_RWs_name` / `repulsive_RWs_name` / `attractive_SAWs_name` / `repulsive_SAWs_name` (dataset names or `None`) • `walks_layout` = `"time_leading_3d"` |
|
|
109
|
-
| **Embeddings** | `FRAME_EMBEDDINGS` → **(frames_written, vocab_size, D)**, typically float32 | `time_created` (ISO) • `seed` (int) • `rng_scheme` = `"SeedSequence.spawn_per_frame_v1"` • `source_walks_path` (str) • `model_base` = `"torch"` or `"pureml"` • `rin_type` = `"attr"` or `"repuls"` • `using_mode` = `"RW"|"SAW"|"merged"` • `window_size` (int) • `alpha` (float; noise exponent) • `dimensionality` = D • `num_negative_samples` (int) • `num_epochs` (int) • `batch_size` (int) • `shuffle_data` (bool) • `frames_written` (int) • `vocab_size` (int) • `frame_count` (int) • `embedding_dtype` (str) • `frame_embeddings_name` = `"FRAME_EMBEDDINGS"` • `arrays_per_chunk` (int) • `compression_level` (int) |
|
|
135
|
+
| **Embeddings** | `FRAME_EMBEDDINGS` → **(frames_written, vocab_size, D)**, typically float32 | `time_created` (ISO) • `seed` (int) • `rng_scheme` = `"SeedSequence.spawn_per_frame_v1"` • `source_walks_path` (str) • `model_base` = `"torch"` or `"pureml"` • `rin_type` = `"attr"` or `"repuls"` • `using_mode` = `"RW"|"SAW"|"merged"` • `window_size` (int) • `alpha` (float; noise exponent) • `dimensionality` = D • `num_negative_samples` (int) • `num_epochs` (int) • `batch_size` (int) • `shuffle_data` (bool) • `frames_written` (int) • `vocab_size` (int) • `frame_count` (int) • `embedding_dtype` (str) • `frame_embeddings_name` = `"FRAME_EMBEDDINGS"` • `arrays_per_chunk` (int) • `compression_level` (int) • `objective` = `"sgns"` or `"sg"` |
|
|
110
136
|
|
|
111
137
|
**Notes**
|
|
112
138
|
|
|
113
139
|
- In **RIN**, `T` equals the number of frame **batches** written (i.e., `frame_range` swept in steps of `frame_batch_size`). `ATTRACTIVE/REPULSIVE_energies` are **pre-normalised** absolute energies (written only when `keep_prenormalized_energies=True`), whereas `ATTRACTIVE/REPULSIVE_transitions` are the **row-wise L1-normalised** versions used for sampling.
|
|
114
140
|
- All archives are Zarr v3 groups. ArrayStorage also maintains per-block metadata in root attrs: `array_chunk_size_in_block`, `array_shape_in_block`, and `array_dtype_in_block` (dicts keyed by dataset name). You’ll see these in every archive.
|
|
115
|
-
|
|
116
|
-
---
|
|
117
|
-
|
|
118
|
-
## Installation
|
|
119
|
-
|
|
120
|
-
```bash
|
|
121
|
-
pip install sawnergy
|
|
122
|
-
```
|
|
123
|
-
|
|
124
|
-
> **Note:** RIN building requires `cpptraj` (AmberTools). Ensure it is discoverable via `$PATH` or the `CPPTRAJ`
|
|
125
|
-
> environment variable.
|
|
141
|
+
- In **Embeddings**, `alpha` and `num_negative_samples` apply to **SGNS** only and are ignored for `objective="sg"`.
|
|
126
142
|
|
|
127
143
|
---
|
|
128
144
|
|
|
@@ -147,7 +163,7 @@ rin_builder.build_rin(
|
|
|
147
163
|
molecule_of_interest=1,
|
|
148
164
|
frame_range=(1, 100),
|
|
149
165
|
frame_batch_size=10,
|
|
150
|
-
prune_low_energies_frac=0.
|
|
166
|
+
prune_low_energies_frac=0.85,
|
|
151
167
|
output_path=rin_path,
|
|
152
168
|
include_attractive=True,
|
|
153
169
|
include_repulsive=False,
|
|
@@ -176,6 +192,7 @@ embeddings_path = embedder.embed_all(
|
|
|
176
192
|
RIN_type="attr",
|
|
177
193
|
using="merged",
|
|
178
194
|
window_size=4,
|
|
195
|
+
objective="sgns",
|
|
179
196
|
num_negative_samples=5,
|
|
180
197
|
num_epochs=5,
|
|
181
198
|
batch_size=1024,
|
|
@@ -198,12 +215,12 @@ print("Embeddings written to", embeddings_path)
|
|
|
198
215
|
|
|
199
216
|
---
|
|
200
217
|
|
|
201
|
-
##
|
|
218
|
+
## Visualization
|
|
202
219
|
|
|
203
220
|
```python
|
|
204
221
|
from sawnergy.visual import Visualizer
|
|
205
222
|
|
|
206
|
-
v =
|
|
223
|
+
v = Visualizer("./RIN_demo.zip")
|
|
207
224
|
v.build_frame(1,
|
|
208
225
|
node_colors="rainbow",
|
|
209
226
|
displayed_nodes="ALL",
|
|
@@ -216,6 +233,13 @@ v.build_frame(1,
|
|
|
216
233
|
|
|
217
234
|
`Visualizer` lazily loads datasets and works even in headless environments (falls back to the `Agg` backend).
|
|
218
235
|
|
|
236
|
+
```python
|
|
237
|
+
from sawnergy.embedding import Visualizer
|
|
238
|
+
|
|
239
|
+
viz = sawnergy.embedding.Visualizer("./EMBEDDINGS_demo.zip")
|
|
240
|
+
viz.build_frame(1, show=True)
|
|
241
|
+
```
|
|
242
|
+
|
|
219
243
|
---
|
|
220
244
|
|
|
221
245
|
## Advanced Notes
|
|
@@ -0,0 +1,370 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# third party
|
|
4
|
+
import numpy as np
|
|
5
|
+
from pureml.machinery import Tensor
|
|
6
|
+
from pureml.layers import Embedding, Affine
|
|
7
|
+
from pureml.losses import BCE, CCE
|
|
8
|
+
from pureml.general_math import sum as t_sum
|
|
9
|
+
from pureml.optimizers import Optim, LRScheduler
|
|
10
|
+
from pureml.training_utils import TensorDataset, DataLoader, one_hot
|
|
11
|
+
from pureml.base import NN
|
|
12
|
+
|
|
13
|
+
# built-in
|
|
14
|
+
import logging
|
|
15
|
+
from typing import Type
|
|
16
|
+
|
|
17
|
+
# *----------------------------------------------------*
|
|
18
|
+
# GLOBALS
|
|
19
|
+
# *----------------------------------------------------*
|
|
20
|
+
|
|
21
|
+
_logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# *----------------------------------------------------*
|
|
24
|
+
# CLASSES
|
|
25
|
+
# *----------------------------------------------------*
|
|
26
|
+
|
|
27
|
+
class SGNS_PureML(NN):
|
|
28
|
+
"""PureML implementation of Skip-Gram with Negative Sampling."""
|
|
29
|
+
|
|
30
|
+
def __init__(self,
|
|
31
|
+
V: int,
|
|
32
|
+
D: int,
|
|
33
|
+
*,
|
|
34
|
+
seed: int | None = None,
|
|
35
|
+
optim: Type[Optim],
|
|
36
|
+
optim_kwargs: dict,
|
|
37
|
+
lr_sched: Type[LRScheduler] | None = None,
|
|
38
|
+
lr_sched_kwargs: dict | None = None,
|
|
39
|
+
device: str | None = None):
|
|
40
|
+
"""
|
|
41
|
+
Args:
|
|
42
|
+
V: Vocabulary size (number of nodes).
|
|
43
|
+
D: Embedding dimensionality.
|
|
44
|
+
seed: Optional RNG seed for negative sampling.
|
|
45
|
+
optim: Optimizer class to instantiate.
|
|
46
|
+
optim_kwargs: Keyword arguments for the optimizer (required).
|
|
47
|
+
lr_sched: Optional learning-rate scheduler class.
|
|
48
|
+
lr_sched_kwargs: Keyword arguments for the scheduler (required if lr_sched is provided).
|
|
49
|
+
device: Target device string (e.g. "cuda"); accepted for API parity, ignored by PureML.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
if optim_kwargs is None:
|
|
53
|
+
raise ValueError("optim_kwargs must be provided")
|
|
54
|
+
if lr_sched is not None and lr_sched_kwargs is None:
|
|
55
|
+
raise ValueError("lr_sched_kwargs required when lr_sched is provided")
|
|
56
|
+
|
|
57
|
+
self.V, self.D = int(V), int(D)
|
|
58
|
+
|
|
59
|
+
# embeddings
|
|
60
|
+
self.in_emb = Embedding(self.V, self.D)
|
|
61
|
+
self.out_emb = Embedding(self.V, self.D)
|
|
62
|
+
|
|
63
|
+
# seed + RNG for negative sampling
|
|
64
|
+
self.seed = None if seed is None else int(seed)
|
|
65
|
+
self._rng = np.random.default_rng(self.seed)
|
|
66
|
+
if self.seed is not None:
|
|
67
|
+
# optional: also set global NumPy seed for any non-RNG paths
|
|
68
|
+
np.random.seed(self.seed)
|
|
69
|
+
|
|
70
|
+
# API compatibility: PureML is CPU-only
|
|
71
|
+
self.device = "cpu"
|
|
72
|
+
|
|
73
|
+
# optimizer / scheduler
|
|
74
|
+
self.optim: Optim = optim(self.parameters, **optim_kwargs)
|
|
75
|
+
self.lr_sched: LRScheduler | None = (
|
|
76
|
+
lr_sched(optim=self.optim, **lr_sched_kwargs) if lr_sched is not None else None
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
_logger.info(
|
|
80
|
+
"SGNS_PureML init: V=%d D=%d device=%s seed=%s",
|
|
81
|
+
self.V, self.D, self.device, self.seed
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def _sample_neg(self, B: int, K: int, dist: np.ndarray) -> np.ndarray:
|
|
85
|
+
return self._rng.choice(self.V, size=(B, K), replace=True, p=dist)
|
|
86
|
+
|
|
87
|
+
def predict(self, center: Tensor, pos: Tensor, neg: Tensor) -> tuple[Tensor, Tensor]:
|
|
88
|
+
"""Compute positive/negative logits for SGNS.
|
|
89
|
+
|
|
90
|
+
Shapes:
|
|
91
|
+
center: (B,)
|
|
92
|
+
pos: (B,)
|
|
93
|
+
neg: (B, K)
|
|
94
|
+
Returns:
|
|
95
|
+
pos_logits: (B,)
|
|
96
|
+
neg_logits: (B, K)
|
|
97
|
+
"""
|
|
98
|
+
c = self.in_emb(center) # (B, D)
|
|
99
|
+
pos_e = self.out_emb(pos) # (B, D)
|
|
100
|
+
neg_e = self.out_emb(neg) # (B, K, D)
|
|
101
|
+
|
|
102
|
+
pos_logits = t_sum(c * pos_e, axis=-1) # (B,)
|
|
103
|
+
neg_logits = t_sum(c[:, None, :] * neg_e, axis=-1) # (B, K)
|
|
104
|
+
return pos_logits, neg_logits
|
|
105
|
+
|
|
106
|
+
def fit(self,
|
|
107
|
+
centers: np.ndarray,
|
|
108
|
+
contexts: np.ndarray,
|
|
109
|
+
num_epochs: int,
|
|
110
|
+
batch_size: int,
|
|
111
|
+
num_negative_samples: int,
|
|
112
|
+
noise_dist: np.ndarray,
|
|
113
|
+
shuffle_data: bool,
|
|
114
|
+
lr_step_per_batch: bool):
|
|
115
|
+
"""Train SGNS on the provided center/context pairs."""
|
|
116
|
+
_logger.info(
|
|
117
|
+
"SGNS_PureML fit: epochs=%d batch=%d negatives=%d shuffle=%s",
|
|
118
|
+
num_epochs, batch_size, num_negative_samples, shuffle_data
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if noise_dist.ndim != 1 or noise_dist.size != self.V:
|
|
122
|
+
raise ValueError(f"noise_dist must be 1-D with length {self.V}; got {noise_dist.shape}")
|
|
123
|
+
dist = np.asarray(noise_dist, dtype=np.float64)
|
|
124
|
+
if np.any(dist < 0):
|
|
125
|
+
raise ValueError("noise_dist has negative entries")
|
|
126
|
+
s = dist.sum()
|
|
127
|
+
if not np.isfinite(s) or s <= 0:
|
|
128
|
+
raise ValueError("noise_dist must have positive finite sum")
|
|
129
|
+
if abs(s - 1.0) > 1e-6:
|
|
130
|
+
dist = dist / s
|
|
131
|
+
|
|
132
|
+
data = TensorDataset(centers, contexts)
|
|
133
|
+
for epoch in range(1, num_epochs + 1):
|
|
134
|
+
epoch_loss = 0.0
|
|
135
|
+
batches = 0
|
|
136
|
+
|
|
137
|
+
for cen, pos in DataLoader(data, batch_size=batch_size, shuffle=shuffle_data):
|
|
138
|
+
B = cen.data.shape[0] if isinstance(cen, Tensor) else len(cen)
|
|
139
|
+
|
|
140
|
+
neg_idx_np = self._sample_neg(B, num_negative_samples, dist)
|
|
141
|
+
neg = Tensor(neg_idx_np, requires_grad=False)
|
|
142
|
+
x_pos_logits, x_neg_logits = self(cen, pos, neg)
|
|
143
|
+
|
|
144
|
+
y_pos = Tensor(np.ones_like(x_pos_logits.numpy(copy=False)), requires_grad=False)
|
|
145
|
+
y_neg = Tensor(np.zeros_like(x_neg_logits.numpy(copy=False)), requires_grad=False)
|
|
146
|
+
|
|
147
|
+
K = int(neg.data.shape[1])
|
|
148
|
+
loss = (
|
|
149
|
+
BCE(y_pos, x_pos_logits, from_logits=True)
|
|
150
|
+
+ K*BCE(y_neg, x_neg_logits, from_logits=True)
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
self.optim.zero_grad()
|
|
154
|
+
loss.backward()
|
|
155
|
+
self.optim.step()
|
|
156
|
+
|
|
157
|
+
if lr_step_per_batch and self.lr_sched is not None:
|
|
158
|
+
self.lr_sched.step()
|
|
159
|
+
|
|
160
|
+
loss_value = float(np.asarray(loss.data))
|
|
161
|
+
epoch_loss += loss_value
|
|
162
|
+
batches += 1
|
|
163
|
+
_logger.debug("Epoch %d batch %d loss=%.6f", epoch, batches, loss_value)
|
|
164
|
+
|
|
165
|
+
if (not lr_step_per_batch) and (self.lr_sched is not None):
|
|
166
|
+
self.lr_sched.step()
|
|
167
|
+
|
|
168
|
+
mean_loss = epoch_loss / max(batches, 1)
|
|
169
|
+
_logger.info("Epoch %d/%d mean_loss=%.6f", epoch, num_epochs, mean_loss)
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def in_embeddings(self) -> np.ndarray:
|
|
173
|
+
W: Tensor = self.in_emb.parameters[0] # (V, D)
|
|
174
|
+
if W.shape != (self.V, self.D):
|
|
175
|
+
raise RuntimeError(
|
|
176
|
+
"Wrong embedding matrix shape: "
|
|
177
|
+
"self.in_emb.parameters[0].shape != (V, D)"
|
|
178
|
+
)
|
|
179
|
+
return W.numpy(copy=True, readonly=True)
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def out_embeddings(self) -> np.ndarray:
|
|
183
|
+
W: Tensor = self.out_emb.parameters[0] # (V, D)
|
|
184
|
+
if W.shape != (self.V, self.D):
|
|
185
|
+
raise RuntimeError(
|
|
186
|
+
"Wrong embedding matrix shape: "
|
|
187
|
+
"self.out_emb.parameters[0].shape != (V, D)"
|
|
188
|
+
)
|
|
189
|
+
return W.numpy(copy=True, readonly=True)
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def avg_embeddings(self) -> np.ndarray:
|
|
193
|
+
return 0.5 * (self.in_embeddings + self.out_embeddings)
|
|
194
|
+
|
|
195
|
+
class SG_PureML(NN):
|
|
196
|
+
"""Plain Skip-Gram (full softmax) in PureML.
|
|
197
|
+
|
|
198
|
+
Trains two affine layers to emulate the classic Skip-Gram objective with a
|
|
199
|
+
**full** softmax over the vocabulary (no negative sampling):
|
|
200
|
+
|
|
201
|
+
x = one_hot(center, V) # (B, V)
|
|
202
|
+
y = x @ W_in + b_in # (B, D)
|
|
203
|
+
logits = y @ W_out + b_out # (B, V)
|
|
204
|
+
loss = CCE(one_hot(context, V), logits, from_logits=True)
|
|
205
|
+
|
|
206
|
+
The learnable “input” embeddings are the rows of `W_in` (shape `(V, D)`), and
|
|
207
|
+
the “output” embeddings are the rows of `W_outᵀ` (also `(V, D)`).
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def __init__(self,
|
|
211
|
+
V: int,
|
|
212
|
+
D: int,
|
|
213
|
+
*,
|
|
214
|
+
seed: int | None = None,
|
|
215
|
+
optim: Type[Optim],
|
|
216
|
+
optim_kwargs: dict,
|
|
217
|
+
lr_sched: Type[LRScheduler] | None = None,
|
|
218
|
+
lr_sched_kwargs: dict | None = None,
|
|
219
|
+
device: str | None = None):
|
|
220
|
+
"""Initialize the plain Skip-Gram model (full softmax).
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
V: Vocabulary size (number of nodes/tokens).
|
|
224
|
+
D: Embedding dimensionality.
|
|
225
|
+
seed: Optional RNG seed (kept for API parity; not used in layer init).
|
|
226
|
+
optim: Optimizer class to instantiate (e.g., `Adam`, `SGD`).
|
|
227
|
+
optim_kwargs: Keyword arguments passed to the optimizer constructor.
|
|
228
|
+
lr_sched: Optional learning-rate scheduler class.
|
|
229
|
+
lr_sched_kwargs: Keyword arguments for the scheduler
|
|
230
|
+
(required if `lr_sched` is provided).
|
|
231
|
+
device: Device string (e.g., `"cuda"`). Accepted for parity, ignored
|
|
232
|
+
by PureML (CPU-only).
|
|
233
|
+
|
|
234
|
+
Notes:
|
|
235
|
+
The encoder/decoder are implemented as:
|
|
236
|
+
• `in_emb = Affine(V, D)` (acts on a one-hot center index)
|
|
237
|
+
• `out_emb = Affine(D, V)`
|
|
238
|
+
so forward pass produces vocabulary-sized logits.
|
|
239
|
+
"""
|
|
240
|
+
if optim_kwargs is None:
|
|
241
|
+
raise ValueError("optim_kwargs must be provided")
|
|
242
|
+
if lr_sched is not None and lr_sched_kwargs is None:
|
|
243
|
+
raise ValueError("lr_sched_kwargs required when lr_sched is provided")
|
|
244
|
+
|
|
245
|
+
self.V, self.D = int(V), int(D)
|
|
246
|
+
|
|
247
|
+
# input/output “embedding” projections
|
|
248
|
+
self.in_emb = Affine(self.V, self.D)
|
|
249
|
+
self.out_emb = Affine(self.D, self.V)
|
|
250
|
+
|
|
251
|
+
self.seed = None if seed is None else int(seed)
|
|
252
|
+
|
|
253
|
+
# API compatibility: PureML is CPU-only
|
|
254
|
+
self.device = "cpu"
|
|
255
|
+
|
|
256
|
+
# optimizer / scheduler
|
|
257
|
+
self.optim: Optim = optim(self.parameters, **optim_kwargs)
|
|
258
|
+
self.lr_sched: LRScheduler | None = (
|
|
259
|
+
lr_sched(optim=self.optim, **lr_sched_kwargs) if lr_sched is not None else None
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
_logger.info(
|
|
263
|
+
"SG_PureML init: V=%d D=%d device=%s seed=%s",
|
|
264
|
+
self.V, self.D, self.device, self.seed
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
def predict(self, center: Tensor) -> Tensor:
|
|
268
|
+
"""Return vocabulary logits for each center index.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
center: Tensor of center indices with shape `(B,)` and integer dtype.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
Tensor: Logits over the vocabulary with shape `(B, V)`.
|
|
275
|
+
"""
|
|
276
|
+
c = one_hot(dims=self.V, label=center) # (B, V)
|
|
277
|
+
y = self.in_emb(c) # (B, D)
|
|
278
|
+
z = self.out_emb(y) # (B, V)
|
|
279
|
+
return z
|
|
280
|
+
|
|
281
|
+
def fit(self,
|
|
282
|
+
centers: np.ndarray,
|
|
283
|
+
contexts: np.ndarray,
|
|
284
|
+
num_epochs: int,
|
|
285
|
+
batch_size: int,
|
|
286
|
+
shuffle_data: bool,
|
|
287
|
+
lr_step_per_batch: bool,
|
|
288
|
+
**_ignore):
|
|
289
|
+
"""Train Skip-Gram with full softmax on center/context pairs.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
centers: Array of center indices, shape `(N,)`, dtype integer in `[0, V)`.
|
|
293
|
+
contexts: Array of context (target) indices, shape `(N,)`, dtype integer.
|
|
294
|
+
num_epochs: Number of passes over the dataset.
|
|
295
|
+
batch_size: Mini-batch size.
|
|
296
|
+
shuffle_data: Whether to shuffle pairs each epoch.
|
|
297
|
+
lr_step_per_batch: If True, call `lr_sched.step()` after every batch
|
|
298
|
+
(when a scheduler is provided). If False, step once per epoch.
|
|
299
|
+
**_ignore: Ignored kwargs for API compatibility with SGNS.
|
|
300
|
+
|
|
301
|
+
Optimization:
|
|
302
|
+
Uses `CCE(one_hot(context), logits, from_logits=True)` where
|
|
303
|
+
`logits = predict(center)`. Scheduler stepping obeys `lr_step_per_batch`.
|
|
304
|
+
"""
|
|
305
|
+
_logger.info(
|
|
306
|
+
"SG_PureML fit: epochs=%d batch=%d shuffle=%s",
|
|
307
|
+
num_epochs, batch_size, shuffle_data
|
|
308
|
+
)
|
|
309
|
+
data = TensorDataset(centers, contexts)
|
|
310
|
+
|
|
311
|
+
for epoch in range(1, num_epochs + 1):
|
|
312
|
+
epoch_loss = 0.0
|
|
313
|
+
batches = 0
|
|
314
|
+
|
|
315
|
+
for cen, ctx in DataLoader(data, batch_size=batch_size, shuffle=shuffle_data):
|
|
316
|
+
logits = self(cen) # (B, V)
|
|
317
|
+
y = one_hot(self.V, label=ctx) # (B, V)
|
|
318
|
+
loss = CCE(y, logits, from_logits=True) # scalar
|
|
319
|
+
|
|
320
|
+
self.optim.zero_grad()
|
|
321
|
+
loss.backward()
|
|
322
|
+
self.optim.step()
|
|
323
|
+
|
|
324
|
+
if lr_step_per_batch and self.lr_sched is not None:
|
|
325
|
+
self.lr_sched.step()
|
|
326
|
+
|
|
327
|
+
loss_value = float(np.asarray(loss.data))
|
|
328
|
+
epoch_loss += loss_value
|
|
329
|
+
batches += 1
|
|
330
|
+
_logger.debug("Epoch %d batch %d loss=%.6f", epoch, batches, loss_value)
|
|
331
|
+
|
|
332
|
+
if (not lr_step_per_batch) and (self.lr_sched is not None):
|
|
333
|
+
self.lr_sched.step()
|
|
334
|
+
|
|
335
|
+
mean_loss = epoch_loss / max(batches, 1)
|
|
336
|
+
_logger.info("Epoch %d/%d mean_loss=%.6f", epoch, num_epochs, mean_loss)
|
|
337
|
+
|
|
338
|
+
@property
|
|
339
|
+
def in_embeddings(self) -> np.ndarray:
|
|
340
|
+
"""Input embeddings matrix `W_in` as `(V, D)` (copy, read-only)."""
|
|
341
|
+
W = self.in_emb.parameters[0] # (V, D)
|
|
342
|
+
if W.shape != (self.V, self.D):
|
|
343
|
+
raise RuntimeError(
|
|
344
|
+
"Wrong embedding matrix shape: "
|
|
345
|
+
"self.in_emb.parameters[0].shape != (V, D)"
|
|
346
|
+
)
|
|
347
|
+
return W.numpy(copy=True, readonly=True) # (V, D)
|
|
348
|
+
|
|
349
|
+
@property
|
|
350
|
+
def out_embeddings(self) -> np.ndarray:
|
|
351
|
+
"""Output embeddings matrix `W_outᵀ` as `(V, D)` (copy, read-only).
|
|
352
|
+
(`out_emb.parameters[0]` is `(D, V)`, so we transpose.)"""
|
|
353
|
+
W = self.out_emb.parameters[0] # (D, V)
|
|
354
|
+
if W.shape != (self.D, self.V):
|
|
355
|
+
raise RuntimeError(
|
|
356
|
+
"Wrong embedding matrix shape: "
|
|
357
|
+
"self.out_emb.parameters[0].shape != (D, V)"
|
|
358
|
+
)
|
|
359
|
+
return W.numpy(copy=True, readonly=True).T # (V, D)
|
|
360
|
+
|
|
361
|
+
@property
|
|
362
|
+
def avg_embeddings(self) -> np.ndarray:
|
|
363
|
+
"""Elementwise average of input/output embeddings, shape `(V, D)`."""
|
|
364
|
+
return 0.5 * (self.in_embeddings + self.out_embeddings) # (V, D)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
__all__ = ["SGNS_PureML", "SG_PureML"]
|
|
368
|
+
|
|
369
|
+
if __name__ == "__main__":
|
|
370
|
+
pass
|