brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_neuron.py DELETED
@@ -1,705 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from typing import Callable, Optional
19
-
20
- import brainunit as u
21
- import jax
22
-
23
- from brainstate import init, surrogate, environ
24
- from brainstate._state import HiddenState, ShortTermState
25
- from brainstate.typing import ArrayLike, Size
26
- from ._dynamics import Dynamics
27
- from ._exp_euler import exp_euler_step
28
-
29
- __all__ = [
30
- 'Neuron', 'IF', 'LIF', 'LIFRef', 'ALIF',
31
- ]
32
-
33
-
34
- class Neuron(Dynamics):
35
- """
36
- Base class for all spiking neuron models.
37
-
38
- This abstract class serves as the foundation for implementing various spiking neuron
39
- models. It extends the Dynamics class and provides common functionality for spike
40
- generation and membrane potential dynamics.
41
-
42
- All neuron models should inherit from this class and implement the required methods,
43
- particularly the `get_spike()` method which defines the spike generation mechanism.
44
-
45
- Parameters
46
- ----------
47
- in_size : Size
48
- Size of the input to the neuron.
49
- spk_fun : Callable, default=surrogate.InvSquareGrad()
50
- Surrogate gradient function for the non-differentiable spike generation.
51
- spk_reset : str, default='soft'
52
- Reset mechanism after spike generation:
53
- - 'soft': subtract threshold from membrane potential
54
- - 'hard': use stop_gradient for reset
55
- name : str, optional
56
- Name of the neuron layer.
57
-
58
- Methods
59
- -------
60
- get_spike(*args, **kwargs)
61
- Abstract method that generates spikes based on neuron state variables.
62
- Must be implemented by subclasses.
63
- """
64
- __module__ = 'brainstate.nn'
65
-
66
- def __init__(
67
- self,
68
- in_size: Size,
69
- spk_fun: Callable = surrogate.InvSquareGrad(),
70
- spk_reset: str = 'soft',
71
- name: Optional[str] = None,
72
- ):
73
- super().__init__(in_size, name=name)
74
- self.spk_reset = spk_reset
75
- self.spk_fun = spk_fun
76
-
77
- def get_spike(self, *args, **kwargs):
78
- raise NotImplementedError
79
-
80
-
81
- class IF(Neuron):
82
- r"""Integrate-and-Fire (IF) neuron model.
83
-
84
- This class implements the classic Integrate-and-Fire neuron model, one of the simplest
85
- spiking neuron models. It accumulates input current until the membrane potential reaches
86
- a threshold, at which point it fires a spike and resets the potential.
87
-
88
- The model is characterized by the following differential equation:
89
-
90
- $$
91
- \tau \frac{dV}{dt} = -V + R \cdot I(t)
92
- $$
93
-
94
- Spike condition:
95
- If $V \geq V_{th}$: emit spike and reset $V = V - V_{th}$ (soft reset) or $V = 0$ (hard reset)
96
-
97
- Parameters
98
- ----------
99
- in_size : Size
100
- Size of the input to the neuron.
101
- R : ArrayLike, default=1. * u.ohm
102
- Membrane resistance.
103
- tau : ArrayLike, default=5. * u.ms
104
- Membrane time constant.
105
- V_th : ArrayLike, default=1. * u.mV
106
- Firing threshold voltage (should be positive).
107
- V_initializer : Callable, default=init.Constant(0. * u.mV)
108
- Initializer for the membrane potential state.
109
- spk_fun : Callable, default=surrogate.ReluGrad()
110
- Surrogate gradient function for the non-differentiable spike generation.
111
- spk_reset : str, default='soft'
112
- Reset mechanism after spike generation:
113
- - 'soft': subtract threshold V = V - V_th
114
- - 'hard': strict reset using stop_gradient
115
- name : str, optional
116
- Name of the neuron layer.
117
-
118
- Attributes
119
- ----------
120
- V : HiddenState
121
- Membrane potential.
122
-
123
- Methods
124
- -------
125
- init_state(batch_size=None, **kwargs)
126
- Initialize the neuron state variables.
127
- reset_state(batch_size=None, **kwargs)
128
- Reset the neuron state variables.
129
- get_spike(V=None)
130
- Generate spikes based on the membrane potential.
131
- update(x=0. * u.mA)
132
- Update the neuron state for one time step and return spikes.
133
-
134
- Examples
135
- --------
136
- >>> import brainstate as bs
137
- >>> import brainunit as u
138
- >>>
139
- >>> # Create an IF neuron layer with 10 neurons
140
- >>> if_neuron = bs.nn.IF(10, tau=8*u.ms, V_th=1.2*u.mV)
141
- >>>
142
- >>> # Initialize the state
143
- >>> if_neuron.init_state(batch_size=1)
144
- >>>
145
- >>> # Apply an input current and update the neuron state
146
- >>> spikes = if_neuron.update(x=2.0*u.mA)
147
- >>>
148
- >>> # Create a network with IF neurons
149
- >>> network = bs.nn.Sequential([
150
- ... bs.nn.IF(100, tau=5.0*u.ms),
151
- ... bs.nn.Linear(100, 10)
152
- ... ])
153
-
154
- Notes
155
- -----
156
- - Unlike the LIF model, the IF model has no leak towards a resting potential.
157
- - The membrane potential decays exponentially with time constant tau in the absence of input.
158
- - The time-dependent dynamics are integrated using an exponential Euler method.
159
- - The IF model is perfect integrator in the sense that it accumulates input indefinitely
160
- until reaching threshold, without any leak current.
161
-
162
- References
163
- ----------
164
- .. [1] Lapicque, L. (1907). Recherches quantitatives sur l'excitation électrique
165
- des nerfs traitée comme une polarisation. Journal de Physiologie et de
166
- Pathologie Générale, 9, 620-635.
167
- .. [2] Abbott, L. F. (1999). Lapicque's introduction of the integrate-and-fire
168
- model neuron (1907). Brain Research Bulletin, 50(5-6), 303-304.
169
- .. [3] Burkitt, A. N. (2006). A review of the integrate-and-fire neuron model:
170
- I. Homogeneous synaptic input. Biological cybernetics, 95(1), 1-19.
171
- """
172
-
173
- __module__ = 'brainstate.nn'
174
-
175
- def __init__(
176
- self,
177
- in_size: Size,
178
- R: ArrayLike = 1. * u.ohm,
179
- tau: ArrayLike = 5. * u.ms,
180
- V_th: ArrayLike = 1. * u.mV, # should be positive
181
- V_initializer: Callable = init.Constant(0. * u.mV),
182
- spk_fun: Callable = surrogate.ReluGrad(),
183
- spk_reset: str = 'soft',
184
- name: str = None,
185
- ):
186
- super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
187
-
188
- # parameters
189
- self.R = init.param(R, self.varshape)
190
- self.tau = init.param(tau, self.varshape)
191
- self.V_th = init.param(V_th, self.varshape)
192
- self.V_initializer = V_initializer
193
-
194
- def init_state(self, batch_size: int = None, **kwargs):
195
- self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
196
-
197
- def reset_state(self, batch_size: int = None, **kwargs):
198
- self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
199
-
200
- def get_spike(self, V=None):
201
- V = self.V.value if V is None else V
202
- v_scaled = (V - self.V_th) / self.V_th
203
- return self.spk_fun(v_scaled)
204
-
205
- def update(self, x=0. * u.mA):
206
- # reset
207
- last_V = self.V.value
208
- last_spike = self.get_spike(self.V.value)
209
- V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_V)
210
- V = last_V - V_th * last_spike
211
- # membrane potential
212
- dv = lambda v: (-v + self.R * self.sum_current_inputs(x, v)) / self.tau
213
- V = exp_euler_step(dv, V)
214
- V = self.sum_delta_inputs(V)
215
- self.V.value = V
216
- return self.get_spike(V)
217
-
218
-
219
- class LIF(Neuron):
220
- r"""Leaky Integrate-and-Fire (LIF) neuron model.
221
-
222
- This class implements the Leaky Integrate-and-Fire neuron model, which extends the basic
223
- Integrate-and-Fire model by adding a leak term. The leak causes the membrane potential
224
- to decay towards a resting value in the absence of input, making the model more
225
- biologically plausible.
226
-
227
- The model is characterized by the following differential equation:
228
-
229
- $$
230
- \tau \frac{dV}{dt} = -(V - V_{rest}) + R \cdot I(t)
231
- $$
232
-
233
- Spike condition:
234
- If $V \geq V_{th}$: emit spike and reset $V = V_{reset}$
235
-
236
- Parameters
237
- ----------
238
- in_size : Size
239
- Size of the input to the neuron.
240
- R : ArrayLike, default=1. * u.ohm
241
- Membrane resistance.
242
- tau : ArrayLike, default=5. * u.ms
243
- Membrane time constant.
244
- V_th : ArrayLike, default=1. * u.mV
245
- Firing threshold voltage.
246
- V_reset : ArrayLike, default=0. * u.mV
247
- Reset voltage after spike.
248
- V_rest : ArrayLike, default=0. * u.mV
249
- Resting membrane potential.
250
- V_initializer : Callable, default=init.Constant(0. * u.mV)
251
- Initializer for the membrane potential state.
252
- spk_fun : Callable, default=surrogate.ReluGrad()
253
- Surrogate gradient function for the non-differentiable spike generation.
254
- spk_reset : str, default='soft'
255
- Reset mechanism after spike generation:
256
- - 'soft': subtract threshold V = V - V_th
257
- - 'hard': strict reset using stop_gradient
258
- name : str, optional
259
- Name of the neuron layer.
260
-
261
- Attributes
262
- ----------
263
- V : HiddenState
264
- Membrane potential.
265
-
266
- Methods
267
- -------
268
- init_state(batch_size=None, **kwargs)
269
- Initialize the neuron state variables.
270
- reset_state(batch_size=None, **kwargs)
271
- Reset the neuron state variables.
272
- get_spike(V=None)
273
- Generate spikes based on the membrane potential.
274
- update(x=0. * u.mA)
275
- Update the neuron state for one time step and return spikes.
276
-
277
- Examples
278
- --------
279
- >>> import brainstate
280
- >>> import brainunit as u
281
- >>>
282
- >>> # Create a LIF neuron layer with 10 neurons
283
- >>> lif = brainstate.nn.LIF(10, tau=10*u.ms, V_th=0.8*u.mV)
284
- >>>
285
- >>> # Initialize the state
286
- >>> lif.init_state(batch_size=1)
287
- >>>
288
- >>> # Apply an input current and update the neuron state
289
- >>> spikes = lif.update(x=1.5*u.mA)
290
-
291
- Notes
292
- -----
293
- - The leak term causes the membrane potential to decay exponentially towards V_rest
294
- with time constant tau when no input is present.
295
- - The time-dependent dynamics are integrated using an exponential Euler method.
296
- - Spike generation is non-differentiable, so surrogate gradients are used for
297
- backpropagation during training.
298
-
299
- References
300
- ----------
301
- .. [1] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014).
302
- Neuronal dynamics: From single neurons to networks and models of cognition.
303
- Cambridge University Press.
304
- .. [2] Burkitt, A. N. (2006). A review of the integrate-and-fire neuron model:
305
- I. Homogeneous synaptic input. Biological cybernetics, 95(1), 1-19.
306
- """
307
- __module__ = 'brainstate.nn'
308
-
309
- def __init__(
310
- self,
311
- in_size: Size,
312
- R: ArrayLike = 1. * u.ohm,
313
- tau: ArrayLike = 5. * u.ms,
314
- V_th: ArrayLike = 1. * u.mV,
315
- V_reset: ArrayLike = 0. * u.mV,
316
- V_rest: ArrayLike = 0. * u.mV,
317
- V_initializer: Callable = init.Constant(0. * u.mV),
318
- spk_fun: Callable = surrogate.ReluGrad(),
319
- spk_reset: str = 'soft',
320
- name: str = None,
321
- ):
322
- super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
323
-
324
- # parameters
325
- self.R = init.param(R, self.varshape)
326
- self.tau = init.param(tau, self.varshape)
327
- self.V_th = init.param(V_th, self.varshape)
328
- self.V_rest = init.param(V_rest, self.varshape)
329
- self.V_reset = init.param(V_reset, self.varshape)
330
- self.V_initializer = V_initializer
331
-
332
- def init_state(self, batch_size: int = None, **kwargs):
333
- self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
334
-
335
- def reset_state(self, batch_size: int = None, **kwargs):
336
- self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
337
-
338
- def get_spike(self, V: ArrayLike = None):
339
- V = self.V.value if V is None else V
340
- v_scaled = (V - self.V_th) / (self.V_th - self.V_reset)
341
- return self.spk_fun(v_scaled)
342
-
343
- def update(self, x=0. * u.mA):
344
- last_v = self.V.value
345
- lst_spk = self.get_spike(last_v)
346
- V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v)
347
- V = last_v - (V_th - self.V_reset) * lst_spk
348
- # membrane potential
349
- dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau
350
- V = exp_euler_step(dv, V)
351
- V = self.sum_delta_inputs(V)
352
- self.V.value = V
353
- return self.get_spike(V)
354
-
355
-
356
- class LIFRef(Neuron):
357
- r"""Leaky Integrate-and-Fire neuron model with refractory period.
358
-
359
- This class implements a Leaky Integrate-and-Fire neuron model that includes a
360
- refractory period after spiking, during which the neuron cannot fire regardless
361
- of input. This better captures the behavior of biological neurons that exhibit
362
- a recovery period after action potential generation.
363
-
364
- The model is characterized by the following equations:
365
-
366
- When not in refractory period:
367
- $$
368
- \tau \frac{dV}{dt} = -(V - V_{rest}) + R \cdot I(t)
369
- $$
370
-
371
- During refractory period:
372
- $$
373
- V = V_{reset}
374
- $$
375
-
376
- Spike condition:
377
- If $V \geq V_{th}$: emit spike, set $V = V_{reset}$, and enter refractory period for $\tau_{ref}$
378
-
379
- Parameters
380
- ----------
381
- in_size : Size
382
- Size of the input to the neuron.
383
- R : ArrayLike, default=1. * u.ohm
384
- Membrane resistance.
385
- tau : ArrayLike, default=5. * u.ms
386
- Membrane time constant.
387
- tau_ref : ArrayLike, default=5. * u.ms
388
- Refractory period duration.
389
- V_th : ArrayLike, default=1. * u.mV
390
- Firing threshold voltage.
391
- V_reset : ArrayLike, default=0. * u.mV
392
- Reset voltage after spike.
393
- V_rest : ArrayLike, default=0. * u.mV
394
- Resting membrane potential.
395
- V_initializer : Callable, default=init.Constant(0. * u.mV)
396
- Initializer for the membrane potential state.
397
- spk_fun : Callable, default=surrogate.ReluGrad()
398
- Surrogate gradient function for the non-differentiable spike generation.
399
- spk_reset : str, default='soft'
400
- Reset mechanism after spike generation:
401
- - 'soft': subtract threshold V = V - V_th
402
- - 'hard': strict reset using stop_gradient
403
- name : str, optional
404
- Name of the neuron layer.
405
-
406
- Attributes
407
- ----------
408
- V : HiddenState
409
- Membrane potential.
410
- last_spike_time : ShortTermState
411
- Time of the last spike, used to implement refractory period.
412
-
413
- Methods
414
- -------
415
- init_state(batch_size=None, **kwargs)
416
- Initialize the neuron state variables.
417
- reset_state(batch_size=None, **kwargs)
418
- Reset the neuron state variables.
419
- get_spike(V=None)
420
- Generate spikes based on the membrane potential.
421
- update(x=0. * u.mA)
422
- Update the neuron state for one time step and return spikes.
423
-
424
- Examples
425
- --------
426
- >>> import brainstate as bs
427
- >>> import brainunit as u
428
- >>>
429
- >>> # Create a LIFRef neuron layer with 10 neurons
430
- >>> lifref = bs.nn.LIFRef(10,
431
- ... tau=10*u.ms,
432
- ... tau_ref=5*u.ms,
433
- ... V_th=0.8*u.mV)
434
- >>>
435
- >>> # Initialize the state
436
- >>> lifref.init_state(batch_size=1)
437
- >>>
438
- >>> # Apply an input current and update the neuron state
439
- >>> spikes = lifref.update(x=1.5*u.mA)
440
- >>>
441
- >>> # Create a network with refractory neurons
442
- >>> network = bs.nn.Sequential([
443
- ... bs.nn.LIFRef(100, tau_ref=4*u.ms),
444
- ... bs.nn.Linear(100, 10)
445
- ... ])
446
-
447
- Notes
448
- -----
449
- - The refractory period is implemented by tracking the time of the last spike
450
- and preventing membrane potential updates if the elapsed time is less than tau_ref.
451
- - During the refractory period, the membrane potential remains at the reset value
452
- regardless of input current strength.
453
- - Refractory periods prevent high-frequency repetitive firing and are critical
454
- for realistic neural dynamics.
455
- - The time-dependent dynamics are integrated using an exponential Euler method.
456
- - The simulation environment time variable 't' is used to track the refractory state.
457
-
458
- References
459
- ----------
460
- .. [1] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014).
461
- Neuronal dynamics: From single neurons to networks and models of cognition.
462
- Cambridge University Press.
463
- .. [2] Burkitt, A. N. (2006). A review of the integrate-and-fire neuron model:
464
- I. Homogeneous synaptic input. Biological cybernetics, 95(1), 1-19.
465
- .. [3] Izhikevich, E. M. (2003). Simple model of spiking neurons. IEEE Transactions on
466
- neural networks, 14(6), 1569-1572.
467
- """
468
- __module__ = 'brainstate.nn'
469
-
470
- def __init__(
471
- self,
472
- in_size: Size,
473
- R: ArrayLike = 1. * u.ohm,
474
- tau: ArrayLike = 5. * u.ms,
475
- tau_ref: ArrayLike = 5. * u.ms,
476
- V_th: ArrayLike = 1. * u.mV,
477
- V_reset: ArrayLike = 0. * u.mV,
478
- V_rest: ArrayLike = 0. * u.mV,
479
- V_initializer: Callable = init.Constant(0. * u.mV),
480
- spk_fun: Callable = surrogate.ReluGrad(),
481
- spk_reset: str = 'soft',
482
- name: str = None,
483
- ):
484
- super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
485
-
486
- # parameters
487
- self.R = init.param(R, self.varshape)
488
- self.tau = init.param(tau, self.varshape)
489
- self.tau_ref = init.param(tau_ref, self.varshape)
490
- self.V_th = init.param(V_th, self.varshape)
491
- self.V_rest = init.param(V_rest, self.varshape)
492
- self.V_reset = init.param(V_reset, self.varshape)
493
- self.V_initializer = V_initializer
494
-
495
- def init_state(self, batch_size: int = None, **kwargs):
496
- self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
497
- self.last_spike_time = ShortTermState(init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_size))
498
-
499
- def reset_state(self, batch_size: int = None, **kwargs):
500
- self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
501
- self.last_spike_time.value = init.param(init.Constant(-1e7 * u.ms), self.varshape, batch_size)
502
-
503
- def get_spike(self, V: ArrayLike = None):
504
- V = self.V.value if V is None else V
505
- v_scaled = (V - self.V_th) / (self.V_th - self.V_reset)
506
- return self.spk_fun(v_scaled)
507
-
508
- def update(self, x=0. * u.mA):
509
- t = environ.get('t')
510
- last_v = self.V.value
511
- lst_spk = self.get_spike(last_v)
512
- V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v)
513
- last_v = last_v - (V_th - self.V_reset) * lst_spk
514
- # membrane potential
515
- dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau
516
- V = exp_euler_step(dv, last_v)
517
- V = self.sum_delta_inputs(V)
518
- self.V.value = u.math.where(t - self.last_spike_time.value < self.tau_ref, last_v, V)
519
- # spike time evaluation
520
- lst_spk_time = u.math.where(self.V.value >= self.V_th, environ.get('t'), self.last_spike_time.value)
521
- self.last_spike_time.value = jax.lax.stop_gradient(lst_spk_time)
522
- return self.get_spike()
523
-
524
-
525
- class ALIF(Neuron):
526
- r"""Adaptive Leaky Integrate-and-Fire (ALIF) neuron model.
527
-
528
- This class implements the Adaptive Leaky Integrate-and-Fire neuron model, which extends
529
- the basic LIF model by adding an adaptation variable. This adaptation mechanism increases
530
- the effective firing threshold after each spike, allowing the neuron to exhibit
531
- spike-frequency adaptation - a common feature in biological neurons that reduces
532
- firing rate during sustained stimulation.
533
-
534
- The model is characterized by the following differential equations:
535
-
536
- $$
537
- \tau \frac{dV}{dt} = -(V - V_{rest}) + R \cdot I(t)
538
- $$
539
-
540
- $$
541
- \tau_a \frac{da}{dt} = -a
542
- $$
543
-
544
- Spike condition:
545
- If $V \geq V_{th} + \beta \cdot a$: emit spike, set $V = V_{reset}$, and increment $a = a + 1$
546
-
547
- Parameters
548
- ----------
549
- in_size : Size
550
- Size of the input to the neuron.
551
- R : ArrayLike, default=1. * u.ohm
552
- Membrane resistance.
553
- tau : ArrayLike, default=5. * u.ms
554
- Membrane time constant.
555
- tau_a : ArrayLike, default=100. * u.ms
556
- Adaptation time constant (typically much longer than tau).
557
- V_th : ArrayLike, default=1. * u.mV
558
- Base firing threshold voltage.
559
- V_reset : ArrayLike, default=0. * u.mV
560
- Reset voltage after spike.
561
- V_rest : ArrayLike, default=0. * u.mV
562
- Resting membrane potential.
563
- beta : ArrayLike, default=0.1 * u.mV
564
- Adaptation coupling parameter that scales the effect of the adaptation variable.
565
- spk_fun : Callable, default=surrogate.ReluGrad()
566
- Surrogate gradient function for the non-differentiable spike generation.
567
- spk_reset : str, default='soft'
568
- Reset mechanism after spike generation:
569
- - 'soft': subtract threshold V = V - V_th
570
- - 'hard': strict reset using stop_gradient
571
- V_initializer : Callable, default=init.Constant(0. * u.mV)
572
- Initializer for the membrane potential state.
573
- a_initializer : Callable, default=init.Constant(0.)
574
- Initializer for the adaptation variable.
575
- name : str, optional
576
- Name of the neuron layer.
577
-
578
- Attributes
579
- ----------
580
- V : HiddenState
581
- Membrane potential.
582
- a : HiddenState
583
- Adaptation variable that increases after each spike and decays exponentially.
584
-
585
- Methods
586
- -------
587
- init_state(batch_size=None, **kwargs)
588
- Initialize the neuron state variables.
589
- reset_state(batch_size=None, **kwargs)
590
- Reset the neuron state variables.
591
- get_spike(V=None, a=None)
592
- Generate spikes based on the membrane potential and adaptation variable.
593
- update(x=0. * u.mA)
594
- Update the neuron state for one time step and return spikes.
595
-
596
- Examples
597
- --------
598
- >>> import brainstate as bs
599
- >>> import brainunit as u
600
- >>>
601
- >>> # Create an ALIF neuron layer with 10 neurons
602
- >>> alif = bs.nn.ALIF(10,
603
- ... tau=10*u.ms,
604
- ... tau_a=200*u.ms,
605
- ... beta=0.2*u.mV)
606
- >>>
607
- >>> # Initialize the state
608
- >>> alif.init_state(batch_size=1)
609
- >>>
610
- >>> # Apply an input current and update the neuron state
611
- >>> spikes = alif.update(x=1.5*u.mA)
612
- >>>
613
- >>> # Create a network with adaptation for burst detection
614
- >>> network = bs.nn.Sequential([
615
- ... bs.nn.ALIF(100, tau_a=150*u.ms, beta=0.3*u.mV),
616
- ... bs.nn.Linear(100, 10)
617
- ... ])
618
-
619
- Notes
620
- -----
621
- - The adaptation variable 'a' increases by 1 with each spike and decays exponentially
622
- with time constant tau_a between spikes.
623
- - The effective threshold increases by beta*a, making it progressively harder for the
624
- neuron to fire when it has recently been active.
625
- - This adaptation mechanism creates spike-frequency adaptation, allowing the neuron
626
- to respond strongly to input onset but then reduce its firing rate even if the
627
- input remains constant.
628
- - The adaptation time constant tau_a is typically much larger than the membrane time
629
- constant tau, creating a longer-lasting adaptation effect.
630
- - The time-dependent dynamics are integrated using an exponential Euler method.
631
-
632
- References
633
- ----------
634
- .. [1] Gerstner, W., Kistler, W. M., Naud, R., & Paninski, L. (2014).
635
- Neuronal dynamics: From single neurons to networks and models of cognition.
636
- Cambridge University Press.
637
- .. [2] Brette, R., & Gerstner, W. (2005). Adaptive exponential integrate-and-fire model
638
- as an effective description of neuronal activity. Journal of neurophysiology,
639
- 94(5), 3637-3642.
640
- .. [3] Naud, R., Marcille, N., Clopath, C., & Gerstner, W. (2008). Firing patterns in
641
- the adaptive exponential integrate-and-fire model. Biological cybernetics,
642
- 99(4), 335-347.
643
- """
644
- __module__ = 'brainstate.nn'
645
-
646
- def __init__(
647
- self,
648
- in_size: Size,
649
- R: ArrayLike = 1. * u.ohm,
650
- tau: ArrayLike = 5. * u.ms,
651
- tau_a: ArrayLike = 100. * u.ms,
652
- V_th: ArrayLike = 1. * u.mV,
653
- V_reset: ArrayLike = 0. * u.mV,
654
- V_rest: ArrayLike = 0. * u.mV,
655
- beta: ArrayLike = 0.1 * u.mV,
656
- spk_fun: Callable = surrogate.ReluGrad(),
657
- spk_reset: str = 'soft',
658
- V_initializer: Callable = init.Constant(0. * u.mV),
659
- a_initializer: Callable = init.Constant(0.),
660
- name: str = None,
661
- ):
662
- super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
663
-
664
- # parameters
665
- self.R = init.param(R, self.varshape)
666
- self.tau = init.param(tau, self.varshape)
667
- self.tau_a = init.param(tau_a, self.varshape)
668
- self.V_th = init.param(V_th, self.varshape)
669
- self.V_reset = init.param(V_reset, self.varshape)
670
- self.V_rest = init.param(V_rest, self.varshape)
671
- self.beta = init.param(beta, self.varshape)
672
-
673
- # functions
674
- self.V_initializer = V_initializer
675
- self.a_initializer = a_initializer
676
-
677
- def init_state(self, batch_size: int = None, **kwargs):
678
- self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
679
- self.a = HiddenState(init.param(self.a_initializer, self.varshape, batch_size))
680
-
681
- def reset_state(self, batch_size: int = None, **kwargs):
682
- self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
683
- self.a.value = init.param(self.a_initializer, self.varshape, batch_size)
684
-
685
- def get_spike(self, V=None, a=None):
686
- V = self.V.value if V is None else V
687
- a = self.a.value if a is None else a
688
- v_scaled = (V - self.V_th - self.beta * a) / (self.V_th - self.V_reset)
689
- return self.spk_fun(v_scaled)
690
-
691
- def update(self, x=0. * u.mA):
692
- last_v = self.V.value
693
- last_a = self.a.value
694
- lst_spk = self.get_spike(last_v, last_a)
695
- V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_v)
696
- V = last_v - (V_th - self.V_reset) * lst_spk
697
- a = last_a + lst_spk
698
- # membrane potential
699
- dv = lambda v: (-v + self.V_rest + self.R * self.sum_current_inputs(x, v)) / self.tau
700
- da = lambda a: -a / self.tau_a
701
- V = exp_euler_step(dv, V)
702
- a = exp_euler_step(da, a)
703
- self.V.value = self.sum_delta_inputs(V)
704
- self.a.value = a
705
- return self.get_spike(self.V.value, self.a.value)