sawnergy 1.0.3__py3-none-any.whl → 1.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sawnergy/__init__.py +3 -1
- sawnergy/embedding/SGNS_pml.py +324 -51
- sawnergy/embedding/SGNS_torch.py +282 -39
- sawnergy/embedding/__init__.py +26 -1
- sawnergy/embedding/embedder.py +426 -203
- sawnergy/embedding/visualizer.py +251 -0
- sawnergy/logging_util.py +1 -1
- sawnergy/rin/rin_builder.py +4 -4
- sawnergy/visual/visualizer.py +6 -6
- sawnergy/visual/visualizer_util.py +3 -0
- sawnergy/walks/walker.py +43 -22
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/METADATA +91 -57
- sawnergy-1.0.9.dist-info/RECORD +23 -0
- sawnergy-1.0.3.dist-info/RECORD +0 -22
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/WHEEL +0 -0
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/licenses/LICENSE +0 -0
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/licenses/NOTICE +0 -0
- {sawnergy-1.0.3.dist-info → sawnergy-1.0.9.dist-info}/top_level.txt +0 -0
sawnergy/embedding/embedder.py
CHANGED
|
@@ -1,15 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
"""
|
|
4
|
-
Embedding orchestration for Skip-Gram with Negative Sampling (SGNS).
|
|
5
|
-
|
|
6
|
-
This module consumes attractive/repulsive walk corpora produced by the walker
|
|
7
|
-
pipeline and trains per-frame embeddings using either the PyTorch or PureML
|
|
8
|
-
implementations of SGNS. The resulting embeddings can be persisted back into
|
|
9
|
-
an ``ArrayStorage`` archive along with rich metadata describing the training
|
|
10
|
-
configuration.
|
|
11
|
-
"""
|
|
12
|
-
|
|
13
3
|
# third-pary
|
|
14
4
|
import numpy as np
|
|
15
5
|
|
|
@@ -36,9 +26,8 @@ class Embedder:
|
|
|
36
26
|
|
|
37
27
|
def __init__(self,
|
|
38
28
|
WALKS_path: str | Path,
|
|
39
|
-
base: Literal["torch", "pureml"],
|
|
40
29
|
*,
|
|
41
|
-
seed: int | None = None
|
|
30
|
+
seed: int | None = None,
|
|
42
31
|
) -> None:
|
|
43
32
|
"""Initialize the embedder and load walk tensors.
|
|
44
33
|
|
|
@@ -50,22 +39,19 @@ class Embedder:
|
|
|
50
39
|
``None`` if that collection is absent), and the metadata
|
|
51
40
|
``num_RWs``, ``num_SAWs``, ``node_count``, ``time_stamp_count``,
|
|
52
41
|
``walk_length``.
|
|
53
|
-
base: Which SGNS backend to use, either ``"torch"`` or ``"pureml"``.
|
|
54
42
|
seed: Optional seed for the embedder's RNG. If ``None``, a random
|
|
55
43
|
32-bit seed is chosen.
|
|
56
44
|
|
|
57
45
|
Raises:
|
|
58
46
|
ValueError: If required metadata is missing or any loaded walk array
|
|
59
47
|
has an unexpected shape.
|
|
60
|
-
ImportError: If the requested backend is not installed.
|
|
61
|
-
NameError: If ``base`` is not one of ``{"torch","pureml"}``.
|
|
62
48
|
|
|
63
49
|
Notes:
|
|
64
50
|
- Walks in storage are 1-based (residue indexing). Internally, this
|
|
65
51
|
class normalizes to 0-based indices for training utilities.
|
|
66
52
|
"""
|
|
67
53
|
self._walks_path = Path(WALKS_path)
|
|
68
|
-
_logger.info("Initializing Embedder from %s
|
|
54
|
+
_logger.info("Initializing Embedder from %s", self._walks_path)
|
|
69
55
|
|
|
70
56
|
# placeholders for optional walk collections
|
|
71
57
|
self.attractive_RWs : np.ndarray | None = None
|
|
@@ -124,53 +110,76 @@ class Embedder:
|
|
|
124
110
|
RWs_expected = (time_stamp_count, node_count * num_RWs, walk_length+1) if (num_RWs > 0) else None
|
|
125
111
|
SAWs_expected = (time_stamp_count, node_count * num_SAWs, walk_length+1) if (num_SAWs > 0) else None
|
|
126
112
|
|
|
127
|
-
self.vocab_size
|
|
128
|
-
self.frame_count
|
|
129
|
-
self.walk_length
|
|
113
|
+
self.vocab_size = int(node_count)
|
|
114
|
+
self.frame_count = int(time_stamp_count)
|
|
115
|
+
self.walk_length = int(walk_length)
|
|
116
|
+
self.num_RWs = int(num_RWs)
|
|
117
|
+
self.num_SAWs = int(num_SAWs)
|
|
118
|
+
# Keep dataset names for metadata passthrough
|
|
119
|
+
self._attractive_RWs_name = attractive_RWs_name
|
|
120
|
+
self._repulsive_RWs_name = repulsive_RWs_name
|
|
121
|
+
self._attractive_SAWs_name = attractive_SAWs_name
|
|
122
|
+
self._repulsive_SAWs_name = repulsive_SAWs_name
|
|
130
123
|
|
|
131
124
|
# store walks if present
|
|
132
125
|
if attractive_RWs is not None:
|
|
133
126
|
if RWs_expected and attractive_RWs.shape != RWs_expected:
|
|
134
127
|
raise ValueError(f"ATTR RWs: expected {RWs_expected}, got {attractive_RWs.shape}")
|
|
135
128
|
self.attractive_RWs = attractive_RWs
|
|
129
|
+
_logger.debug("ATTR RWs loaded: %s", self.attractive_RWs.shape)
|
|
136
130
|
|
|
137
131
|
if repulsive_RWs is not None:
|
|
138
132
|
if RWs_expected and repulsive_RWs.shape != RWs_expected:
|
|
139
133
|
raise ValueError(f"REP RWs: expected {RWs_expected}, got {repulsive_RWs.shape}")
|
|
140
134
|
self.repulsive_RWs = repulsive_RWs
|
|
135
|
+
_logger.debug("REP RWs loaded: %s", self.repulsive_RWs.shape)
|
|
141
136
|
|
|
142
137
|
if attractive_SAWs is not None:
|
|
143
138
|
if SAWs_expected and attractive_SAWs.shape != SAWs_expected:
|
|
144
139
|
raise ValueError(f"ATTR SAWs: expected {SAWs_expected}, got {attractive_SAWs.shape}")
|
|
145
140
|
self.attractive_SAWs = attractive_SAWs
|
|
141
|
+
_logger.debug("ATTR SAWs loaded: %s", self.attractive_SAWs.shape)
|
|
146
142
|
|
|
147
143
|
if repulsive_SAWs is not None:
|
|
148
144
|
if SAWs_expected and repulsive_SAWs.shape != SAWs_expected:
|
|
149
145
|
raise ValueError(f"REP SAWs: expected {SAWs_expected}, got {repulsive_SAWs.shape}")
|
|
150
146
|
self.repulsive_SAWs = repulsive_SAWs
|
|
147
|
+
_logger.debug("REP SAWs loaded: %s", self.repulsive_SAWs.shape)
|
|
151
148
|
|
|
152
149
|
# INTERNAL RNG
|
|
153
150
|
self._seed = np.random.randint(0, 2**32 - 1) if seed is None else int(seed)
|
|
154
151
|
self.rng = np.random.default_rng(self._seed)
|
|
155
152
|
_logger.info("RNG initialized from seed=%d", self._seed)
|
|
156
153
|
|
|
157
|
-
# MODEL HANDLE
|
|
158
|
-
self.model_base: Literal["torch", "pureml"] = base
|
|
159
|
-
self.model_constructor = self._get_SGNS_constructor_from(base)
|
|
160
|
-
_logger.info("SGNS backend resolved: %s", getattr(self.model_constructor, "__name__", repr(self.model_constructor)))
|
|
161
|
-
|
|
162
154
|
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- PRIVATE -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
|
163
155
|
|
|
164
156
|
# HELPERS:
|
|
165
157
|
|
|
166
158
|
@staticmethod
|
|
167
|
-
def
|
|
168
|
-
|
|
159
|
+
def _get_NN_constructor_from(base: Literal["torch", "pureml"],
|
|
160
|
+
objective: Literal["sgns", "sg"]):
|
|
161
|
+
"""Resolve the SG/SGNS implementation class for the selected backend.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
base: Backend family to use, ``"torch"`` or ``"pureml"``.
|
|
165
|
+
objective: Training objective, ``"sgns"`` or ``"sg"``.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
A callable class (constructor) implementing the requested model.
|
|
169
|
+
|
|
170
|
+
Raises:
|
|
171
|
+
ImportError: If the requested backend cannot be imported.
|
|
172
|
+
NameError: If ``base`` is not one of the supported values.
|
|
173
|
+
"""
|
|
174
|
+
_logger.debug("Resolving model constructor: base=%s objective=%s", base, objective)
|
|
169
175
|
if base == "torch":
|
|
170
176
|
try:
|
|
171
|
-
from .SGNS_torch import SGNS_Torch
|
|
172
|
-
|
|
177
|
+
from .SGNS_torch import SGNS_Torch, SG_Torch
|
|
178
|
+
ctor = SG_Torch if objective == "sg" else SGNS_Torch
|
|
179
|
+
_logger.debug("Resolved PyTorch class: %s", getattr(ctor, "__name__", str(ctor)))
|
|
180
|
+
return ctor
|
|
173
181
|
except Exception:
|
|
182
|
+
_logger.exception("Failed to import PyTorch backend.")
|
|
174
183
|
raise ImportError(
|
|
175
184
|
"PyTorch is not installed, but base='torch' was requested. "
|
|
176
185
|
"Install PyTorch first, e.g.: `pip install torch` "
|
|
@@ -178,9 +187,12 @@ class Embedder:
|
|
|
178
187
|
)
|
|
179
188
|
elif base == "pureml":
|
|
180
189
|
try:
|
|
181
|
-
from .SGNS_pml import SGNS_PureML
|
|
182
|
-
|
|
190
|
+
from .SGNS_pml import SGNS_PureML, SG_PureML
|
|
191
|
+
ctor = SG_PureML if objective == "sg" else SGNS_PureML
|
|
192
|
+
_logger.debug("Resolved PureML class: %s", getattr(ctor, "__name__", str(ctor)))
|
|
193
|
+
return ctor
|
|
183
194
|
except Exception:
|
|
195
|
+
_logger.exception("Failed to import PureML backend.")
|
|
184
196
|
raise ImportError(
|
|
185
197
|
"PureML is not installed, but base='pureml' was requested. "
|
|
186
198
|
"Install PureML first via `pip install ym-pure-ml` "
|
|
@@ -190,7 +202,18 @@ class Embedder:
|
|
|
190
202
|
|
|
191
203
|
@staticmethod
|
|
192
204
|
def _as_zerobase_intp(walks: np.ndarray, *, V: int) -> np.ndarray:
|
|
193
|
-
"""Validate 1-based
|
|
205
|
+
"""Validate and convert 1-based walks to 0-based ``intp``.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
walks: 2D array of node ids with 1-based indexing.
|
|
209
|
+
V: Vocabulary size for bounds checking.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
2D array of dtype ``intp`` with 0-based indices.
|
|
213
|
+
|
|
214
|
+
Raises:
|
|
215
|
+
ValueError: If shape is not 2D or indices are out of bounds.
|
|
216
|
+
"""
|
|
194
217
|
arr = np.asarray(walks)
|
|
195
218
|
if arr.ndim != 2:
|
|
196
219
|
raise ValueError("walks must be 2D: (num_walks, walk_len)")
|
|
@@ -198,7 +221,9 @@ class Embedder:
|
|
|
198
221
|
arr = arr.astype(np.int64, copy=False)
|
|
199
222
|
# 1-based → 0-based
|
|
200
223
|
arr = arr - 1
|
|
201
|
-
|
|
224
|
+
mn, mx = int(arr.min()), int(arr.max())
|
|
225
|
+
_logger.debug("Zero-basing walks: min=%d max=%d V=%d", mn, mx, V)
|
|
226
|
+
if mn < 0 or mx >= V:
|
|
202
227
|
raise ValueError("walk ids out of range after 1→0-based normalization")
|
|
203
228
|
return arr.astype(np.intp, copy=False)
|
|
204
229
|
|
|
@@ -206,19 +231,29 @@ class Embedder:
|
|
|
206
231
|
def _pairs_from_walks(walks0: np.ndarray, window_size: int) -> np.ndarray:
|
|
207
232
|
"""
|
|
208
233
|
Skip-gram pairs including edge centers (one-sided when needed).
|
|
209
|
-
|
|
210
|
-
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
walks0: (W, L) int array (0-based ids).
|
|
237
|
+
window_size: Symmetric context window radius.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
Array of shape (N_pairs, 2) int32 with columns [center, context].
|
|
241
|
+
|
|
242
|
+
Raises:
|
|
243
|
+
ValueError: If shape is invalid or ``window_size`` <= 0.
|
|
211
244
|
"""
|
|
212
245
|
if walks0.ndim != 2:
|
|
213
246
|
raise ValueError("walks must be 2D: (num_walks, walk_len)")
|
|
214
247
|
|
|
215
248
|
_, L = walks0.shape
|
|
216
249
|
k = int(window_size)
|
|
250
|
+
_logger.debug("Building SG pairs: L=%d, window=%d", L, k)
|
|
217
251
|
|
|
218
252
|
if k <= 0:
|
|
219
253
|
raise ValueError("window_size must be positive")
|
|
220
254
|
|
|
221
255
|
if L == 0:
|
|
256
|
+
_logger.debug("Empty walks length; returning 0 pairs.")
|
|
222
257
|
return np.empty((0, 2), dtype=np.int32)
|
|
223
258
|
|
|
224
259
|
out_chunks = []
|
|
@@ -236,18 +271,42 @@ class Embedder:
|
|
|
236
271
|
out_chunks.append(np.stack((centers_l, ctx_l), axis=2).reshape(-1, 2))
|
|
237
272
|
|
|
238
273
|
if not out_chunks:
|
|
274
|
+
_logger.debug("No offsets produced pairs; returning empty.")
|
|
239
275
|
return np.empty((0, 2), dtype=np.int32)
|
|
240
276
|
|
|
241
|
-
|
|
277
|
+
pairs = np.concatenate(out_chunks, axis=0).astype(np.int32, copy=False)
|
|
278
|
+
_logger.debug("Built %d training pairs", pairs.shape[0])
|
|
279
|
+
return pairs
|
|
242
280
|
|
|
243
281
|
@staticmethod
|
|
244
282
|
def _freq_from_walks(walks0: np.ndarray, *, V: int) -> np.ndarray:
|
|
245
|
-
"""Node frequencies from walks (0-based).
|
|
246
|
-
|
|
283
|
+
"""Node frequencies from walks (0-based).
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
walks0: 2D array of 0-based node ids.
|
|
287
|
+
V: Vocabulary size (minlength for bincount).
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
1D array of int64 frequencies with length ``V``.
|
|
291
|
+
"""
|
|
292
|
+
freq = np.bincount(walks0.ravel(), minlength=V).astype(np.int64, copy=False)
|
|
293
|
+
_logger.debug("Frequency mass: total=%d nonzero=%d", int(freq.sum()), int(np.count_nonzero(freq)))
|
|
294
|
+
return freq
|
|
247
295
|
|
|
248
296
|
@staticmethod
|
|
249
297
|
def _soft_unigram(freq: np.ndarray, *, power: float = 0.75) -> np.ndarray:
|
|
250
|
-
"""Return normalized Pn(w) ∝ f(w)^power as float64 probs.
|
|
298
|
+
"""Return normalized Pn(w) ∝ f(w)^power as float64 probs.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
freq: 1D array of token frequencies.
|
|
302
|
+
power: Exponent used for smoothing (default 0.75 à la word2vec).
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
1D array of probabilities summing to 1.0.
|
|
306
|
+
|
|
307
|
+
Raises:
|
|
308
|
+
ValueError: If mass is invalid (all zeros or non-finite).
|
|
309
|
+
"""
|
|
251
310
|
p = np.asarray(freq, dtype=np.float64)
|
|
252
311
|
if p.sum() == 0:
|
|
253
312
|
raise ValueError("all frequencies are zero")
|
|
@@ -255,13 +314,31 @@ class Embedder:
|
|
|
255
314
|
s = p.sum()
|
|
256
315
|
if not np.isfinite(s) or s <= 0:
|
|
257
316
|
raise ValueError("invalid unigram mass")
|
|
258
|
-
|
|
317
|
+
probs = p / s
|
|
318
|
+
_logger.debug("Noise distribution ready (power=%.3f)", power)
|
|
319
|
+
return probs
|
|
259
320
|
|
|
260
321
|
def _materialize_walks(self, frame_id: int, rin: Literal["attr", "repuls"],
|
|
261
322
|
using: Literal["RW", "SAW", "merged"]) -> np.ndarray:
|
|
323
|
+
"""Materialize the requested set of walks for a frame.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
frame_id: 1-based frame index.
|
|
327
|
+
rin: Which RIN to pull from: ``"attr"`` or ``"repuls"``.
|
|
328
|
+
using: Which walk sets to include: ``"RW"``, ``"SAW"``, or ``"merged"``.
|
|
329
|
+
If ``"merged"``, concatenate available RW and SAW along axis 0.
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
A 2D array of walks with shape (num_walks, walk_length+1).
|
|
333
|
+
|
|
334
|
+
Raises:
|
|
335
|
+
IndexError: If ``frame_id`` is out of range.
|
|
336
|
+
ValueError: If no matching walks are available.
|
|
337
|
+
"""
|
|
262
338
|
if not 1 <= frame_id <= int(self.frame_count):
|
|
263
339
|
raise IndexError(f"frame_id must be in [1, {self.frame_count}]; got {frame_id}")
|
|
264
340
|
|
|
341
|
+
_logger.debug("Materializing %s walks at frame=%d using=%s", rin, frame_id, using)
|
|
265
342
|
frame_id -= 1
|
|
266
343
|
|
|
267
344
|
if rin == "attr":
|
|
@@ -288,8 +365,12 @@ class Embedder:
|
|
|
288
365
|
if not parts:
|
|
289
366
|
raise ValueError(f"No walks available for {rin=} with {using=}")
|
|
290
367
|
if len(parts) == 1:
|
|
291
|
-
|
|
292
|
-
|
|
368
|
+
out = parts[0]
|
|
369
|
+
else:
|
|
370
|
+
out = np.concatenate(parts, axis=0)
|
|
371
|
+
|
|
372
|
+
_logger.debug("Materialized walks shape: %s", getattr(out, "shape", None))
|
|
373
|
+
return out
|
|
293
374
|
|
|
294
375
|
# INTERFACES: (private)
|
|
295
376
|
|
|
@@ -298,6 +379,17 @@ class Embedder:
|
|
|
298
379
|
using: Literal["RW", "SAW", "merged"],
|
|
299
380
|
window_size: int,
|
|
300
381
|
alpha: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
|
|
382
|
+
"""Construct (center, context) pairs and noise distribution for ATTR.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
frame_id: 1-based frame index.
|
|
386
|
+
using: Walk subset to include.
|
|
387
|
+
window_size: Skip-gram window radius.
|
|
388
|
+
alpha: Unigram smoothing exponent.
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
Tuple of (pairs, noise_probs).
|
|
392
|
+
"""
|
|
301
393
|
walks = self._materialize_walks(frame_id, "attr", using)
|
|
302
394
|
walks0 = self._as_zerobase_intp(walks, V=self.vocab_size)
|
|
303
395
|
attractive_corpus = self._pairs_from_walks(walks0, window_size)
|
|
@@ -311,6 +403,17 @@ class Embedder:
|
|
|
311
403
|
using: Literal["RW", "SAW", "merged"],
|
|
312
404
|
window_size: int,
|
|
313
405
|
alpha: float = 0.75) -> tuple[np.ndarray, np.ndarray]:
|
|
406
|
+
"""Construct (center, context) pairs and noise distribution for REP.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
frame_id: 1-based frame index.
|
|
410
|
+
using: Walk subset to include.
|
|
411
|
+
window_size: Skip-gram window radius.
|
|
412
|
+
alpha: Unigram smoothing exponent.
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
Tuple of (pairs, noise_probs).
|
|
416
|
+
"""
|
|
314
417
|
walks = self._materialize_walks(frame_id, "repuls", using)
|
|
315
418
|
walks0 = self._as_zerobase_intp(walks, V=self.vocab_size)
|
|
316
419
|
repulsive_corpus = self._pairs_from_walks(walks0, window_size)
|
|
@@ -322,56 +425,63 @@ class Embedder:
|
|
|
322
425
|
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= PUBLIC -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
|
323
426
|
|
|
324
427
|
def embed_frame(self,
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
428
|
+
frame_id: int,
|
|
429
|
+
RIN_type: Literal["attr", "repuls"],
|
|
430
|
+
using: Literal["RW", "SAW", "merged"],
|
|
431
|
+
num_epochs: int,
|
|
432
|
+
negative_sampling: bool = False,
|
|
433
|
+
window_size: int = 5,
|
|
434
|
+
num_negative_samples: int = 10,
|
|
435
|
+
batch_size: int = 1024,
|
|
436
|
+
*,
|
|
437
|
+
in_weights: np.ndarray | None = None,
|
|
438
|
+
out_weights: np.ndarray | None = None,
|
|
439
|
+
lr_step_per_batch: bool = False,
|
|
440
|
+
shuffle_data: bool = True,
|
|
441
|
+
dimensionality: int = 128,
|
|
442
|
+
alpha: float = 0.75,
|
|
443
|
+
device: str | None = None,
|
|
444
|
+
model_base: Literal["torch", "pureml"] = "pureml",
|
|
445
|
+
model_kwargs: dict[str, object] | None = None,
|
|
446
|
+
kind: tuple[Literal["in", "out", "avg"]] = ("in",),
|
|
447
|
+
_seed: int | None = None
|
|
448
|
+
) -> list[tuple[np.ndarray, str]]:
|
|
449
|
+
"""Train embeddings for a single frame and return the matrix containing embeddings of the specified `kind`.
|
|
341
450
|
|
|
342
451
|
Args:
|
|
343
|
-
frame_id: 1-based frame index to
|
|
344
|
-
RIN_type:
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
window_size:
|
|
349
|
-
num_negative_samples:
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
452
|
+
frame_id: 1-based frame index to embed.
|
|
453
|
+
RIN_type: ``"attr"`` or ``"repuls"`` - which corpus to use.
|
|
454
|
+
using: Which walks to use (``"RW"``, ``"SAW"``, or ``"merged"``).
|
|
455
|
+
num_epochs: Number of passes over the pairs.
|
|
456
|
+
negative_sampling: If ``True``, use SGNS objective; else plain SG.
|
|
457
|
+
window_size: Skip-gram symmetric window radius.
|
|
458
|
+
num_negative_samples: Negatives per positive pair (SGNS only).
|
|
459
|
+
batch_size: Minibatch size for training.
|
|
460
|
+
in_weights: Optional starting input-embedding matrix of shape (V, D).
|
|
461
|
+
out_weights: Optional starting output-embedding matrix of shape (V, D).
|
|
462
|
+
SGNS: shape (V, D)
|
|
463
|
+
SG: shape (D, V)
|
|
464
|
+
lr_step_per_batch: If ``True``, step LR every batch (else per epoch).
|
|
465
|
+
shuffle_data: Shuffle pairs each epoch.
|
|
466
|
+
dimensionality: Embedding dimension ``D``.
|
|
467
|
+
alpha: Unigram smoothing power for noise distribution.
|
|
468
|
+
device: Optional backend device hint (e.g., ``"cuda"``).
|
|
469
|
+
model_base: Backend family (``"torch"`` or ``"pureml"``).
|
|
470
|
+
model_kwargs: Passed through to backend model constructor.
|
|
471
|
+
kind: Which embedding to return: ``"in"``, ``"out"``, or ``"avg"``.
|
|
472
|
+
_seed: Optional override seed for this frame.
|
|
360
473
|
|
|
361
474
|
Returns:
|
|
362
|
-
np.ndarray
|
|
363
|
-
|
|
364
|
-
Raises:
|
|
365
|
-
ValueError: If requested walks are missing, if no training pairs are
|
|
366
|
-
generated, or if required ``sgns_kwargs`` for PureML are absent.
|
|
367
|
-
AttributeError: If the SGNS model does not expose embeddings via
|
|
368
|
-
``.embeddings`` or ``.parameters[0]``.
|
|
475
|
+
list[tuple[np.ndarray, Literal["avg","in","out"]]]:
|
|
476
|
+
(embedding, kind) pairs sorted as 'avg', 'in', 'out'.
|
|
369
477
|
"""
|
|
370
478
|
_logger.info(
|
|
371
|
-
"
|
|
372
|
-
frame_id, RIN_type, using,
|
|
479
|
+
"embed_frame: frame=%d RIN=%s using=%s base=%s D=%d epochs=%d batch=%d sgns=%s window_size=%d alpha=%.3f",
|
|
480
|
+
frame_id, RIN_type, using, model_base, dimensionality, num_epochs, batch_size,
|
|
481
|
+
str(negative_sampling), window_size, alpha
|
|
373
482
|
)
|
|
374
483
|
|
|
484
|
+
# ------------------ resolve training data -----------------
|
|
375
485
|
if RIN_type == "attr":
|
|
376
486
|
if self.attractive_RWs is None and self.attractive_SAWs is None:
|
|
377
487
|
raise ValueError("Attractive random walks are missing")
|
|
@@ -381,125 +491,125 @@ class Embedder:
|
|
|
381
491
|
raise ValueError("Repulsive random walks are missing")
|
|
382
492
|
pairs, noise_probs = self._repulsive_corpus_and_prob(frame_id=frame_id, using=using, window_size=window_size, alpha=alpha)
|
|
383
493
|
else:
|
|
384
|
-
raise
|
|
385
|
-
|
|
494
|
+
raise NameError(f"Unknown RIN_type: {RIN_type!r}")
|
|
386
495
|
if pairs.size == 0:
|
|
387
496
|
raise ValueError("No training pairs generated for the requested configuration")
|
|
497
|
+
# ----------------------------------------------------------
|
|
388
498
|
|
|
499
|
+
# ---------------- construct training corpus ---------------
|
|
389
500
|
centers = pairs[:, 0].astype(np.int64, copy=False)
|
|
390
501
|
contexts = pairs[:, 1].astype(np.int64, copy=False)
|
|
502
|
+
_logger.debug("Pairs split: centers=%s contexts=%s", centers.shape, contexts.shape)
|
|
503
|
+
# ----------------------------------------------------------
|
|
391
504
|
|
|
392
|
-
|
|
393
|
-
if
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
raise ValueError(f"PureML backend requires {sorted(missing)} in sgns_kwargs.")
|
|
505
|
+
# ------------ resolve model_constructor kwargs ------------
|
|
506
|
+
if model_kwargs is not None:
|
|
507
|
+
if (("lr_sched" in model_kwargs and model_kwargs.get("lr_sched", None) is not None)
|
|
508
|
+
and ("lr_sched_kwargs" in model_kwargs and model_kwargs.get("lr_sched_kwargs", None) is None)):
|
|
509
|
+
raise ValueError("When `lr_sched`, you must also provide `lr_sched_kwargs`.")
|
|
398
510
|
|
|
399
|
-
|
|
400
|
-
|
|
511
|
+
constructor_kwargs: dict[str, object] = dict(model_kwargs or {})
|
|
512
|
+
constructor_kwargs.update({
|
|
401
513
|
"V": self.vocab_size,
|
|
402
514
|
"D": dimensionality,
|
|
403
|
-
"
|
|
515
|
+
"in_weights": in_weights,
|
|
516
|
+
"out_weights": out_weights,
|
|
517
|
+
"seed": int(self._seed if _seed is None else _seed),
|
|
518
|
+
"device": device
|
|
404
519
|
})
|
|
520
|
+
_logger.debug("Model constructor kwargs: %s", {k: constructor_kwargs[k] for k in ("V","D","seed","device")})
|
|
521
|
+
# ----------------------------------------------------------
|
|
522
|
+
|
|
523
|
+
# --------------- resolve model constructor ----------------
|
|
524
|
+
model_constructor = self._get_NN_constructor_from(
|
|
525
|
+
model_base, objective=("sgns" if negative_sampling else "sg"))
|
|
526
|
+
# ----------------------------------------------------------
|
|
527
|
+
|
|
528
|
+
# ------------------ initialize the model ------------------
|
|
529
|
+
model = model_constructor(**constructor_kwargs)
|
|
530
|
+
_logger.debug("Model initialized: %s", model_constructor.__name__ if hasattr(model_constructor,"__name__") else str(model_constructor))
|
|
531
|
+
# ----------------------------------------------------------
|
|
532
|
+
|
|
533
|
+
# -------------------- fitting the data --------------------
|
|
534
|
+
_logger.info("Fitting model on %d pairs ...", pairs.shape[0])
|
|
535
|
+
model.fit(centers=centers,
|
|
536
|
+
contexts=contexts,
|
|
537
|
+
num_epochs=num_epochs,
|
|
538
|
+
batch_size=batch_size,
|
|
539
|
+
# -- optional; for SGNS; safely ignored by SG via **_ignore --
|
|
540
|
+
num_negative_samples=num_negative_samples,
|
|
541
|
+
noise_dist=noise_probs,
|
|
542
|
+
# -----------------------------------------
|
|
543
|
+
shuffle_data=shuffle_data,
|
|
544
|
+
lr_step_per_batch=lr_step_per_batch
|
|
545
|
+
)
|
|
546
|
+
_logger.info("Training complete for frame %d", frame_id)
|
|
547
|
+
# ----------------------------------------------------------
|
|
405
548
|
|
|
406
|
-
if
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
self.model = self.model_constructor(**model_kwargs)
|
|
410
|
-
|
|
411
|
-
_logger.info(
|
|
412
|
-
"Training SGNS base=%s constructor=%s frame=%d pairs=%d dim=%d epochs=%d batch=%d neg=%d shuffle=%s",
|
|
413
|
-
self.model_base,
|
|
414
|
-
getattr(self.model_constructor, "__name__", repr(self.model_constructor)),
|
|
415
|
-
frame_id,
|
|
416
|
-
pairs.shape[0],
|
|
417
|
-
dimensionality,
|
|
418
|
-
num_epochs,
|
|
419
|
-
batch_size,
|
|
420
|
-
num_negative_samples,
|
|
421
|
-
shuffle_data
|
|
422
|
-
)
|
|
423
|
-
|
|
424
|
-
self.model.fit(
|
|
425
|
-
centers,
|
|
426
|
-
contexts,
|
|
427
|
-
num_epochs,
|
|
428
|
-
batch_size,
|
|
429
|
-
num_negative_samples,
|
|
430
|
-
noise_probs,
|
|
431
|
-
shuffle_data,
|
|
432
|
-
lr_step_per_batch=False
|
|
433
|
-
)
|
|
549
|
+
if any([k not in ("in", "out", "avg") for k in kind]):
|
|
550
|
+
raise NameError(f"Unknown embeddings kind in {kind}. Expected: one of ['in', 'out', 'avg']")
|
|
434
551
|
|
|
435
|
-
|
|
436
|
-
if
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
552
|
+
# OUTPUT:
|
|
553
|
+
embeddings = [(np.asarray(model.in_embeddings, dtype=np.float32), k) if k == "in" else
|
|
554
|
+
(np.asarray(model.out_embeddings, dtype=np.float32), k) if k == "out" else
|
|
555
|
+
(np.asarray(model.avg_embeddings, dtype=np.float32), k) if k == "avg" else
|
|
556
|
+
(None, k)
|
|
557
|
+
for k in kind
|
|
558
|
+
]
|
|
559
|
+
embeddings.sort(key=lambda pair: pair[1]) # ensures 'avg', 'in', 'out' ordering
|
|
442
560
|
|
|
443
|
-
embeddings = np.asarray(embeddings)
|
|
444
|
-
_logger.info("Frame %d embeddings ready: shape=%s dtype=%s", frame_id, embeddings.shape, embeddings.dtype)
|
|
445
561
|
return embeddings
|
|
446
562
|
|
|
447
563
|
def embed_all(
|
|
448
564
|
self,
|
|
449
565
|
RIN_type: Literal["attr", "repuls"],
|
|
450
566
|
using: Literal["RW", "SAW", "merged"],
|
|
451
|
-
window_size: int,
|
|
452
|
-
num_negative_samples: int,
|
|
453
567
|
num_epochs: int,
|
|
454
|
-
|
|
568
|
+
negative_sampling: bool = False,
|
|
569
|
+
window_size: int = 2,
|
|
570
|
+
num_negative_samples: int = 10,
|
|
571
|
+
batch_size: int = 1024,
|
|
455
572
|
*,
|
|
573
|
+
lr_step_per_batch: bool = False,
|
|
456
574
|
shuffle_data: bool = True,
|
|
457
575
|
dimensionality: int = 128,
|
|
458
576
|
alpha: float = 0.75,
|
|
459
577
|
device: str | None = None,
|
|
460
|
-
|
|
578
|
+
model_base: Literal["torch", "pureml"] = "pureml",
|
|
579
|
+
model_kwargs: dict[str, object] | None = None,
|
|
580
|
+
kind: Literal["in", "out", "avg"] = "in",
|
|
461
581
|
output_path: str | Path | None = None,
|
|
462
582
|
num_matrices_in_compressed_blocks: int = 20,
|
|
463
|
-
compression_level: int = 3
|
|
464
|
-
|
|
583
|
+
compression_level: int = 3,
|
|
584
|
+
) -> str:
|
|
585
|
+
"""Embed all frames and persist a self-contained archive.
|
|
465
586
|
|
|
466
|
-
|
|
467
|
-
per
|
|
468
|
-
|
|
587
|
+
The resulting file stores a block named ``FRAME_EMBEDDINGS`` with a
|
|
588
|
+
compressed sequence of per-frame matrices (each ``(V, D)``), alongside
|
|
589
|
+
rich metadata mirroring the style of other SAWNERGY modules.
|
|
469
590
|
|
|
470
591
|
Args:
|
|
471
|
-
RIN_type:
|
|
472
|
-
using:
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
num_matrices_in_compressed_blocks:
|
|
488
|
-
|
|
489
|
-
compression_level: Blosc Zstd compression level (0-9).
|
|
592
|
+
RIN_type: ``"attr"`` or ``"repuls"`` - which corpus to use.
|
|
593
|
+
using: Which walks to use (``"RW"``, ``"SAW"``, or ``"merged"``).
|
|
594
|
+
num_epochs: Number of epochs to train per frame.
|
|
595
|
+
negative_sampling: If ``True``, use SGNS; otherwise plain SG.
|
|
596
|
+
window_size: Skip-gram window radius.
|
|
597
|
+
num_negative_samples: Negatives per positive pair (SGNS).
|
|
598
|
+
batch_size: Minibatch size for training.
|
|
599
|
+
lr_step_per_batch: If ``True``, step LR per batch (else per epoch).
|
|
600
|
+
shuffle_data: Shuffle pairs each epoch.
|
|
601
|
+
dimensionality: Embedding dimension.
|
|
602
|
+
alpha: Unigram smoothing power for noise distribution.
|
|
603
|
+
device: Backend device hint (e.g., ``"cuda"``).
|
|
604
|
+
model_base: Backend family (``"torch"`` or ``"pureml"``).
|
|
605
|
+
model_kwargs: Passed through to backend model constructor.
|
|
606
|
+
kind: Which embedding to store: ``"in"``, ``"out"``, or ``"avg"``.
|
|
607
|
+
output_path: Optional path for the output archive (``.zip`` inferred).
|
|
608
|
+
num_matrices_in_compressed_blocks: How many frames per compressed chunk.
|
|
609
|
+
compression_level: Integer compression level for the archive.
|
|
490
610
|
|
|
491
611
|
Returns:
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
Raises:
|
|
495
|
-
ValueError: If configuration produces no pairs for a frame or if
|
|
496
|
-
PureML kwargs are incomplete.
|
|
497
|
-
RuntimeError: Propagated from storage operations on failure.
|
|
498
|
-
|
|
499
|
-
Notes:
|
|
500
|
-
- A deterministic child seed is spawned per frame from the master
|
|
501
|
-
seed using ``np.random.SeedSequence`` to ensure reproducibility
|
|
502
|
-
across runs.
|
|
612
|
+
Path to the created embeddings archive, as ``str``.
|
|
503
613
|
"""
|
|
504
614
|
current_time = sawnergy_util.current_time()
|
|
505
615
|
if output_path is None:
|
|
@@ -510,69 +620,182 @@ class Embedder:
|
|
|
510
620
|
output_path = output_path.with_suffix(".zip")
|
|
511
621
|
|
|
512
622
|
_logger.info(
|
|
513
|
-
"
|
|
514
|
-
|
|
623
|
+
"embed_all: frames=%d D=%d base=%s RIN=%s using=%s out=%s",
|
|
624
|
+
self.frame_count, dimensionality, model_base, RIN_type, using, output_path
|
|
515
625
|
)
|
|
516
626
|
|
|
627
|
+
# Per-frame deterministic seeds
|
|
517
628
|
master_ss = np.random.SeedSequence(self._seed)
|
|
518
629
|
child_seeds = master_ss.spawn(self.frame_count)
|
|
519
630
|
|
|
520
|
-
embeddings = []
|
|
521
|
-
|
|
631
|
+
embeddings: list[np.ndarray] = []
|
|
632
|
+
last_frame_in_embs: np.ndarray = None
|
|
633
|
+
last_frame_out_embs: np.ndarray = None
|
|
634
|
+
used_child_seeds: list[int] = []
|
|
635
|
+
for frame_id, seed_seq in enumerate(child_seeds, start=1):
|
|
522
636
|
child_seed = int(seed_seq.generate_state(1, dtype=np.uint32)[0])
|
|
523
|
-
|
|
524
|
-
|
|
637
|
+
used_child_seeds.append(child_seed)
|
|
638
|
+
_logger.info("Embedding frame %d/%d with seed=%d", frame_id, self.frame_count, child_seed)
|
|
639
|
+
|
|
640
|
+
embs_and_kinds: list[tuple[np.ndarray, str]] = \
|
|
525
641
|
self.embed_frame(
|
|
526
|
-
|
|
527
|
-
RIN_type,
|
|
528
|
-
using,
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
642
|
+
frame_id=frame_id,
|
|
643
|
+
RIN_type=RIN_type,
|
|
644
|
+
using=using,
|
|
645
|
+
num_epochs=num_epochs,
|
|
646
|
+
negative_sampling=negative_sampling,
|
|
647
|
+
window_size=window_size,
|
|
648
|
+
num_negative_samples=num_negative_samples,
|
|
649
|
+
batch_size=batch_size,
|
|
650
|
+
in_weights=last_frame_in_embs,
|
|
651
|
+
out_weights=last_frame_out_embs,
|
|
652
|
+
lr_step_per_batch=lr_step_per_batch,
|
|
533
653
|
shuffle_data=shuffle_data,
|
|
534
654
|
dimensionality=dimensionality,
|
|
535
655
|
alpha=alpha,
|
|
536
656
|
device=device,
|
|
537
|
-
|
|
657
|
+
model_base=model_base,
|
|
658
|
+
model_kwargs=model_kwargs,
|
|
659
|
+
kind=("in", "out", "avg"),
|
|
538
660
|
_seed=child_seed
|
|
539
661
|
)
|
|
540
|
-
)
|
|
662
|
+
embs = {K: E for (E, K) in embs_and_kinds}
|
|
663
|
+
|
|
664
|
+
last_frame_in_embs = embs["in"] # (V, D)
|
|
665
|
+
last_frame_out_embs = embs["out"] if negative_sampling else embs["out"].T # SG needs (D, V), SGNS keeps (V, D)
|
|
666
|
+
|
|
667
|
+
resolved_embedding = embs[kind]
|
|
668
|
+
embeddings.append(np.asarray(resolved_embedding, dtype=np.float32, copy=False))
|
|
669
|
+
|
|
670
|
+
_logger.debug("Frame %d embedded: E.shape=%s", frame_id, resolved_embedding.shape)
|
|
541
671
|
|
|
542
|
-
embeddings = [np.asarray(e) for e in embeddings]
|
|
543
672
|
block_name = "FRAME_EMBEDDINGS"
|
|
544
673
|
with sawnergy_util.ArrayStorage.compress_and_cleanup(output_path, compression_level=compression_level) as storage:
|
|
674
|
+
_logger.info("Writing %d frame matrices to block '%s' ...", len(embeddings), block_name)
|
|
545
675
|
storage.write(
|
|
546
676
|
these_arrays=embeddings,
|
|
547
677
|
to_block_named=block_name,
|
|
548
678
|
arrays_per_chunk=num_matrices_in_compressed_blocks
|
|
549
679
|
)
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
storage.add_attr("
|
|
553
|
-
storage.add_attr("
|
|
554
|
-
storage.add_attr("
|
|
555
|
-
storage.add_attr("
|
|
556
|
-
|
|
680
|
+
|
|
681
|
+
# Core dataset discovery (for consumers like the Embeddings Visualizer)
|
|
682
|
+
storage.add_attr("frame_embeddings_name", block_name)
|
|
683
|
+
storage.add_attr("time_stamp_count", int(self.frame_count))
|
|
684
|
+
storage.add_attr("node_count", int(self.vocab_size))
|
|
685
|
+
storage.add_attr("embedding_dim", int(dimensionality))
|
|
686
|
+
|
|
687
|
+
# Provenance of input WALKS
|
|
688
|
+
storage.add_attr("source_WALKS_path", str(self._walks_path))
|
|
689
|
+
storage.add_attr("walk_length", int(self.walk_length))
|
|
690
|
+
storage.add_attr("num_RWs", int(self.num_RWs))
|
|
691
|
+
storage.add_attr("num_SAWs", int(self.num_SAWs))
|
|
692
|
+
storage.add_attr("attractive_RWs_name", self._attractive_RWs_name)
|
|
693
|
+
storage.add_attr("repulsive_RWs_name", self._repulsive_RWs_name)
|
|
694
|
+
storage.add_attr("attractive_SAWs_name", self._attractive_SAWs_name)
|
|
695
|
+
storage.add_attr("repulsive_SAWs_name", self._repulsive_SAWs_name)
|
|
696
|
+
|
|
697
|
+
# Training configuration (sufficient to reproduce)
|
|
698
|
+
storage.add_attr("objective", "sgns" if negative_sampling else "sg")
|
|
699
|
+
storage.add_attr("model_base", model_base)
|
|
700
|
+
storage.add_attr("embedding_kind", kind) # 'in' | 'out' | 'avg'
|
|
701
|
+
storage.add_attr("num_epochs", int(num_epochs))
|
|
702
|
+
storage.add_attr("batch_size", int(batch_size))
|
|
557
703
|
storage.add_attr("window_size", int(window_size))
|
|
558
704
|
storage.add_attr("alpha", float(alpha))
|
|
559
|
-
storage.add_attr("
|
|
705
|
+
storage.add_attr("negative_sampling", bool(negative_sampling))
|
|
560
706
|
storage.add_attr("num_negative_samples", int(num_negative_samples))
|
|
561
|
-
storage.add_attr("
|
|
562
|
-
storage.add_attr("batch_size", int(batch_size))
|
|
707
|
+
storage.add_attr("lr_step_per_batch", bool(lr_step_per_batch))
|
|
563
708
|
storage.add_attr("shuffle_data", bool(shuffle_data))
|
|
564
|
-
storage.add_attr("
|
|
565
|
-
storage.add_attr("
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
storage.add_attr("
|
|
709
|
+
storage.add_attr("device_hint", device if device is not None else "")
|
|
710
|
+
storage.add_attr("model_kwargs_repr", repr(model_kwargs) if model_kwargs is not None else "{}")
|
|
711
|
+
|
|
712
|
+
# Which walks were used to train
|
|
713
|
+
storage.add_attr("RIN_type", RIN_type) # 'attr' or 'repuls'
|
|
714
|
+
storage.add_attr("using", using) # 'RW' | 'SAW' | 'merged'
|
|
715
|
+
|
|
716
|
+
# Reproducibility
|
|
717
|
+
storage.add_attr("master_seed", int(self._seed))
|
|
718
|
+
storage.add_attr("per_frame_seeds", [int(s) for s in used_child_seeds])
|
|
719
|
+
|
|
720
|
+
# Archive/IO details
|
|
569
721
|
storage.add_attr("arrays_per_chunk", int(num_matrices_in_compressed_blocks))
|
|
570
722
|
storage.add_attr("compression_level", int(compression_level))
|
|
723
|
+
storage.add_attr("created_at", current_time)
|
|
724
|
+
|
|
725
|
+
_logger.info(
|
|
726
|
+
"Stored embeddings archive: %s | shape=(T,N,D)=(%d,%d,%d)",
|
|
727
|
+
output_path, self.frame_count, self.vocab_size, dimensionality
|
|
728
|
+
)
|
|
571
729
|
|
|
572
|
-
_logger.info("Embedding archive written to %s", output_path)
|
|
573
730
|
return str(output_path)
|
|
574
731
|
|
|
575
|
-
|
|
732
|
+
# *----------------------------------------------------*
|
|
733
|
+
# FUNCTIONS
|
|
734
|
+
# *----------------------------------------------------*
|
|
735
|
+
|
|
736
|
+
def align_frames(this: np.ndarray,
|
|
737
|
+
to_this: np.ndarray,
|
|
738
|
+
*,
|
|
739
|
+
center: bool = True,
|
|
740
|
+
add_back_mean: bool = True,
|
|
741
|
+
allow_reflection: bool = False) -> np.ndarray:
|
|
742
|
+
"""
|
|
743
|
+
Align `this` to `to_this` via Orthogonal Procrustes.
|
|
744
|
+
|
|
745
|
+
Solves: min_{R ∈ O(D)} || X R - Y ||_F
|
|
746
|
+
with X = this, Y = to_this (both shape (N, D)). Returns X aligned.
|
|
747
|
+
|
|
748
|
+
Args:
|
|
749
|
+
this: (N, D) matrix to be aligned.
|
|
750
|
+
to_this: (N, D) target matrix.
|
|
751
|
+
center: if True, subtract per-dimension means before solving.
|
|
752
|
+
add_back_mean: if True, add Y's mean back after alignment.
|
|
753
|
+
allow_reflection: if False, enforce det(R) = +1 (proper rotation).
|
|
754
|
+
|
|
755
|
+
Returns:
|
|
756
|
+
Aligned copy of `this` with shape (N, D).
|
|
757
|
+
"""
|
|
758
|
+
X = np.asarray(this, dtype=np.float64)
|
|
759
|
+
Y = np.asarray(to_this, dtype=np.float64)
|
|
760
|
+
|
|
761
|
+
if X.ndim != 2 or Y.ndim != 2:
|
|
762
|
+
raise ValueError(f"Expected 2D arrays; got {X.ndim=} and {Y.ndim=}")
|
|
763
|
+
if X.shape[1] != Y.shape[1]:
|
|
764
|
+
raise ValueError(f"Dimensionalities must match: X.shape={X.shape}, Y.shape={Y.shape}")
|
|
765
|
+
if X.shape[0] != Y.shape[0]:
|
|
766
|
+
raise ValueError(f"Row counts must match (one-to-one correspondence): {X.shape[0]} vs {Y.shape[0]}")
|
|
767
|
+
|
|
768
|
+
# center
|
|
769
|
+
if center:
|
|
770
|
+
X_mean = X.mean(axis=0, keepdims=True)
|
|
771
|
+
Y_mean = Y.mean(axis=0, keepdims=True)
|
|
772
|
+
Xc = X - X_mean
|
|
773
|
+
Yc = Y - Y_mean
|
|
774
|
+
else:
|
|
775
|
+
Xc, Yc = X, Y
|
|
776
|
+
Y_mean = 0.0
|
|
777
|
+
|
|
778
|
+
# Cross-covariance and SVD
|
|
779
|
+
# M = Xᵀ Y (D×D); solution R = U Vᵀ for SVD(M) = U Σ Vᵀ
|
|
780
|
+
M = Xc.T @ Yc
|
|
781
|
+
U, _, Vt = np.linalg.svd(M, full_matrices=False)
|
|
782
|
+
R = U @ Vt
|
|
783
|
+
|
|
784
|
+
# enforce proper rotation unless reflections are allowed
|
|
785
|
+
if not allow_reflection and np.linalg.det(R) < 0:
|
|
786
|
+
Vt[-1, :] *= -1
|
|
787
|
+
R = U @ Vt
|
|
788
|
+
|
|
789
|
+
X_aligned = Xc @ R
|
|
790
|
+
|
|
791
|
+
if center and add_back_mean is True:
|
|
792
|
+
X_aligned = X_aligned + Y_mean
|
|
793
|
+
|
|
794
|
+
# match input dtype if possible
|
|
795
|
+
return X_aligned.astype(this.dtype, copy=False)
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
__all__ = ["Embedder", "align_frames"]
|
|
576
799
|
|
|
577
800
|
if __name__ == "__main__":
|
|
578
801
|
pass
|