brainstate 0.1.8__py2.py3-none-any.whl → 0.1.10__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. brainstate/__init__.py +58 -51
  2. brainstate/_compatible_import.py +148 -148
  3. brainstate/_state.py +1605 -1663
  4. brainstate/_state_test.py +52 -52
  5. brainstate/_utils.py +47 -47
  6. brainstate/augment/__init__.py +30 -30
  7. brainstate/augment/_autograd.py +778 -778
  8. brainstate/augment/_autograd_test.py +1289 -1289
  9. brainstate/augment/_eval_shape.py +99 -99
  10. brainstate/augment/_eval_shape_test.py +38 -38
  11. brainstate/augment/_mapping.py +1060 -1060
  12. brainstate/augment/_mapping_test.py +597 -597
  13. brainstate/augment/_random.py +151 -151
  14. brainstate/compile/__init__.py +38 -38
  15. brainstate/compile/_ad_checkpoint.py +204 -204
  16. brainstate/compile/_ad_checkpoint_test.py +49 -49
  17. brainstate/compile/_conditions.py +256 -256
  18. brainstate/compile/_conditions_test.py +220 -220
  19. brainstate/compile/_error_if.py +92 -92
  20. brainstate/compile/_error_if_test.py +52 -52
  21. brainstate/compile/_jit.py +346 -346
  22. brainstate/compile/_jit_test.py +143 -143
  23. brainstate/compile/_loop_collect_return.py +536 -536
  24. brainstate/compile/_loop_collect_return_test.py +58 -58
  25. brainstate/compile/_loop_no_collection.py +184 -184
  26. brainstate/compile/_loop_no_collection_test.py +50 -50
  27. brainstate/compile/_make_jaxpr.py +888 -888
  28. brainstate/compile/_make_jaxpr_test.py +156 -156
  29. brainstate/compile/_progress_bar.py +202 -202
  30. brainstate/compile/_unvmap.py +159 -159
  31. brainstate/compile/_util.py +147 -147
  32. brainstate/environ.py +563 -563
  33. brainstate/environ_test.py +62 -62
  34. brainstate/functional/__init__.py +27 -26
  35. brainstate/graph/__init__.py +29 -29
  36. brainstate/graph/_graph_node.py +244 -244
  37. brainstate/graph/_graph_node_test.py +73 -73
  38. brainstate/graph/_graph_operation.py +1738 -1738
  39. brainstate/graph/_graph_operation_test.py +563 -563
  40. brainstate/init/__init__.py +26 -26
  41. brainstate/init/_base.py +52 -52
  42. brainstate/init/_generic.py +244 -244
  43. brainstate/init/_random_inits.py +553 -553
  44. brainstate/init/_random_inits_test.py +149 -149
  45. brainstate/init/_regular_inits.py +105 -105
  46. brainstate/init/_regular_inits_test.py +50 -50
  47. brainstate/mixin.py +365 -363
  48. brainstate/mixin_test.py +77 -73
  49. brainstate/nn/__init__.py +135 -131
  50. brainstate/{functional → nn}/_activations.py +808 -813
  51. brainstate/{functional → nn}/_activations_test.py +331 -331
  52. brainstate/nn/_collective_ops.py +514 -514
  53. brainstate/nn/_collective_ops_test.py +43 -43
  54. brainstate/nn/_common.py +178 -178
  55. brainstate/nn/_conv.py +501 -501
  56. brainstate/nn/_conv_test.py +238 -238
  57. brainstate/nn/_delay.py +588 -502
  58. brainstate/nn/_delay_test.py +238 -184
  59. brainstate/nn/_dropout.py +426 -426
  60. brainstate/nn/_dropout_test.py +100 -100
  61. brainstate/nn/_dynamics.py +1343 -1343
  62. brainstate/nn/_dynamics_test.py +78 -78
  63. brainstate/nn/_elementwise.py +1119 -1119
  64. brainstate/nn/_elementwise_test.py +169 -169
  65. brainstate/nn/_embedding.py +58 -58
  66. brainstate/nn/_exp_euler.py +92 -92
  67. brainstate/nn/_exp_euler_test.py +35 -35
  68. brainstate/nn/_fixedprob.py +239 -239
  69. brainstate/nn/_fixedprob_test.py +114 -114
  70. brainstate/nn/_inputs.py +608 -608
  71. brainstate/nn/_linear.py +424 -424
  72. brainstate/nn/_linear_mv.py +83 -83
  73. brainstate/nn/_linear_mv_test.py +120 -120
  74. brainstate/nn/_linear_test.py +107 -107
  75. brainstate/nn/_ltp.py +28 -28
  76. brainstate/nn/_module.py +377 -377
  77. brainstate/nn/_module_test.py +40 -40
  78. brainstate/nn/_neuron.py +705 -705
  79. brainstate/nn/_neuron_test.py +161 -161
  80. brainstate/nn/_normalizations.py +975 -918
  81. brainstate/nn/_normalizations_test.py +73 -73
  82. brainstate/{functional → nn}/_others.py +46 -46
  83. brainstate/nn/_poolings.py +1177 -1177
  84. brainstate/nn/_poolings_test.py +217 -217
  85. brainstate/nn/_projection.py +486 -486
  86. brainstate/nn/_rate_rnns.py +554 -554
  87. brainstate/nn/_rate_rnns_test.py +63 -63
  88. brainstate/nn/_readout.py +209 -209
  89. brainstate/nn/_readout_test.py +53 -53
  90. brainstate/nn/_stp.py +236 -236
  91. brainstate/nn/_synapse.py +505 -505
  92. brainstate/nn/_synapse_test.py +131 -131
  93. brainstate/nn/_synaptic_projection.py +423 -423
  94. brainstate/nn/_synouts.py +162 -162
  95. brainstate/nn/_synouts_test.py +57 -57
  96. brainstate/nn/_utils.py +89 -89
  97. brainstate/nn/metrics.py +388 -388
  98. brainstate/optim/__init__.py +38 -38
  99. brainstate/optim/_base.py +64 -64
  100. brainstate/optim/_lr_scheduler.py +448 -448
  101. brainstate/optim/_lr_scheduler_test.py +50 -50
  102. brainstate/optim/_optax_optimizer.py +152 -152
  103. brainstate/optim/_optax_optimizer_test.py +53 -53
  104. brainstate/optim/_sgd_optimizer.py +1104 -1104
  105. brainstate/random/__init__.py +24 -24
  106. brainstate/random/_rand_funs.py +3616 -3616
  107. brainstate/random/_rand_funs_test.py +567 -567
  108. brainstate/random/_rand_seed.py +210 -210
  109. brainstate/random/_rand_seed_test.py +48 -48
  110. brainstate/random/_rand_state.py +1409 -1409
  111. brainstate/random/_random_for_unit.py +52 -52
  112. brainstate/surrogate.py +1957 -1957
  113. brainstate/transform.py +23 -23
  114. brainstate/typing.py +304 -304
  115. brainstate/util/__init__.py +50 -50
  116. brainstate/util/caller.py +98 -98
  117. brainstate/util/error.py +55 -55
  118. brainstate/util/filter.py +469 -469
  119. brainstate/util/others.py +540 -540
  120. brainstate/util/pretty_pytree.py +945 -945
  121. brainstate/util/pretty_pytree_test.py +159 -159
  122. brainstate/util/pretty_repr.py +328 -328
  123. brainstate/util/pretty_table.py +2954 -2954
  124. brainstate/util/scaling.py +258 -258
  125. brainstate/util/struct.py +523 -523
  126. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/METADATA +91 -99
  127. brainstate-0.1.10.dist-info/RECORD +130 -0
  128. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/WHEEL +1 -1
  129. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info/licenses}/LICENSE +202 -202
  130. brainstate/functional/_normalization.py +0 -81
  131. brainstate/functional/_spikes.py +0 -204
  132. brainstate-0.1.8.dist-info/RECORD +0 -132
  133. {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/top_level.txt +0 -0
@@ -1,63 +1,63 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import unittest
18
-
19
- import jax.numpy as jnp
20
-
21
- import brainstate
22
-
23
-
24
- class TestRateRNNModels(unittest.TestCase):
25
- def setUp(self):
26
- self.num_in = 3
27
- self.num_out = 3
28
- self.batch_size = 4
29
- self.x = jnp.ones((self.batch_size, self.num_in))
30
-
31
- def test_ValinaRNNCell(self):
32
- model = brainstate.nn.ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
33
- model.init_state(batch_size=self.batch_size)
34
- output = model.update(self.x)
35
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
36
-
37
- def test_GRUCell(self):
38
- model = brainstate.nn.GRUCell(num_in=self.num_in, num_out=self.num_out)
39
- model.init_state(batch_size=self.batch_size)
40
- output = model.update(self.x)
41
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
42
-
43
- def test_MGUCell(self):
44
- model = brainstate.nn.MGUCell(num_in=self.num_in, num_out=self.num_out)
45
- model.init_state(batch_size=self.batch_size)
46
- output = model.update(self.x)
47
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
48
-
49
- def test_LSTMCell(self):
50
- model = brainstate.nn.LSTMCell(num_in=self.num_in, num_out=self.num_out)
51
- model.init_state(batch_size=self.batch_size)
52
- output = model.update(self.x)
53
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
54
-
55
- def test_URLSTMCell(self):
56
- model = brainstate.nn.URLSTMCell(num_in=self.num_in, num_out=self.num_out)
57
- model.init_state(batch_size=self.batch_size)
58
- output = model.update(self.x)
59
- self.assertEqual(output.shape, (self.batch_size, self.num_out))
60
-
61
-
62
- if __name__ == '__main__':
63
- unittest.main()
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import unittest
18
+
19
+ import jax.numpy as jnp
20
+
21
+ import brainstate
22
+
23
+
24
+ class TestRateRNNModels(unittest.TestCase):
25
+ def setUp(self):
26
+ self.num_in = 3
27
+ self.num_out = 3
28
+ self.batch_size = 4
29
+ self.x = jnp.ones((self.batch_size, self.num_in))
30
+
31
+ def test_ValinaRNNCell(self):
32
+ model = brainstate.nn.ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
33
+ model.init_state(batch_size=self.batch_size)
34
+ output = model.update(self.x)
35
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
36
+
37
+ def test_GRUCell(self):
38
+ model = brainstate.nn.GRUCell(num_in=self.num_in, num_out=self.num_out)
39
+ model.init_state(batch_size=self.batch_size)
40
+ output = model.update(self.x)
41
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
42
+
43
+ def test_MGUCell(self):
44
+ model = brainstate.nn.MGUCell(num_in=self.num_in, num_out=self.num_out)
45
+ model.init_state(batch_size=self.batch_size)
46
+ output = model.update(self.x)
47
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
48
+
49
+ def test_LSTMCell(self):
50
+ model = brainstate.nn.LSTMCell(num_in=self.num_in, num_out=self.num_out)
51
+ model.init_state(batch_size=self.batch_size)
52
+ output = model.update(self.x)
53
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
54
+
55
+ def test_URLSTMCell(self):
56
+ model = brainstate.nn.URLSTMCell(num_in=self.num_in, num_out=self.num_out)
57
+ model.init_state(batch_size=self.batch_size)
58
+ output = model.update(self.x)
59
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
60
+
61
+
62
+ if __name__ == '__main__':
63
+ unittest.main()
brainstate/nn/_readout.py CHANGED
@@ -1,209 +1,209 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- import numbers
20
- from typing import Callable
21
-
22
- import brainunit as u
23
- import jax
24
-
25
- from brainstate import environ, init, surrogate
26
- from brainstate._state import HiddenState, ParamState
27
- from brainstate.typing import Size, ArrayLike
28
- from ._exp_euler import exp_euler_step
29
- from ._module import Module
30
- from ._neuron import Neuron
31
-
32
- __all__ = [
33
- 'LeakyRateReadout',
34
- 'LeakySpikeReadout',
35
- ]
36
-
37
-
38
- class LeakyRateReadout(Module):
39
- r"""
40
- Leaky dynamics for the read-out module.
41
-
42
- This module implements a leaky integrator with the following dynamics:
43
-
44
- .. math::
45
- r_{t} = \alpha r_{t-1} + x_{t} W
46
-
47
- where:
48
- - :math:`r_{t}` is the output at time t
49
- - :math:`\alpha = e^{-\Delta t / \tau}` is the decay factor
50
- - :math:`x_{t}` is the input at time t
51
- - :math:`W` is the weight matrix
52
-
53
- The leaky integrator acts as a low-pass filter, allowing the network
54
- to maintain memory of past inputs with an exponential decay determined
55
- by the time constant tau.
56
-
57
- Parameters
58
- ----------
59
- in_size : int or sequence of int
60
- Size of the input dimension(s)
61
- out_size : int or sequence of int
62
- Size of the output dimension(s)
63
- tau : ArrayLike, optional
64
- Time constant of the leaky dynamics, by default 5ms
65
- w_init : Callable, optional
66
- Weight initialization function, by default KaimingNormal()
67
- name : str, optional
68
- Name of the module, by default None
69
-
70
- Attributes
71
- ----------
72
- decay : float
73
- Decay factor computed as exp(-dt/tau)
74
- weight : ParamState
75
- Weight matrix connecting input to output
76
- r : HiddenState
77
- Hidden state representing the output values
78
- """
79
- __module__ = 'brainstate.nn'
80
-
81
- def __init__(
82
- self,
83
- in_size: Size,
84
- out_size: Size,
85
- tau: ArrayLike = 5. * u.ms,
86
- w_init: Callable = init.KaimingNormal(),
87
- name: str = None,
88
- ):
89
- super().__init__(name=name)
90
-
91
- # parameters
92
- self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
93
- self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
94
- self.tau = init.param(tau, self.in_size)
95
- self.decay = u.math.exp(-environ.get_dt() / self.tau)
96
-
97
- # weights
98
- self.weight = ParamState(init.param(w_init, (self.in_size[0], self.out_size[0])))
99
-
100
- def init_state(self, batch_size=None, **kwargs):
101
- self.r = HiddenState(init.param(init.Constant(0.), self.out_size, batch_size))
102
-
103
- def reset_state(self, batch_size=None, **kwargs):
104
- self.r.value = init.param(init.Constant(0.), self.out_size, batch_size)
105
-
106
- def update(self, x):
107
- self.r.value = self.decay * self.r.value + x @ self.weight.value
108
- return self.r.value
109
-
110
-
111
- class LeakySpikeReadout(Neuron):
112
- r"""
113
- Integrate-and-fire neuron model with leaky dynamics for readout functionality.
114
-
115
- This class implements a spiking neuron with the following dynamics:
116
-
117
- .. math::
118
- \frac{dV}{dt} = \frac{-V + I_{in}}{\tau}
119
-
120
- where:
121
- - :math:`V` is the membrane potential
122
- - :math:`\tau` is the membrane time constant
123
- - :math:`I_{in}` is the input current
124
-
125
- Spike generation occurs when :math:`V > V_{th}` according to:
126
-
127
- .. math::
128
- S_t = \text{surrogate}\left(\frac{V - V_{th}}{V_{th}}\right)
129
-
130
- After spiking, the membrane potential is reset according to the reset mode:
131
- - Soft reset: :math:`V \leftarrow V - V_{th} \cdot S_t`
132
- - Hard reset: :math:`V \leftarrow V - V_t \cdot S_t` (where :math:`V_t` is detached)
133
-
134
- Parameters
135
- ----------
136
- in_size : Size
137
- Size of the input dimension
138
- tau : ArrayLike, optional
139
- Membrane time constant, by default 5ms
140
- V_th : ArrayLike, optional
141
- Spike threshold, by default 1mV
142
- w_init : Callable, optional
143
- Weight initialization function, by default KaimingNormal(unit=mV)
144
- V_initializer : ArrayLike, optional
145
- Initial membrane potential, by default ZeroInit(unit=mV)
146
- spk_fun : Callable, optional
147
- Surrogate gradient function for spike generation, by default ReluGrad()
148
- spk_reset : str, optional
149
- Reset mechanism after spike ('soft' or 'hard'), by default 'soft'
150
- name : str, optional
151
- Name of the module, by default None
152
-
153
- Attributes
154
- ----------
155
- V : HiddenState
156
- Membrane potential state variable
157
- weight : ParamState
158
- Synaptic weight matrix
159
- """
160
-
161
- __module__ = 'brainstate.nn'
162
-
163
- def __init__(
164
- self,
165
- in_size: Size,
166
- tau: ArrayLike = 5. * u.ms,
167
- V_th: ArrayLike = 1. * u.mV,
168
- w_init: Callable = init.KaimingNormal(unit=u.mV),
169
- V_initializer: ArrayLike = init.ZeroInit(unit=u.mV),
170
- spk_fun: Callable = surrogate.ReluGrad(),
171
- spk_reset: str = 'soft',
172
- name: str = None,
173
- ):
174
- super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
175
-
176
- # parameters
177
- self.tau = init.param(tau, self.varshape)
178
- self.V_th = init.param(V_th, self.varshape)
179
- self.V_initializer = V_initializer
180
-
181
- # weights
182
- self.weight = ParamState(init.param(w_init, (self.in_size[-1], self.out_size[-1])))
183
-
184
- def init_state(self, batch_size, **kwargs):
185
- self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
186
-
187
- def reset_state(self, batch_size, **kwargs):
188
- self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
189
-
190
- @property
191
- def spike(self):
192
- return self.get_spike(self.V.value)
193
-
194
- def get_spike(self, V):
195
- v_scaled = (V - self.V_th) / self.V_th
196
- return self.spk_fun(v_scaled)
197
-
198
- def update(self, spk):
199
- # reset
200
- last_V = self.V.value
201
- last_spike = self.get_spike(last_V)
202
- V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_V)
203
- V = last_V - V_th * last_spike
204
- # membrane potential
205
- x = spk @ self.weight.value
206
- dv = lambda v: (-v + self.sum_current_inputs(x, v)) / self.tau
207
- V = exp_euler_step(dv, V)
208
- self.V.value = self.sum_delta_inputs(V)
209
- return self.get_spike(V)
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+
19
+ import numbers
20
+ from typing import Callable
21
+
22
+ import brainunit as u
23
+ import jax
24
+
25
+ from brainstate import environ, init, surrogate
26
+ from brainstate._state import HiddenState, ParamState
27
+ from brainstate.typing import Size, ArrayLike
28
+ from ._exp_euler import exp_euler_step
29
+ from ._module import Module
30
+ from ._neuron import Neuron
31
+
32
+ __all__ = [
33
+ 'LeakyRateReadout',
34
+ 'LeakySpikeReadout',
35
+ ]
36
+
37
+
38
+ class LeakyRateReadout(Module):
39
+ r"""
40
+ Leaky dynamics for the read-out module.
41
+
42
+ This module implements a leaky integrator with the following dynamics:
43
+
44
+ .. math::
45
+ r_{t} = \alpha r_{t-1} + x_{t} W
46
+
47
+ where:
48
+ - :math:`r_{t}` is the output at time t
49
+ - :math:`\alpha = e^{-\Delta t / \tau}` is the decay factor
50
+ - :math:`x_{t}` is the input at time t
51
+ - :math:`W` is the weight matrix
52
+
53
+ The leaky integrator acts as a low-pass filter, allowing the network
54
+ to maintain memory of past inputs with an exponential decay determined
55
+ by the time constant tau.
56
+
57
+ Parameters
58
+ ----------
59
+ in_size : int or sequence of int
60
+ Size of the input dimension(s)
61
+ out_size : int or sequence of int
62
+ Size of the output dimension(s)
63
+ tau : ArrayLike, optional
64
+ Time constant of the leaky dynamics, by default 5ms
65
+ w_init : Callable, optional
66
+ Weight initialization function, by default KaimingNormal()
67
+ name : str, optional
68
+ Name of the module, by default None
69
+
70
+ Attributes
71
+ ----------
72
+ decay : float
73
+ Decay factor computed as exp(-dt/tau)
74
+ weight : ParamState
75
+ Weight matrix connecting input to output
76
+ r : HiddenState
77
+ Hidden state representing the output values
78
+ """
79
+ __module__ = 'brainstate.nn'
80
+
81
+ def __init__(
82
+ self,
83
+ in_size: Size,
84
+ out_size: Size,
85
+ tau: ArrayLike = 5. * u.ms,
86
+ w_init: Callable = init.KaimingNormal(),
87
+ name: str = None,
88
+ ):
89
+ super().__init__(name=name)
90
+
91
+ # parameters
92
+ self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
93
+ self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
94
+ self.tau = init.param(tau, self.in_size)
95
+ self.decay = u.math.exp(-environ.get_dt() / self.tau)
96
+
97
+ # weights
98
+ self.weight = ParamState(init.param(w_init, (self.in_size[0], self.out_size[0])))
99
+
100
+ def init_state(self, batch_size=None, **kwargs):
101
+ self.r = HiddenState(init.param(init.Constant(0.), self.out_size, batch_size))
102
+
103
+ def reset_state(self, batch_size=None, **kwargs):
104
+ self.r.value = init.param(init.Constant(0.), self.out_size, batch_size)
105
+
106
+ def update(self, x):
107
+ self.r.value = self.decay * self.r.value + x @ self.weight.value
108
+ return self.r.value
109
+
110
+
111
+ class LeakySpikeReadout(Neuron):
112
+ r"""
113
+ Integrate-and-fire neuron model with leaky dynamics for readout functionality.
114
+
115
+ This class implements a spiking neuron with the following dynamics:
116
+
117
+ .. math::
118
+ \frac{dV}{dt} = \frac{-V + I_{in}}{\tau}
119
+
120
+ where:
121
+ - :math:`V` is the membrane potential
122
+ - :math:`\tau` is the membrane time constant
123
+ - :math:`I_{in}` is the input current
124
+
125
+ Spike generation occurs when :math:`V > V_{th}` according to:
126
+
127
+ .. math::
128
+ S_t = \text{surrogate}\left(\frac{V - V_{th}}{V_{th}}\right)
129
+
130
+ After spiking, the membrane potential is reset according to the reset mode:
131
+ - Soft reset: :math:`V \leftarrow V - V_{th} \cdot S_t`
132
+ - Hard reset: :math:`V \leftarrow V - V_t \cdot S_t` (where :math:`V_t` is detached)
133
+
134
+ Parameters
135
+ ----------
136
+ in_size : Size
137
+ Size of the input dimension
138
+ tau : ArrayLike, optional
139
+ Membrane time constant, by default 5ms
140
+ V_th : ArrayLike, optional
141
+ Spike threshold, by default 1mV
142
+ w_init : Callable, optional
143
+ Weight initialization function, by default KaimingNormal(unit=mV)
144
+ V_initializer : ArrayLike, optional
145
+ Initial membrane potential, by default ZeroInit(unit=mV)
146
+ spk_fun : Callable, optional
147
+ Surrogate gradient function for spike generation, by default ReluGrad()
148
+ spk_reset : str, optional
149
+ Reset mechanism after spike ('soft' or 'hard'), by default 'soft'
150
+ name : str, optional
151
+ Name of the module, by default None
152
+
153
+ Attributes
154
+ ----------
155
+ V : HiddenState
156
+ Membrane potential state variable
157
+ weight : ParamState
158
+ Synaptic weight matrix
159
+ """
160
+
161
+ __module__ = 'brainstate.nn'
162
+
163
+ def __init__(
164
+ self,
165
+ in_size: Size,
166
+ tau: ArrayLike = 5. * u.ms,
167
+ V_th: ArrayLike = 1. * u.mV,
168
+ w_init: Callable = init.KaimingNormal(unit=u.mV),
169
+ V_initializer: ArrayLike = init.ZeroInit(unit=u.mV),
170
+ spk_fun: Callable = surrogate.ReluGrad(),
171
+ spk_reset: str = 'soft',
172
+ name: str = None,
173
+ ):
174
+ super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
175
+
176
+ # parameters
177
+ self.tau = init.param(tau, self.varshape)
178
+ self.V_th = init.param(V_th, self.varshape)
179
+ self.V_initializer = V_initializer
180
+
181
+ # weights
182
+ self.weight = ParamState(init.param(w_init, (self.in_size[-1], self.out_size[-1])))
183
+
184
+ def init_state(self, batch_size, **kwargs):
185
+ self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
186
+
187
+ def reset_state(self, batch_size, **kwargs):
188
+ self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
189
+
190
+ @property
191
+ def spike(self):
192
+ return self.get_spike(self.V.value)
193
+
194
+ def get_spike(self, V):
195
+ v_scaled = (V - self.V_th) / self.V_th
196
+ return self.spk_fun(v_scaled)
197
+
198
+ def update(self, spk):
199
+ # reset
200
+ last_V = self.V.value
201
+ last_spike = self.get_spike(last_V)
202
+ V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_V)
203
+ V = last_V - V_th * last_spike
204
+ # membrane potential
205
+ x = spk @ self.weight.value
206
+ dv = lambda v: (-v + self.sum_current_inputs(x, v)) / self.tau
207
+ V = exp_euler_step(dv, V)
208
+ self.V.value = self.sum_delta_inputs(V)
209
+ return self.get_spike(V)
@@ -1,53 +1,53 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- import unittest
18
-
19
- import jax.numpy as jnp
20
-
21
- import brainstate
22
-
23
-
24
- class TestReadoutModels(unittest.TestCase):
25
- def setUp(self):
26
- self.in_size = 3
27
- self.out_size = 3
28
- self.batch_size = 4
29
- self.tau = 5.0
30
- self.V_th = 1.0
31
- self.x = jnp.ones((self.batch_size, self.in_size))
32
-
33
- def test_LeakyRateReadout(self):
34
- with brainstate.environ.context(dt=0.1):
35
- model = brainstate.nn.LeakyRateReadout(in_size=self.in_size, out_size=self.out_size, tau=self.tau)
36
- model.init_state(batch_size=self.batch_size)
37
- output = model.update(self.x)
38
- self.assertEqual(output.shape, (self.batch_size, self.out_size))
39
-
40
- def test_LeakySpikeReadout(self):
41
- with brainstate.environ.context(dt=0.1):
42
- model = brainstate.nn.LeakySpikeReadout(in_size=self.in_size, tau=self.tau, V_th=self.V_th,
43
- V_initializer=brainstate.init.ZeroInit(),
44
- w_init=brainstate.init.KaimingNormal())
45
- model.init_state(batch_size=self.batch_size)
46
- with brainstate.environ.context(t=0.):
47
- output = model.update(self.x)
48
- self.assertEqual(output.shape, (self.batch_size, self.out_size))
49
-
50
-
51
- if __name__ == '__main__':
52
- with brainstate.environ.context(dt=0.1):
53
- unittest.main()
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import unittest
18
+
19
+ import jax.numpy as jnp
20
+
21
+ import brainstate
22
+
23
+
24
+ class TestReadoutModels(unittest.TestCase):
25
+ def setUp(self):
26
+ self.in_size = 3
27
+ self.out_size = 3
28
+ self.batch_size = 4
29
+ self.tau = 5.0
30
+ self.V_th = 1.0
31
+ self.x = jnp.ones((self.batch_size, self.in_size))
32
+
33
+ def test_LeakyRateReadout(self):
34
+ with brainstate.environ.context(dt=0.1):
35
+ model = brainstate.nn.LeakyRateReadout(in_size=self.in_size, out_size=self.out_size, tau=self.tau)
36
+ model.init_state(batch_size=self.batch_size)
37
+ output = model.update(self.x)
38
+ self.assertEqual(output.shape, (self.batch_size, self.out_size))
39
+
40
+ def test_LeakySpikeReadout(self):
41
+ with brainstate.environ.context(dt=0.1):
42
+ model = brainstate.nn.LeakySpikeReadout(in_size=self.in_size, tau=self.tau, V_th=self.V_th,
43
+ V_initializer=brainstate.init.ZeroInit(),
44
+ w_init=brainstate.init.KaimingNormal())
45
+ model.init_state(batch_size=self.batch_size)
46
+ with brainstate.environ.context(t=0.):
47
+ output = model.update(self.x)
48
+ self.assertEqual(output.shape, (self.batch_size, self.out_size))
49
+
50
+
51
+ if __name__ == '__main__':
52
+ with brainstate.environ.context(dt=0.1):
53
+ unittest.main()