sawnergy 1.0.5__py3-none-any.whl → 1.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sawnergy might be problematic. Click here for more details.
- sawnergy/embedding/SGNS_pml.py +276 -41
- sawnergy/embedding/SGNS_torch.py +145 -11
- sawnergy/embedding/__init__.py +24 -0
- sawnergy/embedding/embedder.py +106 -50
- sawnergy/embedding/visualizer.py +247 -0
- sawnergy/logging_util.py +1 -1
- sawnergy/rin/rin_builder.py +1 -1
- sawnergy/visual/visualizer.py +6 -6
- sawnergy/visual/visualizer_util.py +3 -0
- {sawnergy-1.0.5.dist-info → sawnergy-1.0.7.dist-info}/METADATA +48 -24
- sawnergy-1.0.7.dist-info/RECORD +23 -0
- sawnergy-1.0.5.dist-info/RECORD +0 -22
- {sawnergy-1.0.5.dist-info → sawnergy-1.0.7.dist-info}/WHEEL +0 -0
- {sawnergy-1.0.5.dist-info → sawnergy-1.0.7.dist-info}/licenses/LICENSE +0 -0
- {sawnergy-1.0.5.dist-info → sawnergy-1.0.7.dist-info}/licenses/NOTICE +0 -0
- {sawnergy-1.0.5.dist-info → sawnergy-1.0.7.dist-info}/top_level.txt +0 -0
sawnergy/embedding/embedder.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
"""
|
|
4
|
-
Embedding orchestration for Skip-Gram with Negative Sampling (SGNS).
|
|
4
|
+
Embedding orchestration for Skip-Gram (SG) and Skip-Gram with Negative Sampling (SGNS).
|
|
5
5
|
|
|
6
6
|
This module consumes attractive/repulsive walk corpora produced by the walker
|
|
7
7
|
pipeline and trains per-frame embeddings using either the PyTorch or PureML
|
|
8
|
-
implementations of SGNS. The resulting embeddings can be persisted back into
|
|
8
|
+
implementations of SG/SGNS. The resulting embeddings can be persisted back into
|
|
9
9
|
an ``ArrayStorage`` archive along with rich metadata describing the training
|
|
10
10
|
configuration.
|
|
11
11
|
"""
|
|
@@ -38,7 +38,8 @@ class Embedder:
|
|
|
38
38
|
WALKS_path: str | Path,
|
|
39
39
|
base: Literal["torch", "pureml"],
|
|
40
40
|
*,
|
|
41
|
-
seed: int | None = None
|
|
41
|
+
seed: int | None = None,
|
|
42
|
+
objective: Literal["sgns", "sg"] = "sgns"
|
|
42
43
|
) -> None:
|
|
43
44
|
"""Initialize the embedder and load walk tensors.
|
|
44
45
|
|
|
@@ -53,6 +54,8 @@ class Embedder:
|
|
|
53
54
|
base: Which SGNS backend to use, either ``"torch"`` or ``"pureml"``.
|
|
54
55
|
seed: Optional seed for the embedder's RNG. If ``None``, a random
|
|
55
56
|
32-bit seed is chosen.
|
|
57
|
+
objective: Training objective, either ``"sgns"`` (negative sampling)
|
|
58
|
+
or ``"sg"`` (plain full-softmax Skip-Gram).
|
|
56
59
|
|
|
57
60
|
Raises:
|
|
58
61
|
ValueError: If required metadata is missing or any loaded walk array
|
|
@@ -156,20 +159,25 @@ class Embedder:
|
|
|
156
159
|
|
|
157
160
|
# MODEL HANDLE
|
|
158
161
|
self.model_base: Literal["torch", "pureml"] = base
|
|
159
|
-
self.
|
|
160
|
-
|
|
162
|
+
self.objective: Literal["sgns", "sg"] = objective
|
|
163
|
+
self.model_constructor = self._get_SGNS_constructor_from(base, objective)
|
|
164
|
+
_logger.info(
|
|
165
|
+
"SG backend resolved: %s (objective=%s)",
|
|
166
|
+
getattr(self.model_constructor, "__name__", repr(self.model_constructor)), self.objective
|
|
167
|
+
)
|
|
161
168
|
|
|
162
169
|
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- PRIVATE -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
|
163
170
|
|
|
164
171
|
# HELPERS:
|
|
165
172
|
|
|
166
173
|
@staticmethod
|
|
167
|
-
def _get_SGNS_constructor_from(base: Literal["torch", "pureml"]
|
|
168
|
-
|
|
174
|
+
def _get_SGNS_constructor_from(base: Literal["torch", "pureml"],
|
|
175
|
+
objective: Literal["sgns", "sg"]):
|
|
176
|
+
"""Resolve the SG/SGNS implementation class for the selected backend."""
|
|
169
177
|
if base == "torch":
|
|
170
178
|
try:
|
|
171
|
-
from .SGNS_torch import SGNS_Torch
|
|
172
|
-
return SGNS_Torch
|
|
179
|
+
from .SGNS_torch import SGNS_Torch, SG_Torch
|
|
180
|
+
return SG_Torch if objective == "sg" else SGNS_Torch
|
|
173
181
|
except Exception:
|
|
174
182
|
raise ImportError(
|
|
175
183
|
"PyTorch is not installed, but base='torch' was requested. "
|
|
@@ -178,8 +186,8 @@ class Embedder:
|
|
|
178
186
|
)
|
|
179
187
|
elif base == "pureml":
|
|
180
188
|
try:
|
|
181
|
-
from .SGNS_pml import SGNS_PureML
|
|
182
|
-
return SGNS_PureML
|
|
189
|
+
from .SGNS_pml import SGNS_PureML, SG_PureML
|
|
190
|
+
return SG_PureML if objective == "sg" else SGNS_PureML
|
|
183
191
|
except Exception:
|
|
184
192
|
raise ImportError(
|
|
185
193
|
"PureML is not installed, but base='pureml' was requested. "
|
|
@@ -322,21 +330,24 @@ class Embedder:
|
|
|
322
330
|
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= PUBLIC -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
|
323
331
|
|
|
324
332
|
def embed_frame(self,
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
333
|
+
frame_id: int,
|
|
334
|
+
RIN_type: Literal["attr", "repuls"],
|
|
335
|
+
using: Literal["RW", "SAW", "merged"],
|
|
336
|
+
window_size: int,
|
|
337
|
+
num_negative_samples: int,
|
|
338
|
+
num_epochs: int,
|
|
339
|
+
batch_size: int,
|
|
340
|
+
*,
|
|
341
|
+
lr_step_per_batch: bool = False,
|
|
342
|
+
shuffle_data: bool = True,
|
|
343
|
+
dimensionality: int = 128,
|
|
344
|
+
alpha: float = 0.75,
|
|
345
|
+
device: str | None = None,
|
|
346
|
+
sgns_kwargs: dict[str, object] | None = None,
|
|
347
|
+
kind: Literal["in", "out", "avg"] = "in",
|
|
348
|
+
_seed: int | None = None,
|
|
349
|
+
objective: Literal["sgns", "sg"] | None = None
|
|
350
|
+
) -> np.ndarray:
|
|
340
351
|
"""Train embeddings for a single frame and return the input embedding matrix.
|
|
341
352
|
|
|
342
353
|
Args:
|
|
@@ -347,6 +358,7 @@ class Embedder:
|
|
|
347
358
|
``"merged"`` (concatenates both if available).
|
|
348
359
|
window_size: Symmetric skip-gram window size ``k``.
|
|
349
360
|
num_negative_samples: Number of negative samples per positive pair.
|
|
361
|
+
Ignored when ``objective="sg"``.
|
|
350
362
|
num_epochs: Number of passes over the pair dataset.
|
|
351
363
|
batch_size: Mini-batch size for training.
|
|
352
364
|
shuffle_data: Whether to shuffle pairs each epoch.
|
|
@@ -355,11 +367,15 @@ class Embedder:
|
|
|
355
367
|
device: Optional device string for the Torch backend (e.g., ``"cuda"``).
|
|
356
368
|
sgns_kwargs: Extra keyword arguments forwarded to the backend SGNS
|
|
357
369
|
constructor. For PureML, required keys are:
|
|
358
|
-
``{"optim", "optim_kwargs"
|
|
370
|
+
``{"optim", "optim_kwargs"}``; ``lr_sched`` is optional, but if
|
|
371
|
+
provided then ``lr_sched_kwargs`` must also be provided.
|
|
372
|
+
kind: Which embedding matrix to return: ``"in"``, ``"out"``, or ``"avg"``.
|
|
359
373
|
_seed: Optional child seed for this frame's model initialization.
|
|
374
|
+
objective: Training objective override for this call (``"sgns"`` or
|
|
375
|
+
``"sg"``). If ``None``, uses the value set at construction.
|
|
360
376
|
|
|
361
377
|
Returns:
|
|
362
|
-
np.ndarray: Learned
|
|
378
|
+
np.ndarray: Learned embedding matrix (selected by ``kind``) of shape ``(V, D)``.
|
|
363
379
|
|
|
364
380
|
Raises:
|
|
365
381
|
ValueError: If requested walks are missing, if no training pairs are
|
|
@@ -391,10 +407,14 @@ class Embedder:
|
|
|
391
407
|
|
|
392
408
|
model_kwargs: dict[str, object] = dict(sgns_kwargs or {})
|
|
393
409
|
if self.model_base == "pureml":
|
|
394
|
-
required = {"optim", "optim_kwargs"
|
|
410
|
+
required = {"optim", "optim_kwargs"}
|
|
395
411
|
missing = required.difference(model_kwargs)
|
|
396
412
|
if missing:
|
|
397
413
|
raise ValueError(f"PureML backend requires {sorted(missing)} in sgns_kwargs.")
|
|
414
|
+
has_sched = ("lr_sched" in model_kwargs and model_kwargs["lr_sched"] is not None)
|
|
415
|
+
has_sched_kwargs = ("lr_sched_kwargs" in model_kwargs and model_kwargs["lr_sched_kwargs"] is not None)
|
|
416
|
+
if has_sched and not has_sched_kwargs:
|
|
417
|
+
raise ValueError("When providing lr_sched for PureML, you must also provide lr_sched_kwargs.")
|
|
398
418
|
|
|
399
419
|
child_seed = int(self._seed if _seed is None else _seed)
|
|
400
420
|
model_kwargs.update({
|
|
@@ -406,12 +426,16 @@ class Embedder:
|
|
|
406
426
|
if self.model_base == "torch" and device is not None:
|
|
407
427
|
model_kwargs["device"] = device
|
|
408
428
|
|
|
429
|
+
# Resolve objective (call-level override beats constructor default)
|
|
430
|
+
obj = self.objective if objective is None else objective
|
|
431
|
+
self.model_constructor = self._get_SGNS_constructor_from(self.model_base, obj)
|
|
409
432
|
self.model = self.model_constructor(**model_kwargs)
|
|
410
433
|
|
|
411
434
|
_logger.info(
|
|
412
|
-
"Training
|
|
435
|
+
"Training SG base=%s constructor=%s objective=%s frame=%d pairs=%d dim=%d epochs=%d batch=%d neg=%d shuffle=%s",
|
|
413
436
|
self.model_base,
|
|
414
437
|
getattr(self.model_constructor, "__name__", repr(self.model_constructor)),
|
|
438
|
+
obj,
|
|
415
439
|
frame_id,
|
|
416
440
|
pairs.shape[0],
|
|
417
441
|
dimensionality,
|
|
@@ -421,24 +445,46 @@ class Embedder:
|
|
|
421
445
|
shuffle_data
|
|
422
446
|
)
|
|
423
447
|
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
448
|
+
if obj == "sgns":
|
|
449
|
+
self.model.fit(
|
|
450
|
+
centers,
|
|
451
|
+
contexts,
|
|
452
|
+
num_epochs,
|
|
453
|
+
batch_size,
|
|
454
|
+
num_negative_samples,
|
|
455
|
+
noise_probs,
|
|
456
|
+
shuffle_data,
|
|
457
|
+
lr_step_per_batch
|
|
458
|
+
)
|
|
459
|
+
else:
|
|
460
|
+
self.model.fit(
|
|
461
|
+
centers,
|
|
462
|
+
contexts,
|
|
463
|
+
num_epochs,
|
|
464
|
+
batch_size,
|
|
465
|
+
shuffle_data,
|
|
466
|
+
lr_step_per_batch
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
# Select embedding matrix by kind
|
|
470
|
+
if kind == "in":
|
|
471
|
+
embeddings = getattr(self.model, "in_embeddings", None)
|
|
472
|
+
if embeddings is None:
|
|
473
|
+
embeddings = getattr(self.model, "embeddings", None)
|
|
474
|
+
if embeddings is None:
|
|
475
|
+
params = getattr(self.model, "parameters", None)
|
|
476
|
+
if isinstance(params, tuple) and params:
|
|
477
|
+
embeddings = params[0]
|
|
478
|
+
elif kind == "out":
|
|
479
|
+
embeddings = getattr(self.model, "out_embeddings", None)
|
|
480
|
+
else: # "avg"
|
|
481
|
+
embeddings = getattr(self.model, "avg_embeddings", None)
|
|
434
482
|
|
|
435
|
-
embeddings = getattr(self.model, "embeddings", None)
|
|
436
|
-
if embeddings is None:
|
|
437
|
-
params = getattr(self.model, "parameters", None)
|
|
438
|
-
if isinstance(params, tuple) and params:
|
|
439
|
-
embeddings = params[0]
|
|
440
483
|
if embeddings is None:
|
|
441
|
-
raise AttributeError(
|
|
484
|
+
raise AttributeError(
|
|
485
|
+
"SG/SGNS model does not expose the requested embeddings: "
|
|
486
|
+
f"kind={kind!r}"
|
|
487
|
+
)
|
|
442
488
|
|
|
443
489
|
embeddings = np.asarray(embeddings)
|
|
444
490
|
_logger.info("Frame %d embeddings ready: shape=%s dtype=%s", frame_id, embeddings.shape, embeddings.dtype)
|
|
@@ -460,7 +506,9 @@ class Embedder:
|
|
|
460
506
|
sgns_kwargs: dict[str, object] | None = None,
|
|
461
507
|
output_path: str | Path | None = None,
|
|
462
508
|
num_matrices_in_compressed_blocks: int = 20,
|
|
463
|
-
compression_level: int = 3
|
|
509
|
+
compression_level: int = 3,
|
|
510
|
+
objective: Literal["sgns", "sg"] | None = None,
|
|
511
|
+
kind: Literal["in", "out", "avg"] = "in"):
|
|
464
512
|
"""Train embeddings for all frames and persist them to compressed storage.
|
|
465
513
|
|
|
466
514
|
Iterates through all frames (``1..frame_count``), trains an SGNS model
|
|
@@ -472,6 +520,7 @@ class Embedder:
|
|
|
472
520
|
using: Walk collections: ``"RW"``, ``"SAW"``, or ``"merged"``.
|
|
473
521
|
window_size: Symmetric skip-gram window size ``k``.
|
|
474
522
|
num_negative_samples: Number of negative samples per positive pair.
|
|
523
|
+
Ignored when ``objective="sg"``.
|
|
475
524
|
num_epochs: Number of epochs for each frame.
|
|
476
525
|
batch_size: Mini-batch size used during training.
|
|
477
526
|
shuffle_data: Whether to shuffle pairs each epoch.
|
|
@@ -487,6 +536,10 @@ class Embedder:
|
|
|
487
536
|
num_matrices_in_compressed_blocks: Number of per-frame matrices to
|
|
488
537
|
store per compressed chunk in the output archive.
|
|
489
538
|
compression_level: Blosc Zstd compression level (0-9).
|
|
539
|
+
objective: Training objective for all frames (``"sgns"`` or ``"sg"``).
|
|
540
|
+
If ``None``, uses the value set at construction.
|
|
541
|
+
kind: Which embedding matrix to persist for each frame: ``"in"``,
|
|
542
|
+
``"out"``, or ``"avg"``.
|
|
490
543
|
|
|
491
544
|
Returns:
|
|
492
545
|
str: Filesystem path to the written embeddings archive (``.zip``).
|
|
@@ -498,8 +551,8 @@ class Embedder:
|
|
|
498
551
|
|
|
499
552
|
Notes:
|
|
500
553
|
- A deterministic child seed is spawned per frame from the master
|
|
501
|
-
|
|
502
|
-
|
|
554
|
+
seed using ``np.random.SeedSequence`` to ensure reproducibility
|
|
555
|
+
across runs.
|
|
503
556
|
"""
|
|
504
557
|
current_time = sawnergy_util.current_time()
|
|
505
558
|
if output_path is None:
|
|
@@ -535,7 +588,9 @@ class Embedder:
|
|
|
535
588
|
alpha=alpha,
|
|
536
589
|
device=device,
|
|
537
590
|
sgns_kwargs=sgns_kwargs,
|
|
538
|
-
_seed=child_seed
|
|
591
|
+
_seed=child_seed,
|
|
592
|
+
objective=objective,
|
|
593
|
+
kind=kind
|
|
539
594
|
)
|
|
540
595
|
)
|
|
541
596
|
|
|
@@ -568,6 +623,7 @@ class Embedder:
|
|
|
568
623
|
storage.add_attr("frame_embeddings_name", block_name)
|
|
569
624
|
storage.add_attr("arrays_per_chunk", int(num_matrices_in_compressed_blocks))
|
|
570
625
|
storage.add_attr("compression_level", int(compression_level))
|
|
626
|
+
storage.add_attr("objective", self.objective if objective is None else objective)
|
|
571
627
|
|
|
572
628
|
_logger.info("Embedding archive written to %s", output_path)
|
|
573
629
|
return str(output_path)
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# third party
|
|
4
|
+
import numpy as np
|
|
5
|
+
import matplotlib as mpl
|
|
6
|
+
|
|
7
|
+
# built-in
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Sequence
|
|
10
|
+
import logging
|
|
11
|
+
|
|
12
|
+
# local
|
|
13
|
+
from ..visual import visualizer_util
|
|
14
|
+
from .. import sawnergy_util
|
|
15
|
+
|
|
16
|
+
# *----------------------------------------------------*
|
|
17
|
+
# GLOBALS
|
|
18
|
+
# *----------------------------------------------------*
|
|
19
|
+
|
|
20
|
+
_logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
# *----------------------------------------------------*
|
|
23
|
+
# HELPERS
|
|
24
|
+
# *----------------------------------------------------*
|
|
25
|
+
|
|
26
|
+
def _safe_svd_pca(X: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
|
27
|
+
"""Compute k principal directions via SVD and project onto them."""
|
|
28
|
+
if X.ndim != 2:
|
|
29
|
+
raise ValueError(f"PCA expects 2D array (N, D); got {X.shape}")
|
|
30
|
+
_, D = X.shape
|
|
31
|
+
if k not in (2, 3):
|
|
32
|
+
raise ValueError(f"PCA dimensionality must be 2 or 3; got {k}")
|
|
33
|
+
if D < k:
|
|
34
|
+
raise ValueError(f"Requested k={k} exceeds feature dim D={D}")
|
|
35
|
+
Xc = X - X.mean(axis=0, keepdims=True)
|
|
36
|
+
_, _, Vt = np.linalg.svd(Xc, full_matrices=False)
|
|
37
|
+
comps = Vt[:k].copy()
|
|
38
|
+
proj = Xc @ comps.T
|
|
39
|
+
return proj, comps
|
|
40
|
+
|
|
41
|
+
def _set_equal_axes_3d(ax, xyz: np.ndarray, *, padding: float = 0.05) -> None:
|
|
42
|
+
if xyz.size == 0:
|
|
43
|
+
return
|
|
44
|
+
x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]
|
|
45
|
+
xmin, xmax = float(x.min()), float(x.max())
|
|
46
|
+
ymin, ymax = float(y.min()), float(y.max())
|
|
47
|
+
zmin, zmax = float(z.min()), float(z.max())
|
|
48
|
+
xr = xmax - xmin
|
|
49
|
+
yr = ymax - ymin
|
|
50
|
+
zr = zmax - zmin
|
|
51
|
+
r = max(xr, yr, zr)
|
|
52
|
+
pad = padding * (r if r > 0 else 1.0)
|
|
53
|
+
cx, cy, cz = (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, (zmin + zmax) / 2.0
|
|
54
|
+
ax.set_xlim(cx - r / 2 - pad, cx + r / 2 + pad)
|
|
55
|
+
ax.set_ylim(cy - r / 2 - pad, cy + r / 2 + pad)
|
|
56
|
+
ax.set_zlim(cz - r / 2 - pad, cz + r / 2 + pad)
|
|
57
|
+
try:
|
|
58
|
+
ax.set_box_aspect([1, 1, 1])
|
|
59
|
+
except Exception:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
# *----------------------------------------------------*
|
|
63
|
+
# CLASS
|
|
64
|
+
# *----------------------------------------------------*
|
|
65
|
+
|
|
66
|
+
class Visualizer:
|
|
67
|
+
"""3D PCA visualizer for per-frame embeddings"""
|
|
68
|
+
|
|
69
|
+
no_instances: bool = True
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
EMB_path: str | Path,
|
|
74
|
+
figsize: tuple[int, int] = (9, 7),
|
|
75
|
+
default_node_color: str = visualizer_util.GRAY,
|
|
76
|
+
depthshade: bool = False,
|
|
77
|
+
antialiased: bool = False,
|
|
78
|
+
init_elev: float = 35,
|
|
79
|
+
init_azim: float = 45,
|
|
80
|
+
*,
|
|
81
|
+
show: bool = False
|
|
82
|
+
) -> None:
|
|
83
|
+
# Backend & pyplot
|
|
84
|
+
visualizer_util.ensure_backend(show)
|
|
85
|
+
import matplotlib.pyplot as plt
|
|
86
|
+
self._plt = plt
|
|
87
|
+
|
|
88
|
+
if Visualizer.no_instances:
|
|
89
|
+
try:
|
|
90
|
+
visualizer_util.warm_start_matplotlib()
|
|
91
|
+
finally:
|
|
92
|
+
Visualizer.no_instances = False
|
|
93
|
+
|
|
94
|
+
# Load embeddings archive
|
|
95
|
+
EMB_path = Path(EMB_path)
|
|
96
|
+
with sawnergy_util.ArrayStorage(EMB_path, mode="r") as storage:
|
|
97
|
+
name = storage.get_attr("frame_embeddings_name")
|
|
98
|
+
E = storage.read(name, slice(None))
|
|
99
|
+
if E.ndim != 3:
|
|
100
|
+
raise ValueError(f"Expected embeddings of shape (T,N,D); got {E.shape}")
|
|
101
|
+
self.E = np.asarray(E)
|
|
102
|
+
self.T, self.N, self.D = map(int, self.E.shape)
|
|
103
|
+
_logger.info("Loaded embeddings: T=%d, N=%d, D=%d", self.T, self.N, self.D)
|
|
104
|
+
|
|
105
|
+
# Coloring normalizer (parity with RIN Visualizer)
|
|
106
|
+
self._residue_norm = mpl.colors.Normalize(0, max(1, self.N - 1))
|
|
107
|
+
|
|
108
|
+
# Figure / axes / artists
|
|
109
|
+
self._fig = self._plt.figure(figsize=figsize)
|
|
110
|
+
self._ax = None
|
|
111
|
+
self._scatter = None
|
|
112
|
+
self._marker_size = 30.0
|
|
113
|
+
self._init_elev = init_elev
|
|
114
|
+
self._init_azim = init_azim
|
|
115
|
+
self.default_node_color = default_node_color
|
|
116
|
+
self._antialiased = bool(antialiased)
|
|
117
|
+
self._depthshade = bool(depthshade)
|
|
118
|
+
|
|
119
|
+
# ------------------------------ PRIVATE ------------------------------ #
|
|
120
|
+
|
|
121
|
+
def _ensure_axes(self) -> None:
|
|
122
|
+
if self._ax is not None and self._scatter is not None:
|
|
123
|
+
return
|
|
124
|
+
self._fig.clf()
|
|
125
|
+
self._ax = self._fig.add_subplot(111, projection="3d")
|
|
126
|
+
self._ax.view_init(self._init_elev, self._init_azim)
|
|
127
|
+
self._scatter = self._ax.scatter(
|
|
128
|
+
[], [], [],
|
|
129
|
+
s=self._marker_size,
|
|
130
|
+
depthshade=self._depthshade,
|
|
131
|
+
edgecolors="none",
|
|
132
|
+
antialiased=self._antialiased,
|
|
133
|
+
)
|
|
134
|
+
try:
|
|
135
|
+
self._ax.set_axis_off()
|
|
136
|
+
except Exception:
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
def _project3(self, X: np.ndarray) -> np.ndarray:
|
|
140
|
+
"""Return a 3D PCA projection of embeddings (always 3 coordinates).
|
|
141
|
+
|
|
142
|
+
If the embedding dimensionality D < 3, the remaining coordinate(s) are set to 0
|
|
143
|
+
so that the returned array still has shape (N, 3).
|
|
144
|
+
"""
|
|
145
|
+
k = 3 if X.shape[1] >= 3 else 2
|
|
146
|
+
P, _ = _safe_svd_pca(X, k)
|
|
147
|
+
if k == 2:
|
|
148
|
+
P = np.c_[P, np.zeros((P.shape[0], 1), dtype=P.dtype)]
|
|
149
|
+
return P
|
|
150
|
+
|
|
151
|
+
def _select_nodes(self, displayed_nodes: Sequence[int] | str | None) -> np.ndarray:
|
|
152
|
+
if displayed_nodes is None or displayed_nodes == "ALL":
|
|
153
|
+
return np.arange(self.N, dtype=np.int64)
|
|
154
|
+
idx = np.asarray(displayed_nodes)
|
|
155
|
+
if idx.dtype.kind not in "iu":
|
|
156
|
+
raise TypeError("displayed_nodes must be None, 'ALL', or an integer sequence.")
|
|
157
|
+
if idx.min() < 1 or idx.max() > self.N:
|
|
158
|
+
raise IndexError(f"displayed_nodes out of range [1,{self.N}]")
|
|
159
|
+
return idx.astype(np.int64) - 1
|
|
160
|
+
|
|
161
|
+
def _apply_colors(self, node_colors, idx: np.ndarray) -> np.ndarray:
|
|
162
|
+
# RIN Visualizer semantics:
|
|
163
|
+
if isinstance(node_colors, str):
|
|
164
|
+
node_cmap = self._plt.get_cmap(node_colors)
|
|
165
|
+
return node_cmap(self._residue_norm(idx))
|
|
166
|
+
if node_colors is None:
|
|
167
|
+
full = visualizer_util.map_groups_to_colors(
|
|
168
|
+
N=self.N, groups=None, default_color=self.default_node_color, one_based=True
|
|
169
|
+
)
|
|
170
|
+
return np.asarray(full)[idx]
|
|
171
|
+
arr = np.asarray(node_colors)
|
|
172
|
+
if arr.ndim == 2 and arr.shape[0] == self.N and arr.shape[1] in (3, 4):
|
|
173
|
+
return arr[idx]
|
|
174
|
+
full = visualizer_util.map_groups_to_colors(
|
|
175
|
+
N=self.N, groups=node_colors, default_color=self.default_node_color, one_based=True
|
|
176
|
+
)
|
|
177
|
+
return np.asarray(full)[idx]
|
|
178
|
+
|
|
179
|
+
# ------------------------------ PUBLIC ------------------------------- #
|
|
180
|
+
|
|
181
|
+
def build_frame(
|
|
182
|
+
self,
|
|
183
|
+
frame_id: int,
|
|
184
|
+
*,
|
|
185
|
+
node_colors: str | np.ndarray | None = "rainbow",
|
|
186
|
+
displayed_nodes: Sequence[int] | str | None = "ALL",
|
|
187
|
+
show_node_labels: bool = False,
|
|
188
|
+
show: bool = False
|
|
189
|
+
) -> None:
|
|
190
|
+
"""Render a single frame as a PCA **3D** scatter (matches RIN Visualizer API)."""
|
|
191
|
+
frame0 = int(frame_id) - 1
|
|
192
|
+
if not (0 <= frame0 < self.T):
|
|
193
|
+
raise IndexError(f"frame_id out of range [1,{self.T}]")
|
|
194
|
+
self._ensure_axes()
|
|
195
|
+
|
|
196
|
+
idx = self._select_nodes(displayed_nodes)
|
|
197
|
+
X = self.E[frame0, idx, :] # (n, D)
|
|
198
|
+
P = self._project3(X) # (n, 3)
|
|
199
|
+
colors = self._apply_colors(node_colors, idx)
|
|
200
|
+
|
|
201
|
+
x, y, z = P[:, 0], P[:, 1], P[:, 2]
|
|
202
|
+
self._scatter._offsets3d = (x, y, z)
|
|
203
|
+
self._scatter.set_facecolors(colors)
|
|
204
|
+
_set_equal_axes_3d(self._ax, P, padding=0.05)
|
|
205
|
+
self._ax.view_init(self._init_elev, self._init_azim)
|
|
206
|
+
|
|
207
|
+
if show_node_labels:
|
|
208
|
+
for txt in getattr(self, "_labels", []):
|
|
209
|
+
try:
|
|
210
|
+
txt.remove()
|
|
211
|
+
except Exception:
|
|
212
|
+
pass
|
|
213
|
+
self._labels = []
|
|
214
|
+
for p, nid in zip(P, idx + 1):
|
|
215
|
+
self._labels.append(self._ax.text(p[0], p[1], p[2], str(int(nid)), fontsize=8))
|
|
216
|
+
|
|
217
|
+
# Be friendly to test dummies (they may lack tight_layout/canvas)
|
|
218
|
+
try:
|
|
219
|
+
self._fig.tight_layout()
|
|
220
|
+
except Exception:
|
|
221
|
+
try:
|
|
222
|
+
self._fig.subplots_adjust()
|
|
223
|
+
except Exception:
|
|
224
|
+
pass
|
|
225
|
+
try:
|
|
226
|
+
self._fig.canvas.draw_idle()
|
|
227
|
+
except Exception:
|
|
228
|
+
pass
|
|
229
|
+
|
|
230
|
+
if show:
|
|
231
|
+
try:
|
|
232
|
+
self._plt.show(block=True)
|
|
233
|
+
except TypeError:
|
|
234
|
+
self._plt.show()
|
|
235
|
+
|
|
236
|
+
# convenience
|
|
237
|
+
def savefig(self, path: str | Path, *, dpi: int = 150) -> None:
|
|
238
|
+
self._fig.savefig(path, dpi=dpi)
|
|
239
|
+
|
|
240
|
+
def close(self) -> None:
|
|
241
|
+
try:
|
|
242
|
+
self._plt.close(self._fig)
|
|
243
|
+
except Exception:
|
|
244
|
+
pass
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
__all__ = ["Visualizer"]
|
sawnergy/logging_util.py
CHANGED
sawnergy/rin/rin_builder.py
CHANGED
|
@@ -669,7 +669,7 @@ class RINBuilder:
|
|
|
669
669
|
molecule_of_interest: int,
|
|
670
670
|
frame_range: tuple[int, int] | None = None,
|
|
671
671
|
frame_batch_size: int = -1,
|
|
672
|
-
prune_low_energies_frac: float = 0.
|
|
672
|
+
prune_low_energies_frac: float = 0.85,
|
|
673
673
|
output_path: str | Path | None = None,
|
|
674
674
|
keep_prenormalized_energies: bool = True,
|
|
675
675
|
*,
|
sawnergy/visual/visualizer.py
CHANGED
|
@@ -107,7 +107,7 @@ class Visualizer:
|
|
|
107
107
|
visualizer_util.ensure_backend(show)
|
|
108
108
|
import matplotlib.pyplot as plt
|
|
109
109
|
self._plt = plt
|
|
110
|
-
|
|
110
|
+
# ---------- WARM UP MPL ------------ #
|
|
111
111
|
_logger.debug("Visualizer.__init__ start | RIN_path=%s, figsize=%s, node_size=%s, edge_width=%s, depthshade=%s, antialiased=%s, init_view=(%s,%s)",
|
|
112
112
|
RIN_path, figsize, node_size, edge_width, depthshade, antialiased, init_elev, init_azim)
|
|
113
113
|
if Visualizer.no_instances:
|
|
@@ -116,7 +116,7 @@ class Visualizer:
|
|
|
116
116
|
else:
|
|
117
117
|
_logger.debug("Skipping warm-start (no_instances=False).")
|
|
118
118
|
|
|
119
|
-
|
|
119
|
+
# ---------- LOAD THE DATA ---------- #
|
|
120
120
|
with sawnergy_util.ArrayStorage(RIN_path, mode="r") as storage:
|
|
121
121
|
com_name = storage.get_attr("com_name")
|
|
122
122
|
attr_energies_name = storage.get_attr("attractive_energies_name")
|
|
@@ -135,7 +135,7 @@ class Visualizer:
|
|
|
135
135
|
self.N = np.size(self.COM_coords[0], axis=0)
|
|
136
136
|
_logger.debug("Computed N=%d", self.N)
|
|
137
137
|
|
|
138
|
-
|
|
138
|
+
# - SET UP THE CANVAS AND THE AXES - #
|
|
139
139
|
self._fig = plt.figure(figsize=figsize, num="SAWNERGY")
|
|
140
140
|
self._ax = self._fig.add_subplot(111, projection="3d")
|
|
141
141
|
self._fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
|
@@ -145,14 +145,14 @@ class Visualizer:
|
|
|
145
145
|
self._ax.set_axis_off()
|
|
146
146
|
_logger.debug("Figure and 3D axes initialized.")
|
|
147
147
|
|
|
148
|
-
|
|
148
|
+
# ------ SET UP PLOT ELEMENTS ------ #
|
|
149
149
|
self._scatter: PathCollection = self._ax.scatter([], [], [], s=node_size, depthshade=depthshade, edgecolors="none")
|
|
150
150
|
self._attr: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
|
|
151
151
|
self._repuls: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
|
|
152
152
|
self._ax.add_collection3d(self._attr); self._ax.add_collection3d(self._repuls) # set pointers to the attractive and repulsive collections
|
|
153
153
|
_logger.debug("Artists created | scatter(empty), attr_lines(empty), repuls_lines(empty).")
|
|
154
154
|
|
|
155
|
-
|
|
155
|
+
# ---------- HELPER FIELDS --------- #
|
|
156
156
|
# NOTE: 'under the hood' everything is 0-base indexed,
|
|
157
157
|
# BUT, from the API point of view, the indexing is 1-base,
|
|
158
158
|
# because amino acid residues are 1-base indexed.
|
|
@@ -160,7 +160,7 @@ class Visualizer:
|
|
|
160
160
|
self.default_node_color = default_node_color
|
|
161
161
|
_logger.debug("Helper fields set | residue_norm=[0,%d], default_node_color=%s", self.N-1, self.default_node_color)
|
|
162
162
|
|
|
163
|
-
|
|
163
|
+
# DISALLOW MPL WARM-UP IN THE FUTURE
|
|
164
164
|
Visualizer.no_instances = False
|
|
165
165
|
_logger.debug("Visualizer.no_instances set to False.")
|
|
166
166
|
|
|
@@ -319,6 +319,9 @@ def build_line_segments(
|
|
|
319
319
|
kept = edge_weights >= thresh
|
|
320
320
|
rows, cols = rows[kept], cols[kept]
|
|
321
321
|
|
|
322
|
+
nz = weights[rows, cols] > 0.0
|
|
323
|
+
rows, cols = rows[nz], cols[nz]
|
|
324
|
+
|
|
322
325
|
if rows.size == 0:
|
|
323
326
|
_logger.debug("build_line_segments: no edges kept after threshold; returning empties.")
|
|
324
327
|
return (np.empty((0, 2, 3), dtype=float),
|