sawnergy 1.0.6__py3-none-any.whl → 1.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sawnergy might be problematic. Click here for more details.

@@ -1,11 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  """
4
- Embedding orchestration for Skip-Gram with Negative Sampling (SGNS).
4
+ Embedding orchestration for Skip-Gram (SG) and Skip-Gram with Negative Sampling (SGNS).
5
5
 
6
6
  This module consumes attractive/repulsive walk corpora produced by the walker
7
7
  pipeline and trains per-frame embeddings using either the PyTorch or PureML
8
- implementations of SGNS. The resulting embeddings can be persisted back into
8
+ implementations of SG/SGNS. The resulting embeddings can be persisted back into
9
9
  an ``ArrayStorage`` archive along with rich metadata describing the training
10
10
  configuration.
11
11
  """
@@ -38,7 +38,8 @@ class Embedder:
38
38
  WALKS_path: str | Path,
39
39
  base: Literal["torch", "pureml"],
40
40
  *,
41
- seed: int | None = None
41
+ seed: int | None = None,
42
+ objective: Literal["sgns", "sg"] = "sgns"
42
43
  ) -> None:
43
44
  """Initialize the embedder and load walk tensors.
44
45
 
@@ -53,6 +54,8 @@ class Embedder:
53
54
  base: Which SGNS backend to use, either ``"torch"`` or ``"pureml"``.
54
55
  seed: Optional seed for the embedder's RNG. If ``None``, a random
55
56
  32-bit seed is chosen.
57
+ objective: Training objective, either ``"sgns"`` (negative sampling)
58
+ or ``"sg"`` (plain full-softmax Skip-Gram).
56
59
 
57
60
  Raises:
58
61
  ValueError: If required metadata is missing or any loaded walk array
@@ -156,20 +159,25 @@ class Embedder:
156
159
 
157
160
  # MODEL HANDLE
158
161
  self.model_base: Literal["torch", "pureml"] = base
159
- self.model_constructor = self._get_SGNS_constructor_from(base)
160
- _logger.info("SGNS backend resolved: %s", getattr(self.model_constructor, "__name__", repr(self.model_constructor)))
162
+ self.objective: Literal["sgns", "sg"] = objective
163
+ self.model_constructor = self._get_SGNS_constructor_from(base, objective)
164
+ _logger.info(
165
+ "SG backend resolved: %s (objective=%s)",
166
+ getattr(self.model_constructor, "__name__", repr(self.model_constructor)), self.objective
167
+ )
161
168
 
162
169
  # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- PRIVATE -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
163
170
 
164
171
  # HELPERS:
165
172
 
166
173
  @staticmethod
167
- def _get_SGNS_constructor_from(base: Literal["torch", "pureml"]):
168
- """Resolve the SGNS implementation class for the selected backend."""
174
+ def _get_SGNS_constructor_from(base: Literal["torch", "pureml"],
175
+ objective: Literal["sgns", "sg"]):
176
+ """Resolve the SG/SGNS implementation class for the selected backend."""
169
177
  if base == "torch":
170
178
  try:
171
- from .SGNS_torch import SGNS_Torch
172
- return SGNS_Torch
179
+ from .SGNS_torch import SGNS_Torch, SG_Torch
180
+ return SG_Torch if objective == "sg" else SGNS_Torch
173
181
  except Exception:
174
182
  raise ImportError(
175
183
  "PyTorch is not installed, but base='torch' was requested. "
@@ -178,8 +186,8 @@ class Embedder:
178
186
  )
179
187
  elif base == "pureml":
180
188
  try:
181
- from .SGNS_pml import SGNS_PureML
182
- return SGNS_PureML
189
+ from .SGNS_pml import SGNS_PureML, SG_PureML
190
+ return SG_PureML if objective == "sg" else SGNS_PureML
183
191
  except Exception:
184
192
  raise ImportError(
185
193
  "PureML is not installed, but base='pureml' was requested. "
@@ -322,22 +330,24 @@ class Embedder:
322
330
  # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= PUBLIC -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
323
331
 
324
332
  def embed_frame(self,
325
- frame_id: int,
326
- RIN_type: Literal["attr", "repuls"],
327
- using: Literal["RW", "SAW", "merged"],
328
- window_size: int,
329
- num_negative_samples: int,
330
- num_epochs: int,
331
- batch_size: int,
332
- *,
333
- lr_step_per_batch: bool = False,
334
- shuffle_data: bool = True,
335
- dimensionality: int = 128,
336
- alpha: float = 0.75,
337
- device: str | None = None,
338
- sgns_kwargs: dict[str, object] | None = None,
339
- _seed: int | None = None
340
- ) -> np.ndarray:
333
+ frame_id: int,
334
+ RIN_type: Literal["attr", "repuls"],
335
+ using: Literal["RW", "SAW", "merged"],
336
+ window_size: int,
337
+ num_negative_samples: int,
338
+ num_epochs: int,
339
+ batch_size: int,
340
+ *,
341
+ lr_step_per_batch: bool = False,
342
+ shuffle_data: bool = True,
343
+ dimensionality: int = 128,
344
+ alpha: float = 0.75,
345
+ device: str | None = None,
346
+ sgns_kwargs: dict[str, object] | None = None,
347
+ kind: Literal["in", "out", "avg"] = "in",
348
+ _seed: int | None = None,
349
+ objective: Literal["sgns", "sg"] | None = None
350
+ ) -> np.ndarray:
341
351
  """Train embeddings for a single frame and return the input embedding matrix.
