canns 0.12.6__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. canns/__init__.py +39 -3
  2. canns/analyzer/__init__.py +7 -6
  3. canns/analyzer/data/__init__.py +3 -11
  4. canns/analyzer/data/asa/__init__.py +74 -0
  5. canns/analyzer/data/asa/cohospace.py +905 -0
  6. canns/analyzer/data/asa/config.py +246 -0
  7. canns/analyzer/data/asa/decode.py +448 -0
  8. canns/analyzer/data/asa/embedding.py +269 -0
  9. canns/analyzer/data/asa/filters.py +208 -0
  10. canns/analyzer/data/asa/fr.py +439 -0
  11. canns/analyzer/data/asa/path.py +389 -0
  12. canns/analyzer/data/asa/plotting.py +1276 -0
  13. canns/analyzer/data/asa/tda.py +901 -0
  14. canns/analyzer/data/legacy/__init__.py +6 -0
  15. canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
  16. canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
  17. canns/analyzer/metrics/spatial_metrics.py +70 -100
  18. canns/analyzer/metrics/systematic_ratemap.py +12 -17
  19. canns/analyzer/metrics/utils.py +28 -0
  20. canns/analyzer/model_specific/hopfield.py +19 -16
  21. canns/analyzer/slow_points/checkpoint.py +32 -9
  22. canns/analyzer/slow_points/finder.py +33 -6
  23. canns/analyzer/slow_points/fixed_points.py +12 -0
  24. canns/analyzer/slow_points/visualization.py +22 -10
  25. canns/analyzer/visualization/core/backend.py +15 -26
  26. canns/analyzer/visualization/core/config.py +120 -15
  27. canns/analyzer/visualization/core/jupyter_utils.py +34 -16
  28. canns/analyzer/visualization/core/rendering.py +42 -40
  29. canns/analyzer/visualization/core/writers.py +10 -20
  30. canns/analyzer/visualization/energy_plots.py +78 -28
  31. canns/analyzer/visualization/spatial_plots.py +81 -36
  32. canns/analyzer/visualization/spike_plots.py +27 -7
  33. canns/analyzer/visualization/theta_sweep_plots.py +159 -72
  34. canns/analyzer/visualization/tuning_plots.py +11 -3
  35. canns/data/__init__.py +7 -4
  36. canns/models/__init__.py +10 -0
  37. canns/models/basic/cann.py +102 -40
  38. canns/models/basic/grid_cell.py +9 -8
  39. canns/models/basic/hierarchical_model.py +57 -11
  40. canns/models/brain_inspired/hopfield.py +26 -14
  41. canns/models/brain_inspired/linear.py +15 -16
  42. canns/models/brain_inspired/spiking.py +23 -12
  43. canns/pipeline/__init__.py +4 -8
  44. canns/pipeline/asa/__init__.py +21 -0
  45. canns/pipeline/asa/__main__.py +11 -0
  46. canns/pipeline/asa/app.py +1000 -0
  47. canns/pipeline/asa/runner.py +1095 -0
  48. canns/pipeline/asa/screens.py +215 -0
  49. canns/pipeline/asa/state.py +248 -0
  50. canns/pipeline/asa/styles.tcss +221 -0
  51. canns/pipeline/asa/widgets.py +233 -0
  52. canns/pipeline/gallery/__init__.py +7 -0
  53. canns/task/closed_loop_navigation.py +54 -13
  54. canns/task/open_loop_navigation.py +230 -147
  55. canns/task/tracking.py +156 -24
  56. canns/trainer/__init__.py +8 -5
  57. canns/utils/__init__.py +12 -4
  58. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
  59. canns-0.13.0.dist-info/RECORD +91 -0
  60. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
  61. canns/pipeline/theta_sweep.py +0 -573
  62. canns-0.12.6.dist-info/RECORD +0 -72
  63. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
  64. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -45,9 +45,10 @@ def _render_single_energy_1d_frame(
45
45
  options: _Energy1DRenderOptions,
46
46
  ) -> np.ndarray:
