canns 0.12.6__py3-none-any.whl → 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. canns/__init__.py +39 -3
  2. canns/analyzer/__init__.py +7 -6
  3. canns/analyzer/data/__init__.py +3 -11
  4. canns/analyzer/data/asa/__init__.py +74 -0
  5. canns/analyzer/data/asa/cohospace.py +905 -0
  6. canns/analyzer/data/asa/config.py +246 -0
  7. canns/analyzer/data/asa/decode.py +448 -0
  8. canns/analyzer/data/asa/embedding.py +269 -0
  9. canns/analyzer/data/asa/filters.py +208 -0
  10. canns/analyzer/data/asa/fr.py +439 -0
  11. canns/analyzer/data/asa/path.py +389 -0
  12. canns/analyzer/data/asa/plotting.py +1276 -0
  13. canns/analyzer/data/asa/tda.py +901 -0
  14. canns/analyzer/data/legacy/__init__.py +6 -0
  15. canns/analyzer/data/{cann1d.py → legacy/cann1d.py} +2 -2
  16. canns/analyzer/data/{cann2d.py → legacy/cann2d.py} +3 -3
  17. canns/analyzer/metrics/spatial_metrics.py +70 -100
  18. canns/analyzer/metrics/systematic_ratemap.py +12 -17
  19. canns/analyzer/metrics/utils.py +28 -0
  20. canns/analyzer/model_specific/hopfield.py +19 -16
  21. canns/analyzer/slow_points/checkpoint.py +32 -9
  22. canns/analyzer/slow_points/finder.py +33 -6
  23. canns/analyzer/slow_points/fixed_points.py +12 -0
  24. canns/analyzer/slow_points/visualization.py +22 -10
  25. canns/analyzer/visualization/core/backend.py +15 -26
  26. canns/analyzer/visualization/core/config.py +120 -15
  27. canns/analyzer/visualization/core/jupyter_utils.py +34 -16
  28. canns/analyzer/visualization/core/rendering.py +42 -40
  29. canns/analyzer/visualization/core/writers.py +10 -20
  30. canns/analyzer/visualization/energy_plots.py +78 -28
  31. canns/analyzer/visualization/spatial_plots.py +81 -36
  32. canns/analyzer/visualization/spike_plots.py +27 -7
  33. canns/analyzer/visualization/theta_sweep_plots.py +159 -72
  34. canns/analyzer/visualization/tuning_plots.py +11 -3
  35. canns/data/__init__.py +7 -4
  36. canns/models/__init__.py +10 -0
  37. canns/models/basic/cann.py +102 -40
  38. canns/models/basic/grid_cell.py +9 -8
  39. canns/models/basic/hierarchical_model.py +57 -11
  40. canns/models/brain_inspired/hopfield.py +26 -14
  41. canns/models/brain_inspired/linear.py +15 -16
  42. canns/models/brain_inspired/spiking.py +23 -12
  43. canns/pipeline/__init__.py +4 -8
  44. canns/pipeline/asa/__init__.py +21 -0
  45. canns/pipeline/asa/__main__.py +11 -0
  46. canns/pipeline/asa/app.py +1000 -0
  47. canns/pipeline/asa/runner.py +1095 -0
  48. canns/pipeline/asa/screens.py +215 -0
  49. canns/pipeline/asa/state.py +248 -0
  50. canns/pipeline/asa/styles.tcss +221 -0
  51. canns/pipeline/asa/widgets.py +233 -0
  52. canns/pipeline/gallery/__init__.py +7 -0
  53. canns/task/closed_loop_navigation.py +54 -13
  54. canns/task/open_loop_navigation.py +230 -147
  55. canns/task/tracking.py +156 -24
  56. canns/trainer/__init__.py +8 -5
  57. canns/utils/__init__.py +12 -4
  58. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/METADATA +6 -3
  59. canns-0.13.0.dist-info/RECORD +91 -0
  60. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/entry_points.txt +1 -0
  61. canns/pipeline/theta_sweep.py +0 -573
  62. canns-0.12.6.dist-info/RECORD +0 -72
  63. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/WHEEL +0 -0
  64. {canns-0.12.6.dist-info → canns-0.13.0.dist-info}/licenses/LICENSE +0 -0
@@ -744,12 +744,29 @@ class GridCell(BasicModel):
744
744
  class HierarchicalPathIntegrationModel(BasicModelGroup):
745
745
  """A hierarchical model combining band cells and grid cells for path integration.
746
746
 
747
- This model forms a single grid module. It consists of three `BandCell` modules,
748
- each with a different preferred orientation (separated by 60 degrees), and one
749
- `GridCell` module. The band cells integrate velocity along their respective
750
- directions, and their combined outputs provide the input to the `GridCell`
751
- network, effectively driving the grid cell's activity bump. The model can
752
- also project its grid cell activity to a population of place cells.
747
+ This model forms a single grid module. It consists of three `BandCell` modules
748
+ (60 degrees apart) plus one `GridCell`. The band cells integrate velocity,
749
+ and their combined output drives the grid cell bump. The grid cell activity
750
+ can be projected to place cells.
751
+
752
+ Examples:
753
+ >>> import brainpy.math as bm
754
+ >>> from canns.models.basic.hierarchical_model import HierarchicalPathIntegrationModel
755
+ >>>
756
+ >>> bm.set_dt(0.1)
757
+ >>> place_center = bm.array([[0.0, 0.0], [1.0, 1.0]])
758
+ >>> model = HierarchicalPathIntegrationModel(
759
+ ... spacing=2.5,
760
+ ... angle=0.0,
761
+ ... place_center=place_center,
762
+ ... band_size=30,
763
+ ... grid_num=10,
764
+ ... )
765
+ >>> velocity = bm.array([0.0, 0.0])
766
+ >>> position = bm.array([0.0, 0.0])
767
+ >>> model.update(velocity=velocity, loc=position, loc_input_stre=0.0)
768
+ >>> model.grid_output.value.shape
769
+ (2,)
753
770
 
754
771
  Attributes:
755
772
  band_cell_x (BandCell): The first band cell module (orientation `angle`).
@@ -988,6 +1005,16 @@ class HierarchicalPathIntegrationModel(BasicModelGroup):
988
1005
  )
989
1006
 
990
1007
  def update(self, velocity, loc, loc_input_stre=0.0):
1008
+ """Advance the model by one time step.
1009
+
1010
+ Args:
1011
+ velocity (Array): 2D velocity vector, shape ``(2,)``.
1012
+ loc (Array): 2D position vector, shape ``(2,)``.
1013
+ loc_input_stre (float): Strength of optional location-based input.
1014
+
1015
+ Returns:
1016
+ None
1017
+ """
991
1018
  self.band_cell_x(velocity=velocity, loc=loc, loc_input_stre=loc_input_stre)
992
1019
  self.band_cell_y(velocity=velocity, loc=loc, loc_input_stre=loc_input_stre)
993
1020
  self.band_cell_z(velocity=velocity, loc=loc, loc_input_stre=loc_input_stre)
@@ -1030,11 +1057,20 @@ class HierarchicalPathIntegrationModel(BasicModelGroup):
1030
1057
  class HierarchicalNetwork(BasicModelGroup):
1031
1058
  """A full hierarchical network composed of multiple grid modules.
1032
1059
 
1033
- This class creates and manages a collection of `HierarchicalPathIntegrationModel`
1034
- modules, each with a different grid spacing. By combining the outputs of these
1035
- modules, the network can represent position unambiguously over a large area.
1036
- The final output is a population of place cells whose activities are used to
1037
- decode the animal's estimated position.
1060
+ Each module is a `HierarchicalPathIntegrationModel` with a different grid
1061
+ spacing. The module outputs are combined to decode a single 2D position.
1062
+
1063
+ Examples:
1064
+ >>> import brainpy.math as bm
1065
+ >>> from canns.models.basic import HierarchicalNetwork
1066
+ >>>
1067
+ >>> bm.set_dt(0.1)
1068
+ >>> model = HierarchicalNetwork(num_module=1, num_place=3)
1069
+ >>> velocity = bm.array([0.0, 0.0])
1070
+ >>> position = bm.array([0.0, 0.0])
1071
+ >>> model.update(velocity=velocity, loc=position, loc_input_stre=0.0)
1072
+ >>> model.decoded_pos.value.shape
1073
+ (2,)
1038
1074
 
1039
1075
  Attributes:
1040
1076
  num_module (int): The number of grid modules in the network.
@@ -1163,6 +1199,16 @@ class HierarchicalNetwork(BasicModelGroup):
1163
1199
  self.decoded_pos = bm.Variable(bm.zeros(2))
1164
1200
 
1165
1201
  def update(self, velocity, loc, loc_input_stre=0.0):
1202
+ """Advance the full network by one time step.
1203
+
1204
+ Args:
1205
+ velocity (Array): 2D velocity vector, shape ``(2,)``.
1206
+ loc (Array): 2D position vector, shape ``(2,)``.
1207
+ loc_input_stre (float): Strength of optional location-based input.
1208
+
1209
+ Returns:
1210
+ None
1211
+ """
1166
1212
  grid_output = bm.zeros(self.num_place)
1167
1213
  for i in range(self.num_module):
1168
1214
  # update the band cell module
@@ -9,18 +9,25 @@ __all__ = ["AmariHopfieldNetwork"]
9
9
 
10
10
 
11
11
  class AmariHopfieldNetwork(BrainInspiredModel):