342
352
 
343
353
  Args:
@@ -348,6 +358,7 @@ class Embedder:
348
358
  ``"merged"`` (concatenates both if available).
349
359
  window_size: Symmetric skip-gram window size ``k``.
350
360
  num_negative_samples: Number of negative samples per positive pair.
361
+ Ignored when ``objective="sg"``.
351
362
  num_epochs: Number of passes over the pair dataset.
352
363
  batch_size: Mini-batch size for training.
353
364
  shuffle_data: Whether to shuffle pairs each epoch.
@@ -358,10 +369,13 @@ class Embedder:
358
369
  constructor. For PureML, required keys are:
359
370
  ``{"optim", "optim_kwargs"}``; ``lr_sched`` is optional, but if
360
371
  provided then ``lr_sched_kwargs`` must also be provided.
372
+ kind: Which embedding matrix to return: ``"in"``, ``"out"``, or ``"avg"``.
361
373
  _seed: Optional child seed for this frame's model initialization.
374
+ objective: Training objective override for this call (``"sgns"`` or
375
+ ``"sg"``). If ``None``, uses the value set at construction.
362
376
 
363
377
  Returns:
364
- np.ndarray: Learned **input** embedding matrix of shape ``(V, D)``.
378
+ np.ndarray: Learned embedding matrix (selected by ``kind``) of shape ``(V, D)``.
365
379
 
366
380
  Raises:
367
381
  ValueError: If requested walks are missing, if no training pairs are
@@ -412,12 +426,16 @@ class Embedder:
412
426
  if self.model_base == "torch" and device is not None:
413
427
  model_kwargs["device"] = device
414
428
 
429
+ # Resolve objective (call-level override beats constructor default)
430
+ obj = self.objective if objective is None else objective
431
+ self.model_constructor = self._get_SGNS_constructor_from(self.model_base, obj)
415
432
  self.model = self.model_constructor(**model_kwargs)
416
433
 
417
434
  _logger.info(
418
- "Training SGNS base=%s constructor=%s frame=%d pairs=%d dim=%d epochs=%d batch=%d neg=%d shuffle=%s",
435
+ "Training SG base=%s constructor=%s objective=%s frame=%d pairs=%d dim=%d epochs=%d batch=%d neg=%d shuffle=%s",
419
436
  self.model_base,
420
437
  getattr(self.model_constructor, "__name__", repr(self.model_constructor)),
438
+ obj,
421
439
  frame_id,
422
440
  pairs.shape[0],
423
441
  dimensionality,
@@ -427,24 +445,46 @@ class Embedder:
427
445
  shuffle_data
428
446
  )
429
447
 
