brainstate 0.1.0.post20250420__py2.py3-none-any.whl → 0.1.0.post20250423__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. brainstate/_compatible_import.py +15 -0
  2. brainstate/_state.py +5 -4
  3. brainstate/_state_test.py +2 -1
  4. brainstate/augment/_autograd_test.py +3 -2
  5. brainstate/augment/_eval_shape.py +2 -1
  6. brainstate/augment/_mapping.py +0 -1
  7. brainstate/augment/_mapping_test.py +1 -0
  8. brainstate/compile/_ad_checkpoint.py +2 -1
  9. brainstate/compile/_conditions.py +3 -3
  10. brainstate/compile/_conditions_test.py +2 -1
  11. brainstate/compile/_error_if.py +2 -1
  12. brainstate/compile/_error_if_test.py +2 -1
  13. brainstate/compile/_jit.py +3 -2
  14. brainstate/compile/_jit_test.py +2 -1
  15. brainstate/compile/_loop_collect_return.py +2 -2
  16. brainstate/compile/_loop_collect_return_test.py +2 -1
  17. brainstate/compile/_loop_no_collection.py +1 -1
  18. brainstate/compile/_make_jaxpr.py +2 -2
  19. brainstate/compile/_make_jaxpr_test.py +2 -1
  20. brainstate/compile/_progress_bar.py +2 -1
  21. brainstate/compile/_unvmap.py +1 -2
  22. brainstate/environ.py +4 -4
  23. brainstate/environ_test.py +2 -1
  24. brainstate/functional/_activations.py +2 -1
  25. brainstate/functional/_activations_test.py +1 -1
  26. brainstate/functional/_normalization.py +2 -1
  27. brainstate/functional/_others.py +2 -1
  28. brainstate/graph/_graph_operation.py +3 -2
  29. brainstate/graph/_graph_operation_test.py +4 -3
  30. brainstate/init/_base.py +2 -1
  31. brainstate/init/_generic.py +2 -1
  32. brainstate/nn/__init__.py +4 -0
  33. brainstate/nn/_collective_ops.py +1 -0
  34. brainstate/nn/_collective_ops_test.py +0 -4
  35. brainstate/nn/_common.py +0 -1
  36. brainstate/nn/_dyn_impl/__init__.py +0 -4
  37. brainstate/nn/_dyn_impl/_dynamics_neuron.py +431 -13
  38. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +2 -1
  39. brainstate/nn/_dyn_impl/_dynamics_synapse.py +405 -103
  40. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +2 -1
  41. brainstate/nn/_dyn_impl/_inputs.py +236 -29
  42. brainstate/nn/_dyn_impl/_rate_rnns.py +238 -82
  43. brainstate/nn/_dyn_impl/_rate_rnns_test.py +2 -1
  44. brainstate/nn/_dyn_impl/_readout.py +91 -8
  45. brainstate/nn/_dyn_impl/_readout_test.py +2 -1
  46. brainstate/nn/_dynamics/_dynamics_base.py +676 -96
  47. brainstate/nn/_dynamics/_dynamics_base_test.py +2 -1
  48. brainstate/nn/_dynamics/_projection_base.py +29 -30
  49. brainstate/nn/_dynamics/_state_delay.py +3 -3
  50. brainstate/nn/_dynamics/_synouts_test.py +2 -1
  51. brainstate/nn/_elementwise/_dropout.py +3 -2
  52. brainstate/nn/_elementwise/_dropout_test.py +2 -1
  53. brainstate/nn/_elementwise/_elementwise.py +2 -1
  54. brainstate/nn/{_dyn_impl/_projection_alignpost.py → _event/__init__.py} +8 -7
  55. brainstate/nn/_event/_fixedprob_mv.py +169 -0
  56. brainstate/nn/_event/_fixedprob_mv_test.py +115 -0
  57. brainstate/nn/_event/_linear_mv.py +85 -0
  58. brainstate/nn/_event/_linear_mv_test.py +121 -0
  59. brainstate/nn/_exp_euler.py +2 -1
  60. brainstate/nn/_exp_euler_test.py +2 -1
  61. brainstate/nn/_interaction/_conv.py +2 -1
  62. brainstate/nn/_interaction/_linear.py +2 -1
  63. brainstate/nn/_interaction/_linear_test.py +2 -1
  64. brainstate/nn/_interaction/_normalizations.py +3 -2
  65. brainstate/nn/_interaction/_poolings.py +4 -3
  66. brainstate/nn/_module_test.py +2 -1
  67. brainstate/nn/metrics.py +4 -3
  68. brainstate/optim/_lr_scheduler.py +2 -1
  69. brainstate/optim/_lr_scheduler_test.py +2 -1
  70. brainstate/optim/_optax_optimizer_test.py +2 -1
  71. brainstate/optim/_sgd_optimizer.py +3 -2
  72. brainstate/random/_rand_funs.py +2 -1
  73. brainstate/random/_rand_funs_test.py +3 -2
  74. brainstate/random/_rand_seed.py +3 -2
  75. brainstate/random/_rand_seed_test.py +2 -1
  76. brainstate/random/_rand_state.py +4 -3
  77. brainstate/surrogate.py +1 -2
  78. brainstate/typing.py +4 -4
  79. brainstate/util/_caller.py +2 -1
  80. brainstate/util/_others.py +4 -4
  81. brainstate/util/_pretty_pytree.py +1 -1
  82. brainstate/util/_pretty_pytree_test.py +2 -1
  83. brainstate/util/_pretty_table.py +43 -43
  84. brainstate/util/_struct.py +2 -1
  85. brainstate/util/filter.py +0 -1
  86. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/METADATA +3 -3
  87. brainstate-0.1.0.post20250423.dist-info/RECORD +133 -0
  88. brainstate-0.1.0.post20250420.dist-info/RECORD +0 -129
  89. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/LICENSE +0 -0
  90. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/WHEEL +0 -0
  91. {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250423.dist-info}/top_level.txt +0 -0
