sawnergy 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sawnergy might be problematic. Click here for more details.

@@ -0,0 +1,578 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Embedding orchestration for Skip-Gram with Negative Sampling (SGNS).
5
+
6
+ This module consumes attractive/repulsive walk corpora produced by the walker
7
+ pipeline and trains per-frame embeddings using either the PyTorch or PureML
8
+ implementations of SGNS. The resulting embeddings can be persisted back into
9
+ an ``ArrayStorage`` archive along with rich metadata describing the training
10
+ configuration.
11
+ """
12
+
13
+ # third-pary
14
+ import numpy as np
15
+
16
+ # built-in
17
+ from pathlib import Path
18
+ from typing import Literal
19
+ import logging
20
+
21
+ # local
22
+ from .. import sawnergy_util
23
+
24
+ # *----------------------------------------------------*
25
+ # GLOBALS
26
+ # *----------------------------------------------------*
27
+
28
+ _logger = logging.getLogger(__name__)
29
+
30
+ # *----------------------------------------------------*
31
+ # CLASSES
32
+ # *----------------------------------------------------*
33
+
34
+ class Embedder:
35
+ """Skip-gram embedder over attractive/repulsive walk corpora."""
36
+
37
+ def __init__(self,
38
+ WALKS_path: str | Path,
39
+ base: Literal["torch", "pureml"],
40
+ *,
41
+ seed: int | None = None
42
+ ) -> None:
43
+ """Initialize the embedder and load walk tensors.
44
+
45
+ Args:
46
+ WALKS_path: Path to a ``WALKS_*.zip`` (or ``.zarr``) archive created
47
+ by the walker pipeline. The archive's root attrs must include:
48
+ ``attractive_RWs_name``, ``repulsive_RWs_name``,
49
+ ``attractive_SAWs_name``, ``repulsive_SAWs_name`` (each may be
50
+ ``None`` if that collection is absent), and the metadata
51
+ ``num_RWs``, ``num_SAWs``, ``node_count``, ``time_stamp_count``,
52
+ ``walk_length``.
53
+ base: Which SGNS backend to use, either ``"torch"`` or ``"pureml"``.
54
+ seed: Optional seed for the embedder's RNG. If ``None``, a random
55
+ 32-bit seed is chosen.
56
+
57
+ Raises:
58
+ ValueError: If required metadata is missing or any loaded walk array
59
+ has an unexpected shape.
60
+ ImportError: If the requested backend is not installed.
61
+ NameError: If ``base`` is not one of ``{"torch","pureml"}``.
62
+
63
+ Notes:
64
+ - Walks in storage are 1-based (residue indexing). Internally, this
65
+ class normalizes to 0-based indices for training utilities.
66
+ """
67
+ self._walks_path = Path(WALKS_path)
68
+ _logger.info("Initializing Embedder from %s (base=%s)", self._walks_path, base)
69
+
70
+ # placeholders for optional walk collections
71
+ self.attractive_RWs : np.ndarray | None = None
72
+ self.repulsive_RWs : np.ndarray | None = None
73
+ self.attractive_SAWs: np.ndarray | None = None
74
+ self.repulsive_SAWs : np.ndarray | None = None
75
+
76
+ # Load numpy arrays from read-only storage
77
+ with sawnergy_util.ArrayStorage(self._walks_path, mode="r") as storage:
78
+ attractive_RWs_name = storage.get_attr("attractive_RWs_name")
79
+ repulsive_RWs_name = storage.get_attr("repulsive_RWs_name")
80
+ attractive_SAWs_name = storage.get_attr("attractive_SAWs_name")
81
+ repulsive_SAWs_name = storage.get_attr("repulsive_SAWs_name")
82
+
83
+ attractive_RWs : np.ndarray | None = (
84
+ storage.read(attractive_RWs_name, slice(None)) if attractive_RWs_name is not None else None
85
+ )
86
+
87
+ repulsive_RWs : np.ndarray | None = (
88
+ storage.read(repulsive_RWs_name, slice(None)) if repulsive_RWs_name is not None else None
89
+ )
90
+
91
+ attractive_SAWs : np.ndarray | None = (
92
+ storage.read(attractive_SAWs_name, slice(None)) if attractive_SAWs_name is not None else None
93
+ )
94
+
95
+ repulsive_SAWs : np.ndarray | None = (
96
+ storage.read(repulsive_SAWs_name, slice(None)) if repulsive_SAWs_name is not None else None
97
+ )
98
+
99
+ num_RWs = storage.get_attr("num_RWs")
100
+ num_SAWs = storage.get_attr("num_SAWs")
101
+ node_count = storage.get_attr("node_count")
102
+ time_stamp_count = storage.get_attr("time_stamp_count")
103
+ walk_length = storage.get_attr("walk_length")
104
+
105
+ if node_count is None or time_stamp_count is None or walk_length is None:
106
+ raise ValueError("WALKS metadata missing one of node_count, time_stamp_count, walk_length")
107
+
108
+ _logger.debug(
109
+ ("Loaded WALKS from %s"
110
+ " | ATTR RWs: %s %s"
111
+ " | REP RWs: %s %s"
112
+ " | ATTR SAWs: %s %s"
113
+ " | REP SAWs: %s %s"
114
+ " | num_RWs=%d num_SAWs=%d V=%d L=%d T=%d"),
115
+ self._walks_path,
116
+ getattr(attractive_RWs, "shape", None), getattr(attractive_RWs, "dtype", None),
117
+ getattr(repulsive_RWs, "shape", None), getattr(repulsive_RWs, "dtype", None),
118
+ getattr(attractive_SAWs, "shape", None), getattr(attractive_SAWs, "dtype", None),
119
+ getattr(repulsive_SAWs, "shape", None), getattr(repulsive_SAWs, "dtype", None),
120
+ num_RWs, num_SAWs, node_count, walk_length, time_stamp_count
121
+ )
122
+
123
+ # expected shapes
124
+ RWs_expected = (time_stamp_count, node_count * num_RWs, walk_length+1) if (num_RWs > 0) else None
125
+ SAWs_expected = (time_stamp_count, node_count * num_SAWs, walk_length+1) if (num_SAWs > 0) else None
126
+
127
+ self.vocab_size = int(node_count)
128
+ self.frame_count = int(time_stamp_count)
129
+ self.walk_length = int(walk_length)
130
+
131
+ # store walks if present
132
+ if attractive_RWs is not None:
133
+ if RWs_expected and attractive_RWs.shape != RWs_expected:
134
+ raise ValueError(f"ATTR RWs: expected {RWs_expected}, got {attractive_RWs.shape}")
135
+ self.attractive_RWs = attractive_RWs
136
+
137
+ if repulsive_RWs is not None:
138
+ if RWs_expected and repulsive_RWs.shape != RWs_expected:
139
+ raise ValueError(f"REP RWs: expected {RWs_expected}, got {repulsive_RWs.shape}")
140
+ self.repulsive_RWs = repulsive_RWs
141
+
142
+ if attractive_SAWs is not None:
143
+ if SAWs_expected and attractive_SAWs.shape != SAWs_expected:
144
+ raise ValueError(f"ATTR SAWs: expected {SAWs_expected}, got {attractive_SAWs.shape}")
145
+ self.attractive_SAWs = attractive_SAWs
146
+
147
+ if repulsive_SAWs is not None:
148
+ if SAWs_expected and repulsive_SAWs.shape != SAWs_expected:
149
+ raise ValueError(f"REP SAWs: expected {SAWs_expected}, got {repulsive_SAWs.shape}")
150
+ self.repulsive_SAWs = repulsive_SAWs
151
+
152
+ # INTERNAL RNG
153
+ self._seed = np.random.randint(0, 2**32 - 1) if seed is None else int(seed)
154
+ self.rng = np.random.default_rng(self._seed)
155
+ _logger.info("RNG initialized from seed=%d", self._seed)
156
+
157
+ # MODEL HANDLE
158
+ self.model_base: Literal["torch", "pureml"] = base
159
+ self.model_constructor = self._get_SGNS_constructor_from(base)
160
+ _logger.info("SGNS backend resolved: %s", getattr(self.model_constructor, "__name__", repr(self.model_constructor)))
161
+
162
+ # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- PRIVATE -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
163
+
164
+ # HELPERS:
165
+
166
+ @staticmethod
167
+ def _get_SGNS_constructor_from(base: Literal["torch", "pureml"]):
168
+ """Resolve the SGNS implementation class for the selected backend."""
169
+ if base == "torch":
170
+ try:
171
+ from .SGNS_torch import SGNS_Torch
172
+ return SGNS_Torch
173
+ except Exception:
174
+ raise ImportError(
175
+ "PyTorch is not installed, but base='torch' was requested. "
176
+ "Install PyTorch first, e.g.: `pip install torch` "
177
+ "(see https://pytorch.org/get-started for platform-specific wheels)."
178
+ )
179
+ elif base == "pureml":
180
+ try:
181
+ from .SGNS_pml import SGNS_PureML
182
+ return SGNS_PureML
183
+ except Exception:
184
+ raise ImportError(
185
+ "PureML is not installed, but base='pureml' was requested. "
186
+ "Install PureML first via `pip install ym-pure-ml` "
187
+ )
188
+ else:
189
+ raise NameError(f"Expected `base` in (\"torch\", \"pureml\"); Instead got: {base}")
190
+
191
+ @staticmethod
192
+ def _as_zerobase_intp(walks: np.ndarray, *, V: int) -> np.ndarray:
193
+ """Validate 1-based uint/int walks → 0-based intp; check bounds."""
194
+ arr = np.asarray(walks)
195
+ if arr.ndim != 2:
196
+ raise ValueError("walks must be 2D: (num_walks, walk_len)")
197
+ if arr.dtype.kind not in "iu":
198
+ arr = arr.astype(np.int64, copy=False)
199
+ # 1-based → 0-based
200
+ arr = arr - 1
201
+ if arr.min() < 0 or arr.max() >= V:
202
+ raise ValueError("walk ids out of range after 1→0-based normalization")
203
+ return arr.astype(np.intp, copy=False)
204
+
205
+ @staticmethod
206
+ def _pairs_from_walks(walks0: np.ndarray, window_size: int) -> np.ndarray:
207
+ """
208
+ Skip-gram pairs including edge centers (one-sided when needed).
209
+ walks0: (W, L) int array (0-based ids).
210
+ Returns: (N_pairs, 2) int32 [center, context].
211
+ """
212
+ if walks0.ndim != 2:
213
+ raise ValueError("walks must be 2D: (num_walks, walk_len)")
214
+
215
+ _, L = walks0.shape
216
+ k = int(window_size)
217
+
218
+ if k <= 0:
219
+ raise ValueError("window_size must be positive")
220
+
221
+ if L == 0:
222
+ return np.empty((0, 2), dtype=np.int32)
223
+
224
+ out_chunks = []
225
+ for d in range(1, k + 1):
226
+ span = L - d
227
+ if span <= 0:
228
+ break
229
+ # right contexts: center j pairs with j+d (centers 0..L-d-1)
230
+ centers_r = walks0[:, :L - d]
231
+ ctx_r = walks0[:, d:]
232
+ out_chunks.append(np.stack((centers_r, ctx_r), axis=2).reshape(-1, 2))
233
+ # left contexts: center j pairs with j-d (centers d..L-1)
234
+ centers_l = walks0[:, d:]
235
+ ctx_l = walks0[:, :L - d]
236
+ out_chunks.append(np.stack((centers_l, ctx_l), axis=2).reshape(-1, 2))
237
+
238
+ if not out_chunks:
239
+ return np.empty((0, 2), dtype=np.int32)
240
+
241
+ return np.concatenate(out_chunks, axis=0).astype(np.int32, copy=False)
242
+
243
+ @staticmethod
244
+ def _freq_from_walks(walks0: np.ndarray, *, V: int) -> np.ndarray:
245
+ """Node frequencies from walks (0-based)."""
246
+ return np.bincount(walks0.ravel(), minlength=V).astype(np.int64, copy=False)
247
+
248
+ @staticmethod
249
+ def _soft_unigram(freq: np.ndarray, *, power: float = 0.75) -> np.ndarray:
250
+ """Return normalized Pn(w) ∝ f(w)^power as float64 probs."""
251
+ p = np.asarray(freq, dtype=np.float64)
252
+ if p.sum() == 0:
253
+ raise ValueError("all frequencies are zero")
254
+ p = np.power(p, float(power))
255
+ s = p.sum()
256
+ if not np.isfinite(s) or s <= 0:
257
+ raise ValueError("invalid unigram mass")
258
+ return p / s
259
+
260
+ def _materialize_walks(self, frame_id: int, rin: Literal["attr", "repuls"],
261
+ using: Literal["RW", "SAW", "merged"]) -> np.ndarray:
262
+ if not 1 <= frame_id <= int(self.frame_count):
263
+ raise IndexError(f"frame_id must be in [1, {self.frame_count}]; got {frame_id}")
264
+
265
+ frame_id -= 1
266
+
267
+ if rin == "attr":
268
+ parts = []
269
+ if using in ("RW", "merged"):
270
+ arr = getattr(self, "attractive_RWs", None)
271
+ if arr is not None:
272
+ parts.append(arr[frame_id])
273
+ if using in ("SAW", "merged"):
274
+ arr = getattr(self, "attractive_SAWs", None)
275
+ if arr is not None:
276
+ parts.append(arr[frame_id])
277
+ else:
278
+ parts = []
279
+ if using in ("RW", "merged"):
280
+ arr = getattr(self, "repulsive_RWs", None)
281
+ if arr is not None:
282
+ parts.append(arr[frame_id])
283
+ if using in ("SAW", "merged"):
284
+ arr = getattr(self, "repulsive_SAWs", None)
285
+ if arr is not None:
286
+ parts.append(arr[frame_id])
287
+
288
+ if not parts:
289
+ raise ValueError(f"No walks available for {rin=} with {using=}")
290
+ if len(parts) == 1:
291
+ return parts[0]
292
+ return np.concatenate(parts, axis=0)
293
+
294
+ # INTERFACES: (private)
295
+
296
+ def _attractive_corpus_and_prob(self, *,
297
+ frame_id: int,
298
+ using: Literal["RW", "SAW", "merged"],
299
+ window_size: int,
300
+ alpha: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
301
+ walks = self._materialize_walks(frame_id, "attr", using)
302
+ walks0 = self._as_zerobase_intp(walks, V=self.vocab_size)
303
+ attractive_corpus = self._pairs_from_walks(walks0, window_size)
304
+ attractive_noise_probs = self._soft_unigram(self._freq_from_walks(walks0, V=self.vocab_size), power=alpha)
305
+ _logger.info("ATTR corpus ready: pairs=%d", 0 if attractive_corpus is None else attractive_corpus.shape[0])
306
+
307
+ return attractive_corpus, attractive_noise_probs
308
+
309
+ def _repulsive_corpus_and_prob(self, *,
310
+ frame_id: int,
311
+ using: Literal["RW", "SAW", "merged"],
312
+ window_size: int,
313
+ alpha: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
314
+ walks = self._materialize_walks(frame_id, "repuls", using)
315
+ walks0 = self._as_zerobase_intp(walks, V=self.vocab_size)
316
+ repulsive_corpus = self._pairs_from_walks(walks0, window_size)
317
+ repulsive_noise_probs = self._soft_unigram(self._freq_from_walks(walks0, V=self.vocab_size), power=alpha)
318
+ _logger.info("REP corpus ready: pairs=%d", 0 if repulsive_corpus is None else repulsive_corpus.shape[0])
319
+
320
+ return repulsive_corpus, repulsive_noise_probs
321
+
322
+ # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= PUBLIC -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
323
+
324
+ def embed_frame(self,
325
+ frame_id: int,
326
+ RIN_type: Literal["attr", "repuls"],
327
+ using: Literal["RW", "SAW", "merged"],
328
+ window_size: int,
329
+ num_negative_samples: int,
330
+ num_epochs: int,
331
+ batch_size: int,
332
+ *,
333
+ shuffle_data: bool = True,
334
+ dimensionality: int = 128,
335
+ alpha: float = 0.75,
336
+ device: str | None = None,
337
+ sgns_kwargs: dict[str, object] | None = None,
338
+ _seed: int | None = None
339
+ ) -> np.ndarray:
340
+ """Train embeddings for a single frame and return the input embedding matrix.
341
+
342
+ Args:
343
+ frame_id: 1-based frame index to train on.
344
+ RIN_type: Interaction channel to use: ``"attr"`` (attractive) or
345
+ ``"repuls"`` (repulsive).
346
+ using: Which walk collections to include: ``"RW"``, ``"SAW"``, or
347
+ ``"merged"`` (concatenates both if available).
348
+ window_size: Symmetric skip-gram window size ``k``.
349
+ num_negative_samples: Number of negative samples per positive pair.
350
+ num_epochs: Number of passes over the pair dataset.
351
+ batch_size: Mini-batch size for training.
352
+ shuffle_data: Whether to shuffle pairs each epoch.
353
+ dimensionality: Embedding dimensionality ``D``.
354
+ alpha: Noise distribution exponent (``Pn ∝ f^alpha``).
355
+ device: Optional device string for the Torch backend (e.g., ``"cuda"``).
356
+ sgns_kwargs: Extra keyword arguments forwarded to the backend SGNS
357
+ constructor. For PureML, required keys are:
358
+ ``{"optim", "optim_kwargs", "lr_sched", "lr_sched_kwargs"}``.
359
+ _seed: Optional child seed for this frame's model initialization.
360
+
361
+ Returns:
362
+ np.ndarray: Learned **input** embedding matrix of shape ``(V, D)``.
363
+
364
+ Raises:
365
+ ValueError: If requested walks are missing, if no training pairs are
366
+ generated, or if required ``sgns_kwargs`` for PureML are absent.
367
+ AttributeError: If the SGNS model does not expose embeddings via
368
+ ``.embeddings`` or ``.parameters[0]``.
369
+ """
370
+ _logger.info(
371
+ "Preparing frame %d (rin=%s using=%s window=%d neg=%d epochs=%d batch=%d)",
372
+ frame_id, RIN_type, using, window_size, num_negative_samples, num_epochs, batch_size
373
+ )
374
+
375
+ if RIN_type == "attr":
376
+ if self.attractive_RWs is None and self.attractive_SAWs is None:
377
+ raise ValueError("Attractive random walks are missing")
378
+ pairs, noise_probs = self._attractive_corpus_and_prob(frame_id=frame_id, using=using, window_size=window_size, alpha=alpha)
379
+ elif RIN_type == "repuls":
380
+ if self.repulsive_RWs is None and self.repulsive_SAWs is None:
381
+ raise ValueError("Repulsive random walks are missing")
382
+ pairs, noise_probs = self._repulsive_corpus_and_prob(frame_id=frame_id, using=using, window_size=window_size, alpha=alpha)
383
+ else:
384
+ raise ValueError(f"Unknown RIN_type: {RIN_type!r}")
385
+
386
+ if pairs.size == 0:
387
+ raise ValueError("No training pairs generated for the requested configuration")
388
+
389
+ centers = pairs[:, 0].astype(np.int64, copy=False)
390
+ contexts = pairs[:, 1].astype(np.int64, copy=False)
391
+
392
+ model_kwargs: dict[str, object] = dict(sgns_kwargs or {})
393
+ if self.model_base == "pureml":
394
+ required = {"optim", "optim_kwargs", "lr_sched", "lr_sched_kwargs"}
395
+ missing = required.difference(model_kwargs)
396
+ if missing:
397
+ raise ValueError(f"PureML backend requires {sorted(missing)} in sgns_kwargs.")
398
+
399
+ child_seed = int(self._seed if _seed is None else _seed)
400
+ model_kwargs.update({
401
+ "V": self.vocab_size,
402
+ "D": dimensionality,
403
+ "seed": child_seed
404
+ })
405
+
406
+ if self.model_base == "torch" and device is not None:
407
+ model_kwargs["device"] = device
408
+
409
+ self.model = self.model_constructor(**model_kwargs)
410
+
411
+ _logger.info(
412
+ "Training SGNS base=%s constructor=%s frame=%d pairs=%d dim=%d epochs=%d batch=%d neg=%d shuffle=%s",
413
+ self.model_base,
414
+ getattr(self.model_constructor, "__name__", repr(self.model_constructor)),
415
+ frame_id,
416
+ pairs.shape[0],
417
+ dimensionality,
418
+ num_epochs,
419
+ batch_size,
420
+ num_negative_samples,
421
+ shuffle_data
422
+ )
423
+
424
+ self.model.fit(
425
+ centers,
426
+ contexts,
427
+ num_epochs,
428
+ batch_size,
429
+ num_negative_samples,
430
+ noise_probs,
431
+ shuffle_data,
432
+ lr_step_per_batch=False
433
+ )
434
+
435
+ embeddings = getattr(self.model, "embeddings", None)
436
+ if embeddings is None:
437
+ params = getattr(self.model, "parameters", None)
438
+ if isinstance(params, tuple) and params:
439
+ embeddings = params[0]
440
+ if embeddings is None:
441
+ raise AttributeError("SGNS model does not expose embeddings via '.embeddings' or '.parameters[0]'")
442
+
443
+ embeddings = np.asarray(embeddings)
444
+ _logger.info("Frame %d embeddings ready: shape=%s dtype=%s", frame_id, embeddings.shape, embeddings.dtype)
445
+ return embeddings
446
+
447
+ def embed_all(
448
+ self,
449
+ RIN_type: Literal["attr", "repuls"],
450
+ using: Literal["RW", "SAW", "merged"],
451
+ window_size: int,
452
+ num_negative_samples: int,
453
+ num_epochs: int,
454
+ batch_size: int,
455
+ *,
456
+ shuffle_data: bool = True,
457
+ dimensionality: int = 128,
458
+ alpha: float = 0.75,
459
+ device: str | None = None,
460
+ sgns_kwargs: dict[str, object] | None = None,
461
+ output_path: str | Path | None = None,
462
+ num_matrices_in_compressed_blocks: int = 20,
463
+ compression_level: int = 3):
464
+ """Train embeddings for all frames and persist them to compressed storage.
465
+
466
+ Iterates through all frames (``1..frame_count``), trains an SGNS model
467
+ per frame using the configured backend, collects the resulting input
468
+ embeddings, and writes them into a new compressed ``ArrayStorage`` archive.
469
+
470
+ Args:
471
+ RIN_type: Interaction channel to use: ``"attr"`` or ``"repuls"``.
472
+ using: Walk collections: ``"RW"``, ``"SAW"``, or ``"merged"``.
473
+ window_size: Symmetric skip-gram window size ``k``.
474
+ num_negative_samples: Number of negative samples per positive pair.
475
+ num_epochs: Number of epochs for each frame.
476
+ batch_size: Mini-batch size used during training.
477
+ shuffle_data: Whether to shuffle pairs each epoch.
478
+ dimensionality: Embedding dimensionality ``D``.
479
+ alpha: Noise distribution exponent (``Pn ∝ f^alpha``).
480
+ device: Optional device string for Torch backend.
481
+ sgns_kwargs: Extra constructor kwargs for the SGNS backend (see
482
+ :meth:`embed_frame` for PureML requirements).
483
+ output_path: Destination path. If ``None``, a new file named
484
+ ``EMBEDDINGS_<timestamp>.zip`` is created next to the source
485
+ WALKS archive. If the provided path lacks a suffix, ``.zip`` is
486
+ appended.
487
+ num_matrices_in_compressed_blocks: Number of per-frame matrices to
488
+ store per compressed chunk in the output archive.
489
+ compression_level: Blosc Zstd compression level (0-9).
490
+
491
+ Returns:
492
+ str: Filesystem path to the written embeddings archive (``.zip``).
493
+
494
+ Raises:
495
+ ValueError: If configuration produces no pairs for a frame or if
496
+ PureML kwargs are incomplete.
497
+ RuntimeError: Propagated from storage operations on failure.
498
+
499
+ Notes:
500
+ - A deterministic child seed is spawned per frame from the master
501
+ seed using ``np.random.SeedSequence`` to ensure reproducibility
502
+ across runs.
503
+ """
504
+ current_time = sawnergy_util.current_time()
505
+ if output_path is None:
506
+ output_path = self._walks_path.with_name(f"EMBEDDINGS_{current_time}").with_suffix(".zip")
507
+ else:
508
+ output_path = Path(output_path)
509
+ if output_path.suffix == "":
510
+ output_path = output_path.with_suffix(".zip")
511
+
512
+ _logger.info(
513
+ "Embedding all frames -> %s | frames=%d dim=%d base=%s",
514
+ output_path, self.frame_count, dimensionality, self.model_base
515
+ )
516
+
517
+ master_ss = np.random.SeedSequence(self._seed)
518
+ child_seeds = master_ss.spawn(self.frame_count)
519
+
520
+ embeddings = []
521
+ for frame_idx, seed_seq in enumerate(child_seeds, start=1):
522
+ child_seed = int(seed_seq.generate_state(1, dtype=np.uint32)[0])
523
+ _logger.info("Processing frame %d/%d (child_seed=%d entropy=%d)", frame_idx, self.frame_count, child_seed, seed_seq.entropy)
524
+ embeddings.append(
525
+ self.embed_frame(
526
+ frame_idx,
527
+ RIN_type,
528
+ using,
529
+ window_size,
530
+ num_negative_samples,
531
+ num_epochs,
532
+ batch_size,
533
+ shuffle_data=shuffle_data,
534
+ dimensionality=dimensionality,
535
+ alpha=alpha,
536
+ device=device,
537
+ sgns_kwargs=sgns_kwargs,
538
+ _seed=child_seed
539
+ )
540
+ )
541
+
542
+ embeddings = [np.asarray(e) for e in embeddings]
543
+ block_name = "FRAME_EMBEDDINGS"
544
+ with sawnergy_util.ArrayStorage.compress_and_cleanup(output_path, compression_level=compression_level) as storage:
545
+ storage.write(
546
+ these_arrays=embeddings,
547
+ to_block_named=block_name,
548
+ arrays_per_chunk=num_matrices_in_compressed_blocks
549
+ )
550
+ storage.add_attr("time_created", current_time)
551
+ storage.add_attr("seed", int(self._seed))
552
+ storage.add_attr("rng_scheme", "SeedSequence.spawn_per_frame_v1")
553
+ storage.add_attr("source_walks_path", str(self._walks_path))
554
+ storage.add_attr("model_base", self.model_base)
555
+ storage.add_attr("rin_type", RIN_type)
556
+ storage.add_attr("using_mode", using)
557
+ storage.add_attr("window_size", int(window_size))
558
+ storage.add_attr("alpha", float(alpha))
559
+ storage.add_attr("dimensionality", int(dimensionality))
560
+ storage.add_attr("num_negative_samples", int(num_negative_samples))
561
+ storage.add_attr("num_epochs", int(num_epochs))
562
+ storage.add_attr("batch_size", int(batch_size))
563
+ storage.add_attr("shuffle_data", bool(shuffle_data))
564
+ storage.add_attr("frames_written", int(len(embeddings)))
565
+ storage.add_attr("vocab_size", int(self.vocab_size))
566
+ storage.add_attr("frame_count", int(self.frame_count))
567
+ storage.add_attr("embedding_dtype", str(embeddings[0].dtype))
568
+ storage.add_attr("frame_embeddings_name", block_name)
569
+ storage.add_attr("arrays_per_chunk", int(num_matrices_in_compressed_blocks))
570
+ storage.add_attr("compression_level", int(compression_level))
571
+
572
+ _logger.info("Embedding archive written to %s", output_path)
573
+ return str(output_path)
574
+
575
+ __all__ = ["Embedder"]
576
+
577
+ if __name__ == "__main__":
578
+ pass
@@ -0,0 +1,54 @@
1
+ import logging
2
+ from logging.handlers import TimedRotatingFileHandler
3
+ from pathlib import Path
4
+ from datetime import datetime
5
+
6
+
7
+ def configure_logging(
8
+ logs_dir: Path | str,
9
+ file_level: int = logging.DEBUG,
10
+ console_level: int = logging.WARNING
11
+ ) -> None:
12
+ """
13
+ Configure a logger with a timed rotating file handler and console handler.
14
+
15
+ Args:
16
+ logs_dir: Directory where log files will be stored.
17
+ file_level: Logging level for the file handler (default: DEBUG).
18
+ console_level: Logging level for the console handler (default: WARNING).
19
+ """
20
+
21
+ if isinstance(logs_dir, str):
22
+ logs_dir = Path(logs_dir)
23
+
24
+ root = logging.getLogger()
25
+ if root.handlers:
26
+ return
27
+
28
+ logs_dir.mkdir(parents=True, exist_ok=True)
29
+
30
+ fmt = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
31
+ formatter = logging.Formatter(fmt)
32
+
33
+ logfile = logs_dir / f"sawnergy_{datetime.now():%Y-%m-%d_%H%M%S}.log"
34
+
35
+ file_h = TimedRotatingFileHandler(
36
+ logfile,
37
+ when="midnight",
38
+ backupCount=7,
39
+ encoding="utf-8"
40
+ )
41
+ file_h.setLevel(file_level)
42
+ file_h.setFormatter(formatter)
43
+
44
+ console_h = logging.StreamHandler()
45
+ console_h.setLevel(console_level)
46
+ console_h.setFormatter(formatter)
47
+
48
+ # ensure root level is low enough to handle both handlers
49
+ root.setLevel(min(file_level, console_level))
50
+ root.addHandler(file_h)
51
+ root.addHandler(console_h)
52
+
53
+
54
+ __all__ = ["configure_logging"]
@@ -0,0 +1,9 @@
1
+ # rin
2
+ from .rin_builder import RINBuilder
3
+ from .rin_util import run_cpptraj, CpptrajScript
4
+
5
+ __all__ = [
6
+ "RINBuilder",
7
+ "run_cpptraj",
8
+ "CpptrajScript"
9
+ ]