sawnergy 1.0.6__py3-none-any.whl → 1.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sawnergy might be problematic. Click here for more details.
- sawnergy/embedding/SGNS_pml.py +214 -16
- sawnergy/embedding/SGNS_torch.py +145 -11
- sawnergy/embedding/__init__.py +24 -0
- sawnergy/embedding/embedder.py +99 -49
- sawnergy/embedding/visualizer.py +247 -0
- sawnergy/logging_util.py +1 -1
- sawnergy/rin/rin_builder.py +1 -1
- sawnergy/visual/visualizer.py +6 -6
- sawnergy/visual/visualizer_util.py +3 -0
- {sawnergy-1.0.6.dist-info → sawnergy-1.0.7.dist-info}/METADATA +48 -24
- sawnergy-1.0.7.dist-info/RECORD +23 -0
- sawnergy-1.0.6.dist-info/RECORD +0 -22
- {sawnergy-1.0.6.dist-info → sawnergy-1.0.7.dist-info}/WHEEL +0 -0
- {sawnergy-1.0.6.dist-info → sawnergy-1.0.7.dist-info}/licenses/LICENSE +0 -0
- {sawnergy-1.0.6.dist-info → sawnergy-1.0.7.dist-info}/licenses/NOTICE +0 -0
- {sawnergy-1.0.6.dist-info → sawnergy-1.0.7.dist-info}/top_level.txt +0 -0
sawnergy/embedding/embedder.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
"""
|
|
4
|
-
Embedding orchestration for Skip-Gram with Negative Sampling (SGNS).
|
|
4
|
+
Embedding orchestration for Skip-Gram (SG) and Skip-Gram with Negative Sampling (SGNS).
|
|
5
5
|
|
|
6
6
|
This module consumes attractive/repulsive walk corpora produced by the walker
|
|
7
7
|
pipeline and trains per-frame embeddings using either the PyTorch or PureML
|
|
8
|
-
implementations of SGNS. The resulting embeddings can be persisted back into
|
|
8
|
+
implementations of SG/SGNS. The resulting embeddings can be persisted back into
|
|
9
9
|
an ``ArrayStorage`` archive along with rich metadata describing the training
|
|
10
10
|
configuration.
|
|
11
11
|
"""
|
|
@@ -38,7 +38,8 @@ class Embedder:
|
|
|
38
38
|
WALKS_path: str | Path,
|
|
39
39
|
base: Literal["torch", "pureml"],
|
|
40
40
|
*,
|
|
41
|
-
seed: int | None = None
|
|
41
|
+
seed: int | None = None,
|
|
42
|
+
objective: Literal["sgns", "sg"] = "sgns"
|
|
42
43
|
) -> None:
|
|
43
44
|
"""Initialize the embedder and load walk tensors.
|
|
44
45
|
|
|
@@ -53,6 +54,8 @@ class Embedder:
|
|
|
53
54
|
base: Which SGNS backend to use, either ``"torch"`` or ``"pureml"``.
|
|
54
55
|
seed: Optional seed for the embedder's RNG. If ``None``, a random
|
|
55
56
|
32-bit seed is chosen.
|
|
57
|
+
objective: Training objective, either ``"sgns"`` (negative sampling)
|
|
58
|
+
or ``"sg"`` (plain full-softmax Skip-Gram).
|
|
56
59
|
|
|
57
60
|
Raises:
|
|
58
61
|
ValueError: If required metadata is missing or any loaded walk array
|
|
@@ -156,20 +159,25 @@ class Embedder:
|
|
|
156
159
|
|
|
157
160
|
# MODEL HANDLE
|
|
158
161
|
self.model_base: Literal["torch", "pureml"] = base
|
|
159
|
-
self.
|
|
160
|
-
|
|
162
|
+
self.objective: Literal["sgns", "sg"] = objective
|
|
163
|
+
self.model_constructor = self._get_SGNS_constructor_from(base, objective)
|
|
164
|
+
_logger.info(
|
|
165
|
+
"SG backend resolved: %s (objective=%s)",
|
|
166
|
+
getattr(self.model_constructor, "__name__", repr(self.model_constructor)), self.objective
|
|
167
|
+
)
|
|
161
168
|
|
|
162
169
|
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- PRIVATE -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
|
163
170
|
|
|
164
171
|
# HELPERS:
|
|
165
172
|
|
|
166
173
|
@staticmethod
|
|
167
|
-
def _get_SGNS_constructor_from(base: Literal["torch", "pureml"]
|
|
168
|
-
|
|
174
|
+
def _get_SGNS_constructor_from(base: Literal["torch", "pureml"],
|
|
175
|
+
objective: Literal["sgns", "sg"]):
|
|
176
|
+
"""Resolve the SG/SGNS implementation class for the selected backend."""
|
|
169
177
|
if base == "torch":
|
|
170
178
|
try:
|
|
171
|
-
from .SGNS_torch import SGNS_Torch
|
|
172
|
-
return SGNS_Torch
|
|
179
|
+
from .SGNS_torch import SGNS_Torch, SG_Torch
|
|
180
|
+
return SG_Torch if objective == "sg" else SGNS_Torch
|
|
173
181
|
except Exception:
|
|
174
182
|
raise ImportError(
|
|
175
183
|
"PyTorch is not installed, but base='torch' was requested. "
|
|
@@ -178,8 +186,8 @@ class Embedder:
|
|
|
178
186
|
)
|
|
179
187
|
elif base == "pureml":
|
|
180
188
|
try:
|
|
181
|
-
from .SGNS_pml import SGNS_PureML
|
|
182
|
-
return SGNS_PureML
|
|
189
|
+
from .SGNS_pml import SGNS_PureML, SG_PureML
|
|
190
|
+
return SG_PureML if objective == "sg" else SGNS_PureML
|
|
183
191
|
except Exception:
|
|
184
192
|
raise ImportError(
|
|
185
193
|
"PureML is not installed, but base='pureml' was requested. "
|
|
@@ -322,22 +330,24 @@ class Embedder:
|
|
|
322
330
|
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= PUBLIC -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
|
|
323
331
|
|
|
324
332
|
def embed_frame(self,
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
333
|
+
frame_id: int,
|
|
334
|
+
RIN_type: Literal["attr", "repuls"],
|
|
335
|
+
using: Literal["RW", "SAW", "merged"],
|
|
336
|
+
window_size: int,
|
|
337
|
+
num_negative_samples: int,
|
|
338
|
+
num_epochs: int,
|
|
339
|
+
batch_size: int,
|
|
340
|
+
*,
|
|
341
|
+
lr_step_per_batch: bool = False,
|
|
342
|
+
shuffle_data: bool = True,
|
|
343
|
+
dimensionality: int = 128,
|
|
344
|
+
alpha: float = 0.75,
|
|
345
|
+
device: str | None = None,
|
|
346
|
+
sgns_kwargs: dict[str, object] | None = None,
|
|
347
|
+
kind: Literal["in", "out", "avg"] = "in",
|
|
348
|
+
_seed: int | None = None,
|
|
349
|
+
objective: Literal["sgns", "sg"] | None = None
|
|
350
|
+
) -> np.ndarray:
|
|
341
351
|
"""Train embeddings for a single frame and return the input embedding matrix.
|
|
342
352
|
|
|
343
353
|
Args:
|
|
@@ -348,6 +358,7 @@ class Embedder:
|
|
|
348
358
|
``"merged"`` (concatenates both if available).
|
|
349
359
|
window_size: Symmetric skip-gram window size ``k``.
|
|
350
360
|
num_negative_samples: Number of negative samples per positive pair.
|
|
361
|
+
Ignored when ``objective="sg"``.
|
|
351
362
|
num_epochs: Number of passes over the pair dataset.
|
|
352
363
|
batch_size: Mini-batch size for training.
|
|
353
364
|
shuffle_data: Whether to shuffle pairs each epoch.
|
|
@@ -358,10 +369,13 @@ class Embedder:
|
|
|
358
369
|
constructor. For PureML, required keys are:
|
|
359
370
|
``{"optim", "optim_kwargs"}``; ``lr_sched`` is optional, but if
|
|
360
371
|
provided then ``lr_sched_kwargs`` must also be provided.
|
|
372
|
+
kind: Which embedding matrix to return: ``"in"``, ``"out"``, or ``"avg"``.
|
|
361
373
|
_seed: Optional child seed for this frame's model initialization.
|
|
374
|
+
objective: Training objective override for this call (``"sgns"`` or
|
|
375
|
+
``"sg"``). If ``None``, uses the value set at construction.
|
|
362
376
|
|
|
363
377
|
Returns:
|
|
364
|
-
np.ndarray: Learned
|
|
378
|
+
np.ndarray: Learned embedding matrix (selected by ``kind``) of shape ``(V, D)``.
|
|
365
379
|
|
|
366
380
|
Raises:
|
|
367
381
|
ValueError: If requested walks are missing, if no training pairs are
|
|
@@ -412,12 +426,16 @@ class Embedder:
|
|
|
412
426
|
if self.model_base == "torch" and device is not None:
|
|
413
427
|
model_kwargs["device"] = device
|
|
414
428
|
|
|
429
|
+
# Resolve objective (call-level override beats constructor default)
|
|
430
|
+
obj = self.objective if objective is None else objective
|
|
431
|
+
self.model_constructor = self._get_SGNS_constructor_from(self.model_base, obj)
|
|
415
432
|
self.model = self.model_constructor(**model_kwargs)
|
|
416
433
|
|
|
417
434
|
_logger.info(
|
|
418
|
-
"Training
|
|
435
|
+
"Training SG base=%s constructor=%s objective=%s frame=%d pairs=%d dim=%d epochs=%d batch=%d neg=%d shuffle=%s",
|
|
419
436
|
self.model_base,
|
|
420
437
|
getattr(self.model_constructor, "__name__", repr(self.model_constructor)),
|
|
438
|
+
obj,
|
|
421
439
|
frame_id,
|
|
422
440
|
pairs.shape[0],
|
|
423
441
|
dimensionality,
|
|
@@ -427,24 +445,46 @@ class Embedder:
|
|
|
427
445
|
shuffle_data
|
|
428
446
|
)
|
|
429
447
|
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
448
|
+
if obj == "sgns":
|
|
449
|
+
self.model.fit(
|
|
450
|
+
centers,
|
|
451
|
+
contexts,
|
|
452
|
+
num_epochs,
|
|
453
|
+
batch_size,
|
|
454
|
+
num_negative_samples,
|
|
455
|
+
noise_probs,
|
|
456
|
+
shuffle_data,
|
|
457
|
+
lr_step_per_batch
|
|
458
|
+
)
|
|
459
|
+
else:
|
|
460
|
+
self.model.fit(
|
|
461
|
+
centers,
|
|
462
|
+
contexts,
|
|
463
|
+
num_epochs,
|
|
464
|
+
batch_size,
|
|
465
|
+
shuffle_data,
|
|
466
|
+
lr_step_per_batch
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
# Select embedding matrix by kind
|
|
470
|
+
if kind == "in":
|
|
471
|
+
embeddings = getattr(self.model, "in_embeddings", None)
|
|
472
|
+
if embeddings is None:
|
|
473
|
+
embeddings = getattr(self.model, "embeddings", None)
|
|
474
|
+
if embeddings is None:
|
|
475
|
+
params = getattr(self.model, "parameters", None)
|
|
476
|
+
if isinstance(params, tuple) and params:
|
|
477
|
+
embeddings = params[0]
|
|
478
|
+
elif kind == "out":
|
|
479
|
+
embeddings = getattr(self.model, "out_embeddings", None)
|
|
480
|
+
else: # "avg"
|
|
481
|
+
embeddings = getattr(self.model, "avg_embeddings", None)
|
|
440
482
|
|
|
441
|
-
embeddings = getattr(self.model, "embeddings", None)
|
|
442
|
-
if embeddings is None:
|
|
443
|
-
params = getattr(self.model, "parameters", None)
|
|
444
|
-
if isinstance(params, tuple) and params:
|
|
445
|
-
embeddings = params[0]
|
|
446
483
|
if embeddings is None:
|
|
447
|
-
raise AttributeError(
|
|
484
|
+
raise AttributeError(
|
|
485
|
+
"SG/SGNS model does not expose the requested embeddings: "
|
|
486
|
+
f"kind={kind!r}"
|
|
487
|
+
)
|
|
448
488
|
|
|
449
489
|
embeddings = np.asarray(embeddings)
|
|
450
490
|
_logger.info("Frame %d embeddings ready: shape=%s dtype=%s", frame_id, embeddings.shape, embeddings.dtype)
|
|
@@ -466,7 +506,9 @@ class Embedder:
|
|
|
466
506
|
sgns_kwargs: dict[str, object] | None = None,
|
|
467
507
|
output_path: str | Path | None = None,
|
|
468
508
|
num_matrices_in_compressed_blocks: int = 20,
|
|
469
|
-
compression_level: int = 3
|
|
509
|
+
compression_level: int = 3,
|
|
510
|
+
objective: Literal["sgns", "sg"] | None = None,
|
|
511
|
+
kind: Literal["in", "out", "avg"] = "in"):
|
|
470
512
|
"""Train embeddings for all frames and persist them to compressed storage.
|
|
471
513
|
|
|
472
514
|
Iterates through all frames (``1..frame_count``), trains an SGNS model
|
|
@@ -478,6 +520,7 @@ class Embedder:
|
|
|
478
520
|
using: Walk collections: ``"RW"``, ``"SAW"``, or ``"merged"``.
|
|
479
521
|
window_size: Symmetric skip-gram window size ``k``.
|
|
480
522
|
num_negative_samples: Number of negative samples per positive pair.
|
|
523
|
+
Ignored when ``objective="sg"``.
|
|
481
524
|
num_epochs: Number of epochs for each frame.
|
|
482
525
|
batch_size: Mini-batch size used during training.
|
|
483
526
|
shuffle_data: Whether to shuffle pairs each epoch.
|
|
@@ -493,6 +536,10 @@ class Embedder:
|
|
|
493
536
|
num_matrices_in_compressed_blocks: Number of per-frame matrices to
|
|
494
537
|
store per compressed chunk in the output archive.
|
|
495
538
|
compression_level: Blosc Zstd compression level (0-9).
|
|
539
|
+
objective: Training objective for all frames (``"sgns"`` or ``"sg"``).
|
|
540
|
+
If ``None``, uses the value set at construction.
|
|
541
|
+
kind: Which embedding matrix to persist for each frame: ``"in"``,
|
|
542
|
+
``"out"``, or ``"avg"``.
|
|
496
543
|
|
|
497
544
|
Returns:
|
|
498
545
|
str: Filesystem path to the written embeddings archive (``.zip``).
|
|
@@ -504,8 +551,8 @@ class Embedder:
|
|
|
504
551
|
|
|
505
552
|
Notes:
|
|
506
553
|
- A deterministic child seed is spawned per frame from the master
|
|
507
|
-
|
|
508
|
-
|
|
554
|
+
seed using ``np.random.SeedSequence`` to ensure reproducibility
|
|
555
|
+
across runs.
|
|
509
556
|
"""
|
|
510
557
|
current_time = sawnergy_util.current_time()
|
|
511
558
|
if output_path is None:
|
|
@@ -541,7 +588,9 @@ class Embedder:
|
|
|
541
588
|
alpha=alpha,
|
|
542
589
|
device=device,
|
|
543
590
|
sgns_kwargs=sgns_kwargs,
|
|
544
|
-
_seed=child_seed
|
|
591
|
+
_seed=child_seed,
|
|
592
|
+
objective=objective,
|
|
593
|
+
kind=kind
|
|
545
594
|
)
|
|
546
595
|
)
|
|
547
596
|
|
|
@@ -574,6 +623,7 @@ class Embedder:
|
|
|
574
623
|
storage.add_attr("frame_embeddings_name", block_name)
|
|
575
624
|
storage.add_attr("arrays_per_chunk", int(num_matrices_in_compressed_blocks))
|
|
576
625
|
storage.add_attr("compression_level", int(compression_level))
|
|
626
|
+
storage.add_attr("objective", self.objective if objective is None else objective)
|
|
577
627
|
|
|
578
628
|
_logger.info("Embedding archive written to %s", output_path)
|
|
579
629
|
return str(output_path)
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# third party
|
|
4
|
+
import numpy as np
|
|
5
|
+
import matplotlib as mpl
|
|
6
|
+
|
|
7
|
+
# built-in
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Sequence
|
|
10
|
+
import logging
|
|
11
|
+
|
|
12
|
+
# local
|
|
13
|
+
from ..visual import visualizer_util
|
|
14
|
+
from .. import sawnergy_util
|
|
15
|
+
|
|
16
|
+
# *----------------------------------------------------*
|
|
17
|
+
# GLOBALS
|
|
18
|
+
# *----------------------------------------------------*
|
|
19
|
+
|
|
20
|
+
_logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
# *----------------------------------------------------*
|
|
23
|
+
# HELPERS
|
|
24
|
+
# *----------------------------------------------------*
|
|
25
|
+
|
|
26
|
+
def _safe_svd_pca(X: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
|
|
27
|
+
"""Compute k principal directions via SVD and project onto them."""
|
|
28
|
+
if X.ndim != 2:
|
|
29
|
+
raise ValueError(f"PCA expects 2D array (N, D); got {X.shape}")
|
|
30
|
+
_, D = X.shape
|
|
31
|
+
if k not in (2, 3):
|
|
32
|
+
raise ValueError(f"PCA dimensionality must be 2 or 3; got {k}")
|
|
33
|
+
if D < k:
|
|
34
|
+
raise ValueError(f"Requested k={k} exceeds feature dim D={D}")
|
|
35
|
+
Xc = X - X.mean(axis=0, keepdims=True)
|
|
36
|
+
_, _, Vt = np.linalg.svd(Xc, full_matrices=False)
|
|
37
|
+
comps = Vt[:k].copy()
|
|
38
|
+
proj = Xc @ comps.T
|
|
39
|
+
return proj, comps
|
|
40
|
+
|
|
41
|
+
def _set_equal_axes_3d(ax, xyz: np.ndarray, *, padding: float = 0.05) -> None:
|
|
42
|
+
if xyz.size == 0:
|
|
43
|
+
return
|
|
44
|
+
x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]
|
|
45
|
+
xmin, xmax = float(x.min()), float(x.max())
|
|
46
|
+
ymin, ymax = float(y.min()), float(y.max())
|
|
47
|
+
zmin, zmax = float(z.min()), float(z.max())
|
|
48
|
+
xr = xmax - xmin
|
|
49
|
+
yr = ymax - ymin
|
|
50
|
+
zr = zmax - zmin
|
|
51
|
+
r = max(xr, yr, zr)
|
|
52
|
+
pad = padding * (r if r > 0 else 1.0)
|
|
53
|
+
cx, cy, cz = (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, (zmin + zmax) / 2.0
|
|
54
|
+
ax.set_xlim(cx - r / 2 - pad, cx + r / 2 + pad)
|
|
55
|
+
ax.set_ylim(cy - r / 2 - pad, cy + r / 2 + pad)
|
|
56
|
+
ax.set_zlim(cz - r / 2 - pad, cz + r / 2 + pad)
|
|
57
|
+
try:
|
|
58
|
+
ax.set_box_aspect([1, 1, 1])
|
|
59
|
+
except Exception:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
# *----------------------------------------------------*
|
|
63
|
+
# CLASS
|
|
64
|
+
# *----------------------------------------------------*
|
|
65
|
+
|
|
66
|
+
class Visualizer:
|
|
67
|
+
"""3D PCA visualizer for per-frame embeddings"""
|
|
68
|
+
|
|
69
|
+
no_instances: bool = True
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
EMB_path: str | Path,
|
|
74
|
+
figsize: tuple[int, int] = (9, 7),
|
|
75
|
+
default_node_color: str = visualizer_util.GRAY,
|
|
76
|
+
depthshade: bool = False,
|
|
77
|
+
antialiased: bool = False,
|
|
78
|
+
init_elev: float = 35,
|
|
79
|
+
init_azim: float = 45,
|
|
80
|
+
*,
|
|
81
|
+
show: bool = False
|
|
82
|
+
) -> None:
|
|
83
|
+
# Backend & pyplot
|
|
84
|
+
visualizer_util.ensure_backend(show)
|
|
85
|
+
import matplotlib.pyplot as plt
|
|
86
|
+
self._plt = plt
|
|
87
|
+
|
|
88
|
+
if Visualizer.no_instances:
|
|
89
|
+
try:
|
|
90
|
+
visualizer_util.warm_start_matplotlib()
|
|
91
|
+
finally:
|
|
92
|
+
Visualizer.no_instances = False
|
|
93
|
+
|
|
94
|
+
# Load embeddings archive
|
|
95
|
+
EMB_path = Path(EMB_path)
|
|
96
|
+
with sawnergy_util.ArrayStorage(EMB_path, mode="r") as storage:
|
|
97
|
+
name = storage.get_attr("frame_embeddings_name")
|
|
98
|
+
E = storage.read(name, slice(None))
|
|
99
|
+
if E.ndim != 3:
|
|
100
|
+
raise ValueError(f"Expected embeddings of shape (T,N,D); got {E.shape}")
|
|
101
|
+
self.E = np.asarray(E)
|
|
102
|
+
self.T, self.N, self.D = map(int, self.E.shape)
|
|
103
|
+
_logger.info("Loaded embeddings: T=%d, N=%d, D=%d", self.T, self.N, self.D)
|
|
104
|
+
|
|
105
|
+
# Coloring normalizer (parity with RIN Visualizer)
|
|
106
|
+
self._residue_norm = mpl.colors.Normalize(0, max(1, self.N - 1))
|
|
107
|
+
|
|
108
|
+
# Figure / axes / artists
|
|
109
|
+
self._fig = self._plt.figure(figsize=figsize)
|
|
110
|
+
self._ax = None
|
|
111
|
+
self._scatter = None
|
|
112
|
+
self._marker_size = 30.0
|
|
113
|
+
self._init_elev = init_elev
|
|
114
|
+
self._init_azim = init_azim
|
|
115
|
+
self.default_node_color = default_node_color
|
|
116
|
+
self._antialiased = bool(antialiased)
|
|
117
|
+
self._depthshade = bool(depthshade)
|
|
118
|
+
|
|
119
|
+
# ------------------------------ PRIVATE ------------------------------ #
|
|
120
|
+
|
|
121
|
+
def _ensure_axes(self) -> None:
|
|
122
|
+
if self._ax is not None and self._scatter is not None:
|
|
123
|
+
return
|
|
124
|
+
self._fig.clf()
|
|
125
|
+
self._ax = self._fig.add_subplot(111, projection="3d")
|
|
126
|
+
self._ax.view_init(self._init_elev, self._init_azim)
|
|
127
|
+
self._scatter = self._ax.scatter(
|
|
128
|
+
[], [], [],
|
|
129
|
+
s=self._marker_size,
|
|
130
|
+
depthshade=self._depthshade,
|
|
131
|
+
edgecolors="none",
|
|
132
|
+
antialiased=self._antialiased,
|
|
133
|
+
)
|
|
134
|
+
try:
|
|
135
|
+
self._ax.set_axis_off()
|
|
136
|
+
except Exception:
|
|
137
|
+
pass
|
|
138
|
+
|
|
139
|
+
def _project3(self, X: np.ndarray) -> np.ndarray:
|
|
140
|
+
"""Return a 3D PCA projection of embeddings (always 3 coordinates).
|
|
141
|
+
|
|
142
|
+
If the embedding dimensionality D < 3, the remaining coordinate(s) are set to 0
|
|
143
|
+
so that the returned array still has shape (N, 3).
|
|
144
|
+
"""
|
|
145
|
+
k = 3 if X.shape[1] >= 3 else 2
|
|
146
|
+
P, _ = _safe_svd_pca(X, k)
|
|
147
|
+
if k == 2:
|
|
148
|
+
P = np.c_[P, np.zeros((P.shape[0], 1), dtype=P.dtype)]
|
|
149
|
+
return P
|
|
150
|
+
|
|
151
|
+
def _select_nodes(self, displayed_nodes: Sequence[int] | str | None) -> np.ndarray:
|
|
152
|
+
if displayed_nodes is None or displayed_nodes == "ALL":
|
|
153
|
+
return np.arange(self.N, dtype=np.int64)
|
|
154
|
+
idx = np.asarray(displayed_nodes)
|
|
155
|
+
if idx.dtype.kind not in "iu":
|
|
156
|
+
raise TypeError("displayed_nodes must be None, 'ALL', or an integer sequence.")
|
|
157
|
+
if idx.min() < 1 or idx.max() > self.N:
|
|
158
|
+
raise IndexError(f"displayed_nodes out of range [1,{self.N}]")
|
|
159
|
+
return idx.astype(np.int64) - 1
|
|
160
|
+
|
|
161
|
+
def _apply_colors(self, node_colors, idx: np.ndarray) -> np.ndarray:
|
|
162
|
+
# RIN Visualizer semantics:
|
|
163
|
+
if isinstance(node_colors, str):
|
|
164
|
+
node_cmap = self._plt.get_cmap(node_colors)
|
|
165
|
+
return node_cmap(self._residue_norm(idx))
|
|
166
|
+
if node_colors is None:
|
|
167
|
+
full = visualizer_util.map_groups_to_colors(
|
|
168
|
+
N=self.N, groups=None, default_color=self.default_node_color, one_based=True
|
|
169
|
+
)
|
|
170
|
+
return np.asarray(full)[idx]
|
|
171
|
+
arr = np.asarray(node_colors)
|
|
172
|
+
if arr.ndim == 2 and arr.shape[0] == self.N and arr.shape[1] in (3, 4):
|
|
173
|
+
return arr[idx]
|
|
174
|
+
full = visualizer_util.map_groups_to_colors(
|
|
175
|
+
N=self.N, groups=node_colors, default_color=self.default_node_color, one_based=True
|
|
176
|
+
)
|
|
177
|
+
return np.asarray(full)[idx]
|
|
178
|
+
|
|
179
|
+
# ------------------------------ PUBLIC ------------------------------- #
|
|
180
|
+
|
|
181
|
+
def build_frame(
|
|
182
|
+
self,
|
|
183
|
+
frame_id: int,
|
|
184
|
+
*,
|
|
185
|
+
node_colors: str | np.ndarray | None = "rainbow",
|
|
186
|
+
displayed_nodes: Sequence[int] | str | None = "ALL",
|
|
187
|
+
show_node_labels: bool = False,
|
|
188
|
+
show: bool = False
|
|
189
|
+
) -> None:
|
|
190
|
+
"""Render a single frame as a PCA **3D** scatter (matches RIN Visualizer API)."""
|
|
191
|
+
frame0 = int(frame_id) - 1
|
|
192
|
+
if not (0 <= frame0 < self.T):
|
|
193
|
+
raise IndexError(f"frame_id out of range [1,{self.T}]")
|
|
194
|
+
self._ensure_axes()
|
|
195
|
+
|
|
196
|
+
idx = self._select_nodes(displayed_nodes)
|
|
197
|
+
X = self.E[frame0, idx, :] # (n, D)
|
|
198
|
+
P = self._project3(X) # (n, 3)
|
|
199
|
+
colors = self._apply_colors(node_colors, idx)
|
|
200
|
+
|
|
201
|
+
x, y, z = P[:, 0], P[:, 1], P[:, 2]
|
|
202
|
+
self._scatter._offsets3d = (x, y, z)
|
|
203
|
+
self._scatter.set_facecolors(colors)
|
|
204
|
+
_set_equal_axes_3d(self._ax, P, padding=0.05)
|
|
205
|
+
self._ax.view_init(self._init_elev, self._init_azim)
|
|
206
|
+
|
|
207
|
+
if show_node_labels:
|
|
208
|
+
for txt in getattr(self, "_labels", []):
|
|
209
|
+
try:
|
|
210
|
+
txt.remove()
|
|
211
|
+
except Exception:
|
|
212
|
+
pass
|
|
213
|
+
self._labels = []
|
|
214
|
+
for p, nid in zip(P, idx + 1):
|
|
215
|
+
self._labels.append(self._ax.text(p[0], p[1], p[2], str(int(nid)), fontsize=8))
|
|
216
|
+
|
|
217
|
+
# Be friendly to test dummies (they may lack tight_layout/canvas)
|
|
218
|
+
try:
|
|
219
|
+
self._fig.tight_layout()
|
|
220
|
+
except Exception:
|
|
221
|
+
try:
|
|
222
|
+
self._fig.subplots_adjust()
|
|
223
|
+
except Exception:
|
|
224
|
+
pass
|
|
225
|
+
try:
|
|
226
|
+
self._fig.canvas.draw_idle()
|
|
227
|
+
except Exception:
|
|
228
|
+
pass
|
|
229
|
+
|
|
230
|
+
if show:
|
|
231
|
+
try:
|
|
232
|
+
self._plt.show(block=True)
|
|
233
|
+
except TypeError:
|
|
234
|
+
self._plt.show()
|
|
235
|
+
|
|
236
|
+
# convenience
|
|
237
|
+
def savefig(self, path: str | Path, *, dpi: int = 150) -> None:
|
|
238
|
+
self._fig.savefig(path, dpi=dpi)
|
|
239
|
+
|
|
240
|
+
def close(self) -> None:
|
|
241
|
+
try:
|
|
242
|
+
self._plt.close(self._fig)
|
|
243
|
+
except Exception:
|
|
244
|
+
pass
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
__all__ = ["Visualizer"]
|
sawnergy/logging_util.py
CHANGED
sawnergy/rin/rin_builder.py
CHANGED
|
@@ -669,7 +669,7 @@ class RINBuilder:
|
|
|
669
669
|
molecule_of_interest: int,
|
|
670
670
|
frame_range: tuple[int, int] | None = None,
|
|
671
671
|
frame_batch_size: int = -1,
|
|
672
|
-
prune_low_energies_frac: float = 0.
|
|
672
|
+
prune_low_energies_frac: float = 0.85,
|
|
673
673
|
output_path: str | Path | None = None,
|
|
674
674
|
keep_prenormalized_energies: bool = True,
|
|
675
675
|
*,
|
sawnergy/visual/visualizer.py
CHANGED
|
@@ -107,7 +107,7 @@ class Visualizer:
|
|
|
107
107
|
visualizer_util.ensure_backend(show)
|
|
108
108
|
import matplotlib.pyplot as plt
|
|
109
109
|
self._plt = plt
|
|
110
|
-
|
|
110
|
+
# ---------- WARM UP MPL ------------ #
|
|
111
111
|
_logger.debug("Visualizer.__init__ start | RIN_path=%s, figsize=%s, node_size=%s, edge_width=%s, depthshade=%s, antialiased=%s, init_view=(%s,%s)",
|
|
112
112
|
RIN_path, figsize, node_size, edge_width, depthshade, antialiased, init_elev, init_azim)
|
|
113
113
|
if Visualizer.no_instances:
|
|
@@ -116,7 +116,7 @@ class Visualizer:
|
|
|
116
116
|
else:
|
|
117
117
|
_logger.debug("Skipping warm-start (no_instances=False).")
|
|
118
118
|
|
|
119
|
-
|
|
119
|
+
# ---------- LOAD THE DATA ---------- #
|
|
120
120
|
with sawnergy_util.ArrayStorage(RIN_path, mode="r") as storage:
|
|
121
121
|
com_name = storage.get_attr("com_name")
|
|
122
122
|
attr_energies_name = storage.get_attr("attractive_energies_name")
|
|
@@ -135,7 +135,7 @@ class Visualizer:
|
|
|
135
135
|
self.N = np.size(self.COM_coords[0], axis=0)
|
|
136
136
|
_logger.debug("Computed N=%d", self.N)
|
|
137
137
|
|
|
138
|
-
|
|
138
|
+
# - SET UP THE CANVAS AND THE AXES - #
|
|
139
139
|
self._fig = plt.figure(figsize=figsize, num="SAWNERGY")
|
|
140
140
|
self._ax = self._fig.add_subplot(111, projection="3d")
|
|
141
141
|
self._fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
|
@@ -145,14 +145,14 @@ class Visualizer:
|
|
|
145
145
|
self._ax.set_axis_off()
|
|
146
146
|
_logger.debug("Figure and 3D axes initialized.")
|
|
147
147
|
|
|
148
|
-
|
|
148
|
+
# ------ SET UP PLOT ELEMENTS ------ #
|
|
149
149
|
self._scatter: PathCollection = self._ax.scatter([], [], [], s=node_size, depthshade=depthshade, edgecolors="none")
|
|
150
150
|
self._attr: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
|
|
151
151
|
self._repuls: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
|
|
152
152
|
self._ax.add_collection3d(self._attr); self._ax.add_collection3d(self._repuls) # set pointers to the attractive and repulsive collections
|
|
153
153
|
_logger.debug("Artists created | scatter(empty), attr_lines(empty), repuls_lines(empty).")
|
|
154
154
|
|
|
155
|
-
|
|
155
|
+
# ---------- HELPER FIELDS --------- #
|
|
156
156
|
# NOTE: 'under the hood' everything is 0-base indexed,
|
|
157
157
|
# BUT, from the API point of view, the indexing is 1-base,
|
|
158
158
|
# because amino acid residues are 1-base indexed.
|
|
@@ -160,7 +160,7 @@ class Visualizer:
|
|
|
160
160
|
self.default_node_color = default_node_color
|
|
161
161
|
_logger.debug("Helper fields set | residue_norm=[0,%d], default_node_color=%s", self.N-1, self.default_node_color)
|
|
162
162
|
|
|
163
|
-
|
|
163
|
+
# DISALLOW MPL WARM-UP IN THE FUTURE
|
|
164
164
|
Visualizer.no_instances = False
|
|
165
165
|
_logger.debug("Visualizer.no_instances set to False.")
|
|
166
166
|
|
|
@@ -319,6 +319,9 @@ def build_line_segments(
|
|
|
319
319
|
kept = edge_weights >= thresh
|
|
320
320
|
rows, cols = rows[kept], cols[kept]
|
|
321
321
|
|
|
322
|
+
nz = weights[rows, cols] > 0.0
|
|
323
|
+
rows, cols = rows[nz], cols[nz]
|
|
324
|
+
|
|
322
325
|
if rows.size == 0:
|
|
323
326
|
_logger.debug("build_line_segments: no edges kept after threshold; returning empties.")
|
|
324
327
|
return (np.empty((0, 2, 3), dtype=float),
|