sawnergy 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sawnergy might be problematic. Click here for more details.
- sawnergy/__init__.py +13 -0
- sawnergy/embedding/SGNS_pml.py +135 -0
- sawnergy/embedding/SGNS_torch.py +177 -0
- sawnergy/embedding/__init__.py +34 -0
- sawnergy/embedding/embedder.py +578 -0
- sawnergy/logging_util.py +54 -0
- sawnergy/rin/__init__.py +9 -0
- sawnergy/rin/rin_builder.py +936 -0
- sawnergy/rin/rin_util.py +391 -0
- sawnergy/sawnergy_util.py +1182 -0
- sawnergy/visual/__init__.py +42 -0
- sawnergy/visual/visualizer.py +690 -0
- sawnergy/visual/visualizer_util.py +387 -0
- sawnergy/walks/__init__.py +16 -0
- sawnergy/walks/walker.py +795 -0
- sawnergy/walks/walker_util.py +384 -0
- sawnergy-1.0.0.dist-info/METADATA +290 -0
- sawnergy-1.0.0.dist-info/RECORD +22 -0
- sawnergy-1.0.0.dist-info/WHEEL +5 -0
- sawnergy-1.0.0.dist-info/licenses/LICENSE +201 -0
- sawnergy-1.0.0.dist-info/licenses/NOTICE +4 -0
- sawnergy-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,578 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""
|
|
4
|
+
Embedding orchestration for Skip-Gram with Negative Sampling (SGNS).
|
|
5
|
+
|
|
6
|
+
This module consumes attractive/repulsive walk corpora produced by the walker
|
|
7
|
+
pipeline and trains per-frame embeddings using either the PyTorch or PureML
|
|
8
|
+
implementations of SGNS. The resulting embeddings can be persisted back into
|
|
9
|
+
an ``ArrayStorage`` archive along with rich metadata describing the training
|
|
10
|
+
configuration.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
# third-pary
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
# built-in
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Literal
|
|
19
|
+
import logging
|
|
20
|
+
|
|
21
|
+
# local
|
|
22
|
+
from .. import sawnergy_util
|
|
23
|
+
|
|
24
|
+
# *----------------------------------------------------*
|
|
25
|
+
# GLOBALS
|
|
26
|
+
# *----------------------------------------------------*
|
|
27
|
+
|
|
28
|
+
_logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
# *----------------------------------------------------*
|
|
31
|
+
# CLASSES
|
|
32
|
+
# *----------------------------------------------------*
|
|
33
|
+
|
|
34
|
+
class Embedder:
|
|
35
|
+
"""Skip-gram embedder over attractive/repulsive walk corpora."""
|
|
36
|
+
|
|
37
|
+
def __init__(self,
|
|
38
|
+
WALKS_path: str | Path,
|
|
39
|
+
base: Literal["torch", "pureml"],
|
|
40
|
+
*,
|
|
41
|
+
seed: int | None = None
|
|
42
|
+
) -> None:
|
|
43
|
+
"""Initialize the embedder and load walk tensors.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
WALKS_path: Path to a ``WALKS_*.zip`` (or ``.zarr``) archive created
|
|
47
|
+
by the walker pipeline. The archive's root attrs must include:
|
|
48
|
+
``attractive_RWs_name``, ``repulsive_RWs_name``,
|
|
49
|
+
``attractive_SAWs_name``, ``repulsive_SAWs_name`` (each may be
|
|
50
|
+
``None`` if that collection is absent), and the metadata
|
|
51
|
+
``num_RWs``, ``num_SAWs``, ``node_count``, ``time_stamp_count``,
|
|
52
|
+
``walk_length``.
|
|
53
|
+
base: Which SGNS backend to use, either ``"torch"`` or ``"pureml"``.
|
|
54
|
+
seed: Optional seed for the embedder's RNG. If ``None``, a random
|
|
55
|
+
32-bit seed is chosen.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ValueError: If required metadata is missing or any loaded walk array
|
|
59
|
+
has an unexpected shape.
|
|
60
|
+
ImportError: If the requested backend is not installed.
|
|
61
|
+
NameError: If ``base`` is not one of ``{"torch","pureml"}``.
|
|
62
|
+
|
|
63
|
+
Notes:
|
|
64
|
+
- Walks in storage are 1-based (residue indexing). Internally, this
|
|
65
|
+
class normalizes to 0-based indices for training utilities.
|
|
66
|
+
"""
|
|
67
|
+
self._walks_path = Path(WALKS_path)
|
|
68
|
+
_logger.info("Initializing Embedder from %s (base=%s)", self._walks_path, base)
|
|
69
|
+
|
|
70
|
+
# placeholders for optional walk collections
|
|
71
|
+
self.attractive_RWs : np.ndarray | None = None
|
|
72
|
+
self.repulsive_RWs : np.ndarray | None = None
|
|
73
|
+
self.attractive_SAWs: np.ndarray | None = None
|
|
74
|
+
self.repulsive_SAWs : np.ndarray | None = None
|
|
75
|
+
|
|
76
|
+
# Load numpy arrays from read-only storage
|
|
77
|
+
with sawnergy_util.ArrayStorage(self._walks_path, mode="r") as storage:
|
|
78
|
+
attractive_RWs_name = storage.get_attr("attractive_RWs_name")
|
|
79
|
+
repulsive_RWs_name = storage.get_attr("repulsive_RWs_name")
|
|
80
|
+
attractive_SAWs_name = storage.get_attr("attractive_SAWs_name")
|
|
81
|
+
repulsive_SAWs_name = storage.get_attr("repulsive_SAWs_name")
|
|
82
|
+
|
|
83
|
+
attractive_RWs : np.ndarray | None = (
|
|
84
|
+
storage.read(attractive_RWs_name, slice(None)) if attractive_RWs_name is not None else None
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
repulsive_RWs : np.ndarray | None = (
|
|
88
|
+
storage.read(repulsive_RWs_name, slice(None)) if repulsive_RWs_name is not None else None
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
attractive_SAWs : np.ndarray | None = (
|
|
92
|
+
storage.read(attractive_SAWs_name, slice(None)) if attractive_SAWs_name is not None else None
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
repulsive_SAWs : np.ndarray | None = (
|
|
96
|
+
storage.read(repulsive_SAWs_name, slice(None)) if repulsive_SAWs_name is not None else None
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
num_RWs = storage.get_attr("num_RWs")
|
|
100
|
+
num_SAWs = storage.get_attr("num_SAWs")
|
|
101
|
+
node_count = storage.get_attr("node_count")
|
|
102
|
+
time_stamp_count = storage.get_attr("time_stamp_count")
|
|
103
|
+
walk_length = storage.get_attr("walk_length")
|
|
104
|
+
|
|
105
|
+
if node_count is None or time_stamp_count is None or walk_length is None:
|
|
106
|
+
raise ValueError("WALKS metadata missing one of node_count, time_stamp_count, walk_length")
|
|
107
|
+
|
|
108
|
+
_logger.debug(
|
|
109
|
+
("Loaded WALKS from %s"
|
|
110
|
+
" | ATTR RWs: %s %s"
|
|
111
|
+
" | REP RWs: %s %s"
|
|
112
|
+
" | ATTR SAWs: %s %s"
|
|
113
|
+
" | REP SAWs: %s %s"
|
|
114
|
+
" | num_RWs=%d num_SAWs=%d V=%d L=%d T=%d"),
|
|
115
|
+
self._walks_path,
|
|
116
|
+
getattr(attractive_RWs, "shape", None), getattr(attractive_RWs, "dtype", None),
|
|
117
|
+
getattr(repulsive_RWs, "shape", None), getattr(repulsive_RWs, "dtype", None),
|
|
118
|
+
getattr(attractive_SAWs, "shape", None), getattr(attractive_SAWs, "dtype", None),
|
|
119
|
+
getattr(repulsive_SAWs, "shape", None), getattr(repulsive_SAWs, "dtype", None),
|
|
120
|
+
num_RWs, num_SAWs, node_count, walk_length, time_stamp_count
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# expected shapes
|
|
124
|
+
RWs_expected = (time_stamp_count, node_count * num_RWs, walk_length+1) if (num_RWs > 0) else None
|
|
125
|
+
SAWs_expected = (time_stamp_count, node_count * num_SAWs, walk_length+1) if (num_SAWs > 0) else None
|
|
126
|
+
|
|
127
|
+
self.vocab_size = int(node_count)
|
|
128
|
+
self.frame_count = int(time_stamp_count)
|
|
129
|
+
self.walk_length = int(walk_length)
|
|
130
|
+
|
|
131
|
+
# store walks if present
|
|
132
|
+
if attractive_RWs is not None:
|
|
133
|
+
if RWs_expected and attractive_RWs.shape != RWs_expected:
|
|
134
|
+
raise ValueError(f"ATTR RWs: expected {RWs_expected}, got {attractive_RWs.shape}")
|
|
135
|
+
self.attractive_RWs = attractive_RWs
|
|
136
|
+
|
|
137
|
+
if repulsive_RWs is not None:
|
|
138
|
+
if RWs_expected and repulsive_RWs.shape != RWs_expected:
|
|
139
|
+
raise ValueError(f"REP RWs: expected {RWs_expected}, got {repulsive_RWs.shape}")
|
|
140
|
+
self.repulsive_RWs = repulsive_RWs
|
|
141
|
+
|
|
142
|
+
if attractive_SAWs is not None:
|
|
143
|
+
if SAWs_expected and attractive_SAWs.shape != SAWs_expected:
|
|
144
|
+
raise ValueError(f"ATTR SAWs: expected {SAWs_expected}, got {attractive_SAWs.shape}")
|
|
145
|
+
self.attractive_SAWs = attractive_SAWs
|
|
146
|
+
|
|
147
|
+
if repulsive_SAWs is not None:
|
|
148
|
+
if SAWs_expected and repulsive_SAWs.shape != SAWs_expected:
|
|
149
|
+
raise ValueError(f"REP SAWs: expected {SAWs_expected}, got {repulsive_SAWs.shape}")
|
|
150
|
+
self.repulsive_SAWs = repulsive_SAWs
|
|
151
|
+
|
|
152
|
+
# INTERNAL RNG
|
|
153
|
+
self._seed = np.random.randint(0, 2**32 - 1) if seed is None else int(seed)
|
|
154
|
+
self.rng = np.random.default_rng(self._seed)
|
|
155
|
+
_logger.info("RNG initialized from seed=%d", self._seed)
|
|
156
|
+
|
|
157
|
+
# MODEL HANDLE
|
|
158
|
+
self.model_base: Literal["torch", "pureml"] = base
|
|
159
|
+
self.model_constructor = self._get_SGNS_constructor_from(base)
|
|
160
|
+
_logger.info("SGNS backend resolved: %s", getattr(self.model_constructor, "__name__", repr(self.model_constructor)))
|
|
161
|
+
|
|
162
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- PRIVATE -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
|
163
|
+
|
|
164
|
+
# HELPERS:
|
|
165
|
+
|
|
166
|
+
@staticmethod
|
|
167
|
+
def _get_SGNS_constructor_from(base: Literal["torch", "pureml"]):
|
|
168
|
+
"""Resolve the SGNS implementation class for the selected backend."""
|
|
169
|
+
if base == "torch":
|
|
170
|
+
try:
|
|
171
|
+
from .SGNS_torch import SGNS_Torch
|
|
172
|
+
return SGNS_Torch
|
|
173
|
+
except Exception:
|
|
174
|
+
raise ImportError(
|
|
175
|
+
"PyTorch is not installed, but base='torch' was requested. "
|
|
176
|
+
"Install PyTorch first, e.g.: `pip install torch` "
|
|
177
|
+
"(see https://pytorch.org/get-started for platform-specific wheels)."
|
|
178
|
+
)
|
|
179
|
+
elif base == "pureml":
|
|
180
|
+
try:
|
|
181
|
+
from .SGNS_pml import SGNS_PureML
|
|
182
|
+
return SGNS_PureML
|
|
183
|
+
except Exception:
|
|
184
|
+
raise ImportError(
|
|
185
|
+
"PureML is not installed, but base='pureml' was requested. "
|
|
186
|
+
"Install PureML first via `pip install ym-pure-ml` "
|
|
187
|
+
)
|
|
188
|
+
else:
|
|
189
|
+
raise NameError(f"Expected `base` in (\"torch\", \"pureml\"); Instead got: {base}")
|
|
190
|
+
|
|
191
|
+
@staticmethod
|
|
192
|
+
def _as_zerobase_intp(walks: np.ndarray, *, V: int) -> np.ndarray:
|
|
193
|
+
"""Validate 1-based uint/int walks → 0-based intp; check bounds."""
|
|
194
|
+
arr = np.asarray(walks)
|
|
195
|
+
if arr.ndim != 2:
|
|
196
|
+
raise ValueError("walks must be 2D: (num_walks, walk_len)")
|
|
197
|
+
if arr.dtype.kind not in "iu":
|
|
198
|
+
arr = arr.astype(np.int64, copy=False)
|
|
199
|
+
# 1-based → 0-based
|
|
200
|
+
arr = arr - 1
|
|
201
|
+
if arr.min() < 0 or arr.max() >= V:
|
|
202
|
+
raise ValueError("walk ids out of range after 1→0-based normalization")
|
|
203
|
+
return arr.astype(np.intp, copy=False)
|
|
204
|
+
|
|
205
|
+
@staticmethod
|
|
206
|
+
def _pairs_from_walks(walks0: np.ndarray, window_size: int) -> np.ndarray:
|
|
207
|
+
"""
|
|
208
|
+
Skip-gram pairs including edge centers (one-sided when needed).
|
|
209
|
+
walks0: (W, L) int array (0-based ids).
|
|
210
|
+
Returns: (N_pairs, 2) int32 [center, context].
|
|
211
|
+
"""
|
|
212
|
+
if walks0.ndim != 2:
|
|
213
|
+
raise ValueError("walks must be 2D: (num_walks, walk_len)")
|
|
214
|
+
|
|
215
|
+
_, L = walks0.shape
|
|
216
|
+
k = int(window_size)
|
|
217
|
+
|
|
218
|
+
if k <= 0:
|
|
219
|
+
raise ValueError("window_size must be positive")
|
|
220
|
+
|
|
221
|
+
if L == 0:
|
|
222
|
+
return np.empty((0, 2), dtype=np.int32)
|
|
223
|
+
|
|
224
|
+
out_chunks = []
|
|
225
|
+
for d in range(1, k + 1):
|
|
226
|
+
span = L - d
|
|
227
|
+
if span <= 0:
|
|
228
|
+
break
|
|
229
|
+
# right contexts: center j pairs with j+d (centers 0..L-d-1)
|
|
230
|
+
centers_r = walks0[:, :L - d]
|
|
231
|
+
ctx_r = walks0[:, d:]
|
|
232
|
+
out_chunks.append(np.stack((centers_r, ctx_r), axis=2).reshape(-1, 2))
|
|
233
|
+
# left contexts: center j pairs with j-d (centers d..L-1)
|
|
234
|
+
centers_l = walks0[:, d:]
|
|
235
|
+
ctx_l = walks0[:, :L - d]
|
|
236
|
+
out_chunks.append(np.stack((centers_l, ctx_l), axis=2).reshape(-1, 2))
|
|
237
|
+
|
|
238
|
+
if not out_chunks:
|
|
239
|
+
return np.empty((0, 2), dtype=np.int32)
|
|
240
|
+
|
|
241
|
+
return np.concatenate(out_chunks, axis=0).astype(np.int32, copy=False)
|
|
242
|
+
|
|
243
|
+
@staticmethod
|
|
244
|
+
def _freq_from_walks(walks0: np.ndarray, *, V: int) -> np.ndarray:
|
|
245
|
+
"""Node frequencies from walks (0-based)."""
|
|
246
|
+
return np.bincount(walks0.ravel(), minlength=V).astype(np.int64, copy=False)
|
|
247
|
+
|
|
248
|
+
@staticmethod
|
|
249
|
+
def _soft_unigram(freq: np.ndarray, *, power: float = 0.75) -> np.ndarray:
|
|
250
|
+
"""Return normalized Pn(w) ∝ f(w)^power as float64 probs."""
|
|
251
|
+
p = np.asarray(freq, dtype=np.float64)
|
|
252
|
+
if p.sum() == 0:
|
|
253
|
+
raise ValueError("all frequencies are zero")
|
|
254
|
+
p = np.power(p, float(power))
|
|
255
|
+
s = p.sum()
|
|
256
|
+
if not np.isfinite(s) or s <= 0:
|
|
257
|
+
raise ValueError("invalid unigram mass")
|
|
258
|
+
return p / s
|
|
259
|
+
|
|
260
|
+
def _materialize_walks(self, frame_id: int, rin: Literal["attr", "repuls"],
|
|
261
|
+
using: Literal["RW", "SAW", "merged"]) -> np.ndarray:
|
|
262
|
+
if not 1 <= frame_id <= int(self.frame_count):
|
|
263
|
+
raise IndexError(f"frame_id must be in [1, {self.frame_count}]; got {frame_id}")
|
|
264
|
+
|
|
265
|
+
frame_id -= 1
|
|
266
|
+
|
|
267
|
+
if rin == "attr":
|
|
268
|
+
parts = []
|
|
269
|
+
if using in ("RW", "merged"):
|
|
270
|
+
arr = getattr(self, "attractive_RWs", None)
|
|
271
|
+
if arr is not None:
|
|
272
|
+
parts.append(arr[frame_id])
|
|
273
|
+
if using in ("SAW", "merged"):
|
|
274
|
+
arr = getattr(self, "attractive_SAWs", None)
|
|
275
|
+
if arr is not None:
|
|
276
|
+
parts.append(arr[frame_id])
|
|
277
|
+
else:
|
|
278
|
+
parts = []
|
|
279
|
+
if using in ("RW", "merged"):
|
|
280
|
+
arr = getattr(self, "repulsive_RWs", None)
|
|
281
|
+
if arr is not None:
|
|
282
|
+
parts.append(arr[frame_id])
|
|
283
|
+
if using in ("SAW", "merged"):
|
|
284
|
+
arr = getattr(self, "repulsive_SAWs", None)
|
|
285
|
+
if arr is not None:
|
|
286
|
+
parts.append(arr[frame_id])
|
|
287
|
+
|
|
288
|
+
if not parts:
|
|
289
|
+
raise ValueError(f"No walks available for {rin=} with {using=}")
|
|
290
|
+
if len(parts) == 1:
|
|
291
|
+
return parts[0]
|
|
292
|
+
return np.concatenate(parts, axis=0)
|
|
293
|
+
|
|
294
|
+
# INTERFACES: (private)
|
|
295
|
+
|
|
296
|
+
def _attractive_corpus_and_prob(self, *,
|
|
297
|
+
frame_id: int,
|
|
298
|
+
using: Literal["RW", "SAW", "merged"],
|
|
299
|
+
window_size: int,
|
|
300
|
+
alpha: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
|
|
301
|
+
walks = self._materialize_walks(frame_id, "attr", using)
|
|
302
|
+
walks0 = self._as_zerobase_intp(walks, V=self.vocab_size)
|
|
303
|
+
attractive_corpus = self._pairs_from_walks(walks0, window_size)
|
|
304
|
+
attractive_noise_probs = self._soft_unigram(self._freq_from_walks(walks0, V=self.vocab_size), power=alpha)
|
|
305
|
+
_logger.info("ATTR corpus ready: pairs=%d", 0 if attractive_corpus is None else attractive_corpus.shape[0])
|
|
306
|
+
|
|
307
|
+
return attractive_corpus, attractive_noise_probs
|
|
308
|
+
|
|
309
|
+
def _repulsive_corpus_and_prob(self, *,
|
|
310
|
+
frame_id: int,
|
|
311
|
+
using: Literal["RW", "SAW", "merged"],
|
|
312
|
+
window_size: int,
|
|
313
|
+
alpha: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
|
|
314
|
+
walks = self._materialize_walks(frame_id, "repuls", using)
|
|
315
|
+
walks0 = self._as_zerobase_intp(walks, V=self.vocab_size)
|
|
316
|
+
repulsive_corpus = self._pairs_from_walks(walks0, window_size)
|
|
317
|
+
repulsive_noise_probs = self._soft_unigram(self._freq_from_walks(walks0, V=self.vocab_size), power=alpha)
|
|
318
|
+
_logger.info("REP corpus ready: pairs=%d", 0 if repulsive_corpus is None else repulsive_corpus.shape[0])
|
|
319
|
+
|
|
320
|
+
return repulsive_corpus, repulsive_noise_probs
|
|
321
|
+
|
|
322
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= PUBLIC -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
|
323
|
+
|
|
324
|
+
def embed_frame(self,
|
|
325
|
+
frame_id: int,
|
|
326
|
+
RIN_type: Literal["attr", "repuls"],
|
|
327
|
+
using: Literal["RW", "SAW", "merged"],
|
|
328
|
+
window_size: int,
|
|
329
|
+
num_negative_samples: int,
|
|
330
|
+
num_epochs: int,
|
|
331
|
+
batch_size: int,
|
|
332
|
+
*,
|
|
333
|
+
shuffle_data: bool = True,
|
|
334
|
+
dimensionality: int = 128,
|
|
335
|
+
alpha: float = 0.75,
|
|
336
|
+
device: str | None = None,
|
|
337
|
+
sgns_kwargs: dict[str, object] | None = None,
|
|
338
|
+
_seed: int | None = None
|
|
339
|
+
) -> np.ndarray:
|
|
340
|
+
"""Train embeddings for a single frame and return the input embedding matrix.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
frame_id: 1-based frame index to train on.
|
|
344
|
+
RIN_type: Interaction channel to use: ``"attr"`` (attractive) or
|
|
345
|
+
``"repuls"`` (repulsive).
|
|
346
|
+
using: Which walk collections to include: ``"RW"``, ``"SAW"``, or
|
|
347
|
+
``"merged"`` (concatenates both if available).
|
|
348
|
+
window_size: Symmetric skip-gram window size ``k``.
|
|
349
|
+
num_negative_samples: Number of negative samples per positive pair.
|
|
350
|
+
num_epochs: Number of passes over the pair dataset.
|
|
351
|
+
batch_size: Mini-batch size for training.
|
|
352
|
+
shuffle_data: Whether to shuffle pairs each epoch.
|
|
353
|
+
dimensionality: Embedding dimensionality ``D``.
|
|
354
|
+
alpha: Noise distribution exponent (``Pn ∝ f^alpha``).
|
|
355
|
+
device: Optional device string for the Torch backend (e.g., ``"cuda"``).
|
|
356
|
+
sgns_kwargs: Extra keyword arguments forwarded to the backend SGNS
|
|
357
|
+
constructor. For PureML, required keys are:
|
|
358
|
+
``{"optim", "optim_kwargs", "lr_sched", "lr_sched_kwargs"}``.
|
|
359
|
+
_seed: Optional child seed for this frame's model initialization.
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
np.ndarray: Learned **input** embedding matrix of shape ``(V, D)``.
|
|
363
|
+
|
|
364
|
+
Raises:
|
|
365
|
+
ValueError: If requested walks are missing, if no training pairs are
|
|
366
|
+
generated, or if required ``sgns_kwargs`` for PureML are absent.
|
|
367
|
+
AttributeError: If the SGNS model does not expose embeddings via
|
|
368
|
+
``.embeddings`` or ``.parameters[0]``.
|
|
369
|
+
"""
|
|
370
|
+
_logger.info(
|
|
371
|
+
"Preparing frame %d (rin=%s using=%s window=%d neg=%d epochs=%d batch=%d)",
|
|
372
|
+
frame_id, RIN_type, using, window_size, num_negative_samples, num_epochs, batch_size
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
if RIN_type == "attr":
|
|
376
|
+
if self.attractive_RWs is None and self.attractive_SAWs is None:
|
|
377
|
+
raise ValueError("Attractive random walks are missing")
|
|
378
|
+
pairs, noise_probs = self._attractive_corpus_and_prob(frame_id=frame_id, using=using, window_size=window_size, alpha=alpha)
|
|
379
|
+
elif RIN_type == "repuls":
|
|
380
|
+
if self.repulsive_RWs is None and self.repulsive_SAWs is None:
|
|
381
|
+
raise ValueError("Repulsive random walks are missing")
|
|
382
|
+
pairs, noise_probs = self._repulsive_corpus_and_prob(frame_id=frame_id, using=using, window_size=window_size, alpha=alpha)
|
|
383
|
+
else:
|
|
384
|
+
raise ValueError(f"Unknown RIN_type: {RIN_type!r}")
|
|
385
|
+
|
|
386
|
+
if pairs.size == 0:
|
|
387
|
+
raise ValueError("No training pairs generated for the requested configuration")
|
|
388
|
+
|
|
389
|
+
centers = pairs[:, 0].astype(np.int64, copy=False)
|
|
390
|
+
contexts = pairs[:, 1].astype(np.int64, copy=False)
|
|
391
|
+
|
|
392
|
+
model_kwargs: dict[str, object] = dict(sgns_kwargs or {})
|
|
393
|
+
if self.model_base == "pureml":
|
|
394
|
+
required = {"optim", "optim_kwargs", "lr_sched", "lr_sched_kwargs"}
|
|
395
|
+
missing = required.difference(model_kwargs)
|
|
396
|
+
if missing:
|
|
397
|
+
raise ValueError(f"PureML backend requires {sorted(missing)} in sgns_kwargs.")
|
|
398
|
+
|
|
399
|
+
child_seed = int(self._seed if _seed is None else _seed)
|
|
400
|
+
model_kwargs.update({
|
|
401
|
+
"V": self.vocab_size,
|
|
402
|
+
"D": dimensionality,
|
|
403
|
+
"seed": child_seed
|
|
404
|
+
})
|
|
405
|
+
|
|
406
|
+
if self.model_base == "torch" and device is not None:
|
|
407
|
+
model_kwargs["device"] = device
|
|
408
|
+
|
|
409
|
+
self.model = self.model_constructor(**model_kwargs)
|
|
410
|
+
|
|
411
|
+
_logger.info(
|
|
412
|
+
"Training SGNS base=%s constructor=%s frame=%d pairs=%d dim=%d epochs=%d batch=%d neg=%d shuffle=%s",
|
|
413
|
+
self.model_base,
|
|
414
|
+
getattr(self.model_constructor, "__name__", repr(self.model_constructor)),
|
|
415
|
+
frame_id,
|
|
416
|
+
pairs.shape[0],
|
|
417
|
+
dimensionality,
|
|
418
|
+
num_epochs,
|
|
419
|
+
batch_size,
|
|
420
|
+
num_negative_samples,
|
|
421
|
+
shuffle_data
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
self.model.fit(
|
|
425
|
+
centers,
|
|
426
|
+
contexts,
|
|
427
|
+
num_epochs,
|
|
428
|
+
batch_size,
|
|
429
|
+
num_negative_samples,
|
|
430
|
+
noise_probs,
|
|
431
|
+
shuffle_data,
|
|
432
|
+
lr_step_per_batch=False
|
|
433
|
+
)
|
|
434
|
+
|
|
435
|
+
embeddings = getattr(self.model, "embeddings", None)
|
|
436
|
+
if embeddings is None:
|
|
437
|
+
params = getattr(self.model, "parameters", None)
|
|
438
|
+
if isinstance(params, tuple) and params:
|
|
439
|
+
embeddings = params[0]
|
|
440
|
+
if embeddings is None:
|
|
441
|
+
raise AttributeError("SGNS model does not expose embeddings via '.embeddings' or '.parameters[0]'")
|
|
442
|
+
|
|
443
|
+
embeddings = np.asarray(embeddings)
|
|
444
|
+
_logger.info("Frame %d embeddings ready: shape=%s dtype=%s", frame_id, embeddings.shape, embeddings.dtype)
|
|
445
|
+
return embeddings
|
|
446
|
+
|
|
447
|
+
def embed_all(
|
|
448
|
+
self,
|
|
449
|
+
RIN_type: Literal["attr", "repuls"],
|
|
450
|
+
using: Literal["RW", "SAW", "merged"],
|
|
451
|
+
window_size: int,
|
|
452
|
+
num_negative_samples: int,
|
|
453
|
+
num_epochs: int,
|
|
454
|
+
batch_size: int,
|
|
455
|
+
*,
|
|
456
|
+
shuffle_data: bool = True,
|
|
457
|
+
dimensionality: int = 128,
|
|
458
|
+
alpha: float = 0.75,
|
|
459
|
+
device: str | None = None,
|
|
460
|
+
sgns_kwargs: dict[str, object] | None = None,
|
|
461
|
+
output_path: str | Path | None = None,
|
|
462
|
+
num_matrices_in_compressed_blocks: int = 20,
|
|
463
|
+
compression_level: int = 3):
|
|
464
|
+
"""Train embeddings for all frames and persist them to compressed storage.
|
|
465
|
+
|
|
466
|
+
Iterates through all frames (``1..frame_count``), trains an SGNS model
|
|
467
|
+
per frame using the configured backend, collects the resulting input
|
|
468
|
+
embeddings, and writes them into a new compressed ``ArrayStorage`` archive.
|
|
469
|
+
|
|
470
|
+
Args:
|
|
471
|
+
RIN_type: Interaction channel to use: ``"attr"`` or ``"repuls"``.
|
|
472
|
+
using: Walk collections: ``"RW"``, ``"SAW"``, or ``"merged"``.
|
|
473
|
+
window_size: Symmetric skip-gram window size ``k``.
|
|
474
|
+
num_negative_samples: Number of negative samples per positive pair.
|
|
475
|
+
num_epochs: Number of epochs for each frame.
|
|
476
|
+
batch_size: Mini-batch size used during training.
|
|
477
|
+
shuffle_data: Whether to shuffle pairs each epoch.
|
|
478
|
+
dimensionality: Embedding dimensionality ``D``.
|
|
479
|
+
alpha: Noise distribution exponent (``Pn ∝ f^alpha``).
|
|
480
|
+
device: Optional device string for Torch backend.
|
|
481
|
+
sgns_kwargs: Extra constructor kwargs for the SGNS backend (see
|
|
482
|
+
:meth:`embed_frame` for PureML requirements).
|
|
483
|
+
output_path: Destination path. If ``None``, a new file named
|
|
484
|
+
``EMBEDDINGS_<timestamp>.zip`` is created next to the source
|
|
485
|
+
WALKS archive. If the provided path lacks a suffix, ``.zip`` is
|
|
486
|
+
appended.
|
|
487
|
+
num_matrices_in_compressed_blocks: Number of per-frame matrices to
|
|
488
|
+
store per compressed chunk in the output archive.
|
|
489
|
+
compression_level: Blosc Zstd compression level (0-9).
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
str: Filesystem path to the written embeddings archive (``.zip``).
|
|
493
|
+
|
|
494
|
+
Raises:
|
|
495
|
+
ValueError: If configuration produces no pairs for a frame or if
|
|
496
|
+
PureML kwargs are incomplete.
|
|
497
|
+
RuntimeError: Propagated from storage operations on failure.
|
|
498
|
+
|
|
499
|
+
Notes:
|
|
500
|
+
- A deterministic child seed is spawned per frame from the master
|
|
501
|
+
seed using ``np.random.SeedSequence`` to ensure reproducibility
|
|
502
|
+
across runs.
|
|
503
|
+
"""
|
|
504
|
+
current_time = sawnergy_util.current_time()
|
|
505
|
+
if output_path is None:
|
|
506
|
+
output_path = self._walks_path.with_name(f"EMBEDDINGS_{current_time}").with_suffix(".zip")
|
|
507
|
+
else:
|
|
508
|
+
output_path = Path(output_path)
|
|
509
|
+
if output_path.suffix == "":
|
|
510
|
+
output_path = output_path.with_suffix(".zip")
|
|
511
|
+
|
|
512
|
+
_logger.info(
|
|
513
|
+
"Embedding all frames -> %s | frames=%d dim=%d base=%s",
|
|
514
|
+
output_path, self.frame_count, dimensionality, self.model_base
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
master_ss = np.random.SeedSequence(self._seed)
|
|
518
|
+
child_seeds = master_ss.spawn(self.frame_count)
|
|
519
|
+
|
|
520
|
+
embeddings = []
|
|
521
|
+
for frame_idx, seed_seq in enumerate(child_seeds, start=1):
|
|
522
|
+
child_seed = int(seed_seq.generate_state(1, dtype=np.uint32)[0])
|
|
523
|
+
_logger.info("Processing frame %d/%d (child_seed=%d entropy=%d)", frame_idx, self.frame_count, child_seed, seed_seq.entropy)
|
|
524
|
+
embeddings.append(
|
|
525
|
+
self.embed_frame(
|
|
526
|
+
frame_idx,
|
|
527
|
+
RIN_type,
|
|
528
|
+
using,
|
|
529
|
+
window_size,
|
|
530
|
+
num_negative_samples,
|
|
531
|
+
num_epochs,
|
|
532
|
+
batch_size,
|
|
533
|
+
shuffle_data=shuffle_data,
|
|
534
|
+
dimensionality=dimensionality,
|
|
535
|
+
alpha=alpha,
|
|
536
|
+
device=device,
|
|
537
|
+
sgns_kwargs=sgns_kwargs,
|
|
538
|
+
_seed=child_seed
|
|
539
|
+
)
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
embeddings = [np.asarray(e) for e in embeddings]
|
|
543
|
+
block_name = "FRAME_EMBEDDINGS"
|
|
544
|
+
with sawnergy_util.ArrayStorage.compress_and_cleanup(output_path, compression_level=compression_level) as storage:
|
|
545
|
+
storage.write(
|
|
546
|
+
these_arrays=embeddings,
|
|
547
|
+
to_block_named=block_name,
|
|
548
|
+
arrays_per_chunk=num_matrices_in_compressed_blocks
|
|
549
|
+
)
|
|
550
|
+
storage.add_attr("time_created", current_time)
|
|
551
|
+
storage.add_attr("seed", int(self._seed))
|
|
552
|
+
storage.add_attr("rng_scheme", "SeedSequence.spawn_per_frame_v1")
|
|
553
|
+
storage.add_attr("source_walks_path", str(self._walks_path))
|
|
554
|
+
storage.add_attr("model_base", self.model_base)
|
|
555
|
+
storage.add_attr("rin_type", RIN_type)
|
|
556
|
+
storage.add_attr("using_mode", using)
|
|
557
|
+
storage.add_attr("window_size", int(window_size))
|
|
558
|
+
storage.add_attr("alpha", float(alpha))
|
|
559
|
+
storage.add_attr("dimensionality", int(dimensionality))
|
|
560
|
+
storage.add_attr("num_negative_samples", int(num_negative_samples))
|
|
561
|
+
storage.add_attr("num_epochs", int(num_epochs))
|
|
562
|
+
storage.add_attr("batch_size", int(batch_size))
|
|
563
|
+
storage.add_attr("shuffle_data", bool(shuffle_data))
|
|
564
|
+
storage.add_attr("frames_written", int(len(embeddings)))
|
|
565
|
+
storage.add_attr("vocab_size", int(self.vocab_size))
|
|
566
|
+
storage.add_attr("frame_count", int(self.frame_count))
|
|
567
|
+
storage.add_attr("embedding_dtype", str(embeddings[0].dtype))
|
|
568
|
+
storage.add_attr("frame_embeddings_name", block_name)
|
|
569
|
+
storage.add_attr("arrays_per_chunk", int(num_matrices_in_compressed_blocks))
|
|
570
|
+
storage.add_attr("compression_level", int(compression_level))
|
|
571
|
+
|
|
572
|
+
_logger.info("Embedding archive written to %s", output_path)
|
|
573
|
+
return str(output_path)
|
|
574
|
+
|
|
575
|
+
__all__ = ["Embedder"]
|
|
576
|
+
|
|
577
|
+
if __name__ == "__main__":
|
|
578
|
+
pass
|
sawnergy/logging_util.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from logging.handlers import TimedRotatingFileHandler
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def configure_logging(
|
|
8
|
+
logs_dir: Path | str,
|
|
9
|
+
file_level: int = logging.DEBUG,
|
|
10
|
+
console_level: int = logging.WARNING
|
|
11
|
+
) -> None:
|
|
12
|
+
"""
|
|
13
|
+
Configure a logger with a timed rotating file handler and console handler.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
logs_dir: Directory where log files will be stored.
|
|
17
|
+
file_level: Logging level for the file handler (default: DEBUG).
|
|
18
|
+
console_level: Logging level for the console handler (default: WARNING).
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
if isinstance(logs_dir, str):
|
|
22
|
+
logs_dir = Path(logs_dir)
|
|
23
|
+
|
|
24
|
+
root = logging.getLogger()
|
|
25
|
+
if root.handlers:
|
|
26
|
+
return
|
|
27
|
+
|
|
28
|
+
logs_dir.mkdir(parents=True, exist_ok=True)
|
|
29
|
+
|
|
30
|
+
fmt = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
|
31
|
+
formatter = logging.Formatter(fmt)
|
|
32
|
+
|
|
33
|
+
logfile = logs_dir / f"sawnergy_{datetime.now():%Y-%m-%d_%H%M%S}.log"
|
|
34
|
+
|
|
35
|
+
file_h = TimedRotatingFileHandler(
|
|
36
|
+
logfile,
|
|
37
|
+
when="midnight",
|
|
38
|
+
backupCount=7,
|
|
39
|
+
encoding="utf-8"
|
|
40
|
+
)
|
|
41
|
+
file_h.setLevel(file_level)
|
|
42
|
+
file_h.setFormatter(formatter)
|
|
43
|
+
|
|
44
|
+
console_h = logging.StreamHandler()
|
|
45
|
+
console_h.setLevel(console_level)
|
|
46
|
+
console_h.setFormatter(formatter)
|
|
47
|
+
|
|
48
|
+
# ensure root level is low enough to handle both handlers
|
|
49
|
+
root.setLevel(min(file_level, console_level))
|
|
50
|
+
root.addHandler(file_h)
|
|
51
|
+
root.addHandler(console_h)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
__all__ = ["configure_logging"]
|