sawnergy 1.0.5__py3-none-any.whl → 1.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sawnergy might be problematic. Click here for more details.

@@ -1,11 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  """
4
- Embedding orchestration for Skip-Gram with Negative Sampling (SGNS).
4
+ Embedding orchestration for Skip-Gram (SG) and Skip-Gram with Negative Sampling (SGNS).
5
5
 
6
6
  This module consumes attractive/repulsive walk corpora produced by the walker
7
7
  pipeline and trains per-frame embeddings using either the PyTorch or PureML
8
- implementations of SGNS. The resulting embeddings can be persisted back into
8
+ implementations of SG/SGNS. The resulting embeddings can be persisted back into
9
9
  an ``ArrayStorage`` archive along with rich metadata describing the training
10
10
  configuration.
11
11
  """
@@ -38,7 +38,8 @@ class Embedder:
38
38
  WALKS_path: str | Path,
39
39
  base: Literal["torch", "pureml"],
40
40
  *,
41
- seed: int | None = None
41
+ seed: int | None = None,
42
+ objective: Literal["sgns", "sg"] = "sgns"
42
43
  ) -> None:
43
44
  """Initialize the embedder and load walk tensors.
44
45
 
@@ -53,6 +54,8 @@ class Embedder:
53
54
  base: Which SGNS backend to use, either ``"torch"`` or ``"pureml"``.
54
55
  seed: Optional seed for the embedder's RNG. If ``None``, a random
55
56
  32-bit seed is chosen.
57
+ objective: Training objective, either ``"sgns"`` (negative sampling)
58
+ or ``"sg"`` (plain full-softmax Skip-Gram).
56
59
 
57
60
  Raises:
58
61
  ValueError: If required metadata is missing or any loaded walk array
@@ -156,20 +159,25 @@ class Embedder:
156
159
 
157
160
  # MODEL HANDLE
158
161
  self.model_base: Literal["torch", "pureml"] = base
159
- self.model_constructor = self._get_SGNS_constructor_from(base)
160
- _logger.info("SGNS backend resolved: %s", getattr(self.model_constructor, "__name__", repr(self.model_constructor)))
162
+ self.objective: Literal["sgns", "sg"] = objective
163
+ self.model_constructor = self._get_SGNS_constructor_from(base, objective)
164
+ _logger.info(
165
+ "SG backend resolved: %s (objective=%s)",
166
+ getattr(self.model_constructor, "__name__", repr(self.model_constructor)), self.objective
167
+ )
161
168
 
162
169
  # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- PRIVATE -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
163
170
 
164
171
  # HELPERS:
165
172
 
166
173
  @staticmethod
167
- def _get_SGNS_constructor_from(base: Literal["torch", "pureml"]):
168
- """Resolve the SGNS implementation class for the selected backend."""
174
+ def _get_SGNS_constructor_from(base: Literal["torch", "pureml"],
175
+ objective: Literal["sgns", "sg"]):
176
+ """Resolve the SG/SGNS implementation class for the selected backend."""
169
177
  if base == "torch":
170
178
  try:
171
- from .SGNS_torch import SGNS_Torch
172
- return SGNS_Torch
179
+ from .SGNS_torch import SGNS_Torch, SG_Torch
180
+ return SG_Torch if objective == "sg" else SGNS_Torch
173
181
  except Exception:
174
182
  raise ImportError(
175
183
  "PyTorch is not installed, but base='torch' was requested. "
@@ -178,8 +186,8 @@ class Embedder:
178
186
  )
179
187
  elif base == "pureml":
180
188
  try:
181
- from .SGNS_pml import SGNS_PureML
182
- return SGNS_PureML
189
+ from .SGNS_pml import SGNS_PureML, SG_PureML
190
+ return SG_PureML if objective == "sg" else SGNS_PureML
183
191
  except Exception:
184
192
  raise ImportError(
185
193
  "PureML is not installed, but base='pureml' was requested. "
@@ -322,21 +330,24 @@ class Embedder:
322
330
  # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= PUBLIC -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
323
331
 
324
332
  def embed_frame(self,
325
- frame_id: int,
326
- RIN_type: Literal["attr", "repuls"],
327
- using: Literal["RW", "SAW", "merged"],
328
- window_size: int,
329
- num_negative_samples: int,
330
- num_epochs: int,
331
- batch_size: int,
332
- *,
333
- shuffle_data: bool = True,
334
- dimensionality: int = 128,
335
- alpha: float = 0.75,
336
- device: str | None = None,
337
- sgns_kwargs: dict[str, object] | None = None,
338
- _seed: int | None = None
339
- ) -> np.ndarray:
333
+ frame_id: int,
334
+ RIN_type: Literal["attr", "repuls"],
335
+ using: Literal["RW", "SAW", "merged"],
336
+ window_size: int,
337
+ num_negative_samples: int,
338
+ num_epochs: int,
339
+ batch_size: int,
340
+ *,
341
+ lr_step_per_batch: bool = False,
342
+ shuffle_data: bool = True,
343
+ dimensionality: int = 128,
344
+ alpha: float = 0.75,
345
+ device: str | None = None,
346
+ sgns_kwargs: dict[str, object] | None = None,
347
+ kind: Literal["in", "out", "avg"] = "in",
348
+ _seed: int | None = None,
349
+ objective: Literal["sgns", "sg"] | None = None
350
+ ) -> np.ndarray:
340
351
  """Train embeddings for a single frame and return the input embedding matrix.