47
47
  """Render a single frame for 1D energy landscape animation (module-level for pickling)."""
48
+ from io import BytesIO
49
+
48
50
  import matplotlib.pyplot as plt
49
51
  import numpy as np
50
- from io import BytesIO
51
52
 
52
53
  fig, ax = plt.subplots(figsize=options.figsize)
53
54
  sim_index = options.sim_indices_to_render[frame_index]
@@ -121,9 +122,10 @@ def _render_single_energy_2d_frame(
121
122
  options: _Energy2DRenderOptions,
122
123
  ) -> np.ndarray:
123
124
  """Render a single frame for 2D energy landscape animation (module-level for pickling)."""
125
+ from io import BytesIO
126
+
124
127
  import matplotlib.pyplot as plt
125
128
  import numpy as np
126
- from io import BytesIO
127
129
 
128
130
  fig, ax = plt.subplots(figsize=options.figsize)
129
131
  sim_index = options.sim_indices_to_render[frame_index]
@@ -211,16 +213,10 @@ def energy_landscape_1d_static(
211
213
  show: bool = True,
212
214
  **kwargs: Any,
213
215
  ):
214
- """Plot a 1D static energy landscape using Matplotlib.
215
-
216
- This mirrors the long-form description from the pre-reorganisation module so
217
- existing documentation references stay accurate. The function accepts a
218
- dictionary of datasets, plotting each curve on the same set of axes while
219
- honouring the ``PlotConfig`` defaults callers relied on previously.
216
+ """Plot a 1D static energy landscape.
220
217
 
221
218
  Args:
222
- data_sets: Mapping of series labels to ``(x, y)`` tuples representing
223
- the energy curve to draw.
219
+ data_sets: Mapping ``label -> (x, y)`` where ``x`` and ``y`` are 1D arrays.
224
220
  config: Optional :class:`PlotConfig` carrying shared styling.
225
221
  title: Plot title when no config override is supplied.
226
222
  xlabel: X-axis label when no config override is supplied.
