brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +167 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2297 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +2157 -1652
- brainstate/_state_test.py +1129 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1620 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1447 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +146 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +635 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +134 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +480 -477
- brainstate/nn/_dynamics.py +870 -1267
- brainstate/nn/_dynamics_test.py +53 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +391 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +675 -675
- brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
- brainstate/random/{_rand_state.py → _state.py} +1320 -1617
- brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
- brainstate/transform/__init__.py +56 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2176 -2016
- brainstate/transform/_make_jaxpr_test.py +1634 -1510
- brainstate/transform/_mapping.py +607 -529
- brainstate/transform/_mapping_test.py +104 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
- brainstate-0.2.2.dist-info/RECORD +111 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- brainstate-0.2.1.dist-info/RECORD +0 -111
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/nn/_rnns.py
CHANGED
@@ -1,946 +1,946 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
|
19
|
-
from typing import Callable, Union
|
20
|
-
|
21
|
-
import jax.numpy as jnp
|
22
|
-
|
23
|
-
from brainstate import random
|
24
|
-
from brainstate._state import HiddenState, ParamState
|
25
|
-
from brainstate.typing import ArrayLike
|
26
|
-
from . import _activations as functional
|
27
|
-
from . import init as init
|
28
|
-
from ._linear import Linear
|
29
|
-
from ._module import Module
|
30
|
-
|
31
|
-
__all__ = [
|
32
|
-
'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
|
33
|
-
]
|
34
|
-
|
35
|
-
|
36
|
-
class RNNCell(Module):
|
37
|
-
"""
|
38
|
-
Base class for all recurrent neural network (RNN) cell implementations.
|
39
|
-
|
40
|
-
This abstract class serves as the foundation for implementing various RNN cell types
|
41
|
-
such as vanilla RNN, GRU, LSTM, and other recurrent architectures. It extends the
|
42
|
-
Module class and provides common functionality and interface for recurrent units.
|
43
|
-
|
44
|
-
All RNN cell implementations should inherit from this class and implement the required
|
45
|
-
methods, particularly the `init_state()`, `reset_state()`, and `update()` methods that
|
46
|
-
define the state initialization and recurrent dynamics.
|
47
|
-
|
48
|
-
The RNNCell typically maintains hidden state(s) that are updated at each time step
|
49
|
-
based on the current input and previous state values.
|
50
|
-
|
51
|
-
Methods
|
52
|
-
-------
|
53
|
-
init_state(batch_size=None, **kwargs)
|
54
|
-
Initialize the cell state variables with appropriate dimensions.
|
55
|
-
reset_state(batch_size=None, **kwargs)
|
56
|
-
Reset the cell state variables to their initial values.
|
57
|
-
update(x)
|
58
|
-
Update the cell state for one time step based on input x and return output.
|
59
|
-
|
60
|
-
See Also
|
61
|
-
--------
|
62
|
-
ValinaRNNCell : Vanilla RNN cell implementation
|
63
|
-
GRUCell : Gated Recurrent Unit cell implementation
|
64
|
-
LSTMCell : Long Short-Term Memory cell implementation
|
65
|
-
URLSTMCell : LSTM with UR gating mechanism
|
66
|
-
MGUCell : Minimal Gated Unit cell implementation
|
67
|
-
"""
|
68
|
-
__module__ = 'brainstate.nn'
|
69
|
-
pass
|
70
|
-
|
71
|
-
|
72
|
-
class ValinaRNNCell(RNNCell):
|
73
|
-
r"""
|
74
|
-
Vanilla Recurrent Neural Network (RNN) cell implementation.
|
75
|
-
|
76
|
-
This class implements the basic RNN model that updates a hidden state based on
|
77
|
-
the current input and previous hidden state. The standard RNN cell follows the
|
78
|
-
mathematical formulation:
|
79
|
-
|
80
|
-
.. math::
|
81
|
-
|
82
|
-
h_t = \phi(W [x_t, h_{t-1}] + b)
|
83
|
-
|
84
|
-
where:
|
85
|
-
|
86
|
-
- :math:`x_t` is the input vector at time t
|
87
|
-
- :math:`h_t` is the hidden state at time t
|
88
|
-
- :math:`h_{t-1}` is the hidden state at previous time step
|
89
|
-
- :math:`W` is the weight matrix for the combined input-hidden linear transformation
|
90
|
-
- :math:`b` is the bias vector
|
91
|
-
- :math:`\phi` is the activation function
|
92
|
-
|
93
|
-
Parameters
|
94
|
-
----------
|
95
|
-
num_in : int
|
96
|
-
The number of input units.
|
97
|
-
num_out : int
|
98
|
-
The number of hidden units.
|
99
|
-
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
100
|
-
Initializer for the hidden state.
|
101
|
-
w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
|
102
|
-
Initializer for the weight matrix.
|
103
|
-
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
104
|
-
Initializer for the bias vector.
|
105
|
-
activation : str or Callable, default='relu'
|
106
|
-
Activation function to use. Can be a string (e.g., 'relu', 'tanh')
|
107
|
-
or a callable function.
|
108
|
-
name : str, optional
|
109
|
-
Name of the module.
|
110
|
-
|
111
|
-
Attributes
|
112
|
-
----------
|
113
|
-
num_in : int
|
114
|
-
Number of input features.
|
115
|
-
num_out : int
|
116
|
-
Number of hidden units.
|
117
|
-
in_size : tuple
|
118
|
-
Shape of input (num_in,).
|
119
|
-
out_size : tuple
|
120
|
-
Shape of output (num_out,).
|
121
|
-
|
122
|
-
State Variables
|
123
|
-
---------------
|
124
|
-
h : HiddenState
|
125
|
-
Hidden state of the RNN cell.
|
126
|
-
|
127
|
-
Methods
|
128
|
-
-------
|
129
|
-
init_state(batch_size=None, **kwargs)
|
130
|
-
Initialize the cell hidden state.
|
131
|
-
reset_state(batch_size=None, **kwargs)
|
132
|
-
Reset the cell hidden state to its initial value.
|
133
|
-
update(x)
|
134
|
-
Update the hidden state for one time step and return the new state.
|
135
|
-
|
136
|
-
Examples
|
137
|
-
--------
|
138
|
-
.. code-block:: python
|
139
|
-
|
140
|
-
>>> import brainstate as bs
|
141
|
-
>>> import jax.numpy as jnp
|
142
|
-
>>>
|
143
|
-
>>> # Create a vanilla RNN cell
|
144
|
-
>>> cell = bs.nn.ValinaRNNCell(num_in=10, num_out=20)
|
145
|
-
>>>
|
146
|
-
>>> # Initialize state for batch size 32
|
147
|
-
>>> cell.init_state(batch_size=32)
|
148
|
-
>>>
|
149
|
-
>>> # Process a single time step
|
150
|
-
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
151
|
-
>>> output = cell.update(x)
|
152
|
-
>>> print(output.shape) # (32, 20)
|
153
|
-
>>>
|
154
|
-
>>> # Process a sequence of inputs
|
155
|
-
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
156
|
-
>>> outputs = []
|
157
|
-
>>> for t in range(100):
|
158
|
-
... output = cell.update(sequence[t])
|
159
|
-
... outputs.append(output)
|
160
|
-
>>> outputs = jnp.stack(outputs)
|
161
|
-
>>> print(outputs.shape) # (100, 32, 20)
|
162
|
-
|
163
|
-
Notes
|
164
|
-
-----
|
165
|
-
Vanilla RNNs can suffer from vanishing or exploding gradient problems
|
166
|
-
when processing long sequences. For better performance on long sequences,
|
167
|
-
consider using gated architectures like GRU or LSTM.
|
168
|
-
"""
|
169
|
-
__module__ = 'brainstate.nn'
|
170
|
-
|
171
|
-
def __init__(
|
172
|
-
self,
|
173
|
-
num_in: int,
|
174
|
-
num_out: int,
|
175
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
176
|
-
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
177
|
-
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
178
|
-
activation: str | Callable = 'relu',
|
179
|
-
name: str = None,
|
180
|
-
):
|
181
|
-
super().__init__(name=name)
|
182
|
-
|
183
|
-
# parameters
|
184
|
-
self.num_out = num_out
|
185
|
-
self.num_in = num_in
|
186
|
-
self.in_size = (num_in,)
|
187
|
-
self.out_size = (num_out,)
|
188
|
-
self._state_initializer = state_init
|
189
|
-
|
190
|
-
# activation function
|
191
|
-
if isinstance(activation, str):
|
192
|
-
self.activation = getattr(functional, activation)
|
193
|
-
else:
|
194
|
-
assert callable(activation), "The activation function should be a string or a callable function. "
|
195
|
-
self.activation = activation
|
196
|
-
|
197
|
-
# weights
|
198
|
-
self.W = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
199
|
-
|
200
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
201
|
-
"""
|
202
|
-
Initialize the hidden state.
|
203
|
-
|
204
|
-
Parameters
|
205
|
-
----------
|
206
|
-
batch_size : int, optional
|
207
|
-
The batch size for state initialization.
|
208
|
-
**kwargs
|
209
|
-
Additional keyword arguments.
|
210
|
-
"""
|
211
|
-
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
212
|
-
|
213
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
214
|
-
"""
|
215
|
-
Reset the hidden state to initial value.
|
216
|
-
|
217
|
-
Parameters
|
218
|
-
----------
|
219
|
-
batch_size : int, optional
|
220
|
-
The batch size for state reset.
|
221
|
-
**kwargs
|
222
|
-
Additional keyword arguments.
|
223
|
-
"""
|
224
|
-
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
225
|
-
|
226
|
-
def update(self, x):
|
227
|
-
xh = jnp.concatenate([x, self.h.value], axis=-1)
|
228
|
-
h = self.W(xh)
|
229
|
-
self.h.value = self.activation(h)
|
230
|
-
return self.h.value
|
231
|
-
|
232
|
-
|
233
|
-
class GRUCell(RNNCell):
|
234
|
-
r"""
|
235
|
-
Gated Recurrent Unit (GRU) cell implementation.
|
236
|
-
|
237
|
-
The GRU is a gating mechanism in recurrent neural networks that aims to solve
|
238
|
-
the vanishing gradient problem. It uses gating mechanisms to control information
|
239
|
-
flow and has fewer parameters than LSTM as it combines the forget and input gates
|
240
|
-
into a single update gate.
|
241
|
-
|
242
|
-
The GRU cell follows the mathematical formulation:
|
243
|
-
|
244
|
-
.. math::
|
245
|
-
|
246
|
-
r_t &= \sigma(W_r [x_t, h_{t-1}] + b_r) \\
|
247
|
-
z_t &= \sigma(W_z [x_t, h_{t-1}] + b_z) \\
|
248
|
-
\tilde{h}_t &= \phi(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\
|
249
|
-
h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
|
250
|
-
|
251
|
-
where:
|
252
|
-
|
253
|
-
- :math:`x_t` is the input vector at time t
|
254
|
-
- :math:`h_t` is the hidden state at time t
|
255
|
-
- :math:`r_t` is the reset gate vector
|
256
|
-
- :math:`z_t` is the update gate vector
|
257
|
-
- :math:`\tilde{h}_t` is the candidate hidden state
|
258
|
-
- :math:`\odot` represents element-wise multiplication
|
259
|
-
- :math:`\sigma` is the sigmoid activation function
|
260
|
-
- :math:`\phi` is the activation function (typically tanh)
|
261
|
-
|
262
|
-
Parameters
|
263
|
-
----------
|
264
|
-
num_in : int
|
265
|
-
The number of input units.
|
266
|
-
num_out : int
|
267
|
-
The number of hidden units.
|
268
|
-
w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
|
269
|
-
Initializer for the weight matrices.
|
270
|
-
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
271
|
-
Initializer for the bias vectors.
|
272
|
-
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
273
|
-
Initializer for the hidden state.
|
274
|
-
activation : str or Callable, default='tanh'
|
275
|
-
Activation function to use. Can be a string (e.g., 'tanh', 'relu')
|
276
|
-
or a callable function.
|
277
|
-
name : str, optional
|
278
|
-
Name of the module.
|
279
|
-
|
280
|
-
Attributes
|
281
|
-
----------
|
282
|
-
num_in : int
|
283
|
-
Number of input features.
|
284
|
-
num_out : int
|
285
|
-
Number of hidden units.
|
286
|
-
in_size : tuple
|
287
|
-
Shape of input (num_in,).
|
288
|
-
out_size : tuple
|
289
|
-
Shape of output (num_out,).
|
290
|
-
|
291
|
-
State Variables
|
292
|
-
---------------
|
293
|
-
h : HiddenState
|
294
|
-
Hidden state of the GRU cell.
|
295
|
-
|
296
|
-
Methods
|
297
|
-
-------
|
298
|
-
init_state(batch_size=None, **kwargs)
|
299
|
-
Initialize the cell hidden state.
|
300
|
-
reset_state(batch_size=None, **kwargs)
|
301
|
-
Reset the cell hidden state to its initial value.
|
302
|
-
update(x)
|
303
|
-
Update the hidden state for one time step and return the new state.
|
304
|
-
|
305
|
-
Examples
|
306
|
-
--------
|
307
|
-
.. code-block:: python
|
308
|
-
|
309
|
-
>>> import brainstate as bs
|
310
|
-
>>> import jax.numpy as jnp
|
311
|
-
>>>
|
312
|
-
>>> # Create a GRU cell
|
313
|
-
>>> cell = bs.nn.GRUCell(num_in=10, num_out=20)
|
314
|
-
>>>
|
315
|
-
>>> # Initialize state for batch size 32
|
316
|
-
>>> cell.init_state(batch_size=32)
|
317
|
-
>>>
|
318
|
-
>>> # Process a single time step
|
319
|
-
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
320
|
-
>>> output = cell.update(x)
|
321
|
-
>>> print(output.shape) # (32, 20)
|
322
|
-
>>>
|
323
|
-
>>> # Process a sequence
|
324
|
-
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
325
|
-
>>> outputs = []
|
326
|
-
>>> for t in range(100):
|
327
|
-
... output = cell.update(sequence[t])
|
328
|
-
... outputs.append(output)
|
329
|
-
>>> outputs = jnp.stack(outputs)
|
330
|
-
>>> print(outputs.shape) # (100, 32, 20)
|
331
|
-
>>>
|
332
|
-
>>> # Reset state with different batch size
|
333
|
-
>>> cell.reset_state(batch_size=16)
|
334
|
-
>>> x_new = jnp.ones((16, 10))
|
335
|
-
>>> output_new = cell.update(x_new)
|
336
|
-
>>> print(output_new.shape) # (16, 20)
|
337
|
-
|
338
|
-
Notes
|
339
|
-
-----
|
340
|
-
GRU cells are computationally more efficient than LSTM cells due to having
|
341
|
-
fewer parameters, while often achieving comparable performance on many tasks.
|
342
|
-
|
343
|
-
References
|
344
|
-
----------
|
345
|
-
.. [1] Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F.,
|
346
|
-
Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using
|
347
|
-
RNN encoder-decoder for statistical machine translation.
|
348
|
-
arXiv preprint arXiv:1406.1078.
|
349
|
-
"""
|
350
|
-
__module__ = 'brainstate.nn'
|
351
|
-
|
352
|
-
def __init__(
|
353
|
-
self,
|
354
|
-
num_in: int,
|
355
|
-
num_out: int,
|
356
|
-
w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
|
357
|
-
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
358
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
359
|
-
activation: str | Callable = 'tanh',
|
360
|
-
name: str = None,
|
361
|
-
):
|
362
|
-
super().__init__(name=name)
|
363
|
-
|
364
|
-
# parameters
|
365
|
-
self._state_initializer = state_init
|
366
|
-
self.num_out = num_out
|
367
|
-
self.num_in = num_in
|
368
|
-
self.in_size = (num_in,)
|
369
|
-
self.out_size = (num_out,)
|
370
|
-
|
371
|
-
# activation function
|
372
|
-
if isinstance(activation, str):
|
373
|
-
self.activation = getattr(functional, activation)
|
374
|
-
else:
|
375
|
-
assert callable(activation), "The activation function should be a string or a callable function. "
|
376
|
-
self.activation = activation
|
377
|
-
|
378
|
-
# weights
|
379
|
-
self.Wrz = Linear(num_in + num_out, num_out * 2, w_init=w_init, b_init=b_init)
|
380
|
-
self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
381
|
-
|
382
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
383
|
-
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
384
|
-
|
385
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
386
|
-
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
387
|
-
|
388
|
-
def update(self, x):
|
389
|
-
old_h = self.h.value
|
390
|
-
xh = jnp.concatenate([x, old_h], axis=-1)
|
391
|
-
r, z = jnp.split(functional.sigmoid(self.Wrz(xh)), indices_or_sections=2, axis=-1)
|
392
|
-
rh = r * old_h
|
393
|
-
h = self.activation(self.Wh(jnp.concatenate([x, rh], axis=-1)))
|
394
|
-
h = (1 - z) * old_h + z * h
|
395
|
-
self.h.value = h
|
396
|
-
return h
|
397
|
-
|
398
|
-
|
399
|
-
class MGUCell(RNNCell):
|
400
|
-
r"""
|
401
|
-
Minimal Gated Unit (MGU) cell implementation.
|
402
|
-
|
403
|
-
MGU is a simplified version of GRU that uses a single forget gate instead of
|
404
|
-
separate reset and update gates. This design significantly reduces the number
|
405
|
-
of parameters while maintaining much of the gating capability. MGU provides
|
406
|
-
a good trade-off between model complexity and performance.
|
407
|
-
|
408
|
-
The MGU cell follows the mathematical formulation:
|
409
|
-
|
410
|
-
.. math::
|
411
|
-
|
412
|
-
f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
|
413
|
-
\tilde{h}_t &= \phi(W_h [x_t, (f_t \odot h_{t-1})] + b_h) \\
|
414
|
-
h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t
|
415
|
-
|
416
|
-
where:
|
417
|
-
|
418
|
-
- :math:`x_t` is the input vector at time t
|
419
|
-
- :math:`h_t` is the hidden state at time t
|
420
|
-
- :math:`f_t` is the forget gate vector
|
421
|
-
- :math:`\tilde{h}_t` is the candidate hidden state
|
422
|
-
- :math:`\odot` represents element-wise multiplication
|
423
|
-
- :math:`\sigma` is the sigmoid activation function
|
424
|
-
- :math:`\phi` is the activation function (typically tanh)
|
425
|
-
|
426
|
-
Parameters
|
427
|
-
----------
|
428
|
-
num_in : int
|
429
|
-
The number of input units.
|
430
|
-
num_out : int
|
431
|
-
The number of hidden units.
|
432
|
-
w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
|
433
|
-
Initializer for the weight matrices.
|
434
|
-
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
435
|
-
Initializer for the bias vectors.
|
436
|
-
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
437
|
-
Initializer for the hidden state.
|
438
|
-
activation : str or Callable, default='tanh'
|
439
|
-
Activation function to use. Can be a string (e.g., 'tanh', 'relu')
|
440
|
-
or a callable function.
|
441
|
-
name : str, optional
|
442
|
-
Name of the module.
|
443
|
-
|
444
|
-
Attributes
|
445
|
-
----------
|
446
|
-
num_in : int
|
447
|
-
Number of input features.
|
448
|
-
num_out : int
|
449
|
-
Number of hidden units.
|
450
|
-
in_size : tuple
|
451
|
-
Shape of input (num_in,).
|
452
|
-
out_size : tuple
|
453
|
-
Shape of output (num_out,).
|
454
|
-
|
455
|
-
State Variables
|
456
|
-
---------------
|
457
|
-
h : HiddenState
|
458
|
-
Hidden state of the MGU cell.
|
459
|
-
|
460
|
-
Methods
|
461
|
-
-------
|
462
|
-
init_state(batch_size=None, **kwargs)
|
463
|
-
Initialize the cell hidden state.
|
464
|
-
reset_state(batch_size=None, **kwargs)
|
465
|
-
Reset the cell hidden state to its initial value.
|
466
|
-
update(x)
|
467
|
-
Update the hidden state for one time step and return the new state.
|
468
|
-
|
469
|
-
Examples
|
470
|
-
--------
|
471
|
-
.. code-block:: python
|
472
|
-
|
473
|
-
>>> import brainstate as bs
|
474
|
-
>>> import jax.numpy as jnp
|
475
|
-
>>>
|
476
|
-
>>> # Create an MGU cell
|
477
|
-
>>> cell = bs.nn.MGUCell(num_in=10, num_out=20)
|
478
|
-
>>>
|
479
|
-
>>> # Initialize state for batch size 32
|
480
|
-
>>> cell.init_state(batch_size=32)
|
481
|
-
>>>
|
482
|
-
>>> # Process a single time step
|
483
|
-
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
484
|
-
>>> output = cell.update(x)
|
485
|
-
>>> print(output.shape) # (32, 20)
|
486
|
-
>>>
|
487
|
-
>>> # Process a sequence
|
488
|
-
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
489
|
-
>>> outputs = []
|
490
|
-
>>> for t in range(100):
|
491
|
-
... output = cell.update(sequence[t])
|
492
|
-
... outputs.append(output)
|
493
|
-
>>> outputs = jnp.stack(outputs)
|
494
|
-
>>> print(outputs.shape) # (100, 32, 20)
|
495
|
-
|
496
|
-
Notes
|
497
|
-
-----
|
498
|
-
MGU provides a lightweight alternative to GRU and LSTM, making it suitable
|
499
|
-
for resource-constrained applications or when model simplicity is preferred.
|
500
|
-
|
501
|
-
References
|
502
|
-
----------
|
503
|
-
.. [1] Zhou, G. B., Wu, J., Zhang, C. L., & Zhou, Z. H. (2016). Minimal gated unit
|
504
|
-
for recurrent neural networks. International Journal of Automation and Computing,
|
505
|
-
13(3), 226-234.
|
506
|
-
"""
|
507
|
-
__module__ = 'brainstate.nn'
|
508
|
-
|
509
|
-
def __init__(
|
510
|
-
self,
|
511
|
-
num_in: int,
|
512
|
-
num_out: int,
|
513
|
-
w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
|
514
|
-
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
515
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
516
|
-
activation: str | Callable = 'tanh',
|
517
|
-
name: str = None,
|
518
|
-
):
|
519
|
-
super().__init__(name=name)
|
520
|
-
|
521
|
-
# parameters
|
522
|
-
self._state_initializer = state_init
|
523
|
-
self.num_out = num_out
|
524
|
-
self.num_in = num_in
|
525
|
-
self.in_size = (num_in,)
|
526
|
-
self.out_size = (num_out,)
|
527
|
-
|
528
|
-
# activation function
|
529
|
-
if isinstance(activation, str):
|
530
|
-
self.activation = getattr(functional, activation)
|
531
|
-
else:
|
532
|
-
assert callable(activation), "The activation function should be a string or a callable function. "
|
533
|
-
self.activation = activation
|
534
|
-
|
535
|
-
# weights
|
536
|
-
self.Wf = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
537
|
-
self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
538
|
-
|
539
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
540
|
-
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
541
|
-
|
542
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
543
|
-
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
544
|
-
|
545
|
-
def update(self, x):
|
546
|
-
old_h = self.h.value
|
547
|
-
xh = jnp.concatenate([x, old_h], axis=-1)
|
548
|
-
f = functional.sigmoid(self.Wf(xh))
|
549
|
-
fh = f * old_h
|
550
|
-
h = self.activation(self.Wh(jnp.concatenate([x, fh], axis=-1)))
|
551
|
-
self.h.value = (1 - f) * self.h.value + f * h
|
552
|
-
return self.h.value
|
553
|
-
|
554
|
-
|
555
|
-
class LSTMCell(RNNCell):
|
556
|
-
r"""
|
557
|
-
Long Short-Term Memory (LSTM) cell implementation.
|
558
|
-
|
559
|
-
LSTM is a type of RNN architecture designed to address the vanishing gradient
|
560
|
-
problem and learn long-term dependencies. It uses a cell state to carry
|
561
|
-
information across time steps and three gates (input, forget, output) to
|
562
|
-
control information flow.
|
563
|
-
|
564
|
-
The LSTM cell follows the mathematical formulation:
|
565
|
-
|
566
|
-
.. math::
|
567
|
-
|
568
|
-
i_t &= \sigma(W_i [x_t, h_{t-1}] + b_i) \\
|
569
|
-
f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
|
570
|
-
g_t &= \phi(W_g [x_t, h_{t-1}] + b_g) \\
|
571
|
-
o_t &= \sigma(W_o [x_t, h_{t-1}] + b_o) \\
|
572
|
-
c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\
|
573
|
-
h_t &= o_t \odot \phi(c_t)
|
574
|
-
|
575
|
-
where:
|
576
|
-
|
577
|
-
- :math:`x_t` is the input vector at time t
|
578
|
-
- :math:`h_t` is the hidden state at time t
|
579
|
-
- :math:`c_t` is the cell state at time t
|
580
|
-
- :math:`i_t` is the input gate activation
|
581
|
-
- :math:`f_t` is the forget gate activation
|
582
|
-
- :math:`o_t` is the output gate activation
|
583
|
-
- :math:`g_t` is the cell update (candidate) vector
|
584
|
-
- :math:`\odot` represents element-wise multiplication
|
585
|
-
- :math:`\sigma` is the sigmoid activation function
|
586
|
-
- :math:`\phi` is the activation function (typically tanh)
|
587
|
-
|
588
|
-
Parameters
|
589
|
-
----------
|
590
|
-
num_in : int
|
591
|
-
The number of input units.
|
592
|
-
num_out : int
|
593
|
-
The number of hidden/cell units.
|
594
|
-
w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
|
595
|
-
Initializer for the weight matrices.
|
596
|
-
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
597
|
-
Initializer for the bias vectors.
|
598
|
-
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
599
|
-
Initializer for the hidden and cell states.
|
600
|
-
activation : str or Callable, default='tanh'
|
601
|
-
Activation function to use. Can be a string (e.g., 'tanh', 'relu')
|
602
|
-
or a callable function.
|
603
|
-
name : str, optional
|
604
|
-
Name of the module.
|
605
|
-
|
606
|
-
Attributes
|
607
|
-
----------
|
608
|
-
num_in : int
|
609
|
-
Number of input features.
|
610
|
-
num_out : int
|
611
|
-
Number of hidden/cell units.
|
612
|
-
in_size : tuple
|
613
|
-
Shape of input (num_in,).
|
614
|
-
out_size : tuple
|
615
|
-
Shape of output (num_out,).
|
616
|
-
|
617
|
-
State Variables
|
618
|
-
---------------
|
619
|
-
h : HiddenState
|
620
|
-
Hidden state of the LSTM cell.
|
621
|
-
c : HiddenState
|
622
|
-
Cell state of the LSTM cell.
|
623
|
-
|
624
|
-
Methods
|
625
|
-
-------
|
626
|
-
init_state(batch_size=None, **kwargs)
|
627
|
-
Initialize the cell and hidden states.
|
628
|
-
reset_state(batch_size=None, **kwargs)
|
629
|
-
Reset the cell and hidden states to their initial values.
|
630
|
-
update(x)
|
631
|
-
Update the states for one time step and return the new hidden state.
|
632
|
-
|
633
|
-
Examples
|
634
|
-
--------
|
635
|
-
.. code-block:: python
|
636
|
-
|
637
|
-
>>> import brainstate as bs
|
638
|
-
>>> import jax.numpy as jnp
|
639
|
-
>>>
|
640
|
-
>>> # Create an LSTM cell
|
641
|
-
>>> cell = bs.nn.LSTMCell(num_in=10, num_out=20)
|
642
|
-
>>>
|
643
|
-
>>> # Initialize states for batch size 32
|
644
|
-
>>> cell.init_state(batch_size=32)
|
645
|
-
>>>
|
646
|
-
>>> # Process a single time step
|
647
|
-
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
648
|
-
>>> output = cell.update(x)
|
649
|
-
>>> print(output.shape) # (32, 20)
|
650
|
-
>>>
|
651
|
-
>>> # Process a sequence
|
652
|
-
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
653
|
-
>>> outputs = []
|
654
|
-
>>> for t in range(100):
|
655
|
-
... output = cell.update(sequence[t])
|
656
|
-
... outputs.append(output)
|
657
|
-
>>> outputs = jnp.stack(outputs)
|
658
|
-
>>> print(outputs.shape) # (100, 32, 20)
|
659
|
-
>>>
|
660
|
-
>>> # Access cell state
|
661
|
-
>>> print(cell.c.value.shape) # (32, 20)
|
662
|
-
>>> print(cell.h.value.shape) # (32, 20)
|
663
|
-
|
664
|
-
Notes
|
665
|
-
-----
|
666
|
-
- The forget gate bias is initialized with +1.0 following Jozefowicz et al. (2015)
|
667
|
-
to reduce forgetting at the beginning of training.
|
668
|
-
- LSTM cells are effective for learning long-term dependencies but require
|
669
|
-
more parameters and computation than simpler RNN variants.
|
670
|
-
|
671
|
-
References
|
672
|
-
----------
|
673
|
-
.. [1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.
|
674
|
-
Neural computation, 9(8), 1735-1780.
|
675
|
-
.. [2] Gers, F. A., Schmidhuber, J., & Cummins, F. (2000). Learning to forget:
|
676
|
-
Continual prediction with LSTM. Neural computation, 12(10), 2451-2471.
|
677
|
-
.. [3] Jozefowicz, R., Zaremba, W., & Sutskever, I. (2015). An empirical
|
678
|
-
exploration of recurrent network architectures. In International
|
679
|
-
conference on machine learning (pp. 2342-2350).
|
680
|
-
"""
|
681
|
-
__module__ = 'brainstate.nn'
|
682
|
-
|
683
|
-
def __init__(
|
684
|
-
self,
|
685
|
-
num_in: int,
|
686
|
-
num_out: int,
|
687
|
-
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
688
|
-
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
689
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
690
|
-
activation: str | Callable = 'tanh',
|
691
|
-
name: str = None,
|
692
|
-
):
|
693
|
-
super().__init__(name=name)
|
694
|
-
|
695
|
-
# parameters
|
696
|
-
self.num_out = num_out
|
697
|
-
self.num_in = num_in
|
698
|
-
self.in_size = (num_in,)
|
699
|
-
self.out_size = (num_out,)
|
700
|
-
|
701
|
-
# initializers
|
702
|
-
self._state_initializer = state_init
|
703
|
-
|
704
|
-
# activation function
|
705
|
-
if isinstance(activation, str):
|
706
|
-
self.activation = getattr(functional, activation)
|
707
|
-
else:
|
708
|
-
assert callable(activation), "The activation function should be a string or a callable function. "
|
709
|
-
self.activation = activation
|
710
|
-
|
711
|
-
# weights
|
712
|
-
self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=b_init)
|
713
|
-
|
714
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
715
|
-
self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
716
|
-
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
717
|
-
|
718
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
719
|
-
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
720
|
-
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
721
|
-
|
722
|
-
def update(self, x):
|
723
|
-
h, c = self.h.value, self.c.value
|
724
|
-
xh = jnp.concat([x, h], axis=-1)
|
725
|
-
i, g, f, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
|
726
|
-
c = functional.sigmoid(f + 1.) * c + functional.sigmoid(i) * self.activation(g)
|
727
|
-
h = functional.sigmoid(o) * self.activation(c)
|
728
|
-
self.h.value = h
|
729
|
-
self.c.value = c
|
730
|
-
return h
|
731
|
-
|
732
|
-
|
733
|
-
class URLSTMCell(RNNCell):
|
734
|
-
r"""LSTM with UR gating mechanism.
|
735
|
-
|
736
|
-
URLSTM is a modification of the standard LSTM that uses untied (separate) biases
|
737
|
-
for the forget and retention mechanisms, allowing for more flexible gating control.
|
738
|
-
This implementation is based on the paper "Improving the Gating Mechanism of
|
739
|
-
Recurrent Neural Networks" by Gers et al.
|
740
|
-
|
741
|
-
The URLSTM cell follows the mathematical formulation:
|
742
|
-
|
743
|
-
.. math::
|
744
|
-
|
745
|
-
f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
|
746
|
-
r_t &= \sigma(W_r [x_t, h_{t-1}] - b_f) \\
|
747
|
-
g_t &= 2 r_t \odot f_t + (1 - 2 r_t) \odot f_t^2 \\
|
748
|
-
\tilde{c}_t &= \phi(W_c [x_t, h_{t-1}]) \\
|
749
|
-
c_t &= g_t \odot c_{t-1} + (1 - g_t) \odot \tilde{c}_t \\
|
750
|
-
o_t &= \sigma(W_o [x_t, h_{t-1}]) \\
|
751
|
-
h_t &= o_t \odot \phi(c_t)
|
752
|
-
|
753
|
-
where:
|
754
|
-
|
755
|
-
- :math:`x_t` is the input vector at time t
|
756
|
-
- :math:`h_t` is the hidden state at time t
|
757
|
-
- :math:`c_t` is the cell state at time t
|
758
|
-
- :math:`f_t` is the forget gate with positive bias
|
759
|
-
- :math:`r_t` is the retention gate with negative bias
|
760
|
-
- :math:`g_t` is the unified gate combining forget and retention
|
761
|
-
- :math:`\tilde{c}_t` is the candidate cell state
|
762
|
-
- :math:`o_t` is the output gate
|
763
|
-
- :math:`\odot` represents element-wise multiplication
|
764
|
-
- :math:`\sigma` is the sigmoid activation function
|
765
|
-
- :math:`\phi` is the activation function (typically tanh)
|
766
|
-
|
767
|
-
The key innovation is the untied bias mechanism where the forget and retention
|
768
|
-
gates use opposite biases, initialized using a uniform distribution to encourage
|
769
|
-
diverse gating behavior across units.
|
770
|
-
|
771
|
-
Parameters
|
772
|
-
----------
|
773
|
-
num_in : int
|
774
|
-
The number of input units.
|
775
|
-
num_out : int
|
776
|
-
The number of hidden/output units.
|
777
|
-
w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
|
778
|
-
Initializer for the weight matrix.
|
779
|
-
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
780
|
-
Initializer for the hidden and cell states.
|
781
|
-
activation : str or Callable, default='tanh'
|
782
|
-
Activation function to use. Can be a string (e.g., 'relu', 'tanh')
|
783
|
-
or a callable function.
|
784
|
-
name : str, optional
|
785
|
-
Name of the module.
|
786
|
-
|
787
|
-
State Variables
|
788
|
-
---------------
|
789
|
-
h : HiddenState
|
790
|
-
Hidden state of the URLSTM cell.
|
791
|
-
c : HiddenState
|
792
|
-
Cell state of the URLSTM cell.
|
793
|
-
|
794
|
-
Methods
|
795
|
-
-------
|
796
|
-
init_state(batch_size=None, **kwargs)
|
797
|
-
Initialize the cell and hidden states.
|
798
|
-
reset_state(batch_size=None, **kwargs)
|
799
|
-
Reset the cell and hidden states to their initial values.
|
800
|
-
update(x)
|
801
|
-
Update the cell and hidden states for one time step and return the hidden state.
|
802
|
-
|
803
|
-
Examples
|
804
|
-
--------
|
805
|
-
.. code-block:: python
|
806
|
-
|
807
|
-
>>> import brainstate as bs
|
808
|
-
>>> import jax.numpy as jnp
|
809
|
-
>>>
|
810
|
-
>>> # Create a URLSTM cell
|
811
|
-
>>> cell = bs.nn.URLSTMCell(num_in=10, num_out=20)
|
812
|
-
>>>
|
813
|
-
>>> # Initialize the state for batch size 32
|
814
|
-
>>> cell.init_state(batch_size=32)
|
815
|
-
>>>
|
816
|
-
>>> # Process a sequence
|
817
|
-
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
818
|
-
>>> output = cell.update(x)
|
819
|
-
>>> print(output.shape) # (32, 20)
|
820
|
-
>>>
|
821
|
-
>>> # Process multiple time steps
|
822
|
-
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
823
|
-
>>> outputs = []
|
824
|
-
>>> for t in range(100):
|
825
|
-
... output = cell.update(sequence[t])
|
826
|
-
... outputs.append(output)
|
827
|
-
>>> outputs = jnp.stack(outputs)
|
828
|
-
>>> print(outputs.shape) # (100, 32, 20)
|
829
|
-
|
830
|
-
References
|
831
|
-
----------
|
832
|
-
.. [1] Gu, Albert, et al. "Improving the gating mechanism of recurrent neural networks."
|
833
|
-
International conference on machine learning. PMLR, 2020.
|
834
|
-
"""
|
835
|
-
__module__ = 'brainstate.nn'
|
836
|
-
|
837
|
-
def __init__(
|
838
|
-
self,
|
839
|
-
num_in: int,
|
840
|
-
num_out: int,
|
841
|
-
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
842
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
843
|
-
activation: str | Callable = 'tanh',
|
844
|
-
name: str = None,
|
845
|
-
):
|
846
|
-
super().__init__(name=name)
|
847
|
-
|
848
|
-
# parameters
|
849
|
-
self.num_out = num_out
|
850
|
-
self.num_in = num_in
|
851
|
-
self.in_size = (num_in,)
|
852
|
-
self.out_size = (num_out,)
|
853
|
-
|
854
|
-
# initializers
|
855
|
-
self._state_initializer = state_init
|
856
|
-
|
857
|
-
# activation function
|
858
|
-
if isinstance(activation, str):
|
859
|
-
self.activation = getattr(functional, activation)
|
860
|
-
else:
|
861
|
-
assert callable(activation), "The activation function should be a string or a callable function."
|
862
|
-
self.activation = activation
|
863
|
-
|
864
|
-
# weights - 4 gates: forget, retention, candidate, output
|
865
|
-
self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=None)
|
866
|
-
|
867
|
-
# Initialize untied bias using uniform distribution
|
868
|
-
self.bias = ParamState(self._forget_bias())
|
869
|
-
|
870
|
-
def _forget_bias(self):
|
871
|
-
"""Initialize the forget gate bias using uniform distribution."""
|
872
|
-
# Sample from uniform distribution to encourage diverse gating
|
873
|
-
u = random.uniform(1 / self.num_out, 1 - 1 / self.num_out, (self.num_out,))
|
874
|
-
# Transform to logit space for initialization
|
875
|
-
return -jnp.log(1 / u - 1)
|
876
|
-
|
877
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
878
|
-
"""
|
879
|
-
Initialize the cell and hidden states.
|
880
|
-
|
881
|
-
Parameters
|
882
|
-
----------
|
883
|
-
batch_size : int, optional
|
884
|
-
The batch size for state initialization.
|
885
|
-
**kwargs
|
886
|
-
Additional keyword arguments.
|
887
|
-
"""
|
888
|
-
self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
889
|
-
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
890
|
-
|
891
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
892
|
-
"""
|
893
|
-
Reset the cell and hidden states to their initial values.
|
894
|
-
|
895
|
-
Parameters
|
896
|
-
----------
|
897
|
-
batch_size : int, optional
|
898
|
-
The batch size for state reset.
|
899
|
-
**kwargs
|
900
|
-
Additional keyword arguments.
|
901
|
-
"""
|
902
|
-
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
903
|
-
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
904
|
-
|
905
|
-
def update(self, x: ArrayLike) -> ArrayLike:
|
906
|
-
"""
|
907
|
-
Update the URLSTM cell for one time step.
|
908
|
-
|
909
|
-
Parameters
|
910
|
-
----------
|
911
|
-
x : ArrayLike
|
912
|
-
Input tensor with shape (batch_size, num_in).
|
913
|
-
|
914
|
-
Returns
|
915
|
-
-------
|
916
|
-
ArrayLike
|
917
|
-
Hidden state tensor with shape (batch_size, num_out).
|
918
|
-
"""
|
919
|
-
h, c = self.h.value, self.c.value
|
920
|
-
|
921
|
-
# Concatenate input and hidden state
|
922
|
-
xh = jnp.concatenate([x, h], axis=-1)
|
923
|
-
|
924
|
-
# Compute all gates in one pass
|
925
|
-
gates = self.W(xh)
|
926
|
-
f, r, u, o = jnp.split(gates, indices_or_sections=4, axis=-1)
|
927
|
-
|
928
|
-
# Apply untied biases to forget and retention gates
|
929
|
-
f_gate = functional.sigmoid(f + self.bias.value)
|
930
|
-
r_gate = functional.sigmoid(r - self.bias.value)
|
931
|
-
|
932
|
-
# Compute unified gate
|
933
|
-
g = 2 * r_gate * f_gate + (1 - 2 * r_gate) * f_gate ** 2
|
934
|
-
|
935
|
-
# Update cell state
|
936
|
-
next_cell = g * c + (1 - g) * self.activation(u)
|
937
|
-
|
938
|
-
# Compute output gate and hidden state
|
939
|
-
o_gate = functional.sigmoid(o)
|
940
|
-
next_hidden = o_gate * self.activation(next_cell)
|
941
|
-
|
942
|
-
# Update states
|
943
|
-
self.h.value = next_hidden
|
944
|
-
self.c.value = next_cell
|
945
|
-
|
946
|
-
return next_hidden
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
|
19
|
+
from typing import Callable, Union
|
20
|
+
|
21
|
+
import jax.numpy as jnp
|
22
|
+
|
23
|
+
from brainstate import random
|
24
|
+
from brainstate._state import HiddenState, ParamState
|
25
|
+
from brainstate.typing import ArrayLike
|
26
|
+
from . import _activations as functional
|
27
|
+
from . import init as init
|
28
|
+
from ._linear import Linear
|
29
|
+
from ._module import Module
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
|
33
|
+
]
|
34
|
+
|
35
|
+
|
36
|
+
class RNNCell(Module):
|
37
|
+
"""
|
38
|
+
Base class for all recurrent neural network (RNN) cell implementations.
|
39
|
+
|
40
|
+
This abstract class serves as the foundation for implementing various RNN cell types
|
41
|
+
such as vanilla RNN, GRU, LSTM, and other recurrent architectures. It extends the
|
42
|
+
Module class and provides common functionality and interface for recurrent units.
|
43
|
+
|
44
|
+
All RNN cell implementations should inherit from this class and implement the required
|
45
|
+
methods, particularly the `init_state()`, `reset_state()`, and `update()` methods that
|
46
|
+
define the state initialization and recurrent dynamics.
|
47
|
+
|
48
|
+
The RNNCell typically maintains hidden state(s) that are updated at each time step
|
49
|
+
based on the current input and previous state values.
|
50
|
+
|
51
|
+
Methods
|
52
|
+
-------
|
53
|
+
init_state(batch_size=None, **kwargs)
|
54
|
+
Initialize the cell state variables with appropriate dimensions.
|
55
|
+
reset_state(batch_size=None, **kwargs)
|
56
|
+
Reset the cell state variables to their initial values.
|
57
|
+
update(x)
|
58
|
+
Update the cell state for one time step based on input x and return output.
|
59
|
+
|
60
|
+
See Also
|
61
|
+
--------
|
62
|
+
ValinaRNNCell : Vanilla RNN cell implementation
|
63
|
+
GRUCell : Gated Recurrent Unit cell implementation
|
64
|
+
LSTMCell : Long Short-Term Memory cell implementation
|
65
|
+
URLSTMCell : LSTM with UR gating mechanism
|
66
|
+
MGUCell : Minimal Gated Unit cell implementation
|
67
|
+
"""
|
68
|
+
__module__ = 'brainstate.nn'
|
69
|
+
pass
|
70
|
+
|
71
|
+
|
72
|
+
class ValinaRNNCell(RNNCell):
|
73
|
+
r"""
|
74
|
+
Vanilla Recurrent Neural Network (RNN) cell implementation.
|
75
|
+
|
76
|
+
This class implements the basic RNN model that updates a hidden state based on
|
77
|
+
the current input and previous hidden state. The standard RNN cell follows the
|
78
|
+
mathematical formulation:
|
79
|
+
|
80
|
+
.. math::
|
81
|
+
|
82
|
+
h_t = \phi(W [x_t, h_{t-1}] + b)
|
83
|
+
|
84
|
+
where:
|
85
|
+
|
86
|
+
- :math:`x_t` is the input vector at time t
|
87
|
+
- :math:`h_t` is the hidden state at time t
|
88
|
+
- :math:`h_{t-1}` is the hidden state at previous time step
|
89
|
+
- :math:`W` is the weight matrix for the combined input-hidden linear transformation
|
90
|
+
- :math:`b` is the bias vector
|
91
|
+
- :math:`\phi` is the activation function
|
92
|
+
|
93
|
+
Parameters
|
94
|
+
----------
|
95
|
+
num_in : int
|
96
|
+
The number of input units.
|
97
|
+
num_out : int
|
98
|
+
The number of hidden units.
|
99
|
+
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
100
|
+
Initializer for the hidden state.
|
101
|
+
w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
|
102
|
+
Initializer for the weight matrix.
|
103
|
+
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
104
|
+
Initializer for the bias vector.
|
105
|
+
activation : str or Callable, default='relu'
|
106
|
+
Activation function to use. Can be a string (e.g., 'relu', 'tanh')
|
107
|
+
or a callable function.
|
108
|
+
name : str, optional
|
109
|
+
Name of the module.
|
110
|
+
|
111
|
+
Attributes
|
112
|
+
----------
|
113
|
+
num_in : int
|
114
|
+
Number of input features.
|
115
|
+
num_out : int
|
116
|
+
Number of hidden units.
|
117
|
+
in_size : tuple
|
118
|
+
Shape of input (num_in,).
|
119
|
+
out_size : tuple
|
120
|
+
Shape of output (num_out,).
|
121
|
+
|
122
|
+
State Variables
|
123
|
+
---------------
|
124
|
+
h : HiddenState
|
125
|
+
Hidden state of the RNN cell.
|
126
|
+
|
127
|
+
Methods
|
128
|
+
-------
|
129
|
+
init_state(batch_size=None, **kwargs)
|
130
|
+
Initialize the cell hidden state.
|
131
|
+
reset_state(batch_size=None, **kwargs)
|
132
|
+
Reset the cell hidden state to its initial value.
|
133
|
+
update(x)
|
134
|
+
Update the hidden state for one time step and return the new state.
|
135
|
+
|
136
|
+
Examples
|
137
|
+
--------
|
138
|
+
.. code-block:: python
|
139
|
+
|
140
|
+
>>> import brainstate as bs
|
141
|
+
>>> import jax.numpy as jnp
|
142
|
+
>>>
|
143
|
+
>>> # Create a vanilla RNN cell
|
144
|
+
>>> cell = bs.nn.ValinaRNNCell(num_in=10, num_out=20)
|
145
|
+
>>>
|
146
|
+
>>> # Initialize state for batch size 32
|
147
|
+
>>> cell.init_state(batch_size=32)
|
148
|
+
>>>
|
149
|
+
>>> # Process a single time step
|
150
|
+
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
151
|
+
>>> output = cell.update(x)
|
152
|
+
>>> print(output.shape) # (32, 20)
|
153
|
+
>>>
|
154
|
+
>>> # Process a sequence of inputs
|
155
|
+
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
156
|
+
>>> outputs = []
|
157
|
+
>>> for t in range(100):
|
158
|
+
... output = cell.update(sequence[t])
|
159
|
+
... outputs.append(output)
|
160
|
+
>>> outputs = jnp.stack(outputs)
|
161
|
+
>>> print(outputs.shape) # (100, 32, 20)
|
162
|
+
|
163
|
+
Notes
|
164
|
+
-----
|
165
|
+
Vanilla RNNs can suffer from vanishing or exploding gradient problems
|
166
|
+
when processing long sequences. For better performance on long sequences,
|
167
|
+
consider using gated architectures like GRU or LSTM.
|
168
|
+
"""
|
169
|
+
__module__ = 'brainstate.nn'
|
170
|
+
|
171
|
+
def __init__(
|
172
|
+
self,
|
173
|
+
num_in: int,
|
174
|
+
num_out: int,
|
175
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
176
|
+
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
177
|
+
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
178
|
+
activation: str | Callable = 'relu',
|
179
|
+
name: str = None,
|
180
|
+
):
|
181
|
+
super().__init__(name=name)
|
182
|
+
|
183
|
+
# parameters
|
184
|
+
self.num_out = num_out
|
185
|
+
self.num_in = num_in
|
186
|
+
self.in_size = (num_in,)
|
187
|
+
self.out_size = (num_out,)
|
188
|
+
self._state_initializer = state_init
|
189
|
+
|
190
|
+
# activation function
|
191
|
+
if isinstance(activation, str):
|
192
|
+
self.activation = getattr(functional, activation)
|
193
|
+
else:
|
194
|
+
assert callable(activation), "The activation function should be a string or a callable function. "
|
195
|
+
self.activation = activation
|
196
|
+
|
197
|
+
# weights
|
198
|
+
self.W = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
199
|
+
|
200
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
201
|
+
"""
|
202
|
+
Initialize the hidden state.
|
203
|
+
|
204
|
+
Parameters
|
205
|
+
----------
|
206
|
+
batch_size : int, optional
|
207
|
+
The batch size for state initialization.
|
208
|
+
**kwargs
|
209
|
+
Additional keyword arguments.
|
210
|
+
"""
|
211
|
+
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
212
|
+
|
213
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
214
|
+
"""
|
215
|
+
Reset the hidden state to initial value.
|
216
|
+
|
217
|
+
Parameters
|
218
|
+
----------
|
219
|
+
batch_size : int, optional
|
220
|
+
The batch size for state reset.
|
221
|
+
**kwargs
|
222
|
+
Additional keyword arguments.
|
223
|
+
"""
|
224
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
225
|
+
|
226
|
+
def update(self, x):
|
227
|
+
xh = jnp.concatenate([x, self.h.value], axis=-1)
|
228
|
+
h = self.W(xh)
|
229
|
+
self.h.value = self.activation(h)
|
230
|
+
return self.h.value
|
231
|
+
|
232
|
+
|
233
|
+
class GRUCell(RNNCell):
|
234
|
+
r"""
|
235
|
+
Gated Recurrent Unit (GRU) cell implementation.
|
236
|
+
|
237
|
+
The GRU is a gating mechanism in recurrent neural networks that aims to solve
|
238
|
+
the vanishing gradient problem. It uses gating mechanisms to control information
|
239
|
+
flow and has fewer parameters than LSTM as it combines the forget and input gates
|
240
|
+
into a single update gate.
|
241
|
+
|
242
|
+
The GRU cell follows the mathematical formulation:
|
243
|
+
|
244
|
+
.. math::
|
245
|
+
|
246
|
+
r_t &= \sigma(W_r [x_t, h_{t-1}] + b_r) \\
|
247
|
+
z_t &= \sigma(W_z [x_t, h_{t-1}] + b_z) \\
|
248
|
+
\tilde{h}_t &= \phi(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\
|
249
|
+
h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
|
250
|
+
|
251
|
+
where:
|
252
|
+
|
253
|
+
- :math:`x_t` is the input vector at time t
|
254
|
+
- :math:`h_t` is the hidden state at time t
|
255
|
+
- :math:`r_t` is the reset gate vector
|
256
|
+
- :math:`z_t` is the update gate vector
|
257
|
+
- :math:`\tilde{h}_t` is the candidate hidden state
|
258
|
+
- :math:`\odot` represents element-wise multiplication
|
259
|
+
- :math:`\sigma` is the sigmoid activation function
|
260
|
+
- :math:`\phi` is the activation function (typically tanh)
|
261
|
+
|
262
|
+
Parameters
|
263
|
+
----------
|
264
|
+
num_in : int
|
265
|
+
The number of input units.
|
266
|
+
num_out : int
|
267
|
+
The number of hidden units.
|
268
|
+
w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
|
269
|
+
Initializer for the weight matrices.
|
270
|
+
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
271
|
+
Initializer for the bias vectors.
|
272
|
+
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
273
|
+
Initializer for the hidden state.
|
274
|
+
activation : str or Callable, default='tanh'
|
275
|
+
Activation function to use. Can be a string (e.g., 'tanh', 'relu')
|
276
|
+
or a callable function.
|
277
|
+
name : str, optional
|
278
|
+
Name of the module.
|
279
|
+
|
280
|
+
Attributes
|
281
|
+
----------
|
282
|
+
num_in : int
|
283
|
+
Number of input features.
|
284
|
+
num_out : int
|
285
|
+
Number of hidden units.
|
286
|
+
in_size : tuple
|
287
|
+
Shape of input (num_in,).
|
288
|
+
out_size : tuple
|
289
|
+
Shape of output (num_out,).
|
290
|
+
|
291
|
+
State Variables
|
292
|
+
---------------
|
293
|
+
h : HiddenState
|
294
|
+
Hidden state of the GRU cell.
|
295
|
+
|
296
|
+
Methods
|
297
|
+
-------
|
298
|
+
init_state(batch_size=None, **kwargs)
|
299
|
+
Initialize the cell hidden state.
|
300
|
+
reset_state(batch_size=None, **kwargs)
|
301
|
+
Reset the cell hidden state to its initial value.
|
302
|
+
update(x)
|
303
|
+
Update the hidden state for one time step and return the new state.
|
304
|
+
|
305
|
+
Examples
|
306
|
+
--------
|
307
|
+
.. code-block:: python
|
308
|
+
|
309
|
+
>>> import brainstate as bs
|
310
|
+
>>> import jax.numpy as jnp
|
311
|
+
>>>
|
312
|
+
>>> # Create a GRU cell
|
313
|
+
>>> cell = bs.nn.GRUCell(num_in=10, num_out=20)
|
314
|
+
>>>
|
315
|
+
>>> # Initialize state for batch size 32
|
316
|
+
>>> cell.init_state(batch_size=32)
|
317
|
+
>>>
|
318
|
+
>>> # Process a single time step
|
319
|
+
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
320
|
+
>>> output = cell.update(x)
|
321
|
+
>>> print(output.shape) # (32, 20)
|
322
|
+
>>>
|
323
|
+
>>> # Process a sequence
|
324
|
+
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
325
|
+
>>> outputs = []
|
326
|
+
>>> for t in range(100):
|
327
|
+
... output = cell.update(sequence[t])
|
328
|
+
... outputs.append(output)
|
329
|
+
>>> outputs = jnp.stack(outputs)
|
330
|
+
>>> print(outputs.shape) # (100, 32, 20)
|
331
|
+
>>>
|
332
|
+
>>> # Reset state with different batch size
|
333
|
+
>>> cell.reset_state(batch_size=16)
|
334
|
+
>>> x_new = jnp.ones((16, 10))
|
335
|
+
>>> output_new = cell.update(x_new)
|
336
|
+
>>> print(output_new.shape) # (16, 20)
|
337
|
+
|
338
|
+
Notes
|
339
|
+
-----
|
340
|
+
GRU cells are computationally more efficient than LSTM cells due to having
|
341
|
+
fewer parameters, while often achieving comparable performance on many tasks.
|
342
|
+
|
343
|
+
References
|
344
|
+
----------
|
345
|
+
.. [1] Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F.,
|
346
|
+
Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using
|
347
|
+
RNN encoder-decoder for statistical machine translation.
|
348
|
+
arXiv preprint arXiv:1406.1078.
|
349
|
+
"""
|
350
|
+
__module__ = 'brainstate.nn'
|
351
|
+
|
352
|
+
def __init__(
|
353
|
+
self,
|
354
|
+
num_in: int,
|
355
|
+
num_out: int,
|
356
|
+
w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
|
357
|
+
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
358
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
359
|
+
activation: str | Callable = 'tanh',
|
360
|
+
name: str = None,
|
361
|
+
):
|
362
|
+
super().__init__(name=name)
|
363
|
+
|
364
|
+
# parameters
|
365
|
+
self._state_initializer = state_init
|
366
|
+
self.num_out = num_out
|
367
|
+
self.num_in = num_in
|
368
|
+
self.in_size = (num_in,)
|
369
|
+
self.out_size = (num_out,)
|
370
|
+
|
371
|
+
# activation function
|
372
|
+
if isinstance(activation, str):
|
373
|
+
self.activation = getattr(functional, activation)
|
374
|
+
else:
|
375
|
+
assert callable(activation), "The activation function should be a string or a callable function. "
|
376
|
+
self.activation = activation
|
377
|
+
|
378
|
+
# weights
|
379
|
+
self.Wrz = Linear(num_in + num_out, num_out * 2, w_init=w_init, b_init=b_init)
|
380
|
+
self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
381
|
+
|
382
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
383
|
+
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
384
|
+
|
385
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
386
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
387
|
+
|
388
|
+
def update(self, x):
|
389
|
+
old_h = self.h.value
|
390
|
+
xh = jnp.concatenate([x, old_h], axis=-1)
|
391
|
+
r, z = jnp.split(functional.sigmoid(self.Wrz(xh)), indices_or_sections=2, axis=-1)
|
392
|
+
rh = r * old_h
|
393
|
+
h = self.activation(self.Wh(jnp.concatenate([x, rh], axis=-1)))
|
394
|
+
h = (1 - z) * old_h + z * h
|
395
|
+
self.h.value = h
|
396
|
+
return h
|
397
|
+
|
398
|
+
|
399
|
+
class MGUCell(RNNCell):
|
400
|
+
r"""
|
401
|
+
Minimal Gated Unit (MGU) cell implementation.
|
402
|
+
|
403
|
+
MGU is a simplified version of GRU that uses a single forget gate instead of
|
404
|
+
separate reset and update gates. This design significantly reduces the number
|
405
|
+
of parameters while maintaining much of the gating capability. MGU provides
|
406
|
+
a good trade-off between model complexity and performance.
|
407
|
+
|
408
|
+
The MGU cell follows the mathematical formulation:
|
409
|
+
|
410
|
+
.. math::
|
411
|
+
|
412
|
+
f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
|
413
|
+
\tilde{h}_t &= \phi(W_h [x_t, (f_t \odot h_{t-1})] + b_h) \\
|
414
|
+
h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t
|
415
|
+
|
416
|
+
where:
|
417
|
+
|
418
|
+
- :math:`x_t` is the input vector at time t
|
419
|
+
- :math:`h_t` is the hidden state at time t
|
420
|
+
- :math:`f_t` is the forget gate vector
|
421
|
+
- :math:`\tilde{h}_t` is the candidate hidden state
|
422
|
+
- :math:`\odot` represents element-wise multiplication
|
423
|
+
- :math:`\sigma` is the sigmoid activation function
|
424
|
+
- :math:`\phi` is the activation function (typically tanh)
|
425
|
+
|
426
|
+
Parameters
|
427
|
+
----------
|
428
|
+
num_in : int
|
429
|
+
The number of input units.
|
430
|
+
num_out : int
|
431
|
+
The number of hidden units.
|
432
|
+
w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
|
433
|
+
Initializer for the weight matrices.
|
434
|
+
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
435
|
+
Initializer for the bias vectors.
|
436
|
+
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
437
|
+
Initializer for the hidden state.
|
438
|
+
activation : str or Callable, default='tanh'
|
439
|
+
Activation function to use. Can be a string (e.g., 'tanh', 'relu')
|
440
|
+
or a callable function.
|
441
|
+
name : str, optional
|
442
|
+
Name of the module.
|
443
|
+
|
444
|
+
Attributes
|
445
|
+
----------
|
446
|
+
num_in : int
|
447
|
+
Number of input features.
|
448
|
+
num_out : int
|
449
|
+
Number of hidden units.
|
450
|
+
in_size : tuple
|
451
|
+
Shape of input (num_in,).
|
452
|
+
out_size : tuple
|
453
|
+
Shape of output (num_out,).
|
454
|
+
|
455
|
+
State Variables
|
456
|
+
---------------
|
457
|
+
h : HiddenState
|
458
|
+
Hidden state of the MGU cell.
|
459
|
+
|
460
|
+
Methods
|
461
|
+
-------
|
462
|
+
init_state(batch_size=None, **kwargs)
|
463
|
+
Initialize the cell hidden state.
|
464
|
+
reset_state(batch_size=None, **kwargs)
|
465
|
+
Reset the cell hidden state to its initial value.
|
466
|
+
update(x)
|
467
|
+
Update the hidden state for one time step and return the new state.
|
468
|
+
|
469
|
+
Examples
|
470
|
+
--------
|
471
|
+
.. code-block:: python
|
472
|
+
|
473
|
+
>>> import brainstate as bs
|
474
|
+
>>> import jax.numpy as jnp
|
475
|
+
>>>
|
476
|
+
>>> # Create an MGU cell
|
477
|
+
>>> cell = bs.nn.MGUCell(num_in=10, num_out=20)
|
478
|
+
>>>
|
479
|
+
>>> # Initialize state for batch size 32
|
480
|
+
>>> cell.init_state(batch_size=32)
|
481
|
+
>>>
|
482
|
+
>>> # Process a single time step
|
483
|
+
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
484
|
+
>>> output = cell.update(x)
|
485
|
+
>>> print(output.shape) # (32, 20)
|
486
|
+
>>>
|
487
|
+
>>> # Process a sequence
|
488
|
+
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
489
|
+
>>> outputs = []
|
490
|
+
>>> for t in range(100):
|
491
|
+
... output = cell.update(sequence[t])
|
492
|
+
... outputs.append(output)
|
493
|
+
>>> outputs = jnp.stack(outputs)
|
494
|
+
>>> print(outputs.shape) # (100, 32, 20)
|
495
|
+
|
496
|
+
Notes
|
497
|
+
-----
|
498
|
+
MGU provides a lightweight alternative to GRU and LSTM, making it suitable
|
499
|
+
for resource-constrained applications or when model simplicity is preferred.
|
500
|
+
|
501
|
+
References
|
502
|
+
----------
|
503
|
+
.. [1] Zhou, G. B., Wu, J., Zhang, C. L., & Zhou, Z. H. (2016). Minimal gated unit
|
504
|
+
for recurrent neural networks. International Journal of Automation and Computing,
|
505
|
+
13(3), 226-234.
|
506
|
+
"""
|
507
|
+
__module__ = 'brainstate.nn'
|
508
|
+
|
509
|
+
def __init__(
|
510
|
+
self,
|
511
|
+
num_in: int,
|
512
|
+
num_out: int,
|
513
|
+
w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
|
514
|
+
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
515
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
516
|
+
activation: str | Callable = 'tanh',
|
517
|
+
name: str = None,
|
518
|
+
):
|
519
|
+
super().__init__(name=name)
|
520
|
+
|
521
|
+
# parameters
|
522
|
+
self._state_initializer = state_init
|
523
|
+
self.num_out = num_out
|
524
|
+
self.num_in = num_in
|
525
|
+
self.in_size = (num_in,)
|
526
|
+
self.out_size = (num_out,)
|
527
|
+
|
528
|
+
# activation function
|
529
|
+
if isinstance(activation, str):
|
530
|
+
self.activation = getattr(functional, activation)
|
531
|
+
else:
|
532
|
+
assert callable(activation), "The activation function should be a string or a callable function. "
|
533
|
+
self.activation = activation
|
534
|
+
|
535
|
+
# weights
|
536
|
+
self.Wf = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
537
|
+
self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
538
|
+
|
539
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
540
|
+
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
541
|
+
|
542
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
543
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
544
|
+
|
545
|
+
def update(self, x):
|
546
|
+
old_h = self.h.value
|
547
|
+
xh = jnp.concatenate([x, old_h], axis=-1)
|
548
|
+
f = functional.sigmoid(self.Wf(xh))
|
549
|
+
fh = f * old_h
|
550
|
+
h = self.activation(self.Wh(jnp.concatenate([x, fh], axis=-1)))
|
551
|
+
self.h.value = (1 - f) * self.h.value + f * h
|
552
|
+
return self.h.value
|
553
|
+
|
554
|
+
|
555
|
+
class LSTMCell(RNNCell):
|
556
|
+
r"""
|
557
|
+
Long Short-Term Memory (LSTM) cell implementation.
|
558
|
+
|
559
|
+
LSTM is a type of RNN architecture designed to address the vanishing gradient
|
560
|
+
problem and learn long-term dependencies. It uses a cell state to carry
|
561
|
+
information across time steps and three gates (input, forget, output) to
|
562
|
+
control information flow.
|
563
|
+
|
564
|
+
The LSTM cell follows the mathematical formulation:
|
565
|
+
|
566
|
+
.. math::
|
567
|
+
|
568
|
+
i_t &= \sigma(W_i [x_t, h_{t-1}] + b_i) \\
|
569
|
+
f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
|
570
|
+
g_t &= \phi(W_g [x_t, h_{t-1}] + b_g) \\
|
571
|
+
o_t &= \sigma(W_o [x_t, h_{t-1}] + b_o) \\
|
572
|
+
c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\
|
573
|
+
h_t &= o_t \odot \phi(c_t)
|
574
|
+
|
575
|
+
where:
|
576
|
+
|
577
|
+
- :math:`x_t` is the input vector at time t
|
578
|
+
- :math:`h_t` is the hidden state at time t
|
579
|
+
- :math:`c_t` is the cell state at time t
|
580
|
+
- :math:`i_t` is the input gate activation
|
581
|
+
- :math:`f_t` is the forget gate activation
|
582
|
+
- :math:`o_t` is the output gate activation
|
583
|
+
- :math:`g_t` is the cell update (candidate) vector
|
584
|
+
- :math:`\odot` represents element-wise multiplication
|
585
|
+
- :math:`\sigma` is the sigmoid activation function
|
586
|
+
- :math:`\phi` is the activation function (typically tanh)
|
587
|
+
|
588
|
+
Parameters
|
589
|
+
----------
|
590
|
+
num_in : int
|
591
|
+
The number of input units.
|
592
|
+
num_out : int
|
593
|
+
The number of hidden/cell units.
|
594
|
+
w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
|
595
|
+
Initializer for the weight matrices.
|
596
|
+
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
597
|
+
Initializer for the bias vectors.
|
598
|
+
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
599
|
+
Initializer for the hidden and cell states.
|
600
|
+
activation : str or Callable, default='tanh'
|
601
|
+
Activation function to use. Can be a string (e.g., 'tanh', 'relu')
|
602
|
+
or a callable function.
|
603
|
+
name : str, optional
|
604
|
+
Name of the module.
|
605
|
+
|
606
|
+
Attributes
|
607
|
+
----------
|
608
|
+
num_in : int
|
609
|
+
Number of input features.
|
610
|
+
num_out : int
|
611
|
+
Number of hidden/cell units.
|
612
|
+
in_size : tuple
|
613
|
+
Shape of input (num_in,).
|
614
|
+
out_size : tuple
|
615
|
+
Shape of output (num_out,).
|
616
|
+
|
617
|
+
State Variables
|
618
|
+
---------------
|
619
|
+
h : HiddenState
|
620
|
+
Hidden state of the LSTM cell.
|
621
|
+
c : HiddenState
|
622
|
+
Cell state of the LSTM cell.
|
623
|
+
|
624
|
+
Methods
|
625
|
+
-------
|
626
|
+
init_state(batch_size=None, **kwargs)
|
627
|
+
Initialize the cell and hidden states.
|
628
|
+
reset_state(batch_size=None, **kwargs)
|
629
|
+
Reset the cell and hidden states to their initial values.
|
630
|
+
update(x)
|
631
|
+
Update the states for one time step and return the new hidden state.
|
632
|
+
|
633
|
+
Examples
|
634
|
+
--------
|
635
|
+
.. code-block:: python
|
636
|
+
|
637
|
+
>>> import brainstate as bs
|
638
|
+
>>> import jax.numpy as jnp
|
639
|
+
>>>
|
640
|
+
>>> # Create an LSTM cell
|
641
|
+
>>> cell = bs.nn.LSTMCell(num_in=10, num_out=20)
|
642
|
+
>>>
|
643
|
+
>>> # Initialize states for batch size 32
|
644
|
+
>>> cell.init_state(batch_size=32)
|
645
|
+
>>>
|
646
|
+
>>> # Process a single time step
|
647
|
+
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
648
|
+
>>> output = cell.update(x)
|
649
|
+
>>> print(output.shape) # (32, 20)
|
650
|
+
>>>
|
651
|
+
>>> # Process a sequence
|
652
|
+
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
653
|
+
>>> outputs = []
|
654
|
+
>>> for t in range(100):
|
655
|
+
... output = cell.update(sequence[t])
|
656
|
+
... outputs.append(output)
|
657
|
+
>>> outputs = jnp.stack(outputs)
|
658
|
+
>>> print(outputs.shape) # (100, 32, 20)
|
659
|
+
>>>
|
660
|
+
>>> # Access cell state
|
661
|
+
>>> print(cell.c.value.shape) # (32, 20)
|
662
|
+
>>> print(cell.h.value.shape) # (32, 20)
|
663
|
+
|
664
|
+
Notes
|
665
|
+
-----
|
666
|
+
- The forget gate bias is initialized with +1.0 following Jozefowicz et al. (2015)
|
667
|
+
to reduce forgetting at the beginning of training.
|
668
|
+
- LSTM cells are effective for learning long-term dependencies but require
|
669
|
+
more parameters and computation than simpler RNN variants.
|
670
|
+
|
671
|
+
References
|
672
|
+
----------
|
673
|
+
.. [1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.
|
674
|
+
Neural computation, 9(8), 1735-1780.
|
675
|
+
.. [2] Gers, F. A., Schmidhuber, J., & Cummins, F. (2000). Learning to forget:
|
676
|
+
Continual prediction with LSTM. Neural computation, 12(10), 2451-2471.
|
677
|
+
.. [3] Jozefowicz, R., Zaremba, W., & Sutskever, I. (2015). An empirical
|
678
|
+
exploration of recurrent network architectures. In International
|
679
|
+
conference on machine learning (pp. 2342-2350).
|
680
|
+
"""
|
681
|
+
__module__ = 'brainstate.nn'
|
682
|
+
|
683
|
+
def __init__(
|
684
|
+
self,
|
685
|
+
num_in: int,
|
686
|
+
num_out: int,
|
687
|
+
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
688
|
+
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
689
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
690
|
+
activation: str | Callable = 'tanh',
|
691
|
+
name: str = None,
|
692
|
+
):
|
693
|
+
super().__init__(name=name)
|
694
|
+
|
695
|
+
# parameters
|
696
|
+
self.num_out = num_out
|
697
|
+
self.num_in = num_in
|
698
|
+
self.in_size = (num_in,)
|
699
|
+
self.out_size = (num_out,)
|
700
|
+
|
701
|
+
# initializers
|
702
|
+
self._state_initializer = state_init
|
703
|
+
|
704
|
+
# activation function
|
705
|
+
if isinstance(activation, str):
|
706
|
+
self.activation = getattr(functional, activation)
|
707
|
+
else:
|
708
|
+
assert callable(activation), "The activation function should be a string or a callable function. "
|
709
|
+
self.activation = activation
|
710
|
+
|
711
|
+
# weights
|
712
|
+
self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=b_init)
|
713
|
+
|
714
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
715
|
+
self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
716
|
+
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
717
|
+
|
718
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
719
|
+
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
720
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
721
|
+
|
722
|
+
def update(self, x):
|
723
|
+
h, c = self.h.value, self.c.value
|
724
|
+
xh = jnp.concat([x, h], axis=-1)
|
725
|
+
i, g, f, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
|
726
|
+
c = functional.sigmoid(f + 1.) * c + functional.sigmoid(i) * self.activation(g)
|
727
|
+
h = functional.sigmoid(o) * self.activation(c)
|
728
|
+
self.h.value = h
|
729
|
+
self.c.value = c
|
730
|
+
return h
|
731
|
+
|
732
|
+
|
733
|
+
class URLSTMCell(RNNCell):
|
734
|
+
r"""LSTM with UR gating mechanism.
|
735
|
+
|
736
|
+
URLSTM is a modification of the standard LSTM that uses untied (separate) biases
|
737
|
+
for the forget and retention mechanisms, allowing for more flexible gating control.
|
738
|
+
This implementation is based on the paper "Improving the Gating Mechanism of
|
739
|
+
Recurrent Neural Networks" by Gers et al.
|
740
|
+
|
741
|
+
The URLSTM cell follows the mathematical formulation:
|
742
|
+
|
743
|
+
.. math::
|
744
|
+
|
745
|
+
f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
|
746
|
+
r_t &= \sigma(W_r [x_t, h_{t-1}] - b_f) \\
|
747
|
+
g_t &= 2 r_t \odot f_t + (1 - 2 r_t) \odot f_t^2 \\
|
748
|
+
\tilde{c}_t &= \phi(W_c [x_t, h_{t-1}]) \\
|
749
|
+
c_t &= g_t \odot c_{t-1} + (1 - g_t) \odot \tilde{c}_t \\
|
750
|
+
o_t &= \sigma(W_o [x_t, h_{t-1}]) \\
|
751
|
+
h_t &= o_t \odot \phi(c_t)
|
752
|
+
|
753
|
+
where:
|
754
|
+
|
755
|
+
- :math:`x_t` is the input vector at time t
|
756
|
+
- :math:`h_t` is the hidden state at time t
|
757
|
+
- :math:`c_t` is the cell state at time t
|
758
|
+
- :math:`f_t` is the forget gate with positive bias
|
759
|
+
- :math:`r_t` is the retention gate with negative bias
|
760
|
+
- :math:`g_t` is the unified gate combining forget and retention
|
761
|
+
- :math:`\tilde{c}_t` is the candidate cell state
|
762
|
+
- :math:`o_t` is the output gate
|
763
|
+
- :math:`\odot` represents element-wise multiplication
|
764
|
+
- :math:`\sigma` is the sigmoid activation function
|
765
|
+
- :math:`\phi` is the activation function (typically tanh)
|
766
|
+
|
767
|
+
The key innovation is the untied bias mechanism where the forget and retention
|
768
|
+
gates use opposite biases, initialized using a uniform distribution to encourage
|
769
|
+
diverse gating behavior across units.
|
770
|
+
|
771
|
+
Parameters
|
772
|
+
----------
|
773
|
+
num_in : int
|
774
|
+
The number of input units.
|
775
|
+
num_out : int
|
776
|
+
The number of hidden/output units.
|
777
|
+
w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
|
778
|
+
Initializer for the weight matrix.
|
779
|
+
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
780
|
+
Initializer for the hidden and cell states.
|
781
|
+
activation : str or Callable, default='tanh'
|
782
|
+
Activation function to use. Can be a string (e.g., 'relu', 'tanh')
|
783
|
+
or a callable function.
|
784
|
+
name : str, optional
|
785
|
+
Name of the module.
|
786
|
+
|
787
|
+
State Variables
|
788
|
+
---------------
|
789
|
+
h : HiddenState
|
790
|
+
Hidden state of the URLSTM cell.
|
791
|
+
c : HiddenState
|
792
|
+
Cell state of the URLSTM cell.
|
793
|
+
|
794
|
+
Methods
|
795
|
+
-------
|
796
|
+
init_state(batch_size=None, **kwargs)
|
797
|
+
Initialize the cell and hidden states.
|
798
|
+
reset_state(batch_size=None, **kwargs)
|
799
|
+
Reset the cell and hidden states to their initial values.
|
800
|
+
update(x)
|
801
|
+
Update the cell and hidden states for one time step and return the hidden state.
|
802
|
+
|
803
|
+
Examples
|
804
|
+
--------
|
805
|
+
.. code-block:: python
|
806
|
+
|
807
|
+
>>> import brainstate as bs
|
808
|
+
>>> import jax.numpy as jnp
|
809
|
+
>>>
|
810
|
+
>>> # Create a URLSTM cell
|
811
|
+
>>> cell = bs.nn.URLSTMCell(num_in=10, num_out=20)
|
812
|
+
>>>
|
813
|
+
>>> # Initialize the state for batch size 32
|
814
|
+
>>> cell.init_state(batch_size=32)
|
815
|
+
>>>
|
816
|
+
>>> # Process a sequence
|
817
|
+
>>> x = jnp.ones((32, 10)) # batch_size x num_in
|
818
|
+
>>> output = cell.update(x)
|
819
|
+
>>> print(output.shape) # (32, 20)
|
820
|
+
>>>
|
821
|
+
>>> # Process multiple time steps
|
822
|
+
>>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
|
823
|
+
>>> outputs = []
|
824
|
+
>>> for t in range(100):
|
825
|
+
... output = cell.update(sequence[t])
|
826
|
+
... outputs.append(output)
|
827
|
+
>>> outputs = jnp.stack(outputs)
|
828
|
+
>>> print(outputs.shape) # (100, 32, 20)
|
829
|
+
|
830
|
+
References
|
831
|
+
----------
|
832
|
+
.. [1] Gu, Albert, et al. "Improving the gating mechanism of recurrent neural networks."
|
833
|
+
International conference on machine learning. PMLR, 2020.
|
834
|
+
"""
|
835
|
+
__module__ = 'brainstate.nn'
|
836
|
+
|
837
|
+
def __init__(
|
838
|
+
self,
|
839
|
+
num_in: int,
|
840
|
+
num_out: int,
|
841
|
+
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
842
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
843
|
+
activation: str | Callable = 'tanh',
|
844
|
+
name: str = None,
|
845
|
+
):
|
846
|
+
super().__init__(name=name)
|
847
|
+
|
848
|
+
# parameters
|
849
|
+
self.num_out = num_out
|
850
|
+
self.num_in = num_in
|
851
|
+
self.in_size = (num_in,)
|
852
|
+
self.out_size = (num_out,)
|
853
|
+
|
854
|
+
# initializers
|
855
|
+
self._state_initializer = state_init
|
856
|
+
|
857
|
+
# activation function
|
858
|
+
if isinstance(activation, str):
|
859
|
+
self.activation = getattr(functional, activation)
|
860
|
+
else:
|
861
|
+
assert callable(activation), "The activation function should be a string or a callable function."
|
862
|
+
self.activation = activation
|
863
|
+
|
864
|
+
# weights - 4 gates: forget, retention, candidate, output
|
865
|
+
self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=None)
|
866
|
+
|
867
|
+
# Initialize untied bias using uniform distribution
|
868
|
+
self.bias = ParamState(self._forget_bias())
|
869
|
+
|
870
|
+
def _forget_bias(self):
|
871
|
+
"""Initialize the forget gate bias using uniform distribution."""
|
872
|
+
# Sample from uniform distribution to encourage diverse gating
|
873
|
+
u = random.uniform(1 / self.num_out, 1 - 1 / self.num_out, (self.num_out,))
|
874
|
+
# Transform to logit space for initialization
|
875
|
+
return -jnp.log(1 / u - 1)
|
876
|
+
|
877
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
878
|
+
"""
|
879
|
+
Initialize the cell and hidden states.
|
880
|
+
|
881
|
+
Parameters
|
882
|
+
----------
|
883
|
+
batch_size : int, optional
|
884
|
+
The batch size for state initialization.
|
885
|
+
**kwargs
|
886
|
+
Additional keyword arguments.
|
887
|
+
"""
|
888
|
+
self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
889
|
+
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
890
|
+
|
891
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
892
|
+
"""
|
893
|
+
Reset the cell and hidden states to their initial values.
|
894
|
+
|
895
|
+
Parameters
|
896
|
+
----------
|
897
|
+
batch_size : int, optional
|
898
|
+
The batch size for state reset.
|
899
|
+
**kwargs
|
900
|
+
Additional keyword arguments.
|
901
|
+
"""
|
902
|
+
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
903
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
904
|
+
|
905
|
+
def update(self, x: ArrayLike) -> ArrayLike:
|
906
|
+
"""
|
907
|
+
Update the URLSTM cell for one time step.
|
908
|
+
|
909
|
+
Parameters
|
910
|
+
----------
|
911
|
+
x : ArrayLike
|
912
|
+
Input tensor with shape (batch_size, num_in).
|
913
|
+
|
914
|
+
Returns
|
915
|
+
-------
|
916
|
+
ArrayLike
|
917
|
+
Hidden state tensor with shape (batch_size, num_out).
|
918
|
+
"""
|
919
|
+
h, c = self.h.value, self.c.value
|
920
|
+
|
921
|
+
# Concatenate input and hidden state
|
922
|
+
xh = jnp.concatenate([x, h], axis=-1)
|
923
|
+
|
924
|
+
# Compute all gates in one pass
|
925
|
+
gates = self.W(xh)
|
926
|
+
f, r, u, o = jnp.split(gates, indices_or_sections=4, axis=-1)
|
927
|
+
|
928
|
+
# Apply untied biases to forget and retention gates
|
929
|
+
f_gate = functional.sigmoid(f + self.bias.value)
|
930
|
+
r_gate = functional.sigmoid(r - self.bias.value)
|
931
|
+
|
932
|
+
# Compute unified gate
|
933
|
+
g = 2 * r_gate * f_gate + (1 - 2 * r_gate) * f_gate ** 2
|
934
|
+
|
935
|
+
# Update cell state
|
936
|
+
next_cell = g * c + (1 - g) * self.activation(u)
|
937
|
+
|
938
|
+
# Compute output gate and hidden state
|
939
|
+
o_gate = functional.sigmoid(o)
|
940
|
+
next_hidden = o_gate * self.activation(next_cell)
|
941
|
+
|
942
|
+
# Update states
|
943
|
+
self.h.value = next_hidden
|
944
|
+
self.c.value = next_cell
|
945
|
+
|
946
|
+
return next_hidden
|