brainstate 0.1.2__py2.py3-none-any.whl → 0.1.4__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +0 -15
  3. brainstate/compile/_jit.py +14 -5
  4. brainstate/compile/_make_jaxpr.py +78 -22
  5. brainstate/compile/_make_jaxpr_test.py +13 -2
  6. brainstate/graph/_graph_node.py +1 -1
  7. brainstate/graph/_graph_operation.py +4 -4
  8. brainstate/mixin.py +30 -14
  9. brainstate/nn/__init__.py +84 -17
  10. brainstate/nn/{_interaction/_conv.py → _conv.py} +1 -1
  11. brainstate/nn/{_dynamics/_state_delay.py → _delay.py} +19 -3
  12. brainstate/nn/{_elementwise/_dropout.py → _dropout.py} +6 -5
  13. brainstate/nn/{_dynamics/_dynamics_base.py → _dynamics.py} +137 -21
  14. brainstate/nn/{_elementwise/_elementwise.py → _elementwise.py} +1 -1
  15. brainstate/nn/{_interaction/_embedding.py → _embedding.py} +1 -1
  16. brainstate/nn/{_event/_fixedprob_mv.py → _fixedprob.py} +96 -25
  17. brainstate/nn/{_dyn_impl/_inputs.py → _inputs.py} +4 -5
  18. brainstate/nn/{_interaction/_linear.py → _linear.py} +2 -5
  19. brainstate/nn/{_event/_linear_mv.py → _linear_mv.py} +2 -2
  20. brainstate/nn/{_event/__init__.py → _ltp.py} +7 -5
  21. brainstate/nn/_module.py +5 -5
  22. brainstate/nn/{_dyn_impl/_dynamics_neuron.py → _neuron.py} +2 -2
  23. brainstate/nn/{_interaction/_normalizations.py → _normalizations.py} +1 -1
  24. brainstate/nn/{_interaction/_poolings.py → _poolings.py} +1 -1
  25. brainstate/nn/{_interaction/_poolings_test.py → _poolings_test.py} +1 -1
  26. brainstate/nn/_projection.py +486 -0
  27. brainstate/nn/{_dyn_impl/_rate_rnns.py → _rate_rnns.py} +2 -2
  28. brainstate/nn/{_dyn_impl/_readout.py → _readout.py} +3 -3
  29. brainstate/nn/_stp.py +236 -0
  30. brainstate/nn/{_dyn_impl/_dynamics_synapse.py → _synapse.py} +19 -212
  31. brainstate/nn/_synaptic_projection.py +423 -0
  32. brainstate/nn/{_dynamics/_synouts.py → _synouts.py} +4 -1
  33. brainstate/surrogate.py +1 -1
  34. brainstate/typing.py +1 -1
  35. brainstate/util/__init__.py +14 -14
  36. brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
  37. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/METADATA +1 -1
  38. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/RECORD +61 -63
  39. brainstate/nn/_dyn_impl/__init__.py +0 -42
  40. brainstate/nn/_dynamics/__init__.py +0 -37
  41. brainstate/nn/_dynamics/_projection_base.py +0 -362
  42. brainstate/nn/_elementwise/__init__.py +0 -22
  43. brainstate/nn/_interaction/__init__.py +0 -41
  44. /brainstate/nn/{_interaction/_conv_test.py → _conv_test.py} +0 -0
  45. /brainstate/nn/{_elementwise/_dropout_test.py → _dropout_test.py} +0 -0
  46. /brainstate/nn/{_dynamics/_dynamics_base_test.py → _dynamics_test.py} +0 -0
  47. /brainstate/nn/{_elementwise/_elementwise_test.py → _elementwise_test.py} +0 -0
  48. /brainstate/nn/{_event/_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
  49. /brainstate/nn/{_event/_linear_mv_test.py → _linear_mv_test.py} +0 -0
  50. /brainstate/nn/{_interaction/_linear_test.py → _linear_test.py} +0 -0
  51. /brainstate/nn/{_dyn_impl/_dynamics_neuron_test.py → _neuron_test.py} +0 -0
  52. /brainstate/nn/{_interaction/_normalizations_test.py → _normalizations_test.py} +0 -0
  53. /brainstate/nn/{_dyn_impl/_rate_rnns_test.py → _rate_rnns_test.py} +0 -0
  54. /brainstate/nn/{_dyn_impl/_readout_test.py → _readout_test.py} +0 -0
  55. /brainstate/nn/{_dyn_impl/_dynamics_synapse_test.py → _synapse_test.py} +0 -0
  56. /brainstate/nn/{_dynamics/_synouts_test.py → _synouts_test.py} +0 -0
  57. /brainstate/util/{_caller.py → caller.py} +0 -0
  58. /brainstate/util/{_error.py → error.py} +0 -0
  59. /brainstate/util/{_others.py → others.py} +0 -0
  60. /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
  61. /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
  62. /brainstate/util/{_scaling.py → scaling.py} +0 -0
  63. /brainstate/util/{_struct.py → struct.py} +0 -0
  64. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/LICENSE +0 -0
  65. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/WHEEL +0 -0
  66. {brainstate-0.1.2.dist-info → brainstate-0.1.4.dist-info}/top_level.txt +0 -0
@@ -16,19 +16,20 @@
16
16
 
17
17
  from typing import Union, Callable, Optional
18
18
 
19
+ import brainevent
19
20
  import brainunit as u
20
21
  import jax
21
22
  import jax.numpy as jnp
22
23
  import numpy as np
23
24
 
24
25
  from brainstate import random, augment, environ, init
25
- from brainstate._compatible_import import brainevent
26
- from brainstate._state import ParamState
26
+ from brainstate._state import ParamState, FakeState
27
27
  from brainstate.compile import for_loop
28
- from brainstate.nn._module import Module
29
28
  from brainstate.typing import Size, ArrayLike
29
+ from ._module import Module
30
30
 
31
31
  __all__ = [
32
+ 'FixedNumConn',
32
33
  'EventFixedNumConn',
33
34
  'EventFixedProb',
34
35
  ]
@@ -44,12 +45,11 @@ def init_indices_without_replace(
44
45
  rng = random.default_rng(seed)
45
46
 
46
47
  if method == 'vmap':
47
- @augment.vmap
48
- def rand_indices(key):
49
- rng.set_key(key)
48
+ @augment.vmap(axis_size=n_pre)
49
+ def rand_indices():
50
50
  return rng.choice(n_post, size=(conn_num,), replace=False)
51
51
 
52
- return rand_indices(rng.split_key(n_pre))
52
+ return rand_indices()
53
53
 
54
54
  elif method == 'for_loop':
55
55
  return for_loop(
@@ -61,9 +61,9 @@ def init_indices_without_replace(
61
61
  raise ValueError(f"Unknown method: {method}")
62
62
 
63
63
 
64
- class EventFixedNumConn(Module):
64
+ class FixedNumConn(Module):
65
65
  """
66
- The FixedProb module implements a fixed probability connection with CSR sparse data structure.
66
+ The ``FixedNumConn`` module implements a fixed probability connection with CSR sparse data structure.
67
67
 
68
68
  Parameters
69
69
  ----------
@@ -77,7 +77,7 @@ class EventFixedNumConn(Module):
77
77
  If it is an integer, representing the number of connections.
78
78
  conn_weight : float or callable or jax.Array or brainunit.Quantity
79
79
  Maximum synaptic conductance, i.e., synaptic weight.
80
- conn_target : str, optional
80
+ efferent_target : str, optional
81
81
  The target of the connection. Default is 'post', meaning that each pre-synaptic neuron connects to
82
82
  a fixed number of post-synaptic neurons. The connection number is determined by the value of ``n_conn``.
83
83
 
@@ -104,7 +104,8 @@ class EventFixedNumConn(Module):
104
104
  out_size: Size,
105
105
  conn_num: Union[int, float],
106
106
  conn_weight: Union[Callable, ArrayLike],
107
- conn_target: str = 'post', # 'pre' or 'post'
107
+ efferent_target: str = 'post', # 'pre' or 'post'
108
+ afferent_ratio: Union[int, float] = 1.,
108
109
  allow_multi_conn: bool = True,
109
110
  seed: Optional[int] = None,
110
111
  name: Optional[str] = None,
@@ -116,11 +117,14 @@ class EventFixedNumConn(Module):
116
117
  # network parameters
117
118
  self.in_size = in_size
118
119
  self.out_size = out_size
119
- self.conn_target = conn_target
120
- assert conn_target in ('pre', 'post'), 'The target of the connection must be either "pre" or "post".'
120
+ self.efferent_target = efferent_target
121
+ assert efferent_target in ('pre', 'post'), 'The target of the connection must be either "pre" or "post".'
122
+ assert 0. <= afferent_ratio <= 1., 'Afferent ratio must be in [0, 1].'
121
123
  if isinstance(conn_num, float):
122
124
  assert 0. <= conn_num <= 1., 'Connection probability must be in [0, 1].'
123
- conn_num = int(self.out_size[-1] * conn_num) if conn_target == 'post' else int(self.in_size[-1] * conn_num)
125
+ conn_num = (int(self.out_size[-1] * conn_num)
126
+ if efferent_target == 'post' else
127
+ int(self.in_size[-1] * conn_num))
124
128
  assert isinstance(conn_num, int), 'Connection number must be an integer.'
125
129
  self.conn_num = conn_num
126
130
  self.seed = seed
@@ -128,14 +132,13 @@ class EventFixedNumConn(Module):
128
132
 
129
133
  # connections
130
134
  if self.conn_num >= 1:
131
- if self.conn_target == 'post':
135
+ if self.efferent_target == 'post':
132
136
  n_post = self.out_size[-1]
133
137
  n_pre = self.in_size[-1]
134
138
  else:
135
139
  n_post = self.in_size[-1]
136
140
  n_pre = self.out_size[-1]
137
141
 
138
- # indices of post connected neurons
139
142
  with jax.ensure_compile_time_eval():
140
143
  if allow_multi_conn:
141
144
  rng = np.random if seed is None else np.random.RandomState(seed)
@@ -143,15 +146,83 @@ class EventFixedNumConn(Module):
143
146
  else:
144
147
  indices = init_indices_without_replace(self.conn_num, n_pre, n_post, seed, conn_init)
145
148
  indices = u.math.asarray(indices, dtype=environ.ditype())
146
- conn_weight = init.param(conn_weight, (n_pre, self.conn_num), allow_none=False)
147
- conn_weight = u.math.asarray(conn_weight)
148
- self.weight = param_type(conn_weight)
149
- csr = (
150
- brainevent.FixedPostNumConn((conn_weight, indices), shape=(n_pre, n_post))
151
- if self.conn_target == 'post' else
152
- brainevent.FixedPreNumConn((conn_weight, indices), shape=(n_pre, n_post))
153
- )
154
- self.conn = csr
149
+
150
+ if afferent_ratio == 1.:
151
+ conn_weight = u.math.asarray(init.param(conn_weight, (n_pre, self.conn_num), allow_none=False))
152
+ self.weight = param_type(conn_weight)
153
+ csr = (
154
+ brainevent.FixedPostNumConn((conn_weight, indices), shape=(n_pre, n_post))
155
+ if self.efferent_target == 'post' else
156
+ brainevent.FixedPreNumConn((conn_weight, indices), shape=(n_pre, n_post))
157
+ )
158
+ self.conn = csr
159
+
160
+ else:
161
+ self.pre_selected = np.random.random(n_pre) < afferent_ratio
162
+ indices = indices[self.pre_selected].flatten()
163
+ conn_weight = u.math.asarray(init.param(conn_weight, (indices.size,), allow_none=False))
164
+ self.weight = param_type(conn_weight)
165
+ indptr = (jnp.arange(1, n_pre + 1) * self.conn_num -
166
+ jnp.cumsum(~self.pre_selected) * self.conn_num)
167
+ indptr = jnp.insert(indptr, 0, 0) # insert 0 at the beginning
168
+ csr = (
169
+ brainevent.CSR((conn_weight, indices, indptr), shape=(n_pre, n_post))
170
+ if self.efferent_target == 'post' else
171
+ brainevent.CSC((conn_weight, indices, indptr), shape=(n_pre, n_post))
172
+ )
173
+ self.conn = csr
174
+
175
+ else:
176
+ conn_weight = u.math.asarray(init.param(conn_weight, (), allow_none=False))
177
+ self.weight = FakeState(conn_weight)
178
+
179
+ def update(self, x: jax.Array) -> Union[jax.Array, u.Quantity]:
180
+ if self.conn_num >= 1:
181
+ csr = self.conn.with_data(self.weight.value)
182
+ return x @ csr
183
+ else:
184
+ weight = self.weight.value
185
+ r = u.math.zeros(x.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
186
+ r = u.maybe_decimal(u.Quantity(r, unit=u.get_unit(weight)))
187
+ return u.math.asarray(r, dtype=environ.dftype())
188
+
189
+
190
+ class EventFixedNumConn(FixedNumConn):
191
+ """
192
+ The FixedProb module implements a fixed probability connection with CSR sparse data structure.
193
+
194
+ Parameters
195
+ ----------
196
+ in_size : Size
197
+ Number of pre-synaptic neurons, i.e., input size.
198
+ out_size : Size
199
+ Number of post-synaptic neurons, i.e., output size.
200
+ conn_num : float, int
201
+ If it is a float, representing the probability of connection, i.e., connection probability.
202
+
203
+ If it is an integer, representing the number of connections.
204
+ conn_weight : float or callable or jax.Array or brainunit.Quantity
205
+ Maximum synaptic conductance, i.e., synaptic weight.
206
+ conn_target : str, optional
207
+ The target of the connection. Default is 'post', meaning that each pre-synaptic neuron connects to
208
+ a fixed number of post-synaptic neurons. The connection number is determined by the value of ``n_conn``.
209
+
210
+ If 'pre', each post-synaptic neuron connects to a fixed number of pre-synaptic neurons.
211
+ conn_init : str, optional
212
+ The initialization method of the connection weight. Default is 'vmap', meaning that the connection weight
213
+ is initialized by parallelized across multiple threads.
214
+
215
+ If 'for_loop', the connection weight is initialized by a for loop.
216
+ allow_multi_conn : bool, optional
217
+ Whether multiple connections are allowed from a single pre-synaptic neuron.
218
+ Default is True, meaning that a value of ``a`` can be selected multiple times.
219
+ seed: int, optional
220
+ Random seed. Default is None. If None, the default random seed will be used.
221
+ name : str, optional
222
+ Name of the module.
223
+ """
224
+
225
+ __module__ = 'brainstate.nn'
155
226
 
156
227
  def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
157
228
  if self.conn_num >= 1:
@@ -20,12 +20,11 @@ import jax
20
20
  import numpy as np
21
21
 
22
22
  from brainstate import environ, init, random
23
- from brainstate._state import ShortTermState
24
- from brainstate._state import State, maybe_state
23
+ from brainstate._state import ShortTermState, State, maybe_state
25
24
  from brainstate.compile import while_loop
26
- from brainstate.nn._dynamics._dynamics_base import Dynamics, Prefetch
27
- from brainstate.nn._module import Module
28
25
  from brainstate.typing import ArrayLike, Size, DTypeLike
26
+ from ._dynamics import Dynamics, Prefetch
27
+ from ._module import Module
29
28
 
30
29
  __all__ = [
31
30
  'SpikeTime',
@@ -134,7 +133,7 @@ class PoissonSpike(Dynamics):
134
133
  self.freqs = init.param(freqs, self.varshape, allow_none=False)
135
134
 
136
135
  def update(self):
137
- spikes = random.rand(self.varshape) <= (self.freqs * environ.get_dt())
136
+ spikes = random.rand(*self.varshape) <= (self.freqs * environ.get_dt())
138
137
  spikes = u.math.asarray(spikes, dtype=self.spk_type)
139
138
  return spikes
140
139
 
@@ -22,8 +22,8 @@ import jax.numpy as jnp
22
22
 
23
23
  from brainstate import init, functional
24
24
  from brainstate._state import ParamState
25
- from brainstate.nn._module import Module
26
25
  from brainstate.typing import ArrayLike, Size
26
+ from ._module import Module
27
27
 
28
28
  __all__ = [
29
29
  'Linear',
@@ -350,10 +350,7 @@ class OneToOne(Module):
350
350
  self.weight = param_type(param)
351
351
 
352
352
  def update(self, pre_val):
353
- pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
354
- w_val, w_unit = u.get_mantissa(self.weight.value['weight']), u.get_unit(self.weight.value['weight'])
355
- post_val = pre_val * w_val
356
- post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
353
+ post_val = pre_val * self.weight.value['weight']
357
354
  if 'bias' in self.weight.value:
358
355
  post_val = post_val + self.weight.value['bias']
359
356
  return post_val
@@ -19,10 +19,10 @@ import brainunit as u
19
19
  import jax
20
20
 
21
21
  from brainstate import init
22
- from brainstate._compatible_import import brainevent
22
+ import brainevent
23
23
  from brainstate._state import ParamState
24
- from brainstate.nn._module import Module
25
24
  from brainstate.typing import Size, ArrayLike
25
+ from ._module import Module
26
26
 
27
27
  __all__ = [
28
28
  'EventLinear',
@@ -16,11 +16,13 @@
16
16
  # -*- coding: utf-8 -*-
17
17
 
18
18
 
19
- from ._fixedprob_mv import EventFixedProb, EventFixedNumConn
20
- from ._linear_mv import EventLinear
19
+ from ._synapse import Synapse
21
20
 
22
21
  __all__ = [
23
- 'EventLinear',
24
- 'EventFixedProb',
25
- 'EventFixedNumConn',
22
+ 'LongTermPlasticity',
26
23
  ]
24
+
25
+
26
+ class LongTermPlasticity(Synapse):
27
+ pass
28
+
brainstate/nn/_module.py CHANGED
@@ -34,7 +34,7 @@ import numpy as np
34
34
  from brainstate._state import State
35
35
  from brainstate.graph import Node, states, nodes, flatten
36
36
  from brainstate.mixin import ParamDescriber, ParamDesc
37
- from brainstate.typing import PathParts
37
+ from brainstate.typing import PathParts, Size
38
38
  from brainstate.util import FlattedDict, NestedDict, BrainStateError
39
39
 
40
40
  # maximum integer
@@ -62,8 +62,8 @@ class Module(Node, ParamDesc):
62
62
 
63
63
  __module__ = 'brainstate.nn'
64
64
 
65
- _in_size: Optional[Tuple[int, ...]]
66
- _out_size: Optional[Tuple[int, ...]]
65
+ _in_size: Optional[Size]
66
+ _out_size: Optional[Size]
67
67
  _name: Optional[str]
68
68
 
69
69
  if not TYPE_CHECKING:
@@ -87,7 +87,7 @@ class Module(Node, ParamDesc):
87
87
  raise AttributeError('The name of the model is read-only.')
88
88
 
89
89
  @property
90
- def in_size(self) -> Tuple[int, ...]:
90
+ def in_size(self) -> Size:
91
91
  return self._in_size
92
92
 
93
93
  @in_size.setter
@@ -98,7 +98,7 @@ class Module(Node, ParamDesc):
98
98
  self._in_size = tuple(in_size)
99
99
 
100
100
  @property
101
- def out_size(self) -> Tuple[int, ...]:
101
+ def out_size(self) -> Size:
102
102
  return self._out_size
103
103
 
104
104
  @out_size.setter
@@ -22,9 +22,9 @@ import jax
22
22
 
23
23
  from brainstate import init, surrogate, environ
24
24
  from brainstate._state import HiddenState, ShortTermState
25
- from brainstate.nn._dynamics._dynamics_base import Dynamics
26
- from brainstate.nn._exp_euler import exp_euler_step
27
25
  from brainstate.typing import ArrayLike, Size
26
+ from ._dynamics import Dynamics
27
+ from ._exp_euler import exp_euler_step
28
28
 
29
29
  __all__ = [
30
30
  'Neuron', 'IF', 'LIF', 'LIFRef', 'ALIF',
@@ -22,8 +22,8 @@ import jax.numpy as jnp
22
22
 
23
23
  from brainstate import environ, init
24
24
  from brainstate._state import ParamState, BatchState
25
- from brainstate.nn._module import Module
26
25
  from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
26
+ from ._module import Module
27
27
 
28
28
  __all__ = [
29
29
  'BatchNorm0d',
@@ -25,8 +25,8 @@ import jax.numpy as jnp
25
25
  import numpy as np
26
26
 
27
27
  from brainstate import environ
28
- from brainstate.nn._module import Module
29
28
  from brainstate.typing import Size
29
+ from ._module import Module
30
30
 
31
31
  __all__ = [
32
32
  'Flatten', 'Unflatten',
@@ -103,7 +103,7 @@ class TestPool(parameterized.TestCase):
103
103
  for target_size in [10, 9, 8, 7, 6]
104
104
  )
105
105
  def test_adaptive_pool1d(self, target_size):
106
- from brainstate.nn._interaction._poolings import _adaptive_pool1d
106
+ from brainstate.nn._poolings import _adaptive_pool1d
107
107
 
108
108
  arr = brainstate.random.rand(100)
109
109
  op = jax.numpy.mean