brainstate 0.1.3__py2.py3-none-any.whl → 0.1.4__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +1 -16
  3. brainstate/compile/_jit.py +14 -5
  4. brainstate/compile/_make_jaxpr.py +78 -22
  5. brainstate/compile/_make_jaxpr_test.py +13 -2
  6. brainstate/graph/_graph_node.py +1 -1
  7. brainstate/graph/_graph_operation.py +4 -4
  8. brainstate/mixin.py +31 -2
  9. brainstate/nn/__init__.py +8 -5
  10. brainstate/nn/_delay.py +13 -1
  11. brainstate/nn/_dropout.py +5 -4
  12. brainstate/nn/_dynamics.py +39 -44
  13. brainstate/nn/{_fixedprob_mv.py → _fixedprob.py} +95 -24
  14. brainstate/nn/_linear_mv.py +1 -1
  15. brainstate/nn/_module.py +5 -5
  16. brainstate/nn/_projection.py +190 -98
  17. brainstate/nn/_synapse.py +5 -9
  18. brainstate/nn/_synaptic_projection.py +376 -86
  19. brainstate/surrogate.py +1 -1
  20. brainstate/typing.py +1 -1
  21. brainstate/util/__init__.py +14 -14
  22. brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
  23. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/METADATA +1 -1
  24. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/RECORD +35 -35
  25. /brainstate/nn/{_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
  26. /brainstate/util/{_caller.py → caller.py} +0 -0
  27. /brainstate/util/{_error.py → error.py} +0 -0
  28. /brainstate/util/{_others.py → others.py} +0 -0
  29. /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
  30. /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
  31. /brainstate/util/{_scaling.py → scaling.py} +0 -0
  32. /brainstate/util/{_struct.py → struct.py} +0 -0
  33. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/LICENSE +0 -0
  34. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/WHEEL +0 -0
  35. {brainstate-0.1.3.dist-info → brainstate-0.1.4.dist-info}/top_level.txt +0 -0
@@ -36,13 +36,14 @@ For handling the delays:
36
36
  from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
37
37
 
38
38
  import brainunit as u
39
+ import jax
39
40
  import numpy as np
40
41
 
41
42
  from brainstate import environ
42
43
  from brainstate._state import State
43
44
  from brainstate.graph import Node
44
45
  from brainstate.mixin import ParamDescriber, UpdateReturn
45
- from brainstate.typing import Size, ArrayLike
46
+ from brainstate.typing import Size, ArrayLike, PyTree
46
47
  from ._delay import StateWithDelay, Delay
47
48
  from ._module import Module
48
49
 
@@ -811,18 +812,25 @@ class Dynamics(Module, UpdateReturn):
811
812
  >>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
812
813
  """
813
814
  if isinstance(dyn, Dynamics):
814
- self._add_after_update(dyn.name, dyn)
815
+ self._add_after_update(id(dyn), dyn)
815
816
  return dyn
816
817
  elif isinstance(dyn, ParamDescriber):
817
818
  if not issubclass(dyn.cls, Dynamics):
818
819
  raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
819
820
  if not self._has_after_update(dyn.identifier):
820
- self._add_after_update(dyn.identifier, dyn())
821
+ self._add_after_update(
822
+ dyn.identifier,
823
+ dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape)
824
+ )
821
825
  return self._get_after_update(dyn.identifier)
822
826
  else:
823
827
  raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
824
828
 
825
- def prefetch_delay(self, state: str, delay: Optional[ArrayLike] = None) -> 'PrefetchDelayAt':
829
+ def prefetch_delay(
830
+ self,
831
+ state: str,
832
+ delay: Optional[ArrayLike] = None
833
+ ) -> 'PrefetchDelayAt':
826
834
  """
827
835
  Create a reference to a delayed state or variable in the module.
828
836
 
@@ -840,7 +848,11 @@ class Dynamics(Module, UpdateReturn):
840
848
  """
841
849
  return self.prefetch(state).delay.at(delay)
842
850
 
843
- def output_delay(self, delay: Optional[ArrayLike] = None) -> 'OutputDelayAt':
851
+ def output_delay(
852
+ self,
853
+ delay: Optional[ArrayLike] = None,
854
+ variable_like: PyTree = None
855
+ ) -> 'OutputDelayAt':
844
856
  """
845
857
  Create a reference to the delayed output of the module.
846
858
 
@@ -851,6 +863,7 @@ class Dynamics(Module, UpdateReturn):
851
863
  Args:
852
864
  delay (Optional[ArrayLike]): The amount of time to delay the output access,
853
865
  typically in time units (e.g., milliseconds). Defaults to None.
866
+ variable_like:
854
867
 
855
868
  Returns:
856
869
  OutputDelayAt: An object that provides access to the module's output at the specified delay time.
@@ -1102,8 +1115,6 @@ class OutputDelayAt(Node):
1102
1115
  ----------
1103
1116
  module : Dynamics
1104
1117
  The dynamics module that contains the referenced state or variable.
1105
- item : str
1106
- The name of the state or variable to access with delay.
1107
1118
  time : ArrayLike
1108
1119
  The amount of time to delay access by, typically in time units (e.g., milliseconds).
1109
1120
 
@@ -1113,56 +1124,40 @@ class OutputDelayAt(Node):
1113
1124
  >>> import brainunit as u
1114
1125
  >>> neuron = brainstate.nn.LIF(10)
1115
1126
  >>> # Create a reference to voltage delayed by 5ms
1116
- >>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
1127
+ >>> delayed_spike = OutputDelayAt(neuron, 5.0 * u.ms)
1117
1128
  >>> # Get the delayed value
1118
- >>> v_value = delayed_v()
1129
+ >>> v_value = delayed_spike()
1119
1130
  """
1120
1131
 
1121
1132
  def __init__(
1122
1133
  self,
1123
1134
  module: Dynamics,
1124
1135
  time: Optional[ArrayLike] = None,
1136
+ variable_like: Optional[PyTree] = None,
1125
1137
  ):
1126
- """
1127
- Initialize a PrefetchDelayAt object.
1128
-
1129
- Parameters
1130
- ----------
1131
- module : AlignPre, Module
1132
- The dynamics module that contains the referenced state or variable.
1133
- time : ArrayLike
1134
- The amount of time to delay access by, typically in time units.
1135
- """
1136
1138
  super().__init__()
1137
- assert isinstance(module, UpdateReturn), 'The module should implement the `update_return` method.'
1138
- assert isinstance(module, Module), 'The module should be an instance of Module.'
1139
+ assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1139
1140
  self.module = module
1141
+ dt = environ.get_dt()
1142
+ if time is None:
1143
+ time = u.math.zeros_like(dt)
1140
1144
  self.time = time
1141
- if time is not None:
1142
- self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
1143
-
1144
- # register the delay
1145
- key = _get_output_delay_key()
1146
- if not module._has_after_update(key):
1147
- # TODO: unit processing
1148
- delay = Delay(module.update_return(), time)
1149
- module._add_after_update(key, receive_update_output(delay))
1150
- self.out_delay: Delay = module._get_after_update(key)
1151
- self.out_delay.register_delay(time)
1145
+ self.step = u.math.asarray(time / dt, dtype=environ.ditype())
1146
+
1147
+ # register the delay
1148
+ key = _get_output_delay_key()
1149
+ if not module._has_after_update(key):
1150
+ delay = Delay(
1151
+ jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()),
1152
+ time,
1153
+ take_aware_unit=True
1154
+ )
1155
+ module._add_after_update(key, receive_update_output(delay))
1156
+ self.out_delay: Delay = module._get_after_update(key)
1157
+ self.out_delay.register_delay(time)
1152
1158
 
1153
1159
  def __call__(self, *args, **kwargs):
1154
- """
1155
- Retrieve the value of the state at the specified delay time.
1156
-
1157
- Returns
1158
- -------
1159
- Any
1160
- The value of the state or variable at the specified delay time.
1161
- """
1162
- if self.time is None:
1163
- return self.module.update_return()
1164
- else:
1165
- return self.out_delay.retrieve_at_step(self.step)
1160
+ return self.out_delay.retrieve_at_step(self.step)
1166
1161
 
1167
1162
 
1168
1163
  def _get_prefetch_delay_key(item) -> str:
@@ -16,19 +16,20 @@
16
16
 
17
17
  from typing import Union, Callable, Optional
18
18
 
19
+ import brainevent
19
20
  import brainunit as u
20
21
  import jax
21
22
  import jax.numpy as jnp
22
23
  import numpy as np
23
24
 
24
25
  from brainstate import random, augment, environ, init
25
- from brainstate._compatible_import import brainevent
26
- from brainstate._state import ParamState
26
+ from brainstate._state import ParamState, FakeState
27
27
  from brainstate.compile import for_loop
28
28
  from brainstate.typing import Size, ArrayLike
29
29
  from ._module import Module
30
30
 
31
31
  __all__ = [
32
+ 'FixedNumConn',
32
33
  'EventFixedNumConn',
33
34
  'EventFixedProb',
34
35
  ]
@@ -44,12 +45,11 @@ def init_indices_without_replace(
44
45
  rng = random.default_rng(seed)
45
46
 
46
47
  if method == 'vmap':
47
- @augment.vmap
48
- def rand_indices(key):
49
- rng.set_key(key)
48
+ @augment.vmap(axis_size=n_pre)
49
+ def rand_indices():
50
50
  return rng.choice(n_post, size=(conn_num,), replace=False)
51
51
 
52
- return rand_indices(rng.split_key(n_pre))
52
+ return rand_indices()
53
53
 
54
54
  elif method == 'for_loop':
55
55
  return for_loop(
@@ -61,9 +61,9 @@ def init_indices_without_replace(
61
61
  raise ValueError(f"Unknown method: {method}")
62
62
 
63
63
 
64
- class EventFixedNumConn(Module):
64
+ class FixedNumConn(Module):
65
65
  """
66
- The FixedProb module implements a fixed probability connection with CSR sparse data structure.
66
+ The ``FixedNumConn`` module implements a fixed probability connection with CSR sparse data structure.
67
67
 
68
68
  Parameters
69
69
  ----------
@@ -77,7 +77,7 @@ class EventFixedNumConn(Module):
77
77
  If it is an integer, representing the number of connections.
78
78
  conn_weight : float or callable or jax.Array or brainunit.Quantity
79
79
  Maximum synaptic conductance, i.e., synaptic weight.
80
- conn_target : str, optional
80
+ efferent_target : str, optional
81
81
  The target of the connection. Default is 'post', meaning that each pre-synaptic neuron connects to
82
82
  a fixed number of post-synaptic neurons. The connection number is determined by the value of ``n_conn``.
83
83
 
@@ -104,7 +104,8 @@ class EventFixedNumConn(Module):
104
104
  out_size: Size,
105
105
  conn_num: Union[int, float],
106
106
  conn_weight: Union[Callable, ArrayLike],
107
- conn_target: str = 'post', # 'pre' or 'post'
107
+ efferent_target: str = 'post', # 'pre' or 'post'
108
+ afferent_ratio: Union[int, float] = 1.,
108
109
  allow_multi_conn: bool = True,
109
110
  seed: Optional[int] = None,
110
111
  name: Optional[str] = None,
@@ -116,11 +117,14 @@ class EventFixedNumConn(Module):
116
117
  # network parameters
117
118
  self.in_size = in_size
118
119
  self.out_size = out_size
119
- self.conn_target = conn_target
120
- assert conn_target in ('pre', 'post'), 'The target of the connection must be either "pre" or "post".'
120
+ self.efferent_target = efferent_target
121
+ assert efferent_target in ('pre', 'post'), 'The target of the connection must be either "pre" or "post".'
122
+ assert 0. <= afferent_ratio <= 1., 'Afferent ratio must be in [0, 1].'
121
123
  if isinstance(conn_num, float):
122
124
  assert 0. <= conn_num <= 1., 'Connection probability must be in [0, 1].'
123
- conn_num = int(self.out_size[-1] * conn_num) if conn_target == 'post' else int(self.in_size[-1] * conn_num)
125
+ conn_num = (int(self.out_size[-1] * conn_num)
126
+ if efferent_target == 'post' else
127
+ int(self.in_size[-1] * conn_num))
124
128
  assert isinstance(conn_num, int), 'Connection number must be an integer.'
125
129
  self.conn_num = conn_num
126
130
  self.seed = seed
@@ -128,14 +132,13 @@ class EventFixedNumConn(Module):
128
132
 
129
133
  # connections
130
134
  if self.conn_num >= 1:
131
- if self.conn_target == 'post':
135
+ if self.efferent_target == 'post':
132
136
  n_post = self.out_size[-1]
133
137
  n_pre = self.in_size[-1]
134
138
  else:
135
139
  n_post = self.in_size[-1]
136
140
  n_pre = self.out_size[-1]
137
141
 
138
- # indices of post connected neurons
139
142
  with jax.ensure_compile_time_eval():
140
143
  if allow_multi_conn:
141
144
  rng = np.random if seed is None else np.random.RandomState(seed)
@@ -143,15 +146,83 @@ class EventFixedNumConn(Module):
143
146
  else:
144
147
  indices = init_indices_without_replace(self.conn_num, n_pre, n_post, seed, conn_init)
145
148
  indices = u.math.asarray(indices, dtype=environ.ditype())
146
- conn_weight = init.param(conn_weight, (n_pre, self.conn_num), allow_none=False)
147
- conn_weight = u.math.asarray(conn_weight)
148
- self.weight = param_type(conn_weight)
149
- csr = (
150
- brainevent.FixedPostNumConn((conn_weight, indices), shape=(n_pre, n_post))
151
- if self.conn_target == 'post' else
152
- brainevent.FixedPreNumConn((conn_weight, indices), shape=(n_pre, n_post))
153
- )
154
- self.conn = csr
149
+
150
+ if afferent_ratio == 1.:
151
+ conn_weight = u.math.asarray(init.param(conn_weight, (n_pre, self.conn_num), allow_none=False))
152
+ self.weight = param_type(conn_weight)
153
+ csr = (
154
+ brainevent.FixedPostNumConn((conn_weight, indices), shape=(n_pre, n_post))
155
+ if self.efferent_target == 'post' else
156
+ brainevent.FixedPreNumConn((conn_weight, indices), shape=(n_pre, n_post))
157
+ )
158
+ self.conn = csr
159
+
160
+ else:
161
+ self.pre_selected = np.random.random(n_pre) < afferent_ratio
162
+ indices = indices[self.pre_selected].flatten()
163
+ conn_weight = u.math.asarray(init.param(conn_weight, (indices.size,), allow_none=False))
164
+ self.weight = param_type(conn_weight)
165
+ indptr = (jnp.arange(1, n_pre + 1) * self.conn_num -
166
+ jnp.cumsum(~self.pre_selected) * self.conn_num)
167
+ indptr = jnp.insert(indptr, 0, 0) # insert 0 at the beginning
168
+ csr = (
169
+ brainevent.CSR((conn_weight, indices, indptr), shape=(n_pre, n_post))
170
+ if self.efferent_target == 'post' else
171
+ brainevent.CSC((conn_weight, indices, indptr), shape=(n_pre, n_post))
172
+ )
173
+ self.conn = csr
174
+
175
+ else:
176
+ conn_weight = u.math.asarray(init.param(conn_weight, (), allow_none=False))
177
+ self.weight = FakeState(conn_weight)
178
+
179
+ def update(self, x: jax.Array) -> Union[jax.Array, u.Quantity]:
180
+ if self.conn_num >= 1:
181
+ csr = self.conn.with_data(self.weight.value)
182
+ return x @ csr
183
+ else:
184
+ weight = self.weight.value
185
+ r = u.math.zeros(x.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
186
+ r = u.maybe_decimal(u.Quantity(r, unit=u.get_unit(weight)))
187
+ return u.math.asarray(r, dtype=environ.dftype())
188
+
189
+
190
+ class EventFixedNumConn(FixedNumConn):
191
+ """
192
+ The FixedProb module implements a fixed probability connection with CSR sparse data structure.
193
+
194
+ Parameters
195
+ ----------
196
+ in_size : Size
197
+ Number of pre-synaptic neurons, i.e., input size.
198
+ out_size : Size
199
+ Number of post-synaptic neurons, i.e., output size.
200
+ conn_num : float, int
201
+ If it is a float, representing the probability of connection, i.e., connection probability.
202
+
203
+ If it is an integer, representing the number of connections.
204
+ conn_weight : float or callable or jax.Array or brainunit.Quantity
205
+ Maximum synaptic conductance, i.e., synaptic weight.
206
+ conn_target : str, optional
207
+ The target of the connection. Default is 'post', meaning that each pre-synaptic neuron connects to
208
+ a fixed number of post-synaptic neurons. The connection number is determined by the value of ``n_conn``.
209
+
210
+ If 'pre', each post-synaptic neuron connects to a fixed number of pre-synaptic neurons.
211
+ conn_init : str, optional
212
+ The initialization method of the connection weight. Default is 'vmap', meaning that the connection weight
213
+ is initialized by parallelized across multiple threads.
214
+
215
+ If 'for_loop', the connection weight is initialized by a for loop.
216
+ allow_multi_conn : bool, optional
217
+ Whether multiple connections are allowed from a single pre-synaptic neuron.
218
+ Default is True, meaning that a value of ``a`` can be selected multiple times.
219
+ seed: int, optional
220
+ Random seed. Default is None. If None, the default random seed will be used.
221
+ name : str, optional
222
+ Name of the module.
223
+ """
224
+
225
+ __module__ = 'brainstate.nn'
155
226
 
156
227
  def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
157
228
  if self.conn_num >= 1:
@@ -19,7 +19,7 @@ import brainunit as u
19
19
  import jax
20
20
 
21
21
  from brainstate import init
22
- from brainstate._compatible_import import brainevent
22
+ import brainevent
23
23
  from brainstate._state import ParamState
24
24
  from brainstate.typing import Size, ArrayLike
25
25
  from ._module import Module
brainstate/nn/_module.py CHANGED
@@ -34,7 +34,7 @@ import numpy as np
34
34
  from brainstate._state import State
35
35
  from brainstate.graph import Node, states, nodes, flatten
36
36
  from brainstate.mixin import ParamDescriber, ParamDesc
37
- from brainstate.typing import PathParts
37
+ from brainstate.typing import PathParts, Size
38
38
  from brainstate.util import FlattedDict, NestedDict, BrainStateError
39
39
 
40
40
  # maximum integer
@@ -62,8 +62,8 @@ class Module(Node, ParamDesc):
62
62
 
63
63
  __module__ = 'brainstate.nn'
64
64
 
65
- _in_size: Optional[Tuple[int, ...]]
66
- _out_size: Optional[Tuple[int, ...]]
65
+ _in_size: Optional[Size]
66
+ _out_size: Optional[Size]
67
67
  _name: Optional[str]
68
68
 
69
69
  if not TYPE_CHECKING:
@@ -87,7 +87,7 @@ class Module(Node, ParamDesc):
87
87
  raise AttributeError('The name of the model is read-only.')
88
88
 
89
89
  @property
90
- def in_size(self) -> Tuple[int, ...]:
90
+ def in_size(self) -> Size:
91
91
  return self._in_size
92
92
 
93
93
  @in_size.setter
@@ -98,7 +98,7 @@ class Module(Node, ParamDesc):
98
98
  self._in_size = tuple(in_size)
99
99
 
100
100
  @property
101
- def out_size(self) -> Tuple[int, ...]:
101
+ def out_size(self) -> Size:
102
102
  return self._out_size
103
103
 
104
104
  @out_size.setter