sawnergy 1.0.3__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,15 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- """
4
- Embedding orchestration for Skip-Gram with Negative Sampling (SGNS).
5
-
6
- This module consumes attractive/repulsive walk corpora produced by the walker
7
- pipeline and trains per-frame embeddings using either the PyTorch or PureML
8
- implementations of SGNS. The resulting embeddings can be persisted back into
9
- an ``ArrayStorage`` archive along with rich metadata describing the training
10
- configuration.
11
- """
12
-
13
3
  # third-pary
14
4
  import numpy as np
15
5
 
@@ -36,9 +26,8 @@ class Embedder:
36
26
 
37
27
  def __init__(self,
38
28
  WALKS_path: str | Path,
39
- base: Literal["torch", "pureml"],
40
29
  *,
41
- seed: int | None = None
30
+ seed: int | None = None,
42
31
  ) -> None:
43
32
  """Initialize the embedder and load walk tensors.
44
33
 
@@ -50,22 +39,19 @@ class Embedder:
50
39
  ``None`` if that collection is absent), and the metadata
51
40
  ``num_RWs``, ``num_SAWs``, ``node_count``, ``time_stamp_count``,
52
41
  ``walk_length``.
53
- base: Which SGNS backend to use, either ``"torch"`` or ``"pureml"``.
54
42
  seed: Optional seed for the embedder's RNG. If ``None``, a random
55
43
  32-bit seed is chosen.
56
44
 
57
45
  Raises:
58
46
  ValueError: If required metadata is missing or any loaded walk array
59
47
  has an unexpected shape.
60
- ImportError: If the requested backend is not installed.
61
- NameError: If ``base`` is not one of ``{"torch","pureml"}``.
62
48
 
63
49
  Notes:
64
50
  - Walks in storage are 1-based (residue indexing). Internally, this
65
51
  class normalizes to 0-based indices for training utilities.
66
52
  """
67
53
  self._walks_path = Path(WALKS_path)
68
- _logger.info("Initializing Embedder from %s (base=%s)", self._walks_path, base)
54
+ _logger.info("Initializing Embedder from %s", self._walks_path)
69
55
 
70
56
  # placeholders for optional walk collections
71
57
  self.attractive_RWs : np.ndarray | None = None
@@ -124,53 +110,76 @@ class Embedder:
124
110
  RWs_expected = (time_stamp_count, node_count * num_RWs, walk_length+1) if (num_RWs > 0) else None
125
111
  SAWs_expected = (time_stamp_count, node_count * num_SAWs, walk_length+1) if (num_SAWs > 0) else None
126
112
 
127
- self.vocab_size = int(node_count)
128
- self.frame_count = int(time_stamp_count)
129
- self.walk_length = int(walk_length)
113
+ self.vocab_size = int(node_count)
114
+ self.frame_count = int(time_stamp_count)
115
+ self.walk_length = int(walk_length)
116
+ self.num_RWs = int(num_RWs)
117
+ self.num_SAWs = int(num_SAWs)
118
+ # Keep dataset names for metadata passthrough
119
+ self._attractive_RWs_name = attractive_RWs_name
120
+ self._repulsive_RWs_name = repulsive_RWs_name
121
+ self._attractive_SAWs_name = attractive_SAWs_name
122
+ self._repulsive_SAWs_name = repulsive_SAWs_name
130
123
 
131
124
  # store walks if present
132
125
  if attractive_RWs is not None:
133
126
  if RWs_expected and attractive_RWs.shape != RWs_expected:
134
127
  raise ValueError(f"ATTR RWs: expected {RWs_expected}, got {attractive_RWs.shape}")
135
128
  self.attractive_RWs = attractive_RWs
129
+ _logger.debug("ATTR RWs loaded: %s", self.attractive_RWs.shape)
136
130
 
137
131
  if repulsive_RWs is not None:
138
132
  if RWs_expected and repulsive_RWs.shape != RWs_expected:
139
133
  raise ValueError(f"REP RWs: expected {RWs_expected}, got {repulsive_RWs.shape}")
140
134
  self.repulsive_RWs = repulsive_RWs
135
+ _logger.debug("REP RWs loaded: %s", self.repulsive_RWs.shape)
141
136
 
142
137
  if attractive_SAWs is not None:
143
138
  if SAWs_expected and attractive_SAWs.shape != SAWs_expected:
144
139
  raise ValueError(f"ATTR SAWs: expected {SAWs_expected}, got {attractive_SAWs.shape}")
145
140
  self.attractive_SAWs = attractive_SAWs
141
+ _logger.debug("ATTR SAWs loaded: %s", self.attractive_SAWs.shape)
146
142
 
147
143
  if repulsive_SAWs is not None:
148
144
  if SAWs_expected and repulsive_SAWs.shape != SAWs_expected:
149
145
  raise ValueError(f"REP SAWs: expected {SAWs_expected}, got {repulsive_SAWs.shape}")
150
146
  self.repulsive_SAWs = repulsive_SAWs
147
+ _logger.debug("REP SAWs loaded: %s", self.repulsive_SAWs.shape)
151
148
 
152
149
  # INTERNAL RNG
153
150
  self._seed = np.random.randint(0, 2**32 - 1) if seed is None else int(seed)
154
151
  self.rng = np.random.default_rng(self._seed)
155
152
  _logger.info("RNG initialized from seed=%d", self._seed)
156
153
 
157
- # MODEL HANDLE
158
- self.model_base: Literal["torch", "pureml"] = base
159
- self.model_constructor = self._get_SGNS_constructor_from(base)
160
- _logger.info("SGNS backend resolved: %s", getattr(self.model_constructor, "__name__", repr(self.model_constructor)))
161
-
162
154
  # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- PRIVATE -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