12
- """
13
- Amari-Hopfield Network implementation supporting both discrete and continuous dynamics.
14
-
15
- This class implements Hopfield networks with flexible activation functions,
16
- supporting both discrete binary states and continuous dynamics. The network
17
- performs pattern completion through energy minimization using asynchronous
18
- or synchronous updates.
19
-
20
- The network energy function:
21
- E = -0.5 * Σ_ij W_ij * s_i * s_j
22
-
23
- Where s_i can be discrete {-1, +1} or continuous depending on activation function.
12
+ """Amari-Hopfield network with discrete or continuous dynamics.
13
+
14
+ The model performs pattern completion by iteratively updating the state
15
+ vector ``s`` to reduce energy:
16
+ E = -0.5 * sum_ij W_ij * s_i * s_j
17
+
18
+ Examples:
19
+ >>> import jax.numpy as jnp
20
+ >>> from canns.models.brain_inspired import AmariHopfieldNetwork
21
+ >>>
22
+ >>> model = AmariHopfieldNetwork(num_neurons=3, activation="sign")
23
+ >>> pattern = jnp.array([1.0, -1.0, 1.0], dtype=jnp.float32)
24
+ >>> weights = jnp.outer(pattern, pattern)
25
+ >>> weights = weights - jnp.diag(jnp.diag(weights)) # zero diagonal
26
+ >>> model.W.value = weights
27
+ >>> model.s.value = jnp.array([1.0, 1.0, -1.0], dtype=jnp.float32)
28
+ >>> model.update(None)
29
+ >>> model.s.value.shape
30
+ (3,)
24
31
 
25
32
  Reference:
26
33
  Amari, S. (1977). Neural theory of association and concept-formation.
@@ -78,8 +85,13 @@ class AmariHopfieldNetwork(BrainInspiredModel):
78
85
  raise ValueError(f"Unknown activation type: {activation}")
79
86
 
80
87
  def update(self, e_old):
81
- """
82
- Update network state for one time step.
88
+ """Update network state for one time step.
89
+
90
+ Args:
91
+ e_old: Unused placeholder for trainer compatibility.
92
+
93
+ Returns:
94
+ None
83
95
  """
84
96
  if self.asyn:
85
97
  self._asynchronous_update()
@@ -11,22 +11,22 @@ __all__ = ["LinearLayer"]
11
11
 
12
12
 
13
13
  class LinearLayer(BrainInspiredModel):
14
- """
15
- Generic linear feedforward layer supporting multiple brain-inspired learning rules.
16
-
17
- This model provides a simple linear transformation with optional sliding threshold
18
- for BCM-style plasticity. It can be used with various trainers:
19
- - OjaTrainer: Normalized Hebbian learning for PCA
20
- - BCMTrainer: Sliding threshold plasticity (requires use_bcm_threshold=True)
21
- - HebbianTrainer: Standard Hebbian learning
14
+ """Generic linear feedforward layer for brain-inspired learning rules.
22
15
 
23
- Computation:
16
+ It computes a simple linear transform:
24
17
  y = W @ x
25
18
 
26
- where W is the weight matrix, x is the input, and y is the output.
19
+ You can pair it with trainers like ``OjaTrainer``, ``BCMTrainer``, or
20
+ ``HebbianTrainer``.
27
21
 
28
- For BCM learning, an optional sliding threshold θ tracks output activity:
29
- θ θ + (1/τ) * (y² - θ)
22
+ Examples:
23
+ >>> import jax.numpy as jnp
24
+ >>> from canns.models.brain_inspired import LinearLayer
25
+ >>>
26
+ >>> layer = LinearLayer(input_size=3, output_size=2)
27
+ >>> y = layer.forward(jnp.array([1.0, 0.5, -1.0], dtype=jnp.float32))
28
+ >>> y.shape
29
+ (2,)
30
30
 
31
31
  References:
32
32
  - Oja (1982): Simplified neuron model as a principal component analyzer
@@ -73,14 +73,13 @@ class LinearLayer(BrainInspiredModel):
73
73
  self.theta = bm.Variable(jnp.ones(self.output_size, dtype=jnp.float32) * 0.1)
74
74
 
75
75
  def forward(self, x: jnp.ndarray) -> jnp.ndarray:
76
- """
77
- Forward pass through the layer.
76
+ """Compute the layer output for one input vector.
78
77
 
79
78
  Args:
80
- x: Input vector of shape (input_size,)
79
+ x: Input vector of shape ``(input_size,)``.
81
80
 
82
81
  Returns:
83
- Output vector of shape (output_size,)
82
+ Output vector of shape ``(output_size,)``.
84
83
  """
85
84
  self.x.value = jnp.asarray(x, dtype=jnp.float32)
86
85
  self.y.value = self.W.value @ self.x.value
@@ -12,15 +12,9 @@ __all__ = ["SpikingLayer"]
12
12
 