341
352
 
342
353
  Args:
@@ -347,6 +358,7 @@ class Embedder:
347
358
  ``"merged"`` (concatenates both if available).
348
359
  window_size: Symmetric skip-gram window size ``k``.
349
360
  num_negative_samples: Number of negative samples per positive pair.
361
+ Ignored when ``objective="sg"``.
350
362
  num_epochs: Number of passes over the pair dataset.
351
363
  batch_size: Mini-batch size for training.
352
364
  shuffle_data: Whether to shuffle pairs each epoch.
@@ -355,11 +367,15 @@ class Embedder:
355
367
  device: Optional device string for the Torch backend (e.g., ``"cuda"``).
356
368
  sgns_kwargs: Extra keyword arguments forwarded to the backend SGNS
357
369
  constructor. For PureML, required keys are:
358
- ``{"optim", "optim_kwargs", "lr_sched", "lr_sched_kwargs"}``.
370
+ ``{"optim", "optim_kwargs"}``; ``lr_sched`` is optional, but if
371
+ provided then ``lr_sched_kwargs`` must also be provided.
372
+ kind: Which embedding matrix to return: ``"in"``, ``"out"``, or ``"avg"``.
359
373
  _seed: Optional child seed for this frame's model initialization.
374
+ objective: Training objective override for this call (``"sgns"`` or
375
+ ``"sg"``). If ``None``, uses the value set at construction.
360
376
 
361
377
  Returns:
362
- np.ndarray: Learned **input** embedding matrix of shape ``(V, D)``.
378
+ np.ndarray: Learned embedding matrix (selected by ``kind``) of shape ``(V, D)``.
363
379
 
364
380
  Raises:
365
381
  ValueError: If requested walks are missing, if no training pairs are
@@ -391,10 +407,14 @@ class Embedder:
391
407
 
392
408
  model_kwargs: dict[str, object] = dict(sgns_kwargs or {})
393
409
  if self.model_base == "pureml":
394
- required = {"optim", "optim_kwargs", "lr_sched", "lr_sched_kwargs"}
410
+ required = {"optim", "optim_kwargs"}
395
411
  missing = required.difference(model_kwargs)
396
412
  if missing:
397
413
  raise ValueError(f"PureML backend requires {sorted(missing)} in sgns_kwargs.")
414
+ has_sched = ("lr_sched" in model_kwargs and model_kwargs["lr_sched"] is not None)
415
+ has_sched_kwargs = ("lr_sched_kwargs" in model_kwargs and model_kwargs["lr_sched_kwargs"] is not None)
416
+ if has_sched and not has_sched_kwargs:
417
+ raise ValueError("When providing lr_sched for PureML, you must also provide lr_sched_kwargs.")
398
418
 
399
419
  child_seed = int(self._seed if _seed is None else _seed)
