sawnergy 1.0.7__py3-none-any.whl → 1.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sawnergy might be problematic. Click here for more details.

@@ -1,15 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- """
4
- Embedding orchestration for Skip-Gram (SG) and Skip-Gram with Negative Sampling (SGNS).
5
-
6
- This module consumes attractive/repulsive walk corpora produced by the walker
7
- pipeline and trains per-frame embeddings using either the PyTorch or PureML
8
- implementations of SG/SGNS. The resulting embeddings can be persisted back into
9
- an ``ArrayStorage`` archive along with rich metadata describing the training
10
- configuration.
11
- """
12
-
13
3
  # third-pary
14
4
  import numpy as np
15
5
 
@@ -36,10 +26,8 @@ class Embedder:
36
26
 
37
27
  def __init__(self,
38
28
  WALKS_path: str | Path,
39
- base: Literal["torch", "pureml"],
40
29
  *,
41
30
  seed: int | None = None,
42
- objective: Literal["sgns", "sg"] = "sgns"
43
31
  ) -> None:
44
32
  """Initialize the embedder and load walk tensors.
45
33
 
@@ -51,24 +39,19 @@ class Embedder:
51
39
  ``None`` if that collection is absent), and the metadata
52
40
  ``num_RWs``, ``num_SAWs``, ``node_count``, ``time_stamp_count``,
53
41
  ``walk_length``.
54
- base: Which SGNS backend to use, either ``"torch"`` or ``"pureml"``.
55
42
  seed: Optional seed for the embedder's RNG. If ``None``, a random
56
43
  32-bit seed is chosen.
57
- objective: Training objective, either ``"sgns"`` (negative sampling)
58
- or ``"sg"`` (plain full-softmax Skip-Gram).
59
44
 
60
45
  Raises:
61
46
  ValueError: If required metadata is missing or any loaded walk array
62
47
  has an unexpected shape.
63
- ImportError: If the requested backend is not installed.
64
- NameError: If ``base`` is not one of ``{"torch","pureml"}``.
65
48
 
66
49
  Notes:
67
50
  - Walks in storage are 1-based (residue indexing). Internally, this
68
51
  class normalizes to 0-based indices for training utilities.
69
52
  """
70
53
  self._walks_path = Path(WALKS_path)
71
- _logger.info("Initializing Embedder from %s (base=%s)", self._walks_path, base)
54
+ _logger.info("Initializing Embedder from %s", self._walks_path)
72
55
 
73
56
  # placeholders for optional walk collections
74
57
  self.attractive_RWs : np.ndarray | None = None
@@ -127,58 +110,76 @@ class Embedder:
127
110
  RWs_expected = (time_stamp_count, node_count * num_RWs, walk_length+1) if (num_RWs > 0) else None
128
111
  SAWs_expected = (time_stamp_count, node_count * num_SAWs, walk_length+1) if (num_SAWs > 0) else None
129
112
 
130
- self.vocab_size = int(node_count)
131
- self.frame_count = int(time_stamp_count)
132
- self.walk_length = int(walk_length)
113
+ self.vocab_size = int(node_count)
114
+ self.frame_count = int(time_stamp_count)
115
+ self.walk_length = int(walk_length)
116
+ self.num_RWs = int(num_RWs)
117
+ self.num_SAWs = int(num_SAWs)
118
+ # Keep dataset names for metadata passthrough
119
+ self._attractive_RWs_name = attractive_RWs_name
120
+ self._repulsive_RWs_name = repulsive_RWs_name
121
+ self._attractive_SAWs_name = attractive_SAWs_name
122
+ self._repulsive_SAWs_name = repulsive_SAWs_name
133
123
 
134
124
  # store walks if present
135
125
  if attractive_RWs is not None:
136
126
  if RWs_expected and attractive_RWs.shape != RWs_expected:
137
127
  raise ValueError(f"ATTR RWs: expected {RWs_expected}, got {attractive_RWs.shape}")
138
128
  self.attractive_RWs = attractive_RWs
129
+ _logger.debug("ATTR RWs loaded: %s", self.attractive_RWs.shape)
139
130
 
140
131
  if repulsive_RWs is not None:
141
132
  if RWs_expected and repulsive_RWs.shape != RWs_expected:
142
133
  raise ValueError(f"REP RWs: expected {RWs_expected}, got {repulsive_RWs.shape}")
143
134
  self.repulsive_RWs = repulsive_RWs
135
+ _logger.debug("REP RWs loaded: %s", self.repulsive_RWs.shape)
144
136
 
145
137
  if attractive_SAWs is not None:
146
138
  if SAWs_expected and attractive_SAWs.shape != SAWs_expected:
147
139
  raise ValueError(f"ATTR SAWs: expected {SAWs_expected}, got {attractive_SAWs.shape}")
148
140
  self.attractive_SAWs = attractive_SAWs
141
+ _logger.debug("ATTR SAWs loaded: %s", self.attractive_SAWs.shape)
149
142
 
150
143
  if repulsive_SAWs is not None:
151
144
  if SAWs_expected and repulsive_SAWs.shape != SAWs_expected:
152
145
  raise ValueError(f"REP SAWs: expected {SAWs_expected}, got {repulsive_SAWs.shape}")
153
146
  self.repulsive_SAWs = repulsive_SAWs
147
+ _logger.debug("REP SAWs loaded: %s", self.repulsive_SAWs.shape)
154
148
 
