canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/analyzer/data/__init__.py +3 -11
- canns/analyzer/data/asa/__init__.py +74 -0
- canns/analyzer/data/asa/cohospace.py +905 -0
- canns/analyzer/data/asa/config.py +246 -0
- canns/analyzer/data/asa/decode.py +448 -0
- canns/analyzer/data/asa/embedding.py +269 -0
- canns/analyzer/data/asa/filters.py +208 -0
- canns/analyzer/data/asa/fr.py +439 -0
- canns/analyzer/data/asa/path.py +389 -0
- canns/analyzer/data/asa/plotting.py +1276 -0
- canns/analyzer/data/asa/tda.py +901 -0
- canns/analyzer/data/legacy/__init__.py +6 -0
- canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
- canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
- canns/analyzer/visualization/core/backend.py +1 -1
- canns/analyzer/visualization/core/config.py +77 -0
- canns/analyzer/visualization/core/rendering.py +10 -6
- canns/analyzer/visualization/energy_plots.py +22 -8
- canns/analyzer/visualization/spatial_plots.py +31 -11
- canns/analyzer/visualization/theta_sweep_plots.py +15 -6
- canns/pipeline/__init__.py +4 -8
- canns/pipeline/asa/__init__.py +21 -0
- canns/pipeline/asa/__main__.py +11 -0
- canns/pipeline/asa/app.py +1000 -0
- canns/pipeline/asa/runner.py +1095 -0
- canns/pipeline/asa/screens.py +215 -0
- canns/pipeline/asa/state.py +248 -0
- canns/pipeline/asa/styles.tcss +221 -0
- canns/pipeline/asa/widgets.py +233 -0
- canns/pipeline/gallery/__init__.py +7 -0
- canns/task/open_loop_navigation.py +3 -1
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
- canns/pipeline/theta_sweep.py +0 -573
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
- {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1276 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import numpy as np
|
|
8
|
+
from matplotlib import animation, cm
|
|
9
|
+
from scipy import signal
|
|
10
|
+
from scipy.ndimage import binary_closing, gaussian_filter
|
|
11
|
+
from scipy.stats import binned_statistic_2d, multivariate_normal
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
from ...visualization.core import (
|
|
15
|
+
PlotConfig,
|
|
16
|
+
emit_backend_warnings,
|
|
17
|
+
finalize_figure,
|
|
18
|
+
get_matplotlib_writer,
|
|
19
|
+
get_optimal_worker_count,
|
|
20
|
+
render_animation_parallel,
|
|
21
|
+
select_animation_backend,
|
|
22
|
+
warn_double_rendering,
|
|
23
|
+
)
|
|
24
|
+
from ...visualization.core.jupyter_utils import display_animation_in_jupyter, is_jupyter_environment
|
|
25
|
+
from .config import CANN2DPlotConfig, ProcessingError, SpikeEmbeddingConfig
|
|
26
|
+
from .embedding import embed_spike_trains
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _ensure_plot_config(
|
|
30
|
+
config: PlotConfig | None,
|
|
31
|
+
factory,
|
|
32
|
+
*,
|
|
33
|
+
kwargs: dict[str, Any] | None = None,
|
|
34
|
+
**defaults: Any,
|
|
35
|
+
) -> PlotConfig:
|
|
36
|
+
if config is None:
|
|
37
|
+
defaults.update({"kwargs": kwargs or {}})
|
|
38
|
+
return factory(**defaults)
|
|
39
|
+
|
|
40
|
+
if kwargs:
|
|
41
|
+
config_kwargs = config.kwargs or {}
|
|
42
|
+
config_kwargs.update(kwargs)
|
|
43
|
+
config.kwargs = config_kwargs
|
|
44
|
+
return config
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _ensure_parent_dir(save_path: str | None) -> None:
|
|
48
|
+
if save_path:
|
|
49
|
+
parent = os.path.dirname(save_path)
|
|
50
|
+
if parent:
|
|
51
|
+
os.makedirs(parent, exist_ok=True)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _render_torus_frame(frame_index: int, frame_data: dict[str, Any]) -> np.ndarray:
|
|
55
|
+
from io import BytesIO
|
|
56
|
+
|
|
57
|
+
import numpy as np
|
|
58
|
+
|
|
59
|
+
fig = plt.figure(figsize=frame_data["figsize"])
|
|
60
|
+
ax = fig.add_subplot(111, projection="3d")
|
|
61
|
+
ax.set_zlim(*frame_data["zlim"])
|
|
62
|
+
ax.view_init(frame_data["elev"], frame_data["azim"])
|
|
63
|
+
ax.axis("off")
|
|
64
|
+
|
|
65
|
+
frame = frame_data["frames"][frame_index]
|
|
66
|
+
m = frame["m"]
|
|
67
|
+
|
|
68
|
+
ax.plot_surface(
|
|
69
|
+
frame_data["torus_x"],
|
|
70
|
+
frame_data["torus_y"],
|
|
71
|
+
frame_data["torus_z"],
|
|
72
|
+
facecolors=cm.viridis(m / (np.max(m) + 1e-9)),
|
|
73
|
+
alpha=1,
|
|
74
|
+
linewidth=0.1,
|
|
75
|
+
antialiased=True,
|
|
76
|
+
rstride=1,
|
|
77
|
+
cstride=1,
|
|
78
|
+
shade=False,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
time_label = frame.get("time")
|
|
82
|
+
label_text = f"Frame: {frame_index + 1}/{len(frame_data['frames'])}"
|
|
83
|
+
if time_label is not None:
|
|
84
|
+
label_text = f"{label_text} | Time: {time_label}"
|
|
85
|
+
ax.text2D(
|
|
86
|
+
0.05,
|
|
87
|
+
0.95,
|
|
88
|
+
label_text,
|
|
89
|
+
transform=ax.transAxes,
|
|
90
|
+
fontsize=12,
|
|
91
|
+
bbox=dict(facecolor="white", alpha=0.7),
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
fig.tight_layout()
|
|
95
|
+
|
|
96
|
+
buf = BytesIO()
|
|
97
|
+
fig.savefig(buf, format="png", dpi=frame_data["dpi"], bbox_inches="tight")
|
|
98
|
+
buf.seek(0)
|
|
99
|
+
img = plt.imread(buf)
|
|
100
|
+
plt.close(fig)
|
|
101
|
+
buf.close()
|
|
102
|
+
|
|
103
|
+
if img.dtype in (np.float32, np.float64):
|
|
104
|
+
img = (img * 255).astype(np.uint8)
|
|
105
|
+
|
|
106
|
+
return img
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _render_2d_bump_frame(frame_index: int, frame_data: dict[str, Any]) -> np.ndarray:
|
|
110
|
+
from io import BytesIO
|
|
111
|
+
|
|
112
|
+
fig, ax = plt.subplots(figsize=frame_data["figsize"])
|
|
113
|
+
ax.set_xlabel("Manifold Dimension 1 (rad)", fontsize=12)
|
|
114
|
+
ax.set_ylabel("Manifold Dimension 2 (rad)", fontsize=12)
|
|
115
|
+
ax.set_title("CANN2D Bump Activity (2D Projection)", fontsize=14, fontweight="bold")
|
|
116
|
+
|
|
117
|
+
im = ax.imshow(
|
|
118
|
+
frame_data["maps"][frame_index].T,
|
|
119
|
+
extent=[0, 2 * np.pi, 0, 2 * np.pi],
|
|
120
|
+
origin="lower",
|
|
121
|
+
cmap="viridis",
|
|
122
|
+
aspect="auto",
|
|
123
|
+
)
|
|
124
|
+
fig.colorbar(im, ax=ax).set_label("Activity", fontsize=11)
|
|
125
|
+
ax.text(
|
|
126
|
+
0.02,
|
|
127
|
+
0.98,
|
|
128
|
+
f"Frame: {frame_index + 1}/{len(frame_data['maps'])}",
|
|
129
|
+
transform=ax.transAxes,
|
|
130
|
+
fontsize=11,
|
|
131
|
+
verticalalignment="top",
|
|
132
|
+
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
fig.tight_layout()
|
|
136
|
+
buf = BytesIO()
|
|
137
|
+
fig.savefig(buf, format="png", dpi=frame_data["dpi"], bbox_inches="tight")
|
|
138
|
+
buf.seek(0)
|
|
139
|
+
img = plt.imread(buf)
|
|
140
|
+
plt.close(fig)
|
|
141
|
+
buf.close()
|
|
142
|
+
|
|
143
|
+
if img.dtype in (np.float32, np.float64):
|
|
144
|
+
img = (img * 255).astype(np.uint8)
|
|
145
|
+
|
|
146
|
+
return img
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def plot_projection(
|
|
150
|
+
reduce_func,
|
|
151
|
+
embed_data,
|
|
152
|
+
config: CANN2DPlotConfig | None = None,
|
|
153
|
+
title="Projection (3D)",
|
|
154
|
+
xlabel="Component 1",
|
|
155
|
+
ylabel="Component 2",
|
|
156
|
+
zlabel="Component 3",
|
|
157
|
+
save_path=None,
|
|
158
|
+
show=True,
|
|
159
|
+
dpi=300,
|
|
160
|
+
figsize=(10, 8),
|
|
161
|
+
**kwargs,
|
|
162
|
+
):
|
|
163
|
+
"""
|
|
164
|
+
Plot a 3D projection of the embedded data.
|
|
165
|
+
|
|
166
|
+
Parameters
|
|
167
|
+
----------
|
|
168
|
+
reduce_func (callable): Function to reduce the dimensionality of the data.
|
|
169
|
+
embed_data (ndarray): Data to be projected.
|
|
170
|
+
config (PlotConfig, optional): Configuration object for unified plotting parameters
|
|
171
|
+
**kwargs: backward compatibility parameters
|
|
172
|
+
title (str): Title of the plot.
|
|
173
|
+
xlabel (str): Label for the x-axis.
|
|
174
|
+
ylabel (str): Label for the y-axis.
|
|
175
|
+
zlabel (str): Label for the z-axis.
|
|
176
|
+
save_path (str, optional): Path to save the plot. If None, plot will not be saved.
|
|
177
|
+
show (bool): Whether to display the plot.
|
|
178
|
+
dpi (int): Dots per inch for saving the figure.
|
|
179
|
+
figsize (tuple): Size of the figure.
|
|
180
|
+
|
|
181
|
+
Returns
|
|
182
|
+
-------
|
|
183
|
+
matplotlib.figure.Figure
|
|
184
|
+
The created figure.
|
|
185
|
+
|
|
186
|
+
Examples
|
|
187
|
+
--------
|
|
188
|
+
>>> fig = plot_projection(reduce_func, embed_data, show=False) # doctest: +SKIP
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
# Handle backward compatibility and configuration
|
|
192
|
+
if config is None:
|
|
193
|
+
config = CANN2DPlotConfig.for_projection_3d(
|
|
194
|
+
title=title,
|
|
195
|
+
xlabel=xlabel,
|
|
196
|
+
ylabel=ylabel,
|
|
197
|
+
zlabel=zlabel,
|
|
198
|
+
save_path=save_path,
|
|
199
|
+
show=show,
|
|
200
|
+
figsize=figsize,
|
|
201
|
+
dpi=dpi,
|
|
202
|
+
**kwargs,
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
if save_path is not None:
|
|
206
|
+
config.save_path = save_path
|
|
207
|
+
if show is not None:
|
|
208
|
+
config.show = show
|
|
209
|
+
if not config.title:
|
|
210
|
+
config.title = title
|
|
211
|
+
if not config.xlabel:
|
|
212
|
+
config.xlabel = xlabel
|
|
213
|
+
if not config.ylabel:
|
|
214
|
+
config.ylabel = ylabel
|
|
215
|
+
if not config.zlabel:
|
|
216
|
+
config.zlabel = zlabel
|
|
217
|
+
if config.figsize == PlotConfig().figsize:
|
|
218
|
+
config.figsize = figsize
|
|
219
|
+
if dpi is not None:
|
|
220
|
+
config.dpi = dpi
|
|
221
|
+
|
|
222
|
+
reduced_data = reduce_func(embed_data[::5])
|
|
223
|
+
|
|
224
|
+
fig = plt.figure(figsize=config.figsize)
|
|
225
|
+
ax = fig.add_subplot(111, projection="3d")
|
|
226
|
+
ax.scatter(reduced_data[:, 0], reduced_data[:, 1], reduced_data[:, 2], s=1, alpha=0.5)
|
|
227
|
+
|
|
228
|
+
ax.set_title(config.title)
|
|
229
|
+
ax.set_xlabel(config.xlabel)
|
|
230
|
+
ax.set_ylabel(config.ylabel)
|
|
231
|
+
ax.set_zlabel(config.zlabel)
|
|
232
|
+
|
|
233
|
+
config.save_dpi = getattr(config, "dpi", config.save_dpi)
|
|
234
|
+
_ensure_parent_dir(config.save_path)
|
|
235
|
+
finalize_figure(fig, config)
|
|
236
|
+
|
|
237
|
+
return fig
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def plot_path_compare(
|
|
241
|
+
x: np.ndarray,
|
|
242
|
+
y: np.ndarray,
|
|
243
|
+
coords: np.ndarray,
|
|
244
|
+
config: PlotConfig | None = None,
|
|
245
|
+
*,
|
|
246
|
+
title: str = "Path Compare",
|
|
247
|
+
figsize: tuple[int, int] = (12, 5),
|
|
248
|
+
show: bool = True,
|
|
249
|
+
save_path: str | None = None,
|
|
250
|
+
) -> tuple[plt.Figure, np.ndarray]:
|
|
251
|
+
"""Plot physical path vs decoded coho-space path side-by-side.
|
|
252
|
+
|
|
253
|
+
Parameters
|
|
254
|
+
----------
|
|
255
|
+
x, y : np.ndarray
|
|
256
|
+
Physical position arrays of shape (T,).
|
|
257
|
+
coords : np.ndarray
|
|
258
|
+
Decoded circular coordinates, shape (T, 1) or (T, 2).
|
|
259
|
+
config : PlotConfig, optional
|
|
260
|
+
Plot configuration. If None, a default config is created.
|
|
261
|
+
title, figsize, show, save_path : optional
|
|
262
|
+
Backward-compatibility parameters.
|
|
263
|
+
|
|
264
|
+
Returns
|
|
265
|
+
-------
|
|
266
|
+
(Figure, ndarray)
|
|
267
|
+
Figure and axes array.
|
|
268
|
+
|
|
269
|
+
Examples
|
|
270
|
+
--------
|
|
271
|
+
>>> fig, axes = plot_path_compare(x, y, coords, show=False) # doctest: +SKIP
|
|
272
|
+
"""
|
|
273
|
+
from .path import draw_base_parallelogram, skew_transform, snake_wrap_trail_in_parallelogram
|
|
274
|
+
|
|
275
|
+
x = np.asarray(x).ravel()
|
|
276
|
+
y = np.asarray(y).ravel()
|
|
277
|
+
coords = np.asarray(coords)
|
|
278
|
+
|
|
279
|
+
if coords.ndim != 2 or coords.shape[1] < 1:
|
|
280
|
+
raise ValueError(f"coords must be 2D with at least 1 column, got {coords.shape}")
|
|
281
|
+
|
|
282
|
+
config = _ensure_plot_config(
|
|
283
|
+
config,
|
|
284
|
+
PlotConfig.for_static_plot,
|
|
285
|
+
title=title,
|
|
286
|
+
figsize=figsize,
|
|
287
|
+
save_path=save_path,
|
|
288
|
+
show=show,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
fig, axes = plt.subplots(1, 2, figsize=config.figsize)
|
|
292
|
+
if config.title:
|
|
293
|
+
fig.suptitle(config.title)
|
|
294
|
+
|
|
295
|
+
ax0 = axes[0]
|
|
296
|
+
ax0.set_title("Physical path (x,y)")
|
|
297
|
+
ax0.set_aspect("equal", "box")
|
|
298
|
+
ax0.axis("off")
|
|
299
|
+
ax0.plot(x, y, lw=0.9, alpha=0.8)
|
|
300
|
+
|
|
301
|
+
ax1 = axes[1]
|
|
302
|
+
ax1.set_title("Decoded coho path")
|
|
303
|
+
ax1.set_aspect("equal", "box")
|
|
304
|
+
ax1.axis("off")
|
|
305
|
+
|
|
306
|
+
if coords.shape[1] >= 2:
|
|
307
|
+
theta2 = coords[:, :2] % (2 * np.pi)
|
|
308
|
+
xy = skew_transform(theta2)
|
|
309
|
+
draw_base_parallelogram(ax1)
|
|
310
|
+
trail = snake_wrap_trail_in_parallelogram(
|
|
311
|
+
xy, np.array([2 * np.pi, 0.0]), np.array([np.pi, np.sqrt(3) * np.pi])
|
|
312
|
+
)
|
|
313
|
+
ax1.plot(trail[:, 0], trail[:, 1], lw=0.9, alpha=0.9)
|
|
314
|
+
else:
|
|
315
|
+
th = coords[:, 0] % (2 * np.pi)
|
|
316
|
+
ax1.plot(np.cos(th), np.sin(th), lw=0.9, alpha=0.9)
|
|
317
|
+
ax1.set_xlim(-1.2, 1.2)
|
|
318
|
+
ax1.set_ylim(-1.2, 1.2)
|
|
319
|
+
|
|
320
|
+
fig.tight_layout()
|
|
321
|
+
_ensure_parent_dir(config.save_path)
|
|
322
|
+
finalize_figure(fig, config)
|
|
323
|
+
return fig, axes
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def plot_cohomap(
|
|
327
|
+
decoding_result: dict[str, Any],
|
|
328
|
+
position_data: dict[str, Any],
|
|
329
|
+
config: PlotConfig | None = None,
|
|
330
|
+
save_path: str | None = None,
|
|
331
|
+
show: bool = False,
|
|
332
|
+
figsize: tuple[int, int] = (10, 4),
|
|
333
|
+
dpi: int = 300,
|
|
334
|
+
subsample: int = 10,
|
|
335
|
+
) -> plt.Figure:
|
|
336
|
+
"""
|
|
337
|
+
Visualize CohoMap 1.0: decoded circular coordinates mapped onto spatial trajectory.
|
|
338
|
+
|
|
339
|
+
Creates a two-panel visualization showing how the two decoded circular coordinates
|
|
340
|
+
vary across the animal's spatial trajectory. Each panel displays the spatial path
|
|
341
|
+
colored by the cosine of one circular coordinate dimension.
|
|
342
|
+
|
|
343
|
+
Parameters:
|
|
344
|
+
decoding_result : dict
|
|
345
|
+
Dictionary from decode_circular_coordinates() containing:
|
|
346
|
+
- 'coordsbox': decoded coordinates for box timepoints (n_times x n_dims)
|
|
347
|
+
- 'times_box': time indices for coordsbox
|
|
348
|
+
position_data : dict
|
|
349
|
+
Position data containing 'x' and 'y' arrays for spatial coordinates
|
|
350
|
+
save_path : str, optional
|
|
351
|
+
Path to save the visualization. If None, no save performed
|
|
352
|
+
show : bool, default=False
|
|
353
|
+
Whether to display the visualization
|
|
354
|
+
figsize : tuple[int, int], default=(10, 4)
|
|
355
|
+
Figure size (width, height) in inches
|
|
356
|
+
dpi : int, default=300
|
|
357
|
+
Resolution for saved figure
|
|
358
|
+
subsample : int, default=10
|
|
359
|
+
Subsampling interval for plotting (plot every Nth timepoint)
|
|
360
|
+
|
|
361
|
+
Returns
|
|
362
|
+
-------
|
|
363
|
+
matplotlib.figure.Figure
|
|
364
|
+
The matplotlib figure object.
|
|
365
|
+
|
|
366
|
+
Raises:
|
|
367
|
+
KeyError : If required keys are missing from input dictionaries
|
|
368
|
+
ValueError : If data dimensions are inconsistent
|
|
369
|
+
IndexError : If time indices are out of bounds
|
|
370
|
+
|
|
371
|
+
Examples
|
|
372
|
+
--------
|
|
373
|
+
>>> # Decode coordinates
|
|
374
|
+
>>> decoding = decode_circular_coordinates(persistence_result, spike_data)
|
|
375
|
+
>>> # Visualize with trajectory data
|
|
376
|
+
>>> fig = plot_cohomap(
|
|
377
|
+
... decoding,
|
|
378
|
+
... position_data={'x': xx, 'y': yy},
|
|
379
|
+
... save_path='cohomap.png',
|
|
380
|
+
... show=True
|
|
381
|
+
... )
|
|
382
|
+
"""
|
|
383
|
+
config = _ensure_plot_config(
|
|
384
|
+
config,
|
|
385
|
+
PlotConfig.for_static_plot,
|
|
386
|
+
title="CohoMap",
|
|
387
|
+
xlabel="",
|
|
388
|
+
ylabel="",
|
|
389
|
+
figsize=figsize,
|
|
390
|
+
save_path=save_path,
|
|
391
|
+
show=show,
|
|
392
|
+
)
|
|
393
|
+
config.save_dpi = dpi
|
|
394
|
+
|
|
395
|
+
# Extract data
|
|
396
|
+
coordsbox = decoding_result["coordsbox"]
|
|
397
|
+
times_box = decoding_result["times_box"]
|
|
398
|
+
xx = position_data["x"]
|
|
399
|
+
yy = position_data["y"]
|
|
400
|
+
|
|
401
|
+
# Subsample time indices for plotting
|
|
402
|
+
plot_times = np.arange(0, len(coordsbox), subsample)
|
|
403
|
+
|
|
404
|
+
# Create a two-panel figure (one per cohomology dimension)
|
|
405
|
+
plt.set_cmap("viridis")
|
|
406
|
+
fig, ax = plt.subplots(1, 2, figsize=config.figsize)
|
|
407
|
+
|
|
408
|
+
# Plot for the first circular coordinate
|
|
409
|
+
ax[0].axis("off")
|
|
410
|
+
ax[0].set_aspect("equal", "box")
|
|
411
|
+
im0 = ax[0].scatter(
|
|
412
|
+
xx[times_box][plot_times],
|
|
413
|
+
yy[times_box][plot_times],
|
|
414
|
+
c=np.cos(coordsbox[plot_times, 0]),
|
|
415
|
+
s=8,
|
|
416
|
+
cmap="viridis",
|
|
417
|
+
)
|
|
418
|
+
plt.colorbar(im0, ax=ax[0], label="cos(coord)")
|
|
419
|
+
ax[0].set_title("CohoMap Dim 1", fontsize=10)
|
|
420
|
+
|
|
421
|
+
# Plot for the second circular coordinate
|
|
422
|
+
ax[1].axis("off")
|
|
423
|
+
ax[1].set_aspect("equal", "box")
|
|
424
|
+
im1 = ax[1].scatter(
|
|
425
|
+
xx[times_box][plot_times],
|
|
426
|
+
yy[times_box][plot_times],
|
|
427
|
+
c=np.cos(coordsbox[plot_times, 1]),
|
|
428
|
+
s=8,
|
|
429
|
+
cmap="viridis",
|
|
430
|
+
)
|
|
431
|
+
plt.colorbar(im1, ax=ax[1], label="cos(coord)")
|
|
432
|
+
ax[1].set_title("CohoMap Dim 2", fontsize=10)
|
|
433
|
+
|
|
434
|
+
fig.tight_layout()
|
|
435
|
+
|
|
436
|
+
_ensure_parent_dir(config.save_path)
|
|
437
|
+
finalize_figure(fig, config)
|
|
438
|
+
return fig
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def plot_cohomap_multi(
|
|
442
|
+
decoding_result: dict,
|
|
443
|
+
position_data: dict,
|
|
444
|
+
config: PlotConfig | None = None,
|
|
445
|
+
save_path: str | None = None,
|
|
446
|
+
show: bool = False,
|
|
447
|
+
figsize: tuple[int, int] = (10, 4),
|
|
448
|
+
dpi: int = 300,
|
|
449
|
+
subsample: int = 10,
|
|
450
|
+
) -> plt.Figure:
|
|
451
|
+
"""
|
|
452
|
+
Visualize CohoMap with N-dimensional decoded coordinates.
|
|
453
|
+
|
|
454
|
+
Each subplot shows the spatial trajectory colored by ``cos(coord_i)`` for a single
|
|
455
|
+
circular coordinate.
|
|
456
|
+
|
|
457
|
+
Parameters
|
|
458
|
+
----------
|
|
459
|
+
decoding_result : dict
|
|
460
|
+
Dictionary containing ``coordsbox`` and ``times_box``.
|
|
461
|
+
position_data : dict
|
|
462
|
+
Position data containing ``x`` and ``y`` arrays.
|
|
463
|
+
config : PlotConfig, optional
|
|
464
|
+
Plot configuration for styling, saving, and showing.
|
|
465
|
+
save_path : str, optional
|
|
466
|
+
Path to save the figure.
|
|
467
|
+
show : bool
|
|
468
|
+
Whether to show the figure.
|
|
469
|
+
figsize : tuple[int, int]
|
|
470
|
+
Figure size in inches.
|
|
471
|
+
dpi : int
|
|
472
|
+
Save DPI.
|
|
473
|
+
subsample : int
|
|
474
|
+
Subsample stride for plotting.
|
|
475
|
+
|
|
476
|
+
Returns
|
|
477
|
+
-------
|
|
478
|
+
matplotlib.figure.Figure
|
|
479
|
+
The created figure.
|
|
480
|
+
|
|
481
|
+
Examples
|
|
482
|
+
--------
|
|
483
|
+
>>> fig = plot_cohomap_multi(decoding, {"x": xx, "y": yy}, show=False) # doctest: +SKIP
|
|
484
|
+
"""
|
|
485
|
+
config = _ensure_plot_config(
|
|
486
|
+
config,
|
|
487
|
+
PlotConfig.for_static_plot,
|
|
488
|
+
title="CohoMap",
|
|
489
|
+
xlabel="",
|
|
490
|
+
ylabel="",
|
|
491
|
+
figsize=figsize,
|
|
492
|
+
save_path=save_path,
|
|
493
|
+
show=show,
|
|
494
|
+
)
|
|
495
|
+
config.save_dpi = dpi
|
|
496
|
+
|
|
497
|
+
coordsbox = decoding_result["coordsbox"]
|
|
498
|
+
times_box = decoding_result["times_box"]
|
|
499
|
+
xx = position_data["x"]
|
|
500
|
+
yy = position_data["y"]
|
|
501
|
+
|
|
502
|
+
plot_times = np.arange(0, len(coordsbox), subsample)
|
|
503
|
+
num_dims = coordsbox.shape[1]
|
|
504
|
+
|
|
505
|
+
fig, axes = plt.subplots(1, num_dims, figsize=(5 * num_dims, 4))
|
|
506
|
+
if num_dims == 1:
|
|
507
|
+
axes = [axes]
|
|
508
|
+
|
|
509
|
+
for i in range(num_dims):
|
|
510
|
+
axes[i].axis("off")
|
|
511
|
+
axes[i].set_aspect("equal", "box")
|
|
512
|
+
im = axes[i].scatter(
|
|
513
|
+
xx[times_box][plot_times],
|
|
514
|
+
yy[times_box][plot_times],
|
|
515
|
+
c=np.cos(coordsbox[plot_times, i]),
|
|
516
|
+
s=8,
|
|
517
|
+
cmap="viridis",
|
|
518
|
+
)
|
|
519
|
+
plt.colorbar(im, ax=axes[i], label=f"cos(coord {i + 1})")
|
|
520
|
+
axes[i].set_title(f"CohoMap Dim {i + 1}")
|
|
521
|
+
|
|
522
|
+
fig.tight_layout()
|
|
523
|
+
_ensure_parent_dir(config.save_path)
|
|
524
|
+
finalize_figure(fig, config)
|
|
525
|
+
return fig
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def plot_3d_bump_on_torus(
|
|
529
|
+
decoding_result: dict[str, Any] | str,
|
|
530
|
+
spike_data: dict[str, Any],
|
|
531
|
+
config: CANN2DPlotConfig | None = None,
|
|
532
|
+
save_path: str | None = None,
|
|
533
|
+
numangsint: int = 51,
|
|
534
|
+
r1: float = 1.5,
|
|
535
|
+
r2: float = 1.0,
|
|
536
|
+
window_size: int = 300,
|
|
537
|
+
frame_step: int = 5,
|
|
538
|
+
n_frames: int = 20,
|
|
539
|
+
fps: int = 5,
|
|
540
|
+
show_progress: bool = True,
|
|
541
|
+
show: bool = True,
|
|
542
|
+
figsize: tuple[int, int] = (8, 8),
|
|
543
|
+
render_backend: str | None = "auto",
|
|
544
|
+
output_dpi: int = 150,
|
|
545
|
+
render_workers: int | None = None,
|
|
546
|
+
**kwargs,
|
|
547
|
+
) -> animation.FuncAnimation | None:
|
|
548
|
+
"""
|
|
549
|
+
Visualize the movement of the neural activity bump on a torus using matplotlib animation.
|
|
550
|
+
|
|
551
|
+
This function follows the canns.analyzer.plotting patterns for animation generation
|
|
552
|
+
with progress tracking and proper resource cleanup.
|
|
553
|
+
|
|
554
|
+
Parameters:
|
|
555
|
+
decoding_result : dict or str
|
|
556
|
+
Dictionary containing decoding results with 'coordsbox' and 'times_box' keys,
|
|
557
|
+
or path to .npz file containing these results
|
|
558
|
+
spike_data : dict, optional
|
|
559
|
+
Spike data dictionary containing spike information
|
|
560
|
+
config : PlotConfig, optional
|
|
561
|
+
Configuration object for unified plotting parameters
|
|
562
|
+
**kwargs : backward compatibility parameters
|
|
563
|
+
save_path : str, optional
|
|
564
|
+
Path to save the animation (e.g., 'animation.gif' or 'animation.mp4')
|
|
565
|
+
numangsint : int
|
|
566
|
+
Grid resolution for the torus surface
|
|
567
|
+
r1 : float
|
|
568
|
+
Major radius of the torus
|
|
569
|
+
r2 : float
|
|
570
|
+
Minor radius of the torus
|
|
571
|
+
window_size : int
|
|
572
|
+
Time window (in number of time points) for each frame
|
|
573
|
+
frame_step : int
|
|
574
|
+
Step size to slide the time window between frames
|
|
575
|
+
n_frames : int
|
|
576
|
+
Total number of frames in the animation
|
|
577
|
+
fps : int
|
|
578
|
+
Frames per second for the output animation
|
|
579
|
+
show_progress : bool
|
|
580
|
+
Whether to show progress bar during generation
|
|
581
|
+
show : bool
|
|
582
|
+
Whether to display the animation
|
|
583
|
+
figsize : tuple[int, int]
|
|
584
|
+
Figure size for the animation
|
|
585
|
+
|
|
586
|
+
Returns
|
|
587
|
+
-------
|
|
588
|
+
matplotlib.animation.FuncAnimation | None
|
|
589
|
+
The animation object, or None when shown in Jupyter.
|
|
590
|
+
|
|
591
|
+
Examples
|
|
592
|
+
--------
|
|
593
|
+
>>> ani = plot_3d_bump_on_torus(decoding, spike_data, show=False) # doctest: +SKIP
|
|
594
|
+
"""
|
|
595
|
+
# Handle backward compatibility and configuration
|
|
596
|
+
if config is None:
|
|
597
|
+
config = CANN2DPlotConfig.for_torus_animation(
|
|
598
|
+
title=kwargs.get("title", "3D Bump on Torus"),
|
|
599
|
+
figsize=figsize,
|
|
600
|
+
fps=fps,
|
|
601
|
+
repeat=True,
|
|
602
|
+
show_progress_bar=show_progress,
|
|
603
|
+
save_path=save_path,
|
|
604
|
+
show=show,
|
|
605
|
+
numangsint=numangsint,
|
|
606
|
+
r1=r1,
|
|
607
|
+
r2=r2,
|
|
608
|
+
window_size=window_size,
|
|
609
|
+
frame_step=frame_step,
|
|
610
|
+
n_frames=n_frames,
|
|
611
|
+
**kwargs,
|
|
612
|
+
)
|
|
613
|
+
else:
|
|
614
|
+
if save_path is not None:
|
|
615
|
+
config.save_path = save_path
|
|
616
|
+
if show is not None:
|
|
617
|
+
config.show = show
|
|
618
|
+
if figsize is not None:
|
|
619
|
+
config.figsize = figsize
|
|
620
|
+
if fps is not None:
|
|
621
|
+
config.fps = fps
|
|
622
|
+
if show_progress is not None:
|
|
623
|
+
config.show_progress_bar = show_progress
|
|
624
|
+
config.numangsint = numangsint
|
|
625
|
+
config.r1 = r1
|
|
626
|
+
config.r2 = r2
|
|
627
|
+
config.window_size = window_size
|
|
628
|
+
config.frame_step = frame_step
|
|
629
|
+
config.n_frames = n_frames
|
|
630
|
+
|
|
631
|
+
for key, value in kwargs.items():
|
|
632
|
+
if hasattr(config, key):
|
|
633
|
+
setattr(config, key, value)
|
|
634
|
+
|
|
635
|
+
# Extract configuration values
|
|
636
|
+
save_path = config.save_path
|
|
637
|
+
show = config.show
|
|
638
|
+
figsize = config.figsize
|
|
639
|
+
fps = config.fps
|
|
640
|
+
show_progress = config.show_progress_bar
|
|
641
|
+
numangsint = config.numangsint
|
|
642
|
+
r1 = config.r1
|
|
643
|
+
r2 = config.r2
|
|
644
|
+
window_size = config.window_size
|
|
645
|
+
frame_step = config.frame_step
|
|
646
|
+
n_frames = config.n_frames
|
|
647
|
+
|
|
648
|
+
# Load decoding results if path is provided
|
|
649
|
+
if isinstance(decoding_result, str):
|
|
650
|
+
f = np.load(decoding_result, allow_pickle=True)
|
|
651
|
+
coords = f["coordsbox"]
|
|
652
|
+
times = f["times_box"]
|
|
653
|
+
f.close()
|
|
654
|
+
else:
|
|
655
|
+
coords = decoding_result["coordsbox"]
|
|
656
|
+
times = decoding_result["times_box"]
|
|
657
|
+
|
|
658
|
+
spk, *_ = embed_spike_trains(
|
|
659
|
+
spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
# Pre-compute torus geometry (constant across frames - optimization)
|
|
663
|
+
# Create grid for torus surface
|
|
664
|
+
x_edge = np.linspace(0, 2 * np.pi, numangsint)
|
|
665
|
+
y_edge = np.linspace(0, 2 * np.pi, numangsint)
|
|
666
|
+
X_grid, Y_grid = np.meshgrid(x_edge, y_edge)
|
|
667
|
+
X_transformed = (X_grid + np.pi / 5) % (2 * np.pi)
|
|
668
|
+
|
|
669
|
+
# Pre-compute torus geometry (only done once!)
|
|
670
|
+
torus_x = (r1 + r2 * np.cos(X_transformed)) * np.cos(Y_grid)
|
|
671
|
+
torus_y = (r1 + r2 * np.cos(X_transformed)) * np.sin(Y_grid)
|
|
672
|
+
torus_z = -r2 * np.sin(X_transformed) # Flip torus surface orientation
|
|
673
|
+
|
|
674
|
+
# Prepare animation data (now only stores colors, not geometry)
|
|
675
|
+
frame_data = []
|
|
676
|
+
prev_m = None
|
|
677
|
+
|
|
678
|
+
for frame_idx in tqdm(range(n_frames), desc="Processing frames"):
|
|
679
|
+
start_idx = frame_idx * frame_step
|
|
680
|
+
end_idx = start_idx + window_size
|
|
681
|
+
if end_idx > np.max(times):
|
|
682
|
+
break
|
|
683
|
+
|
|
684
|
+
mask = (times >= start_idx) & (times < end_idx)
|
|
685
|
+
coords_window = coords[mask]
|
|
686
|
+
if len(coords_window) == 0:
|
|
687
|
+
continue
|
|
688
|
+
|
|
689
|
+
spk_window = spk[times[mask], :]
|
|
690
|
+
activity = np.sum(spk_window, axis=1)
|
|
691
|
+
|
|
692
|
+
m, _, _, _ = binned_statistic_2d(
|
|
693
|
+
coords_window[:, 0],
|
|
694
|
+
coords_window[:, 1],
|
|
695
|
+
activity,
|
|
696
|
+
statistic="sum",
|
|
697
|
+
bins=np.linspace(0, 2 * np.pi, numangsint - 1),
|
|
698
|
+
)
|
|
699
|
+
m = np.nan_to_num(m)
|
|
700
|
+
m = _smooth_tuning_map(m, numangsint - 1, sig=4.0, bClose=True)
|
|
701
|
+
m = gaussian_filter(m, sigma=1.0)
|
|
702
|
+
|
|
703
|
+
if prev_m is not None:
|
|
704
|
+
m = 0.7 * prev_m + 0.3 * m
|
|
705
|
+
prev_m = m
|
|
706
|
+
|
|
707
|
+
# Store only activity map (m) and metadata, reuse geometry
|
|
708
|
+
frame_data.append({"m": m, "time": start_idx * frame_step})
|
|
709
|
+
|
|
710
|
+
if not frame_data:
|
|
711
|
+
raise ProcessingError("No valid frames generated for animation")
|
|
712
|
+
|
|
713
|
+
# Create figure and animation with optimized geometry reuse
|
|
714
|
+
fig = plt.figure(figsize=figsize)
|
|
715
|
+
|
|
716
|
+
try:
|
|
717
|
+
ax = fig.add_subplot(111, projection="3d")
|
|
718
|
+
# Batch set axis properties (reduces overhead)
|
|
719
|
+
ax.set_zlim(-2, 2)
|
|
720
|
+
ax.view_init(-125, 135)
|
|
721
|
+
ax.axis("off")
|
|
722
|
+
|
|
723
|
+
# Initialize with first frame
|
|
724
|
+
first_frame = frame_data[0]
|
|
725
|
+
ax.plot_surface(
|
|
726
|
+
torus_x, # Pre-computed geometry
|
|
727
|
+
torus_y, # Pre-computed geometry
|
|
728
|
+
torus_z, # Pre-computed geometry
|
|
729
|
+
facecolors=cm.viridis(first_frame["m"] / (np.max(first_frame["m"]) + 1e-9)),
|
|
730
|
+
alpha=1,
|
|
731
|
+
linewidth=0.1,
|
|
732
|
+
antialiased=True,
|
|
733
|
+
rstride=1,
|
|
734
|
+
cstride=1,
|
|
735
|
+
shade=False,
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
def animate(frame_idx):
|
|
739
|
+
"""Optimized animation update - reuses pre-computed geometry."""
|
|
740
|
+
frame = frame_data[frame_idx]
|
|
741
|
+
|
|
742
|
+
# 3D surfaces require clear (no blitting support), but minimize overhead
|
|
743
|
+
ax.clear()
|
|
744
|
+
|
|
745
|
+
# Batch axis settings together (reduces function call overhead)
|
|
746
|
+
ax.set_zlim(-2, 2)
|
|
747
|
+
ax.view_init(-125, 135)
|
|
748
|
+
ax.axis("off")
|
|
749
|
+
|
|
750
|
+
# Reuse pre-computed geometry, only update colors
|
|
751
|
+
new_surface = ax.plot_surface(
|
|
752
|
+
torus_x, # Pre-computed, not recalculated!
|
|
753
|
+
torus_y, # Pre-computed, not recalculated!
|
|
754
|
+
torus_z, # Pre-computed, not recalculated!
|
|
755
|
+
facecolors=cm.viridis(frame["m"] / (np.max(frame["m"]) + 1e-9)),
|
|
756
|
+
alpha=1,
|
|
757
|
+
linewidth=0.1,
|
|
758
|
+
antialiased=True,
|
|
759
|
+
rstride=1,
|
|
760
|
+
cstride=1,
|
|
761
|
+
shade=False,
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
# Update time text
|
|
765
|
+
time_text = ax.text2D(
|
|
766
|
+
0.05,
|
|
767
|
+
0.95,
|
|
768
|
+
f"Frame: {frame_idx + 1}/{len(frame_data)}",
|
|
769
|
+
transform=ax.transAxes,
|
|
770
|
+
fontsize=12,
|
|
771
|
+
bbox=dict(facecolor="white", alpha=0.7),
|
|
772
|
+
)
|
|
773
|
+
|
|
774
|
+
return new_surface, time_text
|
|
775
|
+
|
|
776
|
+
# Create animation (blit=False due to 3D limitation)
|
|
777
|
+
interval_ms = 1000 / fps
|
|
778
|
+
ani = None
|
|
779
|
+
progress_bar_enabled = show_progress
|
|
780
|
+
|
|
781
|
+
if save_path:
|
|
782
|
+
_ensure_parent_dir(save_path)
|
|
783
|
+
if show and len(frame_data) > 50:
|
|
784
|
+
warn_double_rendering(len(frame_data), save_path, stacklevel=2)
|
|
785
|
+
|
|
786
|
+
backend_selection = select_animation_backend(
|
|
787
|
+
save_path=save_path,
|
|
788
|
+
requested_backend=render_backend,
|
|
789
|
+
check_imageio_plugins=True,
|
|
790
|
+
)
|
|
791
|
+
emit_backend_warnings(backend_selection.warnings, stacklevel=2)
|
|
792
|
+
backend = backend_selection.backend
|
|
793
|
+
|
|
794
|
+
if backend == "imageio":
|
|
795
|
+
render_data = {
|
|
796
|
+
"frames": frame_data,
|
|
797
|
+
"torus_x": torus_x,
|
|
798
|
+
"torus_y": torus_y,
|
|
799
|
+
"torus_z": torus_z,
|
|
800
|
+
"figsize": figsize,
|
|
801
|
+
"dpi": output_dpi,
|
|
802
|
+
"elev": -125,
|
|
803
|
+
"azim": 135,
|
|
804
|
+
"zlim": (-2, 2),
|
|
805
|
+
}
|
|
806
|
+
workers = render_workers
|
|
807
|
+
if workers is None:
|
|
808
|
+
workers = config.render_workers
|
|
809
|
+
if workers is None:
|
|
810
|
+
workers = get_optimal_worker_count()
|
|
811
|
+
try:
|
|
812
|
+
render_animation_parallel(
|
|
813
|
+
_render_torus_frame,
|
|
814
|
+
render_data,
|
|
815
|
+
num_frames=len(frame_data),
|
|
816
|
+
save_path=save_path,
|
|
817
|
+
fps=fps,
|
|
818
|
+
num_workers=workers,
|
|
819
|
+
show_progress=progress_bar_enabled,
|
|
820
|
+
)
|
|
821
|
+
except Exception as e:
|
|
822
|
+
import warnings
|
|
823
|
+
|
|
824
|
+
warnings.warn(
|
|
825
|
+
f"imageio rendering failed: {e}. Falling back to matplotlib.",
|
|
826
|
+
RuntimeWarning,
|
|
827
|
+
stacklevel=2,
|
|
828
|
+
)
|
|
829
|
+
backend = "matplotlib"
|
|
830
|
+
|
|
831
|
+
if backend == "matplotlib":
|
|
832
|
+
ani = animation.FuncAnimation(
|
|
833
|
+
fig,
|
|
834
|
+
animate,
|
|
835
|
+
frames=len(frame_data),
|
|
836
|
+
interval=interval_ms,
|
|
837
|
+
blit=False,
|
|
838
|
+
repeat=config.repeat,
|
|
839
|
+
)
|
|
840
|
+
|
|
841
|
+
writer = get_matplotlib_writer(save_path, fps=fps)
|
|
842
|
+
if progress_bar_enabled:
|
|
843
|
+
pbar = tqdm(total=len(frame_data), desc=f"Saving to {save_path}")
|
|
844
|
+
|
|
845
|
+
def progress_callback(current_frame: int, total_frames: int) -> None:
|
|
846
|
+
pbar.update(1)
|
|
847
|
+
|
|
848
|
+
try:
|
|
849
|
+
ani.save(save_path, writer=writer, progress_callback=progress_callback)
|
|
850
|
+
finally:
|
|
851
|
+
pbar.close()
|
|
852
|
+
else:
|
|
853
|
+
ani.save(save_path, writer=writer)
|
|
854
|
+
|
|
855
|
+
if show:
|
|
856
|
+
if ani is None:
|
|
857
|
+
ani = animation.FuncAnimation(
|
|
858
|
+
fig,
|
|
859
|
+
animate,
|
|
860
|
+
frames=len(frame_data),
|
|
861
|
+
interval=interval_ms,
|
|
862
|
+
blit=False,
|
|
863
|
+
repeat=config.repeat,
|
|
864
|
+
)
|
|
865
|
+
if is_jupyter_environment():
|
|
866
|
+
display_animation_in_jupyter(ani)
|
|
867
|
+
plt.close(fig)
|
|
868
|
+
else:
|
|
869
|
+
plt.show()
|
|
870
|
+
else:
|
|
871
|
+
plt.close(fig)
|
|
872
|
+
|
|
873
|
+
if show and is_jupyter_environment():
|
|
874
|
+
return None
|
|
875
|
+
return ani
|
|
876
|
+
|
|
877
|
+
except Exception as e:
|
|
878
|
+
plt.close(fig)
|
|
879
|
+
raise ProcessingError(f"Failed to create torus animation: {e}") from e
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
def _smooth_tuning_map(mtot, numangsint, sig, bClose=True):
|
|
883
|
+
"""
|
|
884
|
+
Smooth activity map over circular topology (e.g., torus).
|
|
885
|
+
|
|
886
|
+
Parameters:
|
|
887
|
+
mtot (ndarray): Raw activity map matrix.
|
|
888
|
+
numangsint (int): Grid resolution.
|
|
889
|
+
sig (float): Smoothing kernel standard deviation.
|
|
890
|
+
bClose (bool): Whether to assume circular boundary conditions.
|
|
891
|
+
|
|
892
|
+
Returns:
|
|
893
|
+
mtot_out (ndarray): Smoothed map matrix.
|
|
894
|
+
"""
|
|
895
|
+
numangsint_1 = numangsint - 1
|
|
896
|
+
indstemp1 = np.zeros((numangsint_1, numangsint_1), dtype=int)
|
|
897
|
+
indstemp1[indstemp1 == 0] = np.arange((numangsint_1) ** 2)
|
|
898
|
+
mid = int((numangsint_1) / 2)
|
|
899
|
+
mtemp1_3 = mtot.copy()
|
|
900
|
+
for i in range(numangsint_1):
|
|
901
|
+
mtemp1_3[i, :] = np.roll(mtemp1_3[i, :], int(i / 2))
|
|
902
|
+
mtot_out = np.zeros_like(mtot)
|
|
903
|
+
mtemp1_4 = np.concatenate((mtemp1_3, mtemp1_3, mtemp1_3), 1)
|
|
904
|
+
mtemp1_5 = np.zeros_like(mtemp1_4)
|
|
905
|
+
mtemp1_5[:, :mid] = mtemp1_4[:, (numangsint_1) * 3 - mid :]
|
|
906
|
+
mtemp1_5[:, mid:] = mtemp1_4[:, : (numangsint_1) * 3 - mid]
|
|
907
|
+
if bClose:
|
|
908
|
+
mtemp1_6 = _smooth_image(np.concatenate((mtemp1_5, mtemp1_4, mtemp1_5)), sigma=sig)
|
|
909
|
+
else:
|
|
910
|
+
mtemp1_6 = gaussian_filter(np.concatenate((mtemp1_5, mtemp1_4, mtemp1_5)), sigma=sig)
|
|
911
|
+
for i in range(numangsint_1):
|
|
912
|
+
mtot_out[i, :] = mtemp1_6[
|
|
913
|
+
(numangsint_1) + i,
|
|
914
|
+
(numangsint_1) + (int(i / 2) + 1) : (numangsint_1) * 2 + (int(i / 2) + 1),
|
|
915
|
+
]
|
|
916
|
+
return mtot_out
|
|
917
|
+
|
|
918
|
+
|
|
919
|
+
def _smooth_image(img, sigma):
|
|
920
|
+
"""
|
|
921
|
+
Smooth image using multivariate Gaussian kernel, handling missing (NaN) values.
|
|
922
|
+
|
|
923
|
+
Parameters:
|
|
924
|
+
img (ndarray): Input image matrix.
|
|
925
|
+
sigma (float): Standard deviation of smoothing kernel.
|
|
926
|
+
|
|
927
|
+
Returns:
|
|
928
|
+
imgC (ndarray): Smoothed image with inpainting around NaNs.
|
|
929
|
+
"""
|
|
930
|
+
filterSize = max(np.shape(img))
|
|
931
|
+
grid = np.arange(-filterSize + 1, filterSize, 1)
|
|
932
|
+
xx, yy = np.meshgrid(grid, grid)
|
|
933
|
+
|
|
934
|
+
pos = np.dstack((xx, yy))
|
|
935
|
+
|
|
936
|
+
var = multivariate_normal(mean=[0, 0], cov=[[sigma**2, 0], [0, sigma**2]])
|
|
937
|
+
k = var.pdf(pos)
|
|
938
|
+
k = k / np.sum(k)
|
|
939
|
+
|
|
940
|
+
nans = np.isnan(img)
|
|
941
|
+
imgA = img.copy()
|
|
942
|
+
imgA[nans] = 0
|
|
943
|
+
imgA = signal.convolve2d(imgA, k, mode="valid")
|
|
944
|
+
imgD = img.copy()
|
|
945
|
+
imgD[nans] = 0
|
|
946
|
+
imgD[~nans] = 1
|
|
947
|
+
radius = 1
|
|
948
|
+
L = np.arange(-radius, radius + 1)
|
|
949
|
+
X, Y = np.meshgrid(L, L)
|
|
950
|
+
dk = np.array((X**2 + Y**2) <= radius**2, dtype=bool)
|
|
951
|
+
imgE = np.zeros((filterSize + 2, filterSize + 2))
|
|
952
|
+
imgE[1:-1, 1:-1] = imgD
|
|
953
|
+
imgE = binary_closing(imgE, iterations=1, structure=dk)
|
|
954
|
+
imgD = imgE[1:-1, 1:-1]
|
|
955
|
+
|
|
956
|
+
imgB = np.divide(
|
|
957
|
+
signal.convolve2d(imgD, k, mode="valid"),
|
|
958
|
+
signal.convolve2d(np.ones(np.shape(imgD)), k, mode="valid"),
|
|
959
|
+
)
|
|
960
|
+
imgC = np.divide(imgA, imgB)
|
|
961
|
+
imgC[imgD == 0] = -np.inf
|
|
962
|
+
return imgC
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
def plot_2d_bump_on_manifold(
|
|
966
|
+
decoding_result: dict[str, Any] | str,
|
|
967
|
+
spike_data: dict[str, Any],
|
|
968
|
+
save_path: str | None = None,
|
|
969
|
+
fps: int = 20,
|
|
970
|
+
show: bool = True,
|
|
971
|
+
mode: str = "fast",
|
|
972
|
+
window_size: int = 10,
|
|
973
|
+
frame_step: int = 5,
|
|
974
|
+
numangsint: int = 20,
|
|
975
|
+
figsize: tuple[int, int] = (8, 6),
|
|
976
|
+
show_progress: bool = False,
|
|
977
|
+
config: PlotConfig | None = None,
|
|
978
|
+
render_backend: str | None = "auto",
|
|
979
|
+
output_dpi: int = 150,
|
|
980
|
+
render_workers: int | None = None,
|
|
981
|
+
) -> animation.FuncAnimation | None:
|
|
982
|
+
"""
|
|
983
|
+
Create 2D projection animation of CANN2D bump activity with full blitting support.
|
|
984
|
+
|
|
985
|
+
This function provides a fast 2D heatmap visualization as an alternative to the
|
|
986
|
+
3D torus animation. It achieves 10-20x speedup using matplotlib blitting
|
|
987
|
+
optimization, making it ideal for rapid prototyping and daily analysis.
|
|
988
|
+
|
|
989
|
+
Args:
|
|
990
|
+
decoding_result: Decoding results containing coords and times (dict or file path)
|
|
991
|
+
spike_data: Dictionary containing spike train data
|
|
992
|
+
save_path: Path to save animation (None to skip saving)
|
|
993
|
+
fps: Frames per second
|
|
994
|
+
show: Whether to display the animation
|
|
995
|
+
mode: Visualization mode - 'fast' for 2D heatmap (default), '3d' falls back to 3D
|
|
996
|
+
window_size: Time window for activity aggregation
|
|
997
|
+
frame_step: Time step between frames
|
|
998
|
+
numangsint: Number of angular bins for spatial discretization
|
|
999
|
+
figsize: Figure size (width, height) in inches
|
|
1000
|
+
show_progress: Show progress bar during processing
|
|
1001
|
+
|
|
1002
|
+
Returns
|
|
1003
|
+
-------
|
|
1004
|
+
matplotlib.animation.FuncAnimation | None
|
|
1005
|
+
Animation object (or None in Jupyter when showing).
|
|
1006
|
+
|
|
1007
|
+
Raises:
|
|
1008
|
+
ProcessingError: If mode is invalid or animation generation fails
|
|
1009
|
+
|
|
1010
|
+
Examples
|
|
1011
|
+
--------
|
|
1012
|
+
>>> # Fast 2D visualization (recommended for daily use)
|
|
1013
|
+
>>> ani = plot_2d_bump_on_manifold(
|
|
1014
|
+
... decoding_result, spike_data,
|
|
1015
|
+
... save_path='bump_2d.mp4', mode='fast'
|
|
1016
|
+
... )
|
|
1017
|
+
>>> # For publication-ready 3D visualization, use mode='3d'
|
|
1018
|
+
>>> ani = plot_2d_bump_on_manifold(
|
|
1019
|
+
... decoding_result, spike_data, mode='3d'
|
|
1020
|
+
... )
|
|
1021
|
+
"""
|
|
1022
|
+
import matplotlib.animation as animation
|
|
1023
|
+
|
|
1024
|
+
# Validate inputs
|
|
1025
|
+
if mode == "3d":
|
|
1026
|
+
# Fall back to 3D visualization
|
|
1027
|
+
return plot_3d_bump_on_torus(
|
|
1028
|
+
decoding_result=decoding_result,
|
|
1029
|
+
spike_data=spike_data,
|
|
1030
|
+
save_path=save_path,
|
|
1031
|
+
fps=fps,
|
|
1032
|
+
show=show,
|
|
1033
|
+
window_size=window_size,
|
|
1034
|
+
frame_step=frame_step,
|
|
1035
|
+
numangsint=numangsint,
|
|
1036
|
+
figsize=figsize,
|
|
1037
|
+
show_progress=show_progress,
|
|
1038
|
+
render_backend=render_backend,
|
|
1039
|
+
output_dpi=output_dpi,
|
|
1040
|
+
render_workers=render_workers,
|
|
1041
|
+
)
|
|
1042
|
+
|
|
1043
|
+
if mode != "fast":
|
|
1044
|
+
raise ProcessingError(f"Invalid mode '{mode}'. Must be 'fast' or '3d'.")
|
|
1045
|
+
|
|
1046
|
+
if config is None:
|
|
1047
|
+
config = PlotConfig.for_animation(
|
|
1048
|
+
time_steps_per_second=1000,
|
|
1049
|
+
title="CANN2D Bump Activity (2D Projection)",
|
|
1050
|
+
figsize=figsize,
|
|
1051
|
+
fps=fps,
|
|
1052
|
+
show=show,
|
|
1053
|
+
save_path=save_path,
|
|
1054
|
+
show_progress_bar=show_progress,
|
|
1055
|
+
)
|
|
1056
|
+
else:
|
|
1057
|
+
if save_path is not None:
|
|
1058
|
+
config.save_path = save_path
|
|
1059
|
+
if show is not None:
|
|
1060
|
+
config.show = show
|
|
1061
|
+
if figsize is not None:
|
|
1062
|
+
config.figsize = figsize
|
|
1063
|
+
if fps is not None:
|
|
1064
|
+
config.fps = fps
|
|
1065
|
+
if show_progress is not None:
|
|
1066
|
+
config.show_progress_bar = show_progress
|
|
1067
|
+
|
|
1068
|
+
save_path = config.save_path
|
|
1069
|
+
show = config.show
|
|
1070
|
+
fps = config.fps
|
|
1071
|
+
figsize = config.figsize
|
|
1072
|
+
show_progress = config.show_progress_bar
|
|
1073
|
+
|
|
1074
|
+
# Load decoding results
|
|
1075
|
+
if isinstance(decoding_result, str):
|
|
1076
|
+
f = np.load(decoding_result, allow_pickle=True)
|
|
1077
|
+
coords = f["coordsbox"]
|
|
1078
|
+
times = f["times_box"]
|
|
1079
|
+
f.close()
|
|
1080
|
+
else:
|
|
1081
|
+
coords = decoding_result["coordsbox"]
|
|
1082
|
+
times = decoding_result["times_box"]
|
|
1083
|
+
|
|
1084
|
+
# Process spike data for 2D projection
|
|
1085
|
+
spk, *_ = embed_spike_trains(
|
|
1086
|
+
spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
# Process frames
|
|
1090
|
+
n_frames = (np.max(times) - window_size) // frame_step
|
|
1091
|
+
frame_activity_maps = []
|
|
1092
|
+
prev_m = None
|
|
1093
|
+
|
|
1094
|
+
for frame_idx in tqdm(range(n_frames), desc="Processing frames", disable=not show_progress):
|
|
1095
|
+
start_idx = frame_idx * frame_step
|
|
1096
|
+
end_idx = start_idx + window_size
|
|
1097
|
+
if end_idx > np.max(times):
|
|
1098
|
+
break
|
|
1099
|
+
|
|
1100
|
+
mask = (times >= start_idx) & (times < end_idx)
|
|
1101
|
+
coords_window = coords[mask]
|
|
1102
|
+
if len(coords_window) == 0:
|
|
1103
|
+
continue
|
|
1104
|
+
|
|
1105
|
+
spk_window = spk[times[mask], :]
|
|
1106
|
+
activity = np.sum(spk_window, axis=1)
|
|
1107
|
+
|
|
1108
|
+
m, _, _, _ = binned_statistic_2d(
|
|
1109
|
+
coords_window[:, 0],
|
|
1110
|
+
coords_window[:, 1],
|
|
1111
|
+
activity,
|
|
1112
|
+
statistic="sum",
|
|
1113
|
+
bins=np.linspace(0, 2 * np.pi, numangsint - 1),
|
|
1114
|
+
)
|
|
1115
|
+
m = np.nan_to_num(m)
|
|
1116
|
+
m = _smooth_tuning_map(m, numangsint - 1, sig=4.0, bClose=True)
|
|
1117
|
+
m = gaussian_filter(m, sigma=1.0)
|
|
1118
|
+
|
|
1119
|
+
if prev_m is not None:
|
|
1120
|
+
m = 0.7 * prev_m + 0.3 * m
|
|
1121
|
+
prev_m = m
|
|
1122
|
+
|
|
1123
|
+
frame_activity_maps.append(m)
|
|
1124
|
+
|
|
1125
|
+
if not frame_activity_maps:
|
|
1126
|
+
raise ProcessingError("No valid frames generated for animation")
|
|
1127
|
+
|
|
1128
|
+
# Create 2D visualization with blitting
|
|
1129
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
1130
|
+
ax.set_xlabel("Manifold Dimension 1 (rad)", fontsize=12)
|
|
1131
|
+
ax.set_ylabel("Manifold Dimension 2 (rad)", fontsize=12)
|
|
1132
|
+
ax.set_title("CANN2D Bump Activity (2D Projection)", fontsize=14, fontweight="bold")
|
|
1133
|
+
|
|
1134
|
+
# Pre-create artists for blitting
|
|
1135
|
+
# Heatmap
|
|
1136
|
+
im = ax.imshow(
|
|
1137
|
+
frame_activity_maps[0].T, # Transpose for correct orientation
|
|
1138
|
+
extent=[0, 2 * np.pi, 0, 2 * np.pi],
|
|
1139
|
+
origin="lower",
|
|
1140
|
+
cmap="viridis",
|
|
1141
|
+
animated=True,
|
|
1142
|
+
aspect="auto",
|
|
1143
|
+
)
|
|
1144
|
+
# Colorbar (static)
|
|
1145
|
+
cbar = plt.colorbar(im, ax=ax)
|
|
1146
|
+
cbar.set_label("Activity", fontsize=11)
|
|
1147
|
+
|
|
1148
|
+
# Time text
|
|
1149
|
+
time_text = ax.text(
|
|
1150
|
+
0.02,
|
|
1151
|
+
0.98,
|
|
1152
|
+
"",
|
|
1153
|
+
transform=ax.transAxes,
|
|
1154
|
+
fontsize=11,
|
|
1155
|
+
verticalalignment="top",
|
|
1156
|
+
bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
|
|
1157
|
+
animated=True,
|
|
1158
|
+
)
|
|
1159
|
+
|
|
1160
|
+
def init():
|
|
1161
|
+
"""Initialize animation"""
|
|
1162
|
+
im.set_array(frame_activity_maps[0].T)
|
|
1163
|
+
time_text.set_text("")
|
|
1164
|
+
return im, time_text
|
|
1165
|
+
|
|
1166
|
+
def update(frame_idx):
|
|
1167
|
+
"""Update function - only modify data using blitting"""
|
|
1168
|
+
if frame_idx >= len(frame_activity_maps):
|
|
1169
|
+
return im, time_text
|
|
1170
|
+
|
|
1171
|
+
# Update heatmap data
|
|
1172
|
+
im.set_array(frame_activity_maps[frame_idx].T)
|
|
1173
|
+
|
|
1174
|
+
# Update time text
|
|
1175
|
+
time_text.set_text(f"Frame: {frame_idx + 1}/{len(frame_activity_maps)}")
|
|
1176
|
+
|
|
1177
|
+
return im, time_text
|
|
1178
|
+
|
|
1179
|
+
# Check blitting support
|
|
1180
|
+
use_blitting = True
|
|
1181
|
+
try:
|
|
1182
|
+
if not fig.canvas.supports_blit:
|
|
1183
|
+
use_blitting = False
|
|
1184
|
+
except AttributeError:
|
|
1185
|
+
use_blitting = False
|
|
1186
|
+
|
|
1187
|
+
interval_ms = 1000 / fps
|
|
1188
|
+
|
|
1189
|
+
def _build_animation():
|
|
1190
|
+
return animation.FuncAnimation(
|
|
1191
|
+
fig,
|
|
1192
|
+
update,
|
|
1193
|
+
frames=len(frame_activity_maps),
|
|
1194
|
+
init_func=init,
|
|
1195
|
+
interval=interval_ms,
|
|
1196
|
+
blit=use_blitting,
|
|
1197
|
+
repeat=config.repeat,
|
|
1198
|
+
)
|
|
1199
|
+
|
|
1200
|
+
ani = None
|
|
1201
|
+
progress_bar_enabled = show_progress
|
|
1202
|
+
|
|
1203
|
+
if save_path:
|
|
1204
|
+
_ensure_parent_dir(save_path)
|
|
1205
|
+
if show and len(frame_activity_maps) > 50:
|
|
1206
|
+
warn_double_rendering(len(frame_activity_maps), save_path, stacklevel=2)
|
|
1207
|
+
|
|
1208
|
+
backend_selection = select_animation_backend(
|
|
1209
|
+
save_path=save_path,
|
|
1210
|
+
requested_backend=render_backend,
|
|
1211
|
+
check_imageio_plugins=True,
|
|
1212
|
+
)
|
|
1213
|
+
emit_backend_warnings(backend_selection.warnings, stacklevel=2)
|
|
1214
|
+
backend = backend_selection.backend
|
|
1215
|
+
|
|
1216
|
+
if backend == "imageio":
|
|
1217
|
+
render_data = {
|
|
1218
|
+
"maps": frame_activity_maps,
|
|
1219
|
+
"figsize": figsize,
|
|
1220
|
+
"dpi": output_dpi,
|
|
1221
|
+
}
|
|
1222
|
+
workers = render_workers
|
|
1223
|
+
if workers is None:
|
|
1224
|
+
workers = config.render_workers
|
|
1225
|
+
if workers is None:
|
|
1226
|
+
workers = get_optimal_worker_count()
|
|
1227
|
+
try:
|
|
1228
|
+
render_animation_parallel(
|
|
1229
|
+
_render_2d_bump_frame,
|
|
1230
|
+
render_data,
|
|
1231
|
+
num_frames=len(frame_activity_maps),
|
|
1232
|
+
save_path=save_path,
|
|
1233
|
+
fps=fps,
|
|
1234
|
+
num_workers=workers,
|
|
1235
|
+
show_progress=progress_bar_enabled,
|
|
1236
|
+
)
|
|
1237
|
+
except Exception as e:
|
|
1238
|
+
import warnings
|
|
1239
|
+
|
|
1240
|
+
warnings.warn(
|
|
1241
|
+
f"imageio rendering failed: {e}. Falling back to matplotlib.",
|
|
1242
|
+
RuntimeWarning,
|
|
1243
|
+
stacklevel=2,
|
|
1244
|
+
)
|
|
1245
|
+
backend = "matplotlib"
|
|
1246
|
+
|
|
1247
|
+
if backend == "matplotlib":
|
|
1248
|
+
ani = _build_animation()
|
|
1249
|
+
writer = get_matplotlib_writer(save_path, fps=fps)
|
|
1250
|
+
if progress_bar_enabled:
|
|
1251
|
+
pbar = tqdm(total=len(frame_activity_maps), desc=f"Saving to {save_path}")
|
|
1252
|
+
|
|
1253
|
+
def progress_callback(current_frame: int, total_frames: int) -> None:
|
|
1254
|
+
pbar.update(1)
|
|
1255
|
+
|
|
1256
|
+
try:
|
|
1257
|
+
ani.save(save_path, writer=writer, progress_callback=progress_callback)
|
|
1258
|
+
finally:
|
|
1259
|
+
pbar.close()
|
|
1260
|
+
else:
|
|
1261
|
+
ani.save(save_path, writer=writer)
|
|
1262
|
+
|
|
1263
|
+
if show:
|
|
1264
|
+
if ani is None:
|
|
1265
|
+
ani = _build_animation()
|
|
1266
|
+
if is_jupyter_environment():
|
|
1267
|
+
display_animation_in_jupyter(ani)
|
|
1268
|
+
plt.close(fig)
|
|
1269
|
+
else:
|
|
1270
|
+
plt.show()
|
|
1271
|
+
else:
|
|
1272
|
+
plt.close(fig)
|
|
1273
|
+
|
|
1274
|
+
if show and is_jupyter_environment():
|
|
1275
|
+
return None
|
|
1276
|
+
return ani
|