400
420
  model_kwargs.update({
@@ -406,12 +426,16 @@ class Embedder:
406
426
  if self.model_base == "torch" and device is not None:
407
427
  model_kwargs["device"] = device
408
428
 
429
+ # Resolve objective (call-level override beats constructor default)
430
+ obj = self.objective if objective is None else objective
431
+ self.model_constructor = self._get_SGNS_constructor_from(self.model_base, obj)
409
432
  self.model = self.model_constructor(**model_kwargs)
410
433
 
411
434
  _logger.info(
412
- "Training SGNS base=%s constructor=%s frame=%d pairs=%d dim=%d epochs=%d batch=%d neg=%d shuffle=%s",
435
+ "Training SG base=%s constructor=%s objective=%s frame=%d pairs=%d dim=%d epochs=%d batch=%d neg=%d shuffle=%s",
413
436
  self.model_base,
414
437
  getattr(self.model_constructor, "__name__", repr(self.model_constructor)),
438
+ obj,
415
439
  frame_id,
416
440
  pairs.shape[0],
417
441
  dimensionality,
@@ -421,24 +445,46 @@ class Embedder:
421
445
  shuffle_data
422
446
  )
423
447
 
424
- self.model.fit(
425
- centers,
426
- contexts,
427
- num_epochs,
428
- batch_size,
429
- num_negative_samples,
430
- noise_probs,
431
- shuffle_data,
432
- lr_step_per_batch=False
433
- )
448
+ if obj == "sgns":
449
+ self.model.fit(
450
+ centers,
451
+ contexts,
452
+ num_epochs,
453
+ batch_size,
454
+ num_negative_samples,
455
+ noise_probs,
456
+ shuffle_data,
457
+ lr_step_per_batch
458
+ )
459
+ else:
460
+ self.model.fit(
461
+ centers,
462
+ contexts,
463
+ num_epochs,
464
+ batch_size,
465
+ shuffle_data,
466
+ lr_step_per_batch
467
+ )
468
+
469
+ # Select embedding matrix by kind
470
+ if kind == "in":
471
+ embeddings = getattr(self.model, "in_embeddings", None)
472
+ if embeddings is None:
473
+ embeddings = getattr(self.model, "embeddings", None)
474
+ if embeddings is None:
475
+ params = getattr(self.model, "parameters", None)
476
+ if isinstance(params, tuple) and params:
477
+ embeddings = params[0]
478
+ elif kind == "out":
479
+ embeddings = getattr(self.model, "out_embeddings", None)
480
+ else: # "avg"
481
+ embeddings = getattr(self.model, "avg_embeddings", None)
434
482
 
435
- embeddings = getattr(self.model, "embeddings", None)
436
- if embeddings is None:
437
- params = getattr(self.model, "parameters", None)
438
- if isinstance(params, tuple) and params:
439
- embeddings = params[0]
440
483
  if embeddings is None:
441
- raise AttributeError("SGNS model does not expose embeddings via '.embeddings' or '.parameters[0]'")
484
+ raise AttributeError(
485
+ "SG/SGNS model does not expose the requested embeddings: "
486
+ f"kind={kind!r}"
487
+ )
442
488
 
443
489
  embeddings = np.asarray(embeddings)
444
490
  _logger.info("Frame %d embeddings ready: shape=%s dtype=%s", frame_id, embeddings.shape, embeddings.dtype)
@@ -460,7 +506,9 @@ class Embedder:
460
506
  sgns_kwargs: dict[str, object] | None = None,
461
507
  output_path: str | Path | None = None,
462
508
  num_matrices_in_compressed_blocks: int = 20,
463
- compression_level: int = 3):
509
+ compression_level: int = 3,
510
+ objective: Literal["sgns", "sg"] | None = None,
511
+ kind: Literal["in", "out", "avg"] = "in"):
464
512
  """Train embeddings for all frames and persist them to compressed storage.
465
513
 
466
514
  Iterates through all frames (``1..frame_count``), trains an SGNS model
@@ -472,6 +520,7 @@ class Embedder:
472
520
  using: Walk collections: ``"RW"``, ``"SAW"``, or ``"merged"``.
473
521
  window_size: Symmetric skip-gram window size ``k``.
474
522
  num_negative_samples: Number of negative samples per positive pair.
523
+ Ignored when ``objective="sg"``.
475
524
  num_epochs: Number of epochs for each frame.
476
525
  batch_size: Mini-batch size used during training.
477
526
  shuffle_data: Whether to shuffle pairs each epoch.
@@ -487,6 +536,10 @@ class Embedder:
487
536
  num_matrices_in_compressed_blocks: Number of per-frame matrices to
488
537
  store per compressed chunk in the output archive.
489
538
  compression_level: Blosc Zstd compression level (0-9).
539
+ objective: Training objective for all frames (``"sgns"`` or ``"sg"``).
540
+ If ``None``, uses the value set at construction.
541
+ kind: Which embedding matrix to persist for each frame: ``"in"``,
542
+ ``"out"``, or ``"avg"``.
490
543
 
491
544
  Returns:
492
545
  str: Filesystem path to the written embeddings archive (``.zip``).
@@ -498,8 +551,8 @@ class Embedder:
498
551
 
499
552
  Notes:
500
553
  - A deterministic child seed is spawned per frame from the master
501
- seed using ``np.random.SeedSequence`` to ensure reproducibility
502
- across runs.
554
+ seed using ``np.random.SeedSequence`` to ensure reproducibility
555
+ across runs.
503
556
  """
504
557
  current_time = sawnergy_util.current_time()
505
558
  if output_path is None:
@@ -535,7 +588,9 @@ class Embedder:
535
588
  alpha=alpha,
536
589
  device=device,
537
590
  sgns_kwargs=sgns_kwargs,
538
- _seed=child_seed
591
+ _seed=child_seed,
592
+ objective=objective,
593
+ kind=kind
539
594
  )
540
595
  )
541
596
 
@@ -568,6 +623,7 @@ class Embedder:
568
623
  storage.add_attr("frame_embeddings_name", block_name)
569
624
  storage.add_attr("arrays_per_chunk", int(num_matrices_in_compressed_blocks))
570
625
  storage.add_attr("compression_level", int(compression_level))
626
+ storage.add_attr("objective", self.objective if objective is None else objective)
571
627
 
572
628
  _logger.info("Embedding archive written to %s", output_path)
573
629
  return str(output_path)
@@ -0,0 +1,247 @@
1
+ from __future__ import annotations
2
+
3
+ # third party
4
+ import numpy as np
5
+ import matplotlib as mpl
6
+
7
+ # built-in
8
+ from pathlib import Path
9
+ from typing import Sequence
10
+ import logging
11
+
12
+ # local
13
+ from ..visual import visualizer_util
14
+ from .. import sawnergy_util
15
+
16
+ # *----------------------------------------------------*
17
+ # GLOBALS
18
+ # *----------------------------------------------------*
19
+
20
+ _logger = logging.getLogger(__name__)
21
+
22
+ # *----------------------------------------------------*
23
+ # HELPERS
24
+ # *----------------------------------------------------*
25
+
26
+ def _safe_svd_pca(X: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
27
+ """Compute k principal directions via SVD and project onto them."""
28
+ if X.ndim != 2:
29
+ raise ValueError(f"PCA expects 2D array (N, D); got {X.shape}")
30
+ _, D = X.shape
31
+ if k not in (2, 3):
32
+ raise ValueError(f"PCA dimensionality must be 2 or 3; got {k}")
33
+ if D < k:
34
+ raise ValueError(f"Requested k={k} exceeds feature dim D={D}")
35
+ Xc = X - X.mean(axis=0, keepdims=True)
36
+ _, _, Vt = np.linalg.svd(Xc, full_matrices=False)
37
+ comps = Vt[:k].copy()
38
+ proj = Xc @ comps.T
39
+ return proj, comps
40
+
41
+ def _set_equal_axes_3d(ax, xyz: np.ndarray, *, padding: float = 0.05) -> None:
42
+ if xyz.size == 0:
43
+ return
44
+ x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]
45
+ xmin, xmax = float(x.min()), float(x.max())
46
+ ymin, ymax = float(y.min()), float(y.max())
47
+ zmin, zmax = float(z.min()), float(z.max())
48
+ xr = xmax - xmin
49
+ yr = ymax - ymin
50
+ zr = zmax - zmin
51
+ r = max(xr, yr, zr)
52
+ pad = padding * (r if r > 0 else 1.0)
53
+ cx, cy, cz = (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, (zmin + zmax) / 2.0
54
+ ax.set_xlim(cx - r / 2 - pad, cx + r / 2 + pad)
55
+ ax.set_ylim(cy - r / 2 - pad, cy + r / 2 + pad)
56
+ ax.set_zlim(cz - r / 2 - pad, cz + r / 2 + pad)
57
+ try:
58
+ ax.set_box_aspect([1, 1, 1])
59
+ except Exception:
60
+ pass
61
+
62
+ # *----------------------------------------------------*
63
+ # CLASS
64
+ # *----------------------------------------------------*
65
+
66
+ class Visualizer:
67
+ """3D PCA visualizer for per-frame embeddings"""
68
+
69
+ no_instances: bool = True
70
+
71
+ def __init__(
72
+ self,
73
+ EMB_path: str | Path,
74
+ figsize: tuple[int, int] = (9, 7),
75
+ default_node_color: str = visualizer_util.GRAY,
76
+ depthshade: bool = False,
77
+ antialiased: bool = False,
78
+ init_elev: float = 35,
79
+ init_azim: float = 45,
80
+ *,
81
+ show: bool = False
82
+ ) -> None:
83
+ # Backend & pyplot
84
+ visualizer_util.ensure_backend(show)
85
+ import matplotlib.pyplot as plt
86
+ self._plt = plt
87
+
88
+ if Visualizer.no_instances:
89
+ try:
90
+ visualizer_util.warm_start_matplotlib()
91
+ finally:
92
+ Visualizer.no_instances = False
93
+
94
+ # Load embeddings archive
95
+ EMB_path = Path(EMB_path)
96
+ with sawnergy_util.ArrayStorage(EMB_path, mode="r") as storage:
97
+ name = storage.get_attr("frame_embeddings_name")
98
+ E = storage.read(name, slice(None))
99
+ if E.ndim != 3:
100
+ raise ValueError(f"Expected embeddings of shape (T,N,D); got {E.shape}")
101
+ self.E = np.asarray(E)
102
+ self.T, self.N, self.D = map(int, self.E.shape)
103
+ _logger.info("Loaded embeddings: T=%d, N=%d, D=%d", self.T, self.N, self.D)
104
+
105
+ # Coloring normalizer (parity with RIN Visualizer)
106
+ self._residue_norm = mpl.colors.Normalize(0, max(1, self.N - 1))
107
+
108
+ # Figure / axes / artists
109
+ self._fig = self._plt.figure(figsize=figsize)
110
+ self._ax = None
111
+ self._scatter = None
112
+ self._marker_size = 30.0
113
+ self._init_elev = init_elev
114
+ self._init_azim = init_azim
115
+ self.default_node_color = default_node_color
116
+ self._antialiased = bool(antialiased)
117
+ self._depthshade = bool(depthshade)
118
+
119
+ # ------------------------------ PRIVATE ------------------------------ #
120
+
121
+ def _ensure_axes(self) -> None:
122
+ if self._ax is not None and self._scatter is not None:
123
+ return
124
+ self._fig.clf()
125
+ self._ax = self._fig.add_subplot(111, projection="3d")
126
+ self._ax.view_init(self._init_elev, self._init_azim)
127
+ self._scatter = self._ax.scatter(
128
+ [], [], [],
129
+ s=self._marker_size,
130
+ depthshade=self._depthshade,
131
+ edgecolors="none",
132
+ antialiased=self._antialiased,
133
+ )
134
+ try:
135
+ self._ax.set_axis_off()
136
+ except Exception:
137
+ pass
138
+
139
+ def _project3(self, X: np.ndarray) -> np.ndarray:
140
+ """Return a 3D PCA projection of embeddings (always 3 coordinates).
141
+
142
+ If the embedding dimensionality D < 3, the remaining coordinate(s) are set to 0
143
+ so that the returned array still has shape (N, 3).
144
+ """
145
+ k = 3 if X.shape[1] >= 3 else 2
146
+ P, _ = _safe_svd_pca(X, k)
147
+ if k == 2:
148
+ P = np.c_[P, np.zeros((P.shape[0], 1), dtype=P.dtype)]
149
+ return P
150
+
151
+ def _select_nodes(self, displayed_nodes: Sequence[int] | str | None) -> np.ndarray:
152
+ if displayed_nodes is None or displayed_nodes == "ALL":
153
+ return np.arange(self.N, dtype=np.int64)
154
+ idx = np.asarray(displayed_nodes)
155
+ if idx.dtype.kind not in "iu":
156
+ raise TypeError("displayed_nodes must be None, 'ALL', or an integer sequence.")
157
+ if idx.min() < 1 or idx.max() > self.N:
158
+ raise IndexError(f"displayed_nodes out of range [1,{self.N}]")
159
+ return idx.astype(np.int64) - 1
160
+
161
+ def _apply_colors(self, node_colors, idx: np.ndarray) -> np.ndarray:
162
+ # RIN Visualizer semantics:
163
+ if isinstance(node_colors, str):
164
+ node_cmap = self._plt.get_cmap(node_colors)
165
+ return node_cmap(self._residue_norm(idx))
166
+ if node_colors is None:
167
+ full = visualizer_util.map_groups_to_colors(
168
+ N=self.N, groups=None, default_color=self.default_node_color, one_based=True
169
+ )
170
+ return np.asarray(full)[idx]
171
+ arr = np.asarray(node_colors)
172
+ if arr.ndim == 2 and arr.shape[0] == self.N and arr.shape[1] in (3, 4):
173
+ return arr[idx]
174
+ full = visualizer_util.map_groups_to_colors(
175
+ N=self.N, groups=node_colors, default_color=self.default_node_color, one_based=True
176
+ )
177
+ return np.asarray(full)[idx]
178
+
179
+ # ------------------------------ PUBLIC ------------------------------- #
180
+
181
+ def build_frame(
182
+ self,
183
+ frame_id: int,
184
+ *,
185
+ node_colors: str | np.ndarray | None = "rainbow",
186
+ displayed_nodes: Sequence[int] | str | None = "ALL",
187
+ show_node_labels: bool = False,
188
+ show: bool = False
189
+ ) -> None:
190
+ """Render a single frame as a PCA **3D** scatter (matches RIN Visualizer API)."""
191
+ frame0 = int(frame_id) - 1
192
+ if not (0 <= frame0 < self.T):
193
+ raise IndexError(f"frame_id out of range [1,{self.T}]")
194
+ self._ensure_axes()
195
+
196
+ idx = self._select_nodes(displayed_nodes)
197
+ X = self.E[frame0, idx, :] # (n, D)
198
+ P = self._project3(X) # (n, 3)
199
+ colors = self._apply_colors(node_colors, idx)
200
+
201
+ x, y, z = P[:, 0], P[:, 1], P[:, 2]
202
+ self._scatter._offsets3d = (x, y, z)
203
+ self._scatter.set_facecolors(colors)
204
+ _set_equal_axes_3d(self._ax, P, padding=0.05)
205
+ self._ax.view_init(self._init_elev, self._init_azim)
206
+
207
+ if show_node_labels:
208
+ for txt in getattr(self, "_labels", []):
209
+ try:
210
+ txt.remove()
211
+ except Exception:
212
+ pass
213
+ self._labels = []
214
+ for p, nid in zip(P, idx + 1):
215
+ self._labels.append(self._ax.text(p[0], p[1], p[2], str(int(nid)), fontsize=8))
216
+
217
+ # Be friendly to test dummies (they may lack tight_layout/canvas)
218
+ try:
219
+ self._fig.tight_layout()
220
+ except Exception:
221
+ try:
222
+ self._fig.subplots_adjust()
223
+ except Exception:
224
+ pass
225
+ try:
226
+ self._fig.canvas.draw_idle()
227
+ except Exception:
228
+ pass
229
+
230
+ if show:
231
+ try:
232
+ self._plt.show(block=True)
233
+ except TypeError:
234
+ self._plt.show()
235
+
236
+ # convenience
237
+ def savefig(self, path: str | Path, *, dpi: int = 150) -> None:
238
+ self._fig.savefig(path, dpi=dpi)
239
+
240
+ def close(self) -> None:
241
+ try:
242
+ self._plt.close(self._fig)
243
+ except Exception:
244
+ pass
245
+
246
+
247
+ __all__ = ["Visualizer"]
sawnergy/logging_util.py CHANGED
@@ -6,7 +6,7 @@ from datetime import datetime
6
6
 
7
7
  def configure_logging(
8
8
  logs_dir: Path | str,
9
- file_level: int = logging.DEBUG,
9
+ file_level: int = logging.WARNING,
10
10
  console_level: int = logging.WARNING
11
11
  ) -> None:
12
12
  """
@@ -669,7 +669,7 @@ class RINBuilder:
669
669
  molecule_of_interest: int,
670
670
  frame_range: tuple[int, int] | None = None,
671
671
  frame_batch_size: int = -1,
672
- prune_low_energies_frac: float = 0.3,
672
+ prune_low_energies_frac: float = 0.85,
673
673
  output_path: str | Path | None = None,
674
674
  keep_prenormalized_energies: bool = True,
675
675
  *,
@@ -107,7 +107,7 @@ class Visualizer:
107
107
  visualizer_util.ensure_backend(show)
108
108
  import matplotlib.pyplot as plt
109
109
  self._plt = plt
110
- # ---------- WARM UP MPL ------------ #
110
+ # ---------- WARM UP MPL ------------ #
111
111
  _logger.debug("Visualizer.__init__ start | RIN_path=%s, figsize=%s, node_size=%s, edge_width=%s, depthshade=%s, antialiased=%s, init_view=(%s,%s)",
112
112
  RIN_path, figsize, node_size, edge_width, depthshade, antialiased, init_elev, init_azim)
113
113
  if Visualizer.no_instances:
@@ -116,7 +116,7 @@ class Visualizer:
116
116
  else:
117
117
  _logger.debug("Skipping warm-start (no_instances=False).")
118
118
 
119
- # ---------- LOAD THE DATA ---------- #
119
+ # ---------- LOAD THE DATA ---------- #
120
120
  with sawnergy_util.ArrayStorage(RIN_path, mode="r") as storage:
121
121
  com_name = storage.get_attr("com_name")
122
122
  attr_energies_name = storage.get_attr("attractive_energies_name")
@@ -135,7 +135,7 @@ class Visualizer:
135
135
  self.N = np.size(self.COM_coords[0], axis=0)
136
136
  _logger.debug("Computed N=%d", self.N)
137
137
 
138
- # - SET UP THE CANVAS AND THE AXES - #
138
+ # - SET UP THE CANVAS AND THE AXES - #
139
139
  self._fig = plt.figure(figsize=figsize, num="SAWNERGY")
140
140
  self._ax = self._fig.add_subplot(111, projection="3d")
141
141
  self._fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
@@ -145,14 +145,14 @@ class Visualizer:
145
145
  self._ax.set_axis_off()
146
146
  _logger.debug("Figure and 3D axes initialized.")
147
147
 
148
- # ------ SET UP PLOT ELEMENTS ------ #
148
+ # ------ SET UP PLOT ELEMENTS ------ #
149
149
  self._scatter: PathCollection = self._ax.scatter([], [], [], s=node_size, depthshade=depthshade, edgecolors="none")
150
150
  self._attr: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
151
151
  self._repuls: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
152
152
  self._ax.add_collection3d(self._attr); self._ax.add_collection3d(self._repuls) # set pointers to the attractive and repulsive collections
153
153
  _logger.debug("Artists created | scatter(empty), attr_lines(empty), repuls_lines(empty).")
154
154
 
155
- # ---------- HELPER FIELDS --------- #
155
+ # ---------- HELPER FIELDS --------- #
156
156
  # NOTE: 'under the hood' everything is 0-base indexed,
157
157
  # BUT, from the API point of view, the indexing is 1-base,
158
158
  # because amino acid residues are 1-base indexed.
@@ -160,7 +160,7 @@ class Visualizer:
160
160
  self.default_node_color = default_node_color
161
161
  _logger.debug("Helper fields set | residue_norm=[0,%d], default_node_color=%s", self.N-1, self.default_node_color)
162
162
 
163
- # DISALLOW MPL WARM-UP IN THE FUTURE
163
+ # DISALLOW MPL WARM-UP IN THE FUTURE
164
164
  Visualizer.no_instances = False
165
165
  _logger.debug("Visualizer.no_instances set to False.")
166
166
 
@@ -319,6 +319,9 @@ def build_line_segments(
319
319
  kept = edge_weights >= thresh
320
320
  rows, cols = rows[kept], cols[kept]
321
321
 
322
+ nz = weights[rows, cols] > 0.0
323
+ rows, cols = rows[nz], cols[nz]
324
+
322
325
  if rows.size == 0:
323
326
  _logger.debug("build_line_segments: no edges kept after threshold; returning empties.")
324
327
  return (np.empty((0, 2, 3), dtype=float),