brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -20,9 +20,11 @@ from typing import Callable, Union
|
|
20
20
|
|
21
21
|
import jax.numpy as jnp
|
22
22
|
|
23
|
-
from brainstate import random
|
23
|
+
from brainstate import random
|
24
24
|
from brainstate._state import HiddenState, ParamState
|
25
25
|
from brainstate.typing import ArrayLike
|
26
|
+
from . import _activations as functional
|
27
|
+
from . import init as init
|
26
28
|
from ._linear import Linear
|
27
29
|
from ._module import Module
|
28
30
|
|
@@ -54,7 +56,16 @@ class RNNCell(Module):
|
|
54
56
|
Reset the cell state variables to their initial values.
|
55
57
|
update(x)
|
56
58
|
Update the cell state for one time step based on input x and return output.
|
59
|
+
|
60
|
+
See Also
|
61
|
+
--------
|
62
|
+
ValinaRNNCell : Vanilla RNN cell implementation
|
63
|
+
GRUCell : Gated Recurrent Unit cell implementation
|
64
|
+
LSTMCell : Long Short-Term Memory cell implementation
|
65
|
+
URLSTMCell : LSTM with UR gating mechanism
|
66
|
+
MGUCell : Minimal Gated Unit cell implementation
|
57
67
|
"""
|
68
|
+
__module__ = 'brainstate.nn'
|
58
69
|
pass
|
59
70
|
|
60
71
|
|
@@ -97,8 +108,19 @@ class ValinaRNNCell(RNNCell):
|
|
97
108
|
name : str, optional
|
98
109
|
Name of the module.
|
99
110
|
|
111
|
+
Attributes
|
112
|
+
----------
|
113
|
+
num_in : int
|
114
|
+
Number of input features.
|
115
|
+
num_out : int
|
116
|
+
Number of hidden units.
|
117
|
+
in_size : tuple
|
118
|
+
Shape of input (num_in,).
|
119
|
+
out_size : tuple
|
120
|
+
Shape of output (num_out,).
|
121
|
+
|
100
122
|
State Variables
|
101
|
-
|
123
|
+
---------------
|
102
124
|
h : HiddenState
|
103
125
|
Hidden state of the RNN cell.
|
104
126
|
|
@@ -110,6 +132,39 @@ class ValinaRNNCell(RNNCell):
|
|
110
132
|
Reset the cell hidden state to its initial value.
|
111
133
|
update(x)
|
112
134
|
Update the hidden state for one time step and return the new state.
|
135
|
+
|
136
|
+
Examples
|
137
|
+
--------
|
138
|
+
.. code-block:: python
|
139
|
+
|
140
|
+
>>> import brainstate as bs
|
141
|
+
>>> import jax.numpy as jnp
|
142
|
+
>>>
|
143
|
+
>>> # Create a vanilla RNN cell
|
144
|
+
>>> cell = bs.nn.ValinaRNNCell(num_in=10, num_out=20)
|
145
|
+
>>>
|
146
|
+
>>> # Initialize state for batch size 32
|
147
|
+
>>> cell.init_state(batch_size=32)
|
148
|
+
>>>
|
149
|
+
>>> # Process a single time step
|
150
|
+
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
151
|
+
>>> output = cell.update(x)
|
152
|
+
>>> print(output.shape) # (32, 20)
|
153
|
+
>>>
|
154
|
+
>>> # Process a sequence of inputs
|
155
|
+
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
156
|
+
>>> outputs = []
|
157
|
+
>>> for t in range(100):
|
158
|
+
... output = cell.update(sequence[t])
|
159
|
+
... outputs.append(output)
|
160
|
+
>>> outputs = jnp.stack(outputs)
|
161
|
+
>>> print(outputs.shape) # (100, 32, 20)
|
162
|
+
|
163
|
+
Notes
|
164
|
+
-----
|
165
|
+
Vanilla RNNs can suffer from vanishing or exploding gradient problems
|
166
|
+
when processing long sequences. For better performance on long sequences,
|
167
|
+
consider using gated architectures like GRU or LSTM.
|
113
168
|
"""
|
114
169
|
__module__ = 'brainstate.nn'
|
115
170
|
|
@@ -143,10 +198,30 @@ class ValinaRNNCell(RNNCell):
|
|
143
198
|
self.W = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
144
199
|
|
145
200
|
def init_state(self, batch_size: int = None, **kwargs):
|
146
|
-
|
201
|
+
"""
|
202
|
+
Initialize the hidden state.
|
203
|
+
|
204
|
+
Parameters
|
205
|
+
----------
|
206
|
+
batch_size : int, optional
|
207
|
+
The batch size for state initialization.
|
208
|
+
**kwargs
|
209
|
+
Additional keyword arguments.
|
210
|
+
"""
|
211
|
+
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
147
212
|
|
148
213
|
def reset_state(self, batch_size: int = None, **kwargs):
|
149
|
-
|
214
|
+
"""
|
215
|
+
Reset the hidden state to initial value.
|
216
|
+
|
217
|
+
Parameters
|
218
|
+
----------
|
219
|
+
batch_size : int, optional
|
220
|
+
The batch size for state reset.
|
221
|
+
**kwargs
|
222
|
+
Additional keyword arguments.
|
223
|
+
"""
|
224
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
150
225
|
|
151
226
|
def update(self, x):
|
152
227
|
xh = jnp.concatenate([x, self.h.value], axis=-1)
|
@@ -159,16 +234,18 @@ class GRUCell(RNNCell):
|
|
159
234
|
r"""
|
160
235
|
Gated Recurrent Unit (GRU) cell implementation.
|
161
236
|
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
237
|
+
The GRU is a gating mechanism in recurrent neural networks that aims to solve
|
238
|
+
the vanishing gradient problem. It uses gating mechanisms to control information
|
239
|
+
flow and has fewer parameters than LSTM as it combines the forget and input gates
|
240
|
+
into a single update gate.
|
241
|
+
|
242
|
+
The GRU cell follows the mathematical formulation:
|
166
243
|
|
167
244
|
.. math::
|
168
245
|
|
169
246
|
r_t &= \sigma(W_r [x_t, h_{t-1}] + b_r) \\
|
170
247
|
z_t &= \sigma(W_z [x_t, h_{t-1}] + b_z) \\
|
171
|
-
\tilde{h}_t &= \
|
248
|
+
\tilde{h}_t &= \phi(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\
|
172
249
|
h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
|
173
250
|
|
174
251
|
where:
|
@@ -180,6 +257,7 @@ class GRUCell(RNNCell):
|
|
180
257
|
- :math:`\tilde{h}_t` is the candidate hidden state
|
181
258
|
- :math:`\odot` represents element-wise multiplication
|
182
259
|
- :math:`\sigma` is the sigmoid activation function
|
260
|
+
- :math:`\phi` is the activation function (typically tanh)
|
183
261
|
|
184
262
|
Parameters
|
185
263
|
----------
|
@@ -194,13 +272,24 @@ class GRUCell(RNNCell):
|
|
194
272
|
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
195
273
|
Initializer for the hidden state.
|
196
274
|
activation : str or Callable, default='tanh'
|
197
|
-
Activation function to use. Can be a string (e.g., 'tanh')
|
275
|
+
Activation function to use. Can be a string (e.g., 'tanh', 'relu')
|
198
276
|
or a callable function.
|
199
277
|
name : str, optional
|
200
278
|
Name of the module.
|
201
279
|
|
280
|
+
Attributes
|
281
|
+
----------
|
282
|
+
num_in : int
|
283
|
+
Number of input features.
|
284
|
+
num_out : int
|
285
|
+
Number of hidden units.
|
286
|
+
in_size : tuple
|
287
|
+
Shape of input (num_in,).
|
288
|
+
out_size : tuple
|
289
|
+
Shape of output (num_out,).
|
290
|
+
|
202
291
|
State Variables
|
203
|
-
|
292
|
+
---------------
|
204
293
|
h : HiddenState
|
205
294
|
Hidden state of the GRU cell.
|
206
295
|
|
@@ -212,6 +301,51 @@ class GRUCell(RNNCell):
|
|
212
301
|
Reset the cell hidden state to its initial value.
|
213
302
|
update(x)
|
214
303
|
Update the hidden state for one time step and return the new state.
|
304
|
+
|
305
|
+
Examples
|
306
|
+
--------
|
307
|
+
.. code-block:: python
|
308
|
+
|
309
|
+
>>> import brainstate as bs
|
310
|
+
>>> import jax.numpy as jnp
|
311
|
+
>>>
|
312
|
+
>>> # Create a GRU cell
|
313
|
+
>>> cell = bs.nn.GRUCell(num_in=10, num_out=20)
|
314
|
+
>>>
|
315
|
+
>>> # Initialize state for batch size 32
|
316
|
+
>>> cell.init_state(batch_size=32)
|
317
|
+
>>>
|
318
|
+
>>> # Process a single time step
|
319
|
+
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
320
|
+
>>> output = cell.update(x)
|
321
|
+
>>> print(output.shape) # (32, 20)
|
322
|
+
>>>
|
323
|
+
>>> # Process a sequence
|
324
|
+
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
325
|
+
>>> outputs = []
|
326
|
+
>>> for t in range(100):
|
327
|
+
... output = cell.update(sequence[t])
|
328
|
+
... outputs.append(output)
|
329
|
+
>>> outputs = jnp.stack(outputs)
|
330
|
+
>>> print(outputs.shape) # (100, 32, 20)
|
331
|
+
>>>
|
332
|
+
>>> # Reset state with different batch size
|
333
|
+
>>> cell.reset_state(batch_size=16)
|
334
|
+
>>> x_new = jnp.ones((16, 10))
|
335
|
+
>>> output_new = cell.update(x_new)
|
336
|
+
>>> print(output_new.shape) # (16, 20)
|
337
|
+
|
338
|
+
Notes
|
339
|
+
-----
|
340
|
+
GRU cells are computationally more efficient than LSTM cells due to having
|
341
|
+
fewer parameters, while often achieving comparable performance on many tasks.
|
342
|
+
|
343
|
+
References
|
344
|
+
----------
|
345
|
+
.. [1] Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F.,
|
346
|
+
Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using
|
347
|
+
RNN encoder-decoder for statistical machine translation.
|
348
|
+
arXiv preprint arXiv:1406.1078.
|
215
349
|
"""
|
216
350
|
__module__ = 'brainstate.nn'
|
217
351
|
|
@@ -264,28 +398,30 @@ class GRUCell(RNNCell):
|
|
264
398
|
|
265
399
|
class MGUCell(RNNCell):
|
266
400
|
r"""
|
267
|
-
Minimal Gated
|
401
|
+
Minimal Gated Unit (MGU) cell implementation.
|
268
402
|
|
269
403
|
MGU is a simplified version of GRU that uses a single forget gate instead of
|
270
|
-
separate reset and update gates. This
|
271
|
-
maintaining much of the gating capability.
|
272
|
-
|
404
|
+
separate reset and update gates. This design significantly reduces the number
|
405
|
+
of parameters while maintaining much of the gating capability. MGU provides
|
406
|
+
a good trade-off between model complexity and performance.
|
407
|
+
|
408
|
+
The MGU cell follows the mathematical formulation:
|
273
409
|
|
274
410
|
.. math::
|
275
411
|
|
276
|
-
f_t &=
|
277
|
-
|
278
|
-
h_t &= (1 - f_t)
|
412
|
+
f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
|
413
|
+
\tilde{h}_t &= \phi(W_h [x_t, (f_t \odot h_{t-1})] + b_h) \\
|
414
|
+
h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t
|
279
415
|
|
280
416
|
where:
|
281
417
|
|
282
418
|
- :math:`x_t` is the input vector at time t
|
283
419
|
- :math:`h_t` is the hidden state at time t
|
284
420
|
- :math:`f_t` is the forget gate vector
|
285
|
-
- :math
|
286
|
-
- :math
|
287
|
-
- :math
|
288
|
-
- :math
|
421
|
+
- :math:`\tilde{h}_t` is the candidate hidden state
|
422
|
+
- :math:`\odot` represents element-wise multiplication
|
423
|
+
- :math:`\sigma` is the sigmoid activation function
|
424
|
+
- :math:`\phi` is the activation function (typically tanh)
|
289
425
|
|
290
426
|
Parameters
|
291
427
|
----------
|
@@ -300,13 +436,24 @@ class MGUCell(RNNCell):
|
|
300
436
|
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
301
437
|
Initializer for the hidden state.
|
302
438
|
activation : str or Callable, default='tanh'
|
303
|
-
Activation function to use. Can be a string (e.g., 'tanh')
|
439
|
+
Activation function to use. Can be a string (e.g., 'tanh', 'relu')
|
304
440
|
or a callable function.
|
305
441
|
name : str, optional
|
306
442
|
Name of the module.
|
307
443
|
|
444
|
+
Attributes
|
445
|
+
----------
|
446
|
+
num_in : int
|
447
|
+
Number of input features.
|
448
|
+
num_out : int
|
449
|
+
Number of hidden units.
|
450
|
+
in_size : tuple
|
451
|
+
Shape of input (num_in,).
|
452
|
+
out_size : tuple
|
453
|
+
Shape of output (num_out,).
|
454
|
+
|
308
455
|
State Variables
|
309
|
-
|
456
|
+
---------------
|
310
457
|
h : HiddenState
|
311
458
|
Hidden state of the MGU cell.
|
312
459
|
|
@@ -318,6 +465,44 @@ class MGUCell(RNNCell):
|
|
318
465
|
Reset the cell hidden state to its initial value.
|
319
466
|
update(x)
|
320
467
|
Update the hidden state for one time step and return the new state.
|
468
|
+
|
469
|
+
Examples
|
470
|
+
--------
|
471
|
+
.. code-block:: python
|
472
|
+
|
473
|
+
>>> import brainstate as bs
|
474
|
+
>>> import jax.numpy as jnp
|
475
|
+
>>>
|
476
|
+
>>> # Create an MGU cell
|
477
|
+
>>> cell = bs.nn.MGUCell(num_in=10, num_out=20)
|
478
|
+
>>>
|
479
|
+
>>> # Initialize state for batch size 32
|
480
|
+
>>> cell.init_state(batch_size=32)
|
481
|
+
>>>
|
482
|
+
>>> # Process a single time step
|
483
|
+
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
484
|
+
>>> output = cell.update(x)
|
485
|
+
>>> print(output.shape) # (32, 20)
|
486
|
+
>>>
|
487
|
+
>>> # Process a sequence
|
488
|
+
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
489
|
+
>>> outputs = []
|
490
|
+
>>> for t in range(100):
|
491
|
+
... output = cell.update(sequence[t])
|
492
|
+
... outputs.append(output)
|
493
|
+
>>> outputs = jnp.stack(outputs)
|
494
|
+
>>> print(outputs.shape) # (100, 32, 20)
|
495
|
+
|
496
|
+
Notes
|
497
|
+
-----
|
498
|
+
MGU provides a lightweight alternative to GRU and LSTM, making it suitable
|
499
|
+
for resource-constrained applications or when model simplicity is preferred.
|
500
|
+
|
501
|
+
References
|
502
|
+
----------
|
503
|
+
.. [1] Zhou, G. B., Wu, J., Zhang, C. L., & Zhou, Z. H. (2016). Minimal gated unit
|
504
|
+
for recurrent neural networks. International Journal of Automation and Computing,
|
505
|
+
13(3), 226-234.
|
321
506
|
"""
|
322
507
|
__module__ = 'brainstate.nn'
|
323
508
|
|
@@ -371,34 +556,34 @@ class LSTMCell(RNNCell):
|
|
371
556
|
r"""
|
372
557
|
Long Short-Term Memory (LSTM) cell implementation.
|
373
558
|
|
374
|
-
|
375
|
-
|
376
|
-
|
559
|
+
LSTM is a type of RNN architecture designed to address the vanishing gradient
|
560
|
+
problem and learn long-term dependencies. It uses a cell state to carry
|
561
|
+
information across time steps and three gates (input, forget, output) to
|
562
|
+
control information flow.
|
563
|
+
|
564
|
+
The LSTM cell follows the mathematical formulation:
|
377
565
|
|
378
566
|
.. math::
|
379
567
|
|
380
|
-
i_t &= \sigma(
|
381
|
-
f_t &= \sigma(
|
382
|
-
g_t &= \
|
383
|
-
o_t &= \sigma(
|
568
|
+
i_t &= \sigma(W_i [x_t, h_{t-1}] + b_i) \\
|
569
|
+
f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
|
570
|
+
g_t &= \phi(W_g [x_t, h_{t-1}] + b_g) \\
|
571
|
+
o_t &= \sigma(W_o [x_t, h_{t-1}] + b_o) \\
|
384
572
|
c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\
|
385
|
-
h_t &= o_t \odot \
|
573
|
+
h_t &= o_t \odot \phi(c_t)
|
386
574
|
|
387
575
|
where:
|
388
576
|
|
389
577
|
- :math:`x_t` is the input vector at time t
|
390
578
|
- :math:`h_t` is the hidden state at time t
|
391
579
|
- :math:`c_t` is the cell state at time t
|
392
|
-
- :math:`i_t
|
393
|
-
- :math:`
|
580
|
+
- :math:`i_t` is the input gate activation
|
581
|
+
- :math:`f_t` is the forget gate activation
|
582
|
+
- :math:`o_t` is the output gate activation
|
583
|
+
- :math:`g_t` is the cell update (candidate) vector
|
394
584
|
- :math:`\odot` represents element-wise multiplication
|
395
585
|
- :math:`\sigma` is the sigmoid activation function
|
396
|
-
|
397
|
-
Notes
|
398
|
-
-----
|
399
|
-
Forget gate initialization: Following Jozefowicz et al. (2015), we add 1.0
|
400
|
-
to the forget gate bias after initialization to reduce forgetting at the
|
401
|
-
beginning of training.
|
586
|
+
- :math:`\phi` is the activation function (typically tanh)
|
402
587
|
|
403
588
|
Parameters
|
404
589
|
----------
|
@@ -413,13 +598,24 @@ class LSTMCell(RNNCell):
|
|
413
598
|
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
414
599
|
Initializer for the hidden and cell states.
|
415
600
|
activation : str or Callable, default='tanh'
|
416
|
-
Activation function to use. Can be a string (e.g., 'tanh')
|
601
|
+
Activation function to use. Can be a string (e.g., 'tanh', 'relu')
|
417
602
|
or a callable function.
|
418
603
|
name : str, optional
|
419
604
|
Name of the module.
|
420
605
|
|
606
|
+
Attributes
|
607
|
+
----------
|
608
|
+
num_in : int
|
609
|
+
Number of input features.
|
610
|
+
num_out : int
|
611
|
+
Number of hidden/cell units.
|
612
|
+
in_size : tuple
|
613
|
+
Shape of input (num_in,).
|
614
|
+
out_size : tuple
|
615
|
+
Shape of output (num_out,).
|
616
|
+
|
421
617
|
State Variables
|
422
|
-
|
618
|
+
---------------
|
423
619
|
h : HiddenState
|
424
620
|
Hidden state of the LSTM cell.
|
425
621
|
c : HiddenState
|
@@ -434,15 +630,53 @@ class LSTMCell(RNNCell):
|
|
434
630
|
update(x)
|
435
631
|
Update the states for one time step and return the new hidden state.
|
436
632
|
|
633
|
+
Examples
|
634
|
+
--------
|
635
|
+
.. code-block:: python
|
636
|
+
|
637
|
+
>>> import brainstate as bs
|
638
|
+
>>> import jax.numpy as jnp
|
639
|
+
>>>
|
640
|
+
>>> # Create an LSTM cell
|
641
|
+
>>> cell = bs.nn.LSTMCell(num_in=10, num_out=20)
|
642
|
+
>>>
|
643
|
+
>>> # Initialize states for batch size 32
|
644
|
+
>>> cell.init_state(batch_size=32)
|
645
|
+
>>>
|
646
|
+
>>> # Process a single time step
|
647
|
+
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
648
|
+
>>> output = cell.update(x)
|
649
|
+
>>> print(output.shape) # (32, 20)
|
650
|
+
>>>
|
651
|
+
>>> # Process a sequence
|
652
|
+
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
653
|
+
>>> outputs = []
|
654
|
+
>>> for t in range(100):
|
655
|
+
... output = cell.update(sequence[t])
|
656
|
+
... outputs.append(output)
|
657
|
+
>>> outputs = jnp.stack(outputs)
|
658
|
+
>>> print(outputs.shape) # (100, 32, 20)
|
659
|
+
>>>
|
660
|
+
>>> # Access cell state
|
661
|
+
>>> print(cell.c.value.shape) # (32, 20)
|
662
|
+
>>> print(cell.h.value.shape) # (32, 20)
|
663
|
+
|
664
|
+
Notes
|
665
|
+
-----
|
666
|
+
- The forget gate bias is initialized with +1.0 following Jozefowicz et al. (2015)
|
667
|
+
to reduce forgetting at the beginning of training.
|
668
|
+
- LSTM cells are effective for learning long-term dependencies but require
|
669
|
+
more parameters and computation than simpler RNN variants.
|
670
|
+
|
437
671
|
References
|
438
672
|
----------
|
439
673
|
.. [1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.
|
440
674
|
Neural computation, 9(8), 1735-1780.
|
441
|
-
.. [2]
|
442
|
-
|
675
|
+
.. [2] Gers, F. A., Schmidhuber, J., & Cummins, F. (2000). Learning to forget:
|
676
|
+
Continual prediction with LSTM. Neural computation, 12(10), 2451-2471.
|
443
677
|
.. [3] Jozefowicz, R., Zaremba, W., & Sutskever, I. (2015). An empirical
|
444
678
|
exploration of recurrent network architectures. In International
|
445
|
-
conference on machine learning
|
679
|
+
conference on machine learning (pp. 2342-2350).
|
446
680
|
"""
|
447
681
|
__module__ = 'brainstate.nn'
|
448
682
|
|
@@ -497,6 +731,109 @@ class LSTMCell(RNNCell):
|
|
497
731
|
|
498
732
|
|
499
733
|
class URLSTMCell(RNNCell):
|
734
|
+
r"""LSTM with UR gating mechanism.
|
735
|
+
|
736
|
+
URLSTM is a modification of the standard LSTM that uses untied (separate) biases
|
737
|
+
for the forget and retention mechanisms, allowing for more flexible gating control.
|
738
|
+
This implementation is based on the paper "Improving the Gating Mechanism of
|
739
|
+
Recurrent Neural Networks" by Gers et al.
|
740
|
+
|
741
|
+
The URLSTM cell follows the mathematical formulation:
|
742
|
+
|
743
|
+
.. math::
|
744
|
+
|
745
|
+
f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
|
746
|
+
r_t &= \sigma(W_r [x_t, h_{t-1}] - b_f) \\
|
747
|
+
g_t &= 2 r_t \odot f_t + (1 - 2 r_t) \odot f_t^2 \\
|
748
|
+
\tilde{c}_t &= \phi(W_c [x_t, h_{t-1}]) \\
|
749
|
+
c_t &= g_t \odot c_{t-1} + (1 - g_t) \odot \tilde{c}_t \\
|
750
|
+
o_t &= \sigma(W_o [x_t, h_{t-1}]) \\
|
751
|
+
h_t &= o_t \odot \phi(c_t)
|
752
|
+
|
753
|
+
where:
|
754
|
+
|
755
|
+
- :math:`x_t` is the input vector at time t
|
756
|
+
- :math:`h_t` is the hidden state at time t
|
757
|
+
- :math:`c_t` is the cell state at time t
|
758
|
+
- :math:`f_t` is the forget gate with positive bias
|
759
|
+
- :math:`r_t` is the retention gate with negative bias
|
760
|
+
- :math:`g_t` is the unified gate combining forget and retention
|
761
|
+
- :math:`\tilde{c}_t` is the candidate cell state
|
762
|
+
- :math:`o_t` is the output gate
|
763
|
+
- :math:`\odot` represents element-wise multiplication
|
764
|
+
- :math:`\sigma` is the sigmoid activation function
|
765
|
+
- :math:`\phi` is the activation function (typically tanh)
|
766
|
+
|
767
|
+
The key innovation is the untied bias mechanism where the forget and retention
|
768
|
+
gates use opposite biases, initialized using a uniform distribution to encourage
|
769
|
+
diverse gating behavior across units.
|
770
|
+
|
771
|
+
Parameters
|
772
|
+
----------
|
773
|
+
num_in : int
|
774
|
+
The number of input units.
|
775
|
+
num_out : int
|
776
|
+
The number of hidden/output units.
|
777
|
+
w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
|
778
|
+
Initializer for the weight matrix.
|
779
|
+
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
780
|
+
Initializer for the hidden and cell states.
|
781
|
+
activation : str or Callable, default='tanh'
|
782
|
+
Activation function to use. Can be a string (e.g., 'relu', 'tanh')
|
783
|
+
or a callable function.
|
784
|
+
name : str, optional
|
785
|
+
Name of the module.
|
786
|
+
|
787
|
+
State Variables
|
788
|
+
---------------
|
789
|
+
h : HiddenState
|
790
|
+
Hidden state of the URLSTM cell.
|
791
|
+
c : HiddenState
|
792
|
+
Cell state of the URLSTM cell.
|
793
|
+
|
794
|
+
Methods
|
795
|
+
-------
|
796
|
+
init_state(batch_size=None, **kwargs)
|
797
|
+
Initialize the cell and hidden states.
|
798
|
+
reset_state(batch_size=None, **kwargs)
|
799
|
+
Reset the cell and hidden states to their initial values.
|
800
|
+
update(x)
|
801
|
+
Update the cell and hidden states for one time step and return the hidden state.
|
802
|
+
|
803
|
+
Examples
|
804
|
+
--------
|
805
|
+
.. code-block:: python
|
806
|
+
|
807
|
+
>>> import brainstate as bs
|
808
|
+
>>> import jax.numpy as jnp
|
809
|
+
>>>
|
810
|
+
>>> # Create a URLSTM cell
|
811
|
+
>>> cell = bs.nn.URLSTMCell(num_in=10, num_out=20)
|
812
|
+
>>>
|
813
|
+
>>> # Initialize the state for batch size 32
|
814
|
+
>>> cell.init_state(batch_size=32)
|
815
|
+
>>>
|
816
|
+
>>> # Process a sequence
|
817
|
+
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
818
|
+
>>> output = cell.update(x)
|
819
|
+
>>> print(output.shape) # (32, 20)
|
820
|
+
>>>
|
821
|
+
>>> # Process multiple time steps
|
822
|
+
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
823
|
+
>>> outputs = []
|
824
|
+
>>> for t in range(100):
|
825
|
+
... output = cell.update(sequence[t])
|
826
|
+
... outputs.append(output)
|
827
|
+
>>> outputs = jnp.stack(outputs)
|
828
|
+
>>> print(outputs.shape) # (100, 32, 20)
|
829
|
+
|
830
|
+
References
|
831
|
+
----------
|
832
|
+
.. [1] Gu, Albert, et al. "Improving the gating mechanism of recurrent neural networks."
|
833
|
+
International conference on machine learning. PMLR, 2020.
|
834
|
+
"""
|
835
|
+
__module__ = 'brainstate.nn'
|
836
|
+
|
500
837
|
def __init__(
|
501
838
|
self,
|
502
839
|
num_in: int,
|
@@ -521,34 +858,89 @@ class URLSTMCell(RNNCell):
|
|
521
858
|
if isinstance(activation, str):
|
522
859
|
self.activation = getattr(functional, activation)
|
523
860
|
else:
|
524
|
-
assert callable(activation), "The activation function should be a string or a callable function.
|
861
|
+
assert callable(activation), "The activation function should be a string or a callable function."
|
525
862
|
self.activation = activation
|
526
863
|
|
527
|
-
# weights
|
864
|
+
# weights - 4 gates: forget, retention, candidate, output
|
528
865
|
self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=None)
|
866
|
+
|
867
|
+
# Initialize untied bias using uniform distribution
|
529
868
|
self.bias = ParamState(self._forget_bias())
|
530
869
|
|
531
870
|
def _forget_bias(self):
|
871
|
+
"""Initialize the forget gate bias using uniform distribution."""
|
872
|
+
# Sample from uniform distribution to encourage diverse gating
|
532
873
|
u = random.uniform(1 / self.num_out, 1 - 1 / self.num_out, (self.num_out,))
|
874
|
+
# Transform to logit space for initialization
|
533
875
|
return -jnp.log(1 / u - 1)
|
534
876
|
|
535
877
|
def init_state(self, batch_size: int = None, **kwargs):
|
878
|
+
"""
|
879
|
+
Initialize the cell and hidden states.
|
880
|
+
|
881
|
+
Parameters
|
882
|
+
----------
|
883
|
+
batch_size : int, optional
|
884
|
+
The batch size for state initialization.
|
885
|
+
**kwargs
|
886
|
+
Additional keyword arguments.
|
887
|
+
"""
|
536
888
|
self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
537
889
|
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
538
890
|
|
539
891
|
def reset_state(self, batch_size: int = None, **kwargs):
|
892
|
+
"""
|
893
|
+
Reset the cell and hidden states to their initial values.
|
894
|
+
|
895
|
+
Parameters
|
896
|
+
----------
|
897
|
+
batch_size : int, optional
|
898
|
+
The batch size for state reset.
|
899
|
+
**kwargs
|
900
|
+
Additional keyword arguments.
|
901
|
+
"""
|
540
902
|
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
541
903
|
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
542
904
|
|
543
905
|
def update(self, x: ArrayLike) -> ArrayLike:
|
906
|
+
"""
|
907
|
+
Update the URLSTM cell for one time step.
|
908
|
+
|
909
|
+
Parameters
|
910
|
+
----------
|
911
|
+
x : ArrayLike
|
912
|
+
Input tensor with shape (batch_size, num_in).
|
913
|
+
|
914
|
+
Returns
|
915
|
+
-------
|
916
|
+
ArrayLike
|
917
|
+
Hidden state tensor with shape (batch_size, num_out).
|
918
|
+
"""
|
544
919
|
h, c = self.h.value, self.c.value
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
920
|
+
|
921
|
+
# Concatenate input and hidden state
|
922
|
+
xh = jnp.concatenate([x, h], axis=-1)
|
923
|
+
|
924
|
+
# Compute all gates in one pass
|
925
|
+
gates = self.W(xh)
|
926
|
+
f, r, u, o = jnp.split(gates, indices_or_sections=4, axis=-1)
|
927
|
+
|
928
|
+
# Apply untied biases to forget and retention gates
|
929
|
+
f_gate = functional.sigmoid(f + self.bias.value)
|
930
|
+
r_gate = functional.sigmoid(r - self.bias.value)
|
931
|
+
|
932
|
+
# Compute unified gate
|
933
|
+
g = 2 * r_gate * f_gate + (1 - 2 * r_gate) * f_gate ** 2
|
934
|
+
|
935
|
+
# Update cell state
|
550
936
|
next_cell = g * c + (1 - g) * self.activation(u)
|
551
|
-
|
937
|
+
|
938
|
+
# Compute output gate and hidden state
|
939
|
+
o_gate = functional.sigmoid(o)
|
940
|
+
next_hidden = o_gate * self.activation(next_cell)
|
941
|
+
|
942
|
+
# Update states
|
552
943
|
self.h.value = next_hidden
|
553
944
|
self.c.value = next_cell
|
945
|
+
|
554
946
|
return next_hidden
|