430
- self.model.fit(
431
- centers,
432
- contexts,
433
- num_epochs,
434
- batch_size,
435
- num_negative_samples,
436
- noise_probs,
437
- shuffle_data,
438
- lr_step_per_batch
439
- )
448
+ if obj == "sgns":
449
+ self.model.fit(
450
+ centers,
451
+ contexts,
452
+ num_epochs,
453
+ batch_size,
454
+ num_negative_samples,
455
+ noise_probs,
456
+ shuffle_data,
457
+ lr_step_per_batch
458
+ )
459
+ else:
460
+ self.model.fit(
461
+ centers,
462
+ contexts,
463
+ num_epochs,
464
+ batch_size,
465
+ shuffle_data,
466
+ lr_step_per_batch
467
+ )
468
+
469
+ # Select embedding matrix by kind
470
+ if kind == "in":
471
+ embeddings = getattr(self.model, "in_embeddings", None)
472
+ if embeddings is None:
473
+ embeddings = getattr(self.model, "embeddings", None)
474
+ if embeddings is None:
475
+ params = getattr(self.model, "parameters", None)
476
+ if isinstance(params, tuple) and params:
477
+ embeddings = params[0]
478
+ elif kind == "out":
479
+ embeddings = getattr(self.model, "out_embeddings", None)
480
+ else: # "avg"
481
+ embeddings = getattr(self.model, "avg_embeddings", None)
440
482
 
441
- embeddings = getattr(self.model, "embeddings", None)
442
- if embeddings is None:
443
- params = getattr(self.model, "parameters", None)
444
- if isinstance(params, tuple) and params:
445
- embeddings = params[0]
446
483
  if embeddings is None:
447
- raise AttributeError("SGNS model does not expose embeddings via '.embeddings' or '.parameters[0]'")
484
+ raise AttributeError(
485
+ "SG/SGNS model does not expose the requested embeddings: "
486
+ f"kind={kind!r}"
487
+ )
448
488
 
449
489
  embeddings = np.asarray(embeddings)
450
490
  _logger.info("Frame %d embeddings ready: shape=%s dtype=%s", frame_id, embeddings.shape, embeddings.dtype)
@@ -466,7 +506,9 @@ class Embedder:
466
506
  sgns_kwargs: dict[str, object] | None = None,
467
507
  output_path: str | Path | None = None,
468
508
  num_matrices_in_compressed_blocks: int = 20,
469
- compression_level: int = 3):
509
+ compression_level: int = 3,
510
+ objective: Literal["sgns", "sg"] | None = None,
511
+ kind: Literal["in", "out", "avg"] = "in"):
470
512
  """Train embeddings for all frames and persist them to compressed storage.
471
513
 
472
514
  Iterates through all frames (``1..frame_count``), trains an SGNS model
@@ -478,6 +520,7 @@ class Embedder:
478
520
  using: Walk collections: ``"RW"``, ``"SAW"``, or ``"merged"``.
479
521
  window_size: Symmetric skip-gram window size ``k``.
480
522
  num_negative_samples: Number of negative samples per positive pair.
523
+ Ignored when ``objective="sg"``.
481
524
  num_epochs: Number of epochs for each frame.
482
525
  batch_size: Mini-batch size used during training.
483
526
  shuffle_data: Whether to shuffle pairs each epoch.
@@ -493,6 +536,10 @@ class Embedder:
493
536
  num_matrices_in_compressed_blocks: Number of per-frame matrices to
494
537
  store per compressed chunk in the output archive.
495
538
  compression_level: Blosc Zstd compression level (0-9).
539
+ objective: Training objective for all frames (``"sgns"`` or ``"sg"``).
540
+ If ``None``, uses the value set at construction.
541
+ kind: Which embedding matrix to persist for each frame: ``"in"``,
542
+ ``"out"``, or ``"avg"``.
496
543
 
497
544
  Returns:
498
545
  str: Filesystem path to the written embeddings archive (``.zip``).
@@ -504,8 +551,8 @@ class Embedder:
504
551
 
505
552
  Notes:
506
553
  - A deterministic child seed is spawned per frame from the master
507
- seed using ``np.random.SeedSequence`` to ensure reproducibility
508
- across runs.
554
+ seed using ``np.random.SeedSequence`` to ensure reproducibility
555
+ across runs.
509
556
  """