@@ -234,6 +230,17 @@ def energy_landscape_1d_static(
234
230
 
235
231
  Returns:
236
232
  Tuple[plt.Figure, plt.Axes]: The created figure and axes handles.
233
+
234
+ Examples:
235
+ >>> import numpy as np
236
+ >>> from canns.analyzer.visualization import energy_landscape_1d_static, PlotConfigs
237
+ >>>
238
+ >>> x = np.linspace(0, 1, 5)
239
+ >>> data_sets = {"u": (x, np.sin(x)), "Iext": (x, np.cos(x))}
240
+ >>> config = PlotConfigs.energy_landscape_1d_static(show=False)
241
+ >>> fig, ax = energy_landscape_1d_static(data_sets, config=config)
242
+ >>> print(fig is not None)
243
+ True
237
244
  """
238
245
 
239
246
  config = _ensure_plot_config(
@@ -291,16 +298,10 @@ def energy_landscape_1d_animation(
291
298
  ) -> animation.FuncAnimation:
292
299
  """Create an animation of an evolving 1D energy landscape.
293
300
 
294
- The docstring intentionally preserves the guidance from the previous
295
- implementation so existing callers can rely on the same parameter
296
- explanations.
297
-
298
301
  Args:
299
- data_sets: Dictionary whose keys are legend labels and values are
300
- ``(x_data, y_data)`` tuples where ``y_data`` is shaped as
301
- ``(time, state)``.
302
- time_steps_per_second: Number of simulation time steps per second of
303
- wall-clock time (e.g., ``1/dt``).
302
+ data_sets: Mapping ``label -> (x, y_series)``, where ``y_series`` is
303
+ shaped ``(timesteps, npoints)``.
304
+ time_steps_per_second: Simulation steps per second (e.g., ``1/dt``).
304
305
  config: Optional :class:`PlotConfig` with shared styling overrides.
305
306
  fps: Frames per second to render in the resulting animation.
306
307
  title: Title used when ``config`` is not provided.
@@ -320,6 +321,22 @@ def energy_landscape_1d_animation(
320
321
 
321
322
  Returns:
322
323
  ``matplotlib.animation.FuncAnimation``: The constructed animation.
324
+
325
+ Examples:
326
+ >>> import numpy as np
327
+ >>> from canns.analyzer.visualization import energy_landscape_1d_animation, PlotConfigs
328
+ >>>
329
+ >>> x = np.linspace(0, 1, 5)
330
+ >>> y_series = np.stack([np.sin(x), np.cos(x)], axis=0)
331
+ >>> data_sets = {"u": (x, y_series), "Iext": (x, y_series)}
332
+ >>> config = PlotConfigs.energy_landscape_1d_animation(
333
+ ... time_steps_per_second=10,
334
+ ... fps=2,
335
+ ... show=False,
336
+ ... )
337
+ >>> anim = energy_landscape_1d_animation(data_sets, config=config)
338
+ >>> print(anim is not None)
339
+ True
323
340
  """
324
341
 
325
342
  config = _ensure_plot_config(
@@ -442,8 +459,12 @@ def energy_landscape_1d_animation(
442
459
 
443
460
  if backend == "imageio":
444
461
  # Use imageio backend with parallel rendering
445
- workers = render_workers if render_workers is not None else get_optimal_worker_count()
446
- ctx = get_multiprocessing_context(prefer_fork=(render_start_method == "fork"))
462
+ workers = (
463
+ render_workers if render_workers is not None else get_optimal_worker_count()
464
+ )
465
+ ctx, start_method = get_multiprocessing_context(
466
+ prefer_fork=(render_start_method == "fork")
467
+ )
447
468
 
448
469
  # Create render options
449
470
  render_options = _Energy1DRenderOptions(
@@ -466,9 +487,10 @@ def energy_landscape_1d_animation(
466
487
  writer_kwargs, mode = get_imageio_writer_kwargs(config.save_path, config.fps)
467
488
 
468
489
  try:
469
- import imageio
470
490
  from functools import partial
471
491
 
492
+ import imageio
493
+
472
494
  # Create partial function with options
473
495
  render_func = partial(_render_single_energy_1d_frame, options=render_options)
474
496
 
@@ -504,6 +526,7 @@ def energy_landscape_1d_animation(
504
526
 
505
527
  except Exception as e:
506
528
  import warnings
529
+
507
530
  warnings.warn(
508
531
  f"imageio rendering failed: {e}. Falling back to matplotlib.",
509
532
  RuntimeWarning,
@@ -605,6 +628,16 @@ def energy_landscape_2d_static(
605
628
 
606
629
  Returns:
607
630
  Tuple[plt.Figure, plt.Axes]: The Matplotlib figure and axes objects.
631
+
632
+ Examples:
633
+ >>> import numpy as np
634
+ >>> from canns.analyzer.visualization import energy_landscape_2d_static, PlotConfigs
635
+ >>>
636
+ >>> z = np.random.rand(4, 4)
637
+ >>> config = PlotConfigs.energy_landscape_2d_static(show=False)
638
+ >>> fig, ax = energy_landscape_2d_static(z, config=config)
639
+ >>> print(fig is not None)
640
+ True
608
641
  """
609
642
 
610
643
  config = _ensure_plot_config(
@@ -673,9 +706,6 @@ def energy_landscape_2d_animation(
673
706
  ) -> animation.FuncAnimation:
674
707
  """Create an animation of an evolving 2D landscape.
675
708
 
676
- The long-form description mirrors the previous implementation to maintain
677
- backwards-compatible documentation for downstream users.
678
-
679
709
  Args:
680
710
  zs_data: Array of shape ``(timesteps, dim_y, dim_x)`` describing the
681
711
  landscape at each simulation step.
@@ -701,6 +731,20 @@ def energy_landscape_2d_animation(
701
731
 
702
732
  Returns:
703
733
  ``matplotlib.animation.FuncAnimation``: The constructed animation.
734
+
735
+ Examples:
736
+ >>> import numpy as np
737
+ >>> from canns.analyzer.visualization import energy_landscape_2d_animation, PlotConfigs
738
+ >>>
739
+ >>> zs = np.random.rand(3, 4, 4)
740
+ >>> config = PlotConfigs.energy_landscape_2d_animation(
741
+ ... time_steps_per_second=10,
742
+ ... fps=2,
743
+ ... show=False,
744
+ ... )
745
+ >>> anim = energy_landscape_2d_animation(zs, config=config)
746
+ >>> print(anim is not None)
747
+ True
704
748
  """
705
749
 
706
750
  config = _ensure_plot_config(
@@ -817,8 +861,12 @@ def energy_landscape_2d_animation(
817
861
 
818
862
  if backend == "imageio":
819
863
  # Use imageio backend with parallel rendering
820
- workers = render_workers if render_workers is not None else get_optimal_worker_count()
821
- ctx = get_multiprocessing_context(prefer_fork=(render_start_method == "fork"))
864
+ workers = (
865
+ render_workers if render_workers is not None else get_optimal_worker_count()
866
+ )
867
+ ctx, start_method = get_multiprocessing_context(
868
+ prefer_fork=(render_start_method == "fork")
869
+ )
822
870
 
823
871
  # Create render options
824
872
  render_options = _Energy2DRenderOptions(
@@ -841,9 +889,10 @@ def energy_landscape_2d_animation(
841
889
  writer_kwargs, mode = get_imageio_writer_kwargs(config.save_path, config.fps)
842
890
 
843
891
  try:
844
- import imageio
845
892
  from functools import partial
846
893
 
894
+ import imageio
895
+
847
896
  # Create partial function with options
848
897
  render_func = partial(_render_single_energy_2d_frame, options=render_options)
849
898
 
@@ -879,6 +928,7 @@ def energy_landscape_2d_animation(
879
928
 
880
929
  except Exception as e:
881
930
  import warnings
931
+
882
932
  warnings.warn(
883
933
  f"imageio rendering failed: {e}. Falling back to matplotlib.",
884
934
  RuntimeWarning,
@@ -49,9 +49,10 @@ def _render_single_grid_tracking_frame(
49
49
  options: _GridCellTrackingRenderOptions,
50
50
  ) -> np.ndarray:
51
51
  """Render a single frame for grid cell tracking animation (module-level for pickling)."""
52
+ from io import BytesIO
53
+
52
54
  import matplotlib.pyplot as plt
53
55
  import numpy as np
54
- from io import BytesIO
55
56
 
56
57
  fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=options.figsize)
57
58
  sim_idx = options.sim_indices_to_render[frame_index]
@@ -84,12 +85,19 @@ def _render_single_grid_tracking_frame(
84
85
 
85
86
  # Panel 3: Rate map with position
86
87
  im = ax3.imshow(
87
- options.rate_map.T, origin="lower", cmap="hot",
88
- extent=[0, options.env_size, 0, options.env_size], aspect="auto"
88
+ options.rate_map.T,
89
+ origin="lower",
90
+ cmap="hot",
91
+ extent=[0, options.env_size, 0, options.env_size],
92
+ aspect="auto",
89
93
  )
90
94
  ax3.plot(
91
- [options.position[sim_idx, 0]], [options.position[sim_idx, 1]], "c*",
92
- markersize=15, markeredgecolor="white", markeredgewidth=1.5
95
+ [options.position[sim_idx, 0]],
96
+ [options.position[sim_idx, 1]],
97
+ "c*",
98
+ markersize=15,
99
+ markeredgecolor="white",
100
+ markeredgewidth=1.5,
93
101
  )
94
102
  ax3.set_xlabel("X Position (m)", fontsize=10)
95
103
  ax3.set_ylabel("Y Position (m)", fontsize=10)
@@ -97,7 +105,9 @@ def _render_single_grid_tracking_frame(
97
105
  plt.colorbar(im, ax=ax3, fraction=0.046, pad=0.04)
98
106
 
99
107
  # Overall title with time
100
- fig.suptitle(f"{options.title} | Time: {current_time_s:.2f} s", fontsize=13, fontweight="bold")
108
+ fig.suptitle(
109
+ f"{options.title} | Time: {current_time_s:.2f} s", fontsize=13, fontweight="bold"
110
+ )
101
111
 
102
112
  fig.tight_layout()
103
113
 
@@ -179,15 +189,15 @@ def plot_firing_field_heatmap(
179
189
  tuple[plt.Figure, plt.Axes]: The figure and axis objects for further customization.
180
190
 
181
191
  Example:
182
- >>> from canns.analyzer.metrics.spatial_metrics import compute_firing_field
192
+ >>> import numpy as np
183
193
  >>> from canns.analyzer.visualization import plot_firing_field_heatmap, PlotConfig
184
- >>> # Compute firing field
185
- >>> heatmaps = compute_firing_field(activity, positions, 5.0, 5.0, 50, 50)
186
- >>> # Plot single neuron with PlotConfig
187
- >>> config = PlotConfig(figsize=(6, 6), title='Neuron 0', save_path='neuron_0.png', show=False)
188
- >>> fig, ax = plot_firing_field_heatmap(heatmaps[0], config=config)
189
- >>> # Plot with legacy parameters
190
- >>> fig, ax = plot_firing_field_heatmap(heatmaps[1], title='Neuron 1', cmap='viridis', save_path='neuron_1.png')
194
+ >>>
195
+ >>> # Dummy input heatmap (M x K)
196
+ >>> heatmap = np.random.rand(6, 6)
197
+ >>> config = PlotConfig(title="Neuron 0", show=False)
198
+ >>> fig, ax = plot_firing_field_heatmap(heatmap, config=config)
199
+ >>> print(fig is not None)
200
+ True
191
201
  """
192
202
  # Handle configuration
193
203
  if config is None:
@@ -272,14 +282,16 @@ def plot_autocorrelation(
272
282
  tuple[plt.Figure, plt.Axes]: Figure and axes objects.
273
283
 
274
284
  Example:
285
+ >>> import numpy as np
275
286
  >>> from canns.analyzer.metrics.spatial_metrics import compute_spatial_autocorrelation
276
287
  >>> from canns.analyzer.visualization import plot_autocorrelation, PlotConfigs
288
+ >>>
289
+ >>> rate_map = np.random.rand(10, 10)
277
290
  >>> autocorr = compute_spatial_autocorrelation(rate_map)
278
- >>> # Modern approach
279
- >>> config = PlotConfigs.grid_autocorrelation(save_path='autocorr.png')
291
+ >>> config = PlotConfigs.grid_autocorrelation(show=False)
280
292
  >>> fig, ax = plot_autocorrelation(autocorr, config=config)
281
- >>> # Legacy approach
282
- >>> fig, ax = plot_autocorrelation(autocorr, cmap='RdBu_r', save_path='autocorr.png')
293
+ >>> print(fig is not None)
294
+ True
283
295
 
284
296
  References:
285
297
  Sargolini et al. (2006). Conjunctive representation of position, direction,
@@ -369,12 +381,15 @@ def plot_grid_score(
369
381
  tuple[plt.Figure, plt.Axes]: Figure and axes objects.
370
382
 
371
383
  Example:
384
+ >>> import numpy as np
372
385
  >>> from canns.analyzer.metrics.spatial_metrics import compute_grid_score
373
386
  >>> from canns.analyzer.visualization import plot_grid_score
387
+ >>>
388
+ >>> autocorr = np.random.rand(10, 10)
374
389
  >>> grid_score, rotated_corrs = compute_grid_score(autocorr)
375
- >>> fig, ax = plot_grid_score(rotated_corrs, grid_score)
376
- >>> print(f"Grid score: {grid_score:.3f}")
377
- Grid score: 0.456
390
+ >>> fig, ax = plot_grid_score(rotated_corrs, grid_score, show=False)
391
+ >>> print(isinstance(grid_score, float))
392
+ True
378
393
  """
379
394
  config = _ensure_plot_config(
380
395
  config,
@@ -483,11 +498,20 @@ def plot_grid_spacing_analysis(
483
498
  tuple[plt.Figure, plt.Axes]: Figure and axes objects.
484
499
 
485
500
  Example:
501
+ >>> import numpy as np
486
502
  >>> from canns.analyzer.metrics.spatial_metrics import find_grid_spacing
487
503
  >>> from canns.analyzer.visualization import plot_grid_spacing_analysis
488
- >>> spacing_bins, spacing_m = find_grid_spacing(autocorr, bin_size=0.06)
489
- >>> fig, ax = plot_grid_spacing_analysis(autocorr, spacing_bins, bin_size=0.06)
490
- >>> print(f"Spacing: {spacing_m:.3f}m")
504
+ >>>
505
+ >>> autocorr = np.random.rand(12, 12)
506
+ >>> spacing_bins, spacing_m = find_grid_spacing(autocorr, bin_size=0.05)
507
+ >>> fig, ax = plot_grid_spacing_analysis(
508
+ ... autocorr,
509
+ ... spacing_bins,
510
+ ... bin_size=0.05,
511
+ ... show=False,
512
+ ... )
513
+ >>> print(spacing_m is not None)
514
+ True
491
515
  """
492
516
  config = _ensure_plot_config(
493
517
  config,
@@ -620,18 +644,29 @@ def create_grid_cell_tracking_animation(
620
644
  FuncAnimation | None: Animation object, or None if displayed in Jupyter.
621
645
 
622
646
  Example:
623
- >>> from canns.analyzer.visualization import create_grid_cell_tracking_animation, PlotConfigs
624
- >>> # Create animation
647
+ >>> import numpy as np
648
+ >>> from canns.analyzer.visualization import (
649
+ ... create_grid_cell_tracking_animation,
650
+ ... PlotConfigs,
651
+ ... )
652
+ >>>
653
+ >>> position = np.array([[0.0, 0.0], [0.1, 0.1], [0.2, 0.2]])
654
+ >>> activity = np.array([0.0, 0.5, 1.0])
655
+ >>> rate_map = np.random.rand(5, 5)
625
656
  >>> config = PlotConfigs.grid_cell_tracking_animation(
626
- ... time_steps_per_second=1000, # dt=1.0ms
627
- ... fps=20,
628
- ... save_path="tracking.gif"
657
+ ... time_steps_per_second=10,
658
+ ... fps=2,
659
+ ... show=False,
629
660
  ... )
630
661
  >>> anim = create_grid_cell_tracking_animation(
631
- ... position, activity, rate_map,
662
+ ... position,
663
+ ... activity,
664
+ ... rate_map,
632
665
  ... config=config,
633
- ... env_size=3.0
666
+ ... env_size=1.0,
634
667
  ... )
668
+ >>> print(anim is not None)
669
+ True
635
670
  """
636
671
  config = _ensure_plot_config(
637
672
  config,
@@ -767,8 +802,12 @@ def create_grid_cell_tracking_animation(
767
802
 
768
803
  if backend == "imageio":
769
804
  # Use imageio backend with parallel rendering
770
- workers = render_workers if render_workers is not None else get_optimal_worker_count()
771
- ctx = get_multiprocessing_context(prefer_fork=(render_start_method == "fork"))
805
+ workers = (
806
+ render_workers if render_workers is not None else get_optimal_worker_count()
807
+ )
808
+ ctx, start_method = get_multiprocessing_context(
809
+ prefer_fork=(render_start_method == "fork")
810
+ )
772
811
 
773
812
  # Create render options
774
813
  render_options = _GridCellTrackingRenderOptions(
@@ -788,11 +827,14 @@ def create_grid_cell_tracking_animation(
788
827
  writer_kwargs, mode = get_imageio_writer_kwargs(config.save_path, config.fps)
789
828
 
790
829
  try:
791
- import imageio
792
830
  from functools import partial
793
831
 
832
+ import imageio
833
+
794
834
  # Create partial function with options
795
- render_func = partial(_render_single_grid_tracking_frame, options=render_options)
835
+ render_func = partial(
836
+ _render_single_grid_tracking_frame, options=render_options
837
+ )
796
838
 
797
839
  with imageio.get_writer(config.save_path, mode=mode, **writer_kwargs) as writer:
798
840
  if workers > 1 and ctx is not None:
@@ -826,6 +868,7 @@ def create_grid_cell_tracking_animation(
826
868
 
827
869
  except Exception as e:
828
870
  import warnings
871
+
829
872
  warnings.warn(
830
873
  f"imageio rendering failed: {e}. Falling back to matplotlib.",
831
874
  RuntimeWarning,
@@ -855,7 +898,9 @@ def create_grid_cell_tracking_animation(
855
898
  pbar.update(1)
856
899
 
857
900
  try:
858
- ani.save(config.save_path, writer=writer, progress_callback=progress_callback)
901
+ ani.save(
902
+ config.save_path, writer=writer, progress_callback=progress_callback
903
+ )
859
904
  print(f"Animation saved to: {config.save_path}")
860
905
  finally:
861
906
  pbar.close()
@@ -47,9 +47,6 @@ def raster_plot(
47
47
  ):
48
48
  """Generate a raster plot from a spike train matrix.
49
49
 
50
- The explanatory text mirrors the former ``visualize`` module so callers see
51
- the same guidance after the reorganisation.
52
-
53
50
  Args:
54
51
  spike_train: Boolean/integer array of shape ``(timesteps, neurons)``.
55
52
  config: Optional :class:`PlotConfig` with shared styling options.
@@ -62,6 +59,17 @@ def raster_plot(
62
59
  save_path: Optional path used to persist the plot.
63
60
  show: Whether to display the plot interactively.
64
61
  **kwargs: Additional keyword arguments passed through to Matplotlib.
62
+
63
+ Examples:
64
+ >>> import numpy as np
65
+ >>> from canns.analyzer.visualization import raster_plot, PlotConfigs
66
+ >>>
67
+ >>> spike_train = np.zeros((5, 3), dtype=int)
68
+ >>> spike_train[::2, 0] = 1
69
+ >>> config = PlotConfigs.raster_plot(show=False)
70
+ >>> fig, ax = raster_plot(spike_train, config=config)
71
+ >>> print(fig is not None)
72
+ True
65
73
  """
66
74
 
67
75
  config = _ensure_plot_config(
@@ -158,6 +166,16 @@ def average_firing_rate_plot(
158
166
  save_path: Optional path used to persist the plot.
159
167
  show: Whether to display the plot interactively.
160
168
  **kwargs: Additional keyword arguments forwarded to Matplotlib.
169
+
170
+ Examples:
171
+ >>> import numpy as np
172
+ >>> from canns.analyzer.visualization import average_firing_rate_plot, PlotConfigs
173
+ >>>
174
+ >>> spike_train = np.random.randint(0, 2, size=(10, 4))
175
+ >>> config = PlotConfigs.average_firing_rate_plot(mode="population", show=False)
176
+ >>> fig, ax = average_firing_rate_plot(spike_train, dt=0.1, config=config)
177
+ >>> print(fig is not None)
178
+ True
161
179
  """
162
180
 
163
181
  config = _ensure_plot_config(
@@ -267,10 +285,12 @@ def population_activity_heatmap(
267
285
 
268
286
  Example:
269
287
  >>> import numpy as np
270
- >>> from canns.analyzer.visualization.spike_plots import population_activity_heatmap
271
- >>> # Simulate some activity data
272
- >>> activity = np.random.rand(1000, 100) # 1000 timesteps, 100 neurons
273
- >>> fig, ax = population_activity_heatmap(activity, dt=0.001)
288
+ >>> from canns.analyzer.visualization import population_activity_heatmap, PlotConfig
289
+ >>> activity = np.random.rand(10, 5)
290
+ >>> config = PlotConfig(show=False)
291
+ >>> fig, ax = population_activity_heatmap(activity, dt=0.1, config=config)
292
+ >>> print(fig is not None)
293
+ True
274
294
  """
275
295
  if config is None:
276
296
  config = PlotConfig(