canns 0.12.7__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. canns/analyzer/data/__init__.py +3 -11
  2. canns/analyzer/data/asa/__init__.py +74 -0
  3. canns/analyzer/data/asa/cohospace.py +905 -0
  4. canns/analyzer/data/asa/config.py +246 -0
  5. canns/analyzer/data/asa/decode.py +448 -0
  6. canns/analyzer/data/asa/embedding.py +269 -0
  7. canns/analyzer/data/asa/filters.py +208 -0
  8. canns/analyzer/data/asa/fr.py +439 -0
  9. canns/analyzer/data/asa/path.py +389 -0
  10. canns/analyzer/data/asa/plotting.py +1276 -0
  11. canns/analyzer/data/asa/tda.py +901 -0
  12. canns/analyzer/data/legacy/__init__.py +6 -0
  13. canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
  14. canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
  15. canns/analyzer/visualization/core/backend.py +1 -1
  16. canns/analyzer/visualization/core/config.py +77 -0
  17. canns/analyzer/visualization/core/rendering.py +10 -6
  18. canns/analyzer/visualization/energy_plots.py +22 -8
  19. canns/analyzer/visualization/spatial_plots.py +31 -11
  20. canns/analyzer/visualization/theta_sweep_plots.py +15 -6
  21. canns/pipeline/__init__.py +4 -8
  22. canns/pipeline/asa/__init__.py +21 -0
  23. canns/pipeline/asa/__main__.py +11 -0
  24. canns/pipeline/asa/app.py +1000 -0
  25. canns/pipeline/asa/runner.py +1095 -0
  26. canns/pipeline/asa/screens.py +215 -0
  27. canns/pipeline/asa/state.py +248 -0
  28. canns/pipeline/asa/styles.tcss +221 -0
  29. canns/pipeline/asa/widgets.py +233 -0
  30. canns/pipeline/gallery/__init__.py +7 -0
  31. canns/task/open_loop_navigation.py +3 -1
  32. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
  33. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/RECORD +36 -17
  34. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
  35. canns/pipeline/theta_sweep.py +0 -573
  36. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
  37. {canns-0.12.7.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1276 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ from matplotlib import animation, cm
9
+ from scipy import signal
10
+ from scipy.ndimage import binary_closing, gaussian_filter
11
+ from scipy.stats import binned_statistic_2d, multivariate_normal
12
+ from tqdm import tqdm
13
+
14
+ from ...visualization.core import (
15
+ PlotConfig,
16
+ emit_backend_warnings,
17
+ finalize_figure,
18
+ get_matplotlib_writer,
19
+ get_optimal_worker_count,
20
+ render_animation_parallel,
21
+ select_animation_backend,
22
+ warn_double_rendering,
23
+ )
24
+ from ...visualization.core.jupyter_utils import display_animation_in_jupyter, is_jupyter_environment
25
+ from .config import CANN2DPlotConfig, ProcessingError, SpikeEmbeddingConfig
26
+ from .embedding import embed_spike_trains
27
+
28
+
29
+ def _ensure_plot_config(
30
+ config: PlotConfig | None,
31
+ factory,
32
+ *,
33
+ kwargs: dict[str, Any] | None = None,
34
+ **defaults: Any,
35
+ ) -> PlotConfig:
36
+ if config is None:
37
+ defaults.update({"kwargs": kwargs or {}})
38
+ return factory(**defaults)
39
+
40
+ if kwargs:
41
+ config_kwargs = config.kwargs or {}
42
+ config_kwargs.update(kwargs)
43
+ config.kwargs = config_kwargs
44
+ return config
45
+
46
+
47
+ def _ensure_parent_dir(save_path: str | None) -> None:
48
+ if save_path:
49
+ parent = os.path.dirname(save_path)
50
+ if parent:
51
+ os.makedirs(parent, exist_ok=True)
52
+
53
+
54
+ def _render_torus_frame(frame_index: int, frame_data: dict[str, Any]) -> np.ndarray:
55
+ from io import BytesIO
56
+
57
+ import numpy as np
58
+
59
+ fig = plt.figure(figsize=frame_data["figsize"])
60
+ ax = fig.add_subplot(111, projection="3d")
61
+ ax.set_zlim(*frame_data["zlim"])
62
+ ax.view_init(frame_data["elev"], frame_data["azim"])
63
+ ax.axis("off")
64
+
65
+ frame = frame_data["frames"][frame_index]
66
+ m = frame["m"]
67
+
68
+ ax.plot_surface(
69
+ frame_data["torus_x"],
70
+ frame_data["torus_y"],
71
+ frame_data["torus_z"],
72
+ facecolors=cm.viridis(m / (np.max(m) + 1e-9)),
73
+ alpha=1,
74
+ linewidth=0.1,
75
+ antialiased=True,
76
+ rstride=1,
77
+ cstride=1,
78
+ shade=False,
79
+ )
80
+
81
+ time_label = frame.get("time")
82
+ label_text = f"Frame: {frame_index + 1}/{len(frame_data['frames'])}"
83
+ if time_label is not None:
84
+ label_text = f"{label_text} | Time: {time_label}"
85
+ ax.text2D(
86
+ 0.05,
87
+ 0.95,
88
+ label_text,
89
+ transform=ax.transAxes,
90
+ fontsize=12,
91
+ bbox=dict(facecolor="white", alpha=0.7),
92
+ )
93
+
94
+ fig.tight_layout()
95
+
96
+ buf = BytesIO()
97
+ fig.savefig(buf, format="png", dpi=frame_data["dpi"], bbox_inches="tight")
98
+ buf.seek(0)
99
+ img = plt.imread(buf)
100
+ plt.close(fig)
101
+ buf.close()
102
+
103
+ if img.dtype in (np.float32, np.float64):
104
+ img = (img * 255).astype(np.uint8)
105
+
106
+ return img
107
+
108
+
109
+ def _render_2d_bump_frame(frame_index: int, frame_data: dict[str, Any]) -> np.ndarray:
110
+ from io import BytesIO
111
+
112
+ fig, ax = plt.subplots(figsize=frame_data["figsize"])
113
+ ax.set_xlabel("Manifold Dimension 1 (rad)", fontsize=12)
114
+ ax.set_ylabel("Manifold Dimension 2 (rad)", fontsize=12)
115
+ ax.set_title("CANN2D Bump Activity (2D Projection)", fontsize=14, fontweight="bold")
116
+
117
+ im = ax.imshow(
118
+ frame_data["maps"][frame_index].T,
119
+ extent=[0, 2 * np.pi, 0, 2 * np.pi],
120
+ origin="lower",
121
+ cmap="viridis",
122
+ aspect="auto",
123
+ )
124
+ fig.colorbar(im, ax=ax).set_label("Activity", fontsize=11)
125
+ ax.text(
126
+ 0.02,
127
+ 0.98,
128
+ f"Frame: {frame_index + 1}/{len(frame_data['maps'])}",
129
+ transform=ax.transAxes,
130
+ fontsize=11,
131
+ verticalalignment="top",
132
+ bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
133
+ )
134
+
135
+ fig.tight_layout()
136
+ buf = BytesIO()
137
+ fig.savefig(buf, format="png", dpi=frame_data["dpi"], bbox_inches="tight")
138
+ buf.seek(0)
139
+ img = plt.imread(buf)
140
+ plt.close(fig)
141
+ buf.close()
142
+
143
+ if img.dtype in (np.float32, np.float64):
144
+ img = (img * 255).astype(np.uint8)
145
+
146
+ return img
147
+
148
+
149
+ def plot_projection(
150
+ reduce_func,
151
+ embed_data,
152
+ config: CANN2DPlotConfig | None = None,
153
+ title="Projection (3D)",
154
+ xlabel="Component 1",
155
+ ylabel="Component 2",
156
+ zlabel="Component 3",
157
+ save_path=None,
158
+ show=True,
159
+ dpi=300,
160
+ figsize=(10, 8),
161
+ **kwargs,
162
+ ):
163
+ """
164
+ Plot a 3D projection of the embedded data.
165
+
166
+ Parameters
167
+ ----------
168
+ reduce_func (callable): Function to reduce the dimensionality of the data.
169
+ embed_data (ndarray): Data to be projected.
170
+ config (PlotConfig, optional): Configuration object for unified plotting parameters
171
+ **kwargs: backward compatibility parameters
172
+ title (str): Title of the plot.
173
+ xlabel (str): Label for the x-axis.
174
+ ylabel (str): Label for the y-axis.
175
+ zlabel (str): Label for the z-axis.
176
+ save_path (str, optional): Path to save the plot. If None, plot will not be saved.
177
+ show (bool): Whether to display the plot.
178
+ dpi (int): Dots per inch for saving the figure.
179
+ figsize (tuple): Size of the figure.
180
+
181
+ Returns
182
+ -------
183
+ matplotlib.figure.Figure
184
+ The created figure.
185
+
186
+ Examples
187
+ --------
188
+ >>> fig = plot_projection(reduce_func, embed_data, show=False) # doctest: +SKIP
189
+ """
190
+
191
+ # Handle backward compatibility and configuration
192
+ if config is None:
193
+ config = CANN2DPlotConfig.for_projection_3d(
194
+ title=title,
195
+ xlabel=xlabel,
196
+ ylabel=ylabel,
197
+ zlabel=zlabel,
198
+ save_path=save_path,
199
+ show=show,
200
+ figsize=figsize,
201
+ dpi=dpi,
202
+ **kwargs,
203
+ )
204
+ else:
205
+ if save_path is not None:
206
+ config.save_path = save_path
207
+ if show is not None:
208
+ config.show = show
209
+ if not config.title:
210
+ config.title = title
211
+ if not config.xlabel:
212
+ config.xlabel = xlabel
213
+ if not config.ylabel:
214
+ config.ylabel = ylabel
215
+ if not config.zlabel:
216
+ config.zlabel = zlabel
217
+ if config.figsize == PlotConfig().figsize:
218
+ config.figsize = figsize
219
+ if dpi is not None:
220
+ config.dpi = dpi
221
+
222
+ reduced_data = reduce_func(embed_data[::5])
223
+
224
+ fig = plt.figure(figsize=config.figsize)
225
+ ax = fig.add_subplot(111, projection="3d")
226
+ ax.scatter(reduced_data[:, 0], reduced_data[:, 1], reduced_data[:, 2], s=1, alpha=0.5)
227
+
228
+ ax.set_title(config.title)
229
+ ax.set_xlabel(config.xlabel)
230
+ ax.set_ylabel(config.ylabel)
231
+ ax.set_zlabel(config.zlabel)
232
+
233
+ config.save_dpi = getattr(config, "dpi", config.save_dpi)
234
+ _ensure_parent_dir(config.save_path)
235
+ finalize_figure(fig, config)
236
+
237
+ return fig
238
+
239
+
240
+ def plot_path_compare(
241
+ x: np.ndarray,
242
+ y: np.ndarray,
243
+ coords: np.ndarray,
244
+ config: PlotConfig | None = None,
245
+ *,
246
+ title: str = "Path Compare",
247
+ figsize: tuple[int, int] = (12, 5),
248
+ show: bool = True,
249
+ save_path: str | None = None,
250
+ ) -> tuple[plt.Figure, np.ndarray]:
251
+ """Plot physical path vs decoded coho-space path side-by-side.
252
+
253
+ Parameters
254
+ ----------
255
+ x, y : np.ndarray
256
+ Physical position arrays of shape (T,).
257
+ coords : np.ndarray
258
+ Decoded circular coordinates, shape (T, 1) or (T, 2).
259
+ config : PlotConfig, optional
260
+ Plot configuration. If None, a default config is created.
261
+ title, figsize, show, save_path : optional
262
+ Backward-compatibility parameters.
263
+
264
+ Returns
265
+ -------
266
+ (Figure, ndarray)
267
+ Figure and axes array.
268
+
269
+ Examples
270
+ --------
271
+ >>> fig, axes = plot_path_compare(x, y, coords, show=False) # doctest: +SKIP
272
+ """
273
+ from .path import draw_base_parallelogram, skew_transform, snake_wrap_trail_in_parallelogram
274
+
275
+ x = np.asarray(x).ravel()
276
+ y = np.asarray(y).ravel()
277
+ coords = np.asarray(coords)
278
+
279
+ if coords.ndim != 2 or coords.shape[1] < 1:
280
+ raise ValueError(f"coords must be 2D with at least 1 column, got {coords.shape}")
281
+
282
+ config = _ensure_plot_config(
283
+ config,
284
+ PlotConfig.for_static_plot,
285
+ title=title,
286
+ figsize=figsize,
287
+ save_path=save_path,
288
+ show=show,
289
+ )
290
+
291
+ fig, axes = plt.subplots(1, 2, figsize=config.figsize)
292
+ if config.title:
293
+ fig.suptitle(config.title)
294
+
295
+ ax0 = axes[0]
296
+ ax0.set_title("Physical path (x,y)")
297
+ ax0.set_aspect("equal", "box")
298
+ ax0.axis("off")
299
+ ax0.plot(x, y, lw=0.9, alpha=0.8)
300
+
301
+ ax1 = axes[1]
302
+ ax1.set_title("Decoded coho path")
303
+ ax1.set_aspect("equal", "box")
304
+ ax1.axis("off")
305
+
306
+ if coords.shape[1] >= 2:
307
+ theta2 = coords[:, :2] % (2 * np.pi)
308
+ xy = skew_transform(theta2)
309
+ draw_base_parallelogram(ax1)
310
+ trail = snake_wrap_trail_in_parallelogram(
311
+ xy, np.array([2 * np.pi, 0.0]), np.array([np.pi, np.sqrt(3) * np.pi])
312
+ )
313
+ ax1.plot(trail[:, 0], trail[:, 1], lw=0.9, alpha=0.9)
314
+ else:
315
+ th = coords[:, 0] % (2 * np.pi)
316
+ ax1.plot(np.cos(th), np.sin(th), lw=0.9, alpha=0.9)
317
+ ax1.set_xlim(-1.2, 1.2)
318
+ ax1.set_ylim(-1.2, 1.2)
319
+
320
+ fig.tight_layout()
321
+ _ensure_parent_dir(config.save_path)
322
+ finalize_figure(fig, config)
323
+ return fig, axes
324
+
325
+
326
+ def plot_cohomap(
327
+ decoding_result: dict[str, Any],
328
+ position_data: dict[str, Any],
329
+ config: PlotConfig | None = None,
330
+ save_path: str | None = None,
331
+ show: bool = False,
332
+ figsize: tuple[int, int] = (10, 4),
333
+ dpi: int = 300,
334
+ subsample: int = 10,
335
+ ) -> plt.Figure:
336
+ """
337
+ Visualize CohoMap 1.0: decoded circular coordinates mapped onto spatial trajectory.
338
+
339
+ Creates a two-panel visualization showing how the two decoded circular coordinates
340
+ vary across the animal's spatial trajectory. Each panel displays the spatial path
341
+ colored by the cosine of one circular coordinate dimension.
342
+
343
+ Parameters:
344
+ decoding_result : dict
345
+ Dictionary from decode_circular_coordinates() containing:
346
+ - 'coordsbox': decoded coordinates for box timepoints (n_times x n_dims)
347
+ - 'times_box': time indices for coordsbox
348
+ position_data : dict
349
+ Position data containing 'x' and 'y' arrays for spatial coordinates
350
+ save_path : str, optional
351
+ Path to save the visualization. If None, no save performed
352
+ show : bool, default=False
353
+ Whether to display the visualization
354
+ figsize : tuple[int, int], default=(10, 4)
355
+ Figure size (width, height) in inches
356
+ dpi : int, default=300
357
+ Resolution for saved figure
358
+ subsample : int, default=10
359
+ Subsampling interval for plotting (plot every Nth timepoint)
360
+
361
+ Returns
362
+ -------
363
+ matplotlib.figure.Figure
364
+ The matplotlib figure object.
365
+
366
+ Raises:
367
+ KeyError : If required keys are missing from input dictionaries
368
+ ValueError : If data dimensions are inconsistent
369
+ IndexError : If time indices are out of bounds
370
+
371
+ Examples
372
+ --------
373
+ >>> # Decode coordinates
374
+ >>> decoding = decode_circular_coordinates(persistence_result, spike_data)
375
+ >>> # Visualize with trajectory data
376
+ >>> fig = plot_cohomap(
377
+ ... decoding,
378
+ ... position_data={'x': xx, 'y': yy},
379
+ ... save_path='cohomap.png',
380
+ ... show=True
381
+ ... )
382
+ """
383
+ config = _ensure_plot_config(
384
+ config,
385
+ PlotConfig.for_static_plot,
386
+ title="CohoMap",
387
+ xlabel="",
388
+ ylabel="",
389
+ figsize=figsize,
390
+ save_path=save_path,
391
+ show=show,
392
+ )
393
+ config.save_dpi = dpi
394
+
395
+ # Extract data
396
+ coordsbox = decoding_result["coordsbox"]
397
+ times_box = decoding_result["times_box"]
398
+ xx = position_data["x"]
399
+ yy = position_data["y"]
400
+
401
+ # Subsample time indices for plotting
402
+ plot_times = np.arange(0, len(coordsbox), subsample)
403
+
404
+ # Create a two-panel figure (one per cohomology dimension)
405
+ plt.set_cmap("viridis")
406
+ fig, ax = plt.subplots(1, 2, figsize=config.figsize)
407
+
408
+ # Plot for the first circular coordinate
409
+ ax[0].axis("off")
410
+ ax[0].set_aspect("equal", "box")
411
+ im0 = ax[0].scatter(
412
+ xx[times_box][plot_times],
413
+ yy[times_box][plot_times],
414
+ c=np.cos(coordsbox[plot_times, 0]),
415
+ s=8,
416
+ cmap="viridis",
417
+ )
418
+ plt.colorbar(im0, ax=ax[0], label="cos(coord)")
419
+ ax[0].set_title("CohoMap Dim 1", fontsize=10)
420
+
421
+ # Plot for the second circular coordinate
422
+ ax[1].axis("off")
423
+ ax[1].set_aspect("equal", "box")
424
+ im1 = ax[1].scatter(
425
+ xx[times_box][plot_times],
426
+ yy[times_box][plot_times],
427
+ c=np.cos(coordsbox[plot_times, 1]),
428
+ s=8,
429
+ cmap="viridis",
430
+ )
431
+ plt.colorbar(im1, ax=ax[1], label="cos(coord)")
432
+ ax[1].set_title("CohoMap Dim 2", fontsize=10)
433
+
434
+ fig.tight_layout()
435
+
436
+ _ensure_parent_dir(config.save_path)
437
+ finalize_figure(fig, config)
438
+ return fig
439
+
440
+
441
+ def plot_cohomap_multi(
442
+ decoding_result: dict,
443
+ position_data: dict,
444
+ config: PlotConfig | None = None,
445
+ save_path: str | None = None,
446
+ show: bool = False,
447
+ figsize: tuple[int, int] = (10, 4),
448
+ dpi: int = 300,
449
+ subsample: int = 10,
450
+ ) -> plt.Figure:
451
+ """
452
+ Visualize CohoMap with N-dimensional decoded coordinates.
453
+
454
+ Each subplot shows the spatial trajectory colored by ``cos(coord_i)`` for a single
455
+ circular coordinate.
456
+
457
+ Parameters
458
+ ----------
459
+ decoding_result : dict
460
+ Dictionary containing ``coordsbox`` and ``times_box``.
461
+ position_data : dict
462
+ Position data containing ``x`` and ``y`` arrays.
463
+ config : PlotConfig, optional
464
+ Plot configuration for styling, saving, and showing.
465
+ save_path : str, optional
466
+ Path to save the figure.
467
+ show : bool
468
+ Whether to show the figure.
469
+ figsize : tuple[int, int]
470
+ Figure size in inches.
471
+ dpi : int
472
+ Save DPI.
473
+ subsample : int
474
+ Subsample stride for plotting.
475
+
476
+ Returns
477
+ -------
478
+ matplotlib.figure.Figure
479
+ The created figure.
480
+
481
+ Examples
482
+ --------
483
+ >>> fig = plot_cohomap_multi(decoding, {"x": xx, "y": yy}, show=False) # doctest: +SKIP
484
+ """
485
+ config = _ensure_plot_config(
486
+ config,
487
+ PlotConfig.for_static_plot,
488
+ title="CohoMap",
489
+ xlabel="",
490
+ ylabel="",
491
+ figsize=figsize,
492
+ save_path=save_path,
493
+ show=show,
494
+ )
495
+ config.save_dpi = dpi
496
+
497
+ coordsbox = decoding_result["coordsbox"]
498
+ times_box = decoding_result["times_box"]
499
+ xx = position_data["x"]
500
+ yy = position_data["y"]
501
+
502
+ plot_times = np.arange(0, len(coordsbox), subsample)
503
+ num_dims = coordsbox.shape[1]
504
+
505
+ fig, axes = plt.subplots(1, num_dims, figsize=(5 * num_dims, 4))
506
+ if num_dims == 1:
507
+ axes = [axes]
508
+
509
+ for i in range(num_dims):
510
+ axes[i].axis("off")
511
+ axes[i].set_aspect("equal", "box")
512
+ im = axes[i].scatter(
513
+ xx[times_box][plot_times],
514
+ yy[times_box][plot_times],
515
+ c=np.cos(coordsbox[plot_times, i]),
516
+ s=8,
517
+ cmap="viridis",
518
+ )
519
+ plt.colorbar(im, ax=axes[i], label=f"cos(coord {i + 1})")
520
+ axes[i].set_title(f"CohoMap Dim {i + 1}")
521
+
522
+ fig.tight_layout()
523
+ _ensure_parent_dir(config.save_path)
524
+ finalize_figure(fig, config)
525
+ return fig
526
+
527
+
528
+ def plot_3d_bump_on_torus(
529
+ decoding_result: dict[str, Any] | str,
530
+ spike_data: dict[str, Any],
531
+ config: CANN2DPlotConfig | None = None,
532
+ save_path: str | None = None,
533
+ numangsint: int = 51,
534
+ r1: float = 1.5,
535
+ r2: float = 1.0,
536
+ window_size: int = 300,
537
+ frame_step: int = 5,
538
+ n_frames: int = 20,
539
+ fps: int = 5,
540
+ show_progress: bool = True,
541
+ show: bool = True,
542
+ figsize: tuple[int, int] = (8, 8),
543
+ render_backend: str | None = "auto",
544
+ output_dpi: int = 150,
545
+ render_workers: int | None = None,
546
+ **kwargs,
547
+ ) -> animation.FuncAnimation | None:
548
+ """
549
+ Visualize the movement of the neural activity bump on a torus using matplotlib animation.
550
+
551
+ This function follows the canns.analyzer.plotting patterns for animation generation
552
+ with progress tracking and proper resource cleanup.
553
+
554
+ Parameters:
555
+ decoding_result : dict or str
556
+ Dictionary containing decoding results with 'coordsbox' and 'times_box' keys,
557
+ or path to .npz file containing these results
558
+ spike_data : dict, optional
559
+ Spike data dictionary containing spike information
560
+ config : PlotConfig, optional
561
+ Configuration object for unified plotting parameters
562
+ **kwargs : backward compatibility parameters
563
+ save_path : str, optional
564
+ Path to save the animation (e.g., 'animation.gif' or 'animation.mp4')
565
+ numangsint : int
566
+ Grid resolution for the torus surface
567
+ r1 : float
568
+ Major radius of the torus
569
+ r2 : float
570
+ Minor radius of the torus
571
+ window_size : int
572
+ Time window (in number of time points) for each frame
573
+ frame_step : int
574
+ Step size to slide the time window between frames
575
+ n_frames : int
576
+ Total number of frames in the animation
577
+ fps : int
578
+ Frames per second for the output animation
579
+ show_progress : bool
580
+ Whether to show progress bar during generation
581
+ show : bool
582
+ Whether to display the animation
583
+ figsize : tuple[int, int]
584
+ Figure size for the animation
585
+
586
+ Returns
587
+ -------
588
+ matplotlib.animation.FuncAnimation | None
589
+ The animation object, or None when shown in Jupyter.
590
+
591
+ Examples
592
+ --------
593
+ >>> ani = plot_3d_bump_on_torus(decoding, spike_data, show=False) # doctest: +SKIP
594
+ """
595
+ # Handle backward compatibility and configuration
596
+ if config is None:
597
+ config = CANN2DPlotConfig.for_torus_animation(
598
+ title=kwargs.get("title", "3D Bump on Torus"),
599
+ figsize=figsize,
600
+ fps=fps,
601
+ repeat=True,
602
+ show_progress_bar=show_progress,
603
+ save_path=save_path,
604
+ show=show,
605
+ numangsint=numangsint,
606
+ r1=r1,
607
+ r2=r2,
608
+ window_size=window_size,
609
+ frame_step=frame_step,
610
+ n_frames=n_frames,
611
+ **kwargs,
612
+ )
613
+ else:
614
+ if save_path is not None:
615
+ config.save_path = save_path
616
+ if show is not None:
617
+ config.show = show
618
+ if figsize is not None:
619
+ config.figsize = figsize
620
+ if fps is not None:
621
+ config.fps = fps
622
+ if show_progress is not None:
623
+ config.show_progress_bar = show_progress
624
+ config.numangsint = numangsint
625
+ config.r1 = r1
626
+ config.r2 = r2
627
+ config.window_size = window_size
628
+ config.frame_step = frame_step
629
+ config.n_frames = n_frames
630
+
631
+ for key, value in kwargs.items():
632
+ if hasattr(config, key):
633
+ setattr(config, key, value)
634
+
635
+ # Extract configuration values
636
+ save_path = config.save_path
637
+ show = config.show
638
+ figsize = config.figsize
639
+ fps = config.fps
640
+ show_progress = config.show_progress_bar
641
+ numangsint = config.numangsint
642
+ r1 = config.r1
643
+ r2 = config.r2
644
+ window_size = config.window_size
645
+ frame_step = config.frame_step
646
+ n_frames = config.n_frames
647
+
648
+ # Load decoding results if path is provided
649
+ if isinstance(decoding_result, str):
650
+ f = np.load(decoding_result, allow_pickle=True)
651
+ coords = f["coordsbox"]
652
+ times = f["times_box"]
653
+ f.close()
654
+ else:
655
+ coords = decoding_result["coordsbox"]
656
+ times = decoding_result["times_box"]
657
+
658
+ spk, *_ = embed_spike_trains(
659
+ spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
660
+ )
661
+
662
+ # Pre-compute torus geometry (constant across frames - optimization)
663
+ # Create grid for torus surface
664
+ x_edge = np.linspace(0, 2 * np.pi, numangsint)
665
+ y_edge = np.linspace(0, 2 * np.pi, numangsint)
666
+ X_grid, Y_grid = np.meshgrid(x_edge, y_edge)
667
+ X_transformed = (X_grid + np.pi / 5) % (2 * np.pi)
668
+
669
+ # Pre-compute torus geometry (only done once!)
670
+ torus_x = (r1 + r2 * np.cos(X_transformed)) * np.cos(Y_grid)
671
+ torus_y = (r1 + r2 * np.cos(X_transformed)) * np.sin(Y_grid)
672
+ torus_z = -r2 * np.sin(X_transformed) # Flip torus surface orientation
673
+
674
+ # Prepare animation data (now only stores colors, not geometry)
675
+ frame_data = []
676
+ prev_m = None
677
+
678
+ for frame_idx in tqdm(range(n_frames), desc="Processing frames"):
679
+ start_idx = frame_idx * frame_step
680
+ end_idx = start_idx + window_size
681
+ if end_idx > np.max(times):
682
+ break
683
+
684
+ mask = (times >= start_idx) & (times < end_idx)
685
+ coords_window = coords[mask]
686
+ if len(coords_window) == 0:
687
+ continue
688
+
689
+ spk_window = spk[times[mask], :]
690
+ activity = np.sum(spk_window, axis=1)
691
+
692
+ m, _, _, _ = binned_statistic_2d(
693
+ coords_window[:, 0],
694
+ coords_window[:, 1],
695
+ activity,
696
+ statistic="sum",
697
+ bins=np.linspace(0, 2 * np.pi, numangsint - 1),
698
+ )
699
+ m = np.nan_to_num(m)
700
+ m = _smooth_tuning_map(m, numangsint - 1, sig=4.0, bClose=True)
701
+ m = gaussian_filter(m, sigma=1.0)
702
+
703
+ if prev_m is not None:
704
+ m = 0.7 * prev_m + 0.3 * m
705
+ prev_m = m
706
+
707
+ # Store only activity map (m) and metadata, reuse geometry
708
+ frame_data.append({"m": m, "time": start_idx * frame_step})
709
+
710
+ if not frame_data:
711
+ raise ProcessingError("No valid frames generated for animation")
712
+
713
+ # Create figure and animation with optimized geometry reuse
714
+ fig = plt.figure(figsize=figsize)
715
+
716
+ try:
717
+ ax = fig.add_subplot(111, projection="3d")
718
+ # Batch set axis properties (reduces overhead)
719
+ ax.set_zlim(-2, 2)
720
+ ax.view_init(-125, 135)
721
+ ax.axis("off")
722
+
723
+ # Initialize with first frame
724
+ first_frame = frame_data[0]
725
+ ax.plot_surface(
726
+ torus_x, # Pre-computed geometry
727
+ torus_y, # Pre-computed geometry
728
+ torus_z, # Pre-computed geometry
729
+ facecolors=cm.viridis(first_frame["m"] / (np.max(first_frame["m"]) + 1e-9)),
730
+ alpha=1,
731
+ linewidth=0.1,
732
+ antialiased=True,
733
+ rstride=1,
734
+ cstride=1,
735
+ shade=False,
736
+ )
737
+
738
+ def animate(frame_idx):
739
+ """Optimized animation update - reuses pre-computed geometry."""
740
+ frame = frame_data[frame_idx]
741
+
742
+ # 3D surfaces require clear (no blitting support), but minimize overhead
743
+ ax.clear()
744
+
745
+ # Batch axis settings together (reduces function call overhead)
746
+ ax.set_zlim(-2, 2)
747
+ ax.view_init(-125, 135)
748
+ ax.axis("off")
749
+
750
+ # Reuse pre-computed geometry, only update colors
751
+ new_surface = ax.plot_surface(
752
+ torus_x, # Pre-computed, not recalculated!
753
+ torus_y, # Pre-computed, not recalculated!
754
+ torus_z, # Pre-computed, not recalculated!
755
+ facecolors=cm.viridis(frame["m"] / (np.max(frame["m"]) + 1e-9)),
756
+ alpha=1,
757
+ linewidth=0.1,
758
+ antialiased=True,
759
+ rstride=1,
760
+ cstride=1,
761
+ shade=False,
762
+ )
763
+
764
+ # Update time text
765
+ time_text = ax.text2D(
766
+ 0.05,
767
+ 0.95,
768
+ f"Frame: {frame_idx + 1}/{len(frame_data)}",
769
+ transform=ax.transAxes,
770
+ fontsize=12,
771
+ bbox=dict(facecolor="white", alpha=0.7),
772
+ )
773
+
774
+ return new_surface, time_text
775
+
776
+ # Create animation (blit=False due to 3D limitation)
777
+ interval_ms = 1000 / fps
778
+ ani = None
779
+ progress_bar_enabled = show_progress
780
+
781
+ if save_path:
782
+ _ensure_parent_dir(save_path)
783
+ if show and len(frame_data) > 50:
784
+ warn_double_rendering(len(frame_data), save_path, stacklevel=2)
785
+
786
+ backend_selection = select_animation_backend(
787
+ save_path=save_path,
788
+ requested_backend=render_backend,
789
+ check_imageio_plugins=True,
790
+ )
791
+ emit_backend_warnings(backend_selection.warnings, stacklevel=2)
792
+ backend = backend_selection.backend
793
+
794
+ if backend == "imageio":
795
+ render_data = {
796
+ "frames": frame_data,
797
+ "torus_x": torus_x,
798
+ "torus_y": torus_y,
799
+ "torus_z": torus_z,
800
+ "figsize": figsize,
801
+ "dpi": output_dpi,
802
+ "elev": -125,
803
+ "azim": 135,
804
+ "zlim": (-2, 2),
805
+ }
806
+ workers = render_workers
807
+ if workers is None:
808
+ workers = config.render_workers
809
+ if workers is None:
810
+ workers = get_optimal_worker_count()
811
+ try:
812
+ render_animation_parallel(
813
+ _render_torus_frame,
814
+ render_data,
815
+ num_frames=len(frame_data),
816
+ save_path=save_path,
817
+ fps=fps,
818
+ num_workers=workers,
819
+ show_progress=progress_bar_enabled,
820
+ )
821
+ except Exception as e:
822
+ import warnings
823
+
824
+ warnings.warn(
825
+ f"imageio rendering failed: {e}. Falling back to matplotlib.",
826
+ RuntimeWarning,
827
+ stacklevel=2,
828
+ )
829
+ backend = "matplotlib"
830
+
831
+ if backend == "matplotlib":
832
+ ani = animation.FuncAnimation(
833
+ fig,
834
+ animate,
835
+ frames=len(frame_data),
836
+ interval=interval_ms,
837
+ blit=False,
838
+ repeat=config.repeat,
839
+ )
840
+
841
+ writer = get_matplotlib_writer(save_path, fps=fps)
842
+ if progress_bar_enabled:
843
+ pbar = tqdm(total=len(frame_data), desc=f"Saving to {save_path}")
844
+
845
+ def progress_callback(current_frame: int, total_frames: int) -> None:
846
+ pbar.update(1)
847
+
848
+ try:
849
+ ani.save(save_path, writer=writer, progress_callback=progress_callback)
850
+ finally:
851
+ pbar.close()
852
+ else:
853
+ ani.save(save_path, writer=writer)
854
+
855
+ if show:
856
+ if ani is None:
857
+ ani = animation.FuncAnimation(
858
+ fig,
859
+ animate,
860
+ frames=len(frame_data),
861
+ interval=interval_ms,
862
+ blit=False,
863
+ repeat=config.repeat,
864
+ )
865
+ if is_jupyter_environment():
866
+ display_animation_in_jupyter(ani)
867
+ plt.close(fig)
868
+ else:
869
+ plt.show()
870
+ else:
871
+ plt.close(fig)
872
+
873
+ if show and is_jupyter_environment():
874
+ return None
875
+ return ani
876
+
877
+ except Exception as e:
878
+ plt.close(fig)
879
+ raise ProcessingError(f"Failed to create torus animation: {e}") from e
880
+
881
+
882
+ def _smooth_tuning_map(mtot, numangsint, sig, bClose=True):
883
+ """
884
+ Smooth activity map over circular topology (e.g., torus).
885
+
886
+ Parameters:
887
+ mtot (ndarray): Raw activity map matrix.
888
+ numangsint (int): Grid resolution.
889
+ sig (float): Smoothing kernel standard deviation.
890
+ bClose (bool): Whether to assume circular boundary conditions.
891
+
892
+ Returns:
893
+ mtot_out (ndarray): Smoothed map matrix.
894
+ """
895
+ numangsint_1 = numangsint - 1
896
+ indstemp1 = np.zeros((numangsint_1, numangsint_1), dtype=int)
897
+ indstemp1[indstemp1 == 0] = np.arange((numangsint_1) ** 2)
898
+ mid = int((numangsint_1) / 2)
899
+ mtemp1_3 = mtot.copy()
900
+ for i in range(numangsint_1):
901
+ mtemp1_3[i, :] = np.roll(mtemp1_3[i, :], int(i / 2))
902
+ mtot_out = np.zeros_like(mtot)
903
+ mtemp1_4 = np.concatenate((mtemp1_3, mtemp1_3, mtemp1_3), 1)
904
+ mtemp1_5 = np.zeros_like(mtemp1_4)
905
+ mtemp1_5[:, :mid] = mtemp1_4[:, (numangsint_1) * 3 - mid :]
906
+ mtemp1_5[:, mid:] = mtemp1_4[:, : (numangsint_1) * 3 - mid]
907
+ if bClose:
908
+ mtemp1_6 = _smooth_image(np.concatenate((mtemp1_5, mtemp1_4, mtemp1_5)), sigma=sig)
909
+ else:
910
+ mtemp1_6 = gaussian_filter(np.concatenate((mtemp1_5, mtemp1_4, mtemp1_5)), sigma=sig)
911
+ for i in range(numangsint_1):
912
+ mtot_out[i, :] = mtemp1_6[
913
+ (numangsint_1) + i,
914
+ (numangsint_1) + (int(i / 2) + 1) : (numangsint_1) * 2 + (int(i / 2) + 1),
915
+ ]
916
+ return mtot_out
917
+
918
+
919
+ def _smooth_image(img, sigma):
920
+ """
921
+ Smooth image using multivariate Gaussian kernel, handling missing (NaN) values.
922
+
923
+ Parameters:
924
+ img (ndarray): Input image matrix.
925
+ sigma (float): Standard deviation of smoothing kernel.
926
+
927
+ Returns:
928
+ imgC (ndarray): Smoothed image with inpainting around NaNs.
929
+ """
930
+ filterSize = max(np.shape(img))
931
+ grid = np.arange(-filterSize + 1, filterSize, 1)
932
+ xx, yy = np.meshgrid(grid, grid)
933
+
934
+ pos = np.dstack((xx, yy))
935
+
936
+ var = multivariate_normal(mean=[0, 0], cov=[[sigma**2, 0], [0, sigma**2]])
937
+ k = var.pdf(pos)
938
+ k = k / np.sum(k)
939
+
940
+ nans = np.isnan(img)
941
+ imgA = img.copy()
942
+ imgA[nans] = 0
943
+ imgA = signal.convolve2d(imgA, k, mode="valid")
944
+ imgD = img.copy()
945
+ imgD[nans] = 0
946
+ imgD[~nans] = 1
947
+ radius = 1
948
+ L = np.arange(-radius, radius + 1)
949
+ X, Y = np.meshgrid(L, L)
950
+ dk = np.array((X**2 + Y**2) <= radius**2, dtype=bool)
951
+ imgE = np.zeros((filterSize + 2, filterSize + 2))
952
+ imgE[1:-1, 1:-1] = imgD
953
+ imgE = binary_closing(imgE, iterations=1, structure=dk)
954
+ imgD = imgE[1:-1, 1:-1]
955
+
956
+ imgB = np.divide(
957
+ signal.convolve2d(imgD, k, mode="valid"),
958
+ signal.convolve2d(np.ones(np.shape(imgD)), k, mode="valid"),
959
+ )
960
+ imgC = np.divide(imgA, imgB)
961
+ imgC[imgD == 0] = -np.inf
962
+ return imgC
963
+
964
+
965
+ def plot_2d_bump_on_manifold(
966
+ decoding_result: dict[str, Any] | str,
967
+ spike_data: dict[str, Any],
968
+ save_path: str | None = None,
969
+ fps: int = 20,
970
+ show: bool = True,
971
+ mode: str = "fast",
972
+ window_size: int = 10,
973
+ frame_step: int = 5,
974
+ numangsint: int = 20,
975
+ figsize: tuple[int, int] = (8, 6),
976
+ show_progress: bool = False,
977
+ config: PlotConfig | None = None,
978
+ render_backend: str | None = "auto",
979
+ output_dpi: int = 150,
980
+ render_workers: int | None = None,
981
+ ) -> animation.FuncAnimation | None:
982
+ """
983
+ Create 2D projection animation of CANN2D bump activity with full blitting support.
984
+
985
+ This function provides a fast 2D heatmap visualization as an alternative to the
986
+ 3D torus animation. It achieves 10-20x speedup using matplotlib blitting
987
+ optimization, making it ideal for rapid prototyping and daily analysis.
988
+
989
+ Args:
990
+ decoding_result: Decoding results containing coords and times (dict or file path)
991
+ spike_data: Dictionary containing spike train data
992
+ save_path: Path to save animation (None to skip saving)
993
+ fps: Frames per second
994
+ show: Whether to display the animation
995
+ mode: Visualization mode - 'fast' for 2D heatmap (default), '3d' falls back to 3D
996
+ window_size: Time window for activity aggregation
997
+ frame_step: Time step between frames
998
+ numangsint: Number of angular bins for spatial discretization
999
+ figsize: Figure size (width, height) in inches
1000
+ show_progress: Show progress bar during processing
1001
+
1002
+ Returns
1003
+ -------
1004
+ matplotlib.animation.FuncAnimation | None
1005
+ Animation object (or None in Jupyter when showing).
1006
+
1007
+ Raises:
1008
+ ProcessingError: If mode is invalid or animation generation fails
1009
+
1010
+ Examples
1011
+ --------
1012
+ >>> # Fast 2D visualization (recommended for daily use)
1013
+ >>> ani = plot_2d_bump_on_manifold(
1014
+ ... decoding_result, spike_data,
1015
+ ... save_path='bump_2d.mp4', mode='fast'
1016
+ ... )
1017
+ >>> # For publication-ready 3D visualization, use mode='3d'
1018
+ >>> ani = plot_2d_bump_on_manifold(
1019
+ ... decoding_result, spike_data, mode='3d'
1020
+ ... )
1021
+ """
1022
+ import matplotlib.animation as animation
1023
+
1024
+ # Validate inputs
1025
+ if mode == "3d":
1026
+ # Fall back to 3D visualization
1027
+ return plot_3d_bump_on_torus(
1028
+ decoding_result=decoding_result,
1029
+ spike_data=spike_data,
1030
+ save_path=save_path,
1031
+ fps=fps,
1032
+ show=show,
1033
+ window_size=window_size,
1034
+ frame_step=frame_step,
1035
+ numangsint=numangsint,
1036
+ figsize=figsize,
1037
+ show_progress=show_progress,
1038
+ render_backend=render_backend,
1039
+ output_dpi=output_dpi,
1040
+ render_workers=render_workers,
1041
+ )
1042
+
1043
+ if mode != "fast":
1044
+ raise ProcessingError(f"Invalid mode '{mode}'. Must be 'fast' or '3d'.")
1045
+
1046
+ if config is None:
1047
+ config = PlotConfig.for_animation(
1048
+ time_steps_per_second=1000,
1049
+ title="CANN2D Bump Activity (2D Projection)",
1050
+ figsize=figsize,
1051
+ fps=fps,
1052
+ show=show,
1053
+ save_path=save_path,
1054
+ show_progress_bar=show_progress,
1055
+ )
1056
+ else:
1057
+ if save_path is not None:
1058
+ config.save_path = save_path
1059
+ if show is not None:
1060
+ config.show = show
1061
+ if figsize is not None:
1062
+ config.figsize = figsize
1063
+ if fps is not None:
1064
+ config.fps = fps
1065
+ if show_progress is not None:
1066
+ config.show_progress_bar = show_progress
1067
+
1068
+ save_path = config.save_path
1069
+ show = config.show
1070
+ fps = config.fps
1071
+ figsize = config.figsize
1072
+ show_progress = config.show_progress_bar
1073
+
1074
+ # Load decoding results
1075
+ if isinstance(decoding_result, str):
1076
+ f = np.load(decoding_result, allow_pickle=True)
1077
+ coords = f["coordsbox"]
1078
+ times = f["times_box"]
1079
+ f.close()
1080
+ else:
1081
+ coords = decoding_result["coordsbox"]
1082
+ times = decoding_result["times_box"]
1083
+
1084
+ # Process spike data for 2D projection
1085
+ spk, *_ = embed_spike_trains(
1086
+ spike_data, config=SpikeEmbeddingConfig(smooth=False, speed_filter=True)
1087
+ )
1088
+
1089
+ # Process frames
1090
+ n_frames = (np.max(times) - window_size) // frame_step
1091
+ frame_activity_maps = []
1092
+ prev_m = None
1093
+
1094
+ for frame_idx in tqdm(range(n_frames), desc="Processing frames", disable=not show_progress):
1095
+ start_idx = frame_idx * frame_step
1096
+ end_idx = start_idx + window_size
1097
+ if end_idx > np.max(times):
1098
+ break
1099
+
1100
+ mask = (times >= start_idx) & (times < end_idx)
1101
+ coords_window = coords[mask]
1102
+ if len(coords_window) == 0:
1103
+ continue
1104
+
1105
+ spk_window = spk[times[mask], :]
1106
+ activity = np.sum(spk_window, axis=1)
1107
+
1108
+ m, _, _, _ = binned_statistic_2d(
1109
+ coords_window[:, 0],
1110
+ coords_window[:, 1],
1111
+ activity,
1112
+ statistic="sum",
1113
+ bins=np.linspace(0, 2 * np.pi, numangsint - 1),
1114
+ )
1115
+ m = np.nan_to_num(m)
1116
+ m = _smooth_tuning_map(m, numangsint - 1, sig=4.0, bClose=True)
1117
+ m = gaussian_filter(m, sigma=1.0)
1118
+
1119
+ if prev_m is not None:
1120
+ m = 0.7 * prev_m + 0.3 * m
1121
+ prev_m = m
1122
+
1123
+ frame_activity_maps.append(m)
1124
+
1125
+ if not frame_activity_maps:
1126
+ raise ProcessingError("No valid frames generated for animation")
1127
+
1128
+ # Create 2D visualization with blitting
1129
+ fig, ax = plt.subplots(figsize=figsize)
1130
+ ax.set_xlabel("Manifold Dimension 1 (rad)", fontsize=12)
1131
+ ax.set_ylabel("Manifold Dimension 2 (rad)", fontsize=12)
1132
+ ax.set_title("CANN2D Bump Activity (2D Projection)", fontsize=14, fontweight="bold")
1133
+
1134
+ # Pre-create artists for blitting
1135
+ # Heatmap
1136
+ im = ax.imshow(
1137
+ frame_activity_maps[0].T, # Transpose for correct orientation
1138
+ extent=[0, 2 * np.pi, 0, 2 * np.pi],
1139
+ origin="lower",
1140
+ cmap="viridis",
1141
+ animated=True,
1142
+ aspect="auto",
1143
+ )
1144
+ # Colorbar (static)
1145
+ cbar = plt.colorbar(im, ax=ax)
1146
+ cbar.set_label("Activity", fontsize=11)
1147
+
1148
+ # Time text
1149
+ time_text = ax.text(
1150
+ 0.02,
1151
+ 0.98,
1152
+ "",
1153
+ transform=ax.transAxes,
1154
+ fontsize=11,
1155
+ verticalalignment="top",
1156
+ bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
1157
+ animated=True,
1158
+ )
1159
+
1160
+ def init():
1161
+ """Initialize animation"""
1162
+ im.set_array(frame_activity_maps[0].T)
1163
+ time_text.set_text("")
1164
+ return im, time_text
1165
+
1166
+ def update(frame_idx):
1167
+ """Update function - only modify data using blitting"""
1168
+ if frame_idx >= len(frame_activity_maps):
1169
+ return im, time_text
1170
+
1171
+ # Update heatmap data
1172
+ im.set_array(frame_activity_maps[frame_idx].T)
1173
+
1174
+ # Update time text
1175
+ time_text.set_text(f"Frame: {frame_idx + 1}/{len(frame_activity_maps)}")
1176
+
1177
+ return im, time_text
1178
+
1179
+ # Check blitting support
1180
+ use_blitting = True
1181
+ try:
1182
+ if not fig.canvas.supports_blit:
1183
+ use_blitting = False
1184
+ except AttributeError:
1185
+ use_blitting = False
1186
+
1187
+ interval_ms = 1000 / fps
1188
+
1189
+ def _build_animation():
1190
+ return animation.FuncAnimation(
1191
+ fig,
1192
+ update,
1193
+ frames=len(frame_activity_maps),
1194
+ init_func=init,
1195
+ interval=interval_ms,
1196
+ blit=use_blitting,
1197
+ repeat=config.repeat,
1198
+ )
1199
+
1200
+ ani = None
1201
+ progress_bar_enabled = show_progress
1202
+
1203
+ if save_path:
1204
+ _ensure_parent_dir(save_path)
1205
+ if show and len(frame_activity_maps) > 50:
1206
+ warn_double_rendering(len(frame_activity_maps), save_path, stacklevel=2)
1207
+
1208
+ backend_selection = select_animation_backend(
1209
+ save_path=save_path,
1210
+ requested_backend=render_backend,
1211
+ check_imageio_plugins=True,
1212
+ )
1213
+ emit_backend_warnings(backend_selection.warnings, stacklevel=2)
1214
+ backend = backend_selection.backend
1215
+
1216
+ if backend == "imageio":
1217
+ render_data = {
1218
+ "maps": frame_activity_maps,
1219
+ "figsize": figsize,
1220
+ "dpi": output_dpi,
1221
+ }
1222
+ workers = render_workers
1223
+ if workers is None:
1224
+ workers = config.render_workers
1225
+ if workers is None:
1226
+ workers = get_optimal_worker_count()
1227
+ try:
1228
+ render_animation_parallel(
1229
+ _render_2d_bump_frame,
1230
+ render_data,
1231
+ num_frames=len(frame_activity_maps),
1232
+ save_path=save_path,
1233
+ fps=fps,
1234
+ num_workers=workers,
1235
+ show_progress=progress_bar_enabled,
1236
+ )
1237
+ except Exception as e:
1238
+ import warnings
1239
+
1240
+ warnings.warn(
1241
+ f"imageio rendering failed: {e}. Falling back to matplotlib.",
1242
+ RuntimeWarning,
1243
+ stacklevel=2,
1244
+ )
1245
+ backend = "matplotlib"
1246
+
1247
+ if backend == "matplotlib":
1248
+ ani = _build_animation()
1249
+ writer = get_matplotlib_writer(save_path, fps=fps)
1250
+ if progress_bar_enabled:
1251
+ pbar = tqdm(total=len(frame_activity_maps), desc=f"Saving to {save_path}")
1252
+
1253
+ def progress_callback(current_frame: int, total_frames: int) -> None:
1254
+ pbar.update(1)
1255
+
1256
+ try:
1257
+ ani.save(save_path, writer=writer, progress_callback=progress_callback)
1258
+ finally:
1259
+ pbar.close()
1260
+ else:
1261
+ ani.save(save_path, writer=writer)
1262
+
1263
+ if show:
1264
+ if ani is None:
1265
+ ani = _build_animation()
1266
+ if is_jupyter_environment():
1267
+ display_animation_in_jupyter(ani)
1268
+ plt.close(fig)
1269
+ else:
1270
+ plt.show()
1271
+ else:
1272
+ plt.close(fig)
1273
+
1274
+ if show and is_jupyter_environment():
1275
+ return None
1276
+ return ani