@@ -17,9 +17,10 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import jax.numpy as jnp
21
20
  from typing import Callable, Union
22
21
 
22
+ import jax.numpy as jnp
23
+
23
24
  from brainstate import random, init, functional
24
25
  from brainstate._state import HiddenState, ParamState
25
26
  from brainstate.nn._interaction._linear import Linear
@@ -33,23 +34,83 @@ __all__ = [
33
34
 
34
35
  class RNNCell(Module):
35
36
  """
36
- Base class for RNN cells.
37
+ Base class for all recurrent neural network (RNN) cell implementations.
38
+
39
+ This abstract class serves as the foundation for implementing various RNN cell types
40
+ such as vanilla RNN, GRU, LSTM, and other recurrent architectures. It extends the
41
+ Module class and provides common functionality and interface for recurrent units.
42
+
43
+ All RNN cell implementations should inherit from this class and implement the required
44
+ methods, particularly the `init_state()`, `reset_state()`, and `update()` methods that
45
+ define the state initialization and recurrent dynamics.
46
+
47
+ The RNNCell typically maintains hidden state(s) that are updated at each time step
48
+ based on the current input and previous state values.
49
+
50
+ Methods
51
+ -------
52
+ init_state(batch_size=None, **kwargs)
53
+ Initialize the cell state variables with appropriate dimensions.
54
+ reset_state(batch_size=None, **kwargs)
55
+ Reset the cell state variables to their initial values.
56
+ update(x)
57
+ Update the cell state for one time step based on input x and return output.
37
58
  """
38
59
  pass
39
60
 
40
61
 
41
62
  class ValinaRNNCell(RNNCell):
42
- """
43
- Vanilla RNN cell.
44
-
45
- Args:
46
- num_in: int. The number of input units.
47
- num_out: int. The number of hidden units.
48
- state_init: callable, ArrayLike. The state initializer.
49
- w_init: callable, ArrayLike. The input weight initializer.
50
- b_init: optional, callable, ArrayLike. The bias weight initializer.
51
- activation: str, callable. The activation function. It can be a string or a callable function.
52
- name: optional, str. The name of the module.
63
+ r"""
64
+ Vanilla Recurrent Neural Network (RNN) cell implementation.
65
+
66
+ This class implements the basic RNN model that updates a hidden state based on
67
+ the current input and previous hidden state. The standard RNN cell follows the
68
+ mathematical formulation:
69
+
70
+ .. math::
71
+
72
+ h_t = \phi(W [x_t, h_{t-1}] + b)
73
+
74
+ where:
75
+
76
+ - :math:`x_t` is the input vector at time t
77
+ - :math:`h_t` is the hidden state at time t
78
+ - :math:`h_{t-1}` is the hidden state at previous time step
79
+ - :math:`W` is the weight matrix for the combined input-hidden linear transformation
80
+ - :math:`b` is the bias vector
81
+ - :math:`\phi` is the activation function
82
+
83
+ Parameters
84
+ ----------
85
+ num_in : int
86
+ The number of input units.
87
+ num_out : int
88
+ The number of hidden units.
89
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
90
+ Initializer for the hidden state.
91
+ w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
92
+ Initializer for the weight matrix.
93
+ b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
94
+ Initializer for the bias vector.
95
+ activation : str or Callable, default='relu'
96
+ Activation function to use. Can be a string (e.g., 'relu', 'tanh')
97
+ or a callable function.
98
+ name : str, optional
99
+ Name of the module.
100
+
101
+ State Variables
102
+ --------------
103
+ h : HiddenState
104
+ Hidden state of the RNN cell.
105
+
106
+ Methods
107
+ -------
108
+ init_state(batch_size=None, **kwargs)
109
+ Initialize the cell hidden state.
110
+ reset_state(batch_size=None, **kwargs)
111
+ Reset the cell hidden state to its initial value.
112
+ update(x)
113
+ Update the hidden state for one time step and return the new state.
53
114
  """
54
115
  __module__ = 'brainstate.nn'
55
116
 
@@ -96,17 +157,62 @@ class ValinaRNNCell(RNNCell):
96
157
 
97
158
 
98
159
  class GRUCell(RNNCell):
99
- """
100
- Gated Recurrent Unit (GRU) cell.
101
-
102
- Args:
103
- num_in: int. The number of input units.
104
- num_out: int. The number of hidden units.
105
- state_init: callable, ArrayLike. The state initializer.
106
- w_init: callable, ArrayLike. The input weight initializer.
107
- b_init: optional, callable, ArrayLike. The bias weight initializer.
108
- activation: str, callable. The activation function. It can be a string or a callable function.
109
- name: optional, str. The name of the module.
160
+ r"""
161
+ Gated Recurrent Unit (GRU) cell implementation.
162
+
163
+ This class implements the GRU model that uses gating mechanisms to control
164
+ information flow. The GRU has fewer parameters than LSTM as it combines
165
+ the forget and input gates into a single update gate. The GRU follows the
166
+ mathematical formulation:
167
+
168
+ .. math::
169
+
170
+ r_t &= \sigma(W_r [x_t, h_{t-1}] + b_r) \\
171
+ z_t &= \sigma(W_z [x_t, h_{t-1}] + b_z) \\
172
+ \tilde{h}_t &= \tanh(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\
173
+ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
174
+
175
+ where:
176
+
177
+ - :math:`x_t` is the input vector at time t
178
+ - :math:`h_t` is the hidden state at time t
179
+ - :math:`r_t` is the reset gate vector
180
+ - :math:`z_t` is the update gate vector
181
+ - :math:`\tilde{h}_t` is the candidate hidden state
182
+ - :math:`\odot` represents element-wise multiplication
183
+ - :math:`\sigma` is the sigmoid activation function
184
+
185
+ Parameters
186
+ ----------
187
+ num_in : int
188
+ The number of input units.
189
+ num_out : int
190
+ The number of hidden units.
191
+ w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
192
+ Initializer for the weight matrices.
193
+ b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
194
+ Initializer for the bias vectors.
195
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
196
+ Initializer for the hidden state.
197
+ activation : str or Callable, default='tanh'
198
+ Activation function to use. Can be a string (e.g., 'tanh')
199
+ or a callable function.
200
+ name : str, optional
201
+ Name of the module.
202
+
203
+ State Variables
204
+ --------------
205
+ h : HiddenState
206
+ Hidden state of the GRU cell.
207
+
208
+ Methods
209
+ -------
210
+ init_state(batch_size=None, **kwargs)
211
+ Initialize the cell hidden state.
212
+ reset_state(batch_size=None, **kwargs)
213
+ Reset the cell hidden state to its initial value.
214
+ update(x)
215
+ Update the hidden state for one time step and return the new state.
110
216
  """
111
217
  __module__ = 'brainstate.nn'
112
218
 
@@ -159,32 +265,60 @@ class GRUCell(RNNCell):
159
265
 
160
266
  class MGUCell(RNNCell):
161
267
  r"""
162
- Minimal Gated Recurrent Unit (MGU) cell.
268
+ Minimal Gated Recurrent Unit (MGU) cell implementation.
269
+
270
+ MGU is a simplified version of GRU that uses a single forget gate instead of
271
+ separate reset and update gates. This results in fewer parameters while
272
+ maintaining much of the gating capability. The MGU follows the mathematical
273
+ formulation:
163
274
 
164
275
  .. math::
165
276
 
166
- \begin{aligned}
167
- f_{t}&=\sigma (W_{f}x_{t}+U_{f}h_{t-1}+b_{f})\\
168
- {\hat {h}}_{t}&=\phi (W_{h}x_{t}+U_{h}(f_{t}\odot h_{t-1})+b_{h})\\
169
- h_{t}&=(1-f_{t})\odot h_{t-1}+f_{t}\odot {\hat {h}}_{t}
170
- \end{aligned}
277
+ f_t &= \\sigma(W_f [x_t, h_{t-1}] + b_f) \\\\
278
+ \\tilde{h}_t &= \\phi(W_h [x_t, (f_t \\odot h_{t-1})] + b_h) \\\\
279
+ h_t &= (1 - f_t) \\odot h_{t-1} + f_t \\odot \\tilde{h}_t
171
280
 
172
281
  where:
173
282
 
174
- - :math:`x_{t}`: input vector
175
- - :math:`h_{t}`: output vector
176
- - :math:`{\hat {h}}_{t}`: candidate activation vector
177
- - :math:`f_{t}`: forget vector
178
- - :math:`W, U, b`: parameter matrices and vector
179
-
180
- Args:
181
- num_in: int. The number of input units.
182
- num_out: int. The number of hidden units.
183
- state_init: callable, ArrayLike. The state initializer.
184
- w_init: callable, ArrayLike. The input weight initializer.
185
- b_init: optional, callable, ArrayLike. The bias weight initializer.
186
- activation: str, callable. The activation function. It can be a string or a callable function.
187
- name: optional, str. The name of the module.
283
+ - :math:`x_t` is the input vector at time t
284
+ - :math:`h_t` is the hidden state at time t
285
+ - :math:`f_t` is the forget gate vector
286
+ - :math:`\\tilde{h}_t` is the candidate hidden state
287
+ - :math:`\\odot` represents element-wise multiplication
288
+ - :math:`\\sigma` is the sigmoid activation function
289
+ - :math:`\\phi` is the activation function (typically tanh)
290
+
291
+ Parameters
292
+ ----------
293
+ num_in : int
294
+ The number of input units.
295
+ num_out : int
296
+ The number of hidden units.
297
+ w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
298
+ Initializer for the weight matrices.
299
+ b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
300
+ Initializer for the bias vectors.
301
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
302
+ Initializer for the hidden state.
303
+ activation : str or Callable, default='tanh'
304
+ Activation function to use. Can be a string (e.g., 'tanh')
305
+ or a callable function.
306
+ name : str, optional
307
+ Name of the module.
308
+
309
+ State Variables
310
+ --------------
311
+ h : HiddenState
312
+ Hidden state of the MGU cell.
313
+
314
+ Methods
315
+ -------
316
+ init_state(batch_size=None, **kwargs)
317
+ Initialize the cell hidden state.
318
+ reset_state(batch_size=None, **kwargs)
319
+ Reset the cell hidden state to its initial value.
320
+ update(x)
321
+ Update the hidden state for one time step and return the new state.
188
322
  """
189
323
  __module__ = 'brainstate.nn'
190
324
 
@@ -235,59 +369,81 @@ class MGUCell(RNNCell):
235
369
 
236
370
 
237
371
  class LSTMCell(RNNCell):
238
- r"""Long short-term memory (LSTM) RNN core.
372
+ r"""
373
+ Long Short-Term Memory (LSTM) cell implementation.
239
374
 
240
- The implementation is based on (zaremba, et al., 2014) [1]_. Given
241
- :math:`x_t` and the previous state :math:`(h_{t-1}, c_{t-1})` the core
242
- computes
375
+ This class implements the LSTM architecture which uses multiple gating mechanisms
376
+ to regulate information flow and address the vanishing gradient problem in RNNs.
377
+ The LSTM follows the mathematical formulation:
243
378
 
244
379
  .. math::
245
380
 
246
- \begin{array}{ll}
247
- i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\
248
- f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\
249
- g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\
250
- o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\
251
- c_t = f_t c_{t-1} + i_t g_t \\
252
- h_t = o_t \tanh(c_t)
253
- \end{array}
381
+ i_t &= \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\
382
+ f_t &= \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\
383
+ g_t &= \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\
384
+ o_t &= \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\
385
+ c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\
386
+ h_t &= o_t \odot \tanh(c_t)
254
387
 
255
- where :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and
256
- output gate activations, and :math:`g_t` is a vector of cell updates.
388
+ where:
257
389
 
258
- The output is equal to the new hidden, :math:`h_t`.
390
+ - :math:`x_t` is the input vector at time t
391
+ - :math:`h_t` is the hidden state at time t
392
+ - :math:`c_t` is the cell state at time t
393
+ - :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and output gate activations
394
+ - :math:`g_t` is the cell update vector
395
+ - :math:`\odot` represents element-wise multiplication
396
+ - :math:`\sigma` is the sigmoid activation function
259
397
 
260
398
  Notes
261
399
  -----
262
-
263
- Forget gate initialization: Following (Jozefowicz, et al., 2015) [2]_ we add 1.0
264
- to :math:`b_f` after initialization in order to reduce the scale of forgetting in
265
- the beginning of the training.
266
-
400
+ Forget gate initialization: Following Jozefowicz et al. (2015), we add 1.0
401
+ to the forget gate bias after initialization to reduce forgetting at the
402
+ beginning of training.
267
403
 
268
404
  Parameters
269
405
  ----------
270
- num_in: int
271
- The dimension of the input vector
272
- num_out: int
273
- The number of hidden unit in the node.
274
- state_init: callable, ArrayLike
275
- The state initializer.
276
- w_init: callable, ArrayLike
277
- The input weight initializer.
278
- b_init: optional, callable, ArrayLike
279
- The bias weight initializer.
280
- activation: str, callable
281
- The activation function. It can be a string or a callable function.
406
+ num_in : int
407
+ The number of input units.
408
+ num_out : int
409
+ The number of hidden/cell units.
410
+ w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
411
+ Initializer for the weight matrices.
412
+ b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
413
+ Initializer for the bias vectors.
414
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
415
+ Initializer for the hidden and cell states.
416
+ activation : str or Callable, default='tanh'
417
+ Activation function to use. Can be a string (e.g., 'tanh')
418
+ or a callable function.
419
+ name : str, optional
420
+ Name of the module.
421
+
422
+ State Variables
423
+ --------------
424
+ h : HiddenState
425
+ Hidden state of the LSTM cell.
426
+ c : HiddenState
427
+ Cell state of the LSTM cell.
428
+
429
+ Methods
430
+ -------
431
+ init_state(batch_size=None, **kwargs)
432
+ Initialize the cell and hidden states.
433
+ reset_state(batch_size=None, **kwargs)
434
+ Reset the cell and hidden states to their initial values.
435
+ update(x)
436
+ Update the states for one time step and return the new hidden state.
282
437
 
283
438
  References
284
439
  ----------
285
-
286
- .. [1] Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. "Recurrent neural
287
- network regularization." arXiv preprint arXiv:1409.2329 (2014).
288
- .. [2] Jozefowicz, Rafal, Wojciech Zaremba, and Ilya Sutskever. "An empirical
289
- exploration of recurrent network architectures." In International conference
290
- on machine learning, pp. 2342-2350. PMLR, 2015.
440
+ .. [1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.
441
+ Neural computation, 9(8), 1735-1780.
442
+ .. [2] Zaremba, W., Sutskever, I., & Vinyals, O. (2014). Recurrent neural
443
+ network regularization. arXiv preprint arXiv:1409.2329.
444
+ .. [3] Jozefowicz, R., Zaremba, W., & Sutskever, I. (2015). An empirical
445
+ exploration of recurrent network architectures. In International
446
+ conference on machine learning, pp. 2342-2350.
291
447
  """
292
448
  __module__ = 'brainstate.nn'
293
449
 
@@ -15,9 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import jax.numpy as jnp
19
18
  import unittest
20
19
 
20
+ import jax.numpy as jnp
21
+
21
22
  import brainstate as bst
22
23
 
23
24
 
@@ -17,11 +17,12 @@
17
17
 
18
18
  from __future__ import annotations
19
19
 
20
- import brainunit as u
21
- import jax
22
20
  import numbers
23
21
  from typing import Callable
24
22
 
23
+ import brainunit as u
24
+ import jax
25
+
25
26
  from brainstate import environ, init, surrogate
26
27
  from brainstate._state import HiddenState, ParamState
27
28
  from brainstate.nn._exp_euler import exp_euler_step
@@ -36,8 +37,45 @@ __all__ = [
36
37
 
37
38
 
38
39
  class LeakyRateReadout(Module):
39
- """
40
- Leaky dynamics for the read-out module used in the Real-Time Recurrent Learning.
40
+ r"""
41
+ Leaky dynamics for the read-out module.
42
+
43
+ This module implements a leaky integrator with the following dynamics:
44
+
45
+ .. math::
46
+ r_{t} = \alpha r_{t-1} + x_{t} W
47
+
48
+ where:
49
+ - :math:`r_{t}` is the output at time t
50
+ - :math:`\alpha = e^{-\Delta t / \tau}` is the decay factor
51
+ - :math:`x_{t}` is the input at time t
52
+ - :math:`W` is the weight matrix
53
+
54
+ The leaky integrator acts as a low-pass filter, allowing the network
55
+ to maintain memory of past inputs with an exponential decay determined
56
+ by the time constant tau.
57
+
58
+ Parameters
59
+ ----------
60
+ in_size : int or sequence of int
61
+ Size of the input dimension(s)
62
+ out_size : int or sequence of int
63
+ Size of the output dimension(s)
64
+ tau : ArrayLike, optional
65
+ Time constant of the leaky dynamics, by default 5ms
66
+ w_init : Callable, optional
67
+ Weight initialization function, by default KaimingNormal()
68
+ name : str, optional
69
+ Name of the module, by default None
70
+
71
+ Attributes
72
+ ----------
73
+ decay : float
74
+ Decay factor computed as exp(-dt/tau)
75
+ weight : ParamState
76
+ Weight matrix connecting input to output
77
+ r : HiddenState
78
+ Hidden state representing the output values
41
79
  """
42
80
  __module__ = 'brainstate.nn'
43
81
 
@@ -72,8 +110,53 @@ class LeakyRateReadout(Module):
72
110
 
73
111
 
74
112
  class LeakySpikeReadout(Neuron):
75
- """
76
- Integrate-and-fire neuron model with leaky dynamics.
113
+ r"""
114
+ Integrate-and-fire neuron model with leaky dynamics for readout functionality.
115
+
116
+ This class implements a spiking neuron with the following dynamics:
117
+
118
+ .. math::
119
+ \frac{dV}{dt} = \frac{-V + I_{in}}{\tau}
120
+
121
+ where:
122
+ - :math:`V` is the membrane potential
123
+ - :math:`\tau` is the membrane time constant
124
+ - :math:`I_{in}` is the input current
125
+
126
+ Spike generation occurs when :math:`V > V_{th}` according to:
127
+
128
+ .. math::
129
+ S_t = \text{surrogate}\left(\frac{V - V_{th}}{V_{th}}\right)
130
+
131
+ After spiking, the membrane potential is reset according to the reset mode:
132
+ - Soft reset: :math:`V \leftarrow V - V_{th} \cdot S_t`
133
+ - Hard reset: :math:`V \leftarrow V - V_t \cdot S_t` (where :math:`V_t` is detached)
134
+
135
+ Parameters
136
+ ----------
137
+ in_size : Size
138
+ Size of the input dimension
139
+ tau : ArrayLike, optional
140
+ Membrane time constant, by default 5ms
141
+ V_th : ArrayLike, optional
142
+ Spike threshold, by default 1mV
143
+ w_init : Callable, optional
144
+ Weight initialization function, by default KaimingNormal(unit=mV)
145
+ V_initializer : ArrayLike, optional
146
+ Initial membrane potential, by default ZeroInit(unit=mV)
147
+ spk_fun : Callable, optional
148
+ Surrogate gradient function for spike generation, by default ReluGrad()
149
+ spk_reset : str, optional
150
+ Reset mechanism after spike ('soft' or 'hard'), by default 'soft'
151
+ name : str, optional
152
+ Name of the module, by default None
153
+
154
+ Attributes
155
+ ----------
156
+ V : HiddenState
157
+ Membrane potential state variable
158
+ weight : ParamState
159
+ Synaptic weight matrix
77
160
  """
78
161
 
79
162
  __module__ = 'brainstate.nn'
@@ -92,8 +175,8 @@ class LeakySpikeReadout(Neuron):
92
175
  super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
93
176
 
94
177
  # parameters
95
- self.tau = init.param(tau, (self.varshape,))
96
- self.V_th = init.param(V_th, (self.varshape,))
178
+ self.tau = init.param(tau, self.varshape)
179
+ self.V_th = init.param(V_th, self.varshape)
97
180
  self.V_initializer = V_initializer
98
181
 
99
182
  # weights
@@ -15,9 +15,10 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- import jax.numpy as jnp
19
18
  import unittest
20
19
 
20
+ import jax.numpy as jnp
21
+
21
22
  import brainstate as bst
22
23
 
23
24