canns 0.12.6__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. canns/__init__.py +39 -3
  2. canns/analyzer/__init__.py +7 -6
  3. canns/analyzer/data/__init__.py +3 -11
  4. canns/analyzer/data/asa/__init__.py +74 -0
  5. canns/analyzer/data/asa/cohospace.py +905 -0
  6. canns/analyzer/data/asa/config.py +246 -0
  7. canns/analyzer/data/asa/decode.py +448 -0
  8. canns/analyzer/data/asa/embedding.py +269 -0
  9. canns/analyzer/data/asa/filters.py +208 -0
  10. canns/analyzer/data/asa/fr.py +439 -0
  11. canns/analyzer/data/asa/path.py +389 -0
  12. canns/analyzer/data/asa/plotting.py +1276 -0
  13. canns/analyzer/data/asa/tda.py +901 -0
  14. canns/analyzer/data/legacy/__init__.py +6 -0
  15. canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
  16. canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
  17. canns/analyzer/metrics/spatial_metrics.py +70 -100
  18. canns/analyzer/metrics/systematic_ratemap.py +12 -17
  19. canns/analyzer/metrics/utils.py +28 -0
  20. canns/analyzer/model_specific/hopfield.py +19 -16
  21. canns/analyzer/slow_points/checkpoint.py +32 -9
  22. canns/analyzer/slow_points/finder.py +33 -6
  23. canns/analyzer/slow_points/fixed_points.py +12 -0
  24. canns/analyzer/slow_points/visualization.py +22 -10
  25. canns/analyzer/visualization/core/backend.py +15 -26
  26. canns/analyzer/visualization/core/config.py +120 -15
  27. canns/analyzer/visualization/core/jupyter_utils.py +34 -16
  28. canns/analyzer/visualization/core/rendering.py +42 -40
  29. canns/analyzer/visualization/core/writers.py +10 -20
  30. canns/analyzer/visualization/energy_plots.py +78 -28
  31. canns/analyzer/visualization/spatial_plots.py +81 -36
  32. canns/analyzer/visualization/spike_plots.py +27 -7
  33. canns/analyzer/visualization/theta_sweep_plots.py +159 -72
  34. canns/analyzer/visualization/tuning_plots.py +11 -3
  35. canns/data/__init__.py +7 -4
  36. canns/models/__init__.py +10 -0
  37. canns/models/basic/cann.py +102 -40
  38. canns/models/basic/grid_cell.py +9 -8
  39. canns/models/basic/hierarchical_model.py +57 -11
  40. canns/models/brain_inspired/hopfield.py +26 -14
  41. canns/models/brain_inspired/linear.py +15 -16
  42. canns/models/brain_inspired/spiking.py +23 -12
  43. canns/pipeline/__init__.py +4 -8
  44. canns/pipeline/asa/__init__.py +21 -0
  45. canns/pipeline/asa/__main__.py +11 -0
  46. canns/pipeline/asa/app.py +1000 -0
  47. canns/pipeline/asa/runner.py +1095 -0
  48. canns/pipeline/asa/screens.py +215 -0
  49. canns/pipeline/asa/state.py +248 -0
  50. canns/pipeline/asa/styles.tcss +221 -0
  51. canns/pipeline/asa/widgets.py +233 -0
  52. canns/pipeline/gallery/__init__.py +7 -0
  53. canns/task/closed_loop_navigation.py +54 -13
  54. canns/task/open_loop_navigation.py +230 -147
  55. canns/task/tracking.py +156 -24
  56. canns/trainer/__init__.py +8 -5
  57. canns/utils/__init__.py +12 -4
  58. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
  59. canns-0.13.0.dist-info/RECORD +91 -0
  60. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
  61. canns/pipeline/theta_sweep.py +0 -573
  62. canns-0.12.6.dist-info/RECORD +0 -72
  63. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
  64. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -27,31 +27,63 @@ __all__ = [
27
27
 
28
28
 
29
29
  def map2pi(a):
30
+ """Wrap angles to the range [-pi, pi].
31
+
32
+ Workflow:
33
+ Setup -> Provide angles (scalar or array-like).
34
+ Execute -> Call ``map2pi``.
35
+ Result -> Angles wrapped into [-pi, pi].
36
+
37
+ Examples:
38
+ >>> import numpy as np
39
+ >>> import brainpy.math as bm
40
+ >>> from canns.task.open_loop_navigation import map2pi
41
+ >>>
42
+ >>> angles = bm.array([3.5, -4.0])
43
+ >>> wrapped = map2pi(angles)
44
+ >>> bool(((wrapped >= -np.pi) & (wrapped <= np.pi)).all())
45
+ True
46
+ """
30
47
  b = bm.where(a > np.pi, a - np.pi * 2, a)
31
48
  c = bm.where(b < -np.pi, b + np.pi * 2, b)
32
49
  return c
33
50
 
34
51
 
35
52
  class ActionPolicy(ABC):
36
- """
37
- Abstract base class for action policies that control agent movement.
38
-
39
- Action policies compute parameters for agent.update() at each simulation step,
40
- enabling reusable, testable, and composable control strategies.
41
-
42
- Example:
43
- ```python
44
- class ConstantDriftPolicy(ActionPolicy):
45
- def __init__(self, drift_direction):
46
- self.drift = np.array(drift_direction)
47
-
48
- def compute_action(self, step_idx, agent):
49
- return {'drift_velocity': self.drift}
50
-
51
- task = CustomOpenLoopNavigationTask(
52
- duration=100, action_policy=ConstantDriftPolicy([0.1, 0.0])
53
- )
54
- ```
53
+ """Abstract base class for action policies that control agent movement.
54
+
55
+ Action policies compute parameters for ``agent.update()`` at each simulation
56
+ step, enabling reusable and testable control strategies.
57
+
58
+ Workflow:
59
+ Setup -> Implement ``compute_action``.
60
+ Execute -> Pass the policy into a task and call ``get_data()``.
61
+ Result -> Task data is generated using the policy-controlled actions.
62
+
63
+ Examples:
64
+ >>> import numpy as np
65
+ >>> import brainpy.math as bm
66
+ >>> from canns.task.open_loop_navigation import ActionPolicy, CustomOpenLoopNavigationTask
67
+ >>>
68
+ >>> class ConstantDriftPolicy(ActionPolicy):
69
+ ... def __init__(self, drift_direction):
70
+ ... self.drift = np.array(drift_direction, dtype=float)
71
+ ...
72
+ ... def compute_action(self, step_idx, agent):
73
+ ... return {"drift_velocity": self.drift, "drift_to_random_strength_ratio": 10.0}
74
+ >>>
75
+ >>> bm.set_dt(0.1)
76
+ >>> task = CustomOpenLoopNavigationTask(
77
+ ... duration=0.5,
78
+ ... width=1.0,
79
+ ... height=1.0,
80
+ ... dt=bm.get_dt(),
81
+ ... action_policy=ConstantDriftPolicy([0.1, 0.0]),
82
+ ... progress_bar=False,
83
+ ... )
84
+ >>> task.get_data()
85
+ >>> task.data.position.shape[1]
86
+ 2
55
87
  """
56
88
 
57
89
  @abstractmethod
@@ -75,15 +107,31 @@ class ActionPolicy(ABC):
75
107
 
76
108
  @dataclass
77
109
  class OpenLoopNavigationData:
78
- """
79
- Container for the inputs recorded during the open-loop navigation task.
80
- It contains the position, velocity, speed, movement direction, head direction,
81
- and rotational velocity of the agent.
82
-
83
- Additional fields for theta sweep analysis:
84
- - ang_velocity: Angular velocity calculated using unwrap method
85
- - linear_speed_gains: Normalized linear speed gains [0,1]
86
- - ang_speed_gains: Normalized angular speed gains [-1,1]
110
+ """Container for open-loop navigation trajectories and derived signals.
111
+
112
+ It stores position, velocity, speed, movement direction, head direction,
113
+ and rotational velocity. Optional fields are added for theta sweep analysis.
114
+
115
+ Workflow:
116
+ Setup -> Create an ``OpenLoopNavigationTask``.
117
+ Execute -> Call ``get_data()``.
118
+ Result -> Access trajectories from ``task.data``.
119
+
120
+ Examples:
121
+ >>> import brainpy.math as bm
122
+ >>> from canns.task.open_loop_navigation import OpenLoopNavigationTask
123
+ >>>
124
+ >>> bm.set_dt(0.1)
125
+ >>> task = OpenLoopNavigationTask(
126
+ ... duration=1.0,
127
+ ... width=1.0,
128
+ ... height=1.0,
129
+ ... dt=bm.get_dt(),
130
+ ... progress_bar=False,
131
+ ... )
132
+ >>> task.get_data()
133
+ >>> task.data.position.shape[1]
134
+ 2
87
135
  """
88
136
 
89
137
  position: np.ndarray
@@ -100,9 +148,32 @@ class OpenLoopNavigationData:
100
148
 
101
149
 
102
150
  class OpenLoopNavigationTask(BaseNavigationTask):
103
- """
104
- Open-loop spatial navigation task that synthesises trajectories without
105
- incorporating real-time feedback from a controller.
151
+ """Open-loop navigation task that synthesizes trajectories.
152
+
153
+ The trajectory is generated without real-time feedback control. This is
154
+ useful for producing reproducible paths for model evaluation.
155
+
156
+ Workflow:
157
+ Setup -> Instantiate the task with environment and motion settings.
158
+ Execute -> Call ``get_data()`` to generate a trajectory.
159
+ Result -> Read ``task.data`` for positions, velocities, and speed.
160
+
161
+ Examples:
162
+ >>> import brainpy.math as bm
163
+ >>> from canns.task.open_loop_navigation import OpenLoopNavigationTask
164
+ >>>
165
+ >>> bm.set_dt(0.1)
166
+ >>> task = OpenLoopNavigationTask(
167
+ ... duration=1.0,
168
+ ... width=1.0,
169
+ ... height=1.0,
170
+ ... start_pos=(0.5, 0.5),
171
+ ... dt=bm.get_dt(),
172
+ ... progress_bar=False,
173
+ ... )
174
+ >>> task.get_data()
175
+ >>> task.data.position.shape[0] == task.total_steps
176
+ True
106
177
  """
107
178
 
108
179
  def __init__(
@@ -621,7 +692,9 @@ class OpenLoopNavigationTask(BaseNavigationTask):
621
692
 
622
693
  if not hasattr(self, "agent") or self.agent is None:
623
694
  self.agent = Agent(
624
- environment=self.env, params=copy.deepcopy(self.agent_params), rng_seed=self.rng_seed
695
+ environment=self.env,
696
+ params=copy.deepcopy(self.agent_params),
697
+ rng_seed=self.rng_seed,
625
698
  )
626
699
 
627
700
  # Set initial position
@@ -696,29 +769,36 @@ class OpenLoopNavigationTask(BaseNavigationTask):
696
769
 
697
770
 
698
771
  class CustomOpenLoopNavigationTask(OpenLoopNavigationTask):
699
- """
700
- Template class for policy-based open-loop navigation tasks.
701
-
702
- This class enables custom action policies by accepting an ActionPolicy object
703
- that controls agent movement at each simulation step.
704
-
705
- Args:
706
- action_policy: ActionPolicy instance controlling agent movement
707
- **kwargs: All other arguments passed to OpenLoopNavigationTask
708
-
709
- Example:
710
- ```python
711
- # Define custom policy
712
- class MyPolicy(ActionPolicy):
713
- def compute_action(self, step_idx, agent):
714
- return {'drift_velocity': np.array([0.1, 0.0])}
715
-
716
- # Use it
717
- task = CustomOpenLoopNavigationTask(
718
- duration=100, action_policy=MyPolicy()
719
- )
720
- task.get_data()
721
- ```
772
+ """Open-loop navigation task driven by a custom action policy.
773
+
774
+ Provide an ``ActionPolicy`` to control how the agent moves at each step.
775
+
776
+ Workflow:
777
+ Setup -> Implement a policy and create the task.
778
+ Execute -> Call ``get_data()``.
779
+ Result -> Trajectory data reflects the policy-driven actions.
780
+
781
+ Examples:
782
+ >>> import numpy as np
783
+ >>> import brainpy.math as bm
784
+ >>> from canns.task.open_loop_navigation import ActionPolicy, CustomOpenLoopNavigationTask
785
+ >>>
786
+ >>> class MyPolicy(ActionPolicy):
787
+ ... def compute_action(self, step_idx, agent):
788
+ ... return {"drift_velocity": np.array([0.05, 0.0])}
789
+ >>>
790
+ >>> bm.set_dt(0.1)
791
+ >>> task = CustomOpenLoopNavigationTask(
792
+ ... duration=0.5,
793
+ ... width=1.0,
794
+ ... height=1.0,
795
+ ... dt=bm.get_dt(),
796
+ ... action_policy=MyPolicy(),
797
+ ... progress_bar=False,
798
+ ... )
799
+ >>> task.get_data()
800
+ >>> task.data.velocity.shape[1]
801
+ 2
722
802
  """
723
803
 
724
804
  def __init__(self, *args, action_policy: ActionPolicy | None = None, **kwargs):
@@ -733,46 +813,38 @@ class CustomOpenLoopNavigationTask(OpenLoopNavigationTask):
733
813
 
734
814
 
735
815
  class StateAwareRasterScanPolicy(ActionPolicy):
736
- """
737
- State-aware raster scan policy with cyclic dual-mode exploration.
738
-
739
- Scanning strategy (循环扫描):
740
- 1. Horizontal mode: Left-right sweeps moving downward
741
- When reaching bottom: switch to Vertical mode
742
- 2. Vertical mode: Up-down sweeps moving rightward
743
- → When reaching right edge: switch back to Horizontal mode
744
- 3. Cycles continuously: H V H → V → ... (避免撞墙)
745
-
746
- This cyclic dual-mode approach achieves comprehensive coverage by combining
747
- orthogonal scanning patterns and avoiding wall collisions.
748
-
749
- Tested performance (200s, 1.0m x 1.0m environment):
750
- - Cyclic dual-mode: ~75-80%+ coverage (continuous cycling)
751
- - Single horizontal: 54.1% coverage (29 rows)
752
-
753
- Args:
754
- width: Environment width in meters
755
- height: Environment height in meters
756
- margin: Distance from wall to trigger turn (default: 0.05)
757
- step_size: Movement per turn in perpendicular direction (default: 0.03)
758
- speed: Movement speed (default: 0.15)
759
- drift_strength: Drift-to-random ratio for agent.update() (default: 15.0)
760
-
761
- Example:
762
- ```python
763
- policy = StateAwareRasterScanPolicy(
764
- width=1.0, height=1.0,
765
- step_size=0.03, # Dense scanning for high coverage
766
- drift_strength=15.0
767
- )
768
- task = CustomOpenLoopNavigationTask(
769
- duration=200,
770
- action_policy=policy,
771
- width=1.0,
772
- height=1.0,
773
- start_pos=(0.05, 0.95) # Start at top-left
774
- )
775
- ```
816
+ """State-aware raster scan policy with cyclic dual-mode exploration.
817
+
818
+ Scanning strategy:
819
+ 1) Horizontal mode: left-right sweeps moving downward
820
+ 2) Vertical mode: up-down sweeps moving rightward
821
+ 3) Cycles continuously to avoid walls and improve coverage
822
+
823
+ Workflow:
824
+ Setup -> Instantiate the policy with environment size.
825
+ Execute -> Use it in ``CustomOpenLoopNavigationTask.get_data()``.
826
+ Result -> The trajectory follows a raster-scan pattern.
827
+
828
+ Examples:
829
+ >>> import brainpy.math as bm
830
+ >>> from canns.task.open_loop_navigation import (
831
+ ... StateAwareRasterScanPolicy,
832
+ ... CustomOpenLoopNavigationTask,
833
+ ... )
834
+ >>>
835
+ >>> bm.set_dt(0.1)
836
+ >>> policy = StateAwareRasterScanPolicy(width=1.0, height=1.0)
837
+ >>> task = CustomOpenLoopNavigationTask(
838
+ ... duration=0.5,
839
+ ... width=1.0,
840
+ ... height=1.0,
841
+ ... dt=bm.get_dt(),
842
+ ... action_policy=policy,
843
+ ... progress_bar=False,
844
+ ... )
845
+ >>> task.get_data()
846
+ >>> task.data.position.shape[1]
847
+ 2
776
848
  """
777
849
 
778
850
  def __init__(
@@ -905,47 +977,32 @@ class StateAwareRasterScanPolicy(ActionPolicy):
905
977
 
906
978
 
907
979
  class RasterScanNavigationTask(CustomOpenLoopNavigationTask):
908
- """
909
- Preset task for cyclic dual-mode state-aware raster scan exploration.
910
-
911
- Systematically explores the environment using cyclic mode switching:
912
- 1. Horizontal phase: Left-right sweeps moving downward
913
- → Switches to Vertical when reaching bottom
914
- 2. Vertical phase: Up-down sweeps moving rightward
915
- Switches back to Horizontal when reaching right edge
916
- 3. Cycles continuously: H V H → V → ...
917
-
918
- This cyclic dual-mode strategy achieves superior coverage by combining
919
- orthogonal scanning patterns and continuously adapting to avoid walls.
920
-
921
- Performance (200s, 1.0m x 1.0m):
922
- - Cyclic dual-mode: ~75-80% coverage (continuous cycling)
923
- - Single horizontal: 54.1% coverage (29 rows)
924
- - +20-30% improvement over random walk
925
-
926
- Args:
927
- duration: Simulation duration in seconds
928
- width: Environment width (default: 1.0)
929
- height: Environment height (default: 1.0)
930
- step_size: Scan density - smaller = denser scanning (default: 0.03)
931
- margin: Wall detection margin (default: 0.05)
932
- speed: Movement speed in m/s (default: 0.15)
933
- drift_strength: Drift control strength (default: 15.0)
934
- **kwargs: Additional arguments passed to OpenLoopNavigationTask
935
-
936
- Example:
937
- ```python
938
- # High coverage dual-mode exploration
939
- task = RasterScanNavigationTask(
940
- duration=200,
941
- width=1.0,
942
- height=1.0,
943
- step_size=0.03, # Dense scanning in both directions
944
- speed=0.15 # Movement speed
945
- )
946
- task.get_data()
947
- task.show_trajectory_analysis()
948
- ```
980
+ """Preset open-loop task for cyclic dual-mode raster scan exploration.
981
+
982
+ The task alternates between horizontal and vertical sweep phases to cover
983
+ the environment while avoiding walls.
984
+
985
+ Workflow:
986
+ Setup -> Instantiate the task with scan parameters.
987
+ Execute -> Call ``get_data()``.
988
+ Result -> Access the generated trajectory in ``task.data``.
989
+
990
+ Examples:
991
+ >>> import brainpy.math as bm
992
+ >>> from canns.task.open_loop_navigation import RasterScanNavigationTask
993
+ >>>
994
+ >>> bm.set_dt(0.1)
995
+ >>> task = RasterScanNavigationTask(
996
+ ... duration=0.5,
997
+ ... width=1.0,
998
+ ... height=1.0,
999
+ ... step_size=0.05,
1000
+ ... dt=bm.get_dt(),
1001
+ ... progress_bar=False,
1002
+ ... )
1003
+ >>> task.get_data()
1004
+ >>> task.data.position.shape[1]
1005
+ 2
949
1006
  """
950
1007
 
951
1008
  def __init__(
@@ -983,11 +1040,24 @@ class RasterScanNavigationTask(CustomOpenLoopNavigationTask):
983
1040
 
984
1041
 
985
1042
  class TMazeOpenLoopNavigationTask(OpenLoopNavigationTask):
986
- """
987
- Open-loop navigation task in a T-maze environment.
988
-
989
- This subclass configures the environment with a T-maze boundary, which is useful
990
- for studying decision-making and spatial navigation in a controlled setting.
1043
+ """Open-loop navigation task in a T-maze environment.
1044
+
1045
+ The environment boundary is configured to a classic T-maze layout.
1046
+
1047
+ Workflow:
1048
+ Setup -> Instantiate the task with maze geometry.
1049
+ Execute -> Call ``get_data()``.
1050
+ Result -> Use ``task.data.position`` as the trajectory.
1051
+
1052
+ Examples:
1053
+ >>> import brainpy.math as bm
1054
+ >>> from canns.task.open_loop_navigation import TMazeOpenLoopNavigationTask
1055
+ >>>
1056
+ >>> bm.set_dt(0.1)
1057
+ >>> task = TMazeOpenLoopNavigationTask(duration=0.5, dt=bm.get_dt(), progress_bar=False)
1058
+ >>> task.get_data()
1059
+ >>> task.data.position.shape[1]
1060
+ 2
991
1061
  """
992
1062
 
993
1063
  def __init__(
@@ -1038,12 +1108,25 @@ class TMazeOpenLoopNavigationTask(OpenLoopNavigationTask):
1038
1108
 
1039
1109
 
1040
1110
  class TMazeRecessOpenLoopNavigationTask(TMazeOpenLoopNavigationTask):
1041
- """
1042
- Open-loop navigation task in a T-maze environment with recesses at stem-arm junctions.
1043
-
1044
- This variant adds small rectangular indentations at the T-junction, creating
1045
- additional spatial features that may be useful for studying spatial navigation
1046
- and decision-making.
1111
+ """Open-loop navigation task in a T-maze with recesses at the junction.
1112
+
1113
+ Recesses add small indentations near the stem-arm junctions, providing
1114
+ extra spatial structure.
1115
+
1116
+ Workflow:
1117
+ Setup -> Instantiate the task with recess geometry.
1118
+ Execute -> Call ``get_data()``.
1119
+ Result -> Use ``task.data`` for downstream modeling.
1120
+
1121
+ Examples:
1122
+ >>> import brainpy.math as bm
1123
+ >>> from canns.task.open_loop_navigation import TMazeRecessOpenLoopNavigationTask
1124
+ >>>
1125
+ >>> bm.set_dt(0.1)
1126
+ >>> task = TMazeRecessOpenLoopNavigationTask(duration=0.5, dt=bm.get_dt(), progress_bar=False)
1127
+ >>> task.get_data()
1128
+ >>> task.data.position.shape[1]
1129
+ 2
1047
1130
  """
1048
1131
 
1049
1132
  def __init__(
canns/task/tracking.py CHANGED
@@ -429,10 +429,34 @@ class SmoothTracking(TrackingTask):
429
429
 
430
430
 
431
431
  class PopulationCoding1D(PopulationCoding):
432
- """
433
- Population coding task for 1D continuous attractor networks.
434
- In this task, a stimulus is presented for a specific duration, preceded and followed by
435
- periods of no stimulation, to test the network's ability to form and maintain a memory bump.
432
+ """Population coding task for 1D continuous attractor networks.
433
+
434
+ A stimulus is presented for a specific duration, preceded and followed by
435
+ periods of no stimulation.
436
+
437
+ Workflow:
438
+ Setup -> Create a 1D CANN and the task.
439
+ Execute -> Call ``get_data()``.
440
+ Result -> Use ``task.data`` as the input sequence.
441
+
442
+ Examples:
443
+ >>> import brainpy.math as bm
444
+ >>> from canns.models.basic import CANN1D
445
+ >>> from canns.task.tracking import PopulationCoding1D
446
+ >>>
447
+ >>> bm.set_dt(0.1)
448
+ >>> model = CANN1D(num=64)
449
+ >>> task = PopulationCoding1D(
450
+ ... cann_instance=model,
451
+ ... before_duration=1.0,
452
+ ... after_duration=1.0,
453
+ ... Iext=0.0,
454
+ ... duration=2.0,
455
+ ... time_step=bm.get_dt(),
456
+ ... )
457
+ >>> task.get_data()
458
+ >>> task.data.shape[0] == task.total_steps
459
+ True
436
460
  """
437
461
 
438
462
  def __init__(
@@ -472,10 +496,32 @@ class PopulationCoding1D(PopulationCoding):
472
496
 
473
497
 
474
498
  class TemplateMatching1D(TemplateMatching):
475
- """
476
- Template matching task for 1D continuous attractor networks.
477
- This task presents a stimulus with added noise to test the network's ability to
478
- denoise the input and settle on the correct underlying pattern (template).
499
+ """Template matching task for 1D continuous attractor networks.
500
+
501
+ A fixed stimulus template is presented with noise at each step, testing
502
+ the network's denoising dynamics.
503
+
504
+ Workflow:
505
+ Setup -> Create a 1D CANN and the task.
506
+ Execute -> Call ``get_data()``.
507
+ Result -> Use ``task.data`` as the noisy input sequence.
508
+
509
+ Examples:
510
+ >>> import brainpy.math as bm
511
+ >>> from canns.models.basic import CANN1D
512
+ >>> from canns.task.tracking import TemplateMatching1D
513
+ >>>
514
+ >>> bm.set_dt(0.1)
515
+ >>> model = CANN1D(num=64)
516
+ >>> task = TemplateMatching1D(
517
+ ... cann_instance=model,
518
+ ... Iext=0.0,
519
+ ... duration=1.0,
520
+ ... time_step=bm.get_dt(),
521
+ ... )
522
+ >>> task.get_data()
523
+ >>> task.data.shape[1] == model.shape[0]
524
+ True
479
525
  """
480
526
 
481
527
  def __init__(
@@ -504,10 +550,31 @@ class TemplateMatching1D(TemplateMatching):
504
550
 
505
551
 
506
552
  class SmoothTracking1D(SmoothTracking):
507
- """
508
- Smooth tracking task for 1D continuous attractor networks.
509
- This task provides an external input that moves smoothly over time, testing the network's
510
- ability to track a continuously changing stimulus.
553
+ """Smooth tracking task for 1D continuous attractor networks.
554
+
555
+ The external input moves smoothly between key positions.
556
+
557
+ Workflow:
558
+ Setup -> Create a 1D CANN and the task.
559
+ Execute -> Call ``get_data()``.
560
+ Result -> ``task.data`` contains the smoothly varying stimulus.
561
+
562
+ Examples:
563
+ >>> import brainpy.math as bm
564
+ >>> from canns.models.basic import CANN1D
565
+ >>> from canns.task.tracking import SmoothTracking1D
566
+ >>>
567
+ >>> bm.set_dt(0.1)
568
+ >>> model = CANN1D(num=64)
569
+ >>> task = SmoothTracking1D(
570
+ ... cann_instance=model,
571
+ ... Iext=(0.0, 1.0, 0.5),
572
+ ... duration=(0.5, 0.5),
573
+ ... time_step=bm.get_dt(),
574
+ ... )
575
+ >>> task.get_data()
576
+ >>> task.data.shape[0] == task.total_steps
577
+ True
511
578
  """
512
579
 
513
580
  def __init__(
@@ -565,10 +632,33 @@ class CustomTracking1D(TrackingTask):
565
632
 
566
633
 
567
634
  class PopulationCoding2D(PopulationCoding):
568
- """
569
- Population coding task for 2D continuous attractor networks.
570
- In this task, a stimulus is presented for a specific duration, preceded and followed by
571
- periods of no stimulation, to test the network's ability to form and maintain a memory bump.
635
+ """Population coding task for 2D continuous attractor networks.
636
+
637
+ A 2D stimulus is presented for a duration with pre- and post-silence.
638
+
639
+ Workflow:
640
+ Setup -> Create a 2D CANN and the task.
641
+ Execute -> Call ``get_data()``.
642
+ Result -> Use ``task.data`` as the input sequence.
643
+
644
+ Examples:
645
+ >>> import brainpy.math as bm
646
+ >>> from canns.models.basic import CANN2D
647
+ >>> from canns.task.tracking import PopulationCoding2D
648
+ >>>
649
+ >>> bm.set_dt(0.1)
650
+ >>> model = CANN2D(length=8)
651
+ >>> task = PopulationCoding2D(
652
+ ... cann_instance=model,
653
+ ... before_duration=1.0,
654
+ ... after_duration=1.0,
655
+ ... Iext=(0.0, 0.0),
656
+ ... duration=1.0,
657
+ ... time_step=bm.get_dt(),
658
+ ... )
659
+ >>> task.get_data()
660
+ >>> task.data.shape[1:] == model.shape
661
+ True
572
662
  """
573
663
 
574
664
  def __init__(
@@ -609,10 +699,31 @@ class PopulationCoding2D(PopulationCoding):
609
699
 
610
700
 
611
701
  class TemplateMatching2D(TemplateMatching):
612
- """
613
- Template matching task for 2D continuous attractor networks.
614
- This task presents a stimulus with added noise to test the network's ability to
615
- denoise the input and settle on the correct underlying pattern (template).
702
+ """Template matching task for 2D continuous attractor networks.
703
+
704
+ A 2D template is presented with noise at each step.
705
+
706
+ Workflow:
707
+ Setup -> Create a 2D CANN and the task.
708
+ Execute -> Call ``get_data()``.
709
+ Result -> ``task.data`` contains noisy 2D inputs.
710
+
711
+ Examples:
712
+ >>> import brainpy.math as bm
713
+ >>> from canns.models.basic import CANN2D
714
+ >>> from canns.task.tracking import TemplateMatching2D
715
+ >>>
716
+ >>> bm.set_dt(0.1)
717
+ >>> model = CANN2D(length=8)
718
+ >>> task = TemplateMatching2D(
719
+ ... cann_instance=model,
720
+ ... Iext=(0.0, 0.0),
721
+ ... duration=1.0,
722
+ ... time_step=bm.get_dt(),
723
+ ... )
724
+ >>> task.get_data()
725
+ >>> task.data.shape[1:] == model.shape
726
+ True
616
727
  """
617
728
 
618
729
  def __init__(
@@ -642,10 +753,31 @@ class TemplateMatching2D(TemplateMatching):
642
753
 
643
754
 
644
755
  class SmoothTracking2D(SmoothTracking):
645
- """
646
- Smooth tracking task for 2D continuous attractor networks.
647
- This task provides an external input that moves smoothly over time, testing the network's
648
- ability to track a continuously changing stimulus.
756
+ """Smooth tracking task for 2D continuous attractor networks.
757
+
758
+ The external 2D input moves smoothly between key positions.
759
+
760
+ Workflow:
761
+ Setup -> Create a 2D CANN and the task.
762
+ Execute -> Call ``get_data()``.
763
+ Result -> ``task.data`` contains smoothly varying 2D inputs.
764
+
765
+ Examples:
766
+ >>> import brainpy.math as bm
767
+ >>> from canns.models.basic import CANN2D
768
+ >>> from canns.task.tracking import SmoothTracking2D
769
+ >>>
770
+ >>> bm.set_dt(0.1)
771
+ >>> model = CANN2D(length=8)
772
+ >>> task = SmoothTracking2D(
773
+ ... cann_instance=model,
774
+ ... Iext=((0.0, 0.0), (1.0, 1.0), (0.5, 0.5)),
775
+ ... duration=(0.5, 0.5),
776
+ ... time_step=bm.get_dt(),
777
+ ... )
778
+ >>> task.get_data()
779
+ >>> task.data.shape[1:] == model.shape
780
+ True
649
781
  """
650
782
 
651
783
  def __init__(