canns 0.12.6__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. canns/__init__.py +39 -3
  2. canns/analyzer/__init__.py +7 -6
  3. canns/analyzer/data/__init__.py +3 -11
  4. canns/analyzer/data/asa/__init__.py +74 -0
  5. canns/analyzer/data/asa/cohospace.py +905 -0
  6. canns/analyzer/data/asa/config.py +246 -0
  7. canns/analyzer/data/asa/decode.py +448 -0
  8. canns/analyzer/data/asa/embedding.py +269 -0
  9. canns/analyzer/data/asa/filters.py +208 -0
  10. canns/analyzer/data/asa/fr.py +439 -0
  11. canns/analyzer/data/asa/path.py +389 -0
  12. canns/analyzer/data/asa/plotting.py +1276 -0
  13. canns/analyzer/data/asa/tda.py +901 -0
  14. canns/analyzer/data/legacy/__init__.py +6 -0
  15. canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
  16. canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
  17. canns/analyzer/metrics/spatial_metrics.py +70 -100
  18. canns/analyzer/metrics/systematic_ratemap.py +12 -17
  19. canns/analyzer/metrics/utils.py +28 -0
  20. canns/analyzer/model_specific/hopfield.py +19 -16
  21. canns/analyzer/slow_points/checkpoint.py +32 -9
  22. canns/analyzer/slow_points/finder.py +33 -6
  23. canns/analyzer/slow_points/fixed_points.py +12 -0
  24. canns/analyzer/slow_points/visualization.py +22 -10
  25. canns/analyzer/visualization/core/backend.py +15 -26
  26. canns/analyzer/visualization/core/config.py +120 -15
  27. canns/analyzer/visualization/core/jupyter_utils.py +34 -16
  28. canns/analyzer/visualization/core/rendering.py +42 -40
  29. canns/analyzer/visualization/core/writers.py +10 -20
  30. canns/analyzer/visualization/energy_plots.py +78 -28
  31. canns/analyzer/visualization/spatial_plots.py +81 -36
  32. canns/analyzer/visualization/spike_plots.py +27 -7
  33. canns/analyzer/visualization/theta_sweep_plots.py +159 -72
  34. canns/analyzer/visualization/tuning_plots.py +11 -3
  35. canns/data/__init__.py +7 -4
  36. canns/models/__init__.py +10 -0
  37. canns/models/basic/cann.py +102 -40
  38. canns/models/basic/grid_cell.py +9 -8
  39. canns/models/basic/hierarchical_model.py +57 -11
  40. canns/models/brain_inspired/hopfield.py +26 -14
  41. canns/models/brain_inspired/linear.py +15 -16
  42. canns/models/brain_inspired/spiking.py +23 -12
  43. canns/pipeline/__init__.py +4 -8
  44. canns/pipeline/asa/__init__.py +21 -0
  45. canns/pipeline/asa/__main__.py +11 -0
  46. canns/pipeline/asa/app.py +1000 -0
  47. canns/pipeline/asa/runner.py +1095 -0
  48. canns/pipeline/asa/screens.py +215 -0
  49. canns/pipeline/asa/state.py +248 -0
  50. canns/pipeline/asa/styles.tcss +221 -0
  51. canns/pipeline/asa/widgets.py +233 -0
  52. canns/pipeline/gallery/__init__.py +7 -0
  53. canns/task/closed_loop_navigation.py +54 -13
  54. canns/task/open_loop_navigation.py +230 -147
  55. canns/task/tracking.py +156 -24
  56. canns/trainer/__init__.py +8 -5
  57. canns/utils/__init__.py +12 -4
  58. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
  59. canns-0.13.0.dist-info/RECORD +91 -0
  60. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
  61. canns/pipeline/theta_sweep.py +0 -573
  62. canns-0.12.6.dist-info/RECORD +0 -72
  63. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
  64. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -25,7 +25,6 @@ from tqdm import tqdm
25
25
 