163
155
 
164
156
  # HELPERS:
165
157
 
166
158
  @staticmethod
167
- def _get_SGNS_constructor_from(base: Literal["torch", "pureml"]):
168
- """Resolve the SGNS implementation class for the selected backend."""
159
+ def _get_NN_constructor_from(base: Literal["torch", "pureml"],
160
+ objective: Literal["sgns", "sg"]):
161
+ """Resolve the SG/SGNS implementation class for the selected backend.
162
+
163
+ Args:
164
+ base: Backend family to use, ``"torch"`` or ``"pureml"``.
165
+ objective: Training objective, ``"sgns"`` or ``"sg"``.
166
+
167
+ Returns:
168
+ A callable class (constructor) implementing the requested model.
169
+
170
+ Raises:
171
+ ImportError: If the requested backend cannot be imported.
172
+ NameError: If ``base`` is not one of the supported values.
173
+ """
174
+ _logger.debug("Resolving model constructor: base=%s objective=%s", base, objective)
169
175
  if base == "torch":
170
176
  try:
171
- from .SGNS_torch import SGNS_Torch
172
- return SGNS_Torch
177
+ from .SGNS_torch import SGNS_Torch, SG_Torch
178
+ ctor = SG_Torch if objective == "sg" else SGNS_Torch
179
+ _logger.debug("Resolved PyTorch class: %s", getattr(ctor, "__name__", str(ctor)))
180
+ return ctor
173
181
  except Exception:
182
+ _logger.exception("Failed to import PyTorch backend.")
174
183
  raise ImportError(
175
184
  "PyTorch is not installed, but base='torch' was requested. "
176
185
  "Install PyTorch first, e.g.: `pip install torch` "
@@ -178,9 +187,12 @@ class Embedder:
178
187
  )
179
188
  elif base == "pureml":
180
189
  try:
181
- from .SGNS_pml import SGNS_PureML
182
- return SGNS_PureML
190
+ from .SGNS_pml import SGNS_PureML, SG_PureML
191
+ ctor = SG_PureML if objective == "sg" else SGNS_PureML
192
+ _logger.debug("Resolved PureML class: %s", getattr(ctor, "__name__", str(ctor)))
193
+ return ctor
183
194
  except Exception:
195
+ _logger.exception("Failed to import PureML backend.")
184
196
  raise ImportError(
185
197
  "PureML is not installed, but base='pureml' was requested. "
186
198
  "Install PureML first via `pip install ym-pure-ml` "
@@ -190,7 +202,18 @@ class Embedder:
190
202
 
191
203
  @staticmethod
192
204
  def _as_zerobase_intp(walks: np.ndarray, *, V: int) -> np.ndarray:
193
- """Validate 1-based uint/int walks 0-based intp; check bounds."""
205
+ """Validate and convert 1-based walks to 0-based ``intp``.
206
+
207
+ Args:
208
+ walks: 2D array of node ids with 1-based indexing.
209
+ V: Vocabulary size for bounds checking.
210
+
211
+ Returns:
212
+ 2D array of dtype ``intp`` with 0-based indices.
213
+
214
+ Raises:
215
+ ValueError: If shape is not 2D or indices are out of bounds.
216
+ """
194
217
  arr = np.asarray(walks)
195
218
  if arr.ndim != 2:
196
219
  raise ValueError("walks must be 2D: (num_walks, walk_len)")
@@ -198,7 +221,9 @@ class Embedder:
198
221
  arr = arr.astype(np.int64, copy=False)
199
222
  # 1-based → 0-based
200
223
  arr = arr - 1
201
- if arr.min() < 0 or arr.max() >= V:
224
+ mn, mx = int(arr.min()), int(arr.max())
225
+ _logger.debug("Zero-basing walks: min=%d max=%d V=%d", mn, mx, V)
226
+ if mn < 0 or mx >= V:
202
227
  raise ValueError("walk ids out of range after 1→0-based normalization")
203
228
  return arr.astype(np.intp, copy=False)
204
229
 
@@ -206,19 +231,29 @@ class Embedder:
206
231
  def _pairs_from_walks(walks0: np.ndarray, window_size: int) -> np.ndarray:
207
232
  """
208
233
  Skip-gram pairs including edge centers (one-sided when needed).
209
- walks0: (W, L) int array (0-based ids).
210
- Returns: (N_pairs, 2) int32 [center, context].
234
+
235
+ Args:
236
+ walks0: (W, L) int array (0-based ids).
237
+ window_size: Symmetric context window radius.
238
+
239
+ Returns:
240
+ Array of shape (N_pairs, 2) int32 with columns [center, context].
241
+
242
+ Raises:
243
+ ValueError: If shape is invalid or ``window_size`` <= 0.
211
244
  """
212
245
  if walks0.ndim != 2:
213
246
  raise ValueError("walks must be 2D: (num_walks, walk_len)")
214
247
 
215
248
  _, L = walks0.shape
216
249
  k = int(window_size)
250
+ _logger.debug("Building SG pairs: L=%d, window=%d", L, k)
217
251
 
218
252
  if k <= 0:
219
253
  raise ValueError("window_size must be positive")
220
254
 
221
255
  if L == 0:
256
+ _logger.debug("Empty walks length; returning 0 pairs.")
222
257
  return np.empty((0, 2), dtype=np.int32)
223
258
 
224
259
  out_chunks = []
@@ -236,18 +271,42 @@ class Embedder:
236
271
  out_chunks.append(np.stack((centers_l, ctx_l), axis=2).reshape(-1, 2))
237
272
 
238
273
  if not out_chunks:
274
+ _logger.debug("No offsets produced pairs; returning empty.")
239
275
  return np.empty((0, 2), dtype=np.int32)
240
276
 
241
- return np.concatenate(out_chunks, axis=0).astype(np.int32, copy=False)
277
+ pairs = np.concatenate(out_chunks, axis=0).astype(np.int32, copy=False)
278
+ _logger.debug("Built %d training pairs", pairs.shape[0])
279
+ return pairs
242
280
 
243
281
  @staticmethod
244
282
  def _freq_from_walks(walks0: np.ndarray, *, V: int) -> np.ndarray:
245
- """Node frequencies from walks (0-based)."""
246
- return np.bincount(walks0.ravel(), minlength=V).astype(np.int64, copy=False)
283
+ """Node frequencies from walks (0-based).
284
+
285
+ Args:
286
+ walks0: 2D array of 0-based node ids.
287
+ V: Vocabulary size (minlength for bincount).
288
+
289
+ Returns:
290
+ 1D array of int64 frequencies with length ``V``.
291
+ """
292
+ freq = np.bincount(walks0.ravel(), minlength=V).astype(np.int64, copy=False)
293
+ _logger.debug("Frequency mass: total=%d nonzero=%d", int(freq.sum()), int(np.count_nonzero(freq)))
294
+ return freq
247
295
 
248
296
  @staticmethod
249
297
  def _soft_unigram(freq: np.ndarray, *, power: float = 0.75) -> np.ndarray:
250
- """Return normalized Pn(w) ∝ f(w)^power as float64 probs."""
298
+ """Return normalized Pn(w) ∝ f(w)^power as float64 probs.
299
+
300
+ Args:
301
+ freq: 1D array of token frequencies.
302
+ power: Exponent used for smoothing (default 0.75 à la word2vec).
303
+
304
+ Returns:
305
+ 1D array of probabilities summing to 1.0.
306
+
307
+ Raises:
308
+ ValueError: If mass is invalid (all zeros or non-finite).
309
+ """
251
310
  p = np.asarray(freq, dtype=np.float64)
252
311
  if p.sum() == 0:
253
312
  raise ValueError("all frequencies are zero")
@@ -255,13 +314,31 @@ class Embedder:
255
314
  s = p.sum()
256
315
  if not np.isfinite(s) or s <= 0:
257
316
  raise ValueError("invalid unigram mass")
258
- return p / s
317
+ probs = p / s
318
+ _logger.debug("Noise distribution ready (power=%.3f)", power)
319
+ return probs
259
320
 
260
321
  def _materialize_walks(self, frame_id: int, rin: Literal["attr", "repuls"],
261
322
  using: Literal["RW", "SAW", "merged"]) -> np.ndarray:
323
+ """Materialize the requested set of walks for a frame.
324
+
325
+ Args:
326
+ frame_id: 1-based frame index.
327
+ rin: Which RIN to pull from: ``"attr"`` or ``"repuls"``.
328
+ using: Which walk sets to include: ``"RW"``, ``"SAW"``, or ``"merged"``.
329
+ If ``"merged"``, concatenate available RW and SAW along axis 0.
330
+
331
+ Returns:
332
+ A 2D array of walks with shape (num_walks, walk_length+1).
333
+
334
+ Raises:
335
+ IndexError: If ``frame_id`` is out of range.
336
+ ValueError: If no matching walks are available.
337
+ """
262
338
  if not 1 <= frame_id <= int(self.frame_count):
263
339
  raise IndexError(f"frame_id must be in [1, {self.frame_count}]; got {frame_id}")
264
340
 
341
+ _logger.debug("Materializing %s walks at frame=%d using=%s", rin, frame_id, using)
265
342
  frame_id -= 1
266
343
 
267
344
  if rin == "attr":
@@ -288,8 +365,12 @@ class Embedder:
288
365
  if not parts:
289
366
  raise ValueError(f"No walks available for {rin=} with {using=}")
290
367
  if len(parts) == 1:
291
- return parts[0]
292
- return np.concatenate(parts, axis=0)
368
+ out = parts[0]
369
+ else:
370
+ out = np.concatenate(parts, axis=0)
371
+
372
+ _logger.debug("Materialized walks shape: %s", getattr(out, "shape", None))
373
+ return out
293
374
 
294
375
  # INTERFACES: (private)
295
376
 
@@ -298,6 +379,17 @@ class Embedder:
298
379
  using: Literal["RW", "SAW", "merged"],
299
380
  window_size: int,
300
381
  alpha: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
382
+ """Construct (center, context) pairs and noise distribution for ATTR.
383
+
384
+ Args:
385
+ frame_id: 1-based frame index.
386
+ using: Walk subset to include.
387
+ window_size: Skip-gram window radius.
388
+ alpha: Unigram smoothing exponent.
389
+
390
+ Returns:
391
+ Tuple of (pairs, noise_probs).
392
+ """
301
393
  walks = self._materialize_walks(frame_id, "attr", using)
302
394
  walks0 = self._as_zerobase_intp(walks, V=self.vocab_size)
303
395
  attractive_corpus = self._pairs_from_walks(walks0, window_size)
@@ -311,6 +403,17 @@ class Embedder:
311
403
  using: Literal["RW", "SAW", "merged"],
312
404
  window_size: int,
313
405
  alpha: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
406
+ """Construct (center, context) pairs and noise distribution for REP.
407
+
408
+ Args:
409
+ frame_id: 1-based frame index.
410
+ using: Walk subset to include.
411
+ window_size: Skip-gram window radius.
412
+ alpha: Unigram smoothing exponent.
413
+
414
+ Returns:
415
+ Tuple of (pairs, noise_probs).
416
+ """
314
417
  walks = self._materialize_walks(frame_id, "repuls", using)
315
418
  walks0 = self._as_zerobase_intp(walks, V=self.vocab_size)
316
419
  repulsive_corpus = self._pairs_from_walks(walks0, window_size)
@@ -322,56 +425,63 @@ class Embedder:
322
425
  # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= PUBLIC -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
323
426
 
324
427
  def embed_frame(self,
325
- frame_id: int,
326
- RIN_type: Literal["attr", "repuls"],
327
- using: Literal["RW", "SAW", "merged"],
328
- window_size: int,
329
- num_negative_samples: int,
330
- num_epochs: int,
331
- batch_size: int,
332
- *,
333
- shuffle_data: bool = True,
334
- dimensionality: int = 128,
335
- alpha: float = 0.75,
336
- device: str | None = None,
337
- sgns_kwargs: dict[str, object] | None = None,
338
- _seed: int | None = None
339
- ) -> np.ndarray:
340
- """Train embeddings for a single frame and return the input embedding matrix.
428
+ frame_id: int,
429
+ RIN_type: Literal["attr", "repuls"],
430
+ using: Literal["RW", "SAW", "merged"],
431
+ num_epochs: int,
432
+ negative_sampling: bool = False,
433
+ window_size: int = 5,
434
+ num_negative_samples: int = 10,
435
+ batch_size: int = 1024,
436
+ *,
437
+ in_weights: np.ndarray | None = None,
438
+ out_weights: np.ndarray | None = None,
439
+ lr_step_per_batch: bool = False,
440
+ shuffle_data: bool = True,
441
+ dimensionality: int = 128,
442
+ alpha: float = 0.75,
443
+ device: str | None = None,
444
+ model_base: Literal["torch", "pureml"] = "pureml",
445
+ model_kwargs: dict[str, object] | None = None,
446
+ kind: tuple[Literal["in", "out", "avg"]] = ("in",),
447
+ _seed: int | None = None
448
+ ) -> list[tuple[np.ndarray, str]]:
449
+ """Train embeddings for a single frame and return the matrix containing embeddings of the specified `kind`.
341
450
 
342
451
  Args:
343
- frame_id: 1-based frame index to train on.
344
- RIN_type: Interaction channel to use: ``"attr"`` (attractive) or
345
- ``"repuls"`` (repulsive).
346
- using: Which walk collections to include: ``"RW"``, ``"SAW"``, or
347
- ``"merged"`` (concatenates both if available).
348
- window_size: Symmetric skip-gram window size ``k``.
349
- num_negative_samples: Number of negative samples per positive pair.
350
- num_epochs: Number of passes over the pair dataset.
351
- batch_size: Mini-batch size for training.
352
- shuffle_data: Whether to shuffle pairs each epoch.
353
- dimensionality: Embedding dimensionality ``D``.
354
- alpha: Noise distribution exponent (``Pn ∝ f^alpha``).
355
- device: Optional device string for the Torch backend (e.g., ``"cuda"``).
356
- sgns_kwargs: Extra keyword arguments forwarded to the backend SGNS
357
- constructor. For PureML, required keys are:
358
- ``{"optim", "optim_kwargs", "lr_sched", "lr_sched_kwargs"}``.
359
- _seed: Optional child seed for this frame's model initialization.
452
+ frame_id: 1-based frame index to embed.
453
+ RIN_type: ``"attr"`` or ``"repuls"`` - which corpus to use.
454
+ using: Which walks to use (``"RW"``, ``"SAW"``, or ``"merged"``).
455
+ num_epochs: Number of passes over the pairs.
456
+ negative_sampling: If ``True``, use SGNS objective; else plain SG.
457
+ window_size: Skip-gram symmetric window radius.
458
+ num_negative_samples: Negatives per positive pair (SGNS only).
459
+ batch_size: Minibatch size for training.
460
+ in_weights: Optional starting input-embedding matrix of shape (V, D).
461
+ out_weights: Optional starting output-embedding matrix of shape (V, D).
462
+ SGNS: shape (V, D)
463
+ SG: shape (D, V)
464
+ lr_step_per_batch: If ``True``, step LR every batch (else per epoch).
465
+ shuffle_data: Shuffle pairs each epoch.
466
+ dimensionality: Embedding dimension ``D``.
467
+ alpha: Unigram smoothing power for noise distribution.
468
+ device: Optional backend device hint (e.g., ``"cuda"``).
469
+ model_base: Backend family (``"torch"`` or ``"pureml"``).
470
+ model_kwargs: Passed through to backend model constructor.
471
+ kind: Which embedding to return: ``"in"``, ``"out"``, or ``"avg"``.
472
+ _seed: Optional override seed for this frame.
360
473
 
361
474
  Returns:
362
- np.ndarray: Learned **input** embedding matrix of shape ``(V, D)``.
363
-
364
- Raises:
365
- ValueError: If requested walks are missing, if no training pairs are
366
- generated, or if required ``sgns_kwargs`` for PureML are absent.
367
- AttributeError: If the SGNS model does not expose embeddings via
368
- ``.embeddings`` or ``.parameters[0]``.
475
+ list[tuple[np.ndarray, Literal["avg","in","out"]]]:
476
+ (embedding, kind) pairs sorted as 'avg', 'in', 'out'.
369
477
  """
370
478
  _logger.info(
371
- "Preparing frame %d (rin=%s using=%s window=%d neg=%d epochs=%d batch=%d)",
372
- frame_id, RIN_type, using, window_size, num_negative_samples, num_epochs, batch_size
479
+ "embed_frame: frame=%d RIN=%s using=%s base=%s D=%d epochs=%d batch=%d sgns=%s window_size=%d alpha=%.3f",
480
+ frame_id, RIN_type, using, model_base, dimensionality, num_epochs, batch_size,
481
+ str(negative_sampling), window_size, alpha
373
482
  )
374
483
 
484
+ # ------------------ resolve training data -----------------
375
485
  if RIN_type == "attr":
376
486
  if self.attractive_RWs is None and self.attractive_SAWs is None:
377
487
  raise ValueError("Attractive random walks are missing")
@@ -381,125 +491,125 @@ class Embedder:
381
491
  raise ValueError("Repulsive random walks are missing")
382
492
  pairs, noise_probs = self._repulsive_corpus_and_prob(frame_id=frame_id, using=using, window_size=window_size, alpha=alpha)
383
493
  else:
384
- raise ValueError(f"Unknown RIN_type: {RIN_type!r}")
385
-
494
+ raise NameError(f"Unknown RIN_type: {RIN_type!r}")
386
495
  if pairs.size == 0:
387
496
  raise ValueError("No training pairs generated for the requested configuration")
497
+ # ----------------------------------------------------------
388
498
 
499
+ # ---------------- construct training corpus ---------------
389
500
  centers = pairs[:, 0].astype(np.int64, copy=False)
390
501
  contexts = pairs[:, 1].astype(np.int64, copy=False)
502
+ _logger.debug("Pairs split: centers=%s contexts=%s", centers.shape, contexts.shape)
503
+ # ----------------------------------------------------------
391
504
 
392
- model_kwargs: dict[str, object] = dict(sgns_kwargs or {})
393
- if self.model_base == "pureml":
394
- required = {"optim", "optim_kwargs", "lr_sched", "lr_sched_kwargs"}
395
- missing = required.difference(model_kwargs)
396
- if missing:
397
- raise ValueError(f"PureML backend requires {sorted(missing)} in sgns_kwargs.")
505
+ # ------------ resolve model_constructor kwargs ------------
506
+ if model_kwargs is not None:
507
+ if (("lr_sched" in model_kwargs and model_kwargs.get("lr_sched", None) is not None)
508
+ and ("lr_sched_kwargs" in model_kwargs and model_kwargs.get("lr_sched_kwargs", None) is None)):
509
+ raise ValueError("When `lr_sched`, you must also provide `lr_sched_kwargs`.")
398
510
 
399
- child_seed = int(self._seed if _seed is None else _seed)
400
- model_kwargs.update({
511
+ constructor_kwargs: dict[str, object] = dict(model_kwargs or {})
512
+ constructor_kwargs.update({
401
513
  "V": self.vocab_size,
402
514
  "D": dimensionality,
403
- "seed": child_seed
515
+ "in_weights": in_weights,
516
+ "out_weights": out_weights,
517
+ "seed": int(self._seed if _seed is None else _seed),
518
+ "device": device
404
519
  })
520
+ _logger.debug("Model constructor kwargs: %s", {k: constructor_kwargs[k] for k in ("V","D","seed","device")})
521
+ # ----------------------------------------------------------
522
+
523
+ # --------------- resolve model constructor ----------------
524
+ model_constructor = self._get_NN_constructor_from(
525
+ model_base, objective=("sgns" if negative_sampling else "sg"))
526
+ # ----------------------------------------------------------
527
+
528
+ # ------------------ initialize the model ------------------
529
+ model = model_constructor(**constructor_kwargs)
530
+ _logger.debug("Model initialized: %s", model_constructor.__name__ if hasattr(model_constructor,"__name__") else str(model_constructor))
531
+ # ----------------------------------------------------------
532
+
533
+ # -------------------- fitting the data --------------------
534
+ _logger.info("Fitting model on %d pairs ...", pairs.shape[0])
535
+ model.fit(centers=centers,
536
+ contexts=contexts,
537
+ num_epochs=num_epochs,
538
+ batch_size=batch_size,
539
+ # -- optional; for SGNS; safely ignored by SG via **_ignore --
540
+ num_negative_samples=num_negative_samples,
541
+ noise_dist=noise_probs,
542
+ # -----------------------------------------
543
+ shuffle_data=shuffle_data,
544
+ lr_step_per_batch=lr_step_per_batch
545
+ )
546
+ _logger.info("Training complete for frame %d", frame_id)
547
+ # ----------------------------------------------------------
405
548
 
406
- if self.model_base == "torch" and device is not None:
407
- model_kwargs["device"] = device
408
-
409
- self.model = self.model_constructor(**model_kwargs)
410
-
411
- _logger.info(
412
- "Training SGNS base=%s constructor=%s frame=%d pairs=%d dim=%d epochs=%d batch=%d neg=%d shuffle=%s",
413
- self.model_base,
414
- getattr(self.model_constructor, "__name__", repr(self.model_constructor)),
415
- frame_id,
416
- pairs.shape[0],
417
- dimensionality,
418
- num_epochs,
419
- batch_size,
420
- num_negative_samples,
421
- shuffle_data
422
- )
423
-
424
- self.model.fit(
425
- centers,
426
- contexts,
427
- num_epochs,
428
- batch_size,
429
- num_negative_samples,
430
- noise_probs,
431
- shuffle_data,
432
- lr_step_per_batch=False
433
- )
549
+ if any([k not in ("in", "out", "avg") for k in kind]):
550
+ raise NameError(f"Unknown embeddings kind in {kind}. Expected: one of ['in', 'out', 'avg']")
434
551
 
435
- embeddings = getattr(self.model, "embeddings", None)
436
- if embeddings is None:
437
- params = getattr(self.model, "parameters", None)
438
- if isinstance(params, tuple) and params:
439
- embeddings = params[0]
440
- if embeddings is None:
441
- raise AttributeError("SGNS model does not expose embeddings via '.embeddings' or '.parameters[0]'")
552
+ # OUTPUT:
553
+ embeddings = [(np.asarray(model.in_embeddings, dtype=np.float32), k) if k == "in" else
554
+ (np.asarray(model.out_embeddings, dtype=np.float32), k) if k == "out" else
555
+ (np.asarray(model.avg_embeddings, dtype=np.float32), k) if k == "avg" else
556
+ (None, k)
557
+ for k in kind
558
+ ]
559
+ embeddings.sort(key=lambda pair: pair[1]) # ensures 'avg', 'in', 'out' ordering
442
560
 
443
- embeddings = np.asarray(embeddings)
444
- _logger.info("Frame %d embeddings ready: shape=%s dtype=%s", frame_id, embeddings.shape, embeddings.dtype)
445
561
  return embeddings
446
562
 
447
563
  def embed_all(
448
564
  self,
449
565
  RIN_type: Literal["attr", "repuls"],
450
566
  using: Literal["RW", "SAW", "merged"],
451
- window_size: int,
452
- num_negative_samples: int,
453
567
  num_epochs: int,
454
- batch_size: int,
568
+ negative_sampling: bool = False,
569
+ window_size: int = 2,
570
+ num_negative_samples: int = 10,
571
+ batch_size: int = 1024,
455
572
  *,
573
+ lr_step_per_batch: bool = False,
456
574
  shuffle_data: bool = True,
457
575
  dimensionality: int = 128,
458
576
  alpha: float = 0.75,
459
577
  device: str | None = None,
460
- sgns_kwargs: dict[str, object] | None = None,
578
+ model_base: Literal["torch", "pureml"] = "pureml",
579
+ model_kwargs: dict[str, object] | None = None,
580
+ kind: Literal["in", "out", "avg"] = "in",
461
581
  output_path: str | Path | None = None,
462
582
  num_matrices_in_compressed_blocks: int = 20,
463
- compression_level: int = 3):
464
- """Train embeddings for all frames and persist them to compressed storage.
583
+ compression_level: int = 3,
584
+ ) -> str:
585
+ """Embed all frames and persist a self-contained archive.
465
586
 
466
- Iterates through all frames (``1..frame_count``), trains an SGNS model
467
- per frame using the configured backend, collects the resulting input
468
- embeddings, and writes them into a new compressed ``ArrayStorage`` archive.
587
+ The resulting file stores a block named ``FRAME_EMBEDDINGS`` with a
588
+ compressed sequence of per-frame matrices (each ``(V, D)``), alongside
589
+ rich metadata mirroring the style of other SAWNERGY modules.
469
590
 
470
591
  Args:
471
- RIN_type: Interaction channel to use: ``"attr"`` or ``"repuls"``.
472
- using: Walk collections: ``"RW"``, ``"SAW"``, or ``"merged"``.
473
- window_size: Symmetric skip-gram window size ``k``.
474
- num_negative_samples: Number of negative samples per positive pair.
475
- num_epochs: Number of epochs for each frame.
476
- batch_size: Mini-batch size used during training.
477
- shuffle_data: Whether to shuffle pairs each epoch.
478
- dimensionality: Embedding dimensionality ``D``.
479
- alpha: Noise distribution exponent (``Pn ∝ f^alpha``).
480
- device: Optional device string for Torch backend.
481
- sgns_kwargs: Extra constructor kwargs for the SGNS backend (see
482
- :meth:`embed_frame` for PureML requirements).
483
- output_path: Destination path. If ``None``, a new file named
484
- ``EMBEDDINGS_<timestamp>.zip`` is created next to the source
485
- WALKS archive. If the provided path lacks a suffix, ``.zip`` is
486
- appended.
487
- num_matrices_in_compressed_blocks: Number of per-frame matrices to
488
- store per compressed chunk in the output archive.
489
- compression_level: Blosc Zstd compression level (0-9).
592
+ RIN_type: ``"attr"`` or ``"repuls"`` - which corpus to use.
593
+ using: Which walks to use (``"RW"``, ``"SAW"``, or ``"merged"``).
594
+ num_epochs: Number of epochs to train per frame.
595
+ negative_sampling: If ``True``, use SGNS; otherwise plain SG.
596
+ window_size: Skip-gram window radius.
597
+ num_negative_samples: Negatives per positive pair (SGNS).
598
+ batch_size: Minibatch size for training.
599
+ lr_step_per_batch: If ``True``, step LR per batch (else per epoch).
600
+ shuffle_data: Shuffle pairs each epoch.
601
+ dimensionality: Embedding dimension.
602
+ alpha: Unigram smoothing power for noise distribution.
603
+ device: Backend device hint (e.g., ``"cuda"``).
604
+ model_base: Backend family (``"torch"`` or ``"pureml"``).
605
+ model_kwargs: Passed through to backend model constructor.
606
+ kind: Which embedding to store: ``"in"``, ``"out"``, or ``"avg"``.
607
+ output_path: Optional path for the output archive (``.zip`` inferred).
608
+ num_matrices_in_compressed_blocks: How many frames per compressed chunk.
609
+ compression_level: Integer compression level for the archive.
490
610
 
491
611
  Returns:
492
- str: Filesystem path to the written embeddings archive (``.zip``).
493
-
494
- Raises:
495
- ValueError: If configuration produces no pairs for a frame or if
496
- PureML kwargs are incomplete.
497
- RuntimeError: Propagated from storage operations on failure.
498
-
499
- Notes:
500
- - A deterministic child seed is spawned per frame from the master
501
- seed using ``np.random.SeedSequence`` to ensure reproducibility
502
- across runs.
612
+ Path to the created embeddings archive, as ``str``.
503
613
  """
