sawnergy 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sawnergy might be problematic. Click here for more details.

@@ -0,0 +1,690 @@
1
+ from __future__ import annotations
2
+
3
+ # third-pary
4
+ import numpy as np
5
+ import matplotlib as mpl
6
+ from mpl_toolkits.mplot3d.art3d import Line3DCollection
7
+ from matplotlib.collections import PathCollection
8
+ # built-in
9
+ from pathlib import Path
10
+ from typing import Iterable, Literal
11
+ import logging
12
+ # local
13
+ from . import visualizer_util
14
+ from .. import sawnergy_util
15
+
16
+ # *----------------------------------------------------*
17
+ # GLOBALS
18
+ # *----------------------------------------------------*
19
+
20
+ _logger = logging.getLogger(__name__)
21
+
22
+ # *----------------------------------------------------*
23
+ # CLASSES
24
+ # *----------------------------------------------------*
25
+
26
+ class Visualizer:
27
+ """3D network/trajectory visualizer.
28
+
29
+ This class renders nodes (scatter) and pairwise interactions (line segments)
30
+ for frames of a trajectory stored in an ArrayStorage-backed file (e.g., Zarr
31
+ in a ZIP). It supports showing only a subset of nodes, coloring nodes by
32
+ groups or a colormap, drawing attractive/repulsive edges by weight quantiles,
33
+ and animating the full trajectory.
34
+
35
+ Backend & GUI behavior:
36
+ - The Matplotlib backend is chosen in __init__ via
37
+ visualizer_util.ensure_backend(show), *before* importing pyplot.
38
+ If `show=True` but no GUI/display is available (e.g., headless Linux),
39
+ the backend is switched to 'Agg' and a warning is emitted. In this mode
40
+ figures render off-screen; use savefig() instead of interactive windows.
41
+ - pyplot is imported lazily inside __init__ after backend selection and
42
+ stored as `self._plt` to keep backend control deterministic.
43
+
44
+ Attributes:
45
+ no_instances: Class-level flag to warm-start Matplotlib only once.
46
+ COM_coords: Trajectory coordinates, shape (T, N, 3).
47
+ attr_energies: Attractive weights (shape (T, N, N)) or None if absent.
48
+ repuls_energies: Repulsive weights (shape (T, N, N)) or None if absent.
49
+ N: Number of nodes (int).
50
+ _fig: Matplotlib Figure.
51
+ _ax: 3D Axes.
52
+ _scatter: PathCollection for node markers.
53
+ _attr: Line3DCollection for attractive edges.
54
+ _repuls: Line3DCollection for repulsive edges.
55
+ _residue_norm: Normalizer mapping [0, N-1] to [0, 1] for colormaps.
56
+ default_node_color: Hex color string used when no group color is set.
57
+ """
58
+
59
+ no_instances: bool = True
60
+
61
+ def __init__(
62
+ self,
63
+ RIN_path: str | Path,
64
+ figsize: tuple[int, int] = (9, 7),
65
+ node_size: int = 120,
66
+ edge_width: float = 1.25,
67
+ default_node_color: str = visualizer_util.GRAY,
68
+ depthshade: bool = False,
69
+ antialiased: bool = False,
70
+ init_elev: float = 35,
71
+ init_azim: float = 45,
72
+ *,
73
+ show: bool = False
74
+ ) -> None:
75
+ """Initialize the visualizer and load datasets.
76
+
77
+ Args:
78
+ RIN_path: Path to the archive or store containing datasets.
79
+ figsize: Figure size (inches) for the Matplotlib window.
80
+ node_size: Marker area for nodes (passed to `Axes3D.scatter`).
81
+ edge_width: Line width for edge collections.
82
+ default_node_color: Hex color used for nodes not in any group.
83
+ depthshade: Whether to apply depth shading to scatter points.
84
+ antialiased: Whether to antialias line collections.
85
+ init_elev: Initial elevation angle (degrees) for 3D view.
86
+ init_azim: Initial azimuth angle (degrees) for 3D view.
87
+ show: Hint about intended usage. If True and a GUI/display is available,
88
+ interactive windows can be shown later (e.g., via `self._plt.show()`).
89
+ If True but no GUI/display is available, the backend is switched to
90
+ 'Agg' (off-screen) and a warning is issued. This flag does not itself
91
+ call `show()`; it only influences backend selection.
92
+
93
+ Data discovery:
94
+ Dataset names are auto-resolved from storage attrs:
95
+ 'com_name', 'attractive_energies_name', 'repulsive_energies_name'.
96
+ Any missing channel remains disabled (None) but visualization still works
97
+ just without edges of a specific missing type.
98
+
99
+ Side Effects:
100
+ - Selects a Matplotlib backend before importing pyplot; may fall back
101
+ to 'Agg' in headless environments when `show=True`.
102
+ - Optionally warms up Matplotlib once per process (first instance only).
103
+ - Opens and reads required datasets from storage.
104
+ - Creates a figure, 3D axes, and empty artists for later updates.
105
+ """
106
+ # choose GUI backend before importing pyplot
107
+ visualizer_util.ensure_backend(show)
108
+ import matplotlib.pyplot as plt
109
+ self._plt = plt
110
+ # ---------- WARM UP MPL ------------ #
111
+ _logger.debug("Visualizer.__init__ start | RIN_path=%s, figsize=%s, node_size=%s, edge_width=%s, depthshade=%s, antialiased=%s, init_view=(%s,%s)",
112
+ RIN_path, figsize, node_size, edge_width, depthshade, antialiased, init_elev, init_azim)
113
+ if Visualizer.no_instances:
114
+ _logger.debug("Warm-starting Matplotlib (no_instances=True).")
115
+ visualizer_util.warm_start_matplotlib()
116
+ else:
117
+ _logger.debug("Skipping warm-start (no_instances=False).")
118
+
119
+ # ---------- LOAD THE DATA ---------- #
120
+ with sawnergy_util.ArrayStorage(RIN_path, mode="r") as storage:
121
+ com_name = storage.get_attr("com_name")
122
+ attr_energies_name = storage.get_attr("attractive_energies_name")
123
+ repuls_energies_name = storage.get_attr("repulsive_energies_name")
124
+ self.COM_coords: np.ndarray = storage.read(com_name, slice(None))
125
+ self.attr_energies: np.ndarray = storage.read(attr_energies_name, slice(None)) if attr_energies_name is not None else None
126
+ self.repuls_energies: np.ndarray = storage.read(repuls_energies_name, slice(None)) if repuls_energies_name is not None else None
127
+ try:
128
+ _logger.debug("Loaded datasets | COM_coords.shape=%s, attr_energies.shape=%s, repuls_energies.shape=%s",
129
+ getattr(self.COM_coords, "shape", None),
130
+ getattr(self.attr_energies, "shape", None),
131
+ getattr(self.repuls_energies, "shape", None))
132
+ except Exception:
133
+ _logger.debug("Loaded datasets (shapes unavailable).")
134
+
135
+ self.N = np.size(self.COM_coords[0], axis=0)
136
+ _logger.debug("Computed N=%d", self.N)
137
+
138
+ # - SET UP THE CANVAS AND THE AXES - #
139
+ self._fig = plt.figure(figsize=figsize, num="SAWNERGY")
140
+ self._ax = self._fig.add_subplot(111, projection="3d")
141
+ self._fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
142
+ self._fig.patch.set_facecolor("#999999")
143
+ self._ax.set_autoscale_on(False)
144
+ self._ax.view_init(elev=init_elev, azim=init_azim)
145
+ self._ax.set_axis_off()
146
+ _logger.debug("Figure and 3D axes initialized.")
147
+
148
+ # ------ SET UP PLOT ELEMENTS ------ #
149
+ self._scatter: PathCollection = self._ax.scatter([], [], [], s=node_size, depthshade=depthshade, edgecolors="none")
150
+ self._attr: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
151
+ self._repuls: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
152
+ self._ax.add_collection3d(self._attr); self._ax.add_collection3d(self._repuls) # set pointers to the attractive and repulsive collections
153
+ _logger.debug("Artists created | scatter(empty), attr_lines(empty), repuls_lines(empty).")
154
+
155
+ # ---------- HELPER FIELDS --------- #
156
+ # NOTE: 'under the hood' everything is 0-base indexed,
157
+ # BUT, from the API point of view, the indexing is 1-base,
158
+ # because amino acid residues are 1-base indexed.
159
+ self._residue_norm = mpl.colors.Normalize(0, self.N-1) # uniform coloring
160
+ self.default_node_color = default_node_color
161
+ _logger.debug("Helper fields set | residue_norm=[0,%d], default_node_color=%s", self.N-1, self.default_node_color)
162
+
163
+ # DISALLOW MPL WARM-UP IN THE FUTURE
164
+ Visualizer.no_instances = False
165
+ _logger.debug("Visualizer.no_instances set to False.")
166
+
167
+ # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
168
+ # PRIVATE
169
+ # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
170
+
171
+ # --- UPDS ---
172
+ def _update_scatter(self, xyz, *, colors=None):
173
+ """Update scatter artist with new positions (and optional colors).
174
+
175
+ Args:
176
+ xyz: Array-like of shape (N_visible, 3) containing node positions
177
+ for the *currently displayed* nodes.
178
+ colors: Optional array of RGBA colors (len=N_visible) or a single
179
+ color broadcastable by Matplotlib.
180
+
181
+ Returns:
182
+ None
183
+ """
184
+ try:
185
+ _logger.debug("_update_scatter | xyz.shape=%s, colors=%s",
186
+ getattr(xyz, "shape", None),
187
+ "provided" if colors is not None else "None")
188
+ except Exception:
189
+ _logger.debug("_update_scatter called (shape unavailable).")
190
+ x, y, z = xyz.T
191
+ self._scatter._offsets3d = (x, y, z)
192
+ if colors is not None:
193
+ self._scatter.set_facecolors(colors)
194
+ _logger.debug("_update_scatter done | n_points=%s", len(x) if hasattr(x, "__len__") else "unknown")
195
+
196
+ def _update_attr_edges(self, segs, *, colors=None, opacity=None):
197
+ """Update attractive edge collection.
198
+
199
+ Args:
200
+ segs: Array of shape (E, 2, 3) with edge endpoints.
201
+ colors: Optional array of RGB/RGBA per-edge colors or a single color.
202
+ opacity: Optional array or scalar alpha(s). If both `colors` and
203
+ `opacity` are provided, alpha will be fused into the RGBA.
204
+
205
+ Returns:
206
+ None
207
+ """
208
+ _logger.debug("_update_attr_edges | segs.shape=%s, colors=%s, opacity=%s",
209
+ getattr(segs, "shape", None),
210
+ "provided" if colors is not None else "None",
211
+ "array/scalar" if opacity is not None else "None")
212
+ self._attr.set_segments(segs)
213
+ if colors is not None and opacity is not None:
214
+ rgba = np.array(colors, copy=True)
215
+ if rgba.ndim == 2 and rgba.shape[1] == 4:
216
+ rgba[:, 3] = opacity
217
+ else:
218
+ # map RGB to RGBA with alpha
219
+ rgba = np.c_[rgba, np.asarray(opacity)]
220
+ self._attr.set_colors(rgba)
221
+ else:
222
+ if colors is not None:
223
+ self._attr.set_colors(colors)
224
+ if opacity is not None:
225
+ self._attr.set_alpha(opacity)
226
+ _logger.debug("_update_attr_edges done.")
227
+
228
+ def _update_repuls_edges(self, segs, *, colors=None, opacity=None):
229
+ """Update repulsive edge collection.
230
+
231
+ Args:
232
+ segs: Array of shape (E, 2, 3) with edge endpoints.
233
+ colors: Optional array of RGB/RGBA per-edge colors or a single color.
234
+ opacity: Optional array or scalar alpha(s). If both `colors` and
235
+ `opacity` are provided, alpha will be fused into the RGBA.
236
+
237
+ Returns:
238
+ None
239
+ """
240
+ _logger.debug("_update_repuls_edges | segs.shape=%s, colors=%s, opacity=%s",
241
+ getattr(segs, "shape", None),
242
+ "provided" if colors is not None else "None",
243
+ "array/scalar" if opacity is not None else "None")
244
+ self._repuls.set_segments(segs)
245
+ if colors is not None and opacity is not None:
246
+ rgba = np.array(colors, copy=True)
247
+ if rgba.ndim == 2 and rgba.shape[1] == 4:
248
+ rgba[:, 3] = opacity
249
+ else:
250
+ rgba = np.c_[rgba, np.asarray(opacity)]
251
+ self._repuls.set_colors(rgba)
252
+ else:
253
+ if colors is not None:
254
+ self._repuls.set_colors(colors)
255
+ if opacity is not None:
256
+ self._repuls.set_alpha(opacity)
257
+ _logger.debug("_update_repuls_edges done.")
258
+
259
+ # --- CLEARS ---
260
+
261
+ def _clear_scatter(self):
262
+ """Clear node positions from the scatter artist.
263
+
264
+ Returns:
265
+ None
266
+ """
267
+ _logger.debug("_clear_scatter called.")
268
+ self._scatter._offsets3d = ([], [], [])
269
+
270
+ def _clear_attr_edges(self):
271
+ """Clear all attractive edges from the collection.
272
+
273
+ Returns:
274
+ None
275
+ """
276
+ _logger.debug("_clear_attr_edges called.")
277
+ self._attr.set_segments(np.empty((0, 2, 3)))
278
+
279
+ def _clear_repuls_edges(self):
280
+ """Clear all repulsive edges from the collection.
281
+
282
+ Returns:
283
+ None
284
+ """
285
+ _logger.debug("_clear_repuls_edges called.")
286
+ self._repuls.set_segments(np.empty((0, 2, 3)))
287
+
288
+ # --- FINAL UPD ---
289
+
290
+ def _update_canvas(self, *, pause_for: float = 0.0):
291
+ """Request a canvas redraw and optionally pause.
292
+
293
+ Args:
294
+ pause_for: If > 0, calls `plt.pause(pause_for)` to advance GUI
295
+ event loops and create a visible delay (useful in animations).
296
+
297
+ Returns:
298
+ None
299
+ """
300
+ _logger.debug("_update_canvas | pause_for=%s", pause_for)
301
+ self._fig.canvas.draw_idle()
302
+ if pause_for > 0.0:
303
+ self._plt.pause(pause_for)
304
+
305
+ # ADJUST THE VIEW
306
+ def _fix_view(self, coordinates: np.ndarray, padding: float, spread: float):
307
+ """Adjust axes limits/box aspect and apply optional spatial spreading.
308
+
309
+ Computes a bounding box around provided coordinates, expands it by
310
+ ``padding`` (relative to span), sets axes limits and box aspect, and
311
+ optionally spreads points around their centroid by factor ``spread``.
312
+
313
+ Args:
314
+ coordinates: Array of shape (M, 3) for currently displayed nodes.
315
+ padding: Fraction of the original span added to min/max on each axis.
316
+ spread: If != 1.0, multiply deviations from centroid by this factor.
317
+
318
+ Returns:
319
+ np.ndarray: Possibly modified copy of `coordinates` (same shape) if
320
+ `spread` was applied; otherwise the input array is returned.
321
+ """
322
+ _logger.debug("_fix_view | coords.shape=%s, padding=%s, spread=%s",
323
+ getattr(coordinates, "shape", None), padding, spread)
324
+
325
+ # Apply spread first so limits reflect the final positions
326
+ if spread != 1.0:
327
+ center = coordinates.mean(axis=0, keepdims=True)
328
+ coordinates = center + spread * (coordinates - center)
329
+ _logger.debug("_fix_view | applied spread around centroid.")
330
+
331
+ orig_min = coordinates.min(axis=0)
332
+ orig_max = coordinates.max(axis=0)
333
+ orig_span = np.maximum(orig_max - orig_min, 1e-12)
334
+ xyz_min = orig_min - padding * orig_span
335
+ xyz_max = orig_max + padding * orig_span
336
+
337
+ self._ax.set_xlim(xyz_min[0], xyz_max[0])
338
+ self._ax.set_ylim(xyz_min[1], xyz_max[1])
339
+ self._ax.set_zlim(xyz_min[2], xyz_max[2])
340
+ self._ax.set_box_aspect(np.maximum(xyz_max - xyz_min, 1e-12))
341
+ _logger.debug("_fix_view | bounds set: x=(%s,%s), y=(%s,%s), z=(%s,%s)",
342
+ xyz_min[0], xyz_max[0], xyz_min[1], xyz_max[1], xyz_min[2], xyz_max[2])
343
+
344
+ return coordinates
345
+
346
+ # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
347
+ # PUBLIC
348
+ # -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
349
+
350
+ def build_frame(
351
+ self,
352
+ frame_id: int,
353
+ displayed_nodes: np.typing.ArrayLike | Literal["ALL"] | None = "ALL",
354
+ displayed_pairwise_attraction_for_nodes: np.typing.ArrayLike | Literal["DISPLAYED_NODES"] | None = "DISPLAYED_NODES",
355
+ displayed_pairwise_repulsion_for_nodes: np.typing.ArrayLike | Literal["DISPLAYED_NODES"] | None = "DISPLAYED_NODES",
356
+ frac_node_interactions_displayed: float = 0.01, # 1%
357
+ global_interactions_frac: bool = True,
358
+ global_opacity: bool = True,
359
+ global_color_saturation: bool = True,
360
+ node_colors: str | tuple[tuple[Iterable[int], str]] | None = None,
361
+ title: str | None = None,
362
+ padding: float = 0.1,
363
+ spread: float = 1.0,
364
+ show: bool = False,
365
+ *,
366
+ show_node_labels: bool = False,
367
+ node_label_size: int = 6,
368
+ attractive_edge_cmap: str = visualizer_util.HEAT,
369
+ repulsive_edge_cmap: str = visualizer_util.COLD):
370
+ """Render a single frame into existing artists.
371
+
372
+ This updates node positions/colors and draws attractive/repulsive
373
+ edges chosen by a quantile threshold on weights. Indices passed from
374
+ the public API are interpreted as 1-based and converted to 0-based
375
+ internally.
376
+
377
+ Args:
378
+ frame_id: 1-based frame index to render.
379
+ displayed_nodes: Iterable of node indices to show (1-based),
380
+ "ALL" for all nodes, or None to return early.
381
+ displayed_pairwise_attraction_for_nodes: Iterable of node indices
382
+ (1-based) or the literal "DISPLAYED_NODES" to restrict
383
+ candidate attractive edges to those whose endpoints are both
384
+ in this set.
385
+ displayed_pairwise_repulsion_for_nodes: Same contract as
386
+ `displayed_pairwise_attraction_for_nodes` but for repulsive edges.
387
+ frac_node_interactions_displayed: Fraction of heaviest edges to keep
388
+ (approximate top-`frac`) after endpoint filtering.
389
+ global_interactions_frac: If True, the threshold uses all
390
+ upper-triangle weights; otherwise only candidate edges.
391
+ global_opacity: If True, opacity uses global row-wise normalization;
392
+ otherwise, normalization uses only kept edges (others set to 0).
393
+ global_color_saturation: If True, color uses global absolute
394
+ normalization; otherwise, normalization uses only kept edges.
395
+ node_colors: Either a Matplotlib colormap name (str) or a tuple of
396
+ (indices, hex_color) group pairs for per-node colors.
397
+ title: Optional title displayed in axis coordinates.
398
+ padding: Fractional padding around the displayed nodes' bounds.
399
+ spread: Spatial spread multiplier applied about centroid (displayed
400
+ nodes only).
401
+ show: If True, request a window display:
402
+ - In IPython/interactive mode, uses non-blocking `show(block=False)` and
403
+ a short `pause()` to flush the GUI loop.
404
+ - In non-interactive scripts, uses blocking `show()` at the end
405
+ of the draw step.
406
+ - If the backend is 'Agg' (headless fallback), this is a no-op for
407
+ windows; use `self._plt.savefig(...)` to persist images.
408
+
409
+ Returns:
410
+ None
411
+
412
+ Raises:
413
+ ValueError: If any requested attraction/repulsion node set is not a
414
+ subset of `displayed_nodes`, or invalid sentinel strings are used.
415
+ """
416
+ # PRELIMINARY
417
+ _logger.debug("build_frame called | frame_id(1-based)=%s, frac_node_interactions_displayed=%s, padding=%s, spread=%s, show=%s, show_node_labels=%s",
418
+ frame_id, frac_node_interactions_displayed, padding, spread, show, show_node_labels)
419
+ frame_id -= 1 # 1-base indexing
420
+ _logger.debug("build_frame | using frame_id(0-based)=%s", frame_id)
421
+
422
+ # NODES
423
+ if displayed_nodes is not None:
424
+ if isinstance(displayed_nodes, str):
425
+ if displayed_nodes == "ALL":
426
+ displayed_nodes = np.arange(0, self.N, 1)
427
+ _logger.debug("displayed_nodes='ALL' -> count=%d", displayed_nodes.size)
428
+ else:
429
+ _logger.error("Invalid displayed_nodes string: %s", displayed_nodes)
430
+ raise ValueError(
431
+ "'displayed_nodes' has to be either an ArrayLike "
432
+ "collection of node indices, or an 'ALL' string, "
433
+ "or None.")
434
+ else:
435
+ displayed_nodes = np.asarray(displayed_nodes)-1 # 1-base indexing
436
+ _logger.debug("displayed_nodes provided | count=%d", displayed_nodes.size)
437
+ else:
438
+ _logger.debug("displayed_nodes is None -> returning early.")
439
+ return
440
+
441
+ frame_coords = self.COM_coords[frame_id]
442
+ nodes = frame_coords[displayed_nodes]
443
+ _logger.debug("Selected nodes | nodes.shape=%s (before view fix)", getattr(nodes, "shape", None))
444
+
445
+ nodes = self._fix_view(nodes, padding, spread)
446
+ _logger.debug("Nodes after _fix_view | shape=%s", getattr(nodes, "shape", None))
447
+ coords_for_edges = frame_coords.copy()
448
+ coords_for_edges[displayed_nodes] = nodes
449
+
450
+ # ATTRACTIVE EDGES
451
+ if displayed_pairwise_attraction_for_nodes is not None:
452
+ if self.attr_energies is None:
453
+ _logger.warning("Attractive dataset unavailable; skipping attractive edges.")
454
+ attractive_edges = None
455
+ else:
456
+ if isinstance(displayed_pairwise_attraction_for_nodes, str):
457
+ if displayed_pairwise_attraction_for_nodes == "DISPLAYED_NODES":
458
+ displayed_pairwise_attraction_for_nodes = displayed_nodes
459
+ _logger.debug("Attraction nodes='DISPLAYED_NODES' -> count=%d", displayed_pairwise_attraction_for_nodes.size)
460
+ else:
461
+ _logger.error("Invalid attraction selector string: %s", displayed_pairwise_attraction_for_nodes)
462
+ raise ValueError(
463
+ "'displayed_pairwise_attraction_for_nodes' has to be either an ArrayLike "
464
+ "collection of node indices, or an 'DISPLAYED_NODES' string, "
465
+ "or None.")
466
+ else:
467
+ displayed_pairwise_attraction_for_nodes = np.asarray(displayed_pairwise_attraction_for_nodes)-1 # 1-base indexing
468
+ _logger.debug("Attraction nodes provided | count=%d", displayed_pairwise_attraction_for_nodes.size)
469
+
470
+ if np.setdiff1d(displayed_pairwise_attraction_for_nodes, displayed_nodes).size > 0:
471
+ _logger.error("Attraction nodes not a subset of displayed_nodes.")
472
+ raise ValueError("'displayed_pairwise_attraction_for_nodes' must be a subset of 'displayed_nodes'")
473
+
474
+ attractive_edges, attractive_color_weights, attractive_opacity_weights = \
475
+ visualizer_util.build_line_segments(
476
+ self.N,
477
+ displayed_pairwise_attraction_for_nodes,
478
+ coords_for_edges,
479
+ self.attr_energies[frame_id],
480
+ frac_node_interactions_displayed,
481
+ global_weights_frac=global_interactions_frac,
482
+ global_opacity=global_opacity,
483
+ global_color_saturation=global_color_saturation
484
+ )
485
+ _logger.debug("Attraction edges built | segs.shape=%s, color_w.shape=%s, opacity_w.shape=%s",
486
+ getattr(attractive_edges, "shape", None),
487
+ getattr(attractive_color_weights, "shape", None),
488
+ getattr(attractive_opacity_weights, "shape", None))
489
+ else:
490
+ attractive_edges = None
491
+ _logger.debug("Attraction edges skipped (selector=None).")
492
+
493
+ # REPULSIVE EDGES
494
+ if displayed_pairwise_repulsion_for_nodes is not None:
495
+ if self.repuls_energies is None:
496
+ _logger.warning("Repulsive dataset unavailable; skipping repulsive edges.")
497
+ repulsive_edges = None
498
+ else:
499
+ if isinstance(displayed_pairwise_repulsion_for_nodes, str):
500
+ if displayed_pairwise_repulsion_for_nodes == "DISPLAYED_NODES":
501
+ displayed_pairwise_repulsion_for_nodes = displayed_nodes
502
+ _logger.debug("Repulsion nodes='DISPLAYED_NODES' -> count=%d", displayed_pairwise_repulsion_for_nodes.size)
503
+ else:
504
+ _logger.error("Invalid repulsion selector string: %s", displayed_pairwise_repulsion_for_nodes)
505
+ raise ValueError(
506
+ "'displayed_pairwise_repulsion_for_nodes' has to be either an ArrayLike "
507
+ "collection of node indices, or an 'DISPLAYED_NODES' string, "
508
+ "or None.")
509
+ else:
510
+ displayed_pairwise_repulsion_for_nodes = np.asarray(displayed_pairwise_repulsion_for_nodes)-1 # 1-base indexing
511
+ _logger.debug("Repulsion nodes provided | count=%d", displayed_pairwise_repulsion_for_nodes.size)
512
+
513
+ if np.setdiff1d(displayed_pairwise_repulsion_for_nodes, displayed_nodes).size > 0:
514
+ _logger.error("Repulsion nodes not a subset of displayed_nodes.")
515
+ raise ValueError("'displayed_pairwise_repulsion_for_nodes' must be a subset of 'displayed_nodes'")
516
+
517
+ repulsive_edges, repulsive_color_weights, repulsive_opacity_weights = \
518
+ visualizer_util.build_line_segments(
519
+ self.N,
520
+ displayed_pairwise_repulsion_for_nodes,
521
+ coords_for_edges,
522
+ self.repuls_energies[frame_id],
523
+ frac_node_interactions_displayed,
524
+ global_weights_frac=global_interactions_frac,
525
+ global_opacity=global_opacity,
526
+ global_color_saturation=global_color_saturation
527
+ )
528
+ _logger.debug("Repulsion edges built | segs.shape=%s, color_w.shape=%s, opacity_w.shape=%s",
529
+ getattr(repulsive_edges, "shape", None),
530
+ getattr(repulsive_color_weights, "shape", None),
531
+ getattr(repulsive_opacity_weights, "shape", None))
532
+ else:
533
+ repulsive_edges = None
534
+ _logger.debug("Repulsion edges skipped (selector=None).")
535
+
536
+ # COLOR THE DATA POINTS
537
+ if isinstance(node_colors, str):
538
+ node_cmap = self._plt.get_cmap(node_colors)
539
+ idx0 = np.asarray(displayed_nodes, dtype=int)
540
+ color_array = node_cmap(self._residue_norm(idx0))
541
+ _logger.debug("Node colors via colormap '%s' | count=%d", node_colors, idx0.size)
542
+ else:
543
+ color_array_full = visualizer_util.map_groups_to_colors(
544
+ N=self.N,
545
+ groups=node_colors,
546
+ default_color=self.default_node_color,
547
+ one_based=True
548
+ )
549
+ color_array = np.asarray(color_array_full)[displayed_nodes]
550
+ _logger.debug("Node colors via groups/default | count=%d", color_array.shape[0])
551
+
552
+ # UPDATE CANVAS
553
+ self._update_scatter(nodes, colors=color_array)
554
+
555
+ if attractive_edges is not None:
556
+ attractive_cmap = self._plt.get_cmap(attractive_edge_cmap)
557
+ attr_rgba = attractive_cmap(attractive_color_weights) # (E,4)
558
+ attr_rgba[:, 3] = attractive_opacity_weights
559
+ self._update_attr_edges(attractive_edges,
560
+ colors=attr_rgba,
561
+ opacity=None)
562
+ _logger.debug("Attraction edges updated on canvas.")
563
+
564
+ if repulsive_edges is not None:
565
+ repulsive_cmap = self._plt.get_cmap(repulsive_edge_cmap)
566
+ rep_rgba = repulsive_cmap(repulsive_color_weights) # (E,4)
567
+ rep_rgba[:, 3] = repulsive_opacity_weights
568
+ self._update_repuls_edges(repulsive_edges,
569
+ colors=rep_rgba,
570
+ opacity=None)
571
+ _logger.debug("Repulsion edges updated on canvas.")
572
+
573
+ # EXTRAS
574
+ if title:
575
+ self._ax.text2D(0.5, 0.99, title, transform=self._ax.transAxes,
576
+ ha="center", va="top")
577
+ _logger.debug("Title set: %s", title)
578
+
579
+ if show_node_labels:
580
+ labs = (np.asarray(displayed_nodes, dtype=int) + 1) # one-based labels
581
+ _logger.debug("Adding node labels | count=%d, fontsize=%d", labs.size, node_label_size)
582
+ for (x, y, z), lab in zip(nodes, labs):
583
+ self._ax.text(float(x)+.3, float(y)+.3, float(z)+1.3, str(lab),
584
+ fontsize=node_label_size, color="k")
585
+
586
+ if show:
587
+ # auto-block in scripts; non-block in notebooks/interactive
588
+ try:
589
+ get_ipython # type: ignore
590
+ in_ipy = True
591
+ except NameError:
592
+ in_ipy = False
593
+
594
+ _logger.debug("Showing figure | in_ipy=%s, interactive=%s", in_ipy, self._plt.isinteractive())
595
+
596
+ if in_ipy or self._plt.isinteractive():
597
+ self._plt.show(block=False)
598
+ self._plt.pause(0.05)
599
+ else:
600
+ self._plt.show()
601
+ _logger.debug("build_frame completed.")
602
+
603
+ def animate_trajectory(
604
+ self,
605
+ start: int = 1,
606
+ stop: int | None = None,
607
+ step: int = 1,
608
+ interval_ms: int = 50,
609
+ loop: bool = False,
610
+ **build_kwargs,
611
+ ):
612
+ """Play frames as an animation by reusing existing artists.
613
+
614
+ Iterates frames from `start` to `stop` (inclusive, stepping by `step`)
615
+ and calls `build_frame` for each, pausing `interval_ms` between
616
+ updates. If `loop=True`, the sequence repeats until the figure is
617
+ closed or the user interrupts.
618
+
619
+ Args:
620
+ start: 1-based starting frame index.
621
+ stop: 1-based ending frame index (inclusive). Defaults to the last
622
+ available frame if None.
623
+ step: Step size between frames. Negative values play backwards.
624
+ interval_ms: Pause between frames in milliseconds.
625
+ loop: If True, repeat the frame sequence indefinitely (until the
626
+ figure is closed or interrupted).
627
+ **build_kwargs: Additional keyword arguments forwarded to
628
+ `build_frame` (e.g., `displayed_nodes`, `padding`, `spread`,
629
+ `node_colors`, etc.). `show=False` is enforced internally.
630
+
631
+ Returns:
632
+ None
633
+
634
+ Raises:
635
+ ValueError: If `step` is zero.
636
+
637
+ Notes:
638
+ - Internally enforces `build_kwargs["show"] = False` during iteration to
639
+ avoid blocking; a final `self._plt.show()` is issued at the end of a
640
+ single pass.
641
+ - In headless mode (backend 'Agg'), no GUI window appears; use
642
+ `self._plt.savefig(...)` or a writer to export frames.
643
+ """
644
+ _logger.debug(
645
+ "animate_trajectory | start=%s, stop=%s, step=%s, interval_ms=%s, loop=%s",
646
+ start, stop, step, interval_ms, loop
647
+ )
648
+
649
+ # default to all frames
650
+ T = int(self.COM_coords.shape[0])
651
+ if stop is None:
652
+ stop = T
653
+
654
+ if step == 0:
655
+ raise ValueError("step must be non-zero")
656
+
657
+ # build the list of frame ids
658
+ frames = list(range(start, stop + (1 if step > 0 else -1), step)) # allow for backward play
659
+ if not frames:
660
+ _logger.debug("animate_trajectory | empty frame list -> return")
661
+ return
662
+
663
+ build_kwargs["show"] = False
664
+
665
+ try:
666
+ if loop:
667
+ _logger.debug("animate_trajectory | entering repeat loop until window closed.")
668
+ while self._plt.fignum_exists(self._fig.number):
669
+ for fid in frames:
670
+ if not self._plt.fignum_exists(self._fig.number):
671
+ break
672
+ self.build_frame(fid, **build_kwargs)
673
+ self._update_canvas(pause_for=interval_ms / 1000.0)
674
+ else:
675
+ _logger.debug("animate_trajectory | single pass over frames.")
676
+ for fid in frames:
677
+ self.build_frame(fid, **build_kwargs)
678
+ self._update_canvas(pause_for=interval_ms / 1000.0)
679
+ # one final show so the window stays up when the loop ends
680
+ self._plt.show()
681
+ except KeyboardInterrupt:
682
+ _logger.debug("animate_trajectory | interrupted by user (KeyboardInterrupt).")
683
+
684
+
685
+ __all__ = [
686
+ "Visualizer"
687
+ ]
688
+
689
+ if __name__ == "__main__":
690
+ pass