canns 0.12.6__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/__init__.py +39 -3
- canns/analyzer/__init__.py +7 -6
- canns/analyzer/data/__init__.py +3 -11
- canns/analyzer/data/asa/__init__.py +74 -0
- canns/analyzer/data/asa/cohospace.py +905 -0
- canns/analyzer/data/asa/config.py +246 -0
- canns/analyzer/data/asa/decode.py +448 -0
- canns/analyzer/data/asa/embedding.py +269 -0
- canns/analyzer/data/asa/filters.py +208 -0
- canns/analyzer/data/asa/fr.py +439 -0
- canns/analyzer/data/asa/path.py +389 -0
- canns/analyzer/data/asa/plotting.py +1276 -0
- canns/analyzer/data/asa/tda.py +901 -0
- canns/analyzer/data/legacy/__init__.py +6 -0
- canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
- canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
- canns/analyzer/metrics/spatial_metrics.py +70 -100
- canns/analyzer/metrics/systematic_ratemap.py +12 -17
- canns/analyzer/metrics/utils.py +28 -0
- canns/analyzer/model_specific/hopfield.py +19 -16
- canns/analyzer/slow_points/checkpoint.py +32 -9
- canns/analyzer/slow_points/finder.py +33 -6
- canns/analyzer/slow_points/fixed_points.py +12 -0
- canns/analyzer/slow_points/visualization.py +22 -10
- canns/analyzer/visualization/core/backend.py +15 -26
- canns/analyzer/visualization/core/config.py +120 -15
- canns/analyzer/visualization/core/jupyter_utils.py +34 -16
- canns/analyzer/visualization/core/rendering.py +42 -40
- canns/analyzer/visualization/core/writers.py +10 -20
- canns/analyzer/visualization/energy_plots.py +78 -28
- canns/analyzer/visualization/spatial_plots.py +81 -36
- canns/analyzer/visualization/spike_plots.py +27 -7
- canns/analyzer/visualization/theta_sweep_plots.py +159 -72
- canns/analyzer/visualization/tuning_plots.py +11 -3
- canns/data/__init__.py +7 -4
- canns/models/__init__.py +10 -0
- canns/models/basic/cann.py +102 -40
- canns/models/basic/grid_cell.py +9 -8
- canns/models/basic/hierarchical_model.py +57 -11
- canns/models/brain_inspired/hopfield.py +26 -14
- canns/models/brain_inspired/linear.py +15 -16
- canns/models/brain_inspired/spiking.py +23 -12
- canns/pipeline/__init__.py +4 -8
- canns/pipeline/asa/__init__.py +21 -0
- canns/pipeline/asa/__main__.py +11 -0
- canns/pipeline/asa/app.py +1000 -0
- canns/pipeline/asa/runner.py +1095 -0
- canns/pipeline/asa/screens.py +215 -0
- canns/pipeline/asa/state.py +248 -0
- canns/pipeline/asa/styles.tcss +221 -0
- canns/pipeline/asa/widgets.py +233 -0
- canns/pipeline/gallery/__init__.py +7 -0
- canns/task/closed_loop_navigation.py +54 -13
- canns/task/open_loop_navigation.py +230 -147
- canns/task/tracking.py +156 -24
- canns/trainer/__init__.py +8 -5
- canns/utils/__init__.py +12 -4
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
- canns-0.13.0.dist-info/RECORD +91 -0
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
- canns/pipeline/theta_sweep.py +0 -573
- canns-0.12.6.dist-info/RECORD +0 -72
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -27,31 +27,63 @@ __all__ = [
|
|
|
27
27
|
|
|
28
28
|
|
|
29
29
|
def map2pi(a):
|
|
30
|
+
"""Wrap angles to the range [-pi, pi].
|
|
31
|
+
|
|
32
|
+
Workflow:
|
|
33
|
+
Setup -> Provide angles (scalar or array-like).
|
|
34
|
+
Execute -> Call ``map2pi``.
|
|
35
|
+
Result -> Angles wrapped into [-pi, pi].
|
|
36
|
+
|
|
37
|
+
Examples:
|
|
38
|
+
>>> import numpy as np
|
|
39
|
+
>>> import brainpy.math as bm
|
|
40
|
+
>>> from canns.task.open_loop_navigation import map2pi
|
|
41
|
+
>>>
|
|
42
|
+
>>> angles = bm.array([3.5, -4.0])
|
|
43
|
+
>>> wrapped = map2pi(angles)
|
|
44
|
+
>>> bool(((wrapped >= -np.pi) & (wrapped <= np.pi)).all())
|
|
45
|
+
True
|
|
46
|
+
"""
|
|
30
47
|
b = bm.where(a > np.pi, a - np.pi * 2, a)
|
|
31
48
|
c = bm.where(b < -np.pi, b + np.pi * 2, b)
|
|
32
49
|
return c
|
|
33
50
|
|
|
34
51
|
|
|
35
52
|
class ActionPolicy(ABC):
|
|
36
|
-
"""
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
)
|
|
54
|
-
|
|
53
|
+
"""Abstract base class for action policies that control agent movement.
|
|
54
|
+
|
|
55
|
+
Action policies compute parameters for ``agent.update()`` at each simulation
|
|
56
|
+
step, enabling reusable and testable control strategies.
|
|
57
|
+
|
|
58
|
+
Workflow:
|
|
59
|
+
Setup -> Implement ``compute_action``.
|
|
60
|
+
Execute -> Pass the policy into a task and call ``get_data()``.
|
|
61
|
+
Result -> Task data is generated using the policy-controlled actions.
|
|
62
|
+
|
|
63
|
+
Examples:
|
|
64
|
+
>>> import numpy as np
|
|
65
|
+
>>> import brainpy.math as bm
|
|
66
|
+
>>> from canns.task.open_loop_navigation import ActionPolicy, CustomOpenLoopNavigationTask
|
|
67
|
+
>>>
|
|
68
|
+
>>> class ConstantDriftPolicy(ActionPolicy):
|
|
69
|
+
... def __init__(self, drift_direction):
|
|
70
|
+
... self.drift = np.array(drift_direction, dtype=float)
|
|
71
|
+
...
|
|
72
|
+
... def compute_action(self, step_idx, agent):
|
|
73
|
+
... return {"drift_velocity": self.drift, "drift_to_random_strength_ratio": 10.0}
|
|
74
|
+
>>>
|
|
75
|
+
>>> bm.set_dt(0.1)
|
|
76
|
+
>>> task = CustomOpenLoopNavigationTask(
|
|
77
|
+
... duration=0.5,
|
|
78
|
+
... width=1.0,
|
|
79
|
+
... height=1.0,
|
|
80
|
+
... dt=bm.get_dt(),
|
|
81
|
+
... action_policy=ConstantDriftPolicy([0.1, 0.0]),
|
|
82
|
+
... progress_bar=False,
|
|
83
|
+
... )
|
|
84
|
+
>>> task.get_data()
|
|
85
|
+
>>> task.data.position.shape[1]
|
|
86
|
+
2
|
|
55
87
|
"""
|
|
56
88
|
|
|
57
89
|
@abstractmethod
|
|
@@ -75,15 +107,31 @@ class ActionPolicy(ABC):
|
|
|
75
107
|
|
|
76
108
|
@dataclass
|
|
77
109
|
class OpenLoopNavigationData:
|
|
78
|
-
"""
|
|
79
|
-
|
|
80
|
-
It
|
|
81
|
-
and rotational velocity
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
110
|
+
"""Container for open-loop navigation trajectories and derived signals.
|
|
111
|
+
|
|
112
|
+
It stores position, velocity, speed, movement direction, head direction,
|
|
113
|
+
and rotational velocity. Optional fields are added for theta sweep analysis.
|
|
114
|
+
|
|
115
|
+
Workflow:
|
|
116
|
+
Setup -> Create an ``OpenLoopNavigationTask``.
|
|
117
|
+
Execute -> Call ``get_data()``.
|
|
118
|
+
Result -> Access trajectories from ``task.data``.
|
|
119
|
+
|
|
120
|
+
Examples:
|
|
121
|
+
>>> import brainpy.math as bm
|
|
122
|
+
>>> from canns.task.open_loop_navigation import OpenLoopNavigationTask
|
|
123
|
+
>>>
|
|
124
|
+
>>> bm.set_dt(0.1)
|
|
125
|
+
>>> task = OpenLoopNavigationTask(
|
|
126
|
+
... duration=1.0,
|
|
127
|
+
... width=1.0,
|
|
128
|
+
... height=1.0,
|
|
129
|
+
... dt=bm.get_dt(),
|
|
130
|
+
... progress_bar=False,
|
|
131
|
+
... )
|
|
132
|
+
>>> task.get_data()
|
|
133
|
+
>>> task.data.position.shape[1]
|
|
134
|
+
2
|
|
87
135
|
"""
|
|
88
136
|
|
|
89
137
|
position: np.ndarray
|
|
@@ -100,9 +148,32 @@ class OpenLoopNavigationData:
|
|
|
100
148
|
|
|
101
149
|
|
|
102
150
|
class OpenLoopNavigationTask(BaseNavigationTask):
|
|
103
|
-
"""
|
|
104
|
-
|
|
105
|
-
|
|
151
|
+
"""Open-loop navigation task that synthesizes trajectories.
|
|
152
|
+
|
|
153
|
+
The trajectory is generated without real-time feedback control. This is
|
|
154
|
+
useful for producing reproducible paths for model evaluation.
|
|
155
|
+
|
|
156
|
+
Workflow:
|
|
157
|
+
Setup -> Instantiate the task with environment and motion settings.
|
|
158
|
+
Execute -> Call ``get_data()`` to generate a trajectory.
|
|
159
|
+
Result -> Read ``task.data`` for positions, velocities, and speed.
|
|
160
|
+
|
|
161
|
+
Examples:
|
|
162
|
+
>>> import brainpy.math as bm
|
|
163
|
+
>>> from canns.task.open_loop_navigation import OpenLoopNavigationTask
|
|
164
|
+
>>>
|
|
165
|
+
>>> bm.set_dt(0.1)
|
|
166
|
+
>>> task = OpenLoopNavigationTask(
|
|
167
|
+
... duration=1.0,
|
|
168
|
+
... width=1.0,
|
|
169
|
+
... height=1.0,
|
|
170
|
+
... start_pos=(0.5, 0.5),
|
|
171
|
+
... dt=bm.get_dt(),
|
|
172
|
+
... progress_bar=False,
|
|
173
|
+
... )
|
|
174
|
+
>>> task.get_data()
|
|
175
|
+
>>> task.data.position.shape[0] == task.total_steps
|
|
176
|
+
True
|
|
106
177
|
"""
|
|
107
178
|
|
|
108
179
|
def __init__(
|
|
@@ -621,7 +692,9 @@ class OpenLoopNavigationTask(BaseNavigationTask):
|
|
|
621
692
|
|
|
622
693
|
if not hasattr(self, "agent") or self.agent is None:
|
|
623
694
|
self.agent = Agent(
|
|
624
|
-
environment=self.env,
|
|
695
|
+
environment=self.env,
|
|
696
|
+
params=copy.deepcopy(self.agent_params),
|
|
697
|
+
rng_seed=self.rng_seed,
|
|
625
698
|
)
|
|
626
699
|
|
|
627
700
|
# Set initial position
|
|
@@ -696,29 +769,36 @@ class OpenLoopNavigationTask(BaseNavigationTask):
|
|
|
696
769
|
|
|
697
770
|
|
|
698
771
|
class CustomOpenLoopNavigationTask(OpenLoopNavigationTask):
|
|
699
|
-
"""
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
772
|
+
"""Open-loop navigation task driven by a custom action policy.
|
|
773
|
+
|
|
774
|
+
Provide an ``ActionPolicy`` to control how the agent moves at each step.
|
|
775
|
+
|
|
776
|
+
Workflow:
|
|
777
|
+
Setup -> Implement a policy and create the task.
|
|
778
|
+
Execute -> Call ``get_data()``.
|
|
779
|
+
Result -> Trajectory data reflects the policy-driven actions.
|
|
780
|
+
|
|
781
|
+
Examples:
|
|
782
|
+
>>> import numpy as np
|
|
783
|
+
>>> import brainpy.math as bm
|
|
784
|
+
>>> from canns.task.open_loop_navigation import ActionPolicy, CustomOpenLoopNavigationTask
|
|
785
|
+
>>>
|
|
786
|
+
>>> class MyPolicy(ActionPolicy):
|
|
787
|
+
... def compute_action(self, step_idx, agent):
|
|
788
|
+
... return {"drift_velocity": np.array([0.05, 0.0])}
|
|
789
|
+
>>>
|
|
790
|
+
>>> bm.set_dt(0.1)
|
|
791
|
+
>>> task = CustomOpenLoopNavigationTask(
|
|
792
|
+
... duration=0.5,
|
|
793
|
+
... width=1.0,
|
|
794
|
+
... height=1.0,
|
|
795
|
+
... dt=bm.get_dt(),
|
|
796
|
+
... action_policy=MyPolicy(),
|
|
797
|
+
... progress_bar=False,
|
|
798
|
+
... )
|
|
799
|
+
>>> task.get_data()
|
|
800
|
+
>>> task.data.velocity.shape[1]
|
|
801
|
+
2
|
|
722
802
|
"""
|
|
723
803
|
|
|
724
804
|
def __init__(self, *args, action_policy: ActionPolicy | None = None, **kwargs):
|
|
@@ -733,46 +813,38 @@ class CustomOpenLoopNavigationTask(OpenLoopNavigationTask):
|
|
|
733
813
|
|
|
734
814
|
|
|
735
815
|
class StateAwareRasterScanPolicy(ActionPolicy):
|
|
736
|
-
"""
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
task = CustomOpenLoopNavigationTask(
|
|
769
|
-
duration=200,
|
|
770
|
-
action_policy=policy,
|
|
771
|
-
width=1.0,
|
|
772
|
-
height=1.0,
|
|
773
|
-
start_pos=(0.05, 0.95) # Start at top-left
|
|
774
|
-
)
|
|
775
|
-
```
|
|
816
|
+
"""State-aware raster scan policy with cyclic dual-mode exploration.
|
|
817
|
+
|
|
818
|
+
Scanning strategy:
|
|
819
|
+
1) Horizontal mode: left-right sweeps moving downward
|
|
820
|
+
2) Vertical mode: up-down sweeps moving rightward
|
|
821
|
+
3) Cycles continuously to avoid walls and improve coverage
|
|
822
|
+
|
|
823
|
+
Workflow:
|
|
824
|
+
Setup -> Instantiate the policy with environment size.
|
|
825
|
+
Execute -> Use it in ``CustomOpenLoopNavigationTask.get_data()``.
|
|
826
|
+
Result -> The trajectory follows a raster-scan pattern.
|
|
827
|
+
|
|
828
|
+
Examples:
|
|
829
|
+
>>> import brainpy.math as bm
|
|
830
|
+
>>> from canns.task.open_loop_navigation import (
|
|
831
|
+
... StateAwareRasterScanPolicy,
|
|
832
|
+
... CustomOpenLoopNavigationTask,
|
|
833
|
+
... )
|
|
834
|
+
>>>
|
|
835
|
+
>>> bm.set_dt(0.1)
|
|
836
|
+
>>> policy = StateAwareRasterScanPolicy(width=1.0, height=1.0)
|
|
837
|
+
>>> task = CustomOpenLoopNavigationTask(
|
|
838
|
+
... duration=0.5,
|
|
839
|
+
... width=1.0,
|
|
840
|
+
... height=1.0,
|
|
841
|
+
... dt=bm.get_dt(),
|
|
842
|
+
... action_policy=policy,
|
|
843
|
+
... progress_bar=False,
|
|
844
|
+
... )
|
|
845
|
+
>>> task.get_data()
|
|
846
|
+
>>> task.data.position.shape[1]
|
|
847
|
+
2
|
|
776
848
|
"""
|
|
777
849
|
|
|
778
850
|
def __init__(
|
|
@@ -905,47 +977,32 @@ class StateAwareRasterScanPolicy(ActionPolicy):
|
|
|
905
977
|
|
|
906
978
|
|
|
907
979
|
class RasterScanNavigationTask(CustomOpenLoopNavigationTask):
|
|
908
|
-
"""
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
**kwargs: Additional arguments passed to OpenLoopNavigationTask
|
|
935
|
-
|
|
936
|
-
Example:
|
|
937
|
-
```python
|
|
938
|
-
# High coverage dual-mode exploration
|
|
939
|
-
task = RasterScanNavigationTask(
|
|
940
|
-
duration=200,
|
|
941
|
-
width=1.0,
|
|
942
|
-
height=1.0,
|
|
943
|
-
step_size=0.03, # Dense scanning in both directions
|
|
944
|
-
speed=0.15 # Movement speed
|
|
945
|
-
)
|
|
946
|
-
task.get_data()
|
|
947
|
-
task.show_trajectory_analysis()
|
|
948
|
-
```
|
|
980
|
+
"""Preset open-loop task for cyclic dual-mode raster scan exploration.
|
|
981
|
+
|
|
982
|
+
The task alternates between horizontal and vertical sweep phases to cover
|
|
983
|
+
the environment while avoiding walls.
|
|
984
|
+
|
|
985
|
+
Workflow:
|
|
986
|
+
Setup -> Instantiate the task with scan parameters.
|
|
987
|
+
Execute -> Call ``get_data()``.
|
|
988
|
+
Result -> Access the generated trajectory in ``task.data``.
|
|
989
|
+
|
|
990
|
+
Examples:
|
|
991
|
+
>>> import brainpy.math as bm
|
|
992
|
+
>>> from canns.task.open_loop_navigation import RasterScanNavigationTask
|
|
993
|
+
>>>
|
|
994
|
+
>>> bm.set_dt(0.1)
|
|
995
|
+
>>> task = RasterScanNavigationTask(
|
|
996
|
+
... duration=0.5,
|
|
997
|
+
... width=1.0,
|
|
998
|
+
... height=1.0,
|
|
999
|
+
... step_size=0.05,
|
|
1000
|
+
... dt=bm.get_dt(),
|
|
1001
|
+
... progress_bar=False,
|
|
1002
|
+
... )
|
|
1003
|
+
>>> task.get_data()
|
|
1004
|
+
>>> task.data.position.shape[1]
|
|
1005
|
+
2
|
|
949
1006
|
"""
|
|
950
1007
|
|
|
951
1008
|
def __init__(
|
|
@@ -983,11 +1040,24 @@ class RasterScanNavigationTask(CustomOpenLoopNavigationTask):
|
|
|
983
1040
|
|
|
984
1041
|
|
|
985
1042
|
class TMazeOpenLoopNavigationTask(OpenLoopNavigationTask):
|
|
986
|
-
"""
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
1043
|
+
"""Open-loop navigation task in a T-maze environment.
|
|
1044
|
+
|
|
1045
|
+
The environment boundary is configured to a classic T-maze layout.
|
|
1046
|
+
|
|
1047
|
+
Workflow:
|
|
1048
|
+
Setup -> Instantiate the task with maze geometry.
|
|
1049
|
+
Execute -> Call ``get_data()``.
|
|
1050
|
+
Result -> Use ``task.data.position`` as the trajectory.
|
|
1051
|
+
|
|
1052
|
+
Examples:
|
|
1053
|
+
>>> import brainpy.math as bm
|
|
1054
|
+
>>> from canns.task.open_loop_navigation import TMazeOpenLoopNavigationTask
|
|
1055
|
+
>>>
|
|
1056
|
+
>>> bm.set_dt(0.1)
|
|
1057
|
+
>>> task = TMazeOpenLoopNavigationTask(duration=0.5, dt=bm.get_dt(), progress_bar=False)
|
|
1058
|
+
>>> task.get_data()
|
|
1059
|
+
>>> task.data.position.shape[1]
|
|
1060
|
+
2
|
|
991
1061
|
"""
|
|
992
1062
|
|
|
993
1063
|
def __init__(
|
|
@@ -1038,12 +1108,25 @@ class TMazeOpenLoopNavigationTask(OpenLoopNavigationTask):
|
|
|
1038
1108
|
|
|
1039
1109
|
|
|
1040
1110
|
class TMazeRecessOpenLoopNavigationTask(TMazeOpenLoopNavigationTask):
|
|
1041
|
-
"""
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1111
|
+
"""Open-loop navigation task in a T-maze with recesses at the junction.
|
|
1112
|
+
|
|
1113
|
+
Recesses add small indentations near the stem-arm junctions, providing
|
|
1114
|
+
extra spatial structure.
|
|
1115
|
+
|
|
1116
|
+
Workflow:
|
|
1117
|
+
Setup -> Instantiate the task with recess geometry.
|
|
1118
|
+
Execute -> Call ``get_data()``.
|
|
1119
|
+
Result -> Use ``task.data`` for downstream modeling.
|
|
1120
|
+
|
|
1121
|
+
Examples:
|
|
1122
|
+
>>> import brainpy.math as bm
|
|
1123
|
+
>>> from canns.task.open_loop_navigation import TMazeRecessOpenLoopNavigationTask
|
|
1124
|
+
>>>
|
|
1125
|
+
>>> bm.set_dt(0.1)
|
|
1126
|
+
>>> task = TMazeRecessOpenLoopNavigationTask(duration=0.5, dt=bm.get_dt(), progress_bar=False)
|
|
1127
|
+
>>> task.get_data()
|
|
1128
|
+
>>> task.data.position.shape[1]
|
|
1129
|
+
2
|
|
1047
1130
|
"""
|
|
1048
1131
|
|
|
1049
1132
|
def __init__(
|
canns/task/tracking.py
CHANGED
|
@@ -429,10 +429,34 @@ class SmoothTracking(TrackingTask):
|
|
|
429
429
|
|
|
430
430
|
|
|
431
431
|
class PopulationCoding1D(PopulationCoding):
|
|
432
|
-
"""
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
periods of no stimulation
|
|
432
|
+
"""Population coding task for 1D continuous attractor networks.
|
|
433
|
+
|
|
434
|
+
A stimulus is presented for a specific duration, preceded and followed by
|
|
435
|
+
periods of no stimulation.
|
|
436
|
+
|
|
437
|
+
Workflow:
|
|
438
|
+
Setup -> Create a 1D CANN and the task.
|
|
439
|
+
Execute -> Call ``get_data()``.
|
|
440
|
+
Result -> Use ``task.data`` as the input sequence.
|
|
441
|
+
|
|
442
|
+
Examples:
|
|
443
|
+
>>> import brainpy.math as bm
|
|
444
|
+
>>> from canns.models.basic import CANN1D
|
|
445
|
+
>>> from canns.task.tracking import PopulationCoding1D
|
|
446
|
+
>>>
|
|
447
|
+
>>> bm.set_dt(0.1)
|
|
448
|
+
>>> model = CANN1D(num=64)
|
|
449
|
+
>>> task = PopulationCoding1D(
|
|
450
|
+
... cann_instance=model,
|
|
451
|
+
... before_duration=1.0,
|
|
452
|
+
... after_duration=1.0,
|
|
453
|
+
... Iext=0.0,
|
|
454
|
+
... duration=2.0,
|
|
455
|
+
... time_step=bm.get_dt(),
|
|
456
|
+
... )
|
|
457
|
+
>>> task.get_data()
|
|
458
|
+
>>> task.data.shape[0] == task.total_steps
|
|
459
|
+
True
|
|
436
460
|
"""
|
|
437
461
|
|
|
438
462
|
def __init__(
|
|
@@ -472,10 +496,32 @@ class PopulationCoding1D(PopulationCoding):
|
|
|
472
496
|
|
|
473
497
|
|
|
474
498
|
class TemplateMatching1D(TemplateMatching):
|
|
475
|
-
"""
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
499
|
+
"""Template matching task for 1D continuous attractor networks.
|
|
500
|
+
|
|
501
|
+
A fixed stimulus template is presented with noise at each step, testing
|
|
502
|
+
the network's denoising dynamics.
|
|
503
|
+
|
|
504
|
+
Workflow:
|
|
505
|
+
Setup -> Create a 1D CANN and the task.
|
|
506
|
+
Execute -> Call ``get_data()``.
|
|
507
|
+
Result -> Use ``task.data`` as the noisy input sequence.
|
|
508
|
+
|
|
509
|
+
Examples:
|
|
510
|
+
>>> import brainpy.math as bm
|
|
511
|
+
>>> from canns.models.basic import CANN1D
|
|
512
|
+
>>> from canns.task.tracking import TemplateMatching1D
|
|
513
|
+
>>>
|
|
514
|
+
>>> bm.set_dt(0.1)
|
|
515
|
+
>>> model = CANN1D(num=64)
|
|
516
|
+
>>> task = TemplateMatching1D(
|
|
517
|
+
... cann_instance=model,
|
|
518
|
+
... Iext=0.0,
|
|
519
|
+
... duration=1.0,
|
|
520
|
+
... time_step=bm.get_dt(),
|
|
521
|
+
... )
|
|
522
|
+
>>> task.get_data()
|
|
523
|
+
>>> task.data.shape[1] == model.shape[0]
|
|
524
|
+
True
|
|
479
525
|
"""
|
|
480
526
|
|
|
481
527
|
def __init__(
|
|
@@ -504,10 +550,31 @@ class TemplateMatching1D(TemplateMatching):
|
|
|
504
550
|
|
|
505
551
|
|
|
506
552
|
class SmoothTracking1D(SmoothTracking):
|
|
507
|
-
"""
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
553
|
+
"""Smooth tracking task for 1D continuous attractor networks.
|
|
554
|
+
|
|
555
|
+
The external input moves smoothly between key positions.
|
|
556
|
+
|
|
557
|
+
Workflow:
|
|
558
|
+
Setup -> Create a 1D CANN and the task.
|
|
559
|
+
Execute -> Call ``get_data()``.
|
|
560
|
+
Result -> ``task.data`` contains the smoothly varying stimulus.
|
|
561
|
+
|
|
562
|
+
Examples:
|
|
563
|
+
>>> import brainpy.math as bm
|
|
564
|
+
>>> from canns.models.basic import CANN1D
|
|
565
|
+
>>> from canns.task.tracking import SmoothTracking1D
|
|
566
|
+
>>>
|
|
567
|
+
>>> bm.set_dt(0.1)
|
|
568
|
+
>>> model = CANN1D(num=64)
|
|
569
|
+
>>> task = SmoothTracking1D(
|
|
570
|
+
... cann_instance=model,
|
|
571
|
+
... Iext=(0.0, 1.0, 0.5),
|
|
572
|
+
... duration=(0.5, 0.5),
|
|
573
|
+
... time_step=bm.get_dt(),
|
|
574
|
+
... )
|
|
575
|
+
>>> task.get_data()
|
|
576
|
+
>>> task.data.shape[0] == task.total_steps
|
|
577
|
+
True
|
|
511
578
|
"""
|
|
512
579
|
|
|
513
580
|
def __init__(
|
|
@@ -565,10 +632,33 @@ class CustomTracking1D(TrackingTask):
|
|
|
565
632
|
|
|
566
633
|
|
|
567
634
|
class PopulationCoding2D(PopulationCoding):
|
|
568
|
-
"""
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
635
|
+
"""Population coding task for 2D continuous attractor networks.
|
|
636
|
+
|
|
637
|
+
A 2D stimulus is presented for a duration with pre- and post-silence.
|
|
638
|
+
|
|
639
|
+
Workflow:
|
|
640
|
+
Setup -> Create a 2D CANN and the task.
|
|
641
|
+
Execute -> Call ``get_data()``.
|
|
642
|
+
Result -> Use ``task.data`` as the input sequence.
|
|
643
|
+
|
|
644
|
+
Examples:
|
|
645
|
+
>>> import brainpy.math as bm
|
|
646
|
+
>>> from canns.models.basic import CANN2D
|
|
647
|
+
>>> from canns.task.tracking import PopulationCoding2D
|
|
648
|
+
>>>
|
|
649
|
+
>>> bm.set_dt(0.1)
|
|
650
|
+
>>> model = CANN2D(length=8)
|
|
651
|
+
>>> task = PopulationCoding2D(
|
|
652
|
+
... cann_instance=model,
|
|
653
|
+
... before_duration=1.0,
|
|
654
|
+
... after_duration=1.0,
|
|
655
|
+
... Iext=(0.0, 0.0),
|
|
656
|
+
... duration=1.0,
|
|
657
|
+
... time_step=bm.get_dt(),
|
|
658
|
+
... )
|
|
659
|
+
>>> task.get_data()
|
|
660
|
+
>>> task.data.shape[1:] == model.shape
|
|
661
|
+
True
|
|
572
662
|
"""
|
|
573
663
|
|
|
574
664
|
def __init__(
|
|
@@ -609,10 +699,31 @@ class PopulationCoding2D(PopulationCoding):
|
|
|
609
699
|
|
|
610
700
|
|
|
611
701
|
class TemplateMatching2D(TemplateMatching):
|
|
612
|
-
"""
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
702
|
+
"""Template matching task for 2D continuous attractor networks.
|
|
703
|
+
|
|
704
|
+
A 2D template is presented with noise at each step.
|
|
705
|
+
|
|
706
|
+
Workflow:
|
|
707
|
+
Setup -> Create a 2D CANN and the task.
|
|
708
|
+
Execute -> Call ``get_data()``.
|
|
709
|
+
Result -> ``task.data`` contains noisy 2D inputs.
|
|
710
|
+
|
|
711
|
+
Examples:
|
|
712
|
+
>>> import brainpy.math as bm
|
|
713
|
+
>>> from canns.models.basic import CANN2D
|
|
714
|
+
>>> from canns.task.tracking import TemplateMatching2D
|
|
715
|
+
>>>
|
|
716
|
+
>>> bm.set_dt(0.1)
|
|
717
|
+
>>> model = CANN2D(length=8)
|
|
718
|
+
>>> task = TemplateMatching2D(
|
|
719
|
+
... cann_instance=model,
|
|
720
|
+
... Iext=(0.0, 0.0),
|
|
721
|
+
... duration=1.0,
|
|
722
|
+
... time_step=bm.get_dt(),
|
|
723
|
+
... )
|
|
724
|
+
>>> task.get_data()
|
|
725
|
+
>>> task.data.shape[1:] == model.shape
|
|
726
|
+
True
|
|
616
727
|
"""
|
|
617
728
|
|
|
618
729
|
def __init__(
|
|
@@ -642,10 +753,31 @@ class TemplateMatching2D(TemplateMatching):
|
|
|
642
753
|
|
|
643
754
|
|
|
644
755
|
class SmoothTracking2D(SmoothTracking):
|
|
645
|
-
"""
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
756
|
+
"""Smooth tracking task for 2D continuous attractor networks.
|
|
757
|
+
|
|
758
|
+
The external 2D input moves smoothly between key positions.
|
|
759
|
+
|
|
760
|
+
Workflow:
|
|
761
|
+
Setup -> Create a 2D CANN and the task.
|
|
762
|
+
Execute -> Call ``get_data()``.
|
|
763
|
+
Result -> ``task.data`` contains smoothly varying 2D inputs.
|
|
764
|
+
|
|
765
|
+
Examples:
|
|
766
|
+
>>> import brainpy.math as bm
|
|
767
|
+
>>> from canns.models.basic import CANN2D
|
|
768
|
+
>>> from canns.task.tracking import SmoothTracking2D
|
|
769
|
+
>>>
|
|
770
|
+
>>> bm.set_dt(0.1)
|
|
771
|
+
>>> model = CANN2D(length=8)
|
|
772
|
+
>>> task = SmoothTracking2D(
|
|
773
|
+
... cann_instance=model,
|
|
774
|
+
... Iext=((0.0, 0.0), (1.0, 1.0), (0.5, 0.5)),
|
|
775
|
+
... duration=(0.5, 0.5),
|
|
776
|
+
... time_step=bm.get_dt(),
|
|
777
|
+
... )
|
|
778
|
+
>>> task.get_data()
|
|
779
|
+
>>> task.data.shape[1:] == model.shape
|
|
780
|
+
True
|
|
649
781
|
"""
|
|
650
782
|
|
|
651
783
|
def __init__(
|