sawnergy 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of sawnergy might be problematic. Click here for more details.
- sawnergy/__init__.py +13 -0
- sawnergy/embedding/SGNS_pml.py +135 -0
- sawnergy/embedding/SGNS_torch.py +177 -0
- sawnergy/embedding/__init__.py +34 -0
- sawnergy/embedding/embedder.py +578 -0
- sawnergy/logging_util.py +54 -0
- sawnergy/rin/__init__.py +9 -0
- sawnergy/rin/rin_builder.py +936 -0
- sawnergy/rin/rin_util.py +391 -0
- sawnergy/sawnergy_util.py +1182 -0
- sawnergy/visual/__init__.py +42 -0
- sawnergy/visual/visualizer.py +690 -0
- sawnergy/visual/visualizer_util.py +387 -0
- sawnergy/walks/__init__.py +16 -0
- sawnergy/walks/walker.py +795 -0
- sawnergy/walks/walker_util.py +384 -0
- sawnergy-1.0.0.dist-info/METADATA +290 -0
- sawnergy-1.0.0.dist-info/RECORD +22 -0
- sawnergy-1.0.0.dist-info/WHEEL +5 -0
- sawnergy-1.0.0.dist-info/licenses/LICENSE +201 -0
- sawnergy-1.0.0.dist-info/licenses/NOTICE +4 -0
- sawnergy-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,690 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# third-pary
|
|
4
|
+
import numpy as np
|
|
5
|
+
import matplotlib as mpl
|
|
6
|
+
from mpl_toolkits.mplot3d.art3d import Line3DCollection
|
|
7
|
+
from matplotlib.collections import PathCollection
|
|
8
|
+
# built-in
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Iterable, Literal
|
|
11
|
+
import logging
|
|
12
|
+
# local
|
|
13
|
+
from . import visualizer_util
|
|
14
|
+
from .. import sawnergy_util
|
|
15
|
+
|
|
16
|
+
# *----------------------------------------------------*
|
|
17
|
+
# GLOBALS
|
|
18
|
+
# *----------------------------------------------------*
|
|
19
|
+
|
|
20
|
+
_logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
# *----------------------------------------------------*
|
|
23
|
+
# CLASSES
|
|
24
|
+
# *----------------------------------------------------*
|
|
25
|
+
|
|
26
|
+
class Visualizer:
|
|
27
|
+
"""3D network/trajectory visualizer.
|
|
28
|
+
|
|
29
|
+
This class renders nodes (scatter) and pairwise interactions (line segments)
|
|
30
|
+
for frames of a trajectory stored in an ArrayStorage-backed file (e.g., Zarr
|
|
31
|
+
in a ZIP). It supports showing only a subset of nodes, coloring nodes by
|
|
32
|
+
groups or a colormap, drawing attractive/repulsive edges by weight quantiles,
|
|
33
|
+
and animating the full trajectory.
|
|
34
|
+
|
|
35
|
+
Backend & GUI behavior:
|
|
36
|
+
- The Matplotlib backend is chosen in __init__ via
|
|
37
|
+
visualizer_util.ensure_backend(show), *before* importing pyplot.
|
|
38
|
+
If `show=True` but no GUI/display is available (e.g., headless Linux),
|
|
39
|
+
the backend is switched to 'Agg' and a warning is emitted. In this mode
|
|
40
|
+
figures render off-screen; use savefig() instead of interactive windows.
|
|
41
|
+
- pyplot is imported lazily inside __init__ after backend selection and
|
|
42
|
+
stored as `self._plt` to keep backend control deterministic.
|
|
43
|
+
|
|
44
|
+
Attributes:
|
|
45
|
+
no_instances: Class-level flag to warm-start Matplotlib only once.
|
|
46
|
+
COM_coords: Trajectory coordinates, shape (T, N, 3).
|
|
47
|
+
attr_energies: Attractive weights (shape (T, N, N)) or None if absent.
|
|
48
|
+
repuls_energies: Repulsive weights (shape (T, N, N)) or None if absent.
|
|
49
|
+
N: Number of nodes (int).
|
|
50
|
+
_fig: Matplotlib Figure.
|
|
51
|
+
_ax: 3D Axes.
|
|
52
|
+
_scatter: PathCollection for node markers.
|
|
53
|
+
_attr: Line3DCollection for attractive edges.
|
|
54
|
+
_repuls: Line3DCollection for repulsive edges.
|
|
55
|
+
_residue_norm: Normalizer mapping [0, N-1] to [0, 1] for colormaps.
|
|
56
|
+
default_node_color: Hex color string used when no group color is set.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
no_instances: bool = True
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
RIN_path: str | Path,
|
|
64
|
+
figsize: tuple[int, int] = (9, 7),
|
|
65
|
+
node_size: int = 120,
|
|
66
|
+
edge_width: float = 1.25,
|
|
67
|
+
default_node_color: str = visualizer_util.GRAY,
|
|
68
|
+
depthshade: bool = False,
|
|
69
|
+
antialiased: bool = False,
|
|
70
|
+
init_elev: float = 35,
|
|
71
|
+
init_azim: float = 45,
|
|
72
|
+
*,
|
|
73
|
+
show: bool = False
|
|
74
|
+
) -> None:
|
|
75
|
+
"""Initialize the visualizer and load datasets.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
RIN_path: Path to the archive or store containing datasets.
|
|
79
|
+
figsize: Figure size (inches) for the Matplotlib window.
|
|
80
|
+
node_size: Marker area for nodes (passed to `Axes3D.scatter`).
|
|
81
|
+
edge_width: Line width for edge collections.
|
|
82
|
+
default_node_color: Hex color used for nodes not in any group.
|
|
83
|
+
depthshade: Whether to apply depth shading to scatter points.
|
|
84
|
+
antialiased: Whether to antialias line collections.
|
|
85
|
+
init_elev: Initial elevation angle (degrees) for 3D view.
|
|
86
|
+
init_azim: Initial azimuth angle (degrees) for 3D view.
|
|
87
|
+
show: Hint about intended usage. If True and a GUI/display is available,
|
|
88
|
+
interactive windows can be shown later (e.g., via `self._plt.show()`).
|
|
89
|
+
If True but no GUI/display is available, the backend is switched to
|
|
90
|
+
'Agg' (off-screen) and a warning is issued. This flag does not itself
|
|
91
|
+
call `show()`; it only influences backend selection.
|
|
92
|
+
|
|
93
|
+
Data discovery:
|
|
94
|
+
Dataset names are auto-resolved from storage attrs:
|
|
95
|
+
'com_name', 'attractive_energies_name', 'repulsive_energies_name'.
|
|
96
|
+
Any missing channel remains disabled (None) but visualization still works
|
|
97
|
+
just without edges of a specific missing type.
|
|
98
|
+
|
|
99
|
+
Side Effects:
|
|
100
|
+
- Selects a Matplotlib backend before importing pyplot; may fall back
|
|
101
|
+
to 'Agg' in headless environments when `show=True`.
|
|
102
|
+
- Optionally warms up Matplotlib once per process (first instance only).
|
|
103
|
+
- Opens and reads required datasets from storage.
|
|
104
|
+
- Creates a figure, 3D axes, and empty artists for later updates.
|
|
105
|
+
"""
|
|
106
|
+
# choose GUI backend before importing pyplot
|
|
107
|
+
visualizer_util.ensure_backend(show)
|
|
108
|
+
import matplotlib.pyplot as plt
|
|
109
|
+
self._plt = plt
|
|
110
|
+
# ---------- WARM UP MPL ------------ #
|
|
111
|
+
_logger.debug("Visualizer.__init__ start | RIN_path=%s, figsize=%s, node_size=%s, edge_width=%s, depthshade=%s, antialiased=%s, init_view=(%s,%s)",
|
|
112
|
+
RIN_path, figsize, node_size, edge_width, depthshade, antialiased, init_elev, init_azim)
|
|
113
|
+
if Visualizer.no_instances:
|
|
114
|
+
_logger.debug("Warm-starting Matplotlib (no_instances=True).")
|
|
115
|
+
visualizer_util.warm_start_matplotlib()
|
|
116
|
+
else:
|
|
117
|
+
_logger.debug("Skipping warm-start (no_instances=False).")
|
|
118
|
+
|
|
119
|
+
# ---------- LOAD THE DATA ---------- #
|
|
120
|
+
with sawnergy_util.ArrayStorage(RIN_path, mode="r") as storage:
|
|
121
|
+
com_name = storage.get_attr("com_name")
|
|
122
|
+
attr_energies_name = storage.get_attr("attractive_energies_name")
|
|
123
|
+
repuls_energies_name = storage.get_attr("repulsive_energies_name")
|
|
124
|
+
self.COM_coords: np.ndarray = storage.read(com_name, slice(None))
|
|
125
|
+
self.attr_energies: np.ndarray = storage.read(attr_energies_name, slice(None)) if attr_energies_name is not None else None
|
|
126
|
+
self.repuls_energies: np.ndarray = storage.read(repuls_energies_name, slice(None)) if repuls_energies_name is not None else None
|
|
127
|
+
try:
|
|
128
|
+
_logger.debug("Loaded datasets | COM_coords.shape=%s, attr_energies.shape=%s, repuls_energies.shape=%s",
|
|
129
|
+
getattr(self.COM_coords, "shape", None),
|
|
130
|
+
getattr(self.attr_energies, "shape", None),
|
|
131
|
+
getattr(self.repuls_energies, "shape", None))
|
|
132
|
+
except Exception:
|
|
133
|
+
_logger.debug("Loaded datasets (shapes unavailable).")
|
|
134
|
+
|
|
135
|
+
self.N = np.size(self.COM_coords[0], axis=0)
|
|
136
|
+
_logger.debug("Computed N=%d", self.N)
|
|
137
|
+
|
|
138
|
+
# - SET UP THE CANVAS AND THE AXES - #
|
|
139
|
+
self._fig = plt.figure(figsize=figsize, num="SAWNERGY")
|
|
140
|
+
self._ax = self._fig.add_subplot(111, projection="3d")
|
|
141
|
+
self._fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
|
142
|
+
self._fig.patch.set_facecolor("#999999")
|
|
143
|
+
self._ax.set_autoscale_on(False)
|
|
144
|
+
self._ax.view_init(elev=init_elev, azim=init_azim)
|
|
145
|
+
self._ax.set_axis_off()
|
|
146
|
+
_logger.debug("Figure and 3D axes initialized.")
|
|
147
|
+
|
|
148
|
+
# ------ SET UP PLOT ELEMENTS ------ #
|
|
149
|
+
self._scatter: PathCollection = self._ax.scatter([], [], [], s=node_size, depthshade=depthshade, edgecolors="none")
|
|
150
|
+
self._attr: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
|
|
151
|
+
self._repuls: Line3DCollection = Line3DCollection(np.empty((0,2,3)), linewidths=edge_width, antialiased=antialiased)
|
|
152
|
+
self._ax.add_collection3d(self._attr); self._ax.add_collection3d(self._repuls) # set pointers to the attractive and repulsive collections
|
|
153
|
+
_logger.debug("Artists created | scatter(empty), attr_lines(empty), repuls_lines(empty).")
|
|
154
|
+
|
|
155
|
+
# ---------- HELPER FIELDS --------- #
|
|
156
|
+
# NOTE: 'under the hood' everything is 0-base indexed,
|
|
157
|
+
# BUT, from the API point of view, the indexing is 1-base,
|
|
158
|
+
# because amino acid residues are 1-base indexed.
|
|
159
|
+
self._residue_norm = mpl.colors.Normalize(0, self.N-1) # uniform coloring
|
|
160
|
+
self.default_node_color = default_node_color
|
|
161
|
+
_logger.debug("Helper fields set | residue_norm=[0,%d], default_node_color=%s", self.N-1, self.default_node_color)
|
|
162
|
+
|
|
163
|
+
# DISALLOW MPL WARM-UP IN THE FUTURE
|
|
164
|
+
Visualizer.no_instances = False
|
|
165
|
+
_logger.debug("Visualizer.no_instances set to False.")
|
|
166
|
+
|
|
167
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
|
|
168
|
+
# PRIVATE
|
|
169
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
|
|
170
|
+
|
|
171
|
+
# --- UPDS ---
|
|
172
|
+
def _update_scatter(self, xyz, *, colors=None):
|
|
173
|
+
"""Update scatter artist with new positions (and optional colors).
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
xyz: Array-like of shape (N_visible, 3) containing node positions
|
|
177
|
+
for the *currently displayed* nodes.
|
|
178
|
+
colors: Optional array of RGBA colors (len=N_visible) or a single
|
|
179
|
+
color broadcastable by Matplotlib.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
None
|
|
183
|
+
"""
|
|
184
|
+
try:
|
|
185
|
+
_logger.debug("_update_scatter | xyz.shape=%s, colors=%s",
|
|
186
|
+
getattr(xyz, "shape", None),
|
|
187
|
+
"provided" if colors is not None else "None")
|
|
188
|
+
except Exception:
|
|
189
|
+
_logger.debug("_update_scatter called (shape unavailable).")
|
|
190
|
+
x, y, z = xyz.T
|
|
191
|
+
self._scatter._offsets3d = (x, y, z)
|
|
192
|
+
if colors is not None:
|
|
193
|
+
self._scatter.set_facecolors(colors)
|
|
194
|
+
_logger.debug("_update_scatter done | n_points=%s", len(x) if hasattr(x, "__len__") else "unknown")
|
|
195
|
+
|
|
196
|
+
def _update_attr_edges(self, segs, *, colors=None, opacity=None):
|
|
197
|
+
"""Update attractive edge collection.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
segs: Array of shape (E, 2, 3) with edge endpoints.
|
|
201
|
+
colors: Optional array of RGB/RGBA per-edge colors or a single color.
|
|
202
|
+
opacity: Optional array or scalar alpha(s). If both `colors` and
|
|
203
|
+
`opacity` are provided, alpha will be fused into the RGBA.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
None
|
|
207
|
+
"""
|
|
208
|
+
_logger.debug("_update_attr_edges | segs.shape=%s, colors=%s, opacity=%s",
|
|
209
|
+
getattr(segs, "shape", None),
|
|
210
|
+
"provided" if colors is not None else "None",
|
|
211
|
+
"array/scalar" if opacity is not None else "None")
|
|
212
|
+
self._attr.set_segments(segs)
|
|
213
|
+
if colors is not None and opacity is not None:
|
|
214
|
+
rgba = np.array(colors, copy=True)
|
|
215
|
+
if rgba.ndim == 2 and rgba.shape[1] == 4:
|
|
216
|
+
rgba[:, 3] = opacity
|
|
217
|
+
else:
|
|
218
|
+
# map RGB to RGBA with alpha
|
|
219
|
+
rgba = np.c_[rgba, np.asarray(opacity)]
|
|
220
|
+
self._attr.set_colors(rgba)
|
|
221
|
+
else:
|
|
222
|
+
if colors is not None:
|
|
223
|
+
self._attr.set_colors(colors)
|
|
224
|
+
if opacity is not None:
|
|
225
|
+
self._attr.set_alpha(opacity)
|
|
226
|
+
_logger.debug("_update_attr_edges done.")
|
|
227
|
+
|
|
228
|
+
def _update_repuls_edges(self, segs, *, colors=None, opacity=None):
|
|
229
|
+
"""Update repulsive edge collection.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
segs: Array of shape (E, 2, 3) with edge endpoints.
|
|
233
|
+
colors: Optional array of RGB/RGBA per-edge colors or a single color.
|
|
234
|
+
opacity: Optional array or scalar alpha(s). If both `colors` and
|
|
235
|
+
`opacity` are provided, alpha will be fused into the RGBA.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
None
|
|
239
|
+
"""
|
|
240
|
+
_logger.debug("_update_repuls_edges | segs.shape=%s, colors=%s, opacity=%s",
|
|
241
|
+
getattr(segs, "shape", None),
|
|
242
|
+
"provided" if colors is not None else "None",
|
|
243
|
+
"array/scalar" if opacity is not None else "None")
|
|
244
|
+
self._repuls.set_segments(segs)
|
|
245
|
+
if colors is not None and opacity is not None:
|
|
246
|
+
rgba = np.array(colors, copy=True)
|
|
247
|
+
if rgba.ndim == 2 and rgba.shape[1] == 4:
|
|
248
|
+
rgba[:, 3] = opacity
|
|
249
|
+
else:
|
|
250
|
+
rgba = np.c_[rgba, np.asarray(opacity)]
|
|
251
|
+
self._repuls.set_colors(rgba)
|
|
252
|
+
else:
|
|
253
|
+
if colors is not None:
|
|
254
|
+
self._repuls.set_colors(colors)
|
|
255
|
+
if opacity is not None:
|
|
256
|
+
self._repuls.set_alpha(opacity)
|
|
257
|
+
_logger.debug("_update_repuls_edges done.")
|
|
258
|
+
|
|
259
|
+
# --- CLEARS ---
|
|
260
|
+
|
|
261
|
+
def _clear_scatter(self):
|
|
262
|
+
"""Clear node positions from the scatter artist.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
None
|
|
266
|
+
"""
|
|
267
|
+
_logger.debug("_clear_scatter called.")
|
|
268
|
+
self._scatter._offsets3d = ([], [], [])
|
|
269
|
+
|
|
270
|
+
def _clear_attr_edges(self):
|
|
271
|
+
"""Clear all attractive edges from the collection.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
None
|
|
275
|
+
"""
|
|
276
|
+
_logger.debug("_clear_attr_edges called.")
|
|
277
|
+
self._attr.set_segments(np.empty((0, 2, 3)))
|
|
278
|
+
|
|
279
|
+
def _clear_repuls_edges(self):
|
|
280
|
+
"""Clear all repulsive edges from the collection.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
None
|
|
284
|
+
"""
|
|
285
|
+
_logger.debug("_clear_repuls_edges called.")
|
|
286
|
+
self._repuls.set_segments(np.empty((0, 2, 3)))
|
|
287
|
+
|
|
288
|
+
# --- FINAL UPD ---
|
|
289
|
+
|
|
290
|
+
def _update_canvas(self, *, pause_for: float = 0.0):
|
|
291
|
+
"""Request a canvas redraw and optionally pause.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
pause_for: If > 0, calls `plt.pause(pause_for)` to advance GUI
|
|
295
|
+
event loops and create a visible delay (useful in animations).
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
None
|
|
299
|
+
"""
|
|
300
|
+
_logger.debug("_update_canvas | pause_for=%s", pause_for)
|
|
301
|
+
self._fig.canvas.draw_idle()
|
|
302
|
+
if pause_for > 0.0:
|
|
303
|
+
self._plt.pause(pause_for)
|
|
304
|
+
|
|
305
|
+
# ADJUST THE VIEW
|
|
306
|
+
def _fix_view(self, coordinates: np.ndarray, padding: float, spread: float):
|
|
307
|
+
"""Adjust axes limits/box aspect and apply optional spatial spreading.
|
|
308
|
+
|
|
309
|
+
Computes a bounding box around provided coordinates, expands it by
|
|
310
|
+
``padding`` (relative to span), sets axes limits and box aspect, and
|
|
311
|
+
optionally spreads points around their centroid by factor ``spread``.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
coordinates: Array of shape (M, 3) for currently displayed nodes.
|
|
315
|
+
padding: Fraction of the original span added to min/max on each axis.
|
|
316
|
+
spread: If != 1.0, multiply deviations from centroid by this factor.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
np.ndarray: Possibly modified copy of `coordinates` (same shape) if
|
|
320
|
+
`spread` was applied; otherwise the input array is returned.
|
|
321
|
+
"""
|
|
322
|
+
_logger.debug("_fix_view | coords.shape=%s, padding=%s, spread=%s",
|
|
323
|
+
getattr(coordinates, "shape", None), padding, spread)
|
|
324
|
+
|
|
325
|
+
# Apply spread first so limits reflect the final positions
|
|
326
|
+
if spread != 1.0:
|
|
327
|
+
center = coordinates.mean(axis=0, keepdims=True)
|
|
328
|
+
coordinates = center + spread * (coordinates - center)
|
|
329
|
+
_logger.debug("_fix_view | applied spread around centroid.")
|
|
330
|
+
|
|
331
|
+
orig_min = coordinates.min(axis=0)
|
|
332
|
+
orig_max = coordinates.max(axis=0)
|
|
333
|
+
orig_span = np.maximum(orig_max - orig_min, 1e-12)
|
|
334
|
+
xyz_min = orig_min - padding * orig_span
|
|
335
|
+
xyz_max = orig_max + padding * orig_span
|
|
336
|
+
|
|
337
|
+
self._ax.set_xlim(xyz_min[0], xyz_max[0])
|
|
338
|
+
self._ax.set_ylim(xyz_min[1], xyz_max[1])
|
|
339
|
+
self._ax.set_zlim(xyz_min[2], xyz_max[2])
|
|
340
|
+
self._ax.set_box_aspect(np.maximum(xyz_max - xyz_min, 1e-12))
|
|
341
|
+
_logger.debug("_fix_view | bounds set: x=(%s,%s), y=(%s,%s), z=(%s,%s)",
|
|
342
|
+
xyz_min[0], xyz_max[0], xyz_min[1], xyz_max[1], xyz_min[2], xyz_max[2])
|
|
343
|
+
|
|
344
|
+
return coordinates
|
|
345
|
+
|
|
346
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
|
|
347
|
+
# PUBLIC
|
|
348
|
+
# -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= #
|
|
349
|
+
|
|
350
|
+
def build_frame(
|
|
351
|
+
self,
|
|
352
|
+
frame_id: int,
|
|
353
|
+
displayed_nodes: np.typing.ArrayLike | Literal["ALL"] | None = "ALL",
|
|
354
|
+
displayed_pairwise_attraction_for_nodes: np.typing.ArrayLike | Literal["DISPLAYED_NODES"] | None = "DISPLAYED_NODES",
|
|
355
|
+
displayed_pairwise_repulsion_for_nodes: np.typing.ArrayLike | Literal["DISPLAYED_NODES"] | None = "DISPLAYED_NODES",
|
|
356
|
+
frac_node_interactions_displayed: float = 0.01, # 1%
|
|
357
|
+
global_interactions_frac: bool = True,
|
|
358
|
+
global_opacity: bool = True,
|
|
359
|
+
global_color_saturation: bool = True,
|
|
360
|
+
node_colors: str | tuple[tuple[Iterable[int], str]] | None = None,
|
|
361
|
+
title: str | None = None,
|
|
362
|
+
padding: float = 0.1,
|
|
363
|
+
spread: float = 1.0,
|
|
364
|
+
show: bool = False,
|
|
365
|
+
*,
|
|
366
|
+
show_node_labels: bool = False,
|
|
367
|
+
node_label_size: int = 6,
|
|
368
|
+
attractive_edge_cmap: str = visualizer_util.HEAT,
|
|
369
|
+
repulsive_edge_cmap: str = visualizer_util.COLD):
|
|
370
|
+
"""Render a single frame into existing artists.
|
|
371
|
+
|
|
372
|
+
This updates node positions/colors and draws attractive/repulsive
|
|
373
|
+
edges chosen by a quantile threshold on weights. Indices passed from
|
|
374
|
+
the public API are interpreted as 1-based and converted to 0-based
|
|
375
|
+
internally.
|
|
376
|
+
|
|
377
|
+
Args:
|
|
378
|
+
frame_id: 1-based frame index to render.
|
|
379
|
+
displayed_nodes: Iterable of node indices to show (1-based),
|
|
380
|
+
"ALL" for all nodes, or None to return early.
|
|
381
|
+
displayed_pairwise_attraction_for_nodes: Iterable of node indices
|
|
382
|
+
(1-based) or the literal "DISPLAYED_NODES" to restrict
|
|
383
|
+
candidate attractive edges to those whose endpoints are both
|
|
384
|
+
in this set.
|
|
385
|
+
displayed_pairwise_repulsion_for_nodes: Same contract as
|
|
386
|
+
`displayed_pairwise_attraction_for_nodes` but for repulsive edges.
|
|
387
|
+
frac_node_interactions_displayed: Fraction of heaviest edges to keep
|
|
388
|
+
(approximate top-`frac`) after endpoint filtering.
|
|
389
|
+
global_interactions_frac: If True, the threshold uses all
|
|
390
|
+
upper-triangle weights; otherwise only candidate edges.
|
|
391
|
+
global_opacity: If True, opacity uses global row-wise normalization;
|
|
392
|
+
otherwise, normalization uses only kept edges (others set to 0).
|
|
393
|
+
global_color_saturation: If True, color uses global absolute
|
|
394
|
+
normalization; otherwise, normalization uses only kept edges.
|
|
395
|
+
node_colors: Either a Matplotlib colormap name (str) or a tuple of
|
|
396
|
+
(indices, hex_color) group pairs for per-node colors.
|
|
397
|
+
title: Optional title displayed in axis coordinates.
|
|
398
|
+
padding: Fractional padding around the displayed nodes' bounds.
|
|
399
|
+
spread: Spatial spread multiplier applied about centroid (displayed
|
|
400
|
+
nodes only).
|
|
401
|
+
show: If True, request a window display:
|
|
402
|
+
- In IPython/interactive mode, uses non-blocking `show(block=False)` and
|
|
403
|
+
a short `pause()` to flush the GUI loop.
|
|
404
|
+
- In non-interactive scripts, uses blocking `show()` at the end
|
|
405
|
+
of the draw step.
|
|
406
|
+
- If the backend is 'Agg' (headless fallback), this is a no-op for
|
|
407
|
+
windows; use `self._plt.savefig(...)` to persist images.
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
None
|
|
411
|
+
|
|
412
|
+
Raises:
|
|
413
|
+
ValueError: If any requested attraction/repulsion node set is not a
|
|
414
|
+
subset of `displayed_nodes`, or invalid sentinel strings are used.
|
|
415
|
+
"""
|
|
416
|
+
# PRELIMINARY
|
|
417
|
+
_logger.debug("build_frame called | frame_id(1-based)=%s, frac_node_interactions_displayed=%s, padding=%s, spread=%s, show=%s, show_node_labels=%s",
|
|
418
|
+
frame_id, frac_node_interactions_displayed, padding, spread, show, show_node_labels)
|
|
419
|
+
frame_id -= 1 # 1-base indexing
|
|
420
|
+
_logger.debug("build_frame | using frame_id(0-based)=%s", frame_id)
|
|
421
|
+
|
|
422
|
+
# NODES
|
|
423
|
+
if displayed_nodes is not None:
|
|
424
|
+
if isinstance(displayed_nodes, str):
|
|
425
|
+
if displayed_nodes == "ALL":
|
|
426
|
+
displayed_nodes = np.arange(0, self.N, 1)
|
|
427
|
+
_logger.debug("displayed_nodes='ALL' -> count=%d", displayed_nodes.size)
|
|
428
|
+
else:
|
|
429
|
+
_logger.error("Invalid displayed_nodes string: %s", displayed_nodes)
|
|
430
|
+
raise ValueError(
|
|
431
|
+
"'displayed_nodes' has to be either an ArrayLike "
|
|
432
|
+
"collection of node indices, or an 'ALL' string, "
|
|
433
|
+
"or None.")
|
|
434
|
+
else:
|
|
435
|
+
displayed_nodes = np.asarray(displayed_nodes)-1 # 1-base indexing
|
|
436
|
+
_logger.debug("displayed_nodes provided | count=%d", displayed_nodes.size)
|
|
437
|
+
else:
|
|
438
|
+
_logger.debug("displayed_nodes is None -> returning early.")
|
|
439
|
+
return
|
|
440
|
+
|
|
441
|
+
frame_coords = self.COM_coords[frame_id]
|
|
442
|
+
nodes = frame_coords[displayed_nodes]
|
|
443
|
+
_logger.debug("Selected nodes | nodes.shape=%s (before view fix)", getattr(nodes, "shape", None))
|
|
444
|
+
|
|
445
|
+
nodes = self._fix_view(nodes, padding, spread)
|
|
446
|
+
_logger.debug("Nodes after _fix_view | shape=%s", getattr(nodes, "shape", None))
|
|
447
|
+
coords_for_edges = frame_coords.copy()
|
|
448
|
+
coords_for_edges[displayed_nodes] = nodes
|
|
449
|
+
|
|
450
|
+
# ATTRACTIVE EDGES
|
|
451
|
+
if displayed_pairwise_attraction_for_nodes is not None:
|
|
452
|
+
if self.attr_energies is None:
|
|
453
|
+
_logger.warning("Attractive dataset unavailable; skipping attractive edges.")
|
|
454
|
+
attractive_edges = None
|
|
455
|
+
else:
|
|
456
|
+
if isinstance(displayed_pairwise_attraction_for_nodes, str):
|
|
457
|
+
if displayed_pairwise_attraction_for_nodes == "DISPLAYED_NODES":
|
|
458
|
+
displayed_pairwise_attraction_for_nodes = displayed_nodes
|
|
459
|
+
_logger.debug("Attraction nodes='DISPLAYED_NODES' -> count=%d", displayed_pairwise_attraction_for_nodes.size)
|
|
460
|
+
else:
|
|
461
|
+
_logger.error("Invalid attraction selector string: %s", displayed_pairwise_attraction_for_nodes)
|
|
462
|
+
raise ValueError(
|
|
463
|
+
"'displayed_pairwise_attraction_for_nodes' has to be either an ArrayLike "
|
|
464
|
+
"collection of node indices, or an 'DISPLAYED_NODES' string, "
|
|
465
|
+
"or None.")
|
|
466
|
+
else:
|
|
467
|
+
displayed_pairwise_attraction_for_nodes = np.asarray(displayed_pairwise_attraction_for_nodes)-1 # 1-base indexing
|
|
468
|
+
_logger.debug("Attraction nodes provided | count=%d", displayed_pairwise_attraction_for_nodes.size)
|
|
469
|
+
|
|
470
|
+
if np.setdiff1d(displayed_pairwise_attraction_for_nodes, displayed_nodes).size > 0:
|
|
471
|
+
_logger.error("Attraction nodes not a subset of displayed_nodes.")
|
|
472
|
+
raise ValueError("'displayed_pairwise_attraction_for_nodes' must be a subset of 'displayed_nodes'")
|
|
473
|
+
|
|
474
|
+
attractive_edges, attractive_color_weights, attractive_opacity_weights = \
|
|
475
|
+
visualizer_util.build_line_segments(
|
|
476
|
+
self.N,
|
|
477
|
+
displayed_pairwise_attraction_for_nodes,
|
|
478
|
+
coords_for_edges,
|
|
479
|
+
self.attr_energies[frame_id],
|
|
480
|
+
frac_node_interactions_displayed,
|
|
481
|
+
global_weights_frac=global_interactions_frac,
|
|
482
|
+
global_opacity=global_opacity,
|
|
483
|
+
global_color_saturation=global_color_saturation
|
|
484
|
+
)
|
|
485
|
+
_logger.debug("Attraction edges built | segs.shape=%s, color_w.shape=%s, opacity_w.shape=%s",
|
|
486
|
+
getattr(attractive_edges, "shape", None),
|
|
487
|
+
getattr(attractive_color_weights, "shape", None),
|
|
488
|
+
getattr(attractive_opacity_weights, "shape", None))
|
|
489
|
+
else:
|
|
490
|
+
attractive_edges = None
|
|
491
|
+
_logger.debug("Attraction edges skipped (selector=None).")
|
|
492
|
+
|
|
493
|
+
# REPULSIVE EDGES
|
|
494
|
+
if displayed_pairwise_repulsion_for_nodes is not None:
|
|
495
|
+
if self.repuls_energies is None:
|
|
496
|
+
_logger.warning("Repulsive dataset unavailable; skipping repulsive edges.")
|
|
497
|
+
repulsive_edges = None
|
|
498
|
+
else:
|
|
499
|
+
if isinstance(displayed_pairwise_repulsion_for_nodes, str):
|
|
500
|
+
if displayed_pairwise_repulsion_for_nodes == "DISPLAYED_NODES":
|
|
501
|
+
displayed_pairwise_repulsion_for_nodes = displayed_nodes
|
|
502
|
+
_logger.debug("Repulsion nodes='DISPLAYED_NODES' -> count=%d", displayed_pairwise_repulsion_for_nodes.size)
|
|
503
|
+
else:
|
|
504
|
+
_logger.error("Invalid repulsion selector string: %s", displayed_pairwise_repulsion_for_nodes)
|
|
505
|
+
raise ValueError(
|
|
506
|
+
"'displayed_pairwise_repulsion_for_nodes' has to be either an ArrayLike "
|
|
507
|
+
"collection of node indices, or an 'DISPLAYED_NODES' string, "
|
|
508
|
+
"or None.")
|
|
509
|
+
else:
|
|
510
|
+
displayed_pairwise_repulsion_for_nodes = np.asarray(displayed_pairwise_repulsion_for_nodes)-1 # 1-base indexing
|
|
511
|
+
_logger.debug("Repulsion nodes provided | count=%d", displayed_pairwise_repulsion_for_nodes.size)
|
|
512
|
+
|
|
513
|
+
if np.setdiff1d(displayed_pairwise_repulsion_for_nodes, displayed_nodes).size > 0:
|
|
514
|
+
_logger.error("Repulsion nodes not a subset of displayed_nodes.")
|
|
515
|
+
raise ValueError("'displayed_pairwise_repulsion_for_nodes' must be a subset of 'displayed_nodes'")
|
|
516
|
+
|
|
517
|
+
repulsive_edges, repulsive_color_weights, repulsive_opacity_weights = \
|
|
518
|
+
visualizer_util.build_line_segments(
|
|
519
|
+
self.N,
|
|
520
|
+
displayed_pairwise_repulsion_for_nodes,
|
|
521
|
+
coords_for_edges,
|
|
522
|
+
self.repuls_energies[frame_id],
|
|
523
|
+
frac_node_interactions_displayed,
|
|
524
|
+
global_weights_frac=global_interactions_frac,
|
|
525
|
+
global_opacity=global_opacity,
|
|
526
|
+
global_color_saturation=global_color_saturation
|
|
527
|
+
)
|
|
528
|
+
_logger.debug("Repulsion edges built | segs.shape=%s, color_w.shape=%s, opacity_w.shape=%s",
|
|
529
|
+
getattr(repulsive_edges, "shape", None),
|
|
530
|
+
getattr(repulsive_color_weights, "shape", None),
|
|
531
|
+
getattr(repulsive_opacity_weights, "shape", None))
|
|
532
|
+
else:
|
|
533
|
+
repulsive_edges = None
|
|
534
|
+
_logger.debug("Repulsion edges skipped (selector=None).")
|
|
535
|
+
|
|
536
|
+
# COLOR THE DATA POINTS
|
|
537
|
+
if isinstance(node_colors, str):
|
|
538
|
+
node_cmap = self._plt.get_cmap(node_colors)
|
|
539
|
+
idx0 = np.asarray(displayed_nodes, dtype=int)
|
|
540
|
+
color_array = node_cmap(self._residue_norm(idx0))
|
|
541
|
+
_logger.debug("Node colors via colormap '%s' | count=%d", node_colors, idx0.size)
|
|
542
|
+
else:
|
|
543
|
+
color_array_full = visualizer_util.map_groups_to_colors(
|
|
544
|
+
N=self.N,
|
|
545
|
+
groups=node_colors,
|
|
546
|
+
default_color=self.default_node_color,
|
|
547
|
+
one_based=True
|
|
548
|
+
)
|
|
549
|
+
color_array = np.asarray(color_array_full)[displayed_nodes]
|
|
550
|
+
_logger.debug("Node colors via groups/default | count=%d", color_array.shape[0])
|
|
551
|
+
|
|
552
|
+
# UPDATE CANVAS
|
|
553
|
+
self._update_scatter(nodes, colors=color_array)
|
|
554
|
+
|
|
555
|
+
if attractive_edges is not None:
|
|
556
|
+
attractive_cmap = self._plt.get_cmap(attractive_edge_cmap)
|
|
557
|
+
attr_rgba = attractive_cmap(attractive_color_weights) # (E,4)
|
|
558
|
+
attr_rgba[:, 3] = attractive_opacity_weights
|
|
559
|
+
self._update_attr_edges(attractive_edges,
|
|
560
|
+
colors=attr_rgba,
|
|
561
|
+
opacity=None)
|
|
562
|
+
_logger.debug("Attraction edges updated on canvas.")
|
|
563
|
+
|
|
564
|
+
if repulsive_edges is not None:
|
|
565
|
+
repulsive_cmap = self._plt.get_cmap(repulsive_edge_cmap)
|
|
566
|
+
rep_rgba = repulsive_cmap(repulsive_color_weights) # (E,4)
|
|
567
|
+
rep_rgba[:, 3] = repulsive_opacity_weights
|
|
568
|
+
self._update_repuls_edges(repulsive_edges,
|
|
569
|
+
colors=rep_rgba,
|
|
570
|
+
opacity=None)
|
|
571
|
+
_logger.debug("Repulsion edges updated on canvas.")
|
|
572
|
+
|
|
573
|
+
# EXTRAS
|
|
574
|
+
if title:
|
|
575
|
+
self._ax.text2D(0.5, 0.99, title, transform=self._ax.transAxes,
|
|
576
|
+
ha="center", va="top")
|
|
577
|
+
_logger.debug("Title set: %s", title)
|
|
578
|
+
|
|
579
|
+
if show_node_labels:
|
|
580
|
+
labs = (np.asarray(displayed_nodes, dtype=int) + 1) # one-based labels
|
|
581
|
+
_logger.debug("Adding node labels | count=%d, fontsize=%d", labs.size, node_label_size)
|
|
582
|
+
for (x, y, z), lab in zip(nodes, labs):
|
|
583
|
+
self._ax.text(float(x)+.3, float(y)+.3, float(z)+1.3, str(lab),
|
|
584
|
+
fontsize=node_label_size, color="k")
|
|
585
|
+
|
|
586
|
+
if show:
|
|
587
|
+
# auto-block in scripts; non-block in notebooks/interactive
|
|
588
|
+
try:
|
|
589
|
+
get_ipython # type: ignore
|
|
590
|
+
in_ipy = True
|
|
591
|
+
except NameError:
|
|
592
|
+
in_ipy = False
|
|
593
|
+
|
|
594
|
+
_logger.debug("Showing figure | in_ipy=%s, interactive=%s", in_ipy, self._plt.isinteractive())
|
|
595
|
+
|
|
596
|
+
if in_ipy or self._plt.isinteractive():
|
|
597
|
+
self._plt.show(block=False)
|
|
598
|
+
self._plt.pause(0.05)
|
|
599
|
+
else:
|
|
600
|
+
self._plt.show()
|
|
601
|
+
_logger.debug("build_frame completed.")
|
|
602
|
+
|
|
603
|
+
def animate_trajectory(
|
|
604
|
+
self,
|
|
605
|
+
start: int = 1,
|
|
606
|
+
stop: int | None = None,
|
|
607
|
+
step: int = 1,
|
|
608
|
+
interval_ms: int = 50,
|
|
609
|
+
loop: bool = False,
|
|
610
|
+
**build_kwargs,
|
|
611
|
+
):
|
|
612
|
+
"""Play frames as an animation by reusing existing artists.
|
|
613
|
+
|
|
614
|
+
Iterates frames from `start` to `stop` (inclusive, stepping by `step`)
|
|
615
|
+
and calls `build_frame` for each, pausing `interval_ms` between
|
|
616
|
+
updates. If `loop=True`, the sequence repeats until the figure is
|
|
617
|
+
closed or the user interrupts.
|
|
618
|
+
|
|
619
|
+
Args:
|
|
620
|
+
start: 1-based starting frame index.
|
|
621
|
+
stop: 1-based ending frame index (inclusive). Defaults to the last
|
|
622
|
+
available frame if None.
|
|
623
|
+
step: Step size between frames. Negative values play backwards.
|
|
624
|
+
interval_ms: Pause between frames in milliseconds.
|
|
625
|
+
loop: If True, repeat the frame sequence indefinitely (until the
|
|
626
|
+
figure is closed or interrupted).
|
|
627
|
+
**build_kwargs: Additional keyword arguments forwarded to
|
|
628
|
+
`build_frame` (e.g., `displayed_nodes`, `padding`, `spread`,
|
|
629
|
+
`node_colors`, etc.). `show=False` is enforced internally.
|
|
630
|
+
|
|
631
|
+
Returns:
|
|
632
|
+
None
|
|
633
|
+
|
|
634
|
+
Raises:
|
|
635
|
+
ValueError: If `step` is zero.
|
|
636
|
+
|
|
637
|
+
Notes:
|
|
638
|
+
- Internally enforces `build_kwargs["show"] = False` during iteration to
|
|
639
|
+
avoid blocking; a final `self._plt.show()` is issued at the end of a
|
|
640
|
+
single pass.
|
|
641
|
+
- In headless mode (backend 'Agg'), no GUI window appears; use
|
|
642
|
+
`self._plt.savefig(...)` or a writer to export frames.
|
|
643
|
+
"""
|
|
644
|
+
_logger.debug(
|
|
645
|
+
"animate_trajectory | start=%s, stop=%s, step=%s, interval_ms=%s, loop=%s",
|
|
646
|
+
start, stop, step, interval_ms, loop
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
# default to all frames
|
|
650
|
+
T = int(self.COM_coords.shape[0])
|
|
651
|
+
if stop is None:
|
|
652
|
+
stop = T
|
|
653
|
+
|
|
654
|
+
if step == 0:
|
|
655
|
+
raise ValueError("step must be non-zero")
|
|
656
|
+
|
|
657
|
+
# build the list of frame ids
|
|
658
|
+
frames = list(range(start, stop + (1 if step > 0 else -1), step)) # allow for backward play
|
|
659
|
+
if not frames:
|
|
660
|
+
_logger.debug("animate_trajectory | empty frame list -> return")
|
|
661
|
+
return
|
|
662
|
+
|
|
663
|
+
build_kwargs["show"] = False
|
|
664
|
+
|
|
665
|
+
try:
|
|
666
|
+
if loop:
|
|
667
|
+
_logger.debug("animate_trajectory | entering repeat loop until window closed.")
|
|
668
|
+
while self._plt.fignum_exists(self._fig.number):
|
|
669
|
+
for fid in frames:
|
|
670
|
+
if not self._plt.fignum_exists(self._fig.number):
|
|
671
|
+
break
|
|
672
|
+
self.build_frame(fid, **build_kwargs)
|
|
673
|
+
self._update_canvas(pause_for=interval_ms / 1000.0)
|
|
674
|
+
else:
|
|
675
|
+
_logger.debug("animate_trajectory | single pass over frames.")
|
|
676
|
+
for fid in frames:
|
|
677
|
+
self.build_frame(fid, **build_kwargs)
|
|
678
|
+
self._update_canvas(pause_for=interval_ms / 1000.0)
|
|
679
|
+
# one final show so the window stays up when the loop ends
|
|
680
|
+
self._plt.show()
|
|
681
|
+
except KeyboardInterrupt:
|
|
682
|
+
_logger.debug("animate_trajectory | interrupted by user (KeyboardInterrupt).")
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
__all__ = [
|
|
686
|
+
"Visualizer"
|
|
687
|
+
]
|
|
688
|
+
|
|
689
|
+
if __name__ == "__main__":
|
|
690
|
+
pass
|