510
557
  current_time = sawnergy_util.current_time()
511
558
  if output_path is None:
@@ -541,7 +588,9 @@ class Embedder:
541
588
  alpha=alpha,
542
589
  device=device,
543
590
  sgns_kwargs=sgns_kwargs,
544
- _seed=child_seed
591
+ _seed=child_seed,
592
+ objective=objective,
593
+ kind=kind
545
594
  )
546
595
  )
547
596
 
@@ -574,6 +623,7 @@ class Embedder:
574
623
  storage.add_attr("frame_embeddings_name", block_name)
575
624
  storage.add_attr("arrays_per_chunk", int(num_matrices_in_compressed_blocks))
576
625
  storage.add_attr("compression_level", int(compression_level))
626
+ storage.add_attr("objective", self.objective if objective is None else objective)
577
627
 
578
628
  _logger.info("Embedding archive written to %s", output_path)
579
629
  return str(output_path)
@@ -0,0 +1,247 @@
1
+ from __future__ import annotations
2
+
3
+ # third party
4
+ import numpy as np
5
+ import matplotlib as mpl
6
+
7
+ # built-in
8
+ from pathlib import Path
9
+ from typing import Sequence
10
+ import logging
11
+
12
+ # local
13
+ from ..visual import visualizer_util
14
+ from .. import sawnergy_util
15
+
16
+ # *----------------------------------------------------*
17
+ # GLOBALS
18
+ # *----------------------------------------------------*
19
+
20
+ _logger = logging.getLogger(__name__)
21
+
22
+ # *----------------------------------------------------*
23
+ # HELPERS
24
+ # *----------------------------------------------------*
25
+
26
+ def _safe_svd_pca(X: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]:
27
+ """Compute k principal directions via SVD and project onto them."""
28
+ if X.ndim != 2:
29
+ raise ValueError(f"PCA expects 2D array (N, D); got {X.shape}")
30
+ _, D = X.shape
31
+ if k not in (2, 3):
32
+ raise ValueError(f"PCA dimensionality must be 2 or 3; got {k}")
33
+ if D < k:
34
+ raise ValueError(f"Requested k={k} exceeds feature dim D={D}")
35
+ Xc = X - X.mean(axis=0, keepdims=True)
36
+ _, _, Vt = np.linalg.svd(Xc, full_matrices=False)
37
+ comps = Vt[:k].copy()
38
+ proj = Xc @ comps.T
39
+ return proj, comps
40
+
41
+ def _set_equal_axes_3d(ax, xyz: np.ndarray, *, padding: float = 0.05) -> None:
42
+ if xyz.size == 0:
43
+ return
44
+ x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]
45
+ xmin, xmax = float(x.min()), float(x.max())
46
+ ymin, ymax = float(y.min()), float(y.max())
47
+ zmin, zmax = float(z.min()), float(z.max())
48
+ xr = xmax - xmin
49
+ yr = ymax - ymin
50
+ zr = zmax - zmin
51
+ r = max(xr, yr, zr)
52
+ pad = padding * (r if r > 0 else 1.0)
53
+ cx, cy, cz = (xmin + xmax) / 2.0, (ymin + ymax) / 2.0, (zmin + zmax) / 2.0
54
+ ax.set_xlim(cx - r / 2 - pad, cx + r / 2 + pad)
55
+ ax.set_ylim(cy - r / 2 - pad, cy + r / 2 + pad)
56
+ ax.set_zlim(cz - r / 2 - pad, cz + r / 2 + pad)
57
+ try:
58
+ ax.set_box_aspect([1, 1, 1])
59
+ except Exception:
60
+ pass
61
+
62
+ # *----------------------------------------------------*
63
+ # CLASS
64
+ # *----------------------------------------------------*
65
+
66
+ class Visualizer:
67
+ """3D PCA visualizer for per-frame embeddings"""
68
+
69
+ no_instances: bool = True
70
+
71
+ def __init__(
72
+ self,
73
+ EMB_path: str | Path,
74
+ figsize: tuple[int, int] = (9, 7),
75
+ default_node_color: str = visualizer_util.GRAY,
76
+ depthshade: bool = False,
77
+ antialiased: bool = False,
78
+ init_elev: float = 35,
79
+ init_azim: float = 45,
80
+ *,
81
+ show: bool = False
82
+ ) -> None:
83
+ # Backend & pyplot
84
+ visualizer_util.ensure_backend(show)
85
+ import matplotlib.pyplot as plt
86
+ self._plt = plt
87
+
88
+ if Visualizer.no_instances:
89
+ try:
90
+ visualizer_util.warm_start_matplotlib()
91
+ finally:
92
+ Visualizer.no_instances = False
93
+
94
+ # Load embeddings archive
95
+ EMB_path = Path(EMB_path)
96
+ with sawnergy_util.ArrayStorage(EMB_path, mode="r") as storage:
97
+ name = storage.get_attr("frame_embeddings_name")
98
+ E = storage.read(name, slice(None))
99
+ if E.ndim != 3:
100
+ raise ValueError(f"Expected embeddings of shape (T,N,D); got {E.shape}")
101
+ self.E = np.asarray(E)
102
+ self.T, self.N, self.D = map(int, self.E.shape)
103
+ _logger.info("Loaded embeddings: T=%d, N=%d, D=%d", self.T, self.N, self.D)
104
+
105
+ # Coloring normalizer (parity with RIN Visualizer)
106
+ self._residue_norm = mpl.colors.Normalize(0, max(1, self.N - 1))
107
+
108
+ # Figure / axes / artists
109
+ self._fig = self._plt.figure(figsize=figsize)
110
+ self._ax = None
111
+ self._scatter = None
112
+ self._marker_size = 30.0
113
+ self._init_elev = init_elev
114
+ self._init_azim = init_azim
115
+ self.default_node_color = default_node_color
116
+ self._antialiased = bool(antialiased)
117
+ self._depthshade = bool(depthshade)
118
+
119
+ # ------------------------------ PRIVATE ------------------------------ #
120
+
121
+ def _ensure_axes(self) -> None:
122
+ if self._ax is not None and self._scatter is not None:
123
+ return
124
+ self._fig.clf()
125
+ self._ax = self._fig.add_subplot(111, projection="3d")
126
+ self._ax.view_init(self._init_elev, self._init_azim)
127
+ self._scatter = self._ax.scatter(
128
+ [], [], [],
129
+ s=self._marker_size,
130
+ depthshade=self._depthshade,
131
+ edgecolors="none",
132
+ antialiased=self._antialiased,
133
+ )
134
+ try:
135
+ self._ax.set_axis_off()
136
+ except Exception:
137
+ pass
138
+
139
+ def _project3(self, X: np.ndarray) -> np.ndarray:
140
+ """Return a 3D PCA projection of embeddings (always 3 coordinates).
141
+
142
+ If the embedding dimensionality D < 3, the remaining coordinate(s) are set to 0
143
+ so that the returned array still has shape (N, 3).
144
+ """
145
+ k = 3 if X.shape[1] >= 3 else 2
146
+ P, _ = _safe_svd_pca(X, k)
147
+ if k == 2:
148
+ P = np.c_[P, np.zeros((P.shape[0], 1), dtype=P.dtype)]
149
+ return P
150
+
151
+ def _select_nodes(self, displayed_nodes: Sequence[int] | str | None) -> np.ndarray:
152
+ if displayed_nodes is None or displayed_nodes == "ALL":
153
+ return np.arange(self.N, dtype=np.int64)
154
+ idx = np.asarray(displayed_nodes)
155
+ if idx.dtype.kind not in "iu":
156
+ raise TypeError("displayed_nodes must be None, 'ALL', or an integer sequence.")
157
+ if idx.min() < 1 or idx.max() > self.N:
158
+ raise IndexError(f"displayed_nodes out of range [1,{self.N}]")
159
+ return idx.astype(np.int64) - 1
160
+
161
+ def _apply_colors(self, node_colors, idx: np.ndarray) -> np.ndarray:
162
+ # RIN Visualizer semantics:
163
+ if isinstance(node_colors, str):
164
+ node_cmap = self._plt.get_cmap(node_colors)
165
+ return node_cmap(self._residue_norm(idx))
166
+ if node_colors is None:
167
+ full = visualizer_util.map_groups_to_colors(
168
+ N=self.N, groups=None, default_color=self.default_node_color, one_based=True
169
+ )
170
+ return np.asarray(full)[idx]
171
+ arr = np.asarray(node_colors)
172
+ if arr.ndim == 2 and arr.shape[0] == self.N and arr.shape[1] in (3, 4):
173
+ return arr[idx]
174
+ full = visualizer_util.map_groups_to_colors(
175
+ N=self.N, groups=node_colors, default_color=self.default_node_color, one_based=True
176
+ )
177
+ return np.asarray(full)[idx]
178
+
179
+ # ------------------------------ PUBLIC ------------------------------- #
180
+
181
+ def build_frame(
182
+ self,
183
+ frame_id: int,
184
+ *,
185
+ node_colors: str | np.ndarray | None = "rainbow",
186
+ displayed_nodes: Sequence[int] | str | None = "ALL",
187
+ show_node_labels: bool = False,
188
+ show: bool = False
189
+ ) -> None:
190
+ """Render a single frame as a PCA **3D** scatter (matches RIN Visualizer API)."""
191
+ frame0 = int(frame_id) - 1
192
+ if not (0 <= frame0 < self.T):
193
+ raise IndexError(f"frame_id out of range [1,{self.T}]")
194
+ self._ensure_axes()
195
+
196
+ idx = self._select_nodes(displayed_nodes)
197
+ X = self.E[frame0, idx, :] # (n, D)
198
+ P = self._project3(X) # (n, 3)
199
+ colors = self._apply_colors(node_colors, idx)
200
+
201
+ x, y, z = P[:, 0], P[:, 1], P[:, 2]
202
+ self._scatter._offsets3d = (x, y, z)
203
+ self._scatter.set_facecolors(colors)
204
+ _set_equal_axes_3d(self._ax, P, padding=0.05)
205
+ self._ax.view_init(self._init_elev, self._init_azim)
206
+
207
+ if show_node_labels:
208
+ for txt in getattr(self, "_labels", []):
209
+ try:
210
+ txt.remove()
211
+ except Exception:
212
+ pass
213
+ self._labels = []
214
+ for p, nid in zip(P, idx + 1):
215
+ self._labels.append(self._ax.text(p[0], p[1], p[2], str(int(nid)), fontsize=8))
216
+
217
+ # Be friendly to test dummies (they may lack tight_layout/canvas)
218
+ try:
219
+ self._fig.tight_layout()
220
+ except Exception:
221
+ try:
222
+ self._fig.subplots_adjust()
223
+ except Exception:
224
+ pass
225
+ try:
226
+ self._fig.canvas.draw_idle()
227
+ except Exception:
228
+ pass
229
+
230
+ if show:
231
+ try:
232
+ self._plt.show(block=True)
233
+ except TypeError:
234
+ self._plt.show()
235
+
236
+ # convenience
237
+ def savefig(self, path: str | Path, *, dpi: int = 150) -> None:
238
+ self._fig.savefig(path, dpi=dpi)
239
+
240
+ def close(self) -> None:
241
+ try:
242
+ self._plt.close(self._fig)
243
+ except Exception:
244
+ pass
245
+
246
+
247
+ __all__ = ["Visualizer"]
sawnergy/logging_util.py CHANGED
@@ -6,7 +6,7 @@ from datetime import datetime
6
6
 
7
7
  def configure_logging(
8
8
  logs_dir: Path | str,
9
- file_level: int = logging.DEBUG,
9
+ file_level: int = logging.WARNING,
10
10
  console_level: int = logging.WARNING
11
11
  ) -> None:
12
12
  """
@@ -669,7 +669,7 @@ class RINBuilder:
669
669
  molecule_of_interest: int,
670
670
  frame_range: tuple[int, int] | None = None,
671
671
  frame_batch_size: int = -1,
672
- prune_low_energies_frac: float = 0.3,
672
+ prune_low_energies_frac: float = 0.85,
673
673
  output_path: str | Path | None = None,
674
674
  keep_prenormalized_energies: bool = True,
675
675
  *,
@@ -107,7 +107,7 @@ class Visualizer:
107
107
  visualizer_util.ensure_backend(show)
108
108
  import matplotlib.pyplot as plt
109
109
  self._plt = plt
110
- # ---------- WARM UP MPL ------------ #
110
+ # ---------- WARM UP MPL ------------ #
111
111
  _logger.debug("Visualizer.__init__ start | RIN_path=%s, figsize=%s, node_size=%s, edge_width=%s, depthshade=%s, antialiased=%s, init_view=(%s,%s)",
112
112
  RIN_path, figsize, node_size, edge_width, depthshade, antialiased, init_elev, init_azim)
113
113
  if Visualizer.no_instances:
@@ -116,7 +116,7 @@ class Visualizer:
116
116
  else:
117
117
  _logger.debug("Skipping warm-start (no_instances=False).")
118
118
 
119
- # ---------- LOAD THE DATA ---------- #
119
+ # ---------- LOAD THE DATA ---------- #
120
120
  with sawnergy_util.ArrayStorage(RIN_path, mode="r") as storage:
121
121
  com_name = storage.get_attr("com_name")
122
122
  attr_energies_name = storage.get_attr("attractive_energies_name")
@@ -135,7 +135,7 @@ class Visualizer:
135
135
  self.N = np.size(self.COM_coords[0], axis=0)
136
136
  _logger.debug("Computed N=%d", self.N)
137
137
 
138
- # - SET UP THE CANVAS AND THE AXES - #
138
+ # - SET UP THE CANVAS AND THE AXES - #
139
139
  self._fig = plt.figure(figsize=figsize, num="SAWNERGY")
140
140
  self._ax = self._fig.add_subplot(111, projection="3d")
141
141
  self._fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
@@ -145,14 +145,14 @@ class Visualizer:
145
145
  self._ax.set_axis_off()
146
146
  _logger.debug("Figure and 3D axes initialized.")
147
147
 
148
- # ------ SET UP PLOT ELEMENTS ------ #
148
+ # ------ SET UP PLOT ELEMENTS ------ #
149
149
  self._scatter: PathCollection = self._ax.scatter([], [], [], s=node_size, depthshade=depthshade, edgecolors="none")
150
150
  self._attr: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
151
151
  self._repuls: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
152
152
  self._ax.add_collection3d(self._attr); self._ax.add_collection3d(self._repuls) # set pointers to the attractive and repulsive collections
153
153
  _logger.debug("Artists created | scatter(empty), attr_lines(empty), repuls_lines(empty).")
154
154
 
155
- # ---------- HELPER FIELDS --------- #
155
+ # ---------- HELPER FIELDS --------- #
156
156
  # NOTE: 'under the hood' everything is 0-base indexed,
157
157
  # BUT, from the API point of view, the indexing is 1-base,
158
158
  # because amino acid residues are 1-base indexed.
@@ -160,7 +160,7 @@ class Visualizer:
160
160
  self.default_node_color = default_node_color
161
161
  _logger.debug("Helper fields set | residue_norm=[0,%d], default_node_color=%s", self.N-1, self.default_node_color)
162
162
 
163
- # DISALLOW MPL WARM-UP IN THE FUTURE
163
+ # DISALLOW MPL WARM-UP IN THE FUTURE
164
164
  Visualizer.no_instances = False
165
165
  _logger.debug("Visualizer.no_instances set to False.")
166
166
 
@@ -319,6 +319,9 @@ def build_line_segments(
319
319
  kept = edge_weights >= thresh
320
320
  rows, cols = rows[kept], cols[kept]
321
321
 
322
+ nz = weights[rows, cols] > 0.0
323
+ rows, cols = rows[nz], cols[nz]
324
+
322
325
  if rows.size == 0:
323
326
  _logger.debug("build_line_segments: no edges kept after threshold; returning empties.")
324
327
  return (np.empty((0, 2, 3), dtype=float),