brainstate 0.1.8__py2.py3-none-any.whl → 0.1.10__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +588 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/METADATA +91 -99
- brainstate-0.1.10.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/top_level.txt +0 -0
brainstate/nn/_rate_rnns.py
CHANGED
@@ -1,554 +1,554 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
|
19
|
-
from typing import Callable, Union
|
20
|
-
|
21
|
-
import jax.numpy as jnp
|
22
|
-
|
23
|
-
from brainstate import random, init, functional
|
24
|
-
from brainstate._state import HiddenState, ParamState
|
25
|
-
from brainstate.typing import ArrayLike
|
26
|
-
from ._linear import Linear
|
27
|
-
from ._module import Module
|
28
|
-
|
29
|
-
__all__ = [
|
30
|
-
'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
|
31
|
-
]
|
32
|
-
|
33
|
-
|
34
|
-
class RNNCell(Module):
|
35
|
-
"""
|
36
|
-
Base class for all recurrent neural network (RNN) cell implementations.
|
37
|
-
|
38
|
-
This abstract class serves as the foundation for implementing various RNN cell types
|
39
|
-
such as vanilla RNN, GRU, LSTM, and other recurrent architectures. It extends the
|
40
|
-
Module class and provides common functionality and interface for recurrent units.
|
41
|
-
|
42
|
-
All RNN cell implementations should inherit from this class and implement the required
|
43
|
-
methods, particularly the `init_state()`, `reset_state()`, and `update()` methods that
|
44
|
-
define the state initialization and recurrent dynamics.
|
45
|
-
|
46
|
-
The RNNCell typically maintains hidden state(s) that are updated at each time step
|
47
|
-
based on the current input and previous state values.
|
48
|
-
|
49
|
-
Methods
|
50
|
-
-------
|
51
|
-
init_state(batch_size=None, **kwargs)
|
52
|
-
Initialize the cell state variables with appropriate dimensions.
|
53
|
-
reset_state(batch_size=None, **kwargs)
|
54
|
-
Reset the cell state variables to their initial values.
|
55
|
-
update(x)
|
56
|
-
Update the cell state for one time step based on input x and return output.
|
57
|
-
"""
|
58
|
-
pass
|
59
|
-
|
60
|
-
|
61
|
-
class ValinaRNNCell(RNNCell):
|
62
|
-
r"""
|
63
|
-
Vanilla Recurrent Neural Network (RNN) cell implementation.
|
64
|
-
|
65
|
-
This class implements the basic RNN model that updates a hidden state based on
|
66
|
-
the current input and previous hidden state. The standard RNN cell follows the
|
67
|
-
mathematical formulation:
|
68
|
-
|
69
|
-
.. math::
|
70
|
-
|
71
|
-
h_t = \phi(W [x_t, h_{t-1}] + b)
|
72
|
-
|
73
|
-
where:
|
74
|
-
|
75
|
-
- :math:`x_t` is the input vector at time t
|
76
|
-
- :math:`h_t` is the hidden state at time t
|
77
|
-
- :math:`h_{t-1}` is the hidden state at previous time step
|
78
|
-
- :math:`W` is the weight matrix for the combined input-hidden linear transformation
|
79
|
-
- :math:`b` is the bias vector
|
80
|
-
- :math:`\phi` is the activation function
|
81
|
-
|
82
|
-
Parameters
|
83
|
-
----------
|
84
|
-
num_in : int
|
85
|
-
The number of input units.
|
86
|
-
num_out : int
|
87
|
-
The number of hidden units.
|
88
|
-
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
89
|
-
Initializer for the hidden state.
|
90
|
-
w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
|
91
|
-
Initializer for the weight matrix.
|
92
|
-
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
93
|
-
Initializer for the bias vector.
|
94
|
-
activation : str or Callable, default='relu'
|
95
|
-
Activation function to use. Can be a string (e.g., 'relu', 'tanh')
|
96
|
-
or a callable function.
|
97
|
-
name : str, optional
|
98
|
-
Name of the module.
|
99
|
-
|
100
|
-
State Variables
|
101
|
-
--------------
|
102
|
-
h : HiddenState
|
103
|
-
Hidden state of the RNN cell.
|
104
|
-
|
105
|
-
Methods
|
106
|
-
-------
|
107
|
-
init_state(batch_size=None, **kwargs)
|
108
|
-
Initialize the cell hidden state.
|
109
|
-
reset_state(batch_size=None, **kwargs)
|
110
|
-
Reset the cell hidden state to its initial value.
|
111
|
-
update(x)
|
112
|
-
Update the hidden state for one time step and return the new state.
|
113
|
-
"""
|
114
|
-
__module__ = 'brainstate.nn'
|
115
|
-
|
116
|
-
def __init__(
|
117
|
-
self,
|
118
|
-
num_in: int,
|
119
|
-
num_out: int,
|
120
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
121
|
-
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
122
|
-
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
123
|
-
activation: str | Callable = 'relu',
|
124
|
-
name: str = None,
|
125
|
-
):
|
126
|
-
super().__init__(name=name)
|
127
|
-
|
128
|
-
# parameters
|
129
|
-
self.num_out = num_out
|
130
|
-
self.num_in = num_in
|
131
|
-
self.in_size = (num_in,)
|
132
|
-
self.out_size = (num_out,)
|
133
|
-
self._state_initializer = state_init
|
134
|
-
|
135
|
-
# activation function
|
136
|
-
if isinstance(activation, str):
|
137
|
-
self.activation = getattr(functional, activation)
|
138
|
-
else:
|
139
|
-
assert callable(activation), "The activation function should be a string or a callable function. "
|
140
|
-
self.activation = activation
|
141
|
-
|
142
|
-
# weights
|
143
|
-
self.W = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
144
|
-
|
145
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
146
|
-
self.h = HiddenState(init.param(self._state_initializer, self.num_out, batch_size))
|
147
|
-
|
148
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
149
|
-
self.h.value = init.param(self._state_initializer, self.num_out, batch_size)
|
150
|
-
|
151
|
-
def update(self, x):
|
152
|
-
xh = jnp.concatenate([x, self.h.value], axis=-1)
|
153
|
-
h = self.W(xh)
|
154
|
-
self.h.value = self.activation(h)
|
155
|
-
return self.h.value
|
156
|
-
|
157
|
-
|
158
|
-
class GRUCell(RNNCell):
|
159
|
-
r"""
|
160
|
-
Gated Recurrent Unit (GRU) cell implementation.
|
161
|
-
|
162
|
-
This class implements the GRU model that uses gating mechanisms to control
|
163
|
-
information flow. The GRU has fewer parameters than LSTM as it combines
|
164
|
-
the forget and input gates into a single update gate. The GRU follows the
|
165
|
-
mathematical formulation:
|
166
|
-
|
167
|
-
.. math::
|
168
|
-
|
169
|
-
r_t &= \sigma(W_r [x_t, h_{t-1}] + b_r) \\
|
170
|
-
z_t &= \sigma(W_z [x_t, h_{t-1}] + b_z) \\
|
171
|
-
\tilde{h}_t &= \tanh(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\
|
172
|
-
h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
|
173
|
-
|
174
|
-
where:
|
175
|
-
|
176
|
-
- :math:`x_t` is the input vector at time t
|
177
|
-
- :math:`h_t` is the hidden state at time t
|
178
|
-
- :math:`r_t` is the reset gate vector
|
179
|
-
- :math:`z_t` is the update gate vector
|
180
|
-
- :math:`\tilde{h}_t` is the candidate hidden state
|
181
|
-
- :math:`\odot` represents element-wise multiplication
|
182
|
-
- :math:`\sigma` is the sigmoid activation function
|
183
|
-
|
184
|
-
Parameters
|
185
|
-
----------
|
186
|
-
num_in : int
|
187
|
-
The number of input units.
|
188
|
-
num_out : int
|
189
|
-
The number of hidden units.
|
190
|
-
w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
|
191
|
-
Initializer for the weight matrices.
|
192
|
-
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
193
|
-
Initializer for the bias vectors.
|
194
|
-
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
195
|
-
Initializer for the hidden state.
|
196
|
-
activation : str or Callable, default='tanh'
|
197
|
-
Activation function to use. Can be a string (e.g., 'tanh')
|
198
|
-
or a callable function.
|
199
|
-
name : str, optional
|
200
|
-
Name of the module.
|
201
|
-
|
202
|
-
State Variables
|
203
|
-
--------------
|
204
|
-
h : HiddenState
|
205
|
-
Hidden state of the GRU cell.
|
206
|
-
|
207
|
-
Methods
|
208
|
-
-------
|
209
|
-
init_state(batch_size=None, **kwargs)
|
210
|
-
Initialize the cell hidden state.
|
211
|
-
reset_state(batch_size=None, **kwargs)
|
212
|
-
Reset the cell hidden state to its initial value.
|
213
|
-
update(x)
|
214
|
-
Update the hidden state for one time step and return the new state.
|
215
|
-
"""
|
216
|
-
__module__ = 'brainstate.nn'
|
217
|
-
|
218
|
-
def __init__(
|
219
|
-
self,
|
220
|
-
num_in: int,
|
221
|
-
num_out: int,
|
222
|
-
w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
|
223
|
-
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
224
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
225
|
-
activation: str | Callable = 'tanh',
|
226
|
-
name: str = None,
|
227
|
-
):
|
228
|
-
super().__init__(name=name)
|
229
|
-
|
230
|
-
# parameters
|
231
|
-
self._state_initializer = state_init
|
232
|
-
self.num_out = num_out
|
233
|
-
self.num_in = num_in
|
234
|
-
self.in_size = (num_in,)
|
235
|
-
self.out_size = (num_out,)
|
236
|
-
|
237
|
-
# activation function
|
238
|
-
if isinstance(activation, str):
|
239
|
-
self.activation = getattr(functional, activation)
|
240
|
-
else:
|
241
|
-
assert callable(activation), "The activation function should be a string or a callable function. "
|
242
|
-
self.activation = activation
|
243
|
-
|
244
|
-
# weights
|
245
|
-
self.Wrz = Linear(num_in + num_out, num_out * 2, w_init=w_init, b_init=b_init)
|
246
|
-
self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
247
|
-
|
248
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
249
|
-
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
250
|
-
|
251
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
252
|
-
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
253
|
-
|
254
|
-
def update(self, x):
|
255
|
-
old_h = self.h.value
|
256
|
-
xh = jnp.concatenate([x, old_h], axis=-1)
|
257
|
-
r, z = jnp.split(functional.sigmoid(self.Wrz(xh)), indices_or_sections=2, axis=-1)
|
258
|
-
rh = r * old_h
|
259
|
-
h = self.activation(self.Wh(jnp.concatenate([x, rh], axis=-1)))
|
260
|
-
h = (1 - z) * old_h + z * h
|
261
|
-
self.h.value = h
|
262
|
-
return h
|
263
|
-
|
264
|
-
|
265
|
-
class MGUCell(RNNCell):
|
266
|
-
r"""
|
267
|
-
Minimal Gated Recurrent Unit (MGU) cell implementation.
|
268
|
-
|
269
|
-
MGU is a simplified version of GRU that uses a single forget gate instead of
|
270
|
-
separate reset and update gates. This results in fewer parameters while
|
271
|
-
maintaining much of the gating capability. The MGU follows the mathematical
|
272
|
-
formulation:
|
273
|
-
|
274
|
-
.. math::
|
275
|
-
|
276
|
-
f_t &= \\sigma(W_f [x_t, h_{t-1}] + b_f) \\\\
|
277
|
-
\\tilde{h}_t &= \\phi(W_h [x_t, (f_t \\odot h_{t-1})] + b_h) \\\\
|
278
|
-
h_t &= (1 - f_t) \\odot h_{t-1} + f_t \\odot \\tilde{h}_t
|
279
|
-
|
280
|
-
where:
|
281
|
-
|
282
|
-
- :math:`x_t` is the input vector at time t
|
283
|
-
- :math:`h_t` is the hidden state at time t
|
284
|
-
- :math:`f_t` is the forget gate vector
|
285
|
-
- :math:`\\tilde{h}_t` is the candidate hidden state
|
286
|
-
- :math:`\\odot` represents element-wise multiplication
|
287
|
-
- :math:`\\sigma` is the sigmoid activation function
|
288
|
-
- :math:`\\phi` is the activation function (typically tanh)
|
289
|
-
|
290
|
-
Parameters
|
291
|
-
----------
|
292
|
-
num_in : int
|
293
|
-
The number of input units.
|
294
|
-
num_out : int
|
295
|
-
The number of hidden units.
|
296
|
-
w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
|
297
|
-
Initializer for the weight matrices.
|
298
|
-
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
299
|
-
Initializer for the bias vectors.
|
300
|
-
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
301
|
-
Initializer for the hidden state.
|
302
|
-
activation : str or Callable, default='tanh'
|
303
|
-
Activation function to use. Can be a string (e.g., 'tanh')
|
304
|
-
or a callable function.
|
305
|
-
name : str, optional
|
306
|
-
Name of the module.
|
307
|
-
|
308
|
-
State Variables
|
309
|
-
--------------
|
310
|
-
h : HiddenState
|
311
|
-
Hidden state of the MGU cell.
|
312
|
-
|
313
|
-
Methods
|
314
|
-
-------
|
315
|
-
init_state(batch_size=None, **kwargs)
|
316
|
-
Initialize the cell hidden state.
|
317
|
-
reset_state(batch_size=None, **kwargs)
|
318
|
-
Reset the cell hidden state to its initial value.
|
319
|
-
update(x)
|
320
|
-
Update the hidden state for one time step and return the new state.
|
321
|
-
"""
|
322
|
-
__module__ = 'brainstate.nn'
|
323
|
-
|
324
|
-
def __init__(
|
325
|
-
self,
|
326
|
-
num_in: int,
|
327
|
-
num_out: int,
|
328
|
-
w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
|
329
|
-
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
330
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
331
|
-
activation: str | Callable = 'tanh',
|
332
|
-
name: str = None,
|
333
|
-
):
|
334
|
-
super().__init__(name=name)
|
335
|
-
|
336
|
-
# parameters
|
337
|
-
self._state_initializer = state_init
|
338
|
-
self.num_out = num_out
|
339
|
-
self.num_in = num_in
|
340
|
-
self.in_size = (num_in,)
|
341
|
-
self.out_size = (num_out,)
|
342
|
-
|
343
|
-
# activation function
|
344
|
-
if isinstance(activation, str):
|
345
|
-
self.activation = getattr(functional, activation)
|
346
|
-
else:
|
347
|
-
assert callable(activation), "The activation function should be a string or a callable function. "
|
348
|
-
self.activation = activation
|
349
|
-
|
350
|
-
# weights
|
351
|
-
self.Wf = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
352
|
-
self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
353
|
-
|
354
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
355
|
-
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
356
|
-
|
357
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
358
|
-
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
359
|
-
|
360
|
-
def update(self, x):
|
361
|
-
old_h = self.h.value
|
362
|
-
xh = jnp.concatenate([x, old_h], axis=-1)
|
363
|
-
f = functional.sigmoid(self.Wf(xh))
|
364
|
-
fh = f * old_h
|
365
|
-
h = self.activation(self.Wh(jnp.concatenate([x, fh], axis=-1)))
|
366
|
-
self.h.value = (1 - f) * self.h.value + f * h
|
367
|
-
return self.h.value
|
368
|
-
|
369
|
-
|
370
|
-
class LSTMCell(RNNCell):
|
371
|
-
r"""
|
372
|
-
Long Short-Term Memory (LSTM) cell implementation.
|
373
|
-
|
374
|
-
This class implements the LSTM architecture which uses multiple gating mechanisms
|
375
|
-
to regulate information flow and address the vanishing gradient problem in RNNs.
|
376
|
-
The LSTM follows the mathematical formulation:
|
377
|
-
|
378
|
-
.. math::
|
379
|
-
|
380
|
-
i_t &= \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\
|
381
|
-
f_t &= \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\
|
382
|
-
g_t &= \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\
|
383
|
-
o_t &= \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\
|
384
|
-
c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\
|
385
|
-
h_t &= o_t \odot \tanh(c_t)
|
386
|
-
|
387
|
-
where:
|
388
|
-
|
389
|
-
- :math:`x_t` is the input vector at time t
|
390
|
-
- :math:`h_t` is the hidden state at time t
|
391
|
-
- :math:`c_t` is the cell state at time t
|
392
|
-
- :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and output gate activations
|
393
|
-
- :math:`g_t` is the cell update vector
|
394
|
-
- :math:`\odot` represents element-wise multiplication
|
395
|
-
- :math:`\sigma` is the sigmoid activation function
|
396
|
-
|
397
|
-
Notes
|
398
|
-
-----
|
399
|
-
Forget gate initialization: Following Jozefowicz et al. (2015), we add 1.0
|
400
|
-
to the forget gate bias after initialization to reduce forgetting at the
|
401
|
-
beginning of training.
|
402
|
-
|
403
|
-
Parameters
|
404
|
-
----------
|
405
|
-
num_in : int
|
406
|
-
The number of input units.
|
407
|
-
num_out : int
|
408
|
-
The number of hidden/cell units.
|
409
|
-
w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
|
410
|
-
Initializer for the weight matrices.
|
411
|
-
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
412
|
-
Initializer for the bias vectors.
|
413
|
-
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
414
|
-
Initializer for the hidden and cell states.
|
415
|
-
activation : str or Callable, default='tanh'
|
416
|
-
Activation function to use. Can be a string (e.g., 'tanh')
|
417
|
-
or a callable function.
|
418
|
-
name : str, optional
|
419
|
-
Name of the module.
|
420
|
-
|
421
|
-
State Variables
|
422
|
-
--------------
|
423
|
-
h : HiddenState
|
424
|
-
Hidden state of the LSTM cell.
|
425
|
-
c : HiddenState
|
426
|
-
Cell state of the LSTM cell.
|
427
|
-
|
428
|
-
Methods
|
429
|
-
-------
|
430
|
-
init_state(batch_size=None, **kwargs)
|
431
|
-
Initialize the cell and hidden states.
|
432
|
-
reset_state(batch_size=None, **kwargs)
|
433
|
-
Reset the cell and hidden states to their initial values.
|
434
|
-
update(x)
|
435
|
-
Update the states for one time step and return the new hidden state.
|
436
|
-
|
437
|
-
References
|
438
|
-
----------
|
439
|
-
.. [1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.
|
440
|
-
Neural computation, 9(8), 1735-1780.
|
441
|
-
.. [2] Zaremba, W., Sutskever, I., & Vinyals, O. (2014). Recurrent neural
|
442
|
-
network regularization. arXiv preprint arXiv:1409.2329.
|
443
|
-
.. [3] Jozefowicz, R., Zaremba, W., & Sutskever, I. (2015). An empirical
|
444
|
-
exploration of recurrent network architectures. In International
|
445
|
-
conference on machine learning, pp. 2342-2350.
|
446
|
-
"""
|
447
|
-
__module__ = 'brainstate.nn'
|
448
|
-
|
449
|
-
def __init__(
|
450
|
-
self,
|
451
|
-
num_in: int,
|
452
|
-
num_out: int,
|
453
|
-
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
454
|
-
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
455
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
456
|
-
activation: str | Callable = 'tanh',
|
457
|
-
name: str = None,
|
458
|
-
):
|
459
|
-
super().__init__(name=name)
|
460
|
-
|
461
|
-
# parameters
|
462
|
-
self.num_out = num_out
|
463
|
-
self.num_in = num_in
|
464
|
-
self.in_size = (num_in,)
|
465
|
-
self.out_size = (num_out,)
|
466
|
-
|
467
|
-
# initializers
|
468
|
-
self._state_initializer = state_init
|
469
|
-
|
470
|
-
# activation function
|
471
|
-
if isinstance(activation, str):
|
472
|
-
self.activation = getattr(functional, activation)
|
473
|
-
else:
|
474
|
-
assert callable(activation), "The activation function should be a string or a callable function. "
|
475
|
-
self.activation = activation
|
476
|
-
|
477
|
-
# weights
|
478
|
-
self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=b_init)
|
479
|
-
|
480
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
481
|
-
self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
482
|
-
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
483
|
-
|
484
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
485
|
-
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
486
|
-
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
487
|
-
|
488
|
-
def update(self, x):
|
489
|
-
h, c = self.h.value, self.c.value
|
490
|
-
xh = jnp.concat([x, h], axis=-1)
|
491
|
-
i, g, f, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
|
492
|
-
c = functional.sigmoid(f + 1.) * c + functional.sigmoid(i) * self.activation(g)
|
493
|
-
h = functional.sigmoid(o) * self.activation(c)
|
494
|
-
self.h.value = h
|
495
|
-
self.c.value = c
|
496
|
-
return h
|
497
|
-
|
498
|
-
|
499
|
-
class URLSTMCell(RNNCell):
|
500
|
-
def __init__(
|
501
|
-
self,
|
502
|
-
num_in: int,
|
503
|
-
num_out: int,
|
504
|
-
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
505
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
506
|
-
activation: str | Callable = 'tanh',
|
507
|
-
name: str = None,
|
508
|
-
):
|
509
|
-
super().__init__(name=name)
|
510
|
-
|
511
|
-
# parameters
|
512
|
-
self.num_out = num_out
|
513
|
-
self.num_in = num_in
|
514
|
-
self.in_size = (num_in,)
|
515
|
-
self.out_size = (num_out,)
|
516
|
-
|
517
|
-
# initializers
|
518
|
-
self._state_initializer = state_init
|
519
|
-
|
520
|
-
# activation function
|
521
|
-
if isinstance(activation, str):
|
522
|
-
self.activation = getattr(functional, activation)
|
523
|
-
else:
|
524
|
-
assert callable(activation), "The activation function should be a string or a callable function. "
|
525
|
-
self.activation = activation
|
526
|
-
|
527
|
-
# weights
|
528
|
-
self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=None)
|
529
|
-
self.bias = ParamState(self._forget_bias())
|
530
|
-
|
531
|
-
def _forget_bias(self):
|
532
|
-
u = random.uniform(1 / self.num_out, 1 - 1 / self.num_out, (self.num_out,))
|
533
|
-
return -jnp.log(1 / u - 1)
|
534
|
-
|
535
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
536
|
-
self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
537
|
-
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
538
|
-
|
539
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
540
|
-
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
541
|
-
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
542
|
-
|
543
|
-
def update(self, x: ArrayLike) -> ArrayLike:
|
544
|
-
h, c = self.h.value, self.c.value
|
545
|
-
xh = jnp.concat([x, h], axis=-1)
|
546
|
-
f, r, u, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
|
547
|
-
f_ = functional.sigmoid(f + self.bias.value)
|
548
|
-
r_ = functional.sigmoid(r - self.bias.value)
|
549
|
-
g = 2 * r_ * f_ + (1 - 2 * r_) * f_ ** 2
|
550
|
-
next_cell = g * c + (1 - g) * self.activation(u)
|
551
|
-
next_hidden = functional.sigmoid(o) * self.activation(next_cell)
|
552
|
-
self.h.value = next_hidden
|
553
|
-
self.c.value = next_cell
|
554
|
-
return next_hidden
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
|
19
|
+
from typing import Callable, Union
|
20
|
+
|
21
|
+
import jax.numpy as jnp
|
22
|
+
|
23
|
+
from brainstate import random, init, functional
|
24
|
+
from brainstate._state import HiddenState, ParamState
|
25
|
+
from brainstate.typing import ArrayLike
|
26
|
+
from ._linear import Linear
|
27
|
+
from ._module import Module
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
|
31
|
+
]
|
32
|
+
|
33
|
+
|
34
|
+
class RNNCell(Module):
|
35
|
+
"""
|
36
|
+
Base class for all recurrent neural network (RNN) cell implementations.
|
37
|
+
|
38
|
+
This abstract class serves as the foundation for implementing various RNN cell types
|
39
|
+
such as vanilla RNN, GRU, LSTM, and other recurrent architectures. It extends the
|
40
|
+
Module class and provides common functionality and interface for recurrent units.
|
41
|
+
|
42
|
+
All RNN cell implementations should inherit from this class and implement the required
|
43
|
+
methods, particularly the `init_state()`, `reset_state()`, and `update()` methods that
|
44
|
+
define the state initialization and recurrent dynamics.
|
45
|
+
|
46
|
+
The RNNCell typically maintains hidden state(s) that are updated at each time step
|
47
|
+
based on the current input and previous state values.
|
48
|
+
|
49
|
+
Methods
|
50
|
+
-------
|
51
|
+
init_state(batch_size=None, **kwargs)
|
52
|
+
Initialize the cell state variables with appropriate dimensions.
|
53
|
+
reset_state(batch_size=None, **kwargs)
|
54
|
+
Reset the cell state variables to their initial values.
|
55
|
+
update(x)
|
56
|
+
Update the cell state for one time step based on input x and return output.
|
57
|
+
"""
|
58
|
+
pass
|
59
|
+
|
60
|
+
|
61
|
+
class ValinaRNNCell(RNNCell):
|
62
|
+
r"""
|
63
|
+
Vanilla Recurrent Neural Network (RNN) cell implementation.
|
64
|
+
|
65
|
+
This class implements the basic RNN model that updates a hidden state based on
|
66
|
+
the current input and previous hidden state. The standard RNN cell follows the
|
67
|
+
mathematical formulation:
|
68
|
+
|
69
|
+
.. math::
|
70
|
+
|
71
|
+
h_t = \phi(W [x_t, h_{t-1}] + b)
|
72
|
+
|
73
|
+
where:
|
74
|
+
|
75
|
+
- :math:`x_t` is the input vector at time t
|
76
|
+
- :math:`h_t` is the hidden state at time t
|
77
|
+
- :math:`h_{t-1}` is the hidden state at previous time step
|
78
|
+
- :math:`W` is the weight matrix for the combined input-hidden linear transformation
|
79
|
+
- :math:`b` is the bias vector
|
80
|
+
- :math:`\phi` is the activation function
|
81
|
+
|
82
|
+
Parameters
|
83
|
+
----------
|
84
|
+
num_in : int
|
85
|
+
The number of input units.
|
86
|
+
num_out : int
|
87
|
+
The number of hidden units.
|
88
|
+
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
89
|
+
Initializer for the hidden state.
|
90
|
+
w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
|
91
|
+
Initializer for the weight matrix.
|
92
|
+
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
93
|
+
Initializer for the bias vector.
|
94
|
+
activation : str or Callable, default='relu'
|
95
|
+
Activation function to use. Can be a string (e.g., 'relu', 'tanh')
|
96
|
+
or a callable function.
|
97
|
+
name : str, optional
|
98
|
+
Name of the module.
|
99
|
+
|
100
|
+
State Variables
|
101
|
+
--------------
|
102
|
+
h : HiddenState
|
103
|
+
Hidden state of the RNN cell.
|
104
|
+
|
105
|
+
Methods
|
106
|
+
-------
|
107
|
+
init_state(batch_size=None, **kwargs)
|
108
|
+
Initialize the cell hidden state.
|
109
|
+
reset_state(batch_size=None, **kwargs)
|
110
|
+
Reset the cell hidden state to its initial value.
|
111
|
+
update(x)
|
112
|
+
Update the hidden state for one time step and return the new state.
|
113
|
+
"""
|
114
|
+
__module__ = 'brainstate.nn'
|
115
|
+
|
116
|
+
def __init__(
|
117
|
+
self,
|
118
|
+
num_in: int,
|
119
|
+
num_out: int,
|
120
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
121
|
+
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
122
|
+
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
123
|
+
activation: str | Callable = 'relu',
|
124
|
+
name: str = None,
|
125
|
+
):
|
126
|
+
super().__init__(name=name)
|
127
|
+
|
128
|
+
# parameters
|
129
|
+
self.num_out = num_out
|
130
|
+
self.num_in = num_in
|
131
|
+
self.in_size = (num_in,)
|
132
|
+
self.out_size = (num_out,)
|
133
|
+
self._state_initializer = state_init
|
134
|
+
|
135
|
+
# activation function
|
136
|
+
if isinstance(activation, str):
|
137
|
+
self.activation = getattr(functional, activation)
|
138
|
+
else:
|
139
|
+
assert callable(activation), "The activation function should be a string or a callable function. "
|
140
|
+
self.activation = activation
|
141
|
+
|
142
|
+
# weights
|
143
|
+
self.W = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
144
|
+
|
145
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
146
|
+
self.h = HiddenState(init.param(self._state_initializer, self.num_out, batch_size))
|
147
|
+
|
148
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
149
|
+
self.h.value = init.param(self._state_initializer, self.num_out, batch_size)
|
150
|
+
|
151
|
+
def update(self, x):
|
152
|
+
xh = jnp.concatenate([x, self.h.value], axis=-1)
|
153
|
+
h = self.W(xh)
|
154
|
+
self.h.value = self.activation(h)
|
155
|
+
return self.h.value
|
156
|
+
|
157
|
+
|
158
|
+
class GRUCell(RNNCell):
|
159
|
+
r"""
|
160
|
+
Gated Recurrent Unit (GRU) cell implementation.
|
161
|
+
|
162
|
+
This class implements the GRU model that uses gating mechanisms to control
|
163
|
+
information flow. The GRU has fewer parameters than LSTM as it combines
|
164
|
+
the forget and input gates into a single update gate. The GRU follows the
|
165
|
+
mathematical formulation:
|
166
|
+
|
167
|
+
.. math::
|
168
|
+
|
169
|
+
r_t &= \sigma(W_r [x_t, h_{t-1}] + b_r) \\
|
170
|
+
z_t &= \sigma(W_z [x_t, h_{t-1}] + b_z) \\
|
171
|
+
\tilde{h}_t &= \tanh(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\
|
172
|
+
h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
|
173
|
+
|
174
|
+
where:
|
175
|
+
|
176
|
+
- :math:`x_t` is the input vector at time t
|
177
|
+
- :math:`h_t` is the hidden state at time t
|
178
|
+
- :math:`r_t` is the reset gate vector
|
179
|
+
- :math:`z_t` is the update gate vector
|
180
|
+
- :math:`\tilde{h}_t` is the candidate hidden state
|
181
|
+
- :math:`\odot` represents element-wise multiplication
|
182
|
+
- :math:`\sigma` is the sigmoid activation function
|
183
|
+
|
184
|
+
Parameters
|
185
|
+
----------
|
186
|
+
num_in : int
|
187
|
+
The number of input units.
|
188
|
+
num_out : int
|
189
|
+
The number of hidden units.
|
190
|
+
w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
|
191
|
+
Initializer for the weight matrices.
|
192
|
+
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
193
|
+
Initializer for the bias vectors.
|
194
|
+
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
195
|
+
Initializer for the hidden state.
|
196
|
+
activation : str or Callable, default='tanh'
|
197
|
+
Activation function to use. Can be a string (e.g., 'tanh')
|
198
|
+
or a callable function.
|
199
|
+
name : str, optional
|
200
|
+
Name of the module.
|
201
|
+
|
202
|
+
State Variables
|
203
|
+
--------------
|
204
|
+
h : HiddenState
|
205
|
+
Hidden state of the GRU cell.
|
206
|
+
|
207
|
+
Methods
|
208
|
+
-------
|
209
|
+
init_state(batch_size=None, **kwargs)
|
210
|
+
Initialize the cell hidden state.
|
211
|
+
reset_state(batch_size=None, **kwargs)
|
212
|
+
Reset the cell hidden state to its initial value.
|
213
|
+
update(x)
|
214
|
+
Update the hidden state for one time step and return the new state.
|
215
|
+
"""
|
216
|
+
__module__ = 'brainstate.nn'
|
217
|
+
|
218
|
+
def __init__(
|
219
|
+
self,
|
220
|
+
num_in: int,
|
221
|
+
num_out: int,
|
222
|
+
w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
|
223
|
+
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
224
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
225
|
+
activation: str | Callable = 'tanh',
|
226
|
+
name: str = None,
|
227
|
+
):
|
228
|
+
super().__init__(name=name)
|
229
|
+
|
230
|
+
# parameters
|
231
|
+
self._state_initializer = state_init
|
232
|
+
self.num_out = num_out
|
233
|
+
self.num_in = num_in
|
234
|
+
self.in_size = (num_in,)
|
235
|
+
self.out_size = (num_out,)
|
236
|
+
|
237
|
+
# activation function
|
238
|
+
if isinstance(activation, str):
|
239
|
+
self.activation = getattr(functional, activation)
|
240
|
+
else:
|
241
|
+
assert callable(activation), "The activation function should be a string or a callable function. "
|
242
|
+
self.activation = activation
|
243
|
+
|
244
|
+
# weights
|
245
|
+
self.Wrz = Linear(num_in + num_out, num_out * 2, w_init=w_init, b_init=b_init)
|
246
|
+
self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
247
|
+
|
248
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
249
|
+
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
250
|
+
|
251
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
252
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
253
|
+
|
254
|
+
def update(self, x):
|
255
|
+
old_h = self.h.value
|
256
|
+
xh = jnp.concatenate([x, old_h], axis=-1)
|
257
|
+
r, z = jnp.split(functional.sigmoid(self.Wrz(xh)), indices_or_sections=2, axis=-1)
|
258
|
+
rh = r * old_h
|
259
|
+
h = self.activation(self.Wh(jnp.concatenate([x, rh], axis=-1)))
|
260
|
+
h = (1 - z) * old_h + z * h
|
261
|
+
self.h.value = h
|
262
|
+
return h
|
263
|
+
|
264
|
+
|
265
|
+
class MGUCell(RNNCell):
|
266
|
+
r"""
|
267
|
+
Minimal Gated Recurrent Unit (MGU) cell implementation.
|
268
|
+
|
269
|
+
MGU is a simplified version of GRU that uses a single forget gate instead of
|
270
|
+
separate reset and update gates. This results in fewer parameters while
|
271
|
+
maintaining much of the gating capability. The MGU follows the mathematical
|
272
|
+
formulation:
|
273
|
+
|
274
|
+
.. math::
|
275
|
+
|
276
|
+
f_t &= \\sigma(W_f [x_t, h_{t-1}] + b_f) \\\\
|
277
|
+
\\tilde{h}_t &= \\phi(W_h [x_t, (f_t \\odot h_{t-1})] + b_h) \\\\
|
278
|
+
h_t &= (1 - f_t) \\odot h_{t-1} + f_t \\odot \\tilde{h}_t
|
279
|
+
|
280
|
+
where:
|
281
|
+
|
282
|
+
- :math:`x_t` is the input vector at time t
|
283
|
+
- :math:`h_t` is the hidden state at time t
|
284
|
+
- :math:`f_t` is the forget gate vector
|
285
|
+
- :math:`\\tilde{h}_t` is the candidate hidden state
|
286
|
+
- :math:`\\odot` represents element-wise multiplication
|
287
|
+
- :math:`\\sigma` is the sigmoid activation function
|
288
|
+
- :math:`\\phi` is the activation function (typically tanh)
|
289
|
+
|
290
|
+
Parameters
|
291
|
+
----------
|
292
|
+
num_in : int
|
293
|
+
The number of input units.
|
294
|
+
num_out : int
|
295
|
+
The number of hidden units.
|
296
|
+
w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
|
297
|
+
Initializer for the weight matrices.
|
298
|
+
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
299
|
+
Initializer for the bias vectors.
|
300
|
+
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
301
|
+
Initializer for the hidden state.
|
302
|
+
activation : str or Callable, default='tanh'
|
303
|
+
Activation function to use. Can be a string (e.g., 'tanh')
|
304
|
+
or a callable function.
|
305
|
+
name : str, optional
|
306
|
+
Name of the module.
|
307
|
+
|
308
|
+
State Variables
|
309
|
+
--------------
|
310
|
+
h : HiddenState
|
311
|
+
Hidden state of the MGU cell.
|
312
|
+
|
313
|
+
Methods
|
314
|
+
-------
|
315
|
+
init_state(batch_size=None, **kwargs)
|
316
|
+
Initialize the cell hidden state.
|
317
|
+
reset_state(batch_size=None, **kwargs)
|
318
|
+
Reset the cell hidden state to its initial value.
|
319
|
+
update(x)
|
320
|
+
Update the hidden state for one time step and return the new state.
|
321
|
+
"""
|
322
|
+
__module__ = 'brainstate.nn'
|
323
|
+
|
324
|
+
def __init__(
|
325
|
+
self,
|
326
|
+
num_in: int,
|
327
|
+
num_out: int,
|
328
|
+
w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
|
329
|
+
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
330
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
331
|
+
activation: str | Callable = 'tanh',
|
332
|
+
name: str = None,
|
333
|
+
):
|
334
|
+
super().__init__(name=name)
|
335
|
+
|
336
|
+
# parameters
|
337
|
+
self._state_initializer = state_init
|
338
|
+
self.num_out = num_out
|
339
|
+
self.num_in = num_in
|
340
|
+
self.in_size = (num_in,)
|
341
|
+
self.out_size = (num_out,)
|
342
|
+
|
343
|
+
# activation function
|
344
|
+
if isinstance(activation, str):
|
345
|
+
self.activation = getattr(functional, activation)
|
346
|
+
else:
|
347
|
+
assert callable(activation), "The activation function should be a string or a callable function. "
|
348
|
+
self.activation = activation
|
349
|
+
|
350
|
+
# weights
|
351
|
+
self.Wf = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
352
|
+
self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
353
|
+
|
354
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
355
|
+
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
356
|
+
|
357
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
358
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
359
|
+
|
360
|
+
def update(self, x):
|
361
|
+
old_h = self.h.value
|
362
|
+
xh = jnp.concatenate([x, old_h], axis=-1)
|
363
|
+
f = functional.sigmoid(self.Wf(xh))
|
364
|
+
fh = f * old_h
|
365
|
+
h = self.activation(self.Wh(jnp.concatenate([x, fh], axis=-1)))
|
366
|
+
self.h.value = (1 - f) * self.h.value + f * h
|
367
|
+
return self.h.value
|
368
|
+
|
369
|
+
|
370
|
+
class LSTMCell(RNNCell):
|
371
|
+
r"""
|
372
|
+
Long Short-Term Memory (LSTM) cell implementation.
|
373
|
+
|
374
|
+
This class implements the LSTM architecture which uses multiple gating mechanisms
|
375
|
+
to regulate information flow and address the vanishing gradient problem in RNNs.
|
376
|
+
The LSTM follows the mathematical formulation:
|
377
|
+
|
378
|
+
.. math::
|
379
|
+
|
380
|
+
i_t &= \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\
|
381
|
+
f_t &= \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\
|
382
|
+
g_t &= \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\
|
383
|
+
o_t &= \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\
|
384
|
+
c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\
|
385
|
+
h_t &= o_t \odot \tanh(c_t)
|
386
|
+
|
387
|
+
where:
|
388
|
+
|
389
|
+
- :math:`x_t` is the input vector at time t
|
390
|
+
- :math:`h_t` is the hidden state at time t
|
391
|
+
- :math:`c_t` is the cell state at time t
|
392
|
+
- :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and output gate activations
|
393
|
+
- :math:`g_t` is the cell update vector
|
394
|
+
- :math:`\odot` represents element-wise multiplication
|
395
|
+
- :math:`\sigma` is the sigmoid activation function
|
396
|
+
|
397
|
+
Notes
|
398
|
+
-----
|
399
|
+
Forget gate initialization: Following Jozefowicz et al. (2015), we add 1.0
|
400
|
+
to the forget gate bias after initialization to reduce forgetting at the
|
401
|
+
beginning of training.
|
402
|
+
|
403
|
+
Parameters
|
404
|
+
----------
|
405
|
+
num_in : int
|
406
|
+
The number of input units.
|
407
|
+
num_out : int
|
408
|
+
The number of hidden/cell units.
|
409
|
+
w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
|
410
|
+
Initializer for the weight matrices.
|
411
|
+
b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
412
|
+
Initializer for the bias vectors.
|
413
|
+
state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
|
414
|
+
Initializer for the hidden and cell states.
|
415
|
+
activation : str or Callable, default='tanh'
|
416
|
+
Activation function to use. Can be a string (e.g., 'tanh')
|
417
|
+
or a callable function.
|
418
|
+
name : str, optional
|
419
|
+
Name of the module.
|
420
|
+
|
421
|
+
State Variables
|
422
|
+
--------------
|
423
|
+
h : HiddenState
|
424
|
+
Hidden state of the LSTM cell.
|
425
|
+
c : HiddenState
|
426
|
+
Cell state of the LSTM cell.
|
427
|
+
|
428
|
+
Methods
|
429
|
+
-------
|
430
|
+
init_state(batch_size=None, **kwargs)
|
431
|
+
Initialize the cell and hidden states.
|
432
|
+
reset_state(batch_size=None, **kwargs)
|
433
|
+
Reset the cell and hidden states to their initial values.
|
434
|
+
update(x)
|
435
|
+
Update the states for one time step and return the new hidden state.
|
436
|
+
|
437
|
+
References
|
438
|
+
----------
|
439
|
+
.. [1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.
|
440
|
+
Neural computation, 9(8), 1735-1780.
|
441
|
+
.. [2] Zaremba, W., Sutskever, I., & Vinyals, O. (2014). Recurrent neural
|
442
|
+
network regularization. arXiv preprint arXiv:1409.2329.
|
443
|
+
.. [3] Jozefowicz, R., Zaremba, W., & Sutskever, I. (2015). An empirical
|
444
|
+
exploration of recurrent network architectures. In International
|
445
|
+
conference on machine learning, pp. 2342-2350.
|
446
|
+
"""
|
447
|
+
__module__ = 'brainstate.nn'
|
448
|
+
|
449
|
+
def __init__(
|
450
|
+
self,
|
451
|
+
num_in: int,
|
452
|
+
num_out: int,
|
453
|
+
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
454
|
+
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
455
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
456
|
+
activation: str | Callable = 'tanh',
|
457
|
+
name: str = None,
|
458
|
+
):
|
459
|
+
super().__init__(name=name)
|
460
|
+
|
461
|
+
# parameters
|
462
|
+
self.num_out = num_out
|
463
|
+
self.num_in = num_in
|
464
|
+
self.in_size = (num_in,)
|
465
|
+
self.out_size = (num_out,)
|
466
|
+
|
467
|
+
# initializers
|
468
|
+
self._state_initializer = state_init
|
469
|
+
|
470
|
+
# activation function
|
471
|
+
if isinstance(activation, str):
|
472
|
+
self.activation = getattr(functional, activation)
|
473
|
+
else:
|
474
|
+
assert callable(activation), "The activation function should be a string or a callable function. "
|
475
|
+
self.activation = activation
|
476
|
+
|
477
|
+
# weights
|
478
|
+
self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=b_init)
|
479
|
+
|
480
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
481
|
+
self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
482
|
+
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
483
|
+
|
484
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
485
|
+
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
486
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
487
|
+
|
488
|
+
def update(self, x):
|
489
|
+
h, c = self.h.value, self.c.value
|
490
|
+
xh = jnp.concat([x, h], axis=-1)
|
491
|
+
i, g, f, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
|
492
|
+
c = functional.sigmoid(f + 1.) * c + functional.sigmoid(i) * self.activation(g)
|
493
|
+
h = functional.sigmoid(o) * self.activation(c)
|
494
|
+
self.h.value = h
|
495
|
+
self.c.value = c
|
496
|
+
return h
|
497
|
+
|
498
|
+
|
499
|
+
class URLSTMCell(RNNCell):
|
500
|
+
def __init__(
|
501
|
+
self,
|
502
|
+
num_in: int,
|
503
|
+
num_out: int,
|
504
|
+
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
505
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
506
|
+
activation: str | Callable = 'tanh',
|
507
|
+
name: str = None,
|
508
|
+
):
|
509
|
+
super().__init__(name=name)
|
510
|
+
|
511
|
+
# parameters
|
512
|
+
self.num_out = num_out
|
513
|
+
self.num_in = num_in
|
514
|
+
self.in_size = (num_in,)
|
515
|
+
self.out_size = (num_out,)
|
516
|
+
|
517
|
+
# initializers
|
518
|
+
self._state_initializer = state_init
|
519
|
+
|
520
|
+
# activation function
|
521
|
+
if isinstance(activation, str):
|
522
|
+
self.activation = getattr(functional, activation)
|
523
|
+
else:
|
524
|
+
assert callable(activation), "The activation function should be a string or a callable function. "
|
525
|
+
self.activation = activation
|
526
|
+
|
527
|
+
# weights
|
528
|
+
self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=None)
|
529
|
+
self.bias = ParamState(self._forget_bias())
|
530
|
+
|
531
|
+
def _forget_bias(self):
|
532
|
+
u = random.uniform(1 / self.num_out, 1 - 1 / self.num_out, (self.num_out,))
|
533
|
+
return -jnp.log(1 / u - 1)
|
534
|
+
|
535
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
536
|
+
self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
537
|
+
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
538
|
+
|
539
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
540
|
+
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
541
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
542
|
+
|
543
|
+
def update(self, x: ArrayLike) -> ArrayLike:
|
544
|
+
h, c = self.h.value, self.c.value
|
545
|
+
xh = jnp.concat([x, h], axis=-1)
|
546
|
+
f, r, u, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
|
547
|
+
f_ = functional.sigmoid(f + self.bias.value)
|
548
|
+
r_ = functional.sigmoid(r - self.bias.value)
|
549
|
+
g = 2 * r_ * f_ + (1 - 2 * r_) * f_ ** 2
|
550
|
+
next_cell = g * c + (1 - g) * self.activation(u)
|
551
|
+
next_hidden = functional.sigmoid(o) * self.activation(next_cell)
|
552
|
+
self.h.value = next_hidden
|
553
|
+
self.c.value = next_cell
|
554
|
+
return next_hidden
|