brainstate 0.1.3__py2.py3-none-any.whl → 0.1.5__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. brainstate/__init__.py +1 -1
  2. brainstate/_compatible_import.py +1 -16
  3. brainstate/_state.py +1 -0
  4. brainstate/augment/_mapping.py +9 -9
  5. brainstate/augment/_mapping_test.py +162 -0
  6. brainstate/compile/_jit.py +14 -5
  7. brainstate/compile/_make_jaxpr.py +78 -22
  8. brainstate/compile/_make_jaxpr_test.py +13 -2
  9. brainstate/graph/_graph_node.py +1 -1
  10. brainstate/graph/_graph_operation.py +4 -4
  11. brainstate/mixin.py +31 -2
  12. brainstate/nn/__init__.py +8 -5
  13. brainstate/nn/_common.py +7 -19
  14. brainstate/nn/_delay.py +13 -1
  15. brainstate/nn/_dropout.py +5 -4
  16. brainstate/nn/_dynamics.py +39 -44
  17. brainstate/nn/_exp_euler.py +13 -16
  18. brainstate/nn/{_fixedprob_mv.py → _fixedprob.py} +95 -24
  19. brainstate/nn/_inputs.py +1 -1
  20. brainstate/nn/_linear_mv.py +1 -1
  21. brainstate/nn/_module.py +5 -5
  22. brainstate/nn/_projection.py +190 -98
  23. brainstate/nn/_synapse.py +5 -9
  24. brainstate/nn/_synaptic_projection.py +376 -86
  25. brainstate/random/_rand_state.py +13 -7
  26. brainstate/surrogate.py +1 -1
  27. brainstate/typing.py +1 -1
  28. brainstate/util/__init__.py +14 -14
  29. brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
  30. {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/METADATA +1 -1
  31. {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/RECORD +42 -42
  32. /brainstate/nn/{_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
  33. /brainstate/util/{_caller.py → caller.py} +0 -0
  34. /brainstate/util/{_error.py → error.py} +0 -0
  35. /brainstate/util/{_others.py → others.py} +0 -0
  36. /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
  37. /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
  38. /brainstate/util/{_scaling.py → scaling.py} +0 -0
  39. /brainstate/util/{_struct.py → struct.py} +0 -0
  40. {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/LICENSE +0 -0
  41. {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/WHEEL +0 -0
  42. {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/top_level.txt +0 -0
@@ -88,7 +88,7 @@ class TestMakeJaxpr(unittest.TestCase):
88
88
  self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
89
89
  f3(jnp.zeros(1))))
90
90
 
91
- def test_compar_jax_make_jaxpr2(self):
91
+ def test_compare_jax_make_jaxpr2(self):
92
92
  st1 = brainstate.State(jnp.ones(10))
93
93
 
94
94
  def fa(x):
@@ -108,7 +108,7 @@ class TestMakeJaxpr(unittest.TestCase):
108
108
  print(jaxpr)
109
109
  print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
110
110
 
111
- def test_compar_jax_make_jaxpr3(self):
111
+ def test_compare_jax_make_jaxpr3(self):
112
112
  def fa(x):
113
113
  return 1.
114
114
 
@@ -121,6 +121,17 @@ class TestMakeJaxpr(unittest.TestCase):
121
121
  print(jaxpr)
122
122
  # print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
123
123
 
124
+ def test_static_argnames(self):
125
+ def func4(a, b): # Arg is a pair
126
+ temp = a + jnp.sin(b) * 3.
127
+ c = brainstate.random.rand_like(a)
128
+ return jnp.sum(temp + c)
129
+
130
+ jaxpr, states = brainstate.compile.make_jaxpr(func4, static_argnames='b')(jnp.zeros(8), 1.)
131
+ print()
132
+ print(jaxpr)
133
+ print(states)
134
+
124
135
 
125
136
  def test_return_states():
126
137
  import jax.numpy
@@ -25,7 +25,7 @@ import numpy as np
25
25
 
26
26
  from brainstate._state import State, TreefyState
27
27
  from brainstate.typing import Key
28
- from brainstate.util._pretty_pytree import PrettyObject
28
+ from brainstate.util.pretty_pytree import PrettyObject
29
29
  from ._graph_operation import register_graph_node_type
30
30
 
31
31
  __all__ = [
@@ -30,10 +30,10 @@ from typing_extensions import TypeGuard, Unpack
30
30
  from brainstate._state import State, TreefyState
31
31
  from brainstate._utils import set_module_as
32
32
  from brainstate.typing import PathParts, Filter, Predicate, Key
33
- from brainstate.util._caller import ApplyCaller, CallableProxy, DelayedAccessor
34
- from brainstate.util._pretty_pytree import NestedDict, FlattedDict, PrettyDict
35
- from brainstate.util._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
36
- from brainstate.util._struct import FrozenDict
33
+ from brainstate.util.caller import ApplyCaller, CallableProxy, DelayedAccessor
34
+ from brainstate.util.pretty_pytree import NestedDict, FlattedDict, PrettyDict
35
+ from brainstate.util.pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
36
+ from brainstate.util.struct import FrozenDict
37
37
  from brainstate.util.filter import to_predicate
38
38
 
39
39
  _max_int = np.iinfo(np.int32).max
brainstate/mixin.py CHANGED
@@ -41,6 +41,14 @@ __all__ = [
41
41
  ]
42
42
 
43
43
 
44
+ def hashable(x):
45
+ try:
46
+ hash(x)
47
+ return True
48
+ except TypeError:
49
+ return False
50
+
51
+
44
52
  class Mixin(object):
45
53
  """Base Mixin object.
46
54
 
@@ -67,6 +75,14 @@ class ParamDesc(Mixin):
67
75
 
68
76
 
69
77
  class HashableDict(dict):
78
+ def __init__(self, the_dict: dict):
79
+ out = dict()
80
+ for k, v in the_dict.items():
81
+ if not hashable(v):
82
+ v = str(v) # convert to string if not hashable
83
+ out[k] = v
84
+ super().__init__(out)
85
+
70
86
  def __hash__(self):
71
87
  return hash(tuple(sorted(self.items())))
72
88
 
@@ -132,7 +148,6 @@ class AlignPost(Mixin):
132
148
  raise NotImplementedError
133
149
 
134
150
 
135
-
136
151
  class BindCondData(Mixin):
137
152
  """Bind temporary conductance data.
138
153
 
@@ -147,12 +162,26 @@ class BindCondData(Mixin):
147
162
  self._conductance = None
148
163
 
149
164
 
165
+ def not_implemented(func):
166
+
167
+ def wrapper(*args, **kwargs):
168
+ raise NotImplementedError(f'{func.__name__} is not implemented.')
169
+
170
+ wrapper.not_implemented = True
171
+ return wrapper
172
+
173
+
174
+
150
175
  class UpdateReturn(Mixin):
176
+ @not_implemented
151
177
  def update_return(self) -> PyTree:
152
178
  """
153
179
  The update function return of the model.
154
180
 
155
- It should be a pytree, with each element as a ``jax.ShapeDtypeStruct`` or ``jax.core.ShapedArray``.
181
+ This function requires no parameters and must return a PyTree.
182
+
183
+ It is usually used for delay initialization, for example, ``Dynamics.output_delay`` relies on this function to
184
+ initialize the output delay.
156
185
 
157
186
  """
158
187
  raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
brainstate/nn/__init__.py CHANGED
@@ -33,12 +33,14 @@ from ._embedding import *
33
33
  from ._embedding import __all__ as embed_all
34
34
  from ._exp_euler import *
35
35
  from ._exp_euler import __all__ as exp_euler_all
36
- from ._fixedprob_mv import EventFixedProb, EventFixedNumConn
36
+ from ._fixedprob import *
37
+ from._fixedprob import __all__ as fixedprob_all
37
38
  from ._inputs import *
38
39
  from ._inputs import __all__ as inputs_all
39
40
  from ._linear import *
40
41
  from ._linear import __all__ as linear_all
41
- from ._linear_mv import EventLinear
42
+ from ._linear_mv import *
43
+ from ._linear_mv import __all__ as linear_mv_all
42
44
  from ._ltp import *
43
45
  from ._ltp import __all__ as ltp_all
44
46
  from ._module import *
@@ -69,9 +71,6 @@ from ._utils import __all__ as utils_all
69
71
  __all__ = (
70
72
  [
71
73
  'metrics',
72
- 'EventLinear',
73
- 'EventFixedProb',
74
- 'EventFixedNumConn',
75
74
  ]
76
75
  + collective_ops_all
77
76
  + common_all
@@ -87,6 +86,8 @@ __all__ = (
87
86
  + linear_all
88
87
  + normalizations_all
89
88
  + poolings_all
89
+ + fixedprob_all
90
+ + linear_mv_all
90
91
  + embed_all
91
92
  + dropout_all
92
93
  + elementwise_all
@@ -115,6 +116,8 @@ del (
115
116
  normalizations_all,
116
117
  poolings_all,
117
118
  embed_all,
119
+ fixedprob_all,
120
+ linear_mv_all,
118
121
  dropout_all,
119
122
  elementwise_all,
120
123
  dyn_neuron_all,
brainstate/nn/_common.py CHANGED
@@ -118,14 +118,14 @@ class Vmap(Module):
118
118
  This class wraps a module and applies vectorized mapping to its execution,
119
119
  allowing for efficient parallel processing across specified axes.
120
120
 
121
- Attributes:
121
+ Args:
122
122
  module (Module): The module to be vmapped.
123
- in_axes (int | None | Sequence[Any]): Specifies how to map over inputs.
124
- out_axes (Any): Specifies how to map over outputs.
125
- vmap_states (Filter | Dict[Filter, int]): Specifies which states to vmap and on which axes.
126
- vmap_out_states (Filter | Dict[Filter, int]): Specifies which output states to vmap and on which axes.
127
- axis_name (AxisName | None): Name of the axis being mapped over.
128
- axis_size (int | None): Size of the axis being mapped over.
123
+ in_axes (int | None | Sequence[Any], optional): Specifies how to map over inputs. Defaults to 0.
124
+ out_axes (Any, optional): Specifies how to map over outputs. Defaults to 0.
125
+ vmap_states (Filter | Dict[Filter, int], optional): Specifies which states to vmap and on which axes. Defaults to None.
126
+ vmap_out_states (Filter | Dict[Filter, int], optional): Specifies which output states to vmap and on which axes. Defaults to None.
127
+ axis_name (AxisName | None, optional): Name of the axis being mapped over. Defaults to None.
128
+ axis_size (int | None, optional): Size of the axis being mapped over. Defaults to None.
129
129
  """
130
130
 
131
131
  def __init__(
@@ -138,18 +138,6 @@ class Vmap(Module):
138
138
  axis_name: AxisName | None = None,
139
139
  axis_size: int | None = None,
140
140
  ):
141
- """
142
- Initialize the Vmap instance.
143
-
144
- Args:
145
- module (Module): The module to be vmapped.
146
- in_axes (int | None | Sequence[Any], optional): Specifies how to map over inputs. Defaults to 0.
147
- out_axes (Any, optional): Specifies how to map over outputs. Defaults to 0.
148
- vmap_states (Filter | Dict[Filter, int], optional): Specifies which states to vmap and on which axes. Defaults to None.
149
- vmap_out_states (Filter | Dict[Filter, int], optional): Specifies which output states to vmap and on which axes. Defaults to None.
150
- axis_name (AxisName | None, optional): Name of the axis being mapped over. Defaults to None.
151
- axis_size (int | None, optional): Size of the axis being mapped over. Defaults to None.
152
- """
153
141
  super().__init__()
154
142
 
155
143
  # parameters
brainstate/nn/_delay.py CHANGED
@@ -330,7 +330,14 @@ class Delay(Module):
330
330
  indices = (delay_idx,) + indices
331
331
 
332
332
  # the delay data
333
- return jax.tree.map(lambda a: a[indices], self.history.value)
333
+ if self._unit is None:
334
+ return jax.tree.map(lambda a: a[indices], self.history.value)
335
+ else:
336
+ return jax.tree.map(
337
+ lambda hist, unit: u.maybe_decimal(hist[indices] * unit),
338
+ self.history.value,
339
+ self._unit
340
+ )
334
341
 
335
342
  def retrieve_at_time(self, delay_time, *indices) -> PyTree:
336
343
  """
@@ -393,6 +400,9 @@ class Delay(Module):
393
400
  """
394
401
  assert self.history is not None, 'The delay history is not initialized.'
395
402
 
403
+ if self.take_aware_unit and self._unit is None:
404
+ self._unit = jax.tree.map(lambda x: u.get_unit(x), current, is_leaf=u.math.is_quantity)
405
+
396
406
  # update the delay data at the rotation index
397
407
  if self.delay_method == _DELAY_ROTATE:
398
408
  i = environ.get(environ.I)
@@ -419,6 +429,8 @@ class Delay(Module):
419
429
  raise ValueError(f'Unknown updating method "{self.delay_method}"')
420
430
 
421
431
 
432
+
433
+
422
434
  class StateWithDelay(Delay):
423
435
  """
424
436
  A ``State`` type that defines the state in a differential equation.
brainstate/nn/_dropout.py CHANGED
@@ -409,7 +409,8 @@ class DropoutFixed(ElementWiseBlock):
409
409
  self.out_size = in_size
410
410
 
411
411
  def init_state(self, batch_size=None, **kwargs):
412
- self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
412
+ if self.prob < 1.:
413
+ self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
413
414
 
414
415
  def update(self, x):
415
416
  dtype = u.math.get_dtype(x)
@@ -418,8 +419,8 @@ class DropoutFixed(ElementWiseBlock):
418
419
  if self.mask.value.shape != x.shape:
419
420
  raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
420
421
  f"Please call `init_state()` method first.")
421
- return jnp.where(self.mask.value,
422
- jnp.asarray(x / self.prob, dtype=dtype),
423
- jnp.asarray(0., dtype=dtype))
422
+ return u.math.where(self.mask.value,
423
+ u.math.asarray(x / self.prob, dtype=dtype),
424
+ u.math.asarray(0., dtype=dtype) * u.get_unit(x))
424
425
  else:
425
426
  return x
@@ -36,13 +36,14 @@ For handling the delays:
36
36
  from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
37
37
 
38
38
  import brainunit as u
39
+ import jax
39
40
  import numpy as np
40
41
 
41
42
  from brainstate import environ
42
43
  from brainstate._state import State
43
44
  from brainstate.graph import Node
44
45
  from brainstate.mixin import ParamDescriber, UpdateReturn
45
- from brainstate.typing import Size, ArrayLike
46
+ from brainstate.typing import Size, ArrayLike, PyTree
46
47
  from ._delay import StateWithDelay, Delay
47
48
  from ._module import Module
48
49
 
@@ -811,18 +812,25 @@ class Dynamics(Module, UpdateReturn):
811
812
  >>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
812
813
  """
813
814
  if isinstance(dyn, Dynamics):
814
- self._add_after_update(dyn.name, dyn)
815
+ self._add_after_update(id(dyn), dyn)
815
816
  return dyn
816
817
  elif isinstance(dyn, ParamDescriber):
817
818
  if not issubclass(dyn.cls, Dynamics):
818
819
  raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
819
820
  if not self._has_after_update(dyn.identifier):
820
- self._add_after_update(dyn.identifier, dyn())
821
+ self._add_after_update(
822
+ dyn.identifier,
823
+ dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape)
824
+ )
821
825
  return self._get_after_update(dyn.identifier)
822
826
  else:
823
827
  raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
824
828
 
825
- def prefetch_delay(self, state: str, delay: Optional[ArrayLike] = None) -> 'PrefetchDelayAt':
829
+ def prefetch_delay(
830
+ self,
831
+ state: str,
832
+ delay: Optional[ArrayLike] = None
833
+ ) -> 'PrefetchDelayAt':
826
834
  """
827
835
  Create a reference to a delayed state or variable in the module.
828
836
 
@@ -840,7 +848,11 @@ class Dynamics(Module, UpdateReturn):
840
848
  """
841
849
  return self.prefetch(state).delay.at(delay)
842
850
 
843
- def output_delay(self, delay: Optional[ArrayLike] = None) -> 'OutputDelayAt':
851
+ def output_delay(
852
+ self,
853
+ delay: Optional[ArrayLike] = None,
854
+ variable_like: PyTree = None
855
+ ) -> 'OutputDelayAt':
844
856
  """
845
857
  Create a reference to the delayed output of the module.
846
858
 
@@ -851,6 +863,7 @@ class Dynamics(Module, UpdateReturn):
851
863
  Args:
852
864
  delay (Optional[ArrayLike]): The amount of time to delay the output access,
853
865
  typically in time units (e.g., milliseconds). Defaults to None.
866
+ variable_like:
854
867
 
855
868
  Returns:
856
869
  OutputDelayAt: An object that provides access to the module's output at the specified delay time.
@@ -1102,8 +1115,6 @@ class OutputDelayAt(Node):
1102
1115
  ----------
1103
1116
  module : Dynamics
1104
1117
  The dynamics module that contains the referenced state or variable.
1105
- item : str
1106
- The name of the state or variable to access with delay.
1107
1118
  time : ArrayLike
1108
1119
  The amount of time to delay access by, typically in time units (e.g., milliseconds).
1109
1120
 
@@ -1113,56 +1124,40 @@ class OutputDelayAt(Node):
1113
1124
  >>> import brainunit as u
1114
1125
  >>> neuron = brainstate.nn.LIF(10)
1115
1126
  >>> # Create a reference to voltage delayed by 5ms
1116
- >>> delayed_v = PrefetchDelayAt(neuron, 'V', 5.0 * u.ms)
1127
+ >>> delayed_spike = OutputDelayAt(neuron, 5.0 * u.ms)
1117
1128
  >>> # Get the delayed value
1118
- >>> v_value = delayed_v()
1129
+ >>> v_value = delayed_spike()
1119
1130
  """
1120
1131
 
1121
1132
  def __init__(
1122
1133
  self,
1123
1134
  module: Dynamics,
1124
1135
  time: Optional[ArrayLike] = None,
1136
+ variable_like: Optional[PyTree] = None,
1125
1137
  ):
1126
- """
1127
- Initialize a PrefetchDelayAt object.
1128
-
1129
- Parameters
1130
- ----------
1131
- module : AlignPre, Module
1132
- The dynamics module that contains the referenced state or variable.
1133
- time : ArrayLike
1134
- The amount of time to delay access by, typically in time units.
1135
- """
1136
1138
  super().__init__()
1137
- assert isinstance(module, UpdateReturn), 'The module should implement the `update_return` method.'
1138
- assert isinstance(module, Module), 'The module should be an instance of Module.'
1139
+ assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
1139
1140
  self.module = module
1141
+ dt = environ.get_dt()
1142
+ if time is None:
1143
+ time = u.math.zeros_like(dt)
1140
1144
  self.time = time
1141
- if time is not None:
1142
- self.step = u.math.asarray(time / environ.get_dt(), dtype=environ.ditype())
1143
-
1144
- # register the delay
1145
- key = _get_output_delay_key()
1146
- if not module._has_after_update(key):
1147
- # TODO: unit processing
1148
- delay = Delay(module.update_return(), time)
1149
- module._add_after_update(key, receive_update_output(delay))
1150
- self.out_delay: Delay = module._get_after_update(key)
1151
- self.out_delay.register_delay(time)
1145
+ self.step = u.math.asarray(time / dt, dtype=environ.ditype())
1146
+
1147
+ # register the delay
1148
+ key = _get_output_delay_key()
1149
+ if not module._has_after_update(key):
1150
+ delay = Delay(
1151
+ jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()),
1152
+ time,
1153
+ take_aware_unit=True
1154
+ )
1155
+ module._add_after_update(key, receive_update_output(delay))
1156
+ self.out_delay: Delay = module._get_after_update(key)
1157
+ self.out_delay.register_delay(time)
1152
1158
 
1153
1159
  def __call__(self, *args, **kwargs):
1154
- """
1155
- Retrieve the value of the state at the specified delay time.
1156
-
1157
- Returns
1158
- -------
1159
- Any
1160
- The value of the state or variable at the specified delay time.
1161
- """
1162
- if self.time is None:
1163
- return self.module.update_return()
1164
- else:
1165
- return self.out_delay.retrieve_at_step(self.step)
1160
+ return self.out_delay.retrieve_at_step(self.step)
1166
1161
 
1167
1162
 
1168
1163
  def _get_prefetch_delay_key(item) -> str:
@@ -69,27 +69,24 @@ def exp_euler_step(
69
69
  f'The input data type should be float64, float32, float16, or bfloat16 '
70
70
  f'when using Exponential Euler method. But we got {args[0].dtype}.'
71
71
  )
72
+
73
+ # drift
72
74
  dt = environ.get('dt')
73
75
  linear, derivative = vector_grad(fn, argnums=0, return_value=True)(*args, **kwargs)
74
76
  linear = u.Quantity(u.get_mantissa(linear), u.get_unit(derivative) / u.get_unit(linear))
75
77
  phi = u.math.exprel(dt * linear)
76
78
  x_next = args[0] + dt * phi * derivative
77
79
 
80
+ # diffusion
78
81
  if diffusion is not None:
79
- # unit checking
80
- diffusion = diffusion(*args, **kwargs)
81
- time_unit = u.get_unit(dt)
82
- drift_unit = u.get_unit(derivative)
83
- diffusion_unit = u.get_unit(diffusion)
84
- # if drift_unit.is_unitless:
85
- # assert diffusion_unit.is_unitless, 'The diffusion term should be unitless when the drift term is unitless.'
86
- # else:
87
- # u.fail_for_dimension_mismatch(
88
- # drift_unit, diffusion_unit * time_unit ** 0.5,
89
- # "Drift unit is {drift}, diffusion unit is {diffusion}, ",
90
- # drift=drift_unit, diffusion=diffusion_unit * time_unit ** 0.5
91
- # )
92
-
93
- # diffusion
94
- x_next += diffusion * u.math.sqrt(dt) * random.randn_like(args[0])
82
+ diffusion_part = diffusion(*args, **kwargs) * u.math.sqrt(dt) * random.randn_like(args[0])
83
+ if u.get_dim(x_next) != u.get_dim(diffusion_part):
84
+ drift_unit = u.get_unit(x_next)
85
+ time_unit = u.get_unit(dt)
86
+ raise ValueError(
87
+ f"Drift unit is {drift_unit}, "
88
+ f"expected diffusion unit is {drift_unit / time_unit ** 0.5}, "
89
+ f"but we got {u.get_unit(diffusion_part)}."
90
+ )
91
+ x_next += diffusion_part
95
92
  return x_next
@@ -16,19 +16,20 @@
16
16
 
17
17
  from typing import Union, Callable, Optional
18
18
 
19
+ import brainevent
19
20
  import brainunit as u
20
21
  import jax
21
22
  import jax.numpy as jnp
22
23
  import numpy as np
23
24
 
24
25
  from brainstate import random, augment, environ, init
25
- from brainstate._compatible_import import brainevent
26
- from brainstate._state import ParamState
26
+ from brainstate._state import ParamState, FakeState
27
27
  from brainstate.compile import for_loop
28
28
  from brainstate.typing import Size, ArrayLike
29
29
  from ._module import Module
30
30
 
31
31
  __all__ = [
32
+ 'FixedNumConn',
32
33
  'EventFixedNumConn',
33
34
  'EventFixedProb',
34
35
  ]
@@ -44,12 +45,11 @@ def init_indices_without_replace(
44
45
  rng = random.default_rng(seed)
45
46
 
46
47
  if method == 'vmap':
47
- @augment.vmap
48
- def rand_indices(key):
49
- rng.set_key(key)
48
+ @augment.vmap(axis_size=n_pre)
49
+ def rand_indices():
50
50
  return rng.choice(n_post, size=(conn_num,), replace=False)
51
51
 
52
- return rand_indices(rng.split_key(n_pre))
52
+ return rand_indices()
53
53
 
54
54
  elif method == 'for_loop':
55
55
  return for_loop(
@@ -61,9 +61,9 @@ def init_indices_without_replace(
61
61
  raise ValueError(f"Unknown method: {method}")
62
62
 
63
63
 
64
- class EventFixedNumConn(Module):
64
+ class FixedNumConn(Module):
65
65
  """
66
- The FixedProb module implements a fixed probability connection with CSR sparse data structure.
66
+ The ``FixedNumConn`` module implements a fixed probability connection with CSR sparse data structure.
67
67
 
68
68
  Parameters
69
69
  ----------
@@ -77,7 +77,7 @@ class EventFixedNumConn(Module):
77
77
  If it is an integer, representing the number of connections.
78
78
  conn_weight : float or callable or jax.Array or brainunit.Quantity
79
79
  Maximum synaptic conductance, i.e., synaptic weight.
80
- conn_target : str, optional
80
+ efferent_target : str, optional
81
81
  The target of the connection. Default is 'post', meaning that each pre-synaptic neuron connects to
82
82
  a fixed number of post-synaptic neurons. The connection number is determined by the value of ``n_conn``.
83
83
 
@@ -104,7 +104,8 @@ class EventFixedNumConn(Module):
104
104
  out_size: Size,
105
105
  conn_num: Union[int, float],
106
106
  conn_weight: Union[Callable, ArrayLike],
107
- conn_target: str = 'post', # 'pre' or 'post'
107
+ efferent_target: str = 'post', # 'pre' or 'post'
108
+ afferent_ratio: Union[int, float] = 1.,
108
109
  allow_multi_conn: bool = True,
109
110
  seed: Optional[int] = None,
110
111
  name: Optional[str] = None,
@@ -116,11 +117,14 @@ class EventFixedNumConn(Module):
116
117
  # network parameters
117
118
  self.in_size = in_size
118
119
  self.out_size = out_size
119
- self.conn_target = conn_target
120
- assert conn_target in ('pre', 'post'), 'The target of the connection must be either "pre" or "post".'
120
+ self.efferent_target = efferent_target
121
+ assert efferent_target in ('pre', 'post'), 'The target of the connection must be either "pre" or "post".'
122
+ assert 0. <= afferent_ratio <= 1., 'Afferent ratio must be in [0, 1].'
121
123
  if isinstance(conn_num, float):
122
124
  assert 0. <= conn_num <= 1., 'Connection probability must be in [0, 1].'
123
- conn_num = int(self.out_size[-1] * conn_num) if conn_target == 'post' else int(self.in_size[-1] * conn_num)
125
+ conn_num = (int(self.out_size[-1] * conn_num)
126
+ if efferent_target == 'post' else
127
+ int(self.in_size[-1] * conn_num))
124
128
  assert isinstance(conn_num, int), 'Connection number must be an integer.'
125
129
  self.conn_num = conn_num
126
130
  self.seed = seed
@@ -128,14 +132,13 @@ class EventFixedNumConn(Module):
128
132
 
129
133
  # connections
130
134
  if self.conn_num >= 1:
131
- if self.conn_target == 'post':
135
+ if self.efferent_target == 'post':
132
136
  n_post = self.out_size[-1]
133
137
  n_pre = self.in_size[-1]
134
138
  else:
135
139
  n_post = self.in_size[-1]
136
140
  n_pre = self.out_size[-1]
137
141
 
138
- # indices of post connected neurons
139
142
  with jax.ensure_compile_time_eval():
140
143
  if allow_multi_conn:
141
144
  rng = np.random if seed is None else np.random.RandomState(seed)
@@ -143,15 +146,83 @@ class EventFixedNumConn(Module):
143
146
  else:
144
147
  indices = init_indices_without_replace(self.conn_num, n_pre, n_post, seed, conn_init)
145
148
  indices = u.math.asarray(indices, dtype=environ.ditype())
146
- conn_weight = init.param(conn_weight, (n_pre, self.conn_num), allow_none=False)
147
- conn_weight = u.math.asarray(conn_weight)
148
- self.weight = param_type(conn_weight)
149
- csr = (
150
- brainevent.FixedPostNumConn((conn_weight, indices), shape=(n_pre, n_post))
151
- if self.conn_target == 'post' else
152
- brainevent.FixedPreNumConn((conn_weight, indices), shape=(n_pre, n_post))
153
- )
154
- self.conn = csr
149
+
150
+ if afferent_ratio == 1.:
151
+ conn_weight = u.math.asarray(init.param(conn_weight, (n_pre, self.conn_num), allow_none=False))
152
+ self.weight = param_type(conn_weight)
153
+ csr = (
154
+ brainevent.FixedPostNumConn((conn_weight, indices), shape=(n_pre, n_post))
155
+ if self.efferent_target == 'post' else
156
+ brainevent.FixedPreNumConn((conn_weight, indices), shape=(n_pre, n_post))
157
+ )
158
+ self.conn = csr
159
+
160
+ else:
161
+ self.pre_selected = np.random.random(n_pre) < afferent_ratio
162
+ indices = indices[self.pre_selected].flatten()
163
+ conn_weight = u.math.asarray(init.param(conn_weight, (indices.size,), allow_none=False))
164
+ self.weight = param_type(conn_weight)
165
+ indptr = (jnp.arange(1, n_pre + 1) * self.conn_num -
166
+ jnp.cumsum(~self.pre_selected) * self.conn_num)
167
+ indptr = jnp.insert(indptr, 0, 0) # insert 0 at the beginning
168
+ csr = (
169
+ brainevent.CSR((conn_weight, indices, indptr), shape=(n_pre, n_post))
170
+ if self.efferent_target == 'post' else
171
+ brainevent.CSC((conn_weight, indices, indptr), shape=(n_pre, n_post))
172
+ )
173
+ self.conn = csr
174
+
175
+ else:
176
+ conn_weight = u.math.asarray(init.param(conn_weight, (), allow_none=False))
177
+ self.weight = FakeState(conn_weight)
178
+
179
+ def update(self, x: jax.Array) -> Union[jax.Array, u.Quantity]:
180
+ if self.conn_num >= 1:
181
+ csr = self.conn.with_data(self.weight.value)
182
+ return x @ csr
183
+ else:
184
+ weight = self.weight.value
185
+ r = u.math.zeros(x.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
186
+ r = u.maybe_decimal(u.Quantity(r, unit=u.get_unit(weight)))
187
+ return u.math.asarray(r, dtype=environ.dftype())
188
+
189
+
190
+ class EventFixedNumConn(FixedNumConn):
191
+ """
192
+ The FixedProb module implements a fixed probability connection with CSR sparse data structure.
193
+
194
+ Parameters
195
+ ----------
196
+ in_size : Size
197
+ Number of pre-synaptic neurons, i.e., input size.
198
+ out_size : Size
199
+ Number of post-synaptic neurons, i.e., output size.
200
+ conn_num : float, int
201
+ If it is a float, representing the probability of connection, i.e., connection probability.
202
+
203
+ If it is an integer, representing the number of connections.
204
+ conn_weight : float or callable or jax.Array or brainunit.Quantity
205
+ Maximum synaptic conductance, i.e., synaptic weight.
206
+ conn_target : str, optional
207
+ The target of the connection. Default is 'post', meaning that each pre-synaptic neuron connects to
208
+ a fixed number of post-synaptic neurons. The connection number is determined by the value of ``n_conn``.
209
+
210
+ If 'pre', each post-synaptic neuron connects to a fixed number of pre-synaptic neurons.
211
+ conn_init : str, optional
212
+ The initialization method of the connection weight. Default is 'vmap', meaning that the connection weight
213
+ is initialized by parallelized across multiple threads.
214
+
215
+ If 'for_loop', the connection weight is initialized by a for loop.
216
+ allow_multi_conn : bool, optional
217
+ Whether multiple connections are allowed from a single pre-synaptic neuron.
218
+ Default is True, meaning that a value of ``a`` can be selected multiple times.
219
+ seed: int, optional
220
+ Random seed. Default is None. If None, the default random seed will be used.
221
+ name : str, optional
222
+ Name of the module.
223
+ """
224
+
225
+ __module__ = 'brainstate.nn'
155
226
 
156
227
  def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
157
228
  if self.conn_num >= 1:
brainstate/nn/_inputs.py CHANGED
@@ -547,7 +547,7 @@ def poisson_input(
547
547
  num_input,
548
548
  p,
549
549
  tar[indices].shape,
550
- # check_valid=False,
550
+ check_valid=False,
551
551
  dtype=tar.dtype
552
552
  ),
553
553
  tar_val,