155
149
  # INTERNAL RNG
156
150
  self._seed = np.random.randint(0, 2**32 - 1) if seed is None else int(seed)
157
151
  self.rng = np.random.default_rng(self._seed)
158
152
  _logger.info("RNG initialized from seed=%d", self._seed)
159
153
 
160
- # MODEL HANDLE
161
- self.model_base: Literal["torch", "pureml"] = base
162
- self.objective: Literal["sgns", "sg"] = objective
163
- self.model_constructor = self._get_SGNS_constructor_from(base, objective)
164
- _logger.info(
165
- "SG backend resolved: %s (objective=%s)",
166
- getattr(self.model_constructor, "__name__", repr(self.model_constructor)), self.objective
167
- )
168
-
169
154
  # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- PRIVATE -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
170
155
 
171
156
  # HELPERS:
172
157
 
173
158
  @staticmethod
174
- def _get_SGNS_constructor_from(base: Literal["torch", "pureml"],
175
- objective: Literal["sgns", "sg"]):
176
- """Resolve the SG/SGNS implementation class for the selected backend."""
159
+ def _get_NN_constructor_from(base: Literal["torch", "pureml"],
160
+ objective: Literal["sgns", "sg"]):
161
+ """Resolve the SG/SGNS implementation class for the selected backend.
162
+
163
+ Args:
164
+ base: Backend family to use, ``"torch"`` or ``"pureml"``.
165
+ objective: Training objective, ``"sgns"`` or ``"sg"``.
166
+
167
+ Returns:
168
+ A callable class (constructor) implementing the requested model.
169
+
170
+ Raises:
171
+ ImportError: If the requested backend cannot be imported.
172
+ NameError: If ``base`` is not one of the supported values.
173
+ """
174
+ _logger.debug("Resolving model constructor: base=%s objective=%s", base, objective)
177
175
  if base == "torch":
178
176
  try:
179
177
  from .SGNS_torch import SGNS_Torch, SG_Torch
180
- return SG_Torch if objective == "sg" else SGNS_Torch
178
+ ctor = SG_Torch if objective == "sg" else SGNS_Torch
179
+ _logger.debug("Resolved PyTorch class: %s", getattr(ctor, "__name__", str(ctor)))
180
+ return ctor
181
181
  except Exception:
182
+ _logger.exception("Failed to import PyTorch backend.")
182
183
  raise ImportError(
183
184
  "PyTorch is not installed, but base='torch' was requested. "
184
185
  "Install PyTorch first, e.g.: `pip install torch` "
@@ -187,8 +188,11 @@ class Embedder:
187
188
  elif base == "pureml":
188
189
  try:
189
190
  from .SGNS_pml import SGNS_PureML, SG_PureML
190
- return SG_PureML if objective == "sg" else SGNS_PureML
191
+ ctor = SG_PureML if objective == "sg" else SGNS_PureML
192
+ _logger.debug("Resolved PureML class: %s", getattr(ctor, "__name__", str(ctor)))
193
+ return ctor
191
194
  except Exception:
195
+ _logger.exception("Failed to import PureML backend.")
192
196
  raise ImportError(
193
197
  "PureML is not installed, but base='pureml' was requested. "
194
198
  "Install PureML first via `pip install ym-pure-ml` "
@@ -198,7 +202,18 @@ class Embedder:
198
202
 
199
203
  @staticmethod
200
204
  def _as_zerobase_intp(walks: np.ndarray, *, V: int) -> np.ndarray:
201
- """Validate 1-based uint/int walks 0-based intp; check bounds."""
205
+ """Validate and convert 1-based walks to 0-based ``intp``.
206
+
207
+ Args:
208
+ walks: 2D array of node ids with 1-based indexing.
209
+ V: Vocabulary size for bounds checking.
210
+
211
+ Returns:
212
+ 2D array of dtype ``intp`` with 0-based indices.
213
+
214
+ Raises:
215
+ ValueError: If shape is not 2D or indices are out of bounds.
216
+ """
202
217
  arr = np.asarray(walks)
203
218
  if arr.ndim != 2:
204
219
  raise ValueError("walks must be 2D: (num_walks, walk_len)")
@@ -206,7 +221,9 @@ class Embedder:
206
221
  arr = arr.astype(np.int64, copy=False)
207
222
  # 1-based → 0-based
208
223
  arr = arr - 1
209
- if arr.min() < 0 or arr.max() >= V:
224
+ mn, mx = int(arr.min()), int(arr.max())
225
+ _logger.debug("Zero-basing walks: min=%d max=%d V=%d", mn, mx, V)
226
+ if mn < 0 or mx >= V:
210
227
  raise ValueError("walk ids out of range after 1→0-based normalization")
211
228
  return arr.astype(np.intp, copy=False)
212
229
 
@@ -214,19 +231,29 @@ class Embedder:
214
231
  def _pairs_from_walks(walks0: np.ndarray, window_size: int) -> np.ndarray:
215
232
  """
216
233
  Skip-gram pairs including edge centers (one-sided when needed).
217
- walks0: (W, L) int array (0-based ids).
218
- Returns: (N_pairs, 2) int32 [center, context].
234
+
235
+ Args:
236
+ walks0: (W, L) int array (0-based ids).
237
+ window_size: Symmetric context window radius.
238
+
239
+ Returns:
240
+ Array of shape (N_pairs, 2) int32 with columns [center, context].
241
+
242
+ Raises:
243
+ ValueError: If shape is invalid or ``window_size`` <= 0.
219
244
  """
220
245
  if walks0.ndim != 2:
221
246
  raise ValueError("walks must be 2D: (num_walks, walk_len)")
222
247
 
223
248
  _, L = walks0.shape
224
249
  k = int(window_size)
250
+ _logger.debug("Building SG pairs: L=%d, window=%d", L, k)
225
251
 
226
252
  if k <= 0:
227
253
  raise ValueError("window_size must be positive")
228
254
 
229
255
  if L == 0:
256
+ _logger.debug("Empty walks length; returning 0 pairs.")
230
257
  return np.empty((0, 2), dtype=np.int32)
231
258
 
232
259
  out_chunks = []
@@ -244,18 +271,42 @@ class Embedder:
244
271
  out_chunks.append(np.stack((centers_l, ctx_l), axis=2).reshape(-1, 2))
245
272
 
246
273
  if not out_chunks:
274
+ _logger.debug("No offsets produced pairs; returning empty.")
247
275
  return np.empty((0, 2), dtype=np.int32)
248
276
 
249
- return np.concatenate(out_chunks, axis=0).astype(np.int32, copy=False)
277
+ pairs = np.concatenate(out_chunks, axis=0).astype(np.int32, copy=False)
278
+ _logger.debug("Built %d training pairs", pairs.shape[0])
279
+ return pairs
250
280
 
251
281
  @staticmethod
252
282
  def _freq_from_walks(walks0: np.ndarray, *, V: int) -> np.ndarray:
253
- """Node frequencies from walks (0-based)."""
254
- return np.bincount(walks0.ravel(), minlength=V).astype(np.int64, copy=False)
283
+ """Node frequencies from walks (0-based).
284
+
285
+ Args:
286
+ walks0: 2D array of 0-based node ids.
287
+ V: Vocabulary size (minlength for bincount).
288
+
289
+ Returns:
290
+ 1D array of int64 frequencies with length ``V``.
291
+ """
292
+ freq = np.bincount(walks0.ravel(), minlength=V).astype(np.int64, copy=False)
293
+ _logger.debug("Frequency mass: total=%d nonzero=%d", int(freq.sum()), int(np.count_nonzero(freq)))
294
+ return freq
255
295
 
256
296
  @staticmethod
257
297
  def _soft_unigram(freq: np.ndarray, *, power: float = 0.75) -> np.ndarray:
258
- """Return normalized Pn(w) ∝ f(w)^power as float64 probs."""
298
+ """Return normalized Pn(w) ∝ f(w)^power as float64 probs.
299
+
300
+ Args:
301
+ freq: 1D array of token frequencies.
302
+ power: Exponent used for smoothing (default 0.75 à la word2vec).
303
+
304
+ Returns:
305
+ 1D array of probabilities summing to 1.0.
306
+
307
+ Raises:
308
+ ValueError: If mass is invalid (all zeros or non-finite).
309
+ """
259
310
  p = np.asarray(freq, dtype=np.float64)
260
311
  if p.sum() == 0:
261
312
  raise ValueError("all frequencies are zero")
@@ -263,13 +314,31 @@ class Embedder:
263
314
  s = p.sum()
264
315
  if not np.isfinite(s) or s <= 0:
265
316
  raise ValueError("invalid unigram mass")
266
- return p / s
317
+ probs = p / s
318
+ _logger.debug("Noise distribution ready (power=%.3f)", power)
319
+ return probs
267
320
 
268
321
  def _materialize_walks(self, frame_id: int, rin: Literal["attr", "repuls"],
269
322
  using: Literal["RW", "SAW", "merged"]) -> np.ndarray:
323
+ """Materialize the requested set of walks for a frame.
324
+
325
+ Args:
326
+ frame_id: 1-based frame index.
327
+ rin: Which RIN to pull from: ``"attr"`` or ``"repuls"``.
328
+ using: Which walk sets to include: ``"RW"``, ``"SAW"``, or ``"merged"``.
329
+ If ``"merged"``, concatenate available RW and SAW along axis 0.
330
+
331
+ Returns:
332
+ A 2D array of walks with shape (num_walks, walk_length+1).
333
+
334
+ Raises:
335
+ IndexError: If ``frame_id`` is out of range.
336
+ ValueError: If no matching walks are available.
337
+ """
270
338
  if not 1 <= frame_id <= int(self.frame_count):
271
339
  raise IndexError(f"frame_id must be in [1, {self.frame_count}]; got {frame_id}")
272
340
 
341
+ _logger.debug("Materializing %s walks at frame=%d using=%s", rin, frame_id, using)
273
342
  frame_id -= 1
274
343
 
275
344
  if rin == "attr":
@@ -296,8 +365,12 @@ class Embedder:
296
365
  if not parts:
297
366
  raise ValueError(f"No walks available for {rin=} with {using=}")
298
367
  if len(parts) == 1:
299
- return parts[0]
300
- return np.concatenate(parts, axis=0)
368
+ out = parts[0]
369
+ else:
370
+ out = np.concatenate(parts, axis=0)
371
+
372
+ _logger.debug("Materialized walks shape: %s", getattr(out, "shape", None))
373
+ return out
301
374
 
302
375
  # INTERFACES: (private)
303
376
 
@@ -306,6 +379,17 @@ class Embedder:
306
379
  using: Literal["RW", "SAW", "merged"],
307
380
  window_size: int,
308
381
  alpha: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
382
+ """Construct (center, context) pairs and noise distribution for ATTR.
383
+
384
+ Args:
385
+ frame_id: 1-based frame index.
386
+ using: Walk subset to include.
387
+ window_size: Skip-gram window radius.
388
+ alpha: Unigram smoothing exponent.
389
+
390
+ Returns:
391
+ Tuple of (pairs, noise_probs).
392
+ """
309
393
  walks = self._materialize_walks(frame_id, "attr", using)
310
394
  walks0 = self._as_zerobase_intp(walks, V=self.vocab_size)
311
395
  attractive_corpus = self._pairs_from_walks(walks0, window_size)
@@ -319,6 +403,17 @@ class Embedder:
319
403
  using: Literal["RW", "SAW", "merged"],
320
404
  window_size: int,
321
405
  alpha: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
406
+ """Construct (center, context) pairs and noise distribution for REP.
407
+
408
+ Args:
409
+ frame_id: 1-based frame index.
410
+ using: Walk subset to include.
411
+ window_size: Skip-gram window radius.
412
+ alpha: Unigram smoothing exponent.
413
+
414
+ Returns:
415
+ Tuple of (pairs, noise_probs).
416
+ """
322
417
  walks = self._materialize_walks(frame_id, "repuls", using)
323
418
  walks0 = self._as_zerobase_intp(walks, V=self.vocab_size)
324
419
  repulsive_corpus = self._pairs_from_walks(walks0, window_size)
@@ -333,61 +428,53 @@ class Embedder:
333
428
  frame_id: int,
334
429
  RIN_type: Literal["attr", "repuls"],
335
430
  using: Literal["RW", "SAW", "merged"],
336
- window_size: int,
337
- num_negative_samples: int,
338
431
  num_epochs: int,
339
- batch_size: int,
432
+ negative_sampling: bool = False,
433
+ window_size: int = 5,
434
+ num_negative_samples: int = 10,
435
+ batch_size: int = 1024,
340
436
  *,
341
437
  lr_step_per_batch: bool = False,
342
438
  shuffle_data: bool = True,
343
439
  dimensionality: int = 128,
344
440
  alpha: float = 0.75,
345
441
  device: str | None = None,
346
- sgns_kwargs: dict[str, object] | None = None,
442
+ model_base: Literal["torch", "pureml"] = "pureml",
443
+ model_kwargs: dict[str, object] | None = None,
347
444
  kind: Literal["in", "out", "avg"] = "in",
348
- _seed: int | None = None,
349
- objective: Literal["sgns", "sg"] | None = None
445
+ _seed: int | None = None
350
446
  ) -> np.ndarray:
351
- """Train embeddings for a single frame and return the input embedding matrix.
447
+ """Train embeddings for a single frame and return the matrix.
352
448
 
353
449
  Args:
354
- frame_id: 1-based frame index to train on.
355
- RIN_type: Interaction channel to use: ``"attr"`` (attractive) or
356
- ``"repuls"`` (repulsive).
357
- using: Which walk collections to include: ``"RW"``, ``"SAW"``, or
358
- ``"merged"`` (concatenates both if available).
359
- window_size: Symmetric skip-gram window size ``k``.
360
- num_negative_samples: Number of negative samples per positive pair.
361
- Ignored when ``objective="sg"``.
362
- num_epochs: Number of passes over the pair dataset.
363
- batch_size: Mini-batch size for training.
364
- shuffle_data: Whether to shuffle pairs each epoch.
365
- dimensionality: Embedding dimensionality ``D``.
366
- alpha: Noise distribution exponent (``Pn ∝ f^alpha``).
367
- device: Optional device string for the Torch backend (e.g., ``"cuda"``).
368
- sgns_kwargs: Extra keyword arguments forwarded to the backend SGNS
369
- constructor. For PureML, required keys are:
370
- ``{"optim", "optim_kwargs"}``; ``lr_sched`` is optional, but if
371
- provided then ``lr_sched_kwargs`` must also be provided.
372
- kind: Which embedding matrix to return: ``"in"``, ``"out"``, or ``"avg"``.
373
- _seed: Optional child seed for this frame's model initialization.
374
- objective: Training objective override for this call (``"sgns"`` or
375
- ``"sg"``). If ``None``, uses the value set at construction.
450
+ frame_id: 1-based frame index to embed.
451
+ RIN_type: ``"attr"`` or ``"repuls"`` - which corpus to use.
452
+ using: Which walks to use (``"RW"``, ``"SAW"``, or ``"merged"``).
453
+ num_epochs: Number of passes over the pairs.
454
+ negative_sampling: If ``True``, use SGNS objective; else plain SG.
455
+ window_size: Skip-gram symmetric window radius.
456
+ num_negative_samples: Negatives per positive pair (SGNS only).
457
+ batch_size: Minibatch size for training.
458
+ lr_step_per_batch: If ``True``, step LR every batch (else per epoch).
459
+ shuffle_data: Shuffle pairs each epoch.
460
+ dimensionality: Embedding dimension ``D``.
461
+ alpha: Unigram smoothing power for noise distribution.
462
+ device: Optional backend device hint (e.g., ``"cuda"``).
463
+ model_base: Backend family (``"torch"`` or ``"pureml"``).
464
+ model_kwargs: Passed through to backend model constructor.
465
+ kind: Which embedding to return: ``"in"``, ``"out"``, or ``"avg"``.
466
+ _seed: Optional override seed for this frame.
376
467
 
377
468
  Returns:
378
- np.ndarray: Learned embedding matrix (selected by ``kind``) of shape ``(V, D)``.
379
-
380
- Raises:
381
- ValueError: If requested walks are missing, if no training pairs are
382
- generated, or if required ``sgns_kwargs`` for PureML are absent.
383
- AttributeError: If the SGNS model does not expose embeddings via
384
- ``.embeddings`` or ``.parameters[0]``.
469
+ ``(V, D)`` float32 embedding matrix.
385
470
  """
386
471
  _logger.info(
387
- "Preparing frame %d (rin=%s using=%s window=%d neg=%d epochs=%d batch=%d)",
388
- frame_id, RIN_type, using, window_size, num_negative_samples, num_epochs, batch_size
472
+ "embed_frame: frame=%d RIN=%s using=%s base=%s D=%d epochs=%d batch=%d sgns=%s k=%d alpha=%.3f",
473
+ frame_id, RIN_type, using, model_base, dimensionality, num_epochs, batch_size,
474
+ str(negative_sampling), window_size, alpha
389
475
  )
390
476
 
477
+ # ------------------ resolve training data -----------------
391
478
  if RIN_type == "attr":
392
479
  if self.attractive_RWs is None and self.attractive_SAWs is None:
393
480
  raise ValueError("Attractive random walks are missing")
@@ -397,162 +484,124 @@ class Embedder:
397
484
  raise ValueError("Repulsive random walks are missing")
398
485
  pairs, noise_probs = self._repulsive_corpus_and_prob(frame_id=frame_id, using=using, window_size=window_size, alpha=alpha)
399
486
  else:
400
- raise ValueError(f"Unknown RIN_type: {RIN_type!r}")
401
-
487
+ raise NameError(f"Unknown RIN_type: {RIN_type!r}")
402
488
  if pairs.size == 0:
403
489
  raise ValueError("No training pairs generated for the requested configuration")
490
+ # ----------------------------------------------------------
404
491
 
492
+ # ---------------- construct training corpus ---------------
405
493
  centers = pairs[:, 0].astype(np.int64, copy=False)
406
494
  contexts = pairs[:, 1].astype(np.int64, copy=False)
495
+ _logger.debug("Pairs split: centers=%s contexts=%s", centers.shape, contexts.shape)
496
+ # ----------------------------------------------------------
407
497
 
408
- model_kwargs: dict[str, object] = dict(sgns_kwargs or {})
409
- if self.model_base == "pureml":
410
- required = {"optim", "optim_kwargs"}
411
- missing = required.difference(model_kwargs)
412
- if missing:
413
- raise ValueError(f"PureML backend requires {sorted(missing)} in sgns_kwargs.")
414
- has_sched = ("lr_sched" in model_kwargs and model_kwargs["lr_sched"] is not None)
415
- has_sched_kwargs = ("lr_sched_kwargs" in model_kwargs and model_kwargs["lr_sched_kwargs"] is not None)
416
- if has_sched and not has_sched_kwargs:
417
- raise ValueError("When providing lr_sched for PureML, you must also provide lr_sched_kwargs.")
418
-
419
- child_seed = int(self._seed if _seed is None else _seed)
420
- model_kwargs.update({
498
+ # ------------ resolve model_constructor kwargs ------------
499
+ if model_kwargs is not None:
500
+ if (("lr_sched" in model_kwargs and model_kwargs.get("lr_sched", None) is not None)
501
+ and ("lr_sched_kwargs" in model_kwargs and model_kwargs.get("lr_sched_kwargs", None) is None)):
502
+ raise ValueError("When `lr_sched`, you must also provide `lr_sched_kwargs`.")
503
+
504
+ constructor_kwargs: dict[str, object] = dict(model_kwargs or {})
505
+ constructor_kwargs.update({
421
506
  "V": self.vocab_size,
422
507
  "D": dimensionality,
423
- "seed": child_seed
508
+ "seed": int(self._seed if _seed is None else _seed),
509
+ "device": device
424
510
  })
425
-
426
- if self.model_base == "torch" and device is not None:
427
- model_kwargs["device"] = device
428
-
429
- # Resolve objective (call-level override beats constructor default)
430
- obj = self.objective if objective is None else objective
431
- self.model_constructor = self._get_SGNS_constructor_from(self.model_base, obj)
432
- self.model = self.model_constructor(**model_kwargs)
433
-
434
- _logger.info(
435
- "Training SG base=%s constructor=%s objective=%s frame=%d pairs=%d dim=%d epochs=%d batch=%d neg=%d shuffle=%s",
436
- self.model_base,
437
- getattr(self.model_constructor, "__name__", repr(self.model_constructor)),
438
- obj,
439
- frame_id,
440
- pairs.shape[0],
441
- dimensionality,
442
- num_epochs,
443
- batch_size,
444
- num_negative_samples,
445
- shuffle_data
446
- )
447
-
448
- if obj == "sgns":
449
- self.model.fit(
450
- centers,
451
- contexts,
452
- num_epochs,
453
- batch_size,
454
- num_negative_samples,
455
- noise_probs,
456
- shuffle_data,
457
- lr_step_per_batch
458
- )
459
- else:
460
- self.model.fit(
461
- centers,
462
- contexts,
463
- num_epochs,
464
- batch_size,
465
- shuffle_data,
466
- lr_step_per_batch
511
+ _logger.debug("Model constructor kwargs: %s", {k: constructor_kwargs[k] for k in ("V","D","seed","device")})
512
+ # ----------------------------------------------------------
513
+
514
+ # --------------- resolve model constructor ----------------
515
+ model_constructor = self._get_NN_constructor_from(
516
+ model_base, objective=("sgns" if negative_sampling else "sg"))
517
+ # ----------------------------------------------------------
518
+
519
+ # ------------------ initialize the model ------------------
520
+ model = model_constructor(**constructor_kwargs)
521
+ _logger.debug("Model initialized: %s", model_constructor.__name__ if hasattr(model_constructor,"__name__") else str(model_constructor))
522
+ # ----------------------------------------------------------
523
+
524
+ # -------------------- fitting the data --------------------
525
+ _logger.info("Fitting model on %d pairs ...", pairs.shape[0])
526
+ model.fit(centers=centers,
527
+ contexts=contexts,
528
+ num_epochs=num_epochs,
529
+ batch_size=batch_size,
530
+ # -- optional; for SGNS; safely ignored by SG via **_ignore --
531
+ num_negative_samples=num_negative_samples,
532
+ noise_dist=noise_probs,
533
+ # -----------------------------------------
534
+ shuffle_data=shuffle_data,
535
+ lr_step_per_batch=lr_step_per_batch
467
536
  )
468
-
469
- # Select embedding matrix by kind
470
- if kind == "in":
471
- embeddings = getattr(self.model, "in_embeddings", None)
472
- if embeddings is None:
473
- embeddings = getattr(self.model, "embeddings", None)
474
- if embeddings is None:
475
- params = getattr(self.model, "parameters", None)
476
- if isinstance(params, tuple) and params:
477
- embeddings = params[0]
478
- elif kind == "out":
479
- embeddings = getattr(self.model, "out_embeddings", None)
480
- else: # "avg"
481
- embeddings = getattr(self.model, "avg_embeddings", None)
537
+ _logger.info("Training complete for frame %d", frame_id)
538
+ # ----------------------------------------------------------
539
+
540
+ # OUTPUT:
541
+ embeddings = (model.in_embeddings if kind == "in" else
542
+ model.out_embeddings if kind == "out" else
543
+ model.avg_embeddings if kind == "avg" else
544
+ None
545
+ )
482
546
 
483
547
  if embeddings is None:
484
- raise AttributeError(
485
- "SG/SGNS model does not expose the requested embeddings: "
486
- f"kind={kind!r}"
487
- )
548
+ if kind not in ("in", "out", "avg"):
549
+ raise NameError(f"Unknown {kind} embeddings kind. Expected: one of ['in', 'out', 'avg']")
488
550
 
489
- embeddings = np.asarray(embeddings)
490
- _logger.info("Frame %d embeddings ready: shape=%s dtype=%s", frame_id, embeddings.shape, embeddings.dtype)
491
- return embeddings
551
+ E = np.asarray(embeddings, dtype=np.float32)
552
+ _logger.debug("Returned embeddings shape: %s dtype=%s", E.shape, E.dtype)
553
+ return E
492
554
 
493
555
  def embed_all(
494
556
  self,
495
557
  RIN_type: Literal["attr", "repuls"],
496
558
  using: Literal["RW", "SAW", "merged"],
497
- window_size: int,
498
- num_negative_samples: int,
499
559
  num_epochs: int,
500
- batch_size: int,
560
+ negative_sampling: bool = False,
561
+ window_size: int = 2,
562
+ num_negative_samples: int = 10,
563
+ batch_size: int = 1024,
501
564
  *,
565
+ lr_step_per_batch: bool = False,
502
566
  shuffle_data: bool = True,
503
567
  dimensionality: int = 128,
504
568
  alpha: float = 0.75,
505
569
  device: str | None = None,
506
- sgns_kwargs: dict[str, object] | None = None,
570
+ model_base: Literal["torch", "pureml"] = "pureml",
571
+ model_kwargs: dict[str, object] | None = None,
572
+ kind: Literal["in", "out", "avg"] = "in",
507
573
  output_path: str | Path | None = None,
508
574
  num_matrices_in_compressed_blocks: int = 20,
509
575
  compression_level: int = 3,
510
- objective: Literal["sgns", "sg"] | None = None,
511
- kind: Literal["in", "out", "avg"] = "in"):
512
- """Train embeddings for all frames and persist them to compressed storage.
576
+ ) -> str:
577
+ """Embed all frames and persist a self-contained archive.
513
578
 
514
- Iterates through all frames (``1..frame_count``), trains an SGNS model
515
- per frame using the configured backend, collects the resulting input
516
- embeddings, and writes them into a new compressed ``ArrayStorage`` archive.
579
+ The resulting file stores a block named ``FRAME_EMBEDDINGS`` with a
580
+ compressed sequence of per-frame matrices (each ``(V, D)``), alongside
581
+ rich metadata mirroring the style of other SAWNERGY modules.
517
582
 
518
583
  Args:
519
- RIN_type: Interaction channel to use: ``"attr"`` or ``"repuls"``.
520
- using: Walk collections: ``"RW"``, ``"SAW"``, or ``"merged"``.
521
- window_size: Symmetric skip-gram window size ``k``.
522
- num_negative_samples: Number of negative samples per positive pair.
523
- Ignored when ``objective="sg"``.
524
- num_epochs: Number of epochs for each frame.
525
- batch_size: Mini-batch size used during training.
526
- shuffle_data: Whether to shuffle pairs each epoch.
527
- dimensionality: Embedding dimensionality ``D``.
528
- alpha: Noise distribution exponent (``Pn ∝ f^alpha``).
529
- device: Optional device string for Torch backend.
530
- sgns_kwargs: Extra constructor kwargs for the SGNS backend (see
531
- :meth:`embed_frame` for PureML requirements).
532
- output_path: Destination path. If ``None``, a new file named
533
- ``EMBEDDINGS_<timestamp>.zip`` is created next to the source
534
- WALKS archive. If the provided path lacks a suffix, ``.zip`` is
535
- appended.
536
- num_matrices_in_compressed_blocks: Number of per-frame matrices to
537
- store per compressed chunk in the output archive.
538
- compression_level: Blosc Zstd compression level (0-9).
539
- objective: Training objective for all frames (``"sgns"`` or ``"sg"``).
540
- If ``None``, uses the value set at construction.
541
- kind: Which embedding matrix to persist for each frame: ``"in"``,
542
- ``"out"``, or ``"avg"``.
584
+ RIN_type: ``"attr"`` or ``"repuls"`` - which corpus to use.
585
+ using: Which walks to use (``"RW"``, ``"SAW"``, or ``"merged"``).
586
+ num_epochs: Number of epochs to train per frame.
587
+ negative_sampling: If ``True``, use SGNS; otherwise plain SG.
588
+ window_size: Skip-gram window radius.
589
+ num_negative_samples: Negatives per positive pair (SGNS).
590
+ batch_size: Minibatch size for training.
591
+ lr_step_per_batch: If ``True``, step LR per batch (else per epoch).
592
+ shuffle_data: Shuffle pairs each epoch.
593
+ dimensionality: Embedding dimension.
594
+ alpha: Unigram smoothing power for noise distribution.
595
+ device: Backend device hint (e.g., ``"cuda"``).
596
+ model_base: Backend family (``"torch"`` or ``"pureml"``).
597
+ model_kwargs: Passed through to backend model constructor.
598
+ kind: Which embedding to store: ``"in"``, ``"out"``, or ``"avg"``.
599
+ output_path: Optional path for the output archive (``.zip`` inferred).
600
+ num_matrices_in_compressed_blocks: How many frames per compressed chunk.
601
+ compression_level: Integer compression level for the archive.
543
602
 
544
603
  Returns:
545
- str: Filesystem path to the written embeddings archive (``.zip``).
546
-
547
- Raises:
548
- ValueError: If configuration produces no pairs for a frame or if
549
- PureML kwargs are incomplete.
550
- RuntimeError: Propagated from storage operations on failure.
551
-
552
- Notes:
553
- - A deterministic child seed is spawned per frame from the master
554
- seed using ``np.random.SeedSequence`` to ensure reproducibility
555
- across runs.
604
+ Path to the created embeddings archive, as ``str``.
556
605
  """
557
606
  current_time = sawnergy_util.current_time()
558
607
  if output_path is None:
@@ -563,71 +612,102 @@ class Embedder:
563
612
  output_path = output_path.with_suffix(".zip")
564
613
 
565
614
  _logger.info(
566
- "Embedding all frames -> %s | frames=%d dim=%d base=%s",
567
- output_path, self.frame_count, dimensionality, self.model_base
615
+ "embed_all: frames=%d D=%d base=%s RIN=%s using=%s out=%s",
616
+ self.frame_count, dimensionality, model_base, RIN_type, using, output_path
568
617
  )
569
618
 
619
+ # Per-frame deterministic seeds
570
620
  master_ss = np.random.SeedSequence(self._seed)
571
621
  child_seeds = master_ss.spawn(self.frame_count)
572
622
 
573
- embeddings = []
574
- for frame_idx, seed_seq in enumerate(child_seeds, start=1):
623
+ embeddings: list[np.ndarray] = []
624
+ for frame_id, seed_seq in enumerate(child_seeds, start=1):
575
625
  child_seed = int(seed_seq.generate_state(1, dtype=np.uint32)[0])
576
- _logger.info("Processing frame %d/%d (child_seed=%d entropy=%d)", frame_idx, self.frame_count, child_seed, seed_seq.entropy)
577
- embeddings.append(
578
- self.embed_frame(
579
- frame_idx,
580
- RIN_type,
581
- using,
582
- window_size,
583
- num_negative_samples,
584
- num_epochs,
585
- batch_size,
626
+ _logger.info("Embedding frame %d/%d with seed=%d", frame_id, self.frame_count, child_seed)
627
+ E = self.embed_frame(
628
+ frame_id=frame_id,
629
+ RIN_type=RIN_type,
630
+ using=using,
631
+ num_epochs=num_epochs,
632
+ negative_sampling=negative_sampling,
633
+ window_size=window_size,
634
+ num_negative_samples=num_negative_samples,
635
+ batch_size=batch_size,
636
+ lr_step_per_batch=lr_step_per_batch,
586
637
  shuffle_data=shuffle_data,
587
638
  dimensionality=dimensionality,
588
639
  alpha=alpha,
589
640
  device=device,
590
- sgns_kwargs=sgns_kwargs,
591
- _seed=child_seed,
592
- objective=objective,
593
- kind=kind
641
+ model_base=model_base,
642
+ model_kwargs=model_kwargs,
643
+ kind=kind,
644
+ _seed=child_seed
594
645
  )
595
- )
646
+ embeddings.append(np.asarray(E, dtype=np.float32, copy=False))
647
+ _logger.debug("Frame %d embedded: E.shape=%s", frame_id, E.shape)
596
648
 
597
- embeddings = [np.asarray(e) for e in embeddings]
598
649
  block_name = "FRAME_EMBEDDINGS"
599
650
  with sawnergy_util.ArrayStorage.compress_and_cleanup(output_path, compression_level=compression_level) as storage:
651
+ _logger.info("Writing %d frame matrices to block '%s' ...", len(embeddings), block_name)
600
652
  storage.write(
601
653
  these_arrays=embeddings,
602
654
  to_block_named=block_name,
603
655
  arrays_per_chunk=num_matrices_in_compressed_blocks
604
656
  )
605
- storage.add_attr("time_created", current_time)
606
- storage.add_attr("seed", int(self._seed))
607
- storage.add_attr("rng_scheme", "SeedSequence.spawn_per_frame_v1")
608
- storage.add_attr("source_walks_path", str(self._walks_path))
609
- storage.add_attr("model_base", self.model_base)
610
- storage.add_attr("rin_type", RIN_type)
611
- storage.add_attr("using_mode", using)
657
+
658
+ # Core dataset discovery (for consumers like the Embeddings Visualizer)
659
+ storage.add_attr("frame_embeddings_name", block_name)
660
+ storage.add_attr("time_stamp_count", int(self.frame_count))
661
+ storage.add_attr("node_count", int(self.vocab_size))
662
+ storage.add_attr("embedding_dim", int(dimensionality))
663
+
664
+ # Provenance of input WALKS
665
+ storage.add_attr("source_WALKS_path", str(self._walks_path))
666
+ storage.add_attr("walk_length", int(self.walk_length))
667
+ storage.add_attr("num_RWs", int(self.num_RWs))
668
+ storage.add_attr("num_SAWs", int(self.num_SAWs))
669
+ storage.add_attr("attractive_RWs_name", self._attractive_RWs_name)
670
+ storage.add_attr("repulsive_RWs_name", self._repulsive_RWs_name)
671
+ storage.add_attr("attractive_SAWs_name", self._attractive_SAWs_name)
672
+ storage.add_attr("repulsive_SAWs_name", self._repulsive_SAWs_name)
673
+
674
+ # Training configuration (sufficient to reproduce)
675
+ storage.add_attr("objective", "sgns" if negative_sampling else "sg")
676
+ storage.add_attr("model_base", model_base)
677
+ storage.add_attr("embedding_kind", kind) # 'in' | 'out' | 'avg'
678
+ storage.add_attr("num_epochs", int(num_epochs))
679
+ storage.add_attr("batch_size", int(batch_size))
612
680
  storage.add_attr("window_size", int(window_size))
613
681
  storage.add_attr("alpha", float(alpha))
614
- storage.add_attr("dimensionality", int(dimensionality))
682
+ storage.add_attr("negative_sampling", bool(negative_sampling))
615
683
  storage.add_attr("num_negative_samples", int(num_negative_samples))
616
- storage.add_attr("num_epochs", int(num_epochs))
617
- storage.add_attr("batch_size", int(batch_size))
684
+ storage.add_attr("lr_step_per_batch", bool(lr_step_per_batch))
618
685
  storage.add_attr("shuffle_data", bool(shuffle_data))
619
- storage.add_attr("frames_written", int(len(embeddings)))
620
- storage.add_attr("vocab_size", int(self.vocab_size))
621
- storage.add_attr("frame_count", int(self.frame_count))
622
- storage.add_attr("embedding_dtype", str(embeddings[0].dtype))
623
- storage.add_attr("frame_embeddings_name", block_name)
686
+ storage.add_attr("device_hint", device if device is not None else "")
687
+ storage.add_attr("model_kwargs_repr", repr(model_kwargs) if model_kwargs is not None else "{}")
688
+
689
+ # Which walks were used to train
690
+ storage.add_attr("RIN_type", RIN_type) # 'attr' or 'repuls'
691
+ storage.add_attr("using", using) # 'RW' | 'SAW' | 'merged'
692
+
693
+ # Reproducibility
694
+ storage.add_attr("master_seed", int(self._seed))
695
+ # Note: this records seeds derived from child SeedSequences at metadata time.
696
+ storage.add_attr("per_frame_seeds", [int(s.generate_state(1, dtype=np.uint32)[0]) for s in child_seeds])
697
+
698
+ # Archive/IO details
624
699
  storage.add_attr("arrays_per_chunk", int(num_matrices_in_compressed_blocks))
625
700
  storage.add_attr("compression_level", int(compression_level))
626
- storage.add_attr("objective", self.objective if objective is None else objective)
701
+ storage.add_attr("created_at", current_time)
702
+
703
+ _logger.info(
704
+ "Stored embeddings archive: %s | shape=(T,N,D)=(%d,%d,%d)",
705
+ output_path, self.frame_count, self.vocab_size, dimensionality
706
+ )
627
707
 
628
- _logger.info("Embedding archive written to %s", output_path)
629
708
  return str(output_path)
630
709
 
710
+
631
711
  __all__ = ["Embedder"]
632
712
 
633
713
  if __name__ == "__main__":