canns 0.12.6__py3-none-any.whl → 0.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- canns/__init__.py +39 -3
- canns/analyzer/__init__.py +7 -6
- canns/analyzer/data/__init__.py +3 -11
- canns/analyzer/data/asa/__init__.py +74 -0
- canns/analyzer/data/asa/cohospace.py +905 -0
- canns/analyzer/data/asa/config.py +246 -0
- canns/analyzer/data/asa/decode.py +448 -0
- canns/analyzer/data/asa/embedding.py +269 -0
- canns/analyzer/data/asa/filters.py +208 -0
- canns/analyzer/data/asa/fr.py +439 -0
- canns/analyzer/data/asa/path.py +389 -0
- canns/analyzer/data/asa/plotting.py +1276 -0
- canns/analyzer/data/asa/tda.py +901 -0
- canns/analyzer/data/legacy/__init__.py +6 -0
- canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
- canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
- canns/analyzer/metrics/spatial_metrics.py +70 -100
- canns/analyzer/metrics/systematic_ratemap.py +12 -17
- canns/analyzer/metrics/utils.py +28 -0
- canns/analyzer/model_specific/hopfield.py +19 -16
- canns/analyzer/slow_points/checkpoint.py +32 -9
- canns/analyzer/slow_points/finder.py +33 -6
- canns/analyzer/slow_points/fixed_points.py +12 -0
- canns/analyzer/slow_points/visualization.py +22 -10
- canns/analyzer/visualization/core/backend.py +15 -26
- canns/analyzer/visualization/core/config.py +120 -15
- canns/analyzer/visualization/core/jupyter_utils.py +34 -16
- canns/analyzer/visualization/core/rendering.py +42 -40
- canns/analyzer/visualization/core/writers.py +10 -20
- canns/analyzer/visualization/energy_plots.py +78 -28
- canns/analyzer/visualization/spatial_plots.py +81 -36
- canns/analyzer/visualization/spike_plots.py +27 -7
- canns/analyzer/visualization/theta_sweep_plots.py +159 -72
- canns/analyzer/visualization/tuning_plots.py +11 -3
- canns/data/__init__.py +7 -4
- canns/models/__init__.py +10 -0
- canns/models/basic/cann.py +102 -40
- canns/models/basic/grid_cell.py +9 -8
- canns/models/basic/hierarchical_model.py +57 -11
- canns/models/brain_inspired/hopfield.py +26 -14
- canns/models/brain_inspired/linear.py +15 -16
- canns/models/brain_inspired/spiking.py +23 -12
- canns/pipeline/__init__.py +4 -8
- canns/pipeline/asa/__init__.py +21 -0
- canns/pipeline/asa/__main__.py +11 -0
- canns/pipeline/asa/app.py +1000 -0
- canns/pipeline/asa/runner.py +1095 -0
- canns/pipeline/asa/screens.py +215 -0
- canns/pipeline/asa/state.py +248 -0
- canns/pipeline/asa/styles.tcss +221 -0
- canns/pipeline/asa/widgets.py +233 -0
- canns/pipeline/gallery/__init__.py +7 -0
- canns/task/closed_loop_navigation.py +54 -13
- canns/task/open_loop_navigation.py +230 -147
- canns/task/tracking.py +156 -24
- canns/trainer/__init__.py +8 -5
- canns/utils/__init__.py +12 -4
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
- canns-0.13.0.dist-info/RECORD +91 -0
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
- canns/pipeline/theta_sweep.py +0 -573
- canns-0.12.6.dist-info/RECORD +0 -72
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
- {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -744,12 +744,29 @@ class GridCell(BasicModel):
|
|
|
744
744
|
class HierarchicalPathIntegrationModel(BasicModelGroup):
|
|
745
745
|
"""A hierarchical model combining band cells and grid cells for path integration.
|
|
746
746
|
|
|
747
|
-
This model forms a single grid module. It consists of three `BandCell` modules
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
747
|
+
This model forms a single grid module. It consists of three `BandCell` modules
|
|
748
|
+
(60 degrees apart) plus one `GridCell`. The band cells integrate velocity,
|
|
749
|
+
and their combined output drives the grid cell bump. The grid cell activity
|
|
750
|
+
can be projected to place cells.
|
|
751
|
+
|
|
752
|
+
Examples:
|
|
753
|
+
>>> import brainpy.math as bm
|
|
754
|
+
>>> from canns.models.basic.hierarchical_model import HierarchicalPathIntegrationModel
|
|
755
|
+
>>>
|
|
756
|
+
>>> bm.set_dt(0.1)
|
|
757
|
+
>>> place_center = bm.array([[0.0, 0.0], [1.0, 1.0]])
|
|
758
|
+
>>> model = HierarchicalPathIntegrationModel(
|
|
759
|
+
... spacing=2.5,
|
|
760
|
+
... angle=0.0,
|
|
761
|
+
... place_center=place_center,
|
|
762
|
+
... band_size=30,
|
|
763
|
+
... grid_num=10,
|
|
764
|
+
... )
|
|
765
|
+
>>> velocity = bm.array([0.0, 0.0])
|
|
766
|
+
>>> position = bm.array([0.0, 0.0])
|
|
767
|
+
>>> model.update(velocity=velocity, loc=position, loc_input_stre=0.0)
|
|
768
|
+
>>> model.grid_output.value.shape
|
|
769
|
+
(2,)
|
|
753
770
|
|
|
754
771
|
Attributes:
|
|
755
772
|
band_cell_x (BandCell): The first band cell module (orientation `angle`).
|
|
@@ -988,6 +1005,16 @@ class HierarchicalPathIntegrationModel(BasicModelGroup):
|
|
|
988
1005
|
)
|
|
989
1006
|
|
|
990
1007
|
def update(self, velocity, loc, loc_input_stre=0.0):
|
|
1008
|
+
"""Advance the model by one time step.
|
|
1009
|
+
|
|
1010
|
+
Args:
|
|
1011
|
+
velocity (Array): 2D velocity vector, shape ``(2,)``.
|
|
1012
|
+
loc (Array): 2D position vector, shape ``(2,)``.
|
|
1013
|
+
loc_input_stre (float): Strength of optional location-based input.
|
|
1014
|
+
|
|
1015
|
+
Returns:
|
|
1016
|
+
None
|
|
1017
|
+
"""
|
|
991
1018
|
self.band_cell_x(velocity=velocity, loc=loc, loc_input_stre=loc_input_stre)
|
|
992
1019
|
self.band_cell_y(velocity=velocity, loc=loc, loc_input_stre=loc_input_stre)
|
|
993
1020
|
self.band_cell_z(velocity=velocity, loc=loc, loc_input_stre=loc_input_stre)
|
|
@@ -1030,11 +1057,20 @@ class HierarchicalPathIntegrationModel(BasicModelGroup):
|
|
|
1030
1057
|
class HierarchicalNetwork(BasicModelGroup):
|
|
1031
1058
|
"""A full hierarchical network composed of multiple grid modules.
|
|
1032
1059
|
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1060
|
+
Each module is a `HierarchicalPathIntegrationModel` with a different grid
|
|
1061
|
+
spacing. The module outputs are combined to decode a single 2D position.
|
|
1062
|
+
|
|
1063
|
+
Examples:
|
|
1064
|
+
>>> import brainpy.math as bm
|
|
1065
|
+
>>> from canns.models.basic import HierarchicalNetwork
|
|
1066
|
+
>>>
|
|
1067
|
+
>>> bm.set_dt(0.1)
|
|
1068
|
+
>>> model = HierarchicalNetwork(num_module=1, num_place=3)
|
|
1069
|
+
>>> velocity = bm.array([0.0, 0.0])
|
|
1070
|
+
>>> position = bm.array([0.0, 0.0])
|
|
1071
|
+
>>> model.update(velocity=velocity, loc=position, loc_input_stre=0.0)
|
|
1072
|
+
>>> model.decoded_pos.value.shape
|
|
1073
|
+
(2,)
|
|
1038
1074
|
|
|
1039
1075
|
Attributes:
|
|
1040
1076
|
num_module (int): The number of grid modules in the network.
|
|
@@ -1163,6 +1199,16 @@ class HierarchicalNetwork(BasicModelGroup):
|
|
|
1163
1199
|
self.decoded_pos = bm.Variable(bm.zeros(2))
|
|
1164
1200
|
|
|
1165
1201
|
def update(self, velocity, loc, loc_input_stre=0.0):
|
|
1202
|
+
"""Advance the full network by one time step.
|
|
1203
|
+
|
|
1204
|
+
Args:
|
|
1205
|
+
velocity (Array): 2D velocity vector, shape ``(2,)``.
|
|
1206
|
+
loc (Array): 2D position vector, shape ``(2,)``.
|
|
1207
|
+
loc_input_stre (float): Strength of optional location-based input.
|
|
1208
|
+
|
|
1209
|
+
Returns:
|
|
1210
|
+
None
|
|
1211
|
+
"""
|
|
1166
1212
|
grid_output = bm.zeros(self.num_place)
|
|
1167
1213
|
for i in range(self.num_module):
|
|
1168
1214
|
# update the band cell module
|
|
@@ -9,18 +9,25 @@ __all__ = ["AmariHopfieldNetwork"]
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class AmariHopfieldNetwork(BrainInspiredModel):
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
12
|
+
"""Amari-Hopfield network with discrete or continuous dynamics.
|
|
13
|
+
|
|
14
|
+
The model performs pattern completion by iteratively updating the state
|
|
15
|
+
vector ``s`` to reduce energy:
|
|
16
|
+
E = -0.5 * sum_ij W_ij * s_i * s_j
|
|
17
|
+
|
|
18
|
+
Examples:
|
|
19
|
+
>>> import jax.numpy as jnp
|
|
20
|
+
>>> from canns.models.brain_inspired import AmariHopfieldNetwork
|
|
21
|
+
>>>
|
|
22
|
+
>>> model = AmariHopfieldNetwork(num_neurons=3, activation="sign")
|
|
23
|
+
>>> pattern = jnp.array([1.0, -1.0, 1.0], dtype=jnp.float32)
|
|
24
|
+
>>> weights = jnp.outer(pattern, pattern)
|
|
25
|
+
>>> weights = weights - jnp.diag(jnp.diag(weights)) # zero diagonal
|
|
26
|
+
>>> model.W.value = weights
|
|
27
|
+
>>> model.s.value = jnp.array([1.0, 1.0, -1.0], dtype=jnp.float32)
|
|
28
|
+
>>> model.update(None)
|
|
29
|
+
>>> model.s.value.shape
|
|
30
|
+
(3,)
|
|
24
31
|
|
|
25
32
|
Reference:
|
|
26
33
|
Amari, S. (1977). Neural theory of association and concept-formation.
|
|
@@ -78,8 +85,13 @@ class AmariHopfieldNetwork(BrainInspiredModel):
|
|
|
78
85
|
raise ValueError(f"Unknown activation type: {activation}")
|
|
79
86
|
|
|
80
87
|
def update(self, e_old):
|
|
81
|
-
"""
|
|
82
|
-
|
|
88
|
+
"""Update network state for one time step.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
e_old: Unused placeholder for trainer compatibility.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
None
|
|
83
95
|
"""
|
|
84
96
|
if self.asyn:
|
|
85
97
|
self._asynchronous_update()
|
|
@@ -11,22 +11,22 @@ __all__ = ["LinearLayer"]
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class LinearLayer(BrainInspiredModel):
|
|
14
|
-
"""
|
|
15
|
-
Generic linear feedforward layer supporting multiple brain-inspired learning rules.
|
|
16
|
-
|
|
17
|
-
This model provides a simple linear transformation with optional sliding threshold
|
|
18
|
-
for BCM-style plasticity. It can be used with various trainers:
|
|
19
|
-
- OjaTrainer: Normalized Hebbian learning for PCA
|
|
20
|
-
- BCMTrainer: Sliding threshold plasticity (requires use_bcm_threshold=True)
|
|
21
|
-
- HebbianTrainer: Standard Hebbian learning
|
|
14
|
+
"""Generic linear feedforward layer for brain-inspired learning rules.
|
|
22
15
|
|
|
23
|
-
|
|
16
|
+
It computes a simple linear transform:
|
|
24
17
|
y = W @ x
|
|
25
18
|
|
|
26
|
-
|
|
19
|
+
You can pair it with trainers like ``OjaTrainer``, ``BCMTrainer``, or
|
|
20
|
+
``HebbianTrainer``.
|
|
27
21
|
|
|
28
|
-
|
|
29
|
-
|
|
22
|
+
Examples:
|
|
23
|
+
>>> import jax.numpy as jnp
|
|
24
|
+
>>> from canns.models.brain_inspired import LinearLayer
|
|
25
|
+
>>>
|
|
26
|
+
>>> layer = LinearLayer(input_size=3, output_size=2)
|
|
27
|
+
>>> y = layer.forward(jnp.array([1.0, 0.5, -1.0], dtype=jnp.float32))
|
|
28
|
+
>>> y.shape
|
|
29
|
+
(2,)
|
|
30
30
|
|
|
31
31
|
References:
|
|
32
32
|
- Oja (1982): Simplified neuron model as a principal component analyzer
|
|
@@ -73,14 +73,13 @@ class LinearLayer(BrainInspiredModel):
|
|
|
73
73
|
self.theta = bm.Variable(jnp.ones(self.output_size, dtype=jnp.float32) * 0.1)
|
|
74
74
|
|
|
75
75
|
def forward(self, x: jnp.ndarray) -> jnp.ndarray:
|
|
76
|
-
"""
|
|
77
|
-
Forward pass through the layer.
|
|
76
|
+
"""Compute the layer output for one input vector.
|
|
78
77
|
|
|
79
78
|
Args:
|
|
80
|
-
x: Input vector of shape (input_size,)
|
|
79
|
+
x: Input vector of shape ``(input_size,)``.
|
|
81
80
|
|
|
82
81
|
Returns:
|
|
83
|
-
Output vector of shape (output_size,)
|
|
82
|
+
Output vector of shape ``(output_size,)``.
|
|
84
83
|
"""
|
|
85
84
|
self.x.value = jnp.asarray(x, dtype=jnp.float32)
|
|
86
85
|
self.y.value = self.W.value @ self.x.value
|
|
@@ -12,15 +12,9 @@ __all__ = ["SpikingLayer"]
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class SpikingLayer(BrainInspiredModel):
|
|
15
|
-
"""
|
|
16
|
-
Simple Leaky Integrate-and-Fire (LIF) spiking neuron layer.
|
|
15
|
+
"""Simple Leaky Integrate-and-Fire (LIF) spiking neuron layer.
|
|
17
16
|
|
|
18
|
-
|
|
19
|
-
spike-timing-dependent plasticity (STDP). It features:
|
|
20
|
-
- Leaky integration of input currents
|
|
21
|
-
- Threshold-based spike generation
|
|
22
|
-
- Reset mechanism after spiking
|
|
23
|
-
- Exponential spike traces for STDP learning
|
|
17
|
+
It supports STDP-style training by maintaining pre/post spike traces.
|
|
24
18
|
|
|
25
19
|
Dynamics:
|
|
26
20
|
v[t+1] = leak * v[t] + W @ x[t]
|
|
@@ -28,6 +22,23 @@ class SpikingLayer(BrainInspiredModel):
|
|
|
28
22
|
v = v_reset if spike else v
|
|
29
23
|
trace = decay * trace + spike
|
|
30
24
|
|
|
25
|
+
Notes:
|
|
26
|
+
- x[t] denotes the input current at time t. It can take arbitrary continuous
|
|
27
|
+
values; binary {0, 1} spike trains are a special case of such inputs.
|
|
28
|
+
- The layer does not internally binarize x; thresholding only applies to the
|
|
29
|
+
membrane potential to generate output spikes.
|
|
30
|
+
|
|
31
|
+
Examples:
|
|
32
|
+
>>> import jax.numpy as jnp
|
|
33
|
+
>>> from canns.models.brain_inspired import SpikingLayer
|
|
34
|
+
>>>
|
|
35
|
+
>>> layer = SpikingLayer(input_size=3, output_size=2, threshold=0.5)
|
|
36
|
+
>>> # Continuous input currents (binary spikes {0,1} are a special case)
|
|
37
|
+
>>> x = jnp.array([0.2, 0.5, 1.3], dtype=jnp.float32)
|
|
38
|
+
>>> spikes = layer.forward(x)
|
|
39
|
+
>>> spikes.shape
|
|
40
|
+
(2,)
|
|
41
|
+
|
|
31
42
|
References:
|
|
32
43
|
- Gerstner & Kistler (2002): Spiking Neuron Models
|
|
33
44
|
- Morrison et al. (2008): Phenomenological models of synaptic plasticity
|
|
@@ -88,14 +99,14 @@ class SpikingLayer(BrainInspiredModel):
|
|
|
88
99
|
self.trace_post = bm.Variable(jnp.zeros(self.output_size, dtype=jnp.float32))
|
|
89
100
|
|
|
90
101
|
def forward(self, x: jnp.ndarray) -> jnp.ndarray:
|
|
91
|
-
"""
|
|
92
|
-
Forward pass through the spiking layer.
|
|
102
|
+
"""Compute spikes for one time step.
|
|
93
103
|
|
|
94
104
|
Args:
|
|
95
|
-
x: Input
|
|
105
|
+
x: Input currents of shape ``(input_size,)``. Can be continuous-valued
|
|
106
|
+
(e.g., synaptic currents) or binary spikes {0, 1} as a special case.
|
|
96
107
|
|
|
97
108
|
Returns:
|
|
98
|
-
Output spikes of shape (output_size,) with
|
|
109
|
+
Output spikes of shape ``(output_size,)`` with values 0 or 1.
|
|
99
110
|
"""
|
|
100
111
|
self.x.value = jnp.asarray(x, dtype=jnp.float32)
|
|
101
112
|
|
canns/pipeline/__init__.py
CHANGED
|
@@ -7,15 +7,11 @@ the underlying implementations.
|
|
|
7
7
|
"""
|
|
8
8
|
|
|
9
9
|
from ._base import Pipeline
|
|
10
|
-
from .
|
|
11
|
-
|
|
12
|
-
batch_process_trajectories,
|
|
13
|
-
load_trajectory_from_csv,
|
|
14
|
-
)
|
|
10
|
+
from .asa import ASAApp
|
|
11
|
+
from .asa import main as asa_main
|
|
15
12
|
|
|
16
13
|
__all__ = [
|
|
17
14
|
"Pipeline",
|
|
18
|
-
"
|
|
19
|
-
"
|
|
20
|
-
"batch_process_trajectories",
|
|
15
|
+
"ASAApp",
|
|
16
|
+
"asa_main",
|
|
21
17
|
]
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""ASA TUI - Terminal User Interface for ASA Analysis.
|
|
2
|
+
|
|
3
|
+
This module provides a Textual-based TUI for running ASA (Attractor State Analysis)
|
|
4
|
+
with 7 analysis modules: TDA, CohoMap, PathCompare, CohoSpace, FR, FRM, and GridScore.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
__all__ = ["ASAApp", "main"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def main():
|
|
13
|
+
"""Entry point for canns-tui command."""
|
|
14
|
+
os.environ.setdefault("MPLBACKEND", "Agg")
|
|
15
|
+
from .app import ASAApp
|
|
16
|
+
|
|
17
|
+
app = ASAApp()
|
|
18
|
+
app.run()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
from .app import ASAApp
|