504
614
  current_time = sawnergy_util.current_time()
505
615
  if output_path is None:
@@ -510,69 +620,182 @@ class Embedder:
510
620
  output_path = output_path.with_suffix(".zip")
511
621
 
512
622
  _logger.info(
513
- "Embedding all frames -> %s | frames=%d dim=%d base=%s",
514
- output_path, self.frame_count, dimensionality, self.model_base
623
+ "embed_all: frames=%d D=%d base=%s RIN=%s using=%s out=%s",
624
+ self.frame_count, dimensionality, model_base, RIN_type, using, output_path
515
625
  )
516
626
 
627
+ # Per-frame deterministic seeds
517
628
  master_ss = np.random.SeedSequence(self._seed)
518
629
  child_seeds = master_ss.spawn(self.frame_count)
519
630
 
520
- embeddings = []
521
- for frame_idx, seed_seq in enumerate(child_seeds, start=1):
631
+ embeddings: list[np.ndarray] = []
632
+ last_frame_in_embs: np.ndarray = None
633
+ last_frame_out_embs: np.ndarray = None
634
+ used_child_seeds: list[int] = []
635
+ for frame_id, seed_seq in enumerate(child_seeds, start=1):
522
636
  child_seed = int(seed_seq.generate_state(1, dtype=np.uint32)[0])
523
- _logger.info("Processing frame %d/%d (child_seed=%d entropy=%d)", frame_idx, self.frame_count, child_seed, seed_seq.entropy)
524
- embeddings.append(
637
+ used_child_seeds.append(child_seed)
638
+ _logger.info("Embedding frame %d/%d with seed=%d", frame_id, self.frame_count, child_seed)
639
+
640
+ embs_and_kinds: list[tuple[np.ndarray, str]] = \
525
641
  self.embed_frame(
526
- frame_idx,
527
- RIN_type,
528
- using,
529
- window_size,
530
- num_negative_samples,
531
- num_epochs,
532
- batch_size,
642
+ frame_id=frame_id,
643
+ RIN_type=RIN_type,
644
+ using=using,
645
+ num_epochs=num_epochs,
646
+ negative_sampling=negative_sampling,
647
+ window_size=window_size,
648
+ num_negative_samples=num_negative_samples,
649
+ batch_size=batch_size,
650
+ in_weights=last_frame_in_embs,
651
+ out_weights=last_frame_out_embs,
652
+ lr_step_per_batch=lr_step_per_batch,
533
653
  shuffle_data=shuffle_data,
534
654
  dimensionality=dimensionality,
535
655
  alpha=alpha,
536
656
  device=device,
537
- sgns_kwargs=sgns_kwargs,
657
+ model_base=model_base,
658
+ model_kwargs=model_kwargs,
659
+ kind=("in", "out", "avg"),
538
660
  _seed=child_seed
539
661
  )
540
- )
662
+ embs = {K: E for (E, K) in embs_and_kinds}
663
+
664
+ last_frame_in_embs = embs["in"] # (V, D)
665
+ last_frame_out_embs = embs["out"] if negative_sampling else embs["out"].T # SG needs (D, V), SGNS keeps (V, D)
666
+
667
+ resolved_embedding = embs[kind]
668
+ embeddings.append(np.asarray(resolved_embedding, dtype=np.float32, copy=False))
669
+
670
+ _logger.debug("Frame %d embedded: E.shape=%s", frame_id, resolved_embedding.shape)
541
671
 
542
- embeddings = [np.asarray(e) for e in embeddings]
543
672
  block_name = "FRAME_EMBEDDINGS"
544
673
  with sawnergy_util.ArrayStorage.compress_and_cleanup(output_path, compression_level=compression_level) as storage:
674
+ _logger.info("Writing %d frame matrices to block '%s' ...", len(embeddings), block_name)
545
675
  storage.write(
546
676
  these_arrays=embeddings,
547
677
  to_block_named=block_name,
548
678
  arrays_per_chunk=num_matrices_in_compressed_blocks
549
679
  )
550
- storage.add_attr("time_created", current_time)
551
- storage.add_attr("seed", int(self._seed))
552
- storage.add_attr("rng_scheme", "SeedSequence.spawn_per_frame_v1")
553
- storage.add_attr("source_walks_path", str(self._walks_path))
554
- storage.add_attr("model_base", self.model_base)
555
- storage.add_attr("rin_type", RIN_type)
556
- storage.add_attr("using_mode", using)
680
+
681
+ # Core dataset discovery (for consumers like the Embeddings Visualizer)
682
+ storage.add_attr("frame_embeddings_name", block_name)
683
+ storage.add_attr("time_stamp_count", int(self.frame_count))
684
+ storage.add_attr("node_count", int(self.vocab_size))
685
+ storage.add_attr("embedding_dim", int(dimensionality))
686
+
687
+ # Provenance of input WALKS
688
+ storage.add_attr("source_WALKS_path", str(self._walks_path))
689
+ storage.add_attr("walk_length", int(self.walk_length))
690
+ storage.add_attr("num_RWs", int(self.num_RWs))
691
+ storage.add_attr("num_SAWs", int(self.num_SAWs))
692
+ storage.add_attr("attractive_RWs_name", self._attractive_RWs_name)
693
+ storage.add_attr("repulsive_RWs_name", self._repulsive_RWs_name)
694
+ storage.add_attr("attractive_SAWs_name", self._attractive_SAWs_name)
695
+ storage.add_attr("repulsive_SAWs_name", self._repulsive_SAWs_name)
696
+
697
+ # Training configuration (sufficient to reproduce)
698
+ storage.add_attr("objective", "sgns" if negative_sampling else "sg")
699
+ storage.add_attr("model_base", model_base)
700
+ storage.add_attr("embedding_kind", kind) # 'in' | 'out' | 'avg'
701
+ storage.add_attr("num_epochs", int(num_epochs))
702
+ storage.add_attr("batch_size", int(batch_size))
557
703
  storage.add_attr("window_size", int(window_size))
