sawnergy 1.0.6__py3-none-any.whl → 1.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sawnergy might be problematic. Click here for more details.

@@ -1,15 +1,5 @@
1
1
  from __future__ import annotations
2
2
 
3
- """
4
- Embedding orchestration for Skip-Gram with Negative Sampling (SGNS).
5
-
6
- This module consumes attractive/repulsive walk corpora produced by the walker
7
- pipeline and trains per-frame embeddings using either the PyTorch or PureML
8
- implementations of SGNS. The resulting embeddings can be persisted back into
9
- an ``ArrayStorage`` archive along with rich metadata describing the training
10
- configuration.
11
- """
12
-
13
3
  # third-pary
14
4
  import numpy as np
15
5
 
@@ -36,9 +26,8 @@ class Embedder:
36
26
 
37
27
  def __init__(self,
38
28
  WALKS_path: str | Path,
39
- base: Literal["torch", "pureml"],
40
29
  *,
41
- seed: int | None = None
30
+ seed: int | None = None,
42
31
  ) -> None:
43
32
  """Initialize the embedder and load walk tensors.
44
33
 
@@ -50,22 +39,19 @@ class Embedder:
50
39
  ``None`` if that collection is absent), and the metadata
51
40
  ``num_RWs``, ``num_SAWs``, ``node_count``, ``time_stamp_count``,
52
41
  ``walk_length``.
53
- base: Which SGNS backend to use, either ``"torch"`` or ``"pureml"``.
54
42
  seed: Optional seed for the embedder's RNG. If ``None``, a random
55
43
  32-bit seed is chosen.
56
44
 
57
45
  Raises:
58
46
  ValueError: If required metadata is missing or any loaded walk array
59
47
  has an unexpected shape.
60
- ImportError: If the requested backend is not installed.
61
- NameError: If ``base`` is not one of ``{"torch","pureml"}``.
62
48
 
63
49
  Notes:
64
50
  - Walks in storage are 1-based (residue indexing). Internally, this
65
51
  class normalizes to 0-based indices for training utilities.
66
52
  """
67
53
  self._walks_path = Path(WALKS_path)
68
- _logger.info("Initializing Embedder from %s (base=%s)", self._walks_path, base)
54
+ _logger.info("Initializing Embedder from %s", self._walks_path)
69
55
 
70
56
  # placeholders for optional walk collections
71
57
  self.attractive_RWs : np.ndarray | None = None
@@ -124,53 +110,76 @@ class Embedder:
124
110
  RWs_expected = (time_stamp_count, node_count * num_RWs, walk_length+1) if (num_RWs > 0) else None
125
111
  SAWs_expected = (time_stamp_count, node_count * num_SAWs, walk_length+1) if (num_SAWs > 0) else None
126
112
 
127
- self.vocab_size = int(node_count)
128
- self.frame_count = int(time_stamp_count)
129
- self.walk_length = int(walk_length)
113
+ self.vocab_size = int(node_count)
114
+ self.frame_count = int(time_stamp_count)
115
+ self.walk_length = int(walk_length)
116
+ self.num_RWs = int(num_RWs)
117
+ self.num_SAWs = int(num_SAWs)
118
+ # Keep dataset names for metadata passthrough
119
+ self._attractive_RWs_name = attractive_RWs_name
120
+ self._repulsive_RWs_name = repulsive_RWs_name
121
+ self._attractive_SAWs_name = attractive_SAWs_name
122
+ self._repulsive_SAWs_name = repulsive_SAWs_name
130
123
 
131
124
  # store walks if present
132
125
  if attractive_RWs is not None:
133
126
  if RWs_expected and attractive_RWs.shape != RWs_expected:
134
127
  raise ValueError(f"ATTR RWs: expected {RWs_expected}, got {attractive_RWs.shape}")
135
128
  self.attractive_RWs = attractive_RWs
129
+ _logger.debug("ATTR RWs loaded: %s", self.attractive_RWs.shape)
136
130
 
137
131
  if repulsive_RWs is not None:
138
132
  if RWs_expected and repulsive_RWs.shape != RWs_expected:
139
133
  raise ValueError(f"REP RWs: expected {RWs_expected}, got {repulsive_RWs.shape}")
140
134
  self.repulsive_RWs = repulsive_RWs
135
+ _logger.debug("REP RWs loaded: %s", self.repulsive_RWs.shape)
141
136
 
142
137
  if attractive_SAWs is not None:
143
138
  if SAWs_expected and attractive_SAWs.shape != SAWs_expected:
144
139
  raise ValueError(f"ATTR SAWs: expected {SAWs_expected}, got {attractive_SAWs.shape}")
145
140
  self.attractive_SAWs = attractive_SAWs
141
+ _logger.debug("ATTR SAWs loaded: %s", self.attractive_SAWs.shape)
146
142
 
147
143
  if repulsive_SAWs is not None:
148
144
  if SAWs_expected and repulsive_SAWs.shape != SAWs_expected:
149
145
  raise ValueError(f"REP SAWs: expected {SAWs_expected}, got {repulsive_SAWs.shape}")
150
146
  self.repulsive_SAWs = repulsive_SAWs
147
+ _logger.debug("REP SAWs loaded: %s", self.repulsive_SAWs.shape)
151
148
 