26
26
  from .core.backend import (
27
27
  emit_backend_warnings,
28
- get_imageio_writer_kwargs,
29
28
  get_multiprocessing_context,
30
29
  get_optimal_worker_count,
31
30
  select_animation_backend,
@@ -501,21 +500,35 @@ def plot_population_activity_with_theta(
501
500
  save_path: str | None = None,
502
501
  **kwargs,
503
502
  ) -> tuple[plt.Figure, plt.Axes]:
504
- """
505
- Plot neural population activity with theta oscillation markers and direction trace.
503
+ """Plot neural population activity with theta phase markers.
506
504
 
507
505
  Args:
508
- time_steps: Array of time points
509
- theta_phase: Array of theta phase values [-π, π]
510
- net_activity: 2D array of network activity (time, neurons)
511
- direction: Array of direction values
512
- config: PlotConfig object for unified configuration
513
- add_lines: Whether to add vertical lines at theta phase zeros
514
- atol: Tolerance for detecting theta phase zeros
515
- **kwargs: Additional parameters for backward compatibility
506
+ time_steps: Array of time points.
507
+ theta_phase: Theta phase values in ``[-pi, pi]``.
508
+ net_activity: 2D array of network activity ``(time, neurons)``.
509
+ direction: Direction values (radians) over time.
510
+ config: PlotConfig object for unified configuration.
511
+ add_lines: Whether to add vertical lines at theta phase zeros.
512
+ atol: Tolerance for detecting theta phase zeros.
513
+ **kwargs: Additional parameters for backward compatibility.
516
514
 
517
515
  Returns:
518
- tuple: (figure, axis) objects
516
+ tuple: ``(figure, axis)`` objects.
517
+
518
+ Examples:
519
+ >>> import numpy as np
520
+ >>> from canns.analyzer.visualization import plot_population_activity_with_theta, PlotConfig
521
+ >>>
522
+ >>> time_steps = np.linspace(0, 1, 5)
523
+ >>> theta_phase = np.linspace(-np.pi, np.pi, 5)
524
+ >>> net_activity = np.random.rand(5, 4)
525
+ >>> direction = np.linspace(-np.pi, np.pi, 5)
526
+ >>> config = PlotConfig(show=False)
527
+ >>> fig, ax = plot_population_activity_with_theta(
528
+ ... time_steps, theta_phase, net_activity, direction, config=config
529
+ ... )
530
+ >>> print(fig is not None)
531
+ True
519
532
  """
520
533
  # Handle configuration
521
534
  if config is None:
@@ -649,18 +662,28 @@ def plot_grid_cell_manifold(
649
662
  save_path: str | None = None,
650
663
  **kwargs,
651
664
  ) -> tuple[plt.Figure, plt.Axes]:
652
- """
653
- Plot grid cell activity on the twisted torus manifold.
665
+ """Plot grid cell activity on the twisted torus manifold.
654
666
 
655
667
  Args:
656
- value_grid_twisted: Coordinates on twisted manifold
657
- grid_cell_activity: 2D array of grid cell activities
658
- config: PlotConfig object for unified configuration
659
- ax: Optional axis to draw on instead of creating a new figure
660
- **kwargs: Additional parameters for backward compatibility
668
+ value_grid_twisted: Coordinates on the twisted manifold ``(N, 2)``.
669
+ grid_cell_activity: 2D array of grid cell activities.
670
+ config: PlotConfig object for unified configuration.
671
+ ax: Optional axis to draw on instead of creating a new figure.
672
+ **kwargs: Additional parameters for backward compatibility.
661
673
 
662
674
  Returns:
663
- tuple: (figure, axis) objects
675
+ tuple: ``(figure, axis)`` objects.
676
+
677
+ Examples:
678
+ >>> import numpy as np
679
+ >>> from canns.analyzer.visualization import plot_grid_cell_manifold, PlotConfig
680
+ >>>
681
+ >>> value_grid_twisted = np.random.rand(9, 2)
682
+ >>> grid_cell_activity = np.random.rand(3, 3)
683
+ >>> config = PlotConfig(show=False)
684
+ >>> fig, ax = plot_grid_cell_manifold(value_grid_twisted, grid_cell_activity, config=config)
685
+ >>> print(fig is not None)
686
+ True
664
687
  """
665
688
  # Handle configuration
666
689
  if config is None:
@@ -827,9 +850,10 @@ def _render_single_place_cell_frame(
827
850
  options: _PlaceCellRenderOptions,
828
851
  ) -> np.ndarray:
829
852
  """Render a single frame for place cell animation (module-level for pickling)."""
853
+ from io import BytesIO
854
+
830
855
  import matplotlib.pyplot as plt
831
856
  import numpy as np
832
- from io import BytesIO
833
857
 
834
858
  fig, axes = plt.subplots(1, 2, figsize=options.figsize, width_ratios=[1, 1])
835
859
  ax_env, ax_activity = axes
@@ -959,32 +983,51 @@ def create_theta_sweep_place_cell_animation(
959
983
  render_start_method: str | None = None,
960
984
  **kwargs,
961
985
  ) -> FuncAnimation | None:
962
- """
963
- Create theta sweep animation for place cell network with 2 panels:
964
- 1. Environment trajectory with place cell bump overlay
965
- 2. Population activity heatmap over time
986
+ """Create theta sweep animation for a place cell network.
966
987
 
967
988
  Args:
968
- position_data: Animal position data (time, 2)
969
- pc_activity_data: Place cell activity (time, num_cells)
970
- pc_network: PlaceCellNetwork instance
971
- navigation_task: BaseNavigationTask instance for environment visualization
972
- dt: Time step size
973
- config: PlotConfig object for unified configuration
974
- n_step: Subsample every n_step frames for animation
975
- fps: Frames per second for animation
976
- figsize: Figure size (width, height)
977
- save_path: Path to save animation (GIF or MP4)
978
- show: Whether to display animation
979
- show_progress_bar: Whether to show progress bar during saving
980
- render_backend: Rendering backend ('imageio', 'matplotlib', or 'auto')
981
- output_dpi: DPI for rendered frames (affects file size and quality)
982
- render_workers: Number of parallel workers (None = auto-detect)
983
- render_start_method: Multiprocessing start method ('fork', 'spawn', or None)
984
- **kwargs: Additional parameters (cmap, alpha, etc.)
989
+ position_data: Animal position data ``(time, 2)``.
990
+ pc_activity_data: Place cell activity ``(time, num_cells)``.
991
+ pc_network: PlaceCellNetwork-like object with ``geodesic_result``.
992
+ navigation_task: BaseNavigationTask-like object with ``env``.
993
+ dt: Time step size.
994
+ config: PlotConfig object for unified configuration.
995
+ n_step: Subsample every n_step frames for animation.
996
+ fps: Frames per second for animation.
997
+ figsize: Figure size (width, height).
998
+ save_path: Path to save animation (GIF or MP4).
999
+ show: Whether to display animation.
1000
+ show_progress_bar: Whether to show progress bar during saving.
1001
+ render_backend: Rendering backend ('imageio', 'matplotlib', or 'auto').
1002
+ output_dpi: DPI for rendered frames (affects file size and quality).
1003
+ render_workers: Number of parallel workers (None = auto-detect).
1004
+ render_start_method: Multiprocessing start method ('fork', 'spawn', or None).
1005
+ **kwargs: Additional parameters (cmap, alpha, etc.).
985
1006
 
986
1007
  Returns:
987
- FuncAnimation: Matplotlib animation object
1008
+ FuncAnimation: Matplotlib animation object.
1009
+
1010
+ Examples:
1011
+ This example demonstrates the basic structure. For complete usage, see the
1012
+ documentation or example scripts.
1013
+
1014
+ >>> import numpy as np
1015
+ >>> from canns.analyzer.visualization import PlotConfig
1016
+ >>>
1017
+ >>> # Prepare your data from simulation
1018
+ >>> position_data = np.array([[0.1, 0.1], [0.2, 0.2], [0.3, 0.3]])
1019
+ >>> pc_activity_data = np.random.rand(3, 4) # (time, num_cells)
1020
+ >>>
1021
+ >>> # Assuming you have pc_network and navigation_task from your model
1022
+ >>> # anim = create_theta_sweep_place_cell_animation(
1023
+ >>> # position_data,
1024
+ >>> # pc_activity_data,
1025
+ >>> # pc_network, # Your PlaceCellNetwork instance
1026
+ >>> # navigation_task, # Your BaseNavigationTask instance
1027
+ >>> # config=PlotConfig(show=False),
1028
+ >>> # n_step=1,
1029
+ >>> # fps=10,
1030
+ >>> # ) # doctest: +SKIP
988
1031
  """
989
1032
  # Handle configuration
990
1033
  if config is None:
@@ -1180,8 +1223,9 @@ def create_theta_sweep_place_cell_animation(
1180
1223
  if backend == "imageio":
1181
1224
  # Use imageio backend with parallel rendering
1182
1225
  workers = render_workers if render_workers is not None else get_optimal_worker_count()
1183
- ctx = get_multiprocessing_context(prefer_fork=(render_start_method == "fork"))
1184
- start_method = ctx.method if ctx is not None else None
1226
+ ctx, start_method = get_multiprocessing_context(
1227
+ prefer_fork=(render_start_method == "fork")
1228
+ )
1185
1229
 
1186
1230
  _emit_info(
1187
1231
  f"Parallel rendering enabled: {workers} workers (start_method={start_method})"
@@ -1190,7 +1234,11 @@ def create_theta_sweep_place_cell_animation(
1190
1234
  # Prepare environment data
1191
1235
  env = navigation_task.env
1192
1236
  boundary_array = np.array(env.boundary) if env.boundary is not None else None
1193
- walls_arrays = [np.array(wall) for wall in env.walls] if env.walls is not None and len(env.walls) > 0 else None
1237
+ walls_arrays = (
1238
+ [np.array(wall) for wall in env.walls]
1239
+ if env.walls is not None and len(env.walls) > 0
1240
+ else None
1241
+ )
1194
1242
 
1195
1243
  # Get accessible indices
1196
1244
  accessible_indices = pc_network.geodesic_result.accessible_indices
@@ -1215,11 +1263,14 @@ def create_theta_sweep_place_cell_animation(
1215
1263
  writer_kwargs, mode = get_imageio_writer_kwargs(config.save_path, config.fps)
1216
1264
 
1217
1265
  try:
1218
- import imageio
1219
1266
  from functools import partial
1220
1267
 
1268
+ import imageio
1269
+
1221
1270
  # Create partial function with data and options
1222
- render_func = partial(_render_single_place_cell_frame, data=data, options=render_options)
1271
+ render_func = partial(
1272
+ _render_single_place_cell_frame, data=data, options=render_options
1273
+ )
1223
1274
 
1224
1275
  with imageio.get_writer(config.save_path, mode=mode, **writer_kwargs) as writer:
1225
1276
  if workers > 1 and ctx is not None:
@@ -1335,32 +1386,69 @@ def create_theta_sweep_grid_cell_animation(
1335
1386
  render_start_method: str | None = None,
1336
1387
  **kwargs,
1337
1388
  ) -> FuncAnimation | None:
1338
- """
1339
- Create comprehensive theta sweep animation with 4 panels (optimized for speed):
1340
- 1. Animal trajectory
1341
- 2. Direction cell polar plot
1342
- 3. Grid cell activity on manifold
1343
- 4. Grid cell activity in real space
1389
+ """Create a theta sweep animation with four panels.
1390
+
1391
+ Panels:
1392
+ 1) Animal trajectory
1393
+ 2) Direction cell polar plot
1394
+ 3) Grid cell activity on manifold
1395
+ 4) Grid cell activity in real space
1344
1396
 
1345
1397
  Args:
1346
- position_data: Animal position data (time, 2)
1347
- direction_data: Direction data (time,)
1348
- dc_activity_data: Direction cell activity (time, neurons)
1349
- gc_activity_data: Grid cell activity (time, neurons)
1350
- gc_network: GridCellNetwork instance for coordinate transformations
1351
- env_size: Environment size
1352
- mapping_ratio: Mapping ratio for grid cells
1353
- dt: Time step size
1354
- config: PlotConfig object for unified configuration
1355
- n_step: Subsample every n_step frames for animation
1356
- render_backend: Rendering backend. Use 'matplotlib', 'imageio', or 'auto'/'None' for auto-detect.
1357
- output_dpi: Target DPI when rendering frames with non-interactive backends
1358
- render_workers: Worker processes for imageio backend. ``None`` auto-selects, 0 disables.
1359
- render_start_method: Multiprocessing start method ('fork', 'spawn', 'forkserver') or None for auto
1360
- **kwargs: Additional parameters for backward compatibility
1398
+ position_data: Animal position data ``(time, 2)``.
1399
+ direction_data: Direction data ``(time,)``.
1400
+ dc_activity_data: Direction cell activity ``(time, neurons)``.
1401
+ gc_activity_data: Grid cell activity ``(time, neurons)``.
1402
+ gc_network: GridCellNetwork instance for coordinate transforms.
1403
+ env_size: Environment size.
1404
+ mapping_ratio: Mapping ratio for grid cells.
1405
+ dt: Time step size.
1406
+ config: PlotConfig object for unified configuration.
1407
+ n_step: Subsample every n_step frames for animation.
1408
+ render_backend: Rendering backend. Use 'matplotlib', 'imageio', or 'auto'.
1409
+ output_dpi: Target DPI for non-interactive rendering.
1410
+ render_workers: Worker processes for imageio backend.
1411
+ render_start_method: Multiprocessing start method ('fork', 'spawn', or None).
1412
+ **kwargs: Additional parameters for backward compatibility.
1361
1413
 
1362
1414
  Returns:
1363
- FuncAnimation | None: Matplotlib animation object for interactive backend, otherwise None
1415
+ FuncAnimation | None: Animation object (None if displayed inline).
1416
+
1417
+ Examples:
1418
+ This is a minimal structural example using synthetic data to demonstrate
1419
+ the API. For realistic usage, run a GridCellNetwork simulation to obtain
1420
+ actual activity data.
1421
+
1422
+ >>> import numpy as np
1423
+ >>> import brainpy.math as bm
1424
+ >>> from canns.models.basic.theta_sweep_model import GridCellNetwork
1425
+ >>> from canns.analyzer.visualization import PlotConfig
1426
+ >>>
1427
+ >>> # Minimal example with synthetic data (for structure demonstration)
1428
+ >>> bm.set_dt(1.0)
1429
+ >>> gc_network = GridCellNetwork(num_dc=4, num_gc_x=4, mapping_ratio=1.0)
1430
+ >>> T = 5
1431
+ >>> # NOTE: In real usage, obtain these from actual model simulation
1432
+ >>> position_data = np.random.rand(T, 2)
1433
+ >>> direction_data = np.linspace(-np.pi, np.pi, T)
1434
+ >>> dc_activity_data = np.random.rand(T, gc_network.num_dc)
1435
+ >>> gc_activity_data = np.random.rand(T, gc_network.num)
1436
+ >>>
1437
+ >>> config = PlotConfig(show=False)
1438
+ >>> anim = create_theta_sweep_grid_cell_animation(
1439
+ ... position_data,
1440
+ ... direction_data,
1441
+ ... dc_activity_data,
1442
+ ... gc_activity_data,
1443
+ ... gc_network,
1444
+ ... env_size=1.0,
1445
+ ... mapping_ratio=1.0,
1446
+ ... config=config,
1447
+ ... n_step=1,
1448
+ ... fps=2,
1449
+ ... )
1450
+ >>> print(anim is not None)
1451
+ True
1364
1452
  """
1365
1453
  # Handle configuration
1366
1454
  if config is None:
@@ -1426,8 +1514,7 @@ def create_theta_sweep_grid_cell_animation(
1426
1514
  workers = render_workers if render_workers is not None else get_optimal_worker_count()
1427
1515
 
1428
1516
  # Get multiprocessing context
1429
- ctx = get_multiprocessing_context(prefer_fork=True)
1430
- start_method = ctx.method if ctx is not None else None
1517
+ ctx, start_method = get_multiprocessing_context(prefer_fork=True)
1431
1518
 
1432
1519
  if workers > 0 and ctx is None:
1433
1520
  warnings.warn(
@@ -69,9 +69,6 @@ def tuning_curve(
69
69
  ):
70
70
  """Plot the tuning curve for one or more neurons.
71
71
 
72
- The wording mirrors the original ``visualize`` module to avoid API drift and
73
- to keep existing references valid.
74
-
75
72
  Args:
76
73
  stimulus: 1D array with the stimulus value at each time step.
77
74
  firing_rates: 2D array of firing rates shaped ``(timesteps, neurons)``.
@@ -86,6 +83,17 @@ def tuning_curve(
86
83
  save_path: Optional location where the figure should be stored.
87
84
  show: Whether to display the plot interactively.
88
85
  **kwargs: Additional keyword arguments passed through to ``ax.plot``.
86
+
87
+ Examples:
88
+ >>> import numpy as np
89
+ >>> from canns.analyzer.visualization import tuning_curve, PlotConfigs
90
+ >>>
91
+ >>> stimulus = np.linspace(0, 1, 10)
92
+ >>> firing_rates = np.random.rand(10, 3)
93
+ >>> config = PlotConfigs.tuning_curve(num_bins=5, pref_stim=np.array([0.2, 0.5, 0.8]), show=False)
94
+ >>> fig, ax = tuning_curve(stimulus, firing_rates, neuron_indices=[0, 1], config=config)
95
+ >>> print(fig is not None)
96
+ True
89
97
  """
90
98
 
91
99
  config = _ensure_plot_config(
canns/data/__init__.py CHANGED
@@ -1,8 +1,11 @@
1
- """
2
- Data utilities for CANNs.
1
+ """Data utilities for CANNs.
2
+
3
+ This namespace provides dataset registry, download helpers, and convenience
4
+ loaders for common CANNs datasets.
3
5
 
4
- This module provides dataset management, loading, and downloading utilities.
5
- It consolidates data-related functionality previously scattered across the codebase.
6
+ Examples:
7
+ >>> from canns import data
8
+ >>> print(list(data.DATASETS))
6
9
  """
7
10
 
8
11
  from .datasets import (
canns/models/__init__.py CHANGED
@@ -1,3 +1,13 @@
1
+ """Model definitions for CANNs.
2
+
3
+ This namespace exposes common model families, such as basic CANNs and
4
+ brain-inspired variants.
5
+
6
+ Examples:
7
+ >>> from canns import models
8
+ >>> print(models.basic)
9
+ """
10
+
1
11
  from . import basic as basic
2
12
  from . import brain_inspired as brain_inspired
3
13
  # from .hybrid import *
@@ -76,11 +76,20 @@ class BaseCANN(BasicModel):
76
76
 
77
77
 
78
78
  class BaseCANN1D(BaseCANN):
79
- """
80
- Base class for 1D Continuous Attractor Neural Network (CANN) models.
81
- This class sets up the fundamental properties of the network, including
82
- neuronal properties, feature space, and the connectivity matrix, which
83
- are shared by different CANN model variations.
79
+ """Base class for 1D Continuous Attractor Neural Network (CANN) models.
80
+
81
+ It builds the 1D feature space, connectivity kernel, and stimulus helpers
82
+ shared by 1D CANN variants.
83
+
84
+ Examples:
85
+ >>> import brainpy.math as bm
86
+ >>> from canns.models.basic.cann import BaseCANN1D
87
+ >>>
88
+ >>> bm.set_dt(0.1)
89
+ >>> model = BaseCANN1D(num=64)
90
+ >>> stimulus = model.get_stimulus_by_pos(0.0)
91
+ >>> stimulus.shape
92
+ (64,)
84
93
  """
85
94
 
86
95
  def __init__(
@@ -181,10 +190,21 @@ class BaseCANN1D(BaseCANN):
181
190
 
182
191
 
183
192
  class CANN1D(BaseCANN1D):
184
- """
185
- A standard 1D Continuous Attractor Neural Network (CANN) model.
186
- This model implements the core dynamics where a localized "bump" of activity
187
- can be sustained and moved by external inputs.
193
+ """Standard 1D Continuous Attractor Neural Network (CANN) model.
194
+
195
+ This model sustains a localized "bump" of activity that can be driven by
196
+ external input.
197
+
198
+ Examples:
199
+ >>> import brainpy.math as bm
200
+ >>> from canns.models.basic import CANN1D
201
+ >>>
202
+ >>> bm.set_dt(0.1)
203
+ >>> model = CANN1D(num=64)
204
+ >>> stimulus = model.get_stimulus_by_pos(0.0)
205
+ >>> model.update(stimulus)
206
+ >>> model.r.value.shape
207
+ (64,)
188
208
 
189
209
  Reference:
190
210
  Wu, S., Hamaguchi, K., & Amari, S. I. (2008). Dynamics and computation of continuous attractors.
@@ -210,11 +230,13 @@ class CANN1D(BaseCANN1D):
210
230
  self.inp = bm.Variable(bm.zeros(self.shape))
211
231
 
212
232
  def update(self, inp):
213
- """
214
- The main update function, defining the dynamics of the network for one time step.
233
+ """Advance the network by one time step.
215
234
 
216
235
  Args:
217
- inp (Array): The external input for the current time step.
236
+ inp (Array): External input vector of shape ``(num,)``.
237
+
238
+ Returns:
239
+ None
218
240
  """
219
241
  self.inp.value = inp
220
242
  # The numerator for the firing rate calculation (a non-linear activation function).
@@ -231,10 +253,21 @@ class CANN1D(BaseCANN1D):
231
253
 
232
254
 
233
255
  class CANN1D_SFA(BaseCANN1D):
234
- """
235
- A 1D CANN model that incorporates Spike-Frequency Adaptation (SFA).
236
- SFA is a slow negative feedback mechanism that causes neurons to fire less
237
- over time for a sustained input, which can induce anticipative tracking behavior.
256
+ """1D CANN model with spike-frequency adaptation (SFA).
257
+
258
+ SFA adds a slow negative feedback term that can create anticipative tracking
259
+ under sustained inputs.
260
+
261
+ Examples:
262
+ >>> import brainpy.math as bm
263
+ >>> from canns.models.basic import CANN1D_SFA
264
+ >>>
265
+ >>> bm.set_dt(0.1)
266
+ >>> model = CANN1D_SFA(num=64)
267
+ >>> stimulus = model.get_stimulus_by_pos(0.0)
268
+ >>> model.update(stimulus)
269
+ >>> model.r.value.shape
270
+ (64,)
238
271
 
239
272
  Reference:
240
273
  Mi, Y., Fung, C. C., Wong, K. Y., & Wu, S. (2014). Spike frequency adaptation
@@ -278,12 +311,13 @@ class CANN1D_SFA(BaseCANN1D):
278
311
  self.inp = bm.Variable(bm.zeros(self.shape)) # External input.
279
312
 
280
313
  def update(self, inp):
281
- """
282
- The main update function for the SFA model. It includes dynamics for both
283
- the membrane potential and the adaptation variable.
314
+ """Advance the network by one time step with adaptation.
284
315
 
285
316
  Args:
286
- inp (Array): The external input for the current time step.
317
+ inp (Array): External input vector of shape ``(num,)``.
318
+
319
+ Returns:
320
+ None
287
321
  """
288
322
  self.inp.value = inp
289
323
  # Firing rate calculation is the same as the standard CANN model.
@@ -302,11 +336,20 @@ class CANN1D_SFA(BaseCANN1D):
302
336
 
303
337
 
304
338
  class BaseCANN2D(BaseCANN):
305
- """
306
- Base class for 2D Continuous Attractor Neural Network (CANN) models.
307
- This class sets up the fundamental properties of the network, including
308
- neuronal properties, feature space, and the connectivity matrix, which
309
- are shared by different CANN model variations.
339
+ """Base class for 2D Continuous Attractor Neural Network (CANN) models.
340
+
341
+ It builds the 2D feature space, connectivity kernel, and stimulus helpers
342
+ shared by 2D CANN variants.
343
+
344
+ Examples:
345
+ >>> import brainpy.math as bm
346
+ >>> from canns.models.basic.cann import BaseCANN2D
347
+ >>>
348
+ >>> bm.set_dt(0.1)
349
+ >>> model = BaseCANN2D(length=16)
350
+ >>> stimulus = model.get_stimulus_by_pos([0.0, 0.0])
351
+ >>> stimulus.shape
352
+ (16, 16)
310
353
  """
311
354
 
312
355
  def __init__(
@@ -456,10 +499,18 @@ class BaseCANN2D(BaseCANN):
456
499
 
457
500
 
458
501
  class CANN2D(BaseCANN2D):
459
- """
460
- A 2D Continuous Attractor Neural Network (CANN) model.
461
- This model extends the base CANN2D class to include specific dynamics
462
- and properties for a 2D neural network.
502
+ """2D Continuous Attractor Neural Network (CANN) model.
503
+
504
+ Examples:
505
+ >>> import brainpy.math as bm
506
+ >>> from canns.models.basic import CANN2D
507
+ >>>
508
+ >>> bm.set_dt(0.1)
509
+ >>> model = CANN2D(length=16)
510
+ >>> stimulus = model.get_stimulus_by_pos([0.0, 0.0])
511
+ >>> model.update(stimulus)
512
+ >>> model.r.value.shape
513
+ (16, 16)
463
514
 
464
515
  Reference:
465
516
  Wu, S., Hamaguchi, K., & Amari, S. I. (2008). Dynamics and computation of continuous attractors.
@@ -485,11 +536,13 @@ class CANN2D(BaseCANN2D):
485
536
  self.inp = bm.Variable(bm.zeros((self.length, self.length)))
486
537
 
487
538
  def update(self, inp):
488
- """
489
- The main update function, defining the dynamics of the network for one time step.
539
+ """Advance the network by one time step.
490
540
 
491
541
  Args:
492
- inp (Array): The external input to the network, which can be a stimulus or other driving force.
542
+ inp (Array): External input grid of shape ``(length, length)``.
543
+
544
+ Returns:
545
+ None
493
546
  """
494
547
  self.inp.value = inp
495
548
  # The numerator for the firing rate calculation (a non-linear activation function).
@@ -505,10 +558,18 @@ class CANN2D(BaseCANN2D):
505
558
 
506
559
 
507
560
  class CANN2D_SFA(BaseCANN2D):
508
- """
509
- A 2D Continuous Attractor Neural Network (CANN) model with a specific
510
- implementation of the Synaptic Firing Activity (SFA) dynamics.
511
- This model extends the base CANN2D class to include SFA-specific dynamics.
561
+ """2D CANN model with spike-frequency adaptation (SFA) dynamics.
562
+
563
+ Examples:
564
+ >>> import brainpy.math as bm
565
+ >>> from canns.models.basic import CANN2D_SFA
566
+ >>>
567
+ >>> bm.set_dt(0.1)
568
+ >>> model = CANN2D_SFA(length=16)
569
+ >>> stimulus = model.get_stimulus_by_pos([0.0, 0.0])
570
+ >>> model.update(stimulus)
571
+ >>> model.r.value.shape
572
+ (16, 16)
512
573
  """
513
574
 
514
575
  def __init__(
@@ -544,12 +605,13 @@ class CANN2D_SFA(BaseCANN2D):
544
605
  self.inp = bm.Variable(bm.zeros((self.length, self.length))) # External input.
545
606
 
546
607
  def update(self, inp):
547
- """
548
- The main update function for the SFA model. It includes dynamics for both
549
- the membrane potential and the adaptation variable.
608
+ """Advance the network by one time step with adaptation.
550
609
 
551
610
  Args:
552
- inp (Array): The external input for the current time step.
611
+ inp (Array): External input grid of shape ``(length, length)``.
612
+
613
+ Returns:
614
+ None
553
615
  """
554
616
  self.inp.value = inp
555
617
  # Firing rate calculation is the same as the standard CANN model.
@@ -54,18 +54,19 @@ class GridCell2DPosition(BasicModel):
54
54
 
55
55
  Example:
56
56
  >>> import brainpy.math as bm
57
- >>> from canns.models.basic import GridCell2D
57
+ >>> from canns.models.basic import GridCell2DPosition
58
58
  >>>
59
59
  >>> bm.set_dt(1.0)
60
- >>> model = GridCell2D(length=30, mapping_ratio=1.5)
60
+ >>> model = GridCell2DPosition(length=16, mapping_ratio=1.5)
61
61
  >>>
62
- >>> # Update with 2D position
63
- >>> position = [0.5, 0.3]
62
+ >>> # Update with a 2D position
63
+ >>> position = bm.array([0.5, 0.3])
64
64
  >>> model.update(position)
65
65
  >>>
66
66
  >>> # Access decoded position
67
67
  >>> decoded_pos = model.center_position.value
68
- >>> print(f"Decoded position: {decoded_pos}")
68
+ >>> decoded_pos.shape
69
+ (2,)
69
70
 
70
71
  References:
71
72
  Burak, Y., & Fiete, I. R. (2009).
@@ -387,11 +388,11 @@ class GridCell2DVelocity(BasicModel):
387
388
  >>> bm.set_dt(5e-4) # Small timestep for accurate integration
388
389
  >>> model = GridCell2DVelocity(length=40)
389
390
  >>>
390
- >>> # Healing process (critical!)
391
- >>> model.heal_network()
391
+ >>> # Healing process (recommended before simulation)
392
+ >>> model.heal_network(num_healing_steps=50, dt_healing=1e-3)
392
393
  >>>
393
394
  >>> # Update with 2D velocity
394
- >>> velocity = [0.1, 0.05] # [vx, vy] in m/s
395
+ >>> velocity = bm.array([0.1, 0.05]) # [vx, vy]
395
396
  >>> model.update(velocity)
396
397
 
397
398
  References: