sawnergy 1.0.7__py3-none-any.whl → 1.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sawnergy might be problematic. Click here for more details.
- sawnergy/embedding/SGNS_pml.py +36 -38
- sawnergy/embedding/SGNS_torch.py +82 -29
- sawnergy/embedding/embedder.py +325 -245
- sawnergy/embedding/visualizer.py +9 -5
- {sawnergy-1.0.7.dist-info → sawnergy-1.0.8.dist-info}/METADATA +39 -40
- {sawnergy-1.0.7.dist-info → sawnergy-1.0.8.dist-info}/RECORD +10 -10
- {sawnergy-1.0.7.dist-info → sawnergy-1.0.8.dist-info}/WHEEL +0 -0
- {sawnergy-1.0.7.dist-info → sawnergy-1.0.8.dist-info}/licenses/LICENSE +0 -0
- {sawnergy-1.0.7.dist-info → sawnergy-1.0.8.dist-info}/licenses/NOTICE +0 -0
- {sawnergy-1.0.7.dist-info → sawnergy-1.0.8.dist-info}/top_level.txt +0 -0
sawnergy/embedding/embedder.py
CHANGED
|
@@ -1,15 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
"""
|
|
4
|
-
Embedding orchestration for Skip-Gram (SG) and Skip-Gram with Negative Sampling (SGNS).
|
|
5
|
-
|
|
6
|
-
This module consumes attractive/repulsive walk corpora produced by the walker
|
|
7
|
-
pipeline and trains per-frame embeddings using either the PyTorch or PureML
|
|
8
|
-
implementations of SG/SGNS. The resulting embeddings can be persisted back into
|
|
9
|
-
an ``ArrayStorage`` archive along with rich metadata describing the training
|
|
10
|
-
configuration.
|
|
11
|
-
"""
|
|
12
|
-
|
|
13
3
|
# third-pary
|
|
14
4
|
import numpy as np
|
|
15
5
|
|
|
@@ -36,10 +26,8 @@ class Embedder:
|
|
|
36
26
|
|
|
37
27
|
def __init__(self,
|
|
38
28
|
WALKS_path: str | Path,
|
|
39
|
-
base: Literal["torch", "pureml"],
|
|
40
29
|
*,
|
|
41
30
|
seed: int | None = None,
|
|
42
|
-
objective: Literal["sgns", "sg"] = "sgns"
|
|
43
31
|
) -> None:
|
|
44
32
|
"""Initialize the embedder and load walk tensors.
|
|
45
33
|
|
|
@@ -51,24 +39,19 @@ class Embedder:
|
|
|
51
39
|
``None`` if that collection is absent), and the metadata
|
|
52
40
|
``num_RWs``, ``num_SAWs``, ``node_count``, ``time_stamp_count``,
|
|
53
41
|
``walk_length``.
|
|
54
|
-
base: Which SGNS backend to use, either ``"torch"`` or ``"pureml"``.
|
|
55
42
|
seed: Optional seed for the embedder's RNG. If ``None``, a random
|
|
56
43
|
32-bit seed is chosen.
|
|
57
|
-
objective: Training objective, either ``"sgns"`` (negative sampling)
|
|
58
|
-
or ``"sg"`` (plain full-softmax Skip-Gram).
|
|
59
44
|
|
|
60
45
|
Raises:
|
|
61
46
|
ValueError: If required metadata is missing or any loaded walk array
|
|
62
47
|
has an unexpected shape.
|
|
63
|
-
ImportError: If the requested backend is not installed.
|
|
64
|
-
NameError: If ``base`` is not one of ``{"torch","pureml"}``.
|
|
65
48
|
|
|
66
49
|
Notes:
|
|
67
50
|
- Walks in storage are 1-based (residue indexing). Internally, this
|
|
68
51
|
class normalizes to 0-based indices for training utilities.
|
|
69
52
|
"""
|
|
70
53
|
self._walks_path = Path(WALKS_path)
|
|
71
|
-
_logger.info("Initializing Embedder from %s
|
|
54
|
+
_logger.info("Initializing Embedder from %s", self._walks_path)
|
|
72
55
|
|
|
73
56
|
# placeholders for optional walk collections
|
|
74
57
|
self.attractive_RWs : np.ndarray | None = None
|
|
@@ -127,58 +110,76 @@ class Embedder:
|
|
|
127
110
|
RWs_expected = (time_stamp_count, node_count * num_RWs, walk_length+1) if (num_RWs > 0) else None
|
|
128
111
|
SAWs_expected = (time_stamp_count, node_count * num_SAWs, walk_length+1) if (num_SAWs > 0) else None
|
|
129
112
|
|
|
130
|
-
self.vocab_size
|
|
131
|
-
self.frame_count
|
|
132
|
-
self.walk_length
|
|
113
|
+
self.vocab_size = int(node_count)
|
|
114
|
+
self.frame_count = int(time_stamp_count)
|
|
115
|
+
self.walk_length = int(walk_length)
|
|
116
|
+
self.num_RWs = int(num_RWs)
|
|
117
|
+
self.num_SAWs = int(num_SAWs)
|
|
118
|
+
# Keep dataset names for metadata passthrough
|
|
119
|
+
self._attractive_RWs_name = attractive_RWs_name
|
|
120
|
+
self._repulsive_RWs_name = repulsive_RWs_name
|
|
121
|
+
self._attractive_SAWs_name = attractive_SAWs_name
|
|
122
|
+
self._repulsive_SAWs_name = repulsive_SAWs_name
|
|
133
123
|
|
|
134
124
|
# store walks if present
|
|
135
125
|
if attractive_RWs is not None:
|
|
136
126
|
if RWs_expected and attractive_RWs.shape != RWs_expected:
|
|
137
127
|
raise ValueError(f"ATTR RWs: expected {RWs_expected}, got {attractive_RWs.shape}")
|
|
138
128
|
self.attractive_RWs = attractive_RWs
|
|
129
|
+
_logger.debug("ATTR RWs loaded: %s", self.attractive_RWs.shape)
|
|
139
130
|
|
|
140
131
|
if repulsive_RWs is not None:
|
|
141
132
|
if RWs_expected and repulsive_RWs.shape != RWs_expected:
|
|
142
133
|
raise ValueError(f"REP RWs: expected {RWs_expected}, got {repulsive_RWs.shape}")
|
|
143
134
|
self.repulsive_RWs = repulsive_RWs
|
|
135
|
+
_logger.debug("REP RWs loaded: %s", self.repulsive_RWs.shape)
|
|
144
136
|
|
|
145
137
|
if attractive_SAWs is not None:
|
|
146
138
|
if SAWs_expected and attractive_SAWs.shape != SAWs_expected:
|
|
147
139
|
raise ValueError(f"ATTR SAWs: expected {SAWs_expected}, got {attractive_SAWs.shape}")
|
|
148
140
|
self.attractive_SAWs = attractive_SAWs
|
|
141
|
+
_logger.debug("ATTR SAWs loaded: %s", self.attractive_SAWs.shape)
|
|
149
142
|
|
|
150
143
|
if repulsive_SAWs is not None:
|
|
151
144
|
if SAWs_expected and repulsive_SAWs.shape != SAWs_expected:
|
|
152
145
|
raise ValueError(f"REP SAWs: expected {SAWs_expected}, got {repulsive_SAWs.shape}")
|
|
153
146
|
self.repulsive_SAWs = repulsive_SAWs
|
|
147
|
+
_logger.debug("REP SAWs loaded: %s", self.repulsive_SAWs.shape)
|
|
154
148
|
|
|
155
149
|
# INTERNAL RNG
|
|
156
150
|
self._seed = np.random.randint(0, 2**32 - 1) if seed is None else int(seed)
|
|
157
151
|
self.rng = np.random.default_rng(self._seed)
|
|
158
152
|
_logger.info("RNG initialized from seed=%d", self._seed)
|
|
159
153
|
|
|
160
|
-
# MODEL HANDLE
|
|
161
|
-
self.model_base: Literal["torch", "pureml"] = base
|
|
162
|
-
self.objective: Literal["sgns", "sg"] = objective
|
|
163
|
-
self.model_constructor = self._get_SGNS_constructor_from(base, objective)
|
|
164
|
-
_logger.info(
|
|
165
|
-
"SG backend resolved: %s (objective=%s)",
|
|
166
|
-
getattr(self.model_constructor, "__name__", repr(self.model_constructor)), self.objective
|
|
167
|
-
)
|
|
168
|
-
|
|
169
154
|
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- PRIVATE -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
|
170
155
|
|
|
171
156
|
# HELPERS:
|
|
172
157
|
|
|
173
158
|
@staticmethod
|
|
174
|
-
def
|
|
175
|
-
|
|
176
|
-
"""Resolve the SG/SGNS implementation class for the selected backend.
|
|
159
|
+
def _get_NN_constructor_from(base: Literal["torch", "pureml"],
|
|
160
|
+
objective: Literal["sgns", "sg"]):
|
|
161
|
+
"""Resolve the SG/SGNS implementation class for the selected backend.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
base: Backend family to use, ``"torch"`` or ``"pureml"``.
|
|
165
|
+
objective: Training objective, ``"sgns"`` or ``"sg"``.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
A callable class (constructor) implementing the requested model.
|
|
169
|
+
|
|
170
|
+
Raises:
|
|
171
|
+
ImportError: If the requested backend cannot be imported.
|
|
172
|
+
NameError: If ``base`` is not one of the supported values.
|
|
173
|
+
"""
|
|
174
|
+
_logger.debug("Resolving model constructor: base=%s objective=%s", base, objective)
|
|
177
175
|
if base == "torch":
|
|
178
176
|
try:
|
|
179
177
|
from .SGNS_torch import SGNS_Torch, SG_Torch
|
|
180
|
-
|
|
178
|
+
ctor = SG_Torch if objective == "sg" else SGNS_Torch
|
|
179
|
+
_logger.debug("Resolved PyTorch class: %s", getattr(ctor, "__name__", str(ctor)))
|
|
180
|
+
return ctor
|
|
181
181
|
except Exception:
|
|
182
|
+
_logger.exception("Failed to import PyTorch backend.")
|
|
182
183
|
raise ImportError(
|
|
183
184
|
"PyTorch is not installed, but base='torch' was requested. "
|
|
184
185
|
"Install PyTorch first, e.g.: `pip install torch` "
|
|
@@ -187,8 +188,11 @@ class Embedder:
|
|
|
187
188
|
elif base == "pureml":
|
|
188
189
|
try:
|
|
189
190
|
from .SGNS_pml import SGNS_PureML, SG_PureML
|
|
190
|
-
|
|
191
|
+
ctor = SG_PureML if objective == "sg" else SGNS_PureML
|
|
192
|
+
_logger.debug("Resolved PureML class: %s", getattr(ctor, "__name__", str(ctor)))
|
|
193
|
+
return ctor
|
|
191
194
|
except Exception:
|
|
195
|
+
_logger.exception("Failed to import PureML backend.")
|
|
192
196
|
raise ImportError(
|
|
193
197
|
"PureML is not installed, but base='pureml' was requested. "
|
|
194
198
|
"Install PureML first via `pip install ym-pure-ml` "
|
|
@@ -198,7 +202,18 @@ class Embedder:
|
|
|
198
202
|
|
|
199
203
|
@staticmethod
|
|
200
204
|
def _as_zerobase_intp(walks: np.ndarray, *, V: int) -> np.ndarray:
|
|
201
|
-
"""Validate 1-based
|
|
205
|
+
"""Validate and convert 1-based walks to 0-based ``intp``.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
walks: 2D array of node ids with 1-based indexing.
|
|
209
|
+
V: Vocabulary size for bounds checking.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
2D array of dtype ``intp`` with 0-based indices.
|
|
213
|
+
|
|
214
|
+
Raises:
|
|
215
|
+
ValueError: If shape is not 2D or indices are out of bounds.
|
|
216
|
+
"""
|
|
202
217
|
arr = np.asarray(walks)
|
|
203
218
|
if arr.ndim != 2:
|
|
204
219
|
raise ValueError("walks must be 2D: (num_walks, walk_len)")
|
|
@@ -206,7 +221,9 @@ class Embedder:
|
|
|
206
221
|
arr = arr.astype(np.int64, copy=False)
|
|
207
222
|
# 1-based → 0-based
|
|
208
223
|
arr = arr - 1
|
|
209
|
-
|
|
224
|
+
mn, mx = int(arr.min()), int(arr.max())
|
|
225
|
+
_logger.debug("Zero-basing walks: min=%d max=%d V=%d", mn, mx, V)
|
|
226
|
+
if mn < 0 or mx >= V:
|
|
210
227
|
raise ValueError("walk ids out of range after 1→0-based normalization")
|
|
211
228
|
return arr.astype(np.intp, copy=False)
|
|
212
229
|
|
|
@@ -214,19 +231,29 @@ class Embedder:
|
|
|
214
231
|
def _pairs_from_walks(walks0: np.ndarray, window_size: int) -> np.ndarray:
|
|
215
232
|
"""
|
|
216
233
|
Skip-gram pairs including edge centers (one-sided when needed).
|
|
217
|
-
|
|
218
|
-
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
walks0: (W, L) int array (0-based ids).
|
|
237
|
+
window_size: Symmetric context window radius.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
Array of shape (N_pairs, 2) int32 with columns [center, context].
|
|
241
|
+
|
|
242
|
+
Raises:
|
|
243
|
+
ValueError: If shape is invalid or ``window_size`` <= 0.
|
|
219
244
|
"""
|
|
220
245
|
if walks0.ndim != 2:
|
|
221
246
|
raise ValueError("walks must be 2D: (num_walks, walk_len)")
|
|
222
247
|
|
|
223
248
|
_, L = walks0.shape
|
|
224
249
|
k = int(window_size)
|
|
250
|
+
_logger.debug("Building SG pairs: L=%d, window=%d", L, k)
|
|
225
251
|
|
|
226
252
|
if k <= 0:
|
|
227
253
|
raise ValueError("window_size must be positive")
|
|
228
254
|
|
|
229
255
|
if L == 0:
|
|
256
|
+
_logger.debug("Empty walks length; returning 0 pairs.")
|
|
230
257
|
return np.empty((0, 2), dtype=np.int32)
|
|
231
258
|
|
|
232
259
|
out_chunks = []
|
|
@@ -244,18 +271,42 @@ class Embedder:
|
|
|
244
271
|
out_chunks.append(np.stack((centers_l, ctx_l), axis=2).reshape(-1, 2))
|
|
245
272
|
|
|
246
273
|
if not out_chunks:
|
|
274
|
+
_logger.debug("No offsets produced pairs; returning empty.")
|
|
247
275
|
return np.empty((0, 2), dtype=np.int32)
|
|
248
276
|
|
|
249
|
-
|
|
277
|
+
pairs = np.concatenate(out_chunks, axis=0).astype(np.int32, copy=False)
|
|
278
|
+
_logger.debug("Built %d training pairs", pairs.shape[0])
|
|
279
|
+
return pairs
|
|
250
280
|
|
|
251
281
|
@staticmethod
|
|
252
282
|
def _freq_from_walks(walks0: np.ndarray, *, V: int) -> np.ndarray:
|
|
253
|
-
"""Node frequencies from walks (0-based).
|
|
254
|
-
|
|
283
|
+
"""Node frequencies from walks (0-based).
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
walks0: 2D array of 0-based node ids.
|
|
287
|
+
V: Vocabulary size (minlength for bincount).
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
1D array of int64 frequencies with length ``V``.
|
|
291
|
+
"""
|
|
292
|
+
freq = np.bincount(walks0.ravel(), minlength=V).astype(np.int64, copy=False)
|
|
293
|
+
_logger.debug("Frequency mass: total=%d nonzero=%d", int(freq.sum()), int(np.count_nonzero(freq)))
|
|
294
|
+
return freq
|
|
255
295
|
|
|
256
296
|
@staticmethod
|
|
257
297
|
def _soft_unigram(freq: np.ndarray, *, power: float = 0.75) -> np.ndarray:
|
|
258
|
-
"""Return normalized Pn(w) ∝ f(w)^power as float64 probs.
|
|
298
|
+
"""Return normalized Pn(w) ∝ f(w)^power as float64 probs.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
freq: 1D array of token frequencies.
|
|
302
|
+
power: Exponent used for smoothing (default 0.75 à la word2vec).
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
1D array of probabilities summing to 1.0.
|
|
306
|
+
|
|
307
|
+
Raises:
|
|
308
|
+
ValueError: If mass is invalid (all zeros or non-finite).
|
|
309
|
+
"""
|
|
259
310
|
p = np.asarray(freq, dtype=np.float64)
|
|
260
311
|
if p.sum() == 0:
|
|
261
312
|
raise ValueError("all frequencies are zero")
|
|
@@ -263,13 +314,31 @@ class Embedder:
|
|
|
263
314
|
s = p.sum()
|
|
264
315
|
if not np.isfinite(s) or s <= 0:
|
|
265
316
|
raise ValueError("invalid unigram mass")
|
|
266
|
-
|
|
317
|
+
probs = p / s
|
|
318
|
+
_logger.debug("Noise distribution ready (power=%.3f)", power)
|
|
319
|
+
return probs
|
|
267
320
|
|
|
268
321
|
def _materialize_walks(self, frame_id: int, rin: Literal["attr", "repuls"],
|
|
269
322
|
using: Literal["RW", "SAW", "merged"]) -> np.ndarray:
|
|
323
|
+
"""Materialize the requested set of walks for a frame.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
frame_id: 1-based frame index.
|
|
327
|
+
rin: Which RIN to pull from: ``"attr"`` or ``"repuls"``.
|
|
328
|
+
using: Which walk sets to include: ``"RW"``, ``"SAW"``, or ``"merged"``.
|
|
329
|
+
If ``"merged"``, concatenate available RW and SAW along axis 0.
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
A 2D array of walks with shape (num_walks, walk_length+1).
|
|
333
|
+
|
|
334
|
+
Raises:
|
|
335
|
+
IndexError: If ``frame_id`` is out of range.
|
|
336
|
+
ValueError: If no matching walks are available.
|
|
337
|
+
"""
|
|
270
338
|
if not 1 <= frame_id <= int(self.frame_count):
|
|
271
339
|
raise IndexError(f"frame_id must be in [1, {self.frame_count}]; got {frame_id}")
|
|
272
340
|
|
|
341
|
+
_logger.debug("Materializing %s walks at frame=%d using=%s", rin, frame_id, using)
|
|
273
342
|
frame_id -= 1
|
|
274
343
|
|
|
275
344
|
if rin == "attr":
|
|
@@ -296,8 +365,12 @@ class Embedder:
|
|
|
296
365
|
if not parts:
|
|
297
366
|
raise ValueError(f"No walks available for {rin=} with {using=}")
|
|
298
367
|
if len(parts) == 1:
|
|
299
|
-
|
|
300
|
-
|
|
368
|
+
out = parts[0]
|
|
369
|
+
else:
|
|
370
|
+
out = np.concatenate(parts, axis=0)
|
|
371
|
+
|
|
372
|
+
_logger.debug("Materialized walks shape: %s", getattr(out, "shape", None))
|
|
373
|
+
return out
|
|
301
374
|
|
|
302
375
|
# INTERFACES: (private)
|
|
303
376
|
|
|
@@ -306,6 +379,17 @@ class Embedder:
|
|
|
306
379
|
using: Literal["RW", "SAW", "merged"],
|
|
307
380
|
window_size: int,
|
|
308
381
|
alpha: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
|
|
382
|
+
"""Construct (center, context) pairs and noise distribution for ATTR.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
frame_id: 1-based frame index.
|
|
386
|
+
using: Walk subset to include.
|
|
387
|
+
window_size: Skip-gram window radius.
|
|
388
|
+
alpha: Unigram smoothing exponent.
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
Tuple of (pairs, noise_probs).
|
|
392
|
+
"""
|
|
309
393
|
walks = self._materialize_walks(frame_id, "attr", using)
|
|
310
394
|
walks0 = self._as_zerobase_intp(walks, V=self.vocab_size)
|
|
311
395
|
attractive_corpus = self._pairs_from_walks(walks0, window_size)
|
|
@@ -319,6 +403,17 @@ class Embedder:
|
|
|
319
403
|
using: Literal["RW", "SAW", "merged"],
|
|
320
404
|
window_size: int,
|
|
321
405
|
alpha: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
|
|
406
|
+
"""Construct (center, context) pairs and noise distribution for REP.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
frame_id: 1-based frame index.
|
|
410
|
+
using: Walk subset to include.
|
|
411
|
+
window_size: Skip-gram window radius.
|
|
412
|
+
alpha: Unigram smoothing exponent.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
Tuple of (pairs, noise_probs).
|
|
416
|
+
"""
|
|
322
417
|
walks = self._materialize_walks(frame_id, "repuls", using)
|
|
323
418
|
walks0 = self._as_zerobase_intp(walks, V=self.vocab_size)
|
|
324
419
|
repulsive_corpus = self._pairs_from_walks(walks0, window_size)
|
|
@@ -333,61 +428,53 @@ class Embedder:
|
|
|
333
428
|
frame_id: int,
|
|
334
429
|
RIN_type: Literal["attr", "repuls"],
|
|
335
430
|
using: Literal["RW", "SAW", "merged"],
|
|
336
|
-
window_size: int,
|
|
337
|
-
num_negative_samples: int,
|
|
338
431
|
num_epochs: int,
|
|
339
|
-
|
|
432
|
+
negative_sampling: bool = False,
|
|
433
|
+
window_size: int = 5,
|
|
434
|
+
num_negative_samples: int = 10,
|
|
435
|
+
batch_size: int = 1024,
|
|
340
436
|
*,
|
|
341
437
|
lr_step_per_batch: bool = False,
|
|
342
438
|
shuffle_data: bool = True,
|
|
343
439
|
dimensionality: int = 128,
|
|
344
440
|
alpha: float = 0.75,
|
|
345
441
|
device: str | None = None,
|
|
346
|
-
|
|
442
|
+
model_base: Literal["torch", "pureml"] = "pureml",
|
|
443
|
+
model_kwargs: dict[str, object] | None = None,
|
|
347
444
|
kind: Literal["in", "out", "avg"] = "in",
|
|
348
|
-
_seed: int | None = None
|
|
349
|
-
objective: Literal["sgns", "sg"] | None = None
|
|
445
|
+
_seed: int | None = None
|
|
350
446
|
) -> np.ndarray:
|
|
351
|
-
"""Train embeddings for a single frame and return the
|
|
447
|
+
"""Train embeddings for a single frame and return the matrix.
|
|
352
448
|
|
|
353
449
|
Args:
|
|
354
|
-
frame_id: 1-based frame index to
|
|
355
|
-
RIN_type:
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
window_size:
|
|
360
|
-
num_negative_samples:
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
provided then ``lr_sched_kwargs`` must also be provided.
|
|
372
|
-
kind: Which embedding matrix to return: ``"in"``, ``"out"``, or ``"avg"``.
|
|
373
|
-
_seed: Optional child seed for this frame's model initialization.
|
|
374
|
-
objective: Training objective override for this call (``"sgns"`` or
|
|
375
|
-
``"sg"``). If ``None``, uses the value set at construction.
|
|
450
|
+
frame_id: 1-based frame index to embed.
|
|
451
|
+
RIN_type: ``"attr"`` or ``"repuls"`` - which corpus to use.
|
|
452
|
+
using: Which walks to use (``"RW"``, ``"SAW"``, or ``"merged"``).
|
|
453
|
+
num_epochs: Number of passes over the pairs.
|
|
454
|
+
negative_sampling: If ``True``, use SGNS objective; else plain SG.
|
|
455
|
+
window_size: Skip-gram symmetric window radius.
|
|
456
|
+
num_negative_samples: Negatives per positive pair (SGNS only).
|
|
457
|
+
batch_size: Minibatch size for training.
|
|
458
|
+
lr_step_per_batch: If ``True``, step LR every batch (else per epoch).
|
|
459
|
+
shuffle_data: Shuffle pairs each epoch.
|
|
460
|
+
dimensionality: Embedding dimension ``D``.
|
|
461
|
+
alpha: Unigram smoothing power for noise distribution.
|
|
462
|
+
device: Optional backend device hint (e.g., ``"cuda"``).
|
|
463
|
+
model_base: Backend family (``"torch"`` or ``"pureml"``).
|
|
464
|
+
model_kwargs: Passed through to backend model constructor.
|
|
465
|
+
kind: Which embedding to return: ``"in"``, ``"out"``, or ``"avg"``.
|
|
466
|
+
_seed: Optional override seed for this frame.
|
|
376
467
|
|
|
377
468
|
Returns:
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
Raises:
|
|
381
|
-
ValueError: If requested walks are missing, if no training pairs are
|
|
382
|
-
generated, or if required ``sgns_kwargs`` for PureML are absent.
|
|
383
|
-
AttributeError: If the SGNS model does not expose embeddings via
|
|
384
|
-
``.embeddings`` or ``.parameters[0]``.
|
|
469
|
+
``(V, D)`` float32 embedding matrix.
|
|
385
470
|
"""
|
|
386
471
|
_logger.info(
|
|
387
|
-
"
|
|
388
|
-
frame_id, RIN_type, using,
|
|
472
|
+
"embed_frame: frame=%d RIN=%s using=%s base=%s D=%d epochs=%d batch=%d sgns=%s k=%d alpha=%.3f",
|
|
473
|
+
frame_id, RIN_type, using, model_base, dimensionality, num_epochs, batch_size,
|
|
474
|
+
str(negative_sampling), window_size, alpha
|
|
389
475
|
)
|
|
390
476
|
|
|
477
|
+
# ------------------ resolve training data -----------------
|
|
391
478
|
if RIN_type == "attr":
|
|
392
479
|
if self.attractive_RWs is None and self.attractive_SAWs is None:
|
|
393
480
|
raise ValueError("Attractive random walks are missing")
|
|
@@ -397,162 +484,124 @@ class Embedder:
|
|
|
397
484
|
raise ValueError("Repulsive random walks are missing")
|
|
398
485
|
pairs, noise_probs = self._repulsive_corpus_and_prob(frame_id=frame_id, using=using, window_size=window_size, alpha=alpha)
|
|
399
486
|
else:
|
|
400
|
-
raise
|
|
401
|
-
|
|
487
|
+
raise NameError(f"Unknown RIN_type: {RIN_type!r}")
|
|
402
488
|
if pairs.size == 0:
|
|
403
489
|
raise ValueError("No training pairs generated for the requested configuration")
|
|
490
|
+
# ----------------------------------------------------------
|
|
404
491
|
|
|
492
|
+
# ---------------- construct training corpus ---------------
|
|
405
493
|
centers = pairs[:, 0].astype(np.int64, copy=False)
|
|
406
494
|
contexts = pairs[:, 1].astype(np.int64, copy=False)
|
|
495
|
+
_logger.debug("Pairs split: centers=%s contexts=%s", centers.shape, contexts.shape)
|
|
496
|
+
# ----------------------------------------------------------
|
|
407
497
|
|
|
408
|
-
|
|
409
|
-
if
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
if has_sched and not has_sched_kwargs:
|
|
417
|
-
raise ValueError("When providing lr_sched for PureML, you must also provide lr_sched_kwargs.")
|
|
418
|
-
|
|
419
|
-
child_seed = int(self._seed if _seed is None else _seed)
|
|
420
|
-
model_kwargs.update({
|
|
498
|
+
# ------------ resolve model_constructor kwargs ------------
|
|
499
|
+
if model_kwargs is not None:
|
|
500
|
+
if (("lr_sched" in model_kwargs and model_kwargs.get("lr_sched", None) is not None)
|
|
501
|
+
and ("lr_sched_kwargs" in model_kwargs and model_kwargs.get("lr_sched_kwargs", None) is None)):
|
|
502
|
+
raise ValueError("When `lr_sched`, you must also provide `lr_sched_kwargs`.")
|
|
503
|
+
|
|
504
|
+
constructor_kwargs: dict[str, object] = dict(model_kwargs or {})
|
|
505
|
+
constructor_kwargs.update({
|
|
421
506
|
"V": self.vocab_size,
|
|
422
507
|
"D": dimensionality,
|
|
423
|
-
"seed":
|
|
508
|
+
"seed": int(self._seed if _seed is None else _seed),
|
|
509
|
+
"device": device
|
|
424
510
|
})
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
centers,
|
|
451
|
-
contexts,
|
|
452
|
-
num_epochs,
|
|
453
|
-
batch_size,
|
|
454
|
-
num_negative_samples,
|
|
455
|
-
noise_probs,
|
|
456
|
-
shuffle_data,
|
|
457
|
-
lr_step_per_batch
|
|
458
|
-
)
|
|
459
|
-
else:
|
|
460
|
-
self.model.fit(
|
|
461
|
-
centers,
|
|
462
|
-
contexts,
|
|
463
|
-
num_epochs,
|
|
464
|
-
batch_size,
|
|
465
|
-
shuffle_data,
|
|
466
|
-
lr_step_per_batch
|
|
511
|
+
_logger.debug("Model constructor kwargs: %s", {k: constructor_kwargs[k] for k in ("V","D","seed","device")})
|
|
512
|
+
# ----------------------------------------------------------
|
|
513
|
+
|
|
514
|
+
# --------------- resolve model constructor ----------------
|
|
515
|
+
model_constructor = self._get_NN_constructor_from(
|
|
516
|
+
model_base, objective=("sgns" if negative_sampling else "sg"))
|
|
517
|
+
# ----------------------------------------------------------
|
|
518
|
+
|
|
519
|
+
# ------------------ initialize the model ------------------
|
|
520
|
+
model = model_constructor(**constructor_kwargs)
|
|
521
|
+
_logger.debug("Model initialized: %s", model_constructor.__name__ if hasattr(model_constructor,"__name__") else str(model_constructor))
|
|
522
|
+
# ----------------------------------------------------------
|
|
523
|
+
|
|
524
|
+
# -------------------- fitting the data --------------------
|
|
525
|
+
_logger.info("Fitting model on %d pairs ...", pairs.shape[0])
|
|
526
|
+
model.fit(centers=centers,
|
|
527
|
+
contexts=contexts,
|
|
528
|
+
num_epochs=num_epochs,
|
|
529
|
+
batch_size=batch_size,
|
|
530
|
+
# -- optional; for SGNS; safely ignored by SG via **_ignore --
|
|
531
|
+
num_negative_samples=num_negative_samples,
|
|
532
|
+
noise_dist=noise_probs,
|
|
533
|
+
# -----------------------------------------
|
|
534
|
+
shuffle_data=shuffle_data,
|
|
535
|
+
lr_step_per_batch=lr_step_per_batch
|
|
467
536
|
)
|
|
468
|
-
|
|
469
|
-
#
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
embeddings = params[0]
|
|
478
|
-
elif kind == "out":
|
|
479
|
-
embeddings = getattr(self.model, "out_embeddings", None)
|
|
480
|
-
else: # "avg"
|
|
481
|
-
embeddings = getattr(self.model, "avg_embeddings", None)
|
|
537
|
+
_logger.info("Training complete for frame %d", frame_id)
|
|
538
|
+
# ----------------------------------------------------------
|
|
539
|
+
|
|
540
|
+
# OUTPUT:
|
|
541
|
+
embeddings = (model.in_embeddings if kind == "in" else
|
|
542
|
+
model.out_embeddings if kind == "out" else
|
|
543
|
+
model.avg_embeddings if kind == "avg" else
|
|
544
|
+
None
|
|
545
|
+
)
|
|
482
546
|
|
|
483
547
|
if embeddings is None:
|
|
484
|
-
|
|
485
|
-
"
|
|
486
|
-
f"kind={kind!r}"
|
|
487
|
-
)
|
|
548
|
+
if kind not in ("in", "out", "avg"):
|
|
549
|
+
raise NameError(f"Unknown {kind} embeddings kind. Expected: one of ['in', 'out', 'avg']")
|
|
488
550
|
|
|
489
|
-
|
|
490
|
-
_logger.
|
|
491
|
-
return
|
|
551
|
+
E = np.asarray(embeddings, dtype=np.float32)
|
|
552
|
+
_logger.debug("Returned embeddings shape: %s dtype=%s", E.shape, E.dtype)
|
|
553
|
+
return E
|
|
492
554
|
|
|
493
555
|
def embed_all(
|
|
494
556
|
self,
|
|
495
557
|
RIN_type: Literal["attr", "repuls"],
|
|
496
558
|
using: Literal["RW", "SAW", "merged"],
|
|
497
|
-
window_size: int,
|
|
498
|
-
num_negative_samples: int,
|
|
499
559
|
num_epochs: int,
|
|
500
|
-
|
|
560
|
+
negative_sampling: bool = False,
|
|
561
|
+
window_size: int = 2,
|
|
562
|
+
num_negative_samples: int = 10,
|
|
563
|
+
batch_size: int = 1024,
|
|
501
564
|
*,
|
|
565
|
+
lr_step_per_batch: bool = False,
|
|
502
566
|
shuffle_data: bool = True,
|
|
503
567
|
dimensionality: int = 128,
|
|
504
568
|
alpha: float = 0.75,
|
|
505
569
|
device: str | None = None,
|
|
506
|
-
|
|
570
|
+
model_base: Literal["torch", "pureml"] = "pureml",
|
|
571
|
+
model_kwargs: dict[str, object] | None = None,
|
|
572
|
+
kind: Literal["in", "out", "avg"] = "in",
|
|
507
573
|
output_path: str | Path | None = None,
|
|
508
574
|
num_matrices_in_compressed_blocks: int = 20,
|
|
509
575
|
compression_level: int = 3,
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
"""Train embeddings for all frames and persist them to compressed storage.
|
|
576
|
+
) -> str:
|
|
577
|
+
"""Embed all frames and persist a self-contained archive.
|
|
513
578
|
|
|
514
|
-
|
|
515
|
-
per
|
|
516
|
-
|
|
579
|
+
The resulting file stores a block named ``FRAME_EMBEDDINGS`` with a
|
|
580
|
+
compressed sequence of per-frame matrices (each ``(V, D)``), alongside
|
|
581
|
+
rich metadata mirroring the style of other SAWNERGY modules.
|
|
517
582
|
|
|
518
583
|
Args:
|
|
519
|
-
RIN_type:
|
|
520
|
-
using:
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
batch_size:
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
store per compressed chunk in the output archive.
|
|
538
|
-
compression_level: Blosc Zstd compression level (0-9).
|
|
539
|
-
objective: Training objective for all frames (``"sgns"`` or ``"sg"``).
|
|
540
|
-
If ``None``, uses the value set at construction.
|
|
541
|
-
kind: Which embedding matrix to persist for each frame: ``"in"``,
|
|
542
|
-
``"out"``, or ``"avg"``.
|
|
584
|
+
RIN_type: ``"attr"`` or ``"repuls"`` - which corpus to use.
|
|
585
|
+
using: Which walks to use (``"RW"``, ``"SAW"``, or ``"merged"``).
|
|
586
|
+
num_epochs: Number of epochs to train per frame.
|
|
587
|
+
negative_sampling: If ``True``, use SGNS; otherwise plain SG.
|
|
588
|
+
window_size: Skip-gram window radius.
|
|
589
|
+
num_negative_samples: Negatives per positive pair (SGNS).
|
|
590
|
+
batch_size: Minibatch size for training.
|
|
591
|
+
lr_step_per_batch: If ``True``, step LR per batch (else per epoch).
|
|
592
|
+
shuffle_data: Shuffle pairs each epoch.
|
|
593
|
+
dimensionality: Embedding dimension.
|
|
594
|
+
alpha: Unigram smoothing power for noise distribution.
|
|
595
|
+
device: Backend device hint (e.g., ``"cuda"``).
|
|
596
|
+
model_base: Backend family (``"torch"`` or ``"pureml"``).
|
|
597
|
+
model_kwargs: Passed through to backend model constructor.
|
|
598
|
+
kind: Which embedding to store: ``"in"``, ``"out"``, or ``"avg"``.
|
|
599
|
+
output_path: Optional path for the output archive (``.zip`` inferred).
|
|
600
|
+
num_matrices_in_compressed_blocks: How many frames per compressed chunk.
|
|
601
|
+
compression_level: Integer compression level for the archive.
|
|
543
602
|
|
|
544
603
|
Returns:
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
Raises:
|
|
548
|
-
ValueError: If configuration produces no pairs for a frame or if
|
|
549
|
-
PureML kwargs are incomplete.
|
|
550
|
-
RuntimeError: Propagated from storage operations on failure.
|
|
551
|
-
|
|
552
|
-
Notes:
|
|
553
|
-
- A deterministic child seed is spawned per frame from the master
|
|
554
|
-
seed using ``np.random.SeedSequence`` to ensure reproducibility
|
|
555
|
-
across runs.
|
|
604
|
+
Path to the created embeddings archive, as ``str``.
|
|
556
605
|
"""
|
|
557
606
|
current_time = sawnergy_util.current_time()
|
|
558
607
|
if output_path is None:
|
|
@@ -563,71 +612,102 @@ class Embedder:
|
|
|
563
612
|
output_path = output_path.with_suffix(".zip")
|
|
564
613
|
|
|
565
614
|
_logger.info(
|
|
566
|
-
"
|
|
567
|
-
|
|
615
|
+
"embed_all: frames=%d D=%d base=%s RIN=%s using=%s out=%s",
|
|
616
|
+
self.frame_count, dimensionality, model_base, RIN_type, using, output_path
|
|
568
617
|
)
|
|
569
618
|
|
|
619
|
+
# Per-frame deterministic seeds
|
|
570
620
|
master_ss = np.random.SeedSequence(self._seed)
|
|
571
621
|
child_seeds = master_ss.spawn(self.frame_count)
|
|
572
622
|
|
|
573
|
-
embeddings = []
|
|
574
|
-
for
|
|
623
|
+
embeddings: list[np.ndarray] = []
|
|
624
|
+
for frame_id, seed_seq in enumerate(child_seeds, start=1):
|
|
575
625
|
child_seed = int(seed_seq.generate_state(1, dtype=np.uint32)[0])
|
|
576
|
-
_logger.info("
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
batch_size,
|
|
626
|
+
_logger.info("Embedding frame %d/%d with seed=%d", frame_id, self.frame_count, child_seed)
|
|
627
|
+
E = self.embed_frame(
|
|
628
|
+
frame_id=frame_id,
|
|
629
|
+
RIN_type=RIN_type,
|
|
630
|
+
using=using,
|
|
631
|
+
num_epochs=num_epochs,
|
|
632
|
+
negative_sampling=negative_sampling,
|
|
633
|
+
window_size=window_size,
|
|
634
|
+
num_negative_samples=num_negative_samples,
|
|
635
|
+
batch_size=batch_size,
|
|
636
|
+
lr_step_per_batch=lr_step_per_batch,
|
|
586
637
|
shuffle_data=shuffle_data,
|
|
587
638
|
dimensionality=dimensionality,
|
|
588
639
|
alpha=alpha,
|
|
589
640
|
device=device,
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
641
|
+
model_base=model_base,
|
|
642
|
+
model_kwargs=model_kwargs,
|
|
643
|
+
kind=kind,
|
|
644
|
+
_seed=child_seed
|
|
594
645
|
)
|
|
595
|
-
)
|
|
646
|
+
embeddings.append(np.asarray(E, dtype=np.float32, copy=False))
|
|
647
|
+
_logger.debug("Frame %d embedded: E.shape=%s", frame_id, E.shape)
|
|
596
648
|
|
|
597
|
-
embeddings = [np.asarray(e) for e in embeddings]
|
|
598
649
|
block_name = "FRAME_EMBEDDINGS"
|
|
599
650
|
with sawnergy_util.ArrayStorage.compress_and_cleanup(output_path, compression_level=compression_level) as storage:
|
|
651
|
+
_logger.info("Writing %d frame matrices to block '%s' ...", len(embeddings), block_name)
|
|
600
652
|
storage.write(
|
|
601
653
|
these_arrays=embeddings,
|
|
602
654
|
to_block_named=block_name,
|
|
603
655
|
arrays_per_chunk=num_matrices_in_compressed_blocks
|
|
604
656
|
)
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
storage.add_attr("
|
|
608
|
-
storage.add_attr("
|
|
609
|
-
storage.add_attr("
|
|
610
|
-
storage.add_attr("
|
|
611
|
-
|
|
657
|
+
|
|
658
|
+
# Core dataset discovery (for consumers like the Embeddings Visualizer)
|
|
659
|
+
storage.add_attr("frame_embeddings_name", block_name)
|
|
660
|
+
storage.add_attr("time_stamp_count", int(self.frame_count))
|
|
661
|
+
storage.add_attr("node_count", int(self.vocab_size))
|
|
662
|
+
storage.add_attr("embedding_dim", int(dimensionality))
|
|
663
|
+
|
|
664
|
+
# Provenance of input WALKS
|
|
665
|
+
storage.add_attr("source_WALKS_path", str(self._walks_path))
|
|
666
|
+
storage.add_attr("walk_length", int(self.walk_length))
|
|
667
|
+
storage.add_attr("num_RWs", int(self.num_RWs))
|
|
668
|
+
storage.add_attr("num_SAWs", int(self.num_SAWs))
|
|
669
|
+
storage.add_attr("attractive_RWs_name", self._attractive_RWs_name)
|
|
670
|
+
storage.add_attr("repulsive_RWs_name", self._repulsive_RWs_name)
|
|
671
|
+
storage.add_attr("attractive_SAWs_name", self._attractive_SAWs_name)
|
|
672
|
+
storage.add_attr("repulsive_SAWs_name", self._repulsive_SAWs_name)
|
|
673
|
+
|
|
674
|
+
# Training configuration (sufficient to reproduce)
|
|
675
|
+
storage.add_attr("objective", "sgns" if negative_sampling else "sg")
|
|
676
|
+
storage.add_attr("model_base", model_base)
|
|
677
|
+
storage.add_attr("embedding_kind", kind) # 'in' | 'out' | 'avg'
|
|
678
|
+
storage.add_attr("num_epochs", int(num_epochs))
|
|
679
|
+
storage.add_attr("batch_size", int(batch_size))
|
|
612
680
|
storage.add_attr("window_size", int(window_size))
|
|
613
681
|
storage.add_attr("alpha", float(alpha))
|
|
614
|
-
storage.add_attr("
|
|
682
|
+
storage.add_attr("negative_sampling", bool(negative_sampling))
|
|
615
683
|
storage.add_attr("num_negative_samples", int(num_negative_samples))
|
|
616
|
-
storage.add_attr("
|
|
617
|
-
storage.add_attr("batch_size", int(batch_size))
|
|
684
|
+
storage.add_attr("lr_step_per_batch", bool(lr_step_per_batch))
|
|
618
685
|
storage.add_attr("shuffle_data", bool(shuffle_data))
|
|
619
|
-
storage.add_attr("
|
|
620
|
-
storage.add_attr("
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
storage.add_attr("
|
|
686
|
+
storage.add_attr("device_hint", device if device is not None else "")
|
|
687
|
+
storage.add_attr("model_kwargs_repr", repr(model_kwargs) if model_kwargs is not None else "{}")
|
|
688
|
+
|
|
689
|
+
# Which walks were used to train
|
|
690
|
+
storage.add_attr("RIN_type", RIN_type) # 'attr' or 'repuls'
|
|
691
|
+
storage.add_attr("using", using) # 'RW' | 'SAW' | 'merged'
|
|
692
|
+
|
|
693
|
+
# Reproducibility
|
|
694
|
+
storage.add_attr("master_seed", int(self._seed))
|
|
695
|
+
# Note: this records seeds derived from child SeedSequences at metadata time.
|
|
696
|
+
storage.add_attr("per_frame_seeds", [int(s.generate_state(1, dtype=np.uint32)[0]) for s in child_seeds])
|
|
697
|
+
|
|
698
|
+
# Archive/IO details
|
|
624
699
|
storage.add_attr("arrays_per_chunk", int(num_matrices_in_compressed_blocks))
|
|
625
700
|
storage.add_attr("compression_level", int(compression_level))
|
|
626
|
-
storage.add_attr("
|
|
701
|
+
storage.add_attr("created_at", current_time)
|
|
702
|
+
|
|
703
|
+
_logger.info(
|
|
704
|
+
"Stored embeddings archive: %s | shape=(T,N,D)=(%d,%d,%d)",
|
|
705
|
+
output_path, self.frame_count, self.vocab_size, dimensionality
|
|
706
|
+
)
|
|
627
707
|
|
|
628
|
-
_logger.info("Embedding archive written to %s", output_path)
|
|
629
708
|
return str(output_path)
|
|
630
709
|
|
|
710
|
+
|
|
631
711
|
__all__ = ["Embedder"]
|
|
632
712
|
|
|
633
713
|
if __name__ == "__main__":
|