13
13
 
14
14
  class SpikingLayer(BrainInspiredModel):
15
- """
16
- Simple Leaky Integrate-and-Fire (LIF) spiking neuron layer.
15
+ """Simple Leaky Integrate-and-Fire (LIF) spiking neuron layer.
17
16
 
18
- This model provides a minimal spiking neuron implementation for demonstrating
19
- spike-timing-dependent plasticity (STDP). It features:
20
- - Leaky integration of input currents
21
- - Threshold-based spike generation
22
- - Reset mechanism after spiking
23
- - Exponential spike traces for STDP learning
17
+ It supports STDP-style training by maintaining pre/post spike traces.
24
18
 
25
19
  Dynamics:
26
20
  v[t+1] = leak * v[t] + W @ x[t]
@@ -28,6 +22,23 @@ class SpikingLayer(BrainInspiredModel):
28
22
  v = v_reset if spike else v
29
23
  trace = decay * trace + spike
30
24
 
25
+ Notes:
26
+ - x[t] denotes the input current at time t. It can take arbitrary continuous
27
+ values; binary {0, 1} spike trains are a special case of such inputs.
28
+ - The layer does not internally binarize x; thresholding only applies to the
29
+ membrane potential to generate output spikes.
30
+
31
+ Examples:
32
+ >>> import jax.numpy as jnp
33
+ >>> from canns.models.brain_inspired import SpikingLayer
34
+ >>>
35
+ >>> layer = SpikingLayer(input_size=3, output_size=2, threshold=0.5)
36
+ >>> # Continuous input currents (binary spikes {0,1} are a special case)
37
+ >>> x = jnp.array([0.2, 0.5, 1.3], dtype=jnp.float32)
38
+ >>> spikes = layer.forward(x)
39
+ >>> spikes.shape
40
+ (2,)
41
+
31
42
  References:
32
43
  - Gerstner & Kistler (2002): Spiking Neuron Models
33
44
  - Morrison et al. (2008): Phenomenological models of synaptic plasticity
@@ -88,14 +99,14 @@ class SpikingLayer(BrainInspiredModel):
88
99
  self.trace_post = bm.Variable(jnp.zeros(self.output_size, dtype=jnp.float32))
89
100
 
90
101
  def forward(self, x: jnp.ndarray) -> jnp.ndarray:
91
- """
92
- Forward pass through the spiking layer.
102
+ """Compute spikes for one time step.
93
103
 
94
104
  Args:
95
- x: Input spikes of shape (input_size,) with binary values (0 or 1)
105
+ x: Input currents of shape ``(input_size,)``. Can be continuous-valued
106
+ (e.g., synaptic currents) or binary spikes {0, 1} as a special case.
96
107
 
97
108
  Returns:
98
- Output spikes of shape (output_size,) with binary values (0 or 1)
109
+ Output spikes of shape ``(output_size,)`` with values 0 or 1.
99
110
  """
100
111
  self.x.value = jnp.asarray(x, dtype=jnp.float32)
101
112
 
@@ -7,15 +7,11 @@ the underlying implementations.
7
7
  """
8
8
 
9
9
  from ._base import Pipeline
10
- from .theta_sweep import (
11
- ThetaSweepPipeline,
12
- batch_process_trajectories,
13
- load_trajectory_from_csv,
14
- )
10
+ from .asa import ASAApp
11
+ from .asa import main as asa_main
15
12
 
16
13
  __all__ = [
17
14
  "Pipeline",
18
- "ThetaSweepPipeline",
19
- "load_trajectory_from_csv",
20
- "batch_process_trajectories",
15
+ "ASAApp",
16
+ "asa_main",
21
17
  ]
@@ -0,0 +1,21 @@
1
+ """ASA TUI - Terminal User Interface for ASA Analysis.
2
+
3
+ This module provides a Textual-based TUI for running ASA (Attractor State Analysis)
4
+ with 7 analysis modules: TDA, CohoMap, PathCompare, CohoSpace, FR, FRM, and GridScore.
5
+ """
6
+
7
+ import os
8
+
9
+ __all__ = ["ASAApp", "main"]
10
+
11
+
12
+ def main():
13
+ """Entry point for canns-tui command."""
14
+ os.environ.setdefault("MPLBACKEND", "Agg")
15
+ from .app import ASAApp
16
+
17
+ app = ASAApp()
18
+ app.run()
19
+
20
+
21
+ from .app import ASAApp
@@ -0,0 +1,11 @@
1
+ """Main entry point for running ASA TUI as a module."""
2
+
3
+ import os
4
+
5
+ os.environ.setdefault("MPLBACKEND", "Agg")
6
+
7
+ from .app import ASAApp
8
+
9
+ if __name__ == "__main__":
10
+ app = ASAApp()
11
+ app.run()