152
149
  # INTERNAL RNG
153
150
  self._seed = np.random.randint(0, 2**32 - 1) if seed is None else int(seed)
154
151
  self.rng = np.random.default_rng(self._seed)
155
152
  _logger.info("RNG initialized from seed=%d", self._seed)
156
153
 
157
- # MODEL HANDLE
158
- self.model_base: Literal["torch", "pureml"] = base
159
- self.model_constructor = self._get_SGNS_constructor_from(base)
160
- _logger.info("SGNS backend resolved: %s", getattr(self.model_constructor, "__name__", repr(self.model_constructor)))
161
-
162
154
  # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- PRIVATE -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
163
155
 
164
156
  # HELPERS:
165
157
 
166
158
  @staticmethod
167
- def _get_SGNS_constructor_from(base: Literal["torch", "pureml"]):
168
- """Resolve the SGNS implementation class for the selected backend."""
159
+ def _get_NN_constructor_from(base: Literal["torch", "pureml"],
160
+ objective: Literal["sgns", "sg"]):
161
+ """Resolve the SG/SGNS implementation class for the selected backend.
162
+
163
+ Args:
164
+ base: Backend family to use, ``"torch"`` or ``"pureml"``.
165
+ objective: Training objective, ``"sgns"`` or ``"sg"``.
166
+
167
+ Returns:
168
+ A callable class (constructor) implementing the requested model.
169
+
170
+ Raises:
171
+ ImportError: If the requested backend cannot be imported.
172
+ NameError: If ``base`` is not one of the supported values.
173
+ """
174
+ _logger.debug("Resolving model constructor: base=%s objective=%s", base, objective)
169
175
  if base == "torch":
170
176
  try:
171
- from .SGNS_torch import SGNS_Torch
172
- return SGNS_Torch
177
+ from .SGNS_torch import SGNS_Torch, SG_Torch
178
+ ctor = SG_Torch if objective == "sg" else SGNS_Torch
179
+ _logger.debug("Resolved PyTorch class: %s", getattr(ctor, "__name__", str(ctor)))
180
+ return ctor
173
181
  except Exception:
182
+ _logger.exception("Failed to import PyTorch backend.")
174
183
  raise ImportError(
175
184
  "PyTorch is not installed, but base='torch' was requested. "
176
185
  "Install PyTorch first, e.g.: `pip install torch` "
@@ -178,9 +187,12 @@ class Embedder:
178
187
  )
179
188
  elif base == "pureml":
180
189
  try:
181
- from .SGNS_pml import SGNS_PureML
182
- return SGNS_PureML
190
+ from .SGNS_pml import SGNS_PureML, SG_PureML
191
+ ctor = SG_PureML if objective == "sg" else SGNS_PureML
192
+ _logger.debug("Resolved PureML class: %s", getattr(ctor, "__name__", str(ctor)))
193
+ return ctor
183
194
  except Exception:
195
+ _logger.exception("Failed to import PureML backend.")
184
196
  raise ImportError(
185
197
  "PureML is not installed, but base='pureml' was requested. "
186
198
  "Install PureML first via `pip install ym-pure-ml` "
@@ -190,7 +202,18 @@ class Embedder:
190
202
 
191
203
  @staticmethod
192
204
  def _as_zerobase_intp(walks: np.ndarray, *, V: int) -> np.ndarray:
193
- """Validate 1-based uint/int walks 0-based intp; check bounds."""
205
+ """Validate and convert 1-based walks to 0-based ``intp``.
206
+
207
+ Args:
208
+ walks: 2D array of node ids with 1-based indexing.
209
+ V: Vocabulary size for bounds checking.
210
+
211
+ Returns:
212
+ 2D array of dtype ``intp`` with 0-based indices.
213
+
214
+ Raises:
215
+ ValueError: If shape is not 2D or indices are out of bounds.
216
+ """
194
217
  arr = np.asarray(walks)
195
218
  if arr.ndim != 2:
196
219
  raise ValueError("walks must be 2D: (num_walks, walk_len)")
@@ -198,7 +221,9 @@ class Embedder:
198
221
  arr = arr.astype(np.int64, copy=False)
199
222
  # 1-based → 0-based
200
223
  arr = arr - 1
201
- if arr.min() < 0 or arr.max() >= V:
224
+ mn, mx = int(arr.min()), int(arr.max())
225
+ _logger.debug("Zero-basing walks: min=%d max=%d V=%d", mn, mx, V)
226
+ if mn < 0 or mx >= V:
202
227
  raise ValueError("walk ids out of range after 1→0-based normalization")
203
228
  return arr.astype(np.intp, copy=False)
204
229
 
@@ -206,19 +231,29 @@ class Embedder:
206
231
  def _pairs_from_walks(walks0: np.ndarray, window_size: int) -> np.ndarray:
207
232
  """
208
233
  Skip-gram pairs including edge centers (one-sided when needed).
209
- walks0: (W, L) int array (0-based ids).
210
- Returns: (N_pairs, 2) int32 [center, context].
234
+
235
+ Args:
236
+ walks0: (W, L) int array (0-based ids).
237
+ window_size: Symmetric context window radius.
238
+
239
+ Returns:
240
+ Array of shape (N_pairs, 2) int32 with columns [center, context].
241
+
242
+ Raises:
243
+ ValueError: If shape is invalid or ``window_size`` <= 0.
211
244
  """
212
245
  if walks0.ndim != 2:
213
246
  raise ValueError("walks must be 2D: (num_walks, walk_len)")
214
247
 
215
248
  _, L = walks0.shape
216
249
  k = int(window_size)
250
+ _logger.debug("Building SG pairs: L=%d, window=%d", L, k)
217
251
 
218
252
  if k <= 0:
219
253
  raise ValueError("window_size must be positive")
220
254
 
221
255
  if L == 0:
256
+ _logger.debug("Empty walks length; returning 0 pairs.")
222
257
  return np.empty((0, 2), dtype=np.int32)
223
258
 
224
259
  out_chunks = []
@@ -236,18 +271,42 @@ class Embedder:
236
271
  out_chunks.append(np.stack((centers_l, ctx_l), axis=2).reshape(-1, 2))
237
272
 
238
273
  if not out_chunks:
274
+ _logger.debug("No offsets produced pairs; returning empty.")
239
275
  return np.empty((0, 2), dtype=np.int32)
240
276
 
241
- return np.concatenate(out_chunks, axis=0).astype(np.int32, copy=False)
277
+ pairs = np.concatenate(out_chunks, axis=0).astype(np.int32, copy=False)
278
+ _logger.debug("Built %d training pairs", pairs.shape[0])
279
+ return pairs
242
280
 
243
281
  @staticmethod
244
282
  def _freq_from_walks(walks0: np.ndarray, *, V: int) -> np.ndarray:
245
- """Node frequencies from walks (0-based)."""
246
- return np.bincount(walks0.ravel(), minlength=V).astype(np.int64, copy=False)
283
+ """Node frequencies from walks (0-based).
284
+
285
+ Args:
286
+ walks0: 2D array of 0-based node ids.
287
+ V: Vocabulary size (minlength for bincount).
288
+
289
+ Returns:
290
+ 1D array of int64 frequencies with length ``V``.
291
+ """
292
+ freq = np.bincount(walks0.ravel(), minlength=V).astype(np.int64, copy=False)
293
+ _logger.debug("Frequency mass: total=%d nonzero=%d", int(freq.sum()), int(np.count_nonzero(freq)))
294
+ return freq
247
295
 
248
296
  @staticmethod
249
297
  def _soft_unigram(freq: np.ndarray, *, power: float = 0.75) -> np.ndarray:
250
- """Return normalized Pn(w) ∝ f(w)^power as float64 probs."""
298
+ """Return normalized Pn(w) ∝ f(w)^power as float64 probs.
299
+
300
+ Args:
301
+ freq: 1D array of token frequencies.
302
+ power: Exponent used for smoothing (default 0.75 à la word2vec).
303
+
304
+ Returns:
305
+ 1D array of probabilities summing to 1.0.
306
+
307
+ Raises:
308
+ ValueError: If mass is invalid (all zeros or non-finite).
309
+ """
251
310
  p = np.asarray(freq, dtype=np.float64)
252
311
  if p.sum() == 0:
253
312
  raise ValueError("all frequencies are zero")
@@ -255,13 +314,31 @@ class Embedder:
255
314
  s = p.sum()
256
315
  if not np.isfinite(s) or s <= 0:
257
316
  raise ValueError("invalid unigram mass")
258
- return p / s
317
+ probs = p / s
318
+ _logger.debug("Noise distribution ready (power=%.3f)", power)
319
+ return probs
259
320
 
260
321
  def _materialize_walks(self, frame_id: int, rin: Literal["attr", "repuls"],
261
322
  using: Literal["RW", "SAW", "merged"]) -> np.ndarray:
323
+ """Materialize the requested set of walks for a frame.
324
+
325
+ Args:
326
+ frame_id: 1-based frame index.
327
+ rin: Which RIN to pull from: ``"attr"`` or ``"repuls"``.
328
+ using: Which walk sets to include: ``"RW"``, ``"SAW"``, or ``"merged"``.
329
+ If ``"merged"``, concatenate available RW and SAW along axis 0.
330
+
331
+ Returns:
332
+ A 2D array of walks with shape (num_walks, walk_length+1).
333
+
334
+ Raises:
335
+ IndexError: If ``frame_id`` is out of range.
336
+ ValueError: If no matching walks are available.
337
+ """
262
338
  if not 1 <= frame_id <= int(self.frame_count):
263
339
  raise IndexError(f"frame_id must be in [1, {self.frame_count}]; got {frame_id}")
264
340
 
341
+ _logger.debug("Materializing %s walks at frame=%d using=%s", rin, frame_id, using)
265
342
  frame_id -= 1
266
343
 
267
344
  if rin == "attr":
@@ -288,8 +365,12 @@ class Embedder:
288
365
  if not parts:
289
366
  raise ValueError(f"No walks available for {rin=} with {using=}")
290
367
  if len(parts) == 1:
291
- return parts[0]
292
- return np.concatenate(parts, axis=0)
368
+ out = parts[0]
369
+ else:
370
+ out = np.concatenate(parts, axis=0)
371
+
372
+ _logger.debug("Materialized walks shape: %s", getattr(out, "shape", None))
373
+ return out
293
374
 
294
375
  # INTERFACES: (private)
295
376
 
@@ -298,6 +379,17 @@ class Embedder:
298
379
  using: Literal["RW", "SAW", "merged"],
299
380
  window_size: int,
300
381
  alpha: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
382
+ """Construct (center, context) pairs and noise distribution for ATTR.
383
+
384
+ Args:
385
+ frame_id: 1-based frame index.
386
+ using: Walk subset to include.
387
+ window_size: Skip-gram window radius.
388
+ alpha: Unigram smoothing exponent.
389
+
390
+ Returns:
391
+ Tuple of (pairs, noise_probs).
392
+ """
301
393
  walks = self._materialize_walks(frame_id, "attr", using)
302
394
  walks0 = self._as_zerobase_intp(walks, V=self.vocab_size)
303
395
  attractive_corpus = self._pairs_from_walks(walks0, window_size)
@@ -311,6 +403,17 @@ class Embedder:
311
403
  using: Literal["RW", "SAW", "merged"],
312
404
  window_size: int,
313
405
  alpha: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
406
+ """Construct (center, context) pairs and noise distribution for REP.
407
+
408
+ Args:
409
+ frame_id: 1-based frame index.
410
+ using: Walk subset to include.
411
+ window_size: Skip-gram window radius.
412
+ alpha: Unigram smoothing exponent.
413
+
414
+ Returns:
415
+ Tuple of (pairs, noise_probs).
416
+ """
314
417
  walks = self._materialize_walks(frame_id, "repuls", using)
315
418
  walks0 = self._as_zerobase_intp(walks, V=self.vocab_size)
316
419
  repulsive_corpus = self._pairs_from_walks(walks0, window_size)
@@ -322,58 +425,56 @@ class Embedder:
322
425
  # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= PUBLIC -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
323
426
 
324
427
  def embed_frame(self,
325
- frame_id: int,
326
- RIN_type: Literal["attr", "repuls"],
327
- using: Literal["RW", "SAW", "merged"],
328
- window_size: int,
329
- num_negative_samples: int,
330
- num_epochs: int,
331
- batch_size: int,
332
- *,
333
- lr_step_per_batch: bool = False,
334
- shuffle_data: bool = True,
335
- dimensionality: int = 128,
336
- alpha: float = 0.75,
337
- device: str | None = None,
338
- sgns_kwargs: dict[str, object] | None = None,
339
- _seed: int | None = None
340
- ) -> np.ndarray:
341
- """Train embeddings for a single frame and return the input embedding matrix.
428
+ frame_id: int,
429
+ RIN_type: Literal["attr", "repuls"],
430
+ using: Literal["RW", "SAW", "merged"],
431
+ num_epochs: int,
432
+ negative_sampling: bool = False,
433
+ window_size: int = 5,
434
+ num_negative_samples: int = 10,
435
+ batch_size: int = 1024,
436
+ *,
437
+ lr_step_per_batch: bool = False,
438
+ shuffle_data: bool = True,
439
+ dimensionality: int = 128,
440
+ alpha: float = 0.75,
441
+ device: str | None = None,
442
+ model_base: Literal["torch", "pureml"] = "pureml",
443
+ model_kwargs: dict[str, object] | None = None,
444
+ kind: Literal["in", "out", "avg"] = "in",
445
+ _seed: int | None = None
446
+ ) -> np.ndarray:
447
+ """Train embeddings for a single frame and return the matrix.
342
448
 
343
449
  Args:
344
- frame_id: 1-based frame index to train on.
345
- RIN_type: Interaction channel to use: ``"attr"`` (attractive) or
346
- ``"repuls"`` (repulsive).
347
- using: Which walk collections to include: ``"RW"``, ``"SAW"``, or
348
- ``"merged"`` (concatenates both if available).
349
- window_size: Symmetric skip-gram window size ``k``.
350
- num_negative_samples: Number of negative samples per positive pair.
351
- num_epochs: Number of passes over the pair dataset.
352
- batch_size: Mini-batch size for training.
353
- shuffle_data: Whether to shuffle pairs each epoch.
354
- dimensionality: Embedding dimensionality ``D``.
355
- alpha: Noise distribution exponent (``Pn f^alpha``).
356
- device: Optional device string for the Torch backend (e.g., ``"cuda"``).
357
- sgns_kwargs: Extra keyword arguments forwarded to the backend SGNS
358
- constructor. For PureML, required keys are:
359
- ``{"optim", "optim_kwargs"}``; ``lr_sched`` is optional, but if
360
- provided then ``lr_sched_kwargs`` must also be provided.
361
- _seed: Optional child seed for this frame's model initialization.
450
+ frame_id: 1-based frame index to embed.
451
+ RIN_type: ``"attr"`` or ``"repuls"`` - which corpus to use.
452
+ using: Which walks to use (``"RW"``, ``"SAW"``, or ``"merged"``).
453
+ num_epochs: Number of passes over the pairs.
454
+ negative_sampling: If ``True``, use SGNS objective; else plain SG.
455
+ window_size: Skip-gram symmetric window radius.
456
+ num_negative_samples: Negatives per positive pair (SGNS only).
457
+ batch_size: Minibatch size for training.
458
+ lr_step_per_batch: If ``True``, step LR every batch (else per epoch).
459
+ shuffle_data: Shuffle pairs each epoch.
460
+ dimensionality: Embedding dimension ``D``.
461
+ alpha: Unigram smoothing power for noise distribution.
462
+ device: Optional backend device hint (e.g., ``"cuda"``).
463
+ model_base: Backend family (``"torch"`` or ``"pureml"``).
464
+ model_kwargs: Passed through to backend model constructor.
465
+ kind: Which embedding to return: ``"in"``, ``"out"``, or ``"avg"``.
466
+ _seed: Optional override seed for this frame.
362
467
 
363
468
  Returns:
364
- np.ndarray: Learned **input** embedding matrix of shape ``(V, D)``.
365
-
366
- Raises:
367
- ValueError: If requested walks are missing, if no training pairs are
368
- generated, or if required ``sgns_kwargs`` for PureML are absent.
369
- AttributeError: If the SGNS model does not expose embeddings via
370
- ``.embeddings`` or ``.parameters[0]``.
469
+ ``(V, D)`` float32 embedding matrix.
371
470
  """
372
471
  _logger.info(
373
- "Preparing frame %d (rin=%s using=%s window=%d neg=%d epochs=%d batch=%d)",
374
- frame_id, RIN_type, using, window_size, num_negative_samples, num_epochs, batch_size
472
+ "embed_frame: frame=%d RIN=%s using=%s base=%s D=%d epochs=%d batch=%d sgns=%s k=%d alpha=%.3f",
473
+ frame_id, RIN_type, using, model_base, dimensionality, num_epochs, batch_size,
474
+ str(negative_sampling), window_size, alpha
375
475
  )
376
476
 
477
+ # ------------------ resolve training data -----------------
377
478
  if RIN_type == "attr":
378
479
  if self.attractive_RWs is None and self.attractive_SAWs is None:
379
480
  raise ValueError("Attractive random walks are missing")
@@ -383,129 +484,124 @@ class Embedder:
383
484
  raise ValueError("Repulsive random walks are missing")
384
485
  pairs, noise_probs = self._repulsive_corpus_and_prob(frame_id=frame_id, using=using, window_size=window_size, alpha=alpha)
385
486
  else:
386
- raise ValueError(f"Unknown RIN_type: {RIN_type!r}")
387
-
487
+ raise NameError(f"Unknown RIN_type: {RIN_type!r}")
388
488
  if pairs.size == 0:
389
489
  raise ValueError("No training pairs generated for the requested configuration")
490
+ # ----------------------------------------------------------
390
491
 
492
+ # ---------------- construct training corpus ---------------
391
493
  centers = pairs[:, 0].astype(np.int64, copy=False)
392
494
  contexts = pairs[:, 1].astype(np.int64, copy=False)
495
+ _logger.debug("Pairs split: centers=%s contexts=%s", centers.shape, contexts.shape)
496
+ # ----------------------------------------------------------
393
497
 
394
- model_kwargs: dict[str, object] = dict(sgns_kwargs or {})
395
- if self.model_base == "pureml":
396
- required = {"optim", "optim_kwargs"}
397
- missing = required.difference(model_kwargs)
398
- if missing:
399
- raise ValueError(f"PureML backend requires {sorted(missing)} in sgns_kwargs.")
400
- has_sched = ("lr_sched" in model_kwargs and model_kwargs["lr_sched"] is not None)
401
- has_sched_kwargs = ("lr_sched_kwargs" in model_kwargs and model_kwargs["lr_sched_kwargs"] is not None)
402
- if has_sched and not has_sched_kwargs:
403
- raise ValueError("When providing lr_sched for PureML, you must also provide lr_sched_kwargs.")
404
-
405
- child_seed = int(self._seed if _seed is None else _seed)
406
- model_kwargs.update({
498
+ # ------------ resolve model_constructor kwargs ------------
499
+ if model_kwargs is not None:
500
+ if (("lr_sched" in model_kwargs and model_kwargs.get("lr_sched", None) is not None)
501
+ and ("lr_sched_kwargs" in model_kwargs and model_kwargs.get("lr_sched_kwargs", None) is None)):
502
+ raise ValueError("When `lr_sched`, you must also provide `lr_sched_kwargs`.")
503
+
504
+ constructor_kwargs: dict[str, object] = dict(model_kwargs or {})
505
+ constructor_kwargs.update({
407
506
  "V": self.vocab_size,
408
507
  "D": dimensionality,
409
- "seed": child_seed
508
+ "seed": int(self._seed if _seed is None else _seed),
509
+ "device": device
410
510
  })
511
+ _logger.debug("Model constructor kwargs: %s", {k: constructor_kwargs[k] for k in ("V","D","seed","device")})
512
+ # ----------------------------------------------------------
513
+
514
+ # --------------- resolve model constructor ----------------
515
+ model_constructor = self._get_NN_constructor_from(
516
+ model_base, objective=("sgns" if negative_sampling else "sg"))
517
+ # ----------------------------------------------------------
518
+
519
+ # ------------------ initialize the model ------------------
520
+ model = model_constructor(**constructor_kwargs)
521
+ _logger.debug("Model initialized: %s", model_constructor.__name__ if hasattr(model_constructor,"__name__") else str(model_constructor))
522
+ # ----------------------------------------------------------
523
+
524
+ # -------------------- fitting the data --------------------
525
+ _logger.info("Fitting model on %d pairs ...", pairs.shape[0])
526
+ model.fit(centers=centers,
527
+ contexts=contexts,
528
+ num_epochs=num_epochs,
529
+ batch_size=batch_size,
530
+ # -- optional; for SGNS; safely ignored by SG via **_ignore --
531
+ num_negative_samples=num_negative_samples,
532
+ noise_dist=noise_probs,
533
+ # -----------------------------------------
534
+ shuffle_data=shuffle_data,
535
+ lr_step_per_batch=lr_step_per_batch
536
+ )
537
+ _logger.info("Training complete for frame %d", frame_id)
538
+ # ----------------------------------------------------------
539
+
540
+ # OUTPUT:
541
+ embeddings = (model.in_embeddings if kind == "in" else
542
+ model.out_embeddings if kind == "out" else
543
+ model.avg_embeddings if kind == "avg" else
544
+ None
545
+ )
411
546
 
412
- if self.model_base == "torch" and device is not None:
413
- model_kwargs["device"] = device
414
-
415
- self.model = self.model_constructor(**model_kwargs)
416
-
417
- _logger.info(
418
- "Training SGNS base=%s constructor=%s frame=%d pairs=%d dim=%d epochs=%d batch=%d neg=%d shuffle=%s",
419
- self.model_base,
420
- getattr(self.model_constructor, "__name__", repr(self.model_constructor)),
421
- frame_id,
422
- pairs.shape[0],
423
- dimensionality,
424
- num_epochs,
425
- batch_size,
426
- num_negative_samples,
427
- shuffle_data
428
- )
429
-
430
- self.model.fit(
431
- centers,
432
- contexts,
433
- num_epochs,
434
- batch_size,
435
- num_negative_samples,
436
- noise_probs,
437
- shuffle_data,
438
- lr_step_per_batch
439
- )
440
-
441
- embeddings = getattr(self.model, "embeddings", None)
442
- if embeddings is None:
443
- params = getattr(self.model, "parameters", None)
444
- if isinstance(params, tuple) and params:
445
- embeddings = params[0]
446
547
  if embeddings is None:
447
- raise AttributeError("SGNS model does not expose embeddings via '.embeddings' or '.parameters[0]'")
548
+ if kind not in ("in", "out", "avg"):
549
+ raise NameError(f"Unknown {kind} embeddings kind. Expected: one of ['in', 'out', 'avg']")
448
550
 
449
- embeddings = np.asarray(embeddings)
450
- _logger.info("Frame %d embeddings ready: shape=%s dtype=%s", frame_id, embeddings.shape, embeddings.dtype)
451
- return embeddings
551
+ E = np.asarray(embeddings, dtype=np.float32)
552
+ _logger.debug("Returned embeddings shape: %s dtype=%s", E.shape, E.dtype)
553
+ return E
452
554
 
453
555
  def embed_all(
454
556
  self,
455
557
  RIN_type: Literal["attr", "repuls"],
456
558
  using: Literal["RW", "SAW", "merged"],
457
- window_size: int,
458
- num_negative_samples: int,
459
559
  num_epochs: int,
460
- batch_size: int,
560
+ negative_sampling: bool = False,
561
+ window_size: int = 2,
562
+ num_negative_samples: int = 10,
563
+ batch_size: int = 1024,
461
564
  *,
565
+ lr_step_per_batch: bool = False,
462
566
  shuffle_data: bool = True,
463
567
  dimensionality: int = 128,
464
568
  alpha: float = 0.75,
465
569
  device: str | None = None,
466
- sgns_kwargs: dict[str, object] | None = None,
570
+ model_base: Literal["torch", "pureml"] = "pureml",
571
+ model_kwargs: dict[str, object] | None = None,
572
+ kind: Literal["in", "out", "avg"] = "in",
467
573
  output_path: str | Path | None = None,
468
574
  num_matrices_in_compressed_blocks: int = 20,
469
- compression_level: int = 3):
470
- """Train embeddings for all frames and persist them to compressed storage.
575
+ compression_level: int = 3,
576
+ ) -> str:
577
+ """Embed all frames and persist a self-contained archive.
471
578
 
472
- Iterates through all frames (``1..frame_count``), trains an SGNS model
473
- per frame using the configured backend, collects the resulting input
474
- embeddings, and writes them into a new compressed ``ArrayStorage`` archive.
579
+ The resulting file stores a block named ``FRAME_EMBEDDINGS`` with a
580
+ compressed sequence of per-frame matrices (each ``(V, D)``), alongside
581
+ rich metadata mirroring the style of other SAWNERGY modules.
475
582
 
476
583
  Args:
477
- RIN_type: Interaction channel to use: ``"attr"`` or ``"repuls"``.
478
- using: Walk collections: ``"RW"``, ``"SAW"``, or ``"merged"``.
479
- window_size: Symmetric skip-gram window size ``k``.
480
- num_negative_samples: Number of negative samples per positive pair.
481
- num_epochs: Number of epochs for each frame.
482
- batch_size: Mini-batch size used during training.
483
- shuffle_data: Whether to shuffle pairs each epoch.
484
- dimensionality: Embedding dimensionality ``D``.
485
- alpha: Noise distribution exponent (``Pn ∝ f^alpha``).
486
- device: Optional device string for Torch backend.
487
- sgns_kwargs: Extra constructor kwargs for the SGNS backend (see
488
- :meth:`embed_frame` for PureML requirements).
489
- output_path: Destination path. If ``None``, a new file named
490
- ``EMBEDDINGS_<timestamp>.zip`` is created next to the source
491
- WALKS archive. If the provided path lacks a suffix, ``.zip`` is
492
- appended.
493
- num_matrices_in_compressed_blocks: Number of per-frame matrices to
494
- store per compressed chunk in the output archive.
495
- compression_level: Blosc Zstd compression level (0-9).
584
+ RIN_type: ``"attr"`` or ``"repuls"`` - which corpus to use.
585
+ using: Which walks to use (``"RW"``, ``"SAW"``, or ``"merged"``).
586
+ num_epochs: Number of epochs to train per frame.
587
+ negative_sampling: If ``True``, use SGNS; otherwise plain SG.
588
+ window_size: Skip-gram window radius.
589
+ num_negative_samples: Negatives per positive pair (SGNS).
590
+ batch_size: Minibatch size for training.
591
+ lr_step_per_batch: If ``True``, step LR per batch (else per epoch).
592
+ shuffle_data: Shuffle pairs each epoch.
593
+ dimensionality: Embedding dimension.
594
+ alpha: Unigram smoothing power for noise distribution.
595
+ device: Backend device hint (e.g., ``"cuda"``).
596
+ model_base: Backend family (``"torch"`` or ``"pureml"``).
597
+ model_kwargs: Passed through to backend model constructor.
598
+ kind: Which embedding to store: ``"in"``, ``"out"``, or ``"avg"``.
599
+ output_path: Optional path for the output archive (``.zip`` inferred).
600
+ num_matrices_in_compressed_blocks: How many frames per compressed chunk.
601
+ compression_level: Integer compression level for the archive.
496
602
 
497
603
  Returns:
498
- str: Filesystem path to the written embeddings archive (``.zip``).
499
-
500
- Raises:
501
- ValueError: If configuration produces no pairs for a frame or if
502
- PureML kwargs are incomplete.
503
- RuntimeError: Propagated from storage operations on failure.
504
-
505
- Notes:
506
- - A deterministic child seed is spawned per frame from the master
507
- seed using ``np.random.SeedSequence`` to ensure reproducibility
508
- across runs.
604
+ Path to the created embeddings archive, as ``str``.
509
605
  """
510
606
  current_time = sawnergy_util.current_time()
511
607
  if output_path is None:
@@ -516,68 +612,102 @@ class Embedder:
516
612
  output_path = output_path.with_suffix(".zip")
517
613
 
518
614
  _logger.info(
519
- "Embedding all frames -> %s | frames=%d dim=%d base=%s",
520
- output_path, self.frame_count, dimensionality, self.model_base
615
+ "embed_all: frames=%d D=%d base=%s RIN=%s using=%s out=%s",
616
+ self.frame_count, dimensionality, model_base, RIN_type, using, output_path
521
617
  )
522
618
 
619
+ # Per-frame deterministic seeds
523
620
  master_ss = np.random.SeedSequence(self._seed)
524
621
  child_seeds = master_ss.spawn(self.frame_count)
525
622
 
526
- embeddings = []
527
- for frame_idx, seed_seq in enumerate(child_seeds, start=1):
623
+ embeddings: list[np.ndarray] = []
624
+ for frame_id, seed_seq in enumerate(child_seeds, start=1):
528
625
  child_seed = int(seed_seq.generate_state(1, dtype=np.uint32)[0])
529
- _logger.info("Processing frame %d/%d (child_seed=%d entropy=%d)", frame_idx, self.frame_count, child_seed, seed_seq.entropy)
530
- embeddings.append(
531
- self.embed_frame(
532
- frame_idx,
533
- RIN_type,
534
- using,
535
- window_size,
536
- num_negative_samples,
537
- num_epochs,
538
- batch_size,
626
+ _logger.info("Embedding frame %d/%d with seed=%d", frame_id, self.frame_count, child_seed)
627
+ E = self.embed_frame(
628
+ frame_id=frame_id,
629
+ RIN_type=RIN_type,
630
+ using=using,
631
+ num_epochs=num_epochs,
632
+ negative_sampling=negative_sampling,
633
+ window_size=window_size,
634
+ num_negative_samples=num_negative_samples,
635
+ batch_size=batch_size,
636
+ lr_step_per_batch=lr_step_per_batch,
539
637
  shuffle_data=shuffle_data,
540
638
  dimensionality=dimensionality,
541
639
  alpha=alpha,
542
640
  device=device,
543
- sgns_kwargs=sgns_kwargs,
641
+ model_base=model_base,
642
+ model_kwargs=model_kwargs,
643
+ kind=kind,
544
644
  _seed=child_seed
545
645
  )
546
- )
646
+ embeddings.append(np.asarray(E, dtype=np.float32, copy=False))
647
+ _logger.debug("Frame %d embedded: E.shape=%s", frame_id, E.shape)
547
648
 
548
- embeddings = [np.asarray(e) for e in embeddings]
549
649
  block_name = "FRAME_EMBEDDINGS"
550
650
  with sawnergy_util.ArrayStorage.compress_and_cleanup(output_path, compression_level=compression_level) as storage:
651
+ _logger.info("Writing %d frame matrices to block '%s' ...", len(embeddings), block_name)
551
652
  storage.write(
552
653
  these_arrays=embeddings,
553
654
  to_block_named=block_name,
554
655
  arrays_per_chunk=num_matrices_in_compressed_blocks
555
656
  )
556
- storage.add_attr("time_created", current_time)
557
- storage.add_attr("seed", int(self._seed))
558
- storage.add_attr("rng_scheme", "SeedSequence.spawn_per_frame_v1")
559
- storage.add_attr("source_walks_path", str(self._walks_path))
560
- storage.add_attr("model_base", self.model_base)
561
- storage.add_attr("rin_type", RIN_type)
562
- storage.add_attr("using_mode", using)
657
+
658
+ # Core dataset discovery (for consumers like the Embeddings Visualizer)
659
+ storage.add_attr("frame_embeddings_name", block_name)
660
+ storage.add_attr("time_stamp_count", int(self.frame_count))
661
+ storage.add_attr("node_count", int(self.vocab_size))
662
+ storage.add_attr("embedding_dim", int(dimensionality))
663
+
664
+ # Provenance of input WALKS
665
+ storage.add_attr("source_WALKS_path", str(self._walks_path))
666
+ storage.add_attr("walk_length", int(self.walk_length))
667
+ storage.add_attr("num_RWs", int(self.num_RWs))
668
+ storage.add_attr("num_SAWs", int(self.num_SAWs))
669
+ storage.add_attr("attractive_RWs_name", self._attractive_RWs_name)
670
+ storage.add_attr("repulsive_RWs_name", self._repulsive_RWs_name)
671
+ storage.add_attr("attractive_SAWs_name", self._attractive_SAWs_name)
672
+ storage.add_attr("repulsive_SAWs_name", self._repulsive_SAWs_name)
673
+
674
+ # Training configuration (sufficient to reproduce)
675
+ storage.add_attr("objective", "sgns" if negative_sampling else "sg")
676
+ storage.add_attr("model_base", model_base)
677
+ storage.add_attr("embedding_kind", kind) # 'in' | 'out' | 'avg'
678
+ storage.add_attr("num_epochs", int(num_epochs))
679
+ storage.add_attr("batch_size", int(batch_size))
563
680
  storage.add_attr("window_size", int(window_size))
564
681
  storage.add_attr("alpha", float(alpha))
565
- storage.add_attr("dimensionality", int(dimensionality))
682
+ storage.add_attr("negative_sampling", bool(negative_sampling))
566
683
  storage.add_attr("num_negative_samples", int(num_negative_samples))
567
- storage.add_attr("num_epochs", int(num_epochs))
568
- storage.add_attr("batch_size", int(batch_size))
684
+ storage.add_attr("lr_step_per_batch", bool(lr_step_per_batch))
569
685
  storage.add_attr("shuffle_data", bool(shuffle_data))
570
- storage.add_attr("frames_written", int(len(embeddings)))
571
- storage.add_attr("vocab_size", int(self.vocab_size))
572
- storage.add_attr("frame_count", int(self.frame_count))
573
- storage.add_attr("embedding_dtype", str(embeddings[0].dtype))
574
- storage.add_attr("frame_embeddings_name", block_name)
686
+ storage.add_attr("device_hint", device if device is not None else "")
687
+ storage.add_attr("model_kwargs_repr", repr(model_kwargs) if model_kwargs is not None else "{}")
688
+
689
+ # Which walks were used to train
690
+ storage.add_attr("RIN_type", RIN_type) # 'attr' or 'repuls'
691
+ storage.add_attr("using", using) # 'RW' | 'SAW' | 'merged'
692
+
693
+ # Reproducibility
694
+ storage.add_attr("master_seed", int(self._seed))
695
+ # Note: this records seeds derived from child SeedSequences at metadata time.
696
+ storage.add_attr("per_frame_seeds", [int(s.generate_state(1, dtype=np.uint32)[0]) for s in child_seeds])
697
+
698
+ # Archive/IO details
575
699
  storage.add_attr("arrays_per_chunk", int(num_matrices_in_compressed_blocks))
576
700
  storage.add_attr("compression_level", int(compression_level))
701
+ storage.add_attr("created_at", current_time)
702
+
703
+ _logger.info(
704
+ "Stored embeddings archive: %s | shape=(T,N,D)=(%d,%d,%d)",
705
+ output_path, self.frame_count, self.vocab_size, dimensionality
706
+ )
577
707
 
578
- _logger.info("Embedding archive written to %s", output_path)
579
708
  return str(output_path)
580
709
 
710
+
581
711
  __all__ = ["Embedder"]
582
712
 
583
713
  if __name__ == "__main__":