558
704
  storage.add_attr("alpha", float(alpha))
559
- storage.add_attr("dimensionality", int(dimensionality))
705
+ storage.add_attr("negative_sampling", bool(negative_sampling))
560
706
  storage.add_attr("num_negative_samples", int(num_negative_samples))
561
- storage.add_attr("num_epochs", int(num_epochs))
562
- storage.add_attr("batch_size", int(batch_size))
707
+ storage.add_attr("lr_step_per_batch", bool(lr_step_per_batch))
563
708
  storage.add_attr("shuffle_data", bool(shuffle_data))
564
- storage.add_attr("frames_written", int(len(embeddings)))
565
- storage.add_attr("vocab_size", int(self.vocab_size))
566
- storage.add_attr("frame_count", int(self.frame_count))
567
- storage.add_attr("embedding_dtype", str(embeddings[0].dtype))
568
- storage.add_attr("frame_embeddings_name", block_name)
709
+ storage.add_attr("device_hint", device if device is not None else "")
710
+ storage.add_attr("model_kwargs_repr", repr(model_kwargs) if model_kwargs is not None else "{}")
711
+
712
+ # Which walks were used to train
713
+ storage.add_attr("RIN_type", RIN_type) # 'attr' or 'repuls'
714
+ storage.add_attr("using", using) # 'RW' | 'SAW' | 'merged'
715
+
716
+ # Reproducibility
717
+ storage.add_attr("master_seed", int(self._seed))
718
+ storage.add_attr("per_frame_seeds", [int(s) for s in used_child_seeds])
719
+
720
+ # Archive/IO details
569
721
  storage.add_attr("arrays_per_chunk", int(num_matrices_in_compressed_blocks))
