brainstate 0.1.0.post20250420__py2.py3-none-any.whl → 0.1.0.post20250422__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_compatible_import.py +15 -0
- brainstate/_state.py +5 -4
- brainstate/_state_test.py +2 -1
- brainstate/augment/_autograd_test.py +3 -2
- brainstate/augment/_eval_shape.py +2 -1
- brainstate/augment/_mapping.py +0 -1
- brainstate/augment/_mapping_test.py +1 -0
- brainstate/compile/_ad_checkpoint.py +2 -1
- brainstate/compile/_conditions.py +3 -3
- brainstate/compile/_conditions_test.py +2 -1
- brainstate/compile/_error_if.py +2 -1
- brainstate/compile/_error_if_test.py +2 -1
- brainstate/compile/_jit.py +3 -2
- brainstate/compile/_jit_test.py +2 -1
- brainstate/compile/_loop_collect_return.py +2 -2
- brainstate/compile/_loop_collect_return_test.py +2 -1
- brainstate/compile/_loop_no_collection.py +1 -1
- brainstate/compile/_make_jaxpr.py +2 -2
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_progress_bar.py +2 -1
- brainstate/compile/_unvmap.py +1 -2
- brainstate/environ.py +4 -4
- brainstate/environ_test.py +2 -1
- brainstate/functional/_activations.py +2 -1
- brainstate/functional/_activations_test.py +1 -1
- brainstate/functional/_normalization.py +2 -1
- brainstate/functional/_others.py +2 -1
- brainstate/graph/_graph_operation.py +3 -2
- brainstate/graph/_graph_operation_test.py +4 -3
- brainstate/init/_base.py +2 -1
- brainstate/init/_generic.py +2 -1
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_collective_ops.py +1 -0
- brainstate/nn/_collective_ops_test.py +0 -4
- brainstate/nn/_common.py +0 -1
- brainstate/nn/_dyn_impl/__init__.py +0 -4
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +431 -13
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +2 -1
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +405 -103
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +2 -1
- brainstate/nn/_dyn_impl/_inputs.py +236 -29
- brainstate/nn/_dyn_impl/_rate_rnns.py +238 -82
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +2 -1
- brainstate/nn/_dyn_impl/_readout.py +91 -8
- brainstate/nn/_dyn_impl/_readout_test.py +2 -1
- brainstate/nn/_dynamics/_dynamics_base.py +676 -96
- brainstate/nn/_dynamics/_dynamics_base_test.py +2 -1
- brainstate/nn/_dynamics/_projection_base.py +29 -30
- brainstate/nn/_dynamics/_state_delay.py +3 -3
- brainstate/nn/_dynamics/_synouts_test.py +2 -1
- brainstate/nn/_elementwise/_dropout.py +3 -2
- brainstate/nn/_elementwise/_dropout_test.py +2 -1
- brainstate/nn/_elementwise/_elementwise.py +2 -1
- brainstate/nn/{_dyn_impl/_projection_alignpost.py → _event/__init__.py} +8 -7
- brainstate/nn/_event/_fixedprob_mv.py +169 -0
- brainstate/nn/_event/_fixedprob_mv_test.py +115 -0
- brainstate/nn/_event/_linear_mv.py +85 -0
- brainstate/nn/_event/_linear_mv_test.py +121 -0
- brainstate/nn/_exp_euler.py +2 -1
- brainstate/nn/_exp_euler_test.py +2 -1
- brainstate/nn/_interaction/_conv.py +2 -1
- brainstate/nn/_interaction/_linear.py +2 -1
- brainstate/nn/_interaction/_linear_test.py +2 -1
- brainstate/nn/_interaction/_normalizations.py +3 -2
- brainstate/nn/_interaction/_poolings.py +4 -3
- brainstate/nn/_module_test.py +2 -1
- brainstate/nn/metrics.py +4 -3
- brainstate/optim/_lr_scheduler.py +2 -1
- brainstate/optim/_lr_scheduler_test.py +2 -1
- brainstate/optim/_optax_optimizer_test.py +2 -1
- brainstate/optim/_sgd_optimizer.py +3 -2
- brainstate/random/_rand_funs.py +2 -1
- brainstate/random/_rand_funs_test.py +3 -2
- brainstate/random/_rand_seed.py +3 -2
- brainstate/random/_rand_seed_test.py +2 -1
- brainstate/random/_rand_state.py +4 -3
- brainstate/surrogate.py +1 -2
- brainstate/typing.py +4 -4
- brainstate/util/_caller.py +2 -1
- brainstate/util/_others.py +4 -4
- brainstate/util/_pretty_pytree.py +1 -1
- brainstate/util/_pretty_pytree_test.py +2 -1
- brainstate/util/_pretty_table.py +43 -43
- brainstate/util/_struct.py +2 -1
- brainstate/util/filter.py +0 -1
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/METADATA +3 -3
- brainstate-0.1.0.post20250422.dist-info/RECORD +133 -0
- brainstate-0.1.0.post20250420.dist-info/RECORD +0 -129
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250420.dist-info → brainstate-0.1.0.post20250422.dist-info}/top_level.txt +0 -0
@@ -17,9 +17,10 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
import jax.numpy as jnp
|
21
20
|
from typing import Callable, Union
|
22
21
|
|
22
|
+
import jax.numpy as jnp
|
23
|
+
|
23
24
|
from brainstate import random, init, functional
|
24
25
|
from brainstate._state import HiddenState, ParamState
|
25
26
|
from brainstate.nn._interaction._linear import Linear
|
@@ -33,23 +34,83 @@ __all__ = [
|
|
33
34
|
|
34
35
|
class RNNCell(Module):
|
35
36
|
"""
|
36
|
-
Base class for RNN
|
37
|
+
Base class for all recurrent neural network (RNN) cell implementations.
|
38
|
+
|
39
|
+
This abstract class serves as the foundation for implementing various RNN cell types
|
40
|
+
such as vanilla RNN, GRU, LSTM, and other recurrent architectures. It extends the
|
41
|
+
Module class and provides common functionality and interface for recurrent units.
|
42
|
+
|
43
|
+
All RNN cell implementations should inherit from this class and implement the required
|
44
|
+
methods, particularly the `init_state()`, `reset_state()`, and `update()` methods that
|
45
|
+
define the state initialization and recurrent dynamics.
|
46
|
+
|
47
|
+
The RNNCell typically maintains hidden state(s) that are updated at each time step
|
48
|
+
based on the current input and previous state values.
|
49
|
+
|
50
|
+
Methods
|
51
|
+
-------
|
52
|
+
init_state(batch_size=None, **kwargs)
|
53
|
+
Initialize the cell state variables with appropriate dimensions.
|
54
|
+
reset_state(batch_size=None, **kwargs)
|
55
|
+
Reset the cell state variables to their initial values.
|
56
|
+
update(x)
|
57
|
+
Update the cell state for one time step based on input x and return output.
|
37
58
|
"""
|
38
59
|
pass
|
39
60
|
|
40
61
|
|
41
62
|
class ValinaRNNCell(RNNCell):
|
42
|
-
"""
|
43
|
-
Vanilla RNN cell.
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
63
|
+
r"""
|
64
|
+
Vanilla Recurrent Neural Network (RNN) cell implementation.
|
65
|
+
|
66
|
+
This class implements the basic RNN model that updates a hidden state based on
|
67
|
+
the current input and previous hidden state. The standard RNN cell follows the
|
68
|
+
mathematical formulation:
|
69
|
+
|
70
|
+
.. math::
|
71
|
+
|
72
|
+
h_t = \phi(W [x_t, h_{t-1}] + b)
|
73
|
+
|
74
|
+
where:
|
75
|
+
|
76
|
+
- :math:`x_t` is the input vector at time t
|
77
|
+
- :math:`h_t` is the hidden state at time t
|
78
|
+
- :math:`h_{t-1}` is the hidden state at previous time step
|
79
|
+
- :math:`W` is the weight matrix for the combined input-hidden linear transformation
|
80
|
+
- :math:`b` is the bias vector
|
81
|
+
- :math:`\phi` is the activation function
|
82
|
+
|
83
|
+
Parameters
|
84
|
+
----------
|
85
|
+
num_in : int
|
86
|
+
The number of input units.
|
87
|
+
num_out : int
|
88
|
+
The number of hidden units.
|
89
|
+
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
90
|
+
Initializer for the hidden state.
|
91
|
+
w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
|
92
|
+
Initializer for the weight matrix.
|
93
|
+
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
94
|
+
Initializer for the bias vector.
|
95
|
+
activation : str or Callable, default='relu'
|
96
|
+
Activation function to use. Can be a string (e.g., 'relu', 'tanh')
|
97
|
+
or a callable function.
|
98
|
+
name : str, optional
|
99
|
+
Name of the module.
|
100
|
+
|
101
|
+
State Variables
|
102
|
+
--------------
|
103
|
+
h : HiddenState
|
104
|
+
Hidden state of the RNN cell.
|
105
|
+
|
106
|
+
Methods
|
107
|
+
-------
|
108
|
+
init_state(batch_size=None, **kwargs)
|
109
|
+
Initialize the cell hidden state.
|
110
|
+
reset_state(batch_size=None, **kwargs)
|
111
|
+
Reset the cell hidden state to its initial value.
|
112
|
+
update(x)
|
113
|
+
Update the hidden state for one time step and return the new state.
|
53
114
|
"""
|
54
115
|
__module__ = 'brainstate.nn'
|
55
116
|
|
@@ -96,17 +157,62 @@ class ValinaRNNCell(RNNCell):
|
|
96
157
|
|
97
158
|
|
98
159
|
class GRUCell(RNNCell):
|
99
|
-
"""
|
100
|
-
Gated Recurrent Unit (GRU) cell.
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
160
|
+
r"""
|
161
|
+
Gated Recurrent Unit (GRU) cell implementation.
|
162
|
+
|
163
|
+
This class implements the GRU model that uses gating mechanisms to control
|
164
|
+
information flow. The GRU has fewer parameters than LSTM as it combines
|
165
|
+
the forget and input gates into a single update gate. The GRU follows the
|
166
|
+
mathematical formulation:
|
167
|
+
|
168
|
+
.. math::
|
169
|
+
|
170
|
+
r_t &= \sigma(W_r [x_t, h_{t-1}] + b_r) \\
|
171
|
+
z_t &= \sigma(W_z [x_t, h_{t-1}] + b_z) \\
|
172
|
+
\tilde{h}_t &= \tanh(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\
|
173
|
+
h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
|
174
|
+
|
175
|
+
where:
|
176
|
+
|
177
|
+
- :math:`x_t` is the input vector at time t
|
178
|
+
- :math:`h_t` is the hidden state at time t
|
179
|
+
- :math:`r_t` is the reset gate vector
|
180
|
+
- :math:`z_t` is the update gate vector
|
181
|
+
- :math:`\tilde{h}_t` is the candidate hidden state
|
182
|
+
- :math:`\odot` represents element-wise multiplication
|
183
|
+
- :math:`\sigma` is the sigmoid activation function
|
184
|
+
|
185
|
+
Parameters
|
186
|
+
----------
|
187
|
+
num_in : int
|
188
|
+
The number of input units.
|
189
|
+
num_out : int
|
190
|
+
The number of hidden units.
|
191
|
+
w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
|
192
|
+
Initializer for the weight matrices.
|
193
|
+
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
194
|
+
Initializer for the bias vectors.
|
195
|
+
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
196
|
+
Initializer for the hidden state.
|
197
|
+
activation : str or Callable, default='tanh'
|
198
|
+
Activation function to use. Can be a string (e.g., 'tanh')
|
199
|
+
or a callable function.
|
200
|
+
name : str, optional
|
201
|
+
Name of the module.
|
202
|
+
|
203
|
+
State Variables
|
204
|
+
--------------
|
205
|
+
h : HiddenState
|
206
|
+
Hidden state of the GRU cell.
|
207
|
+
|
208
|
+
Methods
|
209
|
+
-------
|
210
|
+
init_state(batch_size=None, **kwargs)
|
211
|
+
Initialize the cell hidden state.
|
212
|
+
reset_state(batch_size=None, **kwargs)
|
213
|
+
Reset the cell hidden state to its initial value.
|
214
|
+
update(x)
|
215
|
+
Update the hidden state for one time step and return the new state.
|
110
216
|
"""
|
111
217
|
__module__ = 'brainstate.nn'
|
112
218
|
|
@@ -159,32 +265,60 @@ class GRUCell(RNNCell):
|
|
159
265
|
|
160
266
|
class MGUCell(RNNCell):
|
161
267
|
r"""
|
162
|
-
Minimal Gated Recurrent Unit (MGU) cell.
|
268
|
+
Minimal Gated Recurrent Unit (MGU) cell implementation.
|
269
|
+
|
270
|
+
MGU is a simplified version of GRU that uses a single forget gate instead of
|
271
|
+
separate reset and update gates. This results in fewer parameters while
|
272
|
+
maintaining much of the gating capability. The MGU follows the mathematical
|
273
|
+
formulation:
|
163
274
|
|
164
275
|
.. math::
|
165
276
|
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
h_{t}&=(1-f_{t})\odot h_{t-1}+f_{t}\odot {\hat {h}}_{t}
|
170
|
-
\end{aligned}
|
277
|
+
f_t &= \\sigma(W_f [x_t, h_{t-1}] + b_f) \\\\
|
278
|
+
\\tilde{h}_t &= \\phi(W_h [x_t, (f_t \\odot h_{t-1})] + b_h) \\\\
|
279
|
+
h_t &= (1 - f_t) \\odot h_{t-1} + f_t \\odot \\tilde{h}_t
|
171
280
|
|
172
281
|
where:
|
173
282
|
|
174
|
-
- :math:`
|
175
|
-
- :math:`
|
176
|
-
- :math:`
|
177
|
-
- :math
|
178
|
-
- :math
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
283
|
+
- :math:`x_t` is the input vector at time t
|
284
|
+
- :math:`h_t` is the hidden state at time t
|
285
|
+
- :math:`f_t` is the forget gate vector
|
286
|
+
- :math:`\\tilde{h}_t` is the candidate hidden state
|
287
|
+
- :math:`\\odot` represents element-wise multiplication
|
288
|
+
- :math:`\\sigma` is the sigmoid activation function
|
289
|
+
- :math:`\\phi` is the activation function (typically tanh)
|
290
|
+
|
291
|
+
Parameters
|
292
|
+
----------
|
293
|
+
num_in : int
|
294
|
+
The number of input units.
|
295
|
+
num_out : int
|
296
|
+
The number of hidden units.
|
297
|
+
w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
|
298
|
+
Initializer for the weight matrices.
|
299
|
+
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
300
|
+
Initializer for the bias vectors.
|
301
|
+
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
302
|
+
Initializer for the hidden state.
|
303
|
+
activation : str or Callable, default='tanh'
|
304
|
+
Activation function to use. Can be a string (e.g., 'tanh')
|
305
|
+
or a callable function.
|
306
|
+
name : str, optional
|
307
|
+
Name of the module.
|
308
|
+
|
309
|
+
State Variables
|
310
|
+
--------------
|
311
|
+
h : HiddenState
|
312
|
+
Hidden state of the MGU cell.
|
313
|
+
|
314
|
+
Methods
|
315
|
+
-------
|
316
|
+
init_state(batch_size=None, **kwargs)
|
317
|
+
Initialize the cell hidden state.
|
318
|
+
reset_state(batch_size=None, **kwargs)
|
319
|
+
Reset the cell hidden state to its initial value.
|
320
|
+
update(x)
|
321
|
+
Update the hidden state for one time step and return the new state.
|
188
322
|
"""
|
189
323
|
__module__ = 'brainstate.nn'
|
190
324
|
|
@@ -235,59 +369,81 @@ class MGUCell(RNNCell):
|
|
235
369
|
|
236
370
|
|
237
371
|
class LSTMCell(RNNCell):
|
238
|
-
r"""
|
372
|
+
r"""
|
373
|
+
Long Short-Term Memory (LSTM) cell implementation.
|
239
374
|
|
240
|
-
|
241
|
-
|
242
|
-
|
375
|
+
This class implements the LSTM architecture which uses multiple gating mechanisms
|
376
|
+
to regulate information flow and address the vanishing gradient problem in RNNs.
|
377
|
+
The LSTM follows the mathematical formulation:
|
243
378
|
|
244
379
|
.. math::
|
245
380
|
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
h_t = o_t \tanh(c_t)
|
253
|
-
\end{array}
|
381
|
+
i_t &= \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\
|
382
|
+
f_t &= \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\
|
383
|
+
g_t &= \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\
|
384
|
+
o_t &= \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\
|
385
|
+
c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\
|
386
|
+
h_t &= o_t \odot \tanh(c_t)
|
254
387
|
|
255
|
-
where
|
256
|
-
output gate activations, and :math:`g_t` is a vector of cell updates.
|
388
|
+
where:
|
257
389
|
|
258
|
-
|
390
|
+
- :math:`x_t` is the input vector at time t
|
391
|
+
- :math:`h_t` is the hidden state at time t
|
392
|
+
- :math:`c_t` is the cell state at time t
|
393
|
+
- :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and output gate activations
|
394
|
+
- :math:`g_t` is the cell update vector
|
395
|
+
- :math:`\odot` represents element-wise multiplication
|
396
|
+
- :math:`\sigma` is the sigmoid activation function
|
259
397
|
|
260
398
|
Notes
|
261
399
|
-----
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
the beginning of the training.
|
266
|
-
|
400
|
+
Forget gate initialization: Following Jozefowicz et al. (2015), we add 1.0
|
401
|
+
to the forget gate bias after initialization to reduce forgetting at the
|
402
|
+
beginning of training.
|
267
403
|
|
268
404
|
Parameters
|
269
405
|
----------
|
270
|
-
num_in: int
|
271
|
-
|
272
|
-
num_out: int
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
activation: str,
|
281
|
-
|
406
|
+
num_in : int
|
407
|
+
The number of input units.
|
408
|
+
num_out : int
|
409
|
+
The number of hidden/cell units.
|
410
|
+
w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
|
411
|
+
Initializer for the weight matrices.
|
412
|
+
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
413
|
+
Initializer for the bias vectors.
|
414
|
+
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
415
|
+
Initializer for the hidden and cell states.
|
416
|
+
activation : str or Callable, default='tanh'
|
417
|
+
Activation function to use. Can be a string (e.g., 'tanh')
|
418
|
+
or a callable function.
|
419
|
+
name : str, optional
|
420
|
+
Name of the module.
|
421
|
+
|
422
|
+
State Variables
|
423
|
+
--------------
|
424
|
+
h : HiddenState
|
425
|
+
Hidden state of the LSTM cell.
|
426
|
+
c : HiddenState
|
427
|
+
Cell state of the LSTM cell.
|
428
|
+
|
429
|
+
Methods
|
430
|
+
-------
|
431
|
+
init_state(batch_size=None, **kwargs)
|
432
|
+
Initialize the cell and hidden states.
|
433
|
+
reset_state(batch_size=None, **kwargs)
|
434
|
+
Reset the cell and hidden states to their initial values.
|
435
|
+
update(x)
|
436
|
+
Update the states for one time step and return the new hidden state.
|
282
437
|
|
283
438
|
References
|
284
439
|
----------
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
440
|
+
.. [1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.
|
441
|
+
Neural computation, 9(8), 1735-1780.
|
442
|
+
.. [2] Zaremba, W., Sutskever, I., & Vinyals, O. (2014). Recurrent neural
|
443
|
+
network regularization. arXiv preprint arXiv:1409.2329.
|
444
|
+
.. [3] Jozefowicz, R., Zaremba, W., & Sutskever, I. (2015). An empirical
|
445
|
+
exploration of recurrent network architectures. In International
|
446
|
+
conference on machine learning, pp. 2342-2350.
|
291
447
|
"""
|
292
448
|
__module__ = 'brainstate.nn'
|
293
449
|
|
@@ -17,11 +17,12 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
-
import brainunit as u
|
21
|
-
import jax
|
22
20
|
import numbers
|
23
21
|
from typing import Callable
|
24
22
|
|
23
|
+
import brainunit as u
|
24
|
+
import jax
|
25
|
+
|
25
26
|
from brainstate import environ, init, surrogate
|
26
27
|
from brainstate._state import HiddenState, ParamState
|
27
28
|
from brainstate.nn._exp_euler import exp_euler_step
|
@@ -36,8 +37,45 @@ __all__ = [
|
|
36
37
|
|
37
38
|
|
38
39
|
class LeakyRateReadout(Module):
|
39
|
-
"""
|
40
|
-
Leaky dynamics for the read-out module
|
40
|
+
r"""
|
41
|
+
Leaky dynamics for the read-out module.
|
42
|
+
|
43
|
+
This module implements a leaky integrator with the following dynamics:
|
44
|
+
|
45
|
+
.. math::
|
46
|
+
r_{t} = \alpha r_{t-1} + x_{t} W
|
47
|
+
|
48
|
+
where:
|
49
|
+
- :math:`r_{t}` is the output at time t
|
50
|
+
- :math:`\alpha = e^{-\Delta t / \tau}` is the decay factor
|
51
|
+
- :math:`x_{t}` is the input at time t
|
52
|
+
- :math:`W` is the weight matrix
|
53
|
+
|
54
|
+
The leaky integrator acts as a low-pass filter, allowing the network
|
55
|
+
to maintain memory of past inputs with an exponential decay determined
|
56
|
+
by the time constant tau.
|
57
|
+
|
58
|
+
Parameters
|
59
|
+
----------
|
60
|
+
in_size : int or sequence of int
|
61
|
+
Size of the input dimension(s)
|
62
|
+
out_size : int or sequence of int
|
63
|
+
Size of the output dimension(s)
|
64
|
+
tau : ArrayLike, optional
|
65
|
+
Time constant of the leaky dynamics, by default 5ms
|
66
|
+
w_init : Callable, optional
|
67
|
+
Weight initialization function, by default KaimingNormal()
|
68
|
+
name : str, optional
|
69
|
+
Name of the module, by default None
|
70
|
+
|
71
|
+
Attributes
|
72
|
+
----------
|
73
|
+
decay : float
|
74
|
+
Decay factor computed as exp(-dt/tau)
|
75
|
+
weight : ParamState
|
76
|
+
Weight matrix connecting input to output
|
77
|
+
r : HiddenState
|
78
|
+
Hidden state representing the output values
|
41
79
|
"""
|
42
80
|
__module__ = 'brainstate.nn'
|
43
81
|
|
@@ -72,8 +110,53 @@ class LeakyRateReadout(Module):
|
|
72
110
|
|
73
111
|
|
74
112
|
class LeakySpikeReadout(Neuron):
|
75
|
-
"""
|
76
|
-
Integrate-and-fire neuron model with leaky dynamics.
|
113
|
+
r"""
|
114
|
+
Integrate-and-fire neuron model with leaky dynamics for readout functionality.
|
115
|
+
|
116
|
+
This class implements a spiking neuron with the following dynamics:
|
117
|
+
|
118
|
+
.. math::
|
119
|
+
\frac{dV}{dt} = \frac{-V + I_{in}}{\tau}
|
120
|
+
|
121
|
+
where:
|
122
|
+
- :math:`V` is the membrane potential
|
123
|
+
- :math:`\tau` is the membrane time constant
|
124
|
+
- :math:`I_{in}` is the input current
|
125
|
+
|
126
|
+
Spike generation occurs when :math:`V > V_{th}` according to:
|
127
|
+
|
128
|
+
.. math::
|
129
|
+
S_t = \text{surrogate}\left(\frac{V - V_{th}}{V_{th}}\right)
|
130
|
+
|
131
|
+
After spiking, the membrane potential is reset according to the reset mode:
|
132
|
+
- Soft reset: :math:`V \leftarrow V - V_{th} \cdot S_t`
|
133
|
+
- Hard reset: :math:`V \leftarrow V - V_t \cdot S_t` (where :math:`V_t` is detached)
|
134
|
+
|
135
|
+
Parameters
|
136
|
+
----------
|
137
|
+
in_size : Size
|
138
|
+
Size of the input dimension
|
139
|
+
tau : ArrayLike, optional
|
140
|
+
Membrane time constant, by default 5ms
|
141
|
+
V_th : ArrayLike, optional
|
142
|
+
Spike threshold, by default 1mV
|
143
|
+
w_init : Callable, optional
|
144
|
+
Weight initialization function, by default KaimingNormal(unit=mV)
|
145
|
+
V_initializer : ArrayLike, optional
|
146
|
+
Initial membrane potential, by default ZeroInit(unit=mV)
|
147
|
+
spk_fun : Callable, optional
|
148
|
+
Surrogate gradient function for spike generation, by default ReluGrad()
|
149
|
+
spk_reset : str, optional
|
150
|
+
Reset mechanism after spike ('soft' or 'hard'), by default 'soft'
|
151
|
+
name : str, optional
|
152
|
+
Name of the module, by default None
|
153
|
+
|
154
|
+
Attributes
|
155
|
+
----------
|
156
|
+
V : HiddenState
|
157
|
+
Membrane potential state variable
|
158
|
+
weight : ParamState
|
159
|
+
Synaptic weight matrix
|
77
160
|
"""
|
78
161
|
|
79
162
|
__module__ = 'brainstate.nn'
|
@@ -92,8 +175,8 @@ class LeakySpikeReadout(Neuron):
|
|
92
175
|
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
|
93
176
|
|
94
177
|
# parameters
|
95
|
-
self.tau = init.param(tau,
|
96
|
-
self.V_th = init.param(V_th,
|
178
|
+
self.tau = init.param(tau, self.varshape)
|
179
|
+
self.V_th = init.param(V_th, self.varshape)
|
97
180
|
self.V_initializer = V_initializer
|
98
181
|
|
99
182
|
# weights
|