canns 0.12.6__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/__init__.py +39 -3
- canns/analyzer/__init__.py +7 -6
- canns/analyzer/data/__init__.py +3 -11
- canns/analyzer/data/asa/__init__.py +74 -0
- canns/analyzer/data/asa/cohospace.py +905 -0
- canns/analyzer/data/asa/config.py +246 -0
- canns/analyzer/data/asa/decode.py +448 -0
- canns/analyzer/data/asa/embedding.py +269 -0
- canns/analyzer/data/asa/filters.py +208 -0
- canns/analyzer/data/asa/fr.py +439 -0
- canns/analyzer/data/asa/path.py +389 -0
- canns/analyzer/data/asa/plotting.py +1276 -0
- canns/analyzer/data/asa/tda.py +901 -0
- canns/analyzer/data/legacy/__init__.py +6 -0
- canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
- canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
- canns/analyzer/metrics/spatial_metrics.py +70 -100
- canns/analyzer/metrics/systematic_ratemap.py +12 -17
- canns/analyzer/metrics/utils.py +28 -0
- canns/analyzer/model_specific/hopfield.py +19 -16
- canns/analyzer/slow_points/checkpoint.py +32 -9
- canns/analyzer/slow_points/finder.py +33 -6
- canns/analyzer/slow_points/fixed_points.py +12 -0
- canns/analyzer/slow_points/visualization.py +22 -10
- canns/analyzer/visualization/core/backend.py +15 -26
- canns/analyzer/visualization/core/config.py +120 -15
- canns/analyzer/visualization/core/jupyter_utils.py +34 -16
- canns/analyzer/visualization/core/rendering.py +42 -40
- canns/analyzer/visualization/core/writers.py +10 -20
- canns/analyzer/visualization/energy_plots.py +78 -28
- canns/analyzer/visualization/spatial_plots.py +81 -36
- canns/analyzer/visualization/spike_plots.py +27 -7
- canns/analyzer/visualization/theta_sweep_plots.py +159 -72
- canns/analyzer/visualization/tuning_plots.py +11 -3
- canns/data/__init__.py +7 -4
- canns/models/__init__.py +10 -0
- canns/models/basic/cann.py +102 -40
- canns/models/basic/grid_cell.py +9 -8
- canns/models/basic/hierarchical_model.py +57 -11
- canns/models/brain_inspired/hopfield.py +26 -14
- canns/models/brain_inspired/linear.py +15 -16
- canns/models/brain_inspired/spiking.py +23 -12
- canns/pipeline/__init__.py +4 -8
- canns/pipeline/asa/__init__.py +21 -0
- canns/pipeline/asa/__main__.py +11 -0
- canns/pipeline/asa/app.py +1000 -0
- canns/pipeline/asa/runner.py +1095 -0
- canns/pipeline/asa/screens.py +215 -0
- canns/pipeline/asa/state.py +248 -0
- canns/pipeline/asa/styles.tcss +221 -0
- canns/pipeline/asa/widgets.py +233 -0
- canns/pipeline/gallery/__init__.py +7 -0
- canns/task/closed_loop_navigation.py +54 -13
- canns/task/open_loop_navigation.py +230 -147
- canns/task/tracking.py +156 -24
- canns/trainer/__init__.py +8 -5
- canns/utils/__init__.py +12 -4
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
- canns-0.13.0.dist-info/RECORD +91 -0
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
- canns/pipeline/theta_sweep.py +0 -573
- canns-0.12.6.dist-info/RECORD +0 -72
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -45,9 +45,10 @@ def _render_single_energy_1d_frame(
|
|
|
45
45
|
options: _Energy1DRenderOptions,
|
|
46
46
|
) -> np.ndarray:
|
|
47
47
|
"""Render a single frame for 1D energy landscape animation (module-level for pickling)."""
|
|
48
|
+
from io import BytesIO
|
|
49
|
+
|
|
48
50
|
import matplotlib.pyplot as plt
|
|
49
51
|
import numpy as np
|
|
50
|
-
from io import BytesIO
|
|
51
52
|
|
|
52
53
|
fig, ax = plt.subplots(figsize=options.figsize)
|
|
53
54
|
sim_index = options.sim_indices_to_render[frame_index]
|
|
@@ -121,9 +122,10 @@ def _render_single_energy_2d_frame(
|
|
|
121
122
|
options: _Energy2DRenderOptions,
|
|
122
123
|
) -> np.ndarray:
|
|
123
124
|
"""Render a single frame for 2D energy landscape animation (module-level for pickling)."""
|
|
125
|
+
from io import BytesIO
|
|
126
|
+
|
|
124
127
|
import matplotlib.pyplot as plt
|
|
125
128
|
import numpy as np
|
|
126
|
-
from io import BytesIO
|
|
127
129
|
|
|
128
130
|
fig, ax = plt.subplots(figsize=options.figsize)
|
|
129
131
|
sim_index = options.sim_indices_to_render[frame_index]
|
|
@@ -211,16 +213,10 @@ def energy_landscape_1d_static(
|
|
|
211
213
|
show: bool = True,
|
|
212
214
|
**kwargs: Any,
|
|
213
215
|
):
|
|
214
|
-
"""Plot a 1D static energy landscape
|
|
215
|
-
|
|
216
|
-
This mirrors the long-form description from the pre-reorganisation module so
|
|
217
|
-
existing documentation references stay accurate. The function accepts a
|
|
218
|
-
dictionary of datasets, plotting each curve on the same set of axes while
|
|
219
|
-
honouring the ``PlotConfig`` defaults callers relied on previously.
|
|
216
|
+
"""Plot a 1D static energy landscape.
|
|
220
217
|
|
|
221
218
|
Args:
|
|
222
|
-
data_sets: Mapping
|
|
223
|
-
the energy curve to draw.
|
|
219
|
+
data_sets: Mapping ``label -> (x, y)`` where ``x`` and ``y`` are 1D arrays.
|
|
224
220
|
config: Optional :class:`PlotConfig` carrying shared styling.
|
|
225
221
|
title: Plot title when no config override is supplied.
|
|
226
222
|
xlabel: X-axis label when no config override is supplied.
|
|
@@ -234,6 +230,17 @@ def energy_landscape_1d_static(
|
|
|
234
230
|
|
|
235
231
|
Returns:
|
|
236
232
|
Tuple[plt.Figure, plt.Axes]: The created figure and axes handles.
|
|
233
|
+
|
|
234
|
+
Examples:
|
|
235
|
+
>>> import numpy as np
|
|
236
|
+
>>> from canns.analyzer.visualization import energy_landscape_1d_static, PlotConfigs
|
|
237
|
+
>>>
|
|
238
|
+
>>> x = np.linspace(0, 1, 5)
|
|
239
|
+
>>> data_sets = {"u": (x, np.sin(x)), "Iext": (x, np.cos(x))}
|
|
240
|
+
>>> config = PlotConfigs.energy_landscape_1d_static(show=False)
|
|
241
|
+
>>> fig, ax = energy_landscape_1d_static(data_sets, config=config)
|
|
242
|
+
>>> print(fig is not None)
|
|
243
|
+
True
|
|
237
244
|
"""
|
|
238
245
|
|
|
239
246
|
config = _ensure_plot_config(
|
|
@@ -291,16 +298,10 @@ def energy_landscape_1d_animation(
|
|
|
291
298
|
) -> animation.FuncAnimation:
|
|
292
299
|
"""Create an animation of an evolving 1D energy landscape.
|
|
293
300
|
|
|
294
|
-
The docstring intentionally preserves the guidance from the previous
|
|
295
|
-
implementation so existing callers can rely on the same parameter
|
|
296
|
-
explanations.
|
|
297
|
-
|
|
298
301
|
Args:
|
|
299
|
-
data_sets:
|
|
300
|
-
``(
|
|
301
|
-
|
|
302
|
-
time_steps_per_second: Number of simulation time steps per second of
|
|
303
|
-
wall-clock time (e.g., ``1/dt``).
|
|
302
|
+
data_sets: Mapping ``label -> (x, y_series)``, where ``y_series`` is
|
|
303
|
+
shaped ``(timesteps, npoints)``.
|
|
304
|
+
time_steps_per_second: Simulation steps per second (e.g., ``1/dt``).
|
|
304
305
|
config: Optional :class:`PlotConfig` with shared styling overrides.
|
|
305
306
|
fps: Frames per second to render in the resulting animation.
|
|
306
307
|
title: Title used when ``config`` is not provided.
|
|
@@ -320,6 +321,22 @@ def energy_landscape_1d_animation(
|
|
|
320
321
|
|
|
321
322
|
Returns:
|
|
322
323
|
``matplotlib.animation.FuncAnimation``: The constructed animation.
|
|
324
|
+
|
|
325
|
+
Examples:
|
|
326
|
+
>>> import numpy as np
|
|
327
|
+
>>> from canns.analyzer.visualization import energy_landscape_1d_animation, PlotConfigs
|
|
328
|
+
>>>
|
|
329
|
+
>>> x = np.linspace(0, 1, 5)
|
|
330
|
+
>>> y_series = np.stack([np.sin(x), np.cos(x)], axis=0)
|
|
331
|
+
>>> data_sets = {"u": (x, y_series), "Iext": (x, y_series)}
|
|
332
|
+
>>> config = PlotConfigs.energy_landscape_1d_animation(
|
|
333
|
+
... time_steps_per_second=10,
|
|
334
|
+
... fps=2,
|
|
335
|
+
... show=False,
|
|
336
|
+
... )
|
|
337
|
+
>>> anim = energy_landscape_1d_animation(data_sets, config=config)
|
|
338
|
+
>>> print(anim is not None)
|
|
339
|
+
True
|
|
323
340
|
"""
|
|
324
341
|
|
|
325
342
|
config = _ensure_plot_config(
|
|
@@ -442,8 +459,12 @@ def energy_landscape_1d_animation(
|
|
|
442
459
|
|
|
443
460
|
if backend == "imageio":
|
|
444
461
|
# Use imageio backend with parallel rendering
|
|
445
|
-
workers =
|
|
446
|
-
|
|
462
|
+
workers = (
|
|
463
|
+
render_workers if render_workers is not None else get_optimal_worker_count()
|
|
464
|
+
)
|
|
465
|
+
ctx, start_method = get_multiprocessing_context(
|
|
466
|
+
prefer_fork=(render_start_method == "fork")
|
|
467
|
+
)
|
|
447
468
|
|
|
448
469
|
# Create render options
|
|
449
470
|
render_options = _Energy1DRenderOptions(
|
|
@@ -466,9 +487,10 @@ def energy_landscape_1d_animation(
|
|
|
466
487
|
writer_kwargs, mode = get_imageio_writer_kwargs(config.save_path, config.fps)
|
|
467
488
|
|
|
468
489
|
try:
|
|
469
|
-
import imageio
|
|
470
490
|
from functools import partial
|
|
471
491
|
|
|
492
|
+
import imageio
|
|
493
|
+
|
|
472
494
|
# Create partial function with options
|
|
473
495
|
render_func = partial(_render_single_energy_1d_frame, options=render_options)
|
|
474
496
|
|
|
@@ -504,6 +526,7 @@ def energy_landscape_1d_animation(
|
|
|
504
526
|
|
|
505
527
|
except Exception as e:
|
|
506
528
|
import warnings
|
|
529
|
+
|
|
507
530
|
warnings.warn(
|
|
508
531
|
f"imageio rendering failed: {e}. Falling back to matplotlib.",
|
|
509
532
|
RuntimeWarning,
|
|
@@ -605,6 +628,16 @@ def energy_landscape_2d_static(
|
|
|
605
628
|
|
|
606
629
|
Returns:
|
|
607
630
|
Tuple[plt.Figure, plt.Axes]: The Matplotlib figure and axes objects.
|
|
631
|
+
|
|
632
|
+
Examples:
|
|
633
|
+
>>> import numpy as np
|
|
634
|
+
>>> from canns.analyzer.visualization import energy_landscape_2d_static, PlotConfigs
|
|
635
|
+
>>>
|
|
636
|
+
>>> z = np.random.rand(4, 4)
|
|
637
|
+
>>> config = PlotConfigs.energy_landscape_2d_static(show=False)
|
|
638
|
+
>>> fig, ax = energy_landscape_2d_static(z, config=config)
|
|
639
|
+
>>> print(fig is not None)
|
|
640
|
+
True
|
|
608
641
|
"""
|
|
609
642
|
|
|
610
643
|
config = _ensure_plot_config(
|
|
@@ -673,9 +706,6 @@ def energy_landscape_2d_animation(
|
|
|
673
706
|
) -> animation.FuncAnimation:
|
|
674
707
|
"""Create an animation of an evolving 2D landscape.
|
|
675
708
|
|
|
676
|
-
The long-form description mirrors the previous implementation to maintain
|
|
677
|
-
backwards-compatible documentation for downstream users.
|
|
678
|
-
|
|
679
709
|
Args:
|
|
680
710
|
zs_data: Array of shape ``(timesteps, dim_y, dim_x)`` describing the
|
|
681
711
|
landscape at each simulation step.
|
|
@@ -701,6 +731,20 @@ def energy_landscape_2d_animation(
|
|
|
701
731
|
|
|
702
732
|
Returns:
|
|
703
733
|
``matplotlib.animation.FuncAnimation``: The constructed animation.
|
|
734
|
+
|
|
735
|
+
Examples:
|
|
736
|
+
>>> import numpy as np
|
|
737
|
+
>>> from canns.analyzer.visualization import energy_landscape_2d_animation, PlotConfigs
|
|
738
|
+
>>>
|
|
739
|
+
>>> zs = np.random.rand(3, 4, 4)
|
|
740
|
+
>>> config = PlotConfigs.energy_landscape_2d_animation(
|
|
741
|
+
... time_steps_per_second=10,
|
|
742
|
+
... fps=2,
|
|
743
|
+
... show=False,
|
|
744
|
+
... )
|
|
745
|
+
>>> anim = energy_landscape_2d_animation(zs, config=config)
|
|
746
|
+
>>> print(anim is not None)
|
|
747
|
+
True
|
|
704
748
|
"""
|
|
705
749
|
|
|
706
750
|
config = _ensure_plot_config(
|
|
@@ -817,8 +861,12 @@ def energy_landscape_2d_animation(
|
|
|
817
861
|
|
|
818
862
|
if backend == "imageio":
|
|
819
863
|
# Use imageio backend with parallel rendering
|
|
820
|
-
workers =
|
|
821
|
-
|
|
864
|
+
workers = (
|
|
865
|
+
render_workers if render_workers is not None else get_optimal_worker_count()
|
|
866
|
+
)
|
|
867
|
+
ctx, start_method = get_multiprocessing_context(
|
|
868
|
+
prefer_fork=(render_start_method == "fork")
|
|
869
|
+
)
|
|
822
870
|
|
|
823
871
|
# Create render options
|
|
824
872
|
render_options = _Energy2DRenderOptions(
|
|
@@ -841,9 +889,10 @@ def energy_landscape_2d_animation(
|
|
|
841
889
|
writer_kwargs, mode = get_imageio_writer_kwargs(config.save_path, config.fps)
|
|
842
890
|
|
|
843
891
|
try:
|
|
844
|
-
import imageio
|
|
845
892
|
from functools import partial
|
|
846
893
|
|
|
894
|
+
import imageio
|
|
895
|
+
|
|
847
896
|
# Create partial function with options
|
|
848
897
|
render_func = partial(_render_single_energy_2d_frame, options=render_options)
|
|
849
898
|
|
|
@@ -879,6 +928,7 @@ def energy_landscape_2d_animation(
|
|
|
879
928
|
|
|
880
929
|
except Exception as e:
|
|
881
930
|
import warnings
|
|
931
|
+
|
|
882
932
|
warnings.warn(
|
|
883
933
|
f"imageio rendering failed: {e}. Falling back to matplotlib.",
|
|
884
934
|
RuntimeWarning,
|
|
@@ -49,9 +49,10 @@ def _render_single_grid_tracking_frame(
|
|
|
49
49
|
options: _GridCellTrackingRenderOptions,
|
|
50
50
|
) -> np.ndarray:
|
|
51
51
|
"""Render a single frame for grid cell tracking animation (module-level for pickling)."""
|
|
52
|
+
from io import BytesIO
|
|
53
|
+
|
|
52
54
|
import matplotlib.pyplot as plt
|
|
53
55
|
import numpy as np
|
|
54
|
-
from io import BytesIO
|
|
55
56
|
|
|
56
57
|
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=options.figsize)
|
|
57
58
|
sim_idx = options.sim_indices_to_render[frame_index]
|
|
@@ -84,12 +85,19 @@ def _render_single_grid_tracking_frame(
|
|
|
84
85
|
|
|
85
86
|
# Panel 3: Rate map with position
|
|
86
87
|
im = ax3.imshow(
|
|
87
|
-
options.rate_map.T,
|
|
88
|
-
|
|
88
|
+
options.rate_map.T,
|
|
89
|
+
origin="lower",
|
|
90
|
+
cmap="hot",
|
|
91
|
+
extent=[0, options.env_size, 0, options.env_size],
|
|
92
|
+
aspect="auto",
|
|
89
93
|
)
|
|
90
94
|
ax3.plot(
|
|
91
|
-
[options.position[sim_idx, 0]],
|
|
92
|
-
|
|
95
|
+
[options.position[sim_idx, 0]],
|
|
96
|
+
[options.position[sim_idx, 1]],
|
|
97
|
+
"c*",
|
|
98
|
+
markersize=15,
|
|
99
|
+
markeredgecolor="white",
|
|
100
|
+
markeredgewidth=1.5,
|
|
93
101
|
)
|
|
94
102
|
ax3.set_xlabel("X Position (m)", fontsize=10)
|
|
95
103
|
ax3.set_ylabel("Y Position (m)", fontsize=10)
|
|
@@ -97,7 +105,9 @@ def _render_single_grid_tracking_frame(
|
|
|
97
105
|
plt.colorbar(im, ax=ax3, fraction=0.046, pad=0.04)
|
|
98
106
|
|
|
99
107
|
# Overall title with time
|
|
100
|
-
fig.suptitle(
|
|
108
|
+
fig.suptitle(
|
|
109
|
+
f"{options.title} | Time: {current_time_s:.2f} s", fontsize=13, fontweight="bold"
|
|
110
|
+
)
|
|
101
111
|
|
|
102
112
|
fig.tight_layout()
|
|
103
113
|
|
|
@@ -179,15 +189,15 @@ def plot_firing_field_heatmap(
|
|
|
179
189
|
tuple[plt.Figure, plt.Axes]: The figure and axis objects for further customization.
|
|
180
190
|
|
|
181
191
|
Example:
|
|
182
|
-
>>>
|
|
192
|
+
>>> import numpy as np
|
|
183
193
|
>>> from canns.analyzer.visualization import plot_firing_field_heatmap, PlotConfig
|
|
184
|
-
>>>
|
|
185
|
-
>>>
|
|
186
|
-
>>>
|
|
187
|
-
>>> config = PlotConfig(
|
|
188
|
-
>>> fig, ax = plot_firing_field_heatmap(
|
|
189
|
-
>>>
|
|
190
|
-
|
|
194
|
+
>>>
|
|
195
|
+
>>> # Dummy input heatmap (M x K)
|
|
196
|
+
>>> heatmap = np.random.rand(6, 6)
|
|
197
|
+
>>> config = PlotConfig(title="Neuron 0", show=False)
|
|
198
|
+
>>> fig, ax = plot_firing_field_heatmap(heatmap, config=config)
|
|
199
|
+
>>> print(fig is not None)
|
|
200
|
+
True
|
|
191
201
|
"""
|
|
192
202
|
# Handle configuration
|
|
193
203
|
if config is None:
|
|
@@ -272,14 +282,16 @@ def plot_autocorrelation(
|
|
|
272
282
|
tuple[plt.Figure, plt.Axes]: Figure and axes objects.
|
|
273
283
|
|
|
274
284
|
Example:
|
|
285
|
+
>>> import numpy as np
|
|
275
286
|
>>> from canns.analyzer.metrics.spatial_metrics import compute_spatial_autocorrelation
|
|
276
287
|
>>> from canns.analyzer.visualization import plot_autocorrelation, PlotConfigs
|
|
288
|
+
>>>
|
|
289
|
+
>>> rate_map = np.random.rand(10, 10)
|
|
277
290
|
>>> autocorr = compute_spatial_autocorrelation(rate_map)
|
|
278
|
-
>>>
|
|
279
|
-
>>> config = PlotConfigs.grid_autocorrelation(save_path='autocorr.png')
|
|
291
|
+
>>> config = PlotConfigs.grid_autocorrelation(show=False)
|
|
280
292
|
>>> fig, ax = plot_autocorrelation(autocorr, config=config)
|
|
281
|
-
>>>
|
|
282
|
-
|
|
293
|
+
>>> print(fig is not None)
|
|
294
|
+
True
|
|
283
295
|
|
|
284
296
|
References:
|
|
285
297
|
Sargolini et al. (2006). Conjunctive representation of position, direction,
|
|
@@ -369,12 +381,15 @@ def plot_grid_score(
|
|
|
369
381
|
tuple[plt.Figure, plt.Axes]: Figure and axes objects.
|
|
370
382
|
|
|
371
383
|
Example:
|
|
384
|
+
>>> import numpy as np
|
|
372
385
|
>>> from canns.analyzer.metrics.spatial_metrics import compute_grid_score
|
|
373
386
|
>>> from canns.analyzer.visualization import plot_grid_score
|
|
387
|
+
>>>
|
|
388
|
+
>>> autocorr = np.random.rand(10, 10)
|
|
374
389
|
>>> grid_score, rotated_corrs = compute_grid_score(autocorr)
|
|
375
|
-
>>> fig, ax = plot_grid_score(rotated_corrs, grid_score)
|
|
376
|
-
>>> print(
|
|
377
|
-
|
|
390
|
+
>>> fig, ax = plot_grid_score(rotated_corrs, grid_score, show=False)
|
|
391
|
+
>>> print(isinstance(grid_score, float))
|
|
392
|
+
True
|
|
378
393
|
"""
|
|
379
394
|
config = _ensure_plot_config(
|
|
380
395
|
config,
|
|
@@ -483,11 +498,20 @@ def plot_grid_spacing_analysis(
|
|
|
483
498
|
tuple[plt.Figure, plt.Axes]: Figure and axes objects.
|
|
484
499
|
|
|
485
500
|
Example:
|
|
501
|
+
>>> import numpy as np
|
|
486
502
|
>>> from canns.analyzer.metrics.spatial_metrics import find_grid_spacing
|
|
487
503
|
>>> from canns.analyzer.visualization import plot_grid_spacing_analysis
|
|
488
|
-
>>>
|
|
489
|
-
>>>
|
|
490
|
-
>>>
|
|
504
|
+
>>>
|
|
505
|
+
>>> autocorr = np.random.rand(12, 12)
|
|
506
|
+
>>> spacing_bins, spacing_m = find_grid_spacing(autocorr, bin_size=0.05)
|
|
507
|
+
>>> fig, ax = plot_grid_spacing_analysis(
|
|
508
|
+
... autocorr,
|
|
509
|
+
... spacing_bins,
|
|
510
|
+
... bin_size=0.05,
|
|
511
|
+
... show=False,
|
|
512
|
+
... )
|
|
513
|
+
>>> print(spacing_m is not None)
|
|
514
|
+
True
|
|
491
515
|
"""
|
|
492
516
|
config = _ensure_plot_config(
|
|
493
517
|
config,
|
|
@@ -620,18 +644,29 @@ def create_grid_cell_tracking_animation(
|
|
|
620
644
|
FuncAnimation | None: Animation object, or None if displayed in Jupyter.
|
|
621
645
|
|
|
622
646
|
Example:
|
|
623
|
-
>>>
|
|
624
|
-
>>>
|
|
647
|
+
>>> import numpy as np
|
|
648
|
+
>>> from canns.analyzer.visualization import (
|
|
649
|
+
... create_grid_cell_tracking_animation,
|
|
650
|
+
... PlotConfigs,
|
|
651
|
+
... )
|
|
652
|
+
>>>
|
|
653
|
+
>>> position = np.array([[0.0, 0.0], [0.1, 0.1], [0.2, 0.2]])
|
|
654
|
+
>>> activity = np.array([0.0, 0.5, 1.0])
|
|
655
|
+
>>> rate_map = np.random.rand(5, 5)
|
|
625
656
|
>>> config = PlotConfigs.grid_cell_tracking_animation(
|
|
626
|
-
... time_steps_per_second=
|
|
627
|
-
... fps=
|
|
628
|
-
...
|
|
657
|
+
... time_steps_per_second=10,
|
|
658
|
+
... fps=2,
|
|
659
|
+
... show=False,
|
|
629
660
|
... )
|
|
630
661
|
>>> anim = create_grid_cell_tracking_animation(
|
|
631
|
-
... position,
|
|
662
|
+
... position,
|
|
663
|
+
... activity,
|
|
664
|
+
... rate_map,
|
|
632
665
|
... config=config,
|
|
633
|
-
... env_size=
|
|
666
|
+
... env_size=1.0,
|
|
634
667
|
... )
|
|
668
|
+
>>> print(anim is not None)
|
|
669
|
+
True
|
|
635
670
|
"""
|
|
636
671
|
config = _ensure_plot_config(
|
|
637
672
|
config,
|
|
@@ -767,8 +802,12 @@ def create_grid_cell_tracking_animation(
|
|
|
767
802
|
|
|
768
803
|
if backend == "imageio":
|
|
769
804
|
# Use imageio backend with parallel rendering
|
|
770
|
-
workers =
|
|
771
|
-
|
|
805
|
+
workers = (
|
|
806
|
+
render_workers if render_workers is not None else get_optimal_worker_count()
|
|
807
|
+
)
|
|
808
|
+
ctx, start_method = get_multiprocessing_context(
|
|
809
|
+
prefer_fork=(render_start_method == "fork")
|
|
810
|
+
)
|
|
772
811
|
|
|
773
812
|
# Create render options
|
|
774
813
|
render_options = _GridCellTrackingRenderOptions(
|
|
@@ -788,11 +827,14 @@ def create_grid_cell_tracking_animation(
|
|
|
788
827
|
writer_kwargs, mode = get_imageio_writer_kwargs(config.save_path, config.fps)
|
|
789
828
|
|
|
790
829
|
try:
|
|
791
|
-
import imageio
|
|
792
830
|
from functools import partial
|
|
793
831
|
|
|
832
|
+
import imageio
|
|
833
|
+
|
|
794
834
|
# Create partial function with options
|
|
795
|
-
render_func = partial(
|
|
835
|
+
render_func = partial(
|
|
836
|
+
_render_single_grid_tracking_frame, options=render_options
|
|
837
|
+
)
|
|
796
838
|
|
|
797
839
|
with imageio.get_writer(config.save_path, mode=mode, **writer_kwargs) as writer:
|
|
798
840
|
if workers > 1 and ctx is not None:
|
|
@@ -826,6 +868,7 @@ def create_grid_cell_tracking_animation(
|
|
|
826
868
|
|
|
827
869
|
except Exception as e:
|
|
828
870
|
import warnings
|
|
871
|
+
|
|
829
872
|
warnings.warn(
|
|
830
873
|
f"imageio rendering failed: {e}. Falling back to matplotlib.",
|
|
831
874
|
RuntimeWarning,
|
|
@@ -855,7 +898,9 @@ def create_grid_cell_tracking_animation(
|
|
|
855
898
|
pbar.update(1)
|
|
856
899
|
|
|
857
900
|
try:
|
|
858
|
-
ani.save(
|
|
901
|
+
ani.save(
|
|
902
|
+
config.save_path, writer=writer, progress_callback=progress_callback
|
|
903
|
+
)
|
|
859
904
|
print(f"Animation saved to: {config.save_path}")
|
|
860
905
|
finally:
|
|
861
906
|
pbar.close()
|
|
@@ -47,9 +47,6 @@ def raster_plot(
|
|
|
47
47
|
):
|
|
48
48
|
"""Generate a raster plot from a spike train matrix.
|
|
49
49
|
|
|
50
|
-
The explanatory text mirrors the former ``visualize`` module so callers see
|
|
51
|
-
the same guidance after the reorganisation.
|
|
52
|
-
|
|
53
50
|
Args:
|
|
54
51
|
spike_train: Boolean/integer array of shape ``(timesteps, neurons)``.
|
|
55
52
|
config: Optional :class:`PlotConfig` with shared styling options.
|
|
@@ -62,6 +59,17 @@ def raster_plot(
|
|
|
62
59
|
save_path: Optional path used to persist the plot.
|
|
63
60
|
show: Whether to display the plot interactively.
|
|
64
61
|
**kwargs: Additional keyword arguments passed through to Matplotlib.
|
|
62
|
+
|
|
63
|
+
Examples:
|
|
64
|
+
>>> import numpy as np
|
|
65
|
+
>>> from canns.analyzer.visualization import raster_plot, PlotConfigs
|
|
66
|
+
>>>
|
|
67
|
+
>>> spike_train = np.zeros((5, 3), dtype=int)
|
|
68
|
+
>>> spike_train[::2, 0] = 1
|
|
69
|
+
>>> config = PlotConfigs.raster_plot(show=False)
|
|
70
|
+
>>> fig, ax = raster_plot(spike_train, config=config)
|
|
71
|
+
>>> print(fig is not None)
|
|
72
|
+
True
|
|
65
73
|
"""
|
|
66
74
|
|
|
67
75
|
config = _ensure_plot_config(
|
|
@@ -158,6 +166,16 @@ def average_firing_rate_plot(
|
|
|
158
166
|
save_path: Optional path used to persist the plot.
|
|
159
167
|
show: Whether to display the plot interactively.
|
|
160
168
|
**kwargs: Additional keyword arguments forwarded to Matplotlib.
|
|
169
|
+
|
|
170
|
+
Examples:
|
|
171
|
+
>>> import numpy as np
|
|
172
|
+
>>> from canns.analyzer.visualization import average_firing_rate_plot, PlotConfigs
|
|
173
|
+
>>>
|
|
174
|
+
>>> spike_train = np.random.randint(0, 2, size=(10, 4))
|
|
175
|
+
>>> config = PlotConfigs.average_firing_rate_plot(mode="population", show=False)
|
|
176
|
+
>>> fig, ax = average_firing_rate_plot(spike_train, dt=0.1, config=config)
|
|
177
|
+
>>> print(fig is not None)
|
|
178
|
+
True
|
|
161
179
|
"""
|
|
162
180
|
|
|
163
181
|
config = _ensure_plot_config(
|
|
@@ -267,10 +285,12 @@ def population_activity_heatmap(
|
|
|
267
285
|
|
|
268
286
|
Example:
|
|
269
287
|
>>> import numpy as np
|
|
270
|
-
>>> from canns.analyzer.visualization
|
|
271
|
-
>>>
|
|
272
|
-
>>>
|
|
273
|
-
>>> fig, ax = population_activity_heatmap(activity, dt=0.
|
|
288
|
+
>>> from canns.analyzer.visualization import population_activity_heatmap, PlotConfig
|
|
289
|
+
>>> activity = np.random.rand(10, 5)
|
|
290
|
+
>>> config = PlotConfig(show=False)
|
|
291
|
+
>>> fig, ax = population_activity_heatmap(activity, dt=0.1, config=config)
|
|
292
|
+
>>> print(fig is not None)
|
|
293
|
+
True
|
|
274
294
|
"""
|
|
275
295
|
if config is None:
|
|
276
296
|
config = PlotConfig(
|