canns 0.12.6__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/__init__.py +39 -3
- canns/analyzer/__init__.py +7 -6
- canns/analyzer/data/__init__.py +3 -11
- canns/analyzer/data/asa/__init__.py +74 -0
- canns/analyzer/data/asa/cohospace.py +905 -0
- canns/analyzer/data/asa/config.py +246 -0
- canns/analyzer/data/asa/decode.py +448 -0
- canns/analyzer/data/asa/embedding.py +269 -0
- canns/analyzer/data/asa/filters.py +208 -0
- canns/analyzer/data/asa/fr.py +439 -0
- canns/analyzer/data/asa/path.py +389 -0
- canns/analyzer/data/asa/plotting.py +1276 -0
- canns/analyzer/data/asa/tda.py +901 -0
- canns/analyzer/data/legacy/__init__.py +6 -0
- canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
- canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
- canns/analyzer/metrics/spatial_metrics.py +70 -100
- canns/analyzer/metrics/systematic_ratemap.py +12 -17
- canns/analyzer/metrics/utils.py +28 -0
- canns/analyzer/model_specific/hopfield.py +19 -16
- canns/analyzer/slow_points/checkpoint.py +32 -9
- canns/analyzer/slow_points/finder.py +33 -6
- canns/analyzer/slow_points/fixed_points.py +12 -0
- canns/analyzer/slow_points/visualization.py +22 -10
- canns/analyzer/visualization/core/backend.py +15 -26
- canns/analyzer/visualization/core/config.py +120 -15
- canns/analyzer/visualization/core/jupyter_utils.py +34 -16
- canns/analyzer/visualization/core/rendering.py +42 -40
- canns/analyzer/visualization/core/writers.py +10 -20
- canns/analyzer/visualization/energy_plots.py +78 -28
- canns/analyzer/visualization/spatial_plots.py +81 -36
- canns/analyzer/visualization/spike_plots.py +27 -7
- canns/analyzer/visualization/theta_sweep_plots.py +159 -72
- canns/analyzer/visualization/tuning_plots.py +11 -3
- canns/data/__init__.py +7 -4
- canns/models/__init__.py +10 -0
- canns/models/basic/cann.py +102 -40
- canns/models/basic/grid_cell.py +9 -8
- canns/models/basic/hierarchical_model.py +57 -11
- canns/models/brain_inspired/hopfield.py +26 -14
- canns/models/brain_inspired/linear.py +15 -16
- canns/models/brain_inspired/spiking.py +23 -12
- canns/pipeline/__init__.py +4 -8
- canns/pipeline/asa/__init__.py +21 -0
- canns/pipeline/asa/__main__.py +11 -0
- canns/pipeline/asa/app.py +1000 -0
- canns/pipeline/asa/runner.py +1095 -0
- canns/pipeline/asa/screens.py +215 -0
- canns/pipeline/asa/state.py +248 -0
- canns/pipeline/asa/styles.tcss +221 -0
- canns/pipeline/asa/widgets.py +233 -0
- canns/pipeline/gallery/__init__.py +7 -0
- canns/task/closed_loop_navigation.py +54 -13
- canns/task/open_loop_navigation.py +230 -147
- canns/task/tracking.py +156 -24
- canns/trainer/__init__.py +8 -5
- canns/utils/__init__.py +12 -4
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
- canns-0.13.0.dist-info/RECORD +91 -0
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
- canns/pipeline/theta_sweep.py +0 -573
- canns-0.12.6.dist-info/RECORD +0 -72
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -25,7 +25,6 @@ from tqdm import tqdm
|
|
|
25
25
|
|
|
26
26
|
from .core.backend import (
|
|
27
27
|
emit_backend_warnings,
|
|
28
|
-
get_imageio_writer_kwargs,
|
|
29
28
|
get_multiprocessing_context,
|
|
30
29
|
get_optimal_worker_count,
|
|
31
30
|
select_animation_backend,
|
|
@@ -501,21 +500,35 @@ def plot_population_activity_with_theta(
|
|
|
501
500
|
save_path: str | None = None,
|
|
502
501
|
**kwargs,
|
|
503
502
|
) -> tuple[plt.Figure, plt.Axes]:
|
|
504
|
-
"""
|
|
505
|
-
Plot neural population activity with theta oscillation markers and direction trace.
|
|
503
|
+
"""Plot neural population activity with theta phase markers.
|
|
506
504
|
|
|
507
505
|
Args:
|
|
508
|
-
time_steps: Array of time points
|
|
509
|
-
theta_phase:
|
|
510
|
-
net_activity: 2D array of network activity (time, neurons)
|
|
511
|
-
direction:
|
|
512
|
-
config: PlotConfig object for unified configuration
|
|
513
|
-
add_lines: Whether to add vertical lines at theta phase zeros
|
|
514
|
-
atol: Tolerance for detecting theta phase zeros
|
|
515
|
-
**kwargs: Additional parameters for backward compatibility
|
|
506
|
+
time_steps: Array of time points.
|
|
507
|
+
theta_phase: Theta phase values in ``[-pi, pi]``.
|
|
508
|
+
net_activity: 2D array of network activity ``(time, neurons)``.
|
|
509
|
+
direction: Direction values (radians) over time.
|
|
510
|
+
config: PlotConfig object for unified configuration.
|
|
511
|
+
add_lines: Whether to add vertical lines at theta phase zeros.
|
|
512
|
+
atol: Tolerance for detecting theta phase zeros.
|
|
513
|
+
**kwargs: Additional parameters for backward compatibility.
|
|
516
514
|
|
|
517
515
|
Returns:
|
|
518
|
-
tuple: (figure, axis) objects
|
|
516
|
+
tuple: ``(figure, axis)`` objects.
|
|
517
|
+
|
|
518
|
+
Examples:
|
|
519
|
+
>>> import numpy as np
|
|
520
|
+
>>> from canns.analyzer.visualization import plot_population_activity_with_theta, PlotConfig
|
|
521
|
+
>>>
|
|
522
|
+
>>> time_steps = np.linspace(0, 1, 5)
|
|
523
|
+
>>> theta_phase = np.linspace(-np.pi, np.pi, 5)
|
|
524
|
+
>>> net_activity = np.random.rand(5, 4)
|
|
525
|
+
>>> direction = np.linspace(-np.pi, np.pi, 5)
|
|
526
|
+
>>> config = PlotConfig(show=False)
|
|
527
|
+
>>> fig, ax = plot_population_activity_with_theta(
|
|
528
|
+
... time_steps, theta_phase, net_activity, direction, config=config
|
|
529
|
+
... )
|
|
530
|
+
>>> print(fig is not None)
|
|
531
|
+
True
|
|
519
532
|
"""
|
|
520
533
|
# Handle configuration
|
|
521
534
|
if config is None:
|
|
@@ -649,18 +662,28 @@ def plot_grid_cell_manifold(
|
|
|
649
662
|
save_path: str | None = None,
|
|
650
663
|
**kwargs,
|
|
651
664
|
) -> tuple[plt.Figure, plt.Axes]:
|
|
652
|
-
"""
|
|
653
|
-
Plot grid cell activity on the twisted torus manifold.
|
|
665
|
+
"""Plot grid cell activity on the twisted torus manifold.
|
|
654
666
|
|
|
655
667
|
Args:
|
|
656
|
-
value_grid_twisted: Coordinates on twisted manifold
|
|
657
|
-
grid_cell_activity: 2D array of grid cell activities
|
|
658
|
-
config: PlotConfig object for unified configuration
|
|
659
|
-
ax: Optional axis to draw on instead of creating a new figure
|
|
660
|
-
**kwargs: Additional parameters for backward compatibility
|
|
668
|
+
value_grid_twisted: Coordinates on the twisted manifold ``(N, 2)``.
|
|
669
|
+
grid_cell_activity: 2D array of grid cell activities.
|
|
670
|
+
config: PlotConfig object for unified configuration.
|
|
671
|
+
ax: Optional axis to draw on instead of creating a new figure.
|
|
672
|
+
**kwargs: Additional parameters for backward compatibility.
|
|
661
673
|
|
|
662
674
|
Returns:
|
|
663
|
-
tuple: (figure, axis) objects
|
|
675
|
+
tuple: ``(figure, axis)`` objects.
|
|
676
|
+
|
|
677
|
+
Examples:
|
|
678
|
+
>>> import numpy as np
|
|
679
|
+
>>> from canns.analyzer.visualization import plot_grid_cell_manifold, PlotConfig
|
|
680
|
+
>>>
|
|
681
|
+
>>> value_grid_twisted = np.random.rand(9, 2)
|
|
682
|
+
>>> grid_cell_activity = np.random.rand(3, 3)
|
|
683
|
+
>>> config = PlotConfig(show=False)
|
|
684
|
+
>>> fig, ax = plot_grid_cell_manifold(value_grid_twisted, grid_cell_activity, config=config)
|
|
685
|
+
>>> print(fig is not None)
|
|
686
|
+
True
|
|
664
687
|
"""
|
|
665
688
|
# Handle configuration
|
|
666
689
|
if config is None:
|
|
@@ -827,9 +850,10 @@ def _render_single_place_cell_frame(
|
|
|
827
850
|
options: _PlaceCellRenderOptions,
|
|
828
851
|
) -> np.ndarray:
|
|
829
852
|
"""Render a single frame for place cell animation (module-level for pickling)."""
|
|
853
|
+
from io import BytesIO
|
|
854
|
+
|
|
830
855
|
import matplotlib.pyplot as plt
|
|
831
856
|
import numpy as np
|
|
832
|
-
from io import BytesIO
|
|
833
857
|
|
|
834
858
|
fig, axes = plt.subplots(1, 2, figsize=options.figsize, width_ratios=[1, 1])
|
|
835
859
|
ax_env, ax_activity = axes
|
|
@@ -959,32 +983,51 @@ def create_theta_sweep_place_cell_animation(
|
|
|
959
983
|
render_start_method: str | None = None,
|
|
960
984
|
**kwargs,
|
|
961
985
|
) -> FuncAnimation | None:
|
|
962
|
-
"""
|
|
963
|
-
Create theta sweep animation for place cell network with 2 panels:
|
|
964
|
-
1. Environment trajectory with place cell bump overlay
|
|
965
|
-
2. Population activity heatmap over time
|
|
986
|
+
"""Create theta sweep animation for a place cell network.
|
|
966
987
|
|
|
967
988
|
Args:
|
|
968
|
-
position_data: Animal position data (time, 2)
|
|
969
|
-
pc_activity_data: Place cell activity (time, num_cells)
|
|
970
|
-
pc_network: PlaceCellNetwork
|
|
971
|
-
navigation_task: BaseNavigationTask
|
|
972
|
-
dt: Time step size
|
|
973
|
-
config: PlotConfig object for unified configuration
|
|
974
|
-
n_step: Subsample every n_step frames for animation
|
|
975
|
-
fps: Frames per second for animation
|
|
976
|
-
figsize: Figure size (width, height)
|
|
977
|
-
save_path: Path to save animation (GIF or MP4)
|
|
978
|
-
show: Whether to display animation
|
|
979
|
-
show_progress_bar: Whether to show progress bar during saving
|
|
980
|
-
render_backend: Rendering backend ('imageio', 'matplotlib', or 'auto')
|
|
981
|
-
output_dpi: DPI for rendered frames (affects file size and quality)
|
|
982
|
-
render_workers: Number of parallel workers (None = auto-detect)
|
|
983
|
-
render_start_method: Multiprocessing start method ('fork', 'spawn', or None)
|
|
984
|
-
**kwargs: Additional parameters (cmap, alpha, etc.)
|
|
989
|
+
position_data: Animal position data ``(time, 2)``.
|
|
990
|
+
pc_activity_data: Place cell activity ``(time, num_cells)``.
|
|
991
|
+
pc_network: PlaceCellNetwork-like object with ``geodesic_result``.
|
|
992
|
+
navigation_task: BaseNavigationTask-like object with ``env``.
|
|
993
|
+
dt: Time step size.
|
|
994
|
+
config: PlotConfig object for unified configuration.
|
|
995
|
+
n_step: Subsample every n_step frames for animation.
|
|
996
|
+
fps: Frames per second for animation.
|
|
997
|
+
figsize: Figure size (width, height).
|
|
998
|
+
save_path: Path to save animation (GIF or MP4).
|
|
999
|
+
show: Whether to display animation.
|
|
1000
|
+
show_progress_bar: Whether to show progress bar during saving.
|
|
1001
|
+
render_backend: Rendering backend ('imageio', 'matplotlib', or 'auto').
|
|
1002
|
+
output_dpi: DPI for rendered frames (affects file size and quality).
|
|
1003
|
+
render_workers: Number of parallel workers (None = auto-detect).
|
|
1004
|
+
render_start_method: Multiprocessing start method ('fork', 'spawn', or None).
|
|
1005
|
+
**kwargs: Additional parameters (cmap, alpha, etc.).
|
|
985
1006
|
|
|
986
1007
|
Returns:
|
|
987
|
-
FuncAnimation: Matplotlib animation object
|
|
1008
|
+
FuncAnimation: Matplotlib animation object.
|
|
1009
|
+
|
|
1010
|
+
Examples:
|
|
1011
|
+
This example demonstrates the basic structure. For complete usage, see the
|
|
1012
|
+
documentation or example scripts.
|
|
1013
|
+
|
|
1014
|
+
>>> import numpy as np
|
|
1015
|
+
>>> from canns.analyzer.visualization import PlotConfig
|
|
1016
|
+
>>>
|
|
1017
|
+
>>> # Prepare your data from simulation
|
|
1018
|
+
>>> position_data = np.array([[0.1, 0.1], [0.2, 0.2], [0.3, 0.3]])
|
|
1019
|
+
>>> pc_activity_data = np.random.rand(3, 4) # (time, num_cells)
|
|
1020
|
+
>>>
|
|
1021
|
+
>>> # Assuming you have pc_network and navigation_task from your model
|
|
1022
|
+
>>> # anim = create_theta_sweep_place_cell_animation(
|
|
1023
|
+
>>> # position_data,
|
|
1024
|
+
>>> # pc_activity_data,
|
|
1025
|
+
>>> # pc_network, # Your PlaceCellNetwork instance
|
|
1026
|
+
>>> # navigation_task, # Your BaseNavigationTask instance
|
|
1027
|
+
>>> # config=PlotConfig(show=False),
|
|
1028
|
+
>>> # n_step=1,
|
|
1029
|
+
>>> # fps=10,
|
|
1030
|
+
>>> # ) # doctest: +SKIP
|
|
988
1031
|
"""
|
|
989
1032
|
# Handle configuration
|
|
990
1033
|
if config is None:
|
|
@@ -1180,8 +1223,9 @@ def create_theta_sweep_place_cell_animation(
|
|
|
1180
1223
|
if backend == "imageio":
|
|
1181
1224
|
# Use imageio backend with parallel rendering
|
|
1182
1225
|
workers = render_workers if render_workers is not None else get_optimal_worker_count()
|
|
1183
|
-
ctx = get_multiprocessing_context(
|
|
1184
|
-
|
|
1226
|
+
ctx, start_method = get_multiprocessing_context(
|
|
1227
|
+
prefer_fork=(render_start_method == "fork")
|
|
1228
|
+
)
|
|
1185
1229
|
|
|
1186
1230
|
_emit_info(
|
|
1187
1231
|
f"Parallel rendering enabled: {workers} workers (start_method={start_method})"
|
|
@@ -1190,7 +1234,11 @@ def create_theta_sweep_place_cell_animation(
|
|
|
1190
1234
|
# Prepare environment data
|
|
1191
1235
|
env = navigation_task.env
|
|
1192
1236
|
boundary_array = np.array(env.boundary) if env.boundary is not None else None
|
|
1193
|
-
walls_arrays =
|
|
1237
|
+
walls_arrays = (
|
|
1238
|
+
[np.array(wall) for wall in env.walls]
|
|
1239
|
+
if env.walls is not None and len(env.walls) > 0
|
|
1240
|
+
else None
|
|
1241
|
+
)
|
|
1194
1242
|
|
|
1195
1243
|
# Get accessible indices
|
|
1196
1244
|
accessible_indices = pc_network.geodesic_result.accessible_indices
|
|
@@ -1215,11 +1263,14 @@ def create_theta_sweep_place_cell_animation(
|
|
|
1215
1263
|
writer_kwargs, mode = get_imageio_writer_kwargs(config.save_path, config.fps)
|
|
1216
1264
|
|
|
1217
1265
|
try:
|
|
1218
|
-
import imageio
|
|
1219
1266
|
from functools import partial
|
|
1220
1267
|
|
|
1268
|
+
import imageio
|
|
1269
|
+
|
|
1221
1270
|
# Create partial function with data and options
|
|
1222
|
-
render_func = partial(
|
|
1271
|
+
render_func = partial(
|
|
1272
|
+
_render_single_place_cell_frame, data=data, options=render_options
|
|
1273
|
+
)
|
|
1223
1274
|
|
|
1224
1275
|
with imageio.get_writer(config.save_path, mode=mode, **writer_kwargs) as writer:
|
|
1225
1276
|
if workers > 1 and ctx is not None:
|
|
@@ -1335,32 +1386,69 @@ def create_theta_sweep_grid_cell_animation(
|
|
|
1335
1386
|
render_start_method: str | None = None,
|
|
1336
1387
|
**kwargs,
|
|
1337
1388
|
) -> FuncAnimation | None:
|
|
1338
|
-
"""
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1389
|
+
"""Create a theta sweep animation with four panels.
|
|
1390
|
+
|
|
1391
|
+
Panels:
|
|
1392
|
+
1) Animal trajectory
|
|
1393
|
+
2) Direction cell polar plot
|
|
1394
|
+
3) Grid cell activity on manifold
|
|
1395
|
+
4) Grid cell activity in real space
|
|
1344
1396
|
|
|
1345
1397
|
Args:
|
|
1346
|
-
position_data: Animal position data (time, 2)
|
|
1347
|
-
direction_data: Direction data (time,)
|
|
1348
|
-
dc_activity_data: Direction cell activity (time, neurons)
|
|
1349
|
-
gc_activity_data: Grid cell activity (time, neurons)
|
|
1350
|
-
gc_network: GridCellNetwork instance for coordinate
|
|
1351
|
-
env_size: Environment size
|
|
1352
|
-
mapping_ratio: Mapping ratio for grid cells
|
|
1353
|
-
dt: Time step size
|
|
1354
|
-
config: PlotConfig object for unified configuration
|
|
1355
|
-
n_step: Subsample every n_step frames for animation
|
|
1356
|
-
render_backend: Rendering backend. Use 'matplotlib', 'imageio', or 'auto'
|
|
1357
|
-
output_dpi: Target DPI
|
|
1358
|
-
render_workers: Worker processes for imageio backend.
|
|
1359
|
-
render_start_method: Multiprocessing start method ('fork', 'spawn',
|
|
1360
|
-
**kwargs: Additional parameters for backward compatibility
|
|
1398
|
+
position_data: Animal position data ``(time, 2)``.
|
|
1399
|
+
direction_data: Direction data ``(time,)``.
|
|
1400
|
+
dc_activity_data: Direction cell activity ``(time, neurons)``.
|
|
1401
|
+
gc_activity_data: Grid cell activity ``(time, neurons)``.
|
|
1402
|
+
gc_network: GridCellNetwork instance for coordinate transforms.
|
|
1403
|
+
env_size: Environment size.
|
|
1404
|
+
mapping_ratio: Mapping ratio for grid cells.
|
|
1405
|
+
dt: Time step size.
|
|
1406
|
+
config: PlotConfig object for unified configuration.
|
|
1407
|
+
n_step: Subsample every n_step frames for animation.
|
|
1408
|
+
render_backend: Rendering backend. Use 'matplotlib', 'imageio', or 'auto'.
|
|
1409
|
+
output_dpi: Target DPI for non-interactive rendering.
|
|
1410
|
+
render_workers: Worker processes for imageio backend.
|
|
1411
|
+
render_start_method: Multiprocessing start method ('fork', 'spawn', or None).
|
|
1412
|
+
**kwargs: Additional parameters for backward compatibility.
|
|
1361
1413
|
|
|
1362
1414
|
Returns:
|
|
1363
|
-
FuncAnimation | None:
|
|
1415
|
+
FuncAnimation | None: Animation object (None if displayed inline).
|
|
1416
|
+
|
|
1417
|
+
Examples:
|
|
1418
|
+
This is a minimal structural example using synthetic data to demonstrate
|
|
1419
|
+
the API. For realistic usage, run a GridCellNetwork simulation to obtain
|
|
1420
|
+
actual activity data.
|
|
1421
|
+
|
|
1422
|
+
>>> import numpy as np
|
|
1423
|
+
>>> import brainpy.math as bm
|
|
1424
|
+
>>> from canns.models.basic.theta_sweep_model import GridCellNetwork
|
|
1425
|
+
>>> from canns.analyzer.visualization import PlotConfig
|
|
1426
|
+
>>>
|
|
1427
|
+
>>> # Minimal example with synthetic data (for structure demonstration)
|
|
1428
|
+
>>> bm.set_dt(1.0)
|
|
1429
|
+
>>> gc_network = GridCellNetwork(num_dc=4, num_gc_x=4, mapping_ratio=1.0)
|
|
1430
|
+
>>> T = 5
|
|
1431
|
+
>>> # NOTE: In real usage, obtain these from actual model simulation
|
|
1432
|
+
>>> position_data = np.random.rand(T, 2)
|
|
1433
|
+
>>> direction_data = np.linspace(-np.pi, np.pi, T)
|
|
1434
|
+
>>> dc_activity_data = np.random.rand(T, gc_network.num_dc)
|
|
1435
|
+
>>> gc_activity_data = np.random.rand(T, gc_network.num)
|
|
1436
|
+
>>>
|
|
1437
|
+
>>> config = PlotConfig(show=False)
|
|
1438
|
+
>>> anim = create_theta_sweep_grid_cell_animation(
|
|
1439
|
+
... position_data,
|
|
1440
|
+
... direction_data,
|
|
1441
|
+
... dc_activity_data,
|
|
1442
|
+
... gc_activity_data,
|
|
1443
|
+
... gc_network,
|
|
1444
|
+
... env_size=1.0,
|
|
1445
|
+
... mapping_ratio=1.0,
|
|
1446
|
+
... config=config,
|
|
1447
|
+
... n_step=1,
|
|
1448
|
+
... fps=2,
|
|
1449
|
+
... )
|
|
1450
|
+
>>> print(anim is not None)
|
|
1451
|
+
True
|
|
1364
1452
|
"""
|
|
1365
1453
|
# Handle configuration
|
|
1366
1454
|
if config is None:
|
|
@@ -1426,8 +1514,7 @@ def create_theta_sweep_grid_cell_animation(
|
|
|
1426
1514
|
workers = render_workers if render_workers is not None else get_optimal_worker_count()
|
|
1427
1515
|
|
|
1428
1516
|
# Get multiprocessing context
|
|
1429
|
-
ctx = get_multiprocessing_context(prefer_fork=True)
|
|
1430
|
-
start_method = ctx.method if ctx is not None else None
|
|
1517
|
+
ctx, start_method = get_multiprocessing_context(prefer_fork=True)
|
|
1431
1518
|
|
|
1432
1519
|
if workers > 0 and ctx is None:
|
|
1433
1520
|
warnings.warn(
|
|
@@ -69,9 +69,6 @@ def tuning_curve(
|
|
|
69
69
|
):
|
|
70
70
|
"""Plot the tuning curve for one or more neurons.
|
|
71
71
|
|
|
72
|
-
The wording mirrors the original ``visualize`` module to avoid API drift and
|
|
73
|
-
to keep existing references valid.
|
|
74
|
-
|
|
75
72
|
Args:
|
|
76
73
|
stimulus: 1D array with the stimulus value at each time step.
|
|
77
74
|
firing_rates: 2D array of firing rates shaped ``(timesteps, neurons)``.
|
|
@@ -86,6 +83,17 @@ def tuning_curve(
|
|
|
86
83
|
save_path: Optional location where the figure should be stored.
|
|
87
84
|
show: Whether to display the plot interactively.
|
|
88
85
|
**kwargs: Additional keyword arguments passed through to ``ax.plot``.
|
|
86
|
+
|
|
87
|
+
Examples:
|
|
88
|
+
>>> import numpy as np
|
|
89
|
+
>>> from canns.analyzer.visualization import tuning_curve, PlotConfigs
|
|
90
|
+
>>>
|
|
91
|
+
>>> stimulus = np.linspace(0, 1, 10)
|
|
92
|
+
>>> firing_rates = np.random.rand(10, 3)
|
|
93
|
+
>>> config = PlotConfigs.tuning_curve(num_bins=5, pref_stim=np.array([0.2, 0.5, 0.8]), show=False)
|
|
94
|
+
>>> fig, ax = tuning_curve(stimulus, firing_rates, neuron_indices=[0, 1], config=config)
|
|
95
|
+
>>> print(fig is not None)
|
|
96
|
+
True
|
|
89
97
|
"""
|
|
90
98
|
|
|
91
99
|
config = _ensure_plot_config(
|
canns/data/__init__.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
1
|
+
"""Data utilities for CANNs.
|
|
2
|
+
|
|
3
|
+
This namespace provides dataset registry, download helpers, and convenience
|
|
4
|
+
loaders for common CANNs datasets.
|
|
3
5
|
|
|
4
|
-
|
|
5
|
-
|
|
6
|
+
Examples:
|
|
7
|
+
>>> from canns import data
|
|
8
|
+
>>> print(list(data.DATASETS))
|
|
6
9
|
"""
|
|
7
10
|
|
|
8
11
|
from .datasets import (
|
canns/models/__init__.py
CHANGED
|
@@ -1,3 +1,13 @@
|
|
|
1
|
+
"""Model definitions for CANNs.
|
|
2
|
+
|
|
3
|
+
This namespace exposes common model families, such as basic CANNs and
|
|
4
|
+
brain-inspired variants.
|
|
5
|
+
|
|
6
|
+
Examples:
|
|
7
|
+
>>> from canns import models
|
|
8
|
+
>>> print(models.basic)
|
|
9
|
+
"""
|
|
10
|
+
|
|
1
11
|
from . import basic as basic
|
|
2
12
|
from . import brain_inspired as brain_inspired
|
|
3
13
|
# from .hybrid import *
|
canns/models/basic/cann.py
CHANGED
|
@@ -76,11 +76,20 @@ class BaseCANN(BasicModel):
|
|
|
76
76
|
|
|
77
77
|
|
|
78
78
|
class BaseCANN1D(BaseCANN):
|
|
79
|
-
"""
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
79
|
+
"""Base class for 1D Continuous Attractor Neural Network (CANN) models.
|
|
80
|
+
|
|
81
|
+
It builds the 1D feature space, connectivity kernel, and stimulus helpers
|
|
82
|
+
shared by 1D CANN variants.
|
|
83
|
+
|
|
84
|
+
Examples:
|
|
85
|
+
>>> import brainpy.math as bm
|
|
86
|
+
>>> from canns.models.basic.cann import BaseCANN1D
|
|
87
|
+
>>>
|
|
88
|
+
>>> bm.set_dt(0.1)
|
|
89
|
+
>>> model = BaseCANN1D(num=64)
|
|
90
|
+
>>> stimulus = model.get_stimulus_by_pos(0.0)
|
|
91
|
+
>>> stimulus.shape
|
|
92
|
+
(64,)
|
|
84
93
|
"""
|
|
85
94
|
|
|
86
95
|
def __init__(
|
|
@@ -181,10 +190,21 @@ class BaseCANN1D(BaseCANN):
|
|
|
181
190
|
|
|
182
191
|
|
|
183
192
|
class CANN1D(BaseCANN1D):
|
|
184
|
-
"""
|
|
185
|
-
|
|
186
|
-
This model
|
|
187
|
-
|
|
193
|
+
"""Standard 1D Continuous Attractor Neural Network (CANN) model.
|
|
194
|
+
|
|
195
|
+
This model sustains a localized "bump" of activity that can be driven by
|
|
196
|
+
external input.
|
|
197
|
+
|
|
198
|
+
Examples:
|
|
199
|
+
>>> import brainpy.math as bm
|
|
200
|
+
>>> from canns.models.basic import CANN1D
|
|
201
|
+
>>>
|
|
202
|
+
>>> bm.set_dt(0.1)
|
|
203
|
+
>>> model = CANN1D(num=64)
|
|
204
|
+
>>> stimulus = model.get_stimulus_by_pos(0.0)
|
|
205
|
+
>>> model.update(stimulus)
|
|
206
|
+
>>> model.r.value.shape
|
|
207
|
+
(64,)
|
|
188
208
|
|
|
189
209
|
Reference:
|
|
190
210
|
Wu, S., Hamaguchi, K., & Amari, S. I. (2008). Dynamics and computation of continuous attractors.
|
|
@@ -210,11 +230,13 @@ class CANN1D(BaseCANN1D):
|
|
|
210
230
|
self.inp = bm.Variable(bm.zeros(self.shape))
|
|
211
231
|
|
|
212
232
|
def update(self, inp):
|
|
213
|
-
"""
|
|
214
|
-
The main update function, defining the dynamics of the network for one time step.
|
|
233
|
+
"""Advance the network by one time step.
|
|
215
234
|
|
|
216
235
|
Args:
|
|
217
|
-
inp (Array):
|
|
236
|
+
inp (Array): External input vector of shape ``(num,)``.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
None
|
|
218
240
|
"""
|
|
219
241
|
self.inp.value = inp
|
|
220
242
|
# The numerator for the firing rate calculation (a non-linear activation function).
|
|
@@ -231,10 +253,21 @@ class CANN1D(BaseCANN1D):
|
|
|
231
253
|
|
|
232
254
|
|
|
233
255
|
class CANN1D_SFA(BaseCANN1D):
|
|
234
|
-
"""
|
|
235
|
-
|
|
236
|
-
SFA
|
|
237
|
-
|
|
256
|
+
"""1D CANN model with spike-frequency adaptation (SFA).
|
|
257
|
+
|
|
258
|
+
SFA adds a slow negative feedback term that can create anticipative tracking
|
|
259
|
+
under sustained inputs.
|
|
260
|
+
|
|
261
|
+
Examples:
|
|
262
|
+
>>> import brainpy.math as bm
|
|
263
|
+
>>> from canns.models.basic import CANN1D_SFA
|
|
264
|
+
>>>
|
|
265
|
+
>>> bm.set_dt(0.1)
|
|
266
|
+
>>> model = CANN1D_SFA(num=64)
|
|
267
|
+
>>> stimulus = model.get_stimulus_by_pos(0.0)
|
|
268
|
+
>>> model.update(stimulus)
|
|
269
|
+
>>> model.r.value.shape
|
|
270
|
+
(64,)
|
|
238
271
|
|
|
239
272
|
Reference:
|
|
240
273
|
Mi, Y., Fung, C. C., Wong, K. Y., & Wu, S. (2014). Spike frequency adaptation
|
|
@@ -278,12 +311,13 @@ class CANN1D_SFA(BaseCANN1D):
|
|
|
278
311
|
self.inp = bm.Variable(bm.zeros(self.shape)) # External input.
|
|
279
312
|
|
|
280
313
|
def update(self, inp):
|
|
281
|
-
"""
|
|
282
|
-
The main update function for the SFA model. It includes dynamics for both
|
|
283
|
-
the membrane potential and the adaptation variable.
|
|
314
|
+
"""Advance the network by one time step with adaptation.
|
|
284
315
|
|
|
285
316
|
Args:
|
|
286
|
-
inp (Array):
|
|
317
|
+
inp (Array): External input vector of shape ``(num,)``.
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
None
|
|
287
321
|
"""
|
|
288
322
|
self.inp.value = inp
|
|
289
323
|
# Firing rate calculation is the same as the standard CANN model.
|
|
@@ -302,11 +336,20 @@ class CANN1D_SFA(BaseCANN1D):
|
|
|
302
336
|
|
|
303
337
|
|
|
304
338
|
class BaseCANN2D(BaseCANN):
|
|
305
|
-
"""
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
339
|
+
"""Base class for 2D Continuous Attractor Neural Network (CANN) models.
|
|
340
|
+
|
|
341
|
+
It builds the 2D feature space, connectivity kernel, and stimulus helpers
|
|
342
|
+
shared by 2D CANN variants.
|
|
343
|
+
|
|
344
|
+
Examples:
|
|
345
|
+
>>> import brainpy.math as bm
|
|
346
|
+
>>> from canns.models.basic.cann import BaseCANN2D
|
|
347
|
+
>>>
|
|
348
|
+
>>> bm.set_dt(0.1)
|
|
349
|
+
>>> model = BaseCANN2D(length=16)
|
|
350
|
+
>>> stimulus = model.get_stimulus_by_pos([0.0, 0.0])
|
|
351
|
+
>>> stimulus.shape
|
|
352
|
+
(16, 16)
|
|
310
353
|
"""
|
|
311
354
|
|
|
312
355
|
def __init__(
|
|
@@ -456,10 +499,18 @@ class BaseCANN2D(BaseCANN):
|
|
|
456
499
|
|
|
457
500
|
|
|
458
501
|
class CANN2D(BaseCANN2D):
|
|
459
|
-
"""
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
502
|
+
"""2D Continuous Attractor Neural Network (CANN) model.
|
|
503
|
+
|
|
504
|
+
Examples:
|
|
505
|
+
>>> import brainpy.math as bm
|
|
506
|
+
>>> from canns.models.basic import CANN2D
|
|
507
|
+
>>>
|
|
508
|
+
>>> bm.set_dt(0.1)
|
|
509
|
+
>>> model = CANN2D(length=16)
|
|
510
|
+
>>> stimulus = model.get_stimulus_by_pos([0.0, 0.0])
|
|
511
|
+
>>> model.update(stimulus)
|
|
512
|
+
>>> model.r.value.shape
|
|
513
|
+
(16, 16)
|
|
463
514
|
|
|
464
515
|
Reference:
|
|
465
516
|
Wu, S., Hamaguchi, K., & Amari, S. I. (2008). Dynamics and computation of continuous attractors.
|
|
@@ -485,11 +536,13 @@ class CANN2D(BaseCANN2D):
|
|
|
485
536
|
self.inp = bm.Variable(bm.zeros((self.length, self.length)))
|
|
486
537
|
|
|
487
538
|
def update(self, inp):
|
|
488
|
-
"""
|
|
489
|
-
The main update function, defining the dynamics of the network for one time step.
|
|
539
|
+
"""Advance the network by one time step.
|
|
490
540
|
|
|
491
541
|
Args:
|
|
492
|
-
inp (Array):
|
|
542
|
+
inp (Array): External input grid of shape ``(length, length)``.
|
|
543
|
+
|
|
544
|
+
Returns:
|
|
545
|
+
None
|
|
493
546
|
"""
|
|
494
547
|
self.inp.value = inp
|
|
495
548
|
# The numerator for the firing rate calculation (a non-linear activation function).
|
|
@@ -505,10 +558,18 @@ class CANN2D(BaseCANN2D):
|
|
|
505
558
|
|
|
506
559
|
|
|
507
560
|
class CANN2D_SFA(BaseCANN2D):
|
|
508
|
-
"""
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
561
|
+
"""2D CANN model with spike-frequency adaptation (SFA) dynamics.
|
|
562
|
+
|
|
563
|
+
Examples:
|
|
564
|
+
>>> import brainpy.math as bm
|
|
565
|
+
>>> from canns.models.basic import CANN2D_SFA
|
|
566
|
+
>>>
|
|
567
|
+
>>> bm.set_dt(0.1)
|
|
568
|
+
>>> model = CANN2D_SFA(length=16)
|
|
569
|
+
>>> stimulus = model.get_stimulus_by_pos([0.0, 0.0])
|
|
570
|
+
>>> model.update(stimulus)
|
|
571
|
+
>>> model.r.value.shape
|
|
572
|
+
(16, 16)
|
|
512
573
|
"""
|
|
513
574
|
|
|
514
575
|
def __init__(
|
|
@@ -544,12 +605,13 @@ class CANN2D_SFA(BaseCANN2D):
|
|
|
544
605
|
self.inp = bm.Variable(bm.zeros((self.length, self.length))) # External input.
|
|
545
606
|
|
|
546
607
|
def update(self, inp):
|
|
547
|
-
"""
|
|
548
|
-
The main update function for the SFA model. It includes dynamics for both
|
|
549
|
-
the membrane potential and the adaptation variable.
|
|
608
|
+
"""Advance the network by one time step with adaptation.
|
|
550
609
|
|
|
551
610
|
Args:
|
|
552
|
-
inp (Array):
|
|
611
|
+
inp (Array): External input grid of shape ``(length, length)``.
|
|
612
|
+
|
|
613
|
+
Returns:
|
|
614
|
+
None
|
|
553
615
|
"""
|
|
554
616
|
self.inp.value = inp
|
|
555
617
|
# Firing rate calculation is the same as the standard CANN model.
|
canns/models/basic/grid_cell.py
CHANGED
|
@@ -54,18 +54,19 @@ class GridCell2DPosition(BasicModel):
|
|
|
54
54
|
|
|
55
55
|
Example:
|
|
56
56
|
>>> import brainpy.math as bm
|
|
57
|
-
>>> from canns.models.basic import
|
|
57
|
+
>>> from canns.models.basic import GridCell2DPosition
|
|
58
58
|
>>>
|
|
59
59
|
>>> bm.set_dt(1.0)
|
|
60
|
-
>>> model =
|
|
60
|
+
>>> model = GridCell2DPosition(length=16, mapping_ratio=1.5)
|
|
61
61
|
>>>
|
|
62
|
-
>>> # Update with 2D position
|
|
63
|
-
>>> position = [0.5, 0.3]
|
|
62
|
+
>>> # Update with a 2D position
|
|
63
|
+
>>> position = bm.array([0.5, 0.3])
|
|
64
64
|
>>> model.update(position)
|
|
65
65
|
>>>
|
|
66
66
|
>>> # Access decoded position
|
|
67
67
|
>>> decoded_pos = model.center_position.value
|
|
68
|
-
>>>
|
|
68
|
+
>>> decoded_pos.shape
|
|
69
|
+
(2,)
|
|
69
70
|
|
|
70
71
|
References:
|
|
71
72
|
Burak, Y., & Fiete, I. R. (2009).
|
|
@@ -387,11 +388,11 @@ class GridCell2DVelocity(BasicModel):
|
|
|
387
388
|
>>> bm.set_dt(5e-4) # Small timestep for accurate integration
|
|
388
389
|
>>> model = GridCell2DVelocity(length=40)
|
|
389
390
|
>>>
|
|
390
|
-
>>> # Healing process (
|
|
391
|
-
>>> model.heal_network()
|
|
391
|
+
>>> # Healing process (recommended before simulation)
|
|
392
|
+
>>> model.heal_network(num_healing_steps=50, dt_healing=1e-3)
|
|
392
393
|
>>>
|
|
393
394
|
>>> # Update with 2D velocity
|
|
394
|
-
>>> velocity = [0.1, 0.05] # [vx, vy]
|
|
395
|
+
>>> velocity = bm.array([0.1, 0.05]) # [vx, vy]
|
|
395
396
|
>>> model.update(velocity)
|
|
396
397
|
|
|
397
398
|
References:
|