570
722
  storage.add_attr("compression_level", int(compression_level))
723
+ storage.add_attr("created_at", current_time)
724
+
725
+ _logger.info(
726
+ "Stored embeddings archive: %s | shape=(T,N,D)=(%d,%d,%d)",
727
+ output_path, self.frame_count, self.vocab_size, dimensionality
728
+ )
571
729
 
572
- _logger.info("Embedding archive written to %s", output_path)
573
730
  return str(output_path)
574
731
 
575
- __all__ = ["Embedder"]
732
+ # *----------------------------------------------------*
733
+ # FUNCTIONS
734
+ # *----------------------------------------------------*
735
+
736
+ def align_frames(this: np.ndarray,
737
+ to_this: np.ndarray,
738
+ *,
739
+ center: bool = True,
740
+ add_back_mean: bool = True,
741
+ allow_reflection: bool = False) -> np.ndarray:
742
+ """
743
+ Align `this` to `to_this` via Orthogonal Procrustes.
744
+
745
+ Solves: min_{R ∈ O(D)} || X R - Y ||_F
746
+ with X = this, Y = to_this (both shape (N, D)). Returns X aligned.
747
+
748
+ Args:
749
+ this: (N, D) matrix to be aligned.
750
+ to_this: (N, D) target matrix.
751
+ center: if True, subtract per-dimension means before solving.
752
+ add_back_mean: if True, add Y's mean back after alignment.
753
+ allow_reflection: if False, enforce det(R) = +1 (proper rotation).
754
+
755
+ Returns:
756
+ Aligned copy of `this` with shape (N, D).
757
+ """
758
+ X = np.asarray(this, dtype=np.float64)
759
+ Y = np.asarray(to_this, dtype=np.float64)
760
+
761
+ if X.ndim != 2 or Y.ndim != 2:
762
+ raise ValueError(f"Expected 2D arrays; got {X.ndim=} and {Y.ndim=}")
763
+ if X.shape[1] != Y.shape[1]:
764
+ raise ValueError(f"Dimensionalities must match: X.shape={X.shape}, Y.shape={Y.shape}")
765
+ if X.shape[0] != Y.shape[0]:
766
+ raise ValueError(f"Row counts must match (one-to-one correspondence): {X.shape[0]} vs {Y.shape[0]}")
767
+
768
+ # center
769
+ if center:
770
+ X_mean = X.mean(axis=0, keepdims=True)
771
+ Y_mean = Y.mean(axis=0, keepdims=True)
772
+ Xc = X - X_mean
773
+ Yc = Y - Y_mean
774
+ else:
775
+ Xc, Yc = X, Y
776
+ Y_mean = 0.0
777
+
778
+ # Cross-covariance and SVD
779
+ # M = Xᵀ Y (D×D); solution R = U Vᵀ for SVD(M) = U Σ Vᵀ
780
+ M = Xc.T @ Yc
781
+ U, _, Vt = np.linalg.svd(M, full_matrices=False)
782
+ R = U @ Vt
783
+
784
+ # enforce proper rotation unless reflections are allowed
785
+ if not allow_reflection and np.linalg.det(R) < 0:
786
+ Vt[-1, :] *= -1
787
+ R = U @ Vt
788
+
789
+ X_aligned = Xc @ R
790
+
791
+ if center and add_back_mean is True:
792
+ X_aligned = X_aligned + Y_mean
793
+
794
+ # match input dtype if possible
795
+ return X_aligned.astype(this.dtype, copy=False)
796
+
797
+
798
+ __all__ = ["Embedder", "align_frames"]
576
799
 
577
800
  if __name__ == "__main__":
578
801
  pass