brainstate 0.1.0.post20250420__py2.py3-none-any.whl → 0.1.0.post20250423__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. brainstate/_compatible_import.py +15 -0
  2. brainstate/_state.py +5 -4
  3. brainstate/_state_test.py +2 -1
  4. brainstate/augment/_autograd_test.py +3 -2
  5. brainstate/augment/_eval_shape.py +2 -1
  6. brainstate/augment/_mapping.py +0 -1
  7. brainstate/augment/_mapping_test.py +1 -0
  8. brainstate/compile/_ad_checkpoint.py +2 -1
  9. brainstate/compile/_conditions.py +3 -3
  10. brainstate/compile/_conditions_test.py +2 -1
  11. brainstate/compile/_error_if.py +2 -1
  12. brainstate/compile/_error_if_test.py +2 -1
  13. brainstate/compile/_jit.py +3 -2
  14. brainstate/compile/_jit_test.py +2 -1
  15. brainstate/compile/_loop_collect_return.py +2 -2
  16. brainstate/compile/_loop_collect_return_test.py +2 -1
  17. brainstate/compile/_loop_no_collection.py +1 -1
  18. brainstate/compile/_make_jaxpr.py +2 -2
  19. brainstate/compile/_make_jaxpr_test.py +2 -1
  20. brainstate/compile/_progress_bar.py +2 -1
  21. brainstate/compile/_unvmap.py +1 -2
  22. brainstate/environ.py +4 -4
  23. brainstate/environ_test.py +2 -1
  24. brainstate/functional/_activations.py +2 -1
  25. brainstate/functional/_activations_test.py +1 -1
  26. brainstate/functional/_normalization.py +2 -1
  27. brainstate/functional/_others.py +2 -1
  28. brainstate/graph/_graph_operation.py +3 -2
  29. brainstate/graph/_graph_operation_test.py +4 -3
  30. brainstate/init/_base.py +2 -1
  31. brainstate/init/_generic.py +2 -1
  32. brainstate/nn/__init__.py +4 -0
  33. brainstate/nn/_collective_ops.py +1 -0
  34. brainstate/nn/_collective_ops_test.py +0 -4
  35. brainstate/nn/_common.py +0 -1
  36. brainstate/nn/_dyn_impl/__init__.py +0 -4
  37. brainstate/nn/_dyn_impl/_dynamics_neuron.py +431 -13
  38. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +2 -1
  39. brainstate/nn/_dyn_impl/_dynamics_synapse.py +405 -103
  40. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +2 -1
  41. brainstate/nn/_dyn_impl/_inputs.py +236 -29
  42. brainstate/nn/_dyn_impl/_rate_rnns.py +238 -82
  43. brainstate/nn/_dyn_impl/_rate_rnns_test.py +2 -1
  44. brainstate/nn/_dyn_impl/_readout.py +91 -8
  45. brainstate/nn/_dyn_impl/_readout_test.py +2 -1
  46. brainstate/nn/_dynamics/_dynamics_base.py +676 -96
  47. brainstate/nn/_dynamics/_dynamics_base_test.py +2 -1
  48. brainstate/nn/_dynamics/_projection_base.py +29 -30
  49. brainstate/nn/_dynamics/_state_delay.py +3 -3
  50. brainstate/nn/_dynamics/_synouts_test.py +2 -1
  51. brainstate/nn/_elementwise/_dropout.py +3 -2
  52. brainstate/nn/_elementwise/_dropout_test.py +2 -1
  53. brainstate/nn/_elementwise/_elementwise.py +2 -1
  54. brainstate/nn/{_dyn_impl/_projection_alignpost.py → _event/__init__.py} +8 -7
  55. brainstate/nn/_event/_fixedprob_mv.py +169 -0
  56. brainstate/nn/_event/_fixedprob_mv_test.py +115 -0
  57. brainstate/nn/_event/_linear_mv.py +85 -0
  58. brainstate/nn/_event/_linear_mv_test.py +121 -0
  59. brainstate/nn/_exp_euler.py +2 -1
  60. brainstate/nn/_exp_euler_test.py +2 -1
  61. brainstate/nn/_interaction/_conv.py +2 -1
  62. brainstate/nn/_interaction/_linear.py +2 -1
  63. brainstate/nn/_interaction/_linear_test.py +2 -1
  64. brainstate/nn/_interaction/_normalizations.py +3 -2
  65. brainstate/nn/_interaction/_poolings.py +4 -3
  66. brainstate/nn/_module_test.py +2 -1
  67. brainstate/nn/metrics.py +4 -3
  68. brainstate/optim/_lr_scheduler.py +2 -1
  69. brainstate/optim/_lr_scheduler_test.py +2 -1
  70. brainstate/optim/_optax_optimizer_test.py +2 -1
  71. brainstate/optim/_sgd_optimizer.py +3 -2
  72. brainstate/random/_rand_funs.py +2 -1
  73. brainstate/random/_rand_funs_test.py +3 -2
  74. brainstate/random/_rand_seed.py +3 -2
  75. brainstate/random/_rand_seed_test.py +2 -1
  76. brainstate/random/_rand_state.py +4 -3
  77. brainstate/surrogate.py +1 -2
  78. brainstate/typing.py +4 -4
  79. brainstate/util/_caller.py +2 -1
  80. brainstate/util/_others.py +4 -4
  81. brainstate/util/_pretty_pytree.py +1 -1
  82. brainstate/util/_pretty_pytree_test.py +2 -1
  83. brainstate/util/_pretty_table.py +43 -43
  84. brainstate/util/_struct.py +2 -1
  85. brainstate/util/filter.py +0 -1
  86. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/METADATA +3 -3
  87. brainstate-0.1.0.post20250423.dist-info/RECORD +133 -0
  88. brainstate-0.1.0.post20250420.dist-info/RECORD +0 -129
  89. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/LICENSE +0 -0
  90. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/WHEEL +0 -0
  91. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/top_level.txt +0 -0
@@ -17,9 +17,10 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import numpy as np
21
20
  import unittest
22
21
 
22
+ import numpy as np
23
+
23
24
  import brainstate as bst
24
25
 
25
26
 
@@ -104,35 +104,38 @@ class _AlignPost(Module):
104
104
 
105
105
  class AlignPostProj(Interaction):
106
106
  """
107
- Full-chain synaptic projection with the align-post reduction and the automatic synapse merging.
108
-
109
- The ``full-chain`` means that the model needs to provide all information needed for a projection,
110
- including ``pre`` -> ``delay`` -> ``comm`` -> ``syn`` -> ``out`` -> ``post``.
111
-
112
- The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group.
113
-
114
- The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
115
- parameters (such like time constants) will also share the same synaptic variables.
116
-
117
- All align-post projection models prefer to use the event-driven computation mode. This means that the
118
- ``comm`` model should be the event-driven model.
119
-
120
- Moreover, it's worth noting that ``FullProjAlignPost`` has a different updating order with all align-pre
121
- projection models. The updating order of align-post projections is ``spikes`` -> ``comm`` -> ``syn`` -> ``out``.
122
- While, the updating order of all align-pre projection models is usually ``spikes`` -> ``syn`` -> ``comm`` -> ``out``.
107
+ Align-post projection of the neural network.
108
+
109
+
110
+ Examples
111
+ --------
112
+
113
+ Here is an example of using the `AlignPostProj` to create a synaptic projection.
114
+ Note that this projection needs the manual input of pre-synaptic spikes.
115
+
116
+ >>> import brainstate
117
+ >>> import brainevent
118
+ >>> import brainunit as u
119
+ >>> n_exc = 3200
120
+ >>> n_inh = 800
121
+ >>> num = n_exc + n_inh
122
+ >>> pop = brainstate.nn.LIFRef(
123
+ ... num,
124
+ ... V_rest=-49. * u.mV, V_th=-50. * u.mV, V_reset=-60. * u.mV,
125
+ ... tau=20. * u.ms, tau_ref=5. * u.ms,
126
+ ... V_initializer=brainstate.init.Normal(-55., 2., unit=u.mV)
127
+ ... )
128
+ >>> pop.reset_state()
129
+ >>> E = brainstate.nn.AlignPostProj(
130
+ ... comm=brainevent.nn.FixedProb(n_exc, num, prob=80 / num, weight=1.62 * u.mS),
131
+ ... syn=brainstate.nn.Expon.desc(num, tau=5. * u.ms),
132
+ ... out=brainstate.nn.CUBA.desc(scale=u.volt),
133
+ ... post=pop
134
+ ... )
135
+ >>> exe_current = E(pop.spike.value)
123
136
 
124
137
 
125
- # brainstate.nn.AlignPostProj(
126
- # LIF().prefetch('V').delay.at('I'), bst.surrogate.ReluGrad(), comm, syn, out, post
127
- # )
128
138
 
129
- Args:
130
- pre: The pre-synaptic neuron group.
131
- delay: The synaptic delay.
132
- comm: The synaptic communication.
133
- syn: The synaptic dynamics.
134
- out: The synaptic output.
135
- post: The post-synaptic neuron group.
136
139
  """
137
140
  __module__ = 'brainstate.nn'
138
141
 
@@ -314,10 +317,6 @@ class CurrentProj(Interaction):
314
317
  This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
315
318
  than the spiking. To facilitate the event-driven computation, please use align post projections.
316
319
 
317
- # bint.CurrentInteraction(
318
- # LIF().align_pre(bst.nn.Expon.desc()).prefetch('g'), comm, out, post
319
- # )
320
-
321
320
  Args:
322
321
  prefetch: The synaptic dynamics.
323
322
  comm: The synaptic communication.
@@ -16,14 +16,14 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import math
19
+ import numbers
20
+ from functools import partial
21
+ from typing import Optional, Dict, Callable, Union, Sequence
19
22
 
20
23
  import brainunit as u
21
24
  import jax
22
25
  import jax.numpy as jnp
23
- import numbers
24
26
  import numpy as np
25
- from functools import partial
26
- from typing import Optional, Dict, Callable, Union, Sequence
27
27
 
28
28
  from brainstate import environ
29
29
  from brainstate._state import ShortTermState, State
@@ -15,10 +15,11 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import unittest
19
+
18
20
  import brainunit as u
19
21
  import jax.numpy as jnp
20
22
  import numpy as np
21
- import unittest
22
23
 
23
24
  import brainstate as bst
24
25
 
@@ -16,11 +16,12 @@
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- import brainunit as u
20
- import jax.numpy as jnp
21
19
  from functools import partial
22
20
  from typing import Optional, Sequence
23
21
 
22
+ import brainunit as u
23
+ import jax.numpy as jnp
24
+
24
25
  from brainstate import random, environ, init
25
26
  from brainstate._state import ShortTermState
26
27
  from brainstate.nn._module import ElementWiseBlock
@@ -14,9 +14,10 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- import numpy as np
18
17
  import unittest
19
18
 
19
+ import numpy as np
20
+
20
21
  import brainstate as bst
21
22
 
22
23
 
@@ -17,10 +17,11 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
+ from typing import Optional
21
+
20
22
  import brainunit as u
21
23
  import jax.numpy as jnp
22
24
  import jax.typing
23
- from typing import Optional
24
25
 
25
26
  from brainstate import random, functional as F
26
27
  from brainstate._state import ParamState
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -13,13 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from __future__ import annotations
16
+ # -*- coding: utf-8 -*-
17
17
 
18
- from brainstate.nn._dynamics._dynamics_base import Projection
18
+
19
+ from ._fixedprob_mv import EventFixedProb, EventFixedNumConn
20
+ from ._linear_mv import EventLinear
19
21
 
20
22
  __all__ = [
23
+ 'EventLinear',
24
+ 'EventFixedProb',
25
+ 'EventFixedNumConn',
21
26
  ]
22
-
23
-
24
- class ExponentialSynapse(Projection):
25
- pass
@@ -0,0 +1,169 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Union, Callable, Optional
19
+
20
+ import brainunit as u
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+
25
+ from brainstate import random, augment, environ, init
26
+ from brainstate._compatible_import import brainevent
27
+ from brainstate._state import ParamState
28
+ from brainstate.compile import for_loop
29
+ from brainstate.nn._module import Module
30
+ from brainstate.typing import Size, ArrayLike
31
+
32
+ __all__ = [
33
+ 'EventFixedNumConn',
34
+ 'EventFixedProb',
35
+ ]
36
+
37
+
38
+ def init_indices_without_replace(
39
+ conn_num: int,
40
+ n_pre: int,
41
+ n_post: int,
42
+ seed: int | None,
43
+ method: str
44
+ ):
45
+ rng = random.default_rng(seed)
46
+
47
+ if method == 'vmap':
48
+ @augment.vmap
49
+ def rand_indices(key):
50
+ rng.set_key(key)
51
+ return rng.choice(n_post, size=(conn_num,), replace=False)
52
+
53
+ return rand_indices(rng.split_key(n_pre))
54
+
55
+ elif method == 'for_loop':
56
+ return for_loop(
57
+ lambda *args: rng.choice(n_post, size=(conn_num,), replace=False),
58
+ length=n_pre
59
+ )
60
+
61
+ else:
62
+ raise ValueError(f"Unknown method: {method}")
63
+
64
+
65
+ class EventFixedNumConn(Module):
66
+ """
67
+ The FixedProb module implements a fixed probability connection with CSR sparse data structure.
68
+
69
+ Parameters
70
+ ----------
71
+ in_size : Size
72
+ Number of pre-synaptic neurons, i.e., input size.
73
+ out_size : Size
74
+ Number of post-synaptic neurons, i.e., output size.
75
+ conn_num : float, int
76
+ If it is a float, representing the probability of connection, i.e., connection probability.
77
+
78
+ If it is an integer, representing the number of connections.
79
+ conn_weight : float or callable or jax.Array or brainunit.Quantity
80
+ Maximum synaptic conductance, i.e., synaptic weight.
81
+ conn_target : str, optional
82
+ The target of the connection. Default is 'post', meaning that each pre-synaptic neuron connects to
83
+ a fixed number of post-synaptic neurons. The connection number is determined by the value of ``n_conn``.
84
+
85
+ If 'pre', each post-synaptic neuron connects to a fixed number of pre-synaptic neurons.
86
+ conn_init : str, optional
87
+ The initialization method of the connection weight. Default is 'vmap', meaning that the connection weight
88
+ is initialized by parallelized across multiple threads.
89
+
90
+ If 'for_loop', the connection weight is initialized by a for loop.
91
+ allow_multi_conn : bool, optional
92
+ Whether multiple connections are allowed from a single pre-synaptic neuron.
93
+ Default is True, meaning that a value of ``a`` can be selected multiple times.
94
+ seed: int, optional
95
+ Random seed. Default is None. If None, the default random seed will be used.
96
+ name : str, optional
97
+ Name of the module.
98
+ """
99
+
100
+ __module__ = 'brainstate.nn'
101
+
102
+ def __init__(
103
+ self,
104
+ in_size: Size,
105
+ out_size: Size,
106
+ conn_num: Union[int, float],
107
+ conn_weight: Union[Callable, ArrayLike],
108
+ conn_target: str = 'post', # 'pre' or 'post'
109
+ allow_multi_conn: bool = True,
110
+ seed: Optional[int] = None,
111
+ name: Optional[str] = None,
112
+ conn_init: str = 'vmap', # 'vmap' or 'for_loop'
113
+ param_type: type = ParamState,
114
+ ):
115
+ super().__init__(name=name)
116
+
117
+ # network parameters
118
+ self.in_size = in_size
119
+ self.out_size = out_size
120
+ self.conn_target = conn_target
121
+ assert conn_target in ('pre', 'post'), 'The target of the connection must be either "pre" or "post".'
122
+ if isinstance(conn_num, float):
123
+ assert 0. <= conn_num <= 1., 'Connection probability must be in [0, 1].'
124
+ conn_num = int(self.out_size[-1] * conn_num) if conn_target == 'post' else int(self.in_size[-1] * conn_num)
125
+ assert isinstance(conn_num, int), 'Connection number must be an integer.'
126
+ self.conn_num = conn_num
127
+ self.seed = seed
128
+ self.allow_multi_conn = allow_multi_conn
129
+
130
+ # connections
131
+ if self.conn_num >= 1:
132
+ if self.conn_target == 'post':
133
+ n_post = self.out_size[-1]
134
+ n_pre = self.in_size[-1]
135
+ else:
136
+ n_post = self.in_size[-1]
137
+ n_pre = self.out_size[-1]
138
+
139
+ # indices of post connected neurons
140
+ with jax.ensure_compile_time_eval():
141
+ if allow_multi_conn:
142
+ rng = np.random if seed is None else np.random.RandomState(seed)
143
+ indices = rng.randint(0, n_post, size=(n_pre, self.conn_num))
144
+ else:
145
+ indices = init_indices_without_replace(self.conn_num, n_pre, n_post, seed, conn_init)
146
+ indices = u.math.asarray(indices, dtype=environ.ditype())
147
+ conn_weight = init.param(conn_weight, (n_pre, self.conn_num), allow_none=False)
148
+ conn_weight = u.math.asarray(conn_weight)
149
+ self.weight = param_type(conn_weight)
150
+ csr = (
151
+ brainevent.FixedPostNumConn((conn_weight, indices), shape=(n_pre, n_post))
152
+ if self.conn_target == 'post' else
153
+ brainevent.FixedPreNumConn((conn_weight, indices), shape=(n_pre, n_post))
154
+ )
155
+ self.conn = csr
156
+
157
+ def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
158
+ if self.conn_num >= 1:
159
+ csr = self.conn.with_data(self.weight.value)
160
+ return brainevent.EventArray(spk) @ csr
161
+ else:
162
+ weight = self.weight.value
163
+ unit = u.get_unit(weight)
164
+ r = jnp.zeros(spk.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
165
+ r = u.maybe_decimal(u.Quantity(r, unit=unit))
166
+ return u.math.asarray(r, dtype=environ.dftype())
167
+
168
+
169
+ EventFixedProb = EventFixedNumConn
@@ -0,0 +1,115 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import jax.numpy
19
+ import jax.numpy as jnp
20
+ import pytest
21
+
22
+ import brainstate
23
+
24
+
25
+ class TestFixedProbCSR:
26
+ @pytest.mark.parametrize('allow_multi_conn', [True, False, ])
27
+ def test1(self, allow_multi_conn):
28
+ x = brainstate.random.rand(20) < 0.1
29
+ # x = brainstate.random.rand(20)
30
+ m = brainstate.nn.EventFixedProb(20, 40, 0.1, 1.0, seed=123, allow_multi_conn=allow_multi_conn)
31
+ y = m(x)
32
+ print(y)
33
+
34
+ m2 = brainstate.nn.EventFixedProb(20, 40, 0.1, brainstate.init.KaimingUniform(), seed=123)
35
+ print(m2(x))
36
+
37
+ def test_grad_bool(self):
38
+ n_in = 20
39
+ n_out = 30
40
+ x = jax.numpy.asarray(brainstate.random.rand(n_in) < 0.3, dtype=float)
41
+ fn = brainstate.nn.EventFixedProb(n_in, n_out, 0.1, brainstate.init.KaimingUniform(), seed=123)
42
+
43
+ def f(x):
44
+ return fn(x).sum()
45
+
46
+ print(jax.grad(f)(x))
47
+
48
+ @pytest.mark.parametrize('homo_w', [True, False])
49
+ def test_vjp(self, homo_w):
50
+ n_in = 20
51
+ n_out = 30
52
+ x = jax.numpy.asarray(brainstate.random.rand(n_in) < 0.3, dtype=float)
53
+
54
+ if homo_w:
55
+ fn = brainstate.nn.EventFixedProb(n_in, n_out, 0.1, 1.5, seed=123)
56
+ else:
57
+ fn = brainstate.nn.EventFixedProb(n_in, n_out, 0.1, brainstate.init.KaimingUniform(), seed=123)
58
+ w = fn.weight.value
59
+
60
+ def f(x, w):
61
+ fn.weight.value = w
62
+ return fn(x).sum()
63
+
64
+ r = brainstate.augment.grad(f, argnums=(0, 1))(x, w)
65
+
66
+ # -------------------
67
+ # TRUE gradients
68
+
69
+ def true_fn(x, w, indices, n_post):
70
+ post = jnp.zeros((n_post,))
71
+ for i in range(n_in):
72
+ post = post.at[indices[i]].add(w * x[i] if homo_w else w[i] * x[i])
73
+ return post
74
+
75
+ def f2(x, w):
76
+ return true_fn(x, w, fn.conn.indices, n_out).sum()
77
+
78
+ r2 = jax.grad(f2, argnums=(0, 1))(x, w)
79
+ assert (jnp.allclose(r[0], r2[0]))
80
+ assert (jnp.allclose(r[1], r2[1]))
81
+
82
+ @pytest.mark.parametrize('homo_w', [True, False])
83
+ def test_jvp(self, homo_w):
84
+ n_in = 20
85
+ n_out = 30
86
+ x = jax.numpy.asarray(brainstate.random.rand(n_in) < 0.3, dtype=float)
87
+
88
+ fn = brainstate.nn.EventFixedProb(
89
+ n_in, n_out, 0.1, 1.5 if homo_w else brainstate.init.KaimingUniform(),
90
+ seed=123,
91
+ )
92
+ w = fn.weight.value
93
+
94
+ def f(x, w):
95
+ fn.weight.value = w
96
+ return fn(x)
97
+
98
+ o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
99
+
100
+ # -------------------
101
+ # TRUE gradients
102
+
103
+ def true_fn(x, w, indices, n_post):
104
+ post = jnp.zeros((n_post,))
105
+ for i in range(n_in):
106
+ post = post.at[indices[i]].add(w * x[i] if homo_w else w[i] * x[i])
107
+ return post
108
+
109
+ def f2(x, w):
110
+ return true_fn(x, w, fn.conn.indices, n_out)
111
+
112
+ o2, r2 = jax.jvp(f2, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
113
+ assert (jnp.allclose(o1, o2))
114
+ # assert jnp.allclose(r1, r2), f'r1={r1}, r2={r2}'
115
+ assert (jnp.allclose(r1, r2, rtol=1e-4, atol=1e-4))
@@ -0,0 +1,85 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Union, Callable, Optional
19
+
20
+ import brainunit as u
21
+ import jax
22
+
23
+ from brainstate import init
24
+ from brainstate._compatible_import import brainevent
25
+ from brainstate._state import ParamState
26
+ from brainstate.nn._module import Module
27
+ from brainstate.typing import Size, ArrayLike
28
+
29
+ __all__ = [
30
+ 'EventLinear',
31
+ ]
32
+
33
+
34
+ class EventLinear(Module):
35
+ """
36
+
37
+ Parameters
38
+ ----------
39
+ in_size : Size
40
+ Number of pre-synaptic neurons, i.e., input size.
41
+ out_size : Size
42
+ Number of post-synaptic neurons, i.e., output size.
43
+ weight : float or callable or jax.Array or brainunit.Quantity
44
+ Maximum synaptic conductance.
45
+ block_size : int, optional
46
+ Block size for parallel computation.
47
+ float_as_event : bool, optional
48
+ Whether to treat float as event.
49
+ name : str, optional
50
+ Name of the module.
51
+ """
52
+
53
+ __module__ = 'brainstate.nn'
54
+
55
+ def __init__(
56
+ self,
57
+ in_size: Size,
58
+ out_size: Size,
59
+ weight: Union[Callable, ArrayLike],
60
+ float_as_event: bool = True,
61
+ block_size: int = 64,
62
+ name: Optional[str] = None,
63
+ param_type: type = ParamState,
64
+ ):
65
+ super().__init__(name=name)
66
+
67
+ # network parameters
68
+ self.in_size = in_size
69
+ self.out_size = out_size
70
+ self.float_as_event = float_as_event
71
+ self.block_size = block_size
72
+
73
+ # maximum synaptic conductance
74
+ weight = init.param(weight, (self.in_size[-1], self.out_size[-1]), allow_none=False)
75
+ self.weight = param_type(weight)
76
+
77
+ def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
78
+ weight = self.weight.value
79
+ if u.math.size(weight) == 1:
80
+ return u.math.ones(self.out_size) * (u.math.sum(spk) * weight)
81
+
82
+ if self.float_as_event:
83
+ return brainevent.EventArray(spk) @ weight
84
+ else:
85
+ return spk @ weight
@@ -0,0 +1,121 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import pytest
21
+
22
+ import brainstate
23
+
24
+
25
+ class TestEventLinear:
26
+ @pytest.mark.parametrize('bool_x', [True, False])
27
+ @pytest.mark.parametrize('homo_w', [True, False])
28
+ def test1(self, homo_w, bool_x):
29
+ x = brainstate.random.rand(20) < 0.1
30
+ if not bool_x:
31
+ x = jnp.asarray(x, dtype=float)
32
+ m = brainstate.nn.EventLinear(
33
+ 20, 40,
34
+ 1.5 if homo_w else brainstate.init.KaimingUniform(),
35
+ float_as_event=bool_x
36
+ )
37
+ y = m(x)
38
+ print(y)
39
+
40
+ assert (jnp.allclose(y, (x.sum() * m.weight.value) if homo_w else (x @ m.weight.value)))
41
+
42
+ def test_grad_bool(self):
43
+ n_in = 20
44
+ n_out = 30
45
+ x = brainstate.random.rand(n_in) < 0.3
46
+ fn = brainstate.nn.EventLinear(n_in, n_out, brainstate.init.KaimingUniform())
47
+
48
+ with pytest.raises(TypeError):
49
+ print(jax.grad(lambda x: fn(x).sum())(x))
50
+
51
+ @pytest.mark.parametrize('bool_x', [True, False])
52
+ @pytest.mark.parametrize('homo_w', [True, False])
53
+ def test_vjp(self, bool_x, homo_w):
54
+ n_in = 20
55
+ n_out = 30
56
+ if bool_x:
57
+ x = jax.numpy.asarray(brainstate.random.rand(n_in) < 0.3, dtype=float)
58
+ else:
59
+ x = brainstate.random.rand(n_in)
60
+
61
+ fn = brainstate.nn.EventLinear(
62
+ n_in,
63
+ n_out,
64
+ 1.5 if homo_w else brainstate.init.KaimingUniform(),
65
+ float_as_event=bool_x
66
+ )
67
+ w = fn.weight.value
68
+
69
+ def f(x, w):
70
+ fn.weight.value = w
71
+ return fn(x).sum()
72
+
73
+ r1 = jax.grad(f, argnums=(0, 1))(x, w)
74
+
75
+ # -------------------
76
+ # TRUE gradients
77
+
78
+ def f2(x, w):
79
+ y = (x @ (jnp.ones([n_in, n_out]) * w)) if homo_w else (x @ w)
80
+ return y.sum()
81
+
82
+ r2 = jax.grad(f2, argnums=(0, 1))(x, w)
83
+ assert (jnp.allclose(r1[0], r2[0]))
84
+
85
+ if not jnp.allclose(r1[1], r2[1]):
86
+ print(r1[1] - r2[1])
87
+
88
+ assert (jnp.allclose(r1[1], r2[1]))
89
+
90
+ @pytest.mark.parametrize('bool_x', [True, False])
91
+ @pytest.mark.parametrize('homo_w', [True, False])
92
+ def test_jvp(self, bool_x, homo_w):
93
+ n_in = 20
94
+ n_out = 30
95
+ if bool_x:
96
+ x = jax.numpy.asarray(brainstate.random.rand(n_in) < 0.3, dtype=float)
97
+ else:
98
+ x = brainstate.random.rand(n_in)
99
+
100
+ fn = brainstate.nn.EventLinear(
101
+ n_in, n_out, 1.5 if homo_w else brainstate.init.KaimingUniform(),
102
+ float_as_event=bool_x
103
+ )
104
+ w = fn.weight.value
105
+
106
+ def f(x, w):
107
+ fn.weight.value = w
108
+ return fn(x)
109
+
110
+ o1, r1 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
111
+
112
+ # -------------------
113
+ # TRUE gradients
114
+
115
+ def f2(x, w):
116
+ y = (x @ (jnp.ones([n_in, n_out]) * w)) if homo_w else (x @ w)
117
+ return y
118
+
119
+ o2, r2 = jax.jvp(f, (x, w), (jnp.ones_like(x), jnp.ones_like(w)))
120
+ assert (jnp.allclose(o1, o2))
121
+ assert (jnp.allclose(r1, r2))