brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -1,554 +1,946 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- from typing import Callable, Union
20
-
21
- import jax.numpy as jnp
22
-
23
- from brainstate import random, init, functional
24
- from brainstate._state import HiddenState, ParamState
25
- from brainstate.typing import ArrayLike
26
- from ._linear import Linear
27
- from ._module import Module
28
-
29
- __all__ = [
30
- 'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
31
- ]
32
-
33
-
34
- class RNNCell(Module):
35
- """
36
- Base class for all recurrent neural network (RNN) cell implementations.
37
-
38
- This abstract class serves as the foundation for implementing various RNN cell types
39
- such as vanilla RNN, GRU, LSTM, and other recurrent architectures. It extends the
40
- Module class and provides common functionality and interface for recurrent units.
41
-
42
- All RNN cell implementations should inherit from this class and implement the required
43
- methods, particularly the `init_state()`, `reset_state()`, and `update()` methods that
44
- define the state initialization and recurrent dynamics.
45
-
46
- The RNNCell typically maintains hidden state(s) that are updated at each time step
47
- based on the current input and previous state values.
48
-
49
- Methods
50
- -------
51
- init_state(batch_size=None, **kwargs)
52
- Initialize the cell state variables with appropriate dimensions.
53
- reset_state(batch_size=None, **kwargs)
54
- Reset the cell state variables to their initial values.
55
- update(x)
56
- Update the cell state for one time step based on input x and return output.
57
- """
58
- pass
59
-
60
-
61
- class ValinaRNNCell(RNNCell):
62
- r"""
63
- Vanilla Recurrent Neural Network (RNN) cell implementation.
64
-
65
- This class implements the basic RNN model that updates a hidden state based on
66
- the current input and previous hidden state. The standard RNN cell follows the
67
- mathematical formulation:
68
-
69
- .. math::
70
-
71
- h_t = \phi(W [x_t, h_{t-1}] + b)
72
-
73
- where:
74
-
75
- - :math:`x_t` is the input vector at time t
76
- - :math:`h_t` is the hidden state at time t
77
- - :math:`h_{t-1}` is the hidden state at previous time step
78
- - :math:`W` is the weight matrix for the combined input-hidden linear transformation
79
- - :math:`b` is the bias vector
80
- - :math:`\phi` is the activation function
81
-
82
- Parameters
83
- ----------
84
- num_in : int
85
- The number of input units.
86
- num_out : int
87
- The number of hidden units.
88
- state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
89
- Initializer for the hidden state.
90
- w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
91
- Initializer for the weight matrix.
92
- b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
93
- Initializer for the bias vector.
94
- activation : str or Callable, default='relu'
95
- Activation function to use. Can be a string (e.g., 'relu', 'tanh')
96
- or a callable function.
97
- name : str, optional
98
- Name of the module.
99
-
100
- State Variables
101
- --------------
102
- h : HiddenState
103
- Hidden state of the RNN cell.
104
-
105
- Methods
106
- -------
107
- init_state(batch_size=None, **kwargs)
108
- Initialize the cell hidden state.
109
- reset_state(batch_size=None, **kwargs)
110
- Reset the cell hidden state to its initial value.
111
- update(x)
112
- Update the hidden state for one time step and return the new state.
113
- """
114
- __module__ = 'brainstate.nn'
115
-
116
- def __init__(
117
- self,
118
- num_in: int,
119
- num_out: int,
120
- state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
121
- w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
122
- b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
123
- activation: str | Callable = 'relu',
124
- name: str = None,
125
- ):
126
- super().__init__(name=name)
127
-
128
- # parameters
129
- self.num_out = num_out
130
- self.num_in = num_in
131
- self.in_size = (num_in,)
132
- self.out_size = (num_out,)
133
- self._state_initializer = state_init
134
-
135
- # activation function
136
- if isinstance(activation, str):
137
- self.activation = getattr(functional, activation)
138
- else:
139
- assert callable(activation), "The activation function should be a string or a callable function. "
140
- self.activation = activation
141
-
142
- # weights
143
- self.W = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
144
-
145
- def init_state(self, batch_size: int = None, **kwargs):
146
- self.h = HiddenState(init.param(self._state_initializer, self.num_out, batch_size))
147
-
148
- def reset_state(self, batch_size: int = None, **kwargs):
149
- self.h.value = init.param(self._state_initializer, self.num_out, batch_size)
150
-
151
- def update(self, x):
152
- xh = jnp.concatenate([x, self.h.value], axis=-1)
153
- h = self.W(xh)
154
- self.h.value = self.activation(h)
155
- return self.h.value
156
-
157
-
158
- class GRUCell(RNNCell):
159
- r"""
160
- Gated Recurrent Unit (GRU) cell implementation.
161
-
162
- This class implements the GRU model that uses gating mechanisms to control
163
- information flow. The GRU has fewer parameters than LSTM as it combines
164
- the forget and input gates into a single update gate. The GRU follows the
165
- mathematical formulation:
166
-
167
- .. math::
168
-
169
- r_t &= \sigma(W_r [x_t, h_{t-1}] + b_r) \\
170
- z_t &= \sigma(W_z [x_t, h_{t-1}] + b_z) \\
171
- \tilde{h}_t &= \tanh(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\
172
- h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
173
-
174
- where:
175
-
176
- - :math:`x_t` is the input vector at time t
177
- - :math:`h_t` is the hidden state at time t
178
- - :math:`r_t` is the reset gate vector
179
- - :math:`z_t` is the update gate vector
180
- - :math:`\tilde{h}_t` is the candidate hidden state
181
- - :math:`\odot` represents element-wise multiplication
182
- - :math:`\sigma` is the sigmoid activation function
183
-
184
- Parameters
185
- ----------
186
- num_in : int
187
- The number of input units.
188
- num_out : int
189
- The number of hidden units.
190
- w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
191
- Initializer for the weight matrices.
192
- b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
193
- Initializer for the bias vectors.
194
- state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
195
- Initializer for the hidden state.
196
- activation : str or Callable, default='tanh'
197
- Activation function to use. Can be a string (e.g., 'tanh')
198
- or a callable function.
199
- name : str, optional
200
- Name of the module.
201
-
202
- State Variables
203
- --------------
204
- h : HiddenState
205
- Hidden state of the GRU cell.
206
-
207
- Methods
208
- -------
209
- init_state(batch_size=None, **kwargs)
210
- Initialize the cell hidden state.
211
- reset_state(batch_size=None, **kwargs)
212
- Reset the cell hidden state to its initial value.
213
- update(x)
214
- Update the hidden state for one time step and return the new state.
215
- """
216
- __module__ = 'brainstate.nn'
217
-
218
- def __init__(
219
- self,
220
- num_in: int,
221
- num_out: int,
222
- w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
223
- b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
224
- state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
225
- activation: str | Callable = 'tanh',
226
- name: str = None,
227
- ):
228
- super().__init__(name=name)
229
-
230
- # parameters
231
- self._state_initializer = state_init
232
- self.num_out = num_out
233
- self.num_in = num_in
234
- self.in_size = (num_in,)
235
- self.out_size = (num_out,)
236
-
237
- # activation function
238
- if isinstance(activation, str):
239
- self.activation = getattr(functional, activation)
240
- else:
241
- assert callable(activation), "The activation function should be a string or a callable function. "
242
- self.activation = activation
243
-
244
- # weights
245
- self.Wrz = Linear(num_in + num_out, num_out * 2, w_init=w_init, b_init=b_init)
246
- self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
247
-
248
- def init_state(self, batch_size: int = None, **kwargs):
249
- self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
250
-
251
- def reset_state(self, batch_size: int = None, **kwargs):
252
- self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
253
-
254
- def update(self, x):
255
- old_h = self.h.value
256
- xh = jnp.concatenate([x, old_h], axis=-1)
257
- r, z = jnp.split(functional.sigmoid(self.Wrz(xh)), indices_or_sections=2, axis=-1)
258
- rh = r * old_h
259
- h = self.activation(self.Wh(jnp.concatenate([x, rh], axis=-1)))
260
- h = (1 - z) * old_h + z * h
261
- self.h.value = h
262
- return h
263
-
264
-
265
- class MGUCell(RNNCell):
266
- r"""
267
- Minimal Gated Recurrent Unit (MGU) cell implementation.
268
-
269
- MGU is a simplified version of GRU that uses a single forget gate instead of
270
- separate reset and update gates. This results in fewer parameters while
271
- maintaining much of the gating capability. The MGU follows the mathematical
272
- formulation:
273
-
274
- .. math::
275
-
276
- f_t &= \\sigma(W_f [x_t, h_{t-1}] + b_f) \\\\
277
- \\tilde{h}_t &= \\phi(W_h [x_t, (f_t \\odot h_{t-1})] + b_h) \\\\
278
- h_t &= (1 - f_t) \\odot h_{t-1} + f_t \\odot \\tilde{h}_t
279
-
280
- where:
281
-
282
- - :math:`x_t` is the input vector at time t
283
- - :math:`h_t` is the hidden state at time t
284
- - :math:`f_t` is the forget gate vector
285
- - :math:`\\tilde{h}_t` is the candidate hidden state
286
- - :math:`\\odot` represents element-wise multiplication
287
- - :math:`\\sigma` is the sigmoid activation function
288
- - :math:`\\phi` is the activation function (typically tanh)
289
-
290
- Parameters
291
- ----------
292
- num_in : int
293
- The number of input units.
294
- num_out : int
295
- The number of hidden units.
296
- w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
297
- Initializer for the weight matrices.
298
- b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
299
- Initializer for the bias vectors.
300
- state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
301
- Initializer for the hidden state.
302
- activation : str or Callable, default='tanh'
303
- Activation function to use. Can be a string (e.g., 'tanh')
304
- or a callable function.
305
- name : str, optional
306
- Name of the module.
307
-
308
- State Variables
309
- --------------
310
- h : HiddenState
311
- Hidden state of the MGU cell.
312
-
313
- Methods
314
- -------
315
- init_state(batch_size=None, **kwargs)
316
- Initialize the cell hidden state.
317
- reset_state(batch_size=None, **kwargs)
318
- Reset the cell hidden state to its initial value.
319
- update(x)
320
- Update the hidden state for one time step and return the new state.
321
- """
322
- __module__ = 'brainstate.nn'
323
-
324
- def __init__(
325
- self,
326
- num_in: int,
327
- num_out: int,
328
- w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
329
- b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
330
- state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
331
- activation: str | Callable = 'tanh',
332
- name: str = None,
333
- ):
334
- super().__init__(name=name)
335
-
336
- # parameters
337
- self._state_initializer = state_init
338
- self.num_out = num_out
339
- self.num_in = num_in
340
- self.in_size = (num_in,)
341
- self.out_size = (num_out,)
342
-
343
- # activation function
344
- if isinstance(activation, str):
345
- self.activation = getattr(functional, activation)
346
- else:
347
- assert callable(activation), "The activation function should be a string or a callable function. "
348
- self.activation = activation
349
-
350
- # weights
351
- self.Wf = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
352
- self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
353
-
354
- def init_state(self, batch_size: int = None, **kwargs):
355
- self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
356
-
357
- def reset_state(self, batch_size: int = None, **kwargs):
358
- self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
359
-
360
- def update(self, x):
361
- old_h = self.h.value
362
- xh = jnp.concatenate([x, old_h], axis=-1)
363
- f = functional.sigmoid(self.Wf(xh))
364
- fh = f * old_h
365
- h = self.activation(self.Wh(jnp.concatenate([x, fh], axis=-1)))
366
- self.h.value = (1 - f) * self.h.value + f * h
367
- return self.h.value
368
-
369
-
370
- class LSTMCell(RNNCell):
371
- r"""
372
- Long Short-Term Memory (LSTM) cell implementation.
373
-
374
- This class implements the LSTM architecture which uses multiple gating mechanisms
375
- to regulate information flow and address the vanishing gradient problem in RNNs.
376
- The LSTM follows the mathematical formulation:
377
-
378
- .. math::
379
-
380
- i_t &= \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\
381
- f_t &= \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\
382
- g_t &= \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\
383
- o_t &= \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\
384
- c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\
385
- h_t &= o_t \odot \tanh(c_t)
386
-
387
- where:
388
-
389
- - :math:`x_t` is the input vector at time t
390
- - :math:`h_t` is the hidden state at time t
391
- - :math:`c_t` is the cell state at time t
392
- - :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and output gate activations
393
- - :math:`g_t` is the cell update vector
394
- - :math:`\odot` represents element-wise multiplication
395
- - :math:`\sigma` is the sigmoid activation function
396
-
397
- Notes
398
- -----
399
- Forget gate initialization: Following Jozefowicz et al. (2015), we add 1.0
400
- to the forget gate bias after initialization to reduce forgetting at the
401
- beginning of training.
402
-
403
- Parameters
404
- ----------
405
- num_in : int
406
- The number of input units.
407
- num_out : int
408
- The number of hidden/cell units.
409
- w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
410
- Initializer for the weight matrices.
411
- b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
412
- Initializer for the bias vectors.
413
- state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
414
- Initializer for the hidden and cell states.
415
- activation : str or Callable, default='tanh'
416
- Activation function to use. Can be a string (e.g., 'tanh')
417
- or a callable function.
418
- name : str, optional
419
- Name of the module.
420
-
421
- State Variables
422
- --------------
423
- h : HiddenState
424
- Hidden state of the LSTM cell.
425
- c : HiddenState
426
- Cell state of the LSTM cell.
427
-
428
- Methods
429
- -------
430
- init_state(batch_size=None, **kwargs)
431
- Initialize the cell and hidden states.
432
- reset_state(batch_size=None, **kwargs)
433
- Reset the cell and hidden states to their initial values.
434
- update(x)
435
- Update the states for one time step and return the new hidden state.
436
-
437
- References
438
- ----------
439
- .. [1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.
440
- Neural computation, 9(8), 1735-1780.
441
- .. [2] Zaremba, W., Sutskever, I., & Vinyals, O. (2014). Recurrent neural
442
- network regularization. arXiv preprint arXiv:1409.2329.
443
- .. [3] Jozefowicz, R., Zaremba, W., & Sutskever, I. (2015). An empirical
444
- exploration of recurrent network architectures. In International
445
- conference on machine learning, pp. 2342-2350.
446
- """
447
- __module__ = 'brainstate.nn'
448
-
449
- def __init__(
450
- self,
451
- num_in: int,
452
- num_out: int,
453
- w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
454
- b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
455
- state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
456
- activation: str | Callable = 'tanh',
457
- name: str = None,
458
- ):
459
- super().__init__(name=name)
460
-
461
- # parameters
462
- self.num_out = num_out
463
- self.num_in = num_in
464
- self.in_size = (num_in,)
465
- self.out_size = (num_out,)
466
-
467
- # initializers
468
- self._state_initializer = state_init
469
-
470
- # activation function
471
- if isinstance(activation, str):
472
- self.activation = getattr(functional, activation)
473
- else:
474
- assert callable(activation), "The activation function should be a string or a callable function. "
475
- self.activation = activation
476
-
477
- # weights
478
- self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=b_init)
479
-
480
- def init_state(self, batch_size: int = None, **kwargs):
481
- self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
482
- self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
483
-
484
- def reset_state(self, batch_size: int = None, **kwargs):
485
- self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
486
- self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
487
-
488
- def update(self, x):
489
- h, c = self.h.value, self.c.value
490
- xh = jnp.concat([x, h], axis=-1)
491
- i, g, f, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
492
- c = functional.sigmoid(f + 1.) * c + functional.sigmoid(i) * self.activation(g)
493
- h = functional.sigmoid(o) * self.activation(c)
494
- self.h.value = h
495
- self.c.value = c
496
- return h
497
-
498
-
499
- class URLSTMCell(RNNCell):
500
- def __init__(
501
- self,
502
- num_in: int,
503
- num_out: int,
504
- w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
505
- state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
506
- activation: str | Callable = 'tanh',
507
- name: str = None,
508
- ):
509
- super().__init__(name=name)
510
-
511
- # parameters
512
- self.num_out = num_out
513
- self.num_in = num_in
514
- self.in_size = (num_in,)
515
- self.out_size = (num_out,)
516
-
517
- # initializers
518
- self._state_initializer = state_init
519
-
520
- # activation function
521
- if isinstance(activation, str):
522
- self.activation = getattr(functional, activation)
523
- else:
524
- assert callable(activation), "The activation function should be a string or a callable function. "
525
- self.activation = activation
526
-
527
- # weights
528
- self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=None)
529
- self.bias = ParamState(self._forget_bias())
530
-
531
- def _forget_bias(self):
532
- u = random.uniform(1 / self.num_out, 1 - 1 / self.num_out, (self.num_out,))
533
- return -jnp.log(1 / u - 1)
534
-
535
- def init_state(self, batch_size: int = None, **kwargs):
536
- self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
537
- self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
538
-
539
- def reset_state(self, batch_size: int = None, **kwargs):
540
- self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
541
- self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
542
-
543
- def update(self, x: ArrayLike) -> ArrayLike:
544
- h, c = self.h.value, self.c.value
545
- xh = jnp.concat([x, h], axis=-1)
546
- f, r, u, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
547
- f_ = functional.sigmoid(f + self.bias.value)
548
- r_ = functional.sigmoid(r - self.bias.value)
549
- g = 2 * r_ * f_ + (1 - 2 * r_) * f_ ** 2
550
- next_cell = g * c + (1 - g) * self.activation(u)
551
- next_hidden = functional.sigmoid(o) * self.activation(next_cell)
552
- self.h.value = next_hidden
553
- self.c.value = next_cell
554
- return next_hidden
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+
19
+ from typing import Callable, Union
20
+
21
+ import jax.numpy as jnp
22
+
23
+ from brainstate import random
24
+ from brainstate._state import HiddenState, ParamState
25
+ from brainstate.typing import ArrayLike
26
+ from . import _activations as functional
27
+ from . import init as init
28
+ from ._linear import Linear
29
+ from ._module import Module
30
+
31
+ __all__ = [
32
+ 'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
33
+ ]
34
+
35
+
36
+ class RNNCell(Module):
37
+ """
38
+ Base class for all recurrent neural network (RNN) cell implementations.
39
+
40
+ This abstract class serves as the foundation for implementing various RNN cell types
41
+ such as vanilla RNN, GRU, LSTM, and other recurrent architectures. It extends the
42
+ Module class and provides common functionality and interface for recurrent units.
43
+
44
+ All RNN cell implementations should inherit from this class and implement the required
45
+ methods, particularly the `init_state()`, `reset_state()`, and `update()` methods that
46
+ define the state initialization and recurrent dynamics.
47
+
48
+ The RNNCell typically maintains hidden state(s) that are updated at each time step
49
+ based on the current input and previous state values.
50
+
51
+ Methods
52
+ -------
53
+ init_state(batch_size=None, **kwargs)
54
+ Initialize the cell state variables with appropriate dimensions.
55
+ reset_state(batch_size=None, **kwargs)
56
+ Reset the cell state variables to their initial values.
57
+ update(x)
58
+ Update the cell state for one time step based on input x and return output.
59
+
60
+ See Also
61
+ --------
62
+ ValinaRNNCell : Vanilla RNN cell implementation
63
+ GRUCell : Gated Recurrent Unit cell implementation
64
+ LSTMCell : Long Short-Term Memory cell implementation
65
+ URLSTMCell : LSTM with UR gating mechanism
66
+ MGUCell : Minimal Gated Unit cell implementation
67
+ """
68
+ __module__ = 'brainstate.nn'
69
+ pass
70
+
71
+
72
+ class ValinaRNNCell(RNNCell):
73
+ r"""
74
+ Vanilla Recurrent Neural Network (RNN) cell implementation.
75
+
76
+ This class implements the basic RNN model that updates a hidden state based on
77
+ the current input and previous hidden state. The standard RNN cell follows the
78
+ mathematical formulation:
79
+
80
+ .. math::
81
+
82
+ h_t = \phi(W [x_t, h_{t-1}] + b)
83
+
84
+ where:
85
+
86
+ - :math:`x_t` is the input vector at time t
87
+ - :math:`h_t` is the hidden state at time t
88
+ - :math:`h_{t-1}` is the hidden state at previous time step
89
+ - :math:`W` is the weight matrix for the combined input-hidden linear transformation
90
+ - :math:`b` is the bias vector
91
+ - :math:`\phi` is the activation function
92
+
93
+ Parameters
94
+ ----------
95
+ num_in : int
96
+ The number of input units.
97
+ num_out : int
98
+ The number of hidden units.
99
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
100
+ Initializer for the hidden state.
101
+ w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
102
+ Initializer for the weight matrix.
103
+ b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
104
+ Initializer for the bias vector.
105
+ activation : str or Callable, default='relu'
106
+ Activation function to use. Can be a string (e.g., 'relu', 'tanh')
107
+ or a callable function.
108
+ name : str, optional
109
+ Name of the module.
110
+
111
+ Attributes
112
+ ----------
113
+ num_in : int
114
+ Number of input features.
115
+ num_out : int
116
+ Number of hidden units.
117
+ in_size : tuple
118
+ Shape of input (num_in,).
119
+ out_size : tuple
120
+ Shape of output (num_out,).
121
+
122
+ State Variables
123
+ ---------------
124
+ h : HiddenState
125
+ Hidden state of the RNN cell.
126
+
127
+ Methods
128
+ -------
129
+ init_state(batch_size=None, **kwargs)
130
+ Initialize the cell hidden state.
131
+ reset_state(batch_size=None, **kwargs)
132
+ Reset the cell hidden state to its initial value.
133
+ update(x)
134
+ Update the hidden state for one time step and return the new state.
135
+
136
+ Examples
137
+ --------
138
+ .. code-block:: python
139
+
140
+ >>> import brainstate as bs
141
+ >>> import jax.numpy as jnp
142
+ >>>
143
+ >>> # Create a vanilla RNN cell
144
+ >>> cell = bs.nn.ValinaRNNCell(num_in=10, num_out=20)
145
+ >>>
146
+ >>> # Initialize state for batch size 32
147
+ >>> cell.init_state(batch_size=32)
148
+ >>>
149
+ >>> # Process a single time step
150
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
151
+ >>> output = cell.update(x)
152
+ >>> print(output.shape) # (32, 20)
153
+ >>>
154
+ >>> # Process a sequence of inputs
155
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
156
+ >>> outputs = []
157
+ >>> for t in range(100):
158
+ ... output = cell.update(sequence[t])
159
+ ... outputs.append(output)
160
+ >>> outputs = jnp.stack(outputs)
161
+ >>> print(outputs.shape) # (100, 32, 20)
162
+
163
+ Notes
164
+ -----
165
+ Vanilla RNNs can suffer from vanishing or exploding gradient problems
166
+ when processing long sequences. For better performance on long sequences,
167
+ consider using gated architectures like GRU or LSTM.
168
+ """
169
+ __module__ = 'brainstate.nn'
170
+
171
+ def __init__(
172
+ self,
173
+ num_in: int,
174
+ num_out: int,
175
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
176
+ w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
177
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
178
+ activation: str | Callable = 'relu',
179
+ name: str = None,
180
+ ):
181
+ super().__init__(name=name)
182
+
183
+ # parameters
184
+ self.num_out = num_out
185
+ self.num_in = num_in
186
+ self.in_size = (num_in,)
187
+ self.out_size = (num_out,)
188
+ self._state_initializer = state_init
189
+
190
+ # activation function
191
+ if isinstance(activation, str):
192
+ self.activation = getattr(functional, activation)
193
+ else:
194
+ assert callable(activation), "The activation function should be a string or a callable function. "
195
+ self.activation = activation
196
+
197
+ # weights
198
+ self.W = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
199
+
200
+ def init_state(self, batch_size: int = None, **kwargs):
201
+ """
202
+ Initialize the hidden state.
203
+
204
+ Parameters
205
+ ----------
206
+ batch_size : int, optional
207
+ The batch size for state initialization.
208
+ **kwargs
209
+ Additional keyword arguments.
210
+ """
211
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
212
+
213
+ def reset_state(self, batch_size: int = None, **kwargs):
214
+ """
215
+ Reset the hidden state to initial value.
216
+
217
+ Parameters
218
+ ----------
219
+ batch_size : int, optional
220
+ The batch size for state reset.
221
+ **kwargs
222
+ Additional keyword arguments.
223
+ """
224
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
225
+
226
+ def update(self, x):
227
+ xh = jnp.concatenate([x, self.h.value], axis=-1)
228
+ h = self.W(xh)
229
+ self.h.value = self.activation(h)
230
+ return self.h.value
231
+
232
+
233
+ class GRUCell(RNNCell):
234
+ r"""
235
+ Gated Recurrent Unit (GRU) cell implementation.
236
+
237
+ The GRU is a gating mechanism in recurrent neural networks that aims to solve
238
+ the vanishing gradient problem. It uses gating mechanisms to control information
239
+ flow and has fewer parameters than LSTM as it combines the forget and input gates
240
+ into a single update gate.
241
+
242
+ The GRU cell follows the mathematical formulation:
243
+
244
+ .. math::
245
+
246
+ r_t &= \sigma(W_r [x_t, h_{t-1}] + b_r) \\
247
+ z_t &= \sigma(W_z [x_t, h_{t-1}] + b_z) \\
248
+ \tilde{h}_t &= \phi(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\
249
+ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
250
+
251
+ where:
252
+
253
+ - :math:`x_t` is the input vector at time t
254
+ - :math:`h_t` is the hidden state at time t
255
+ - :math:`r_t` is the reset gate vector
256
+ - :math:`z_t` is the update gate vector
257
+ - :math:`\tilde{h}_t` is the candidate hidden state
258
+ - :math:`\odot` represents element-wise multiplication
259
+ - :math:`\sigma` is the sigmoid activation function
260
+ - :math:`\phi` is the activation function (typically tanh)
261
+
262
+ Parameters
263
+ ----------
264
+ num_in : int
265
+ The number of input units.
266
+ num_out : int
267
+ The number of hidden units.
268
+ w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
269
+ Initializer for the weight matrices.
270
+ b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
271
+ Initializer for the bias vectors.
272
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
273
+ Initializer for the hidden state.
274
+ activation : str or Callable, default='tanh'
275
+ Activation function to use. Can be a string (e.g., 'tanh', 'relu')
276
+ or a callable function.
277
+ name : str, optional
278
+ Name of the module.
279
+
280
+ Attributes
281
+ ----------
282
+ num_in : int
283
+ Number of input features.
284
+ num_out : int
285
+ Number of hidden units.
286
+ in_size : tuple
287
+ Shape of input (num_in,).
288
+ out_size : tuple
289
+ Shape of output (num_out,).
290
+
291
+ State Variables
292
+ ---------------
293
+ h : HiddenState
294
+ Hidden state of the GRU cell.
295
+
296
+ Methods
297
+ -------
298
+ init_state(batch_size=None, **kwargs)
299
+ Initialize the cell hidden state.
300
+ reset_state(batch_size=None, **kwargs)
301
+ Reset the cell hidden state to its initial value.
302
+ update(x)
303
+ Update the hidden state for one time step and return the new state.
304
+
305
+ Examples
306
+ --------
307
+ .. code-block:: python
308
+
309
+ >>> import brainstate as bs
310
+ >>> import jax.numpy as jnp
311
+ >>>
312
+ >>> # Create a GRU cell
313
+ >>> cell = bs.nn.GRUCell(num_in=10, num_out=20)
314
+ >>>
315
+ >>> # Initialize state for batch size 32
316
+ >>> cell.init_state(batch_size=32)
317
+ >>>
318
+ >>> # Process a single time step
319
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
320
+ >>> output = cell.update(x)
321
+ >>> print(output.shape) # (32, 20)
322
+ >>>
323
+ >>> # Process a sequence
324
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
325
+ >>> outputs = []
326
+ >>> for t in range(100):
327
+ ... output = cell.update(sequence[t])
328
+ ... outputs.append(output)
329
+ >>> outputs = jnp.stack(outputs)
330
+ >>> print(outputs.shape) # (100, 32, 20)
331
+ >>>
332
+ >>> # Reset state with different batch size
333
+ >>> cell.reset_state(batch_size=16)
334
+ >>> x_new = jnp.ones((16, 10))
335
+ >>> output_new = cell.update(x_new)
336
+ >>> print(output_new.shape) # (16, 20)
337
+
338
+ Notes
339
+ -----
340
+ GRU cells are computationally more efficient than LSTM cells due to having
341
+ fewer parameters, while often achieving comparable performance on many tasks.
342
+
343
+ References
344
+ ----------
345
+ .. [1] Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F.,
346
+ Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using
347
+ RNN encoder-decoder for statistical machine translation.
348
+ arXiv preprint arXiv:1406.1078.
349
+ """
350
+ __module__ = 'brainstate.nn'
351
+
352
+ def __init__(
353
+ self,
354
+ num_in: int,
355
+ num_out: int,
356
+ w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
357
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
358
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
359
+ activation: str | Callable = 'tanh',
360
+ name: str = None,
361
+ ):
362
+ super().__init__(name=name)
363
+
364
+ # parameters
365
+ self._state_initializer = state_init
366
+ self.num_out = num_out
367
+ self.num_in = num_in
368
+ self.in_size = (num_in,)
369
+ self.out_size = (num_out,)
370
+
371
+ # activation function
372
+ if isinstance(activation, str):
373
+ self.activation = getattr(functional, activation)
374
+ else:
375
+ assert callable(activation), "The activation function should be a string or a callable function. "
376
+ self.activation = activation
377
+
378
+ # weights
379
+ self.Wrz = Linear(num_in + num_out, num_out * 2, w_init=w_init, b_init=b_init)
380
+ self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
381
+
382
+ def init_state(self, batch_size: int = None, **kwargs):
383
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
384
+
385
+ def reset_state(self, batch_size: int = None, **kwargs):
386
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
387
+
388
+ def update(self, x):
389
+ old_h = self.h.value
390
+ xh = jnp.concatenate([x, old_h], axis=-1)
391
+ r, z = jnp.split(functional.sigmoid(self.Wrz(xh)), indices_or_sections=2, axis=-1)
392
+ rh = r * old_h
393
+ h = self.activation(self.Wh(jnp.concatenate([x, rh], axis=-1)))
394
+ h = (1 - z) * old_h + z * h
395
+ self.h.value = h
396
+ return h
397
+
398
+
399
+ class MGUCell(RNNCell):
400
+ r"""
401
+ Minimal Gated Unit (MGU) cell implementation.
402
+
403
+ MGU is a simplified version of GRU that uses a single forget gate instead of
404
+ separate reset and update gates. This design significantly reduces the number
405
+ of parameters while maintaining much of the gating capability. MGU provides
406
+ a good trade-off between model complexity and performance.
407
+
408
+ The MGU cell follows the mathematical formulation:
409
+
410
+ .. math::
411
+
412
+ f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
413
+ \tilde{h}_t &= \phi(W_h [x_t, (f_t \odot h_{t-1})] + b_h) \\
414
+ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t
415
+
416
+ where:
417
+
418
+ - :math:`x_t` is the input vector at time t
419
+ - :math:`h_t` is the hidden state at time t
420
+ - :math:`f_t` is the forget gate vector
421
+ - :math:`\tilde{h}_t` is the candidate hidden state
422
+ - :math:`\odot` represents element-wise multiplication
423
+ - :math:`\sigma` is the sigmoid activation function
424
+ - :math:`\phi` is the activation function (typically tanh)
425
+
426
+ Parameters
427
+ ----------
428
+ num_in : int
429
+ The number of input units.
430
+ num_out : int
431
+ The number of hidden units.
432
+ w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
433
+ Initializer for the weight matrices.
434
+ b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
435
+ Initializer for the bias vectors.
436
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
437
+ Initializer for the hidden state.
438
+ activation : str or Callable, default='tanh'
439
+ Activation function to use. Can be a string (e.g., 'tanh', 'relu')
440
+ or a callable function.
441
+ name : str, optional
442
+ Name of the module.
443
+
444
+ Attributes
445
+ ----------
446
+ num_in : int
447
+ Number of input features.
448
+ num_out : int
449
+ Number of hidden units.
450
+ in_size : tuple
451
+ Shape of input (num_in,).
452
+ out_size : tuple
453
+ Shape of output (num_out,).
454
+
455
+ State Variables
456
+ ---------------
457
+ h : HiddenState
458
+ Hidden state of the MGU cell.
459
+
460
+ Methods
461
+ -------
462
+ init_state(batch_size=None, **kwargs)
463
+ Initialize the cell hidden state.
464
+ reset_state(batch_size=None, **kwargs)
465
+ Reset the cell hidden state to its initial value.
466
+ update(x)
467
+ Update the hidden state for one time step and return the new state.
468
+
469
+ Examples
470
+ --------
471
+ .. code-block:: python
472
+
473
+ >>> import brainstate as bs
474
+ >>> import jax.numpy as jnp
475
+ >>>
476
+ >>> # Create an MGU cell
477
+ >>> cell = bs.nn.MGUCell(num_in=10, num_out=20)
478
+ >>>
479
+ >>> # Initialize state for batch size 32
480
+ >>> cell.init_state(batch_size=32)
481
+ >>>
482
+ >>> # Process a single time step
483
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
484
+ >>> output = cell.update(x)
485
+ >>> print(output.shape) # (32, 20)
486
+ >>>
487
+ >>> # Process a sequence
488
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
489
+ >>> outputs = []
490
+ >>> for t in range(100):
491
+ ... output = cell.update(sequence[t])
492
+ ... outputs.append(output)
493
+ >>> outputs = jnp.stack(outputs)
494
+ >>> print(outputs.shape) # (100, 32, 20)
495
+
496
+ Notes
497
+ -----
498
+ MGU provides a lightweight alternative to GRU and LSTM, making it suitable
499
+ for resource-constrained applications or when model simplicity is preferred.
500
+
501
+ References
502
+ ----------
503
+ .. [1] Zhou, G. B., Wu, J., Zhang, C. L., & Zhou, Z. H. (2016). Minimal gated unit
504
+ for recurrent neural networks. International Journal of Automation and Computing,
505
+ 13(3), 226-234.
506
+ """
507
+ __module__ = 'brainstate.nn'
508
+
509
+ def __init__(
510
+ self,
511
+ num_in: int,
512
+ num_out: int,
513
+ w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
514
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
515
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
516
+ activation: str | Callable = 'tanh',
517
+ name: str = None,
518
+ ):
519
+ super().__init__(name=name)
520
+
521
+ # parameters
522
+ self._state_initializer = state_init
523
+ self.num_out = num_out
524
+ self.num_in = num_in
525
+ self.in_size = (num_in,)
526
+ self.out_size = (num_out,)
527
+
528
+ # activation function
529
+ if isinstance(activation, str):
530
+ self.activation = getattr(functional, activation)
531
+ else:
532
+ assert callable(activation), "The activation function should be a string or a callable function. "
533
+ self.activation = activation
534
+
535
+ # weights
536
+ self.Wf = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
537
+ self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
538
+
539
+ def init_state(self, batch_size: int = None, **kwargs):
540
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
541
+
542
+ def reset_state(self, batch_size: int = None, **kwargs):
543
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
544
+
545
+ def update(self, x):
546
+ old_h = self.h.value
547
+ xh = jnp.concatenate([x, old_h], axis=-1)
548
+ f = functional.sigmoid(self.Wf(xh))
549
+ fh = f * old_h
550
+ h = self.activation(self.Wh(jnp.concatenate([x, fh], axis=-1)))
551
+ self.h.value = (1 - f) * self.h.value + f * h
552
+ return self.h.value
553
+
554
+
555
+ class LSTMCell(RNNCell):
556
+ r"""
557
+ Long Short-Term Memory (LSTM) cell implementation.
558
+
559
+ LSTM is a type of RNN architecture designed to address the vanishing gradient
560
+ problem and learn long-term dependencies. It uses a cell state to carry
561
+ information across time steps and three gates (input, forget, output) to
562
+ control information flow.
563
+
564
+ The LSTM cell follows the mathematical formulation:
565
+
566
+ .. math::
567
+
568
+ i_t &= \sigma(W_i [x_t, h_{t-1}] + b_i) \\
569
+ f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
570
+ g_t &= \phi(W_g [x_t, h_{t-1}] + b_g) \\
571
+ o_t &= \sigma(W_o [x_t, h_{t-1}] + b_o) \\
572
+ c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\
573
+ h_t &= o_t \odot \phi(c_t)
574
+
575
+ where:
576
+
577
+ - :math:`x_t` is the input vector at time t
578
+ - :math:`h_t` is the hidden state at time t
579
+ - :math:`c_t` is the cell state at time t
580
+ - :math:`i_t` is the input gate activation
581
+ - :math:`f_t` is the forget gate activation
582
+ - :math:`o_t` is the output gate activation
583
+ - :math:`g_t` is the cell update (candidate) vector
584
+ - :math:`\odot` represents element-wise multiplication
585
+ - :math:`\sigma` is the sigmoid activation function
586
+ - :math:`\phi` is the activation function (typically tanh)
587
+
588
+ Parameters
589
+ ----------
590
+ num_in : int
591
+ The number of input units.
592
+ num_out : int
593
+ The number of hidden/cell units.
594
+ w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
595
+ Initializer for the weight matrices.
596
+ b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
597
+ Initializer for the bias vectors.
598
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
599
+ Initializer for the hidden and cell states.
600
+ activation : str or Callable, default='tanh'
601
+ Activation function to use. Can be a string (e.g., 'tanh', 'relu')
602
+ or a callable function.
603
+ name : str, optional
604
+ Name of the module.
605
+
606
+ Attributes
607
+ ----------
608
+ num_in : int
609
+ Number of input features.
610
+ num_out : int
611
+ Number of hidden/cell units.
612
+ in_size : tuple
613
+ Shape of input (num_in,).
614
+ out_size : tuple
615
+ Shape of output (num_out,).
616
+
617
+ State Variables
618
+ ---------------
619
+ h : HiddenState
620
+ Hidden state of the LSTM cell.
621
+ c : HiddenState
622
+ Cell state of the LSTM cell.
623
+
624
+ Methods
625
+ -------
626
+ init_state(batch_size=None, **kwargs)
627
+ Initialize the cell and hidden states.
628
+ reset_state(batch_size=None, **kwargs)
629
+ Reset the cell and hidden states to their initial values.
630
+ update(x)
631
+ Update the states for one time step and return the new hidden state.
632
+
633
+ Examples
634
+ --------
635
+ .. code-block:: python
636
+
637
+ >>> import brainstate as bs
638
+ >>> import jax.numpy as jnp
639
+ >>>
640
+ >>> # Create an LSTM cell
641
+ >>> cell = bs.nn.LSTMCell(num_in=10, num_out=20)
642
+ >>>
643
+ >>> # Initialize states for batch size 32
644
+ >>> cell.init_state(batch_size=32)
645
+ >>>
646
+ >>> # Process a single time step
647
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
648
+ >>> output = cell.update(x)
649
+ >>> print(output.shape) # (32, 20)
650
+ >>>
651
+ >>> # Process a sequence
652
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
653
+ >>> outputs = []
654
+ >>> for t in range(100):
655
+ ... output = cell.update(sequence[t])
656
+ ... outputs.append(output)
657
+ >>> outputs = jnp.stack(outputs)
658
+ >>> print(outputs.shape) # (100, 32, 20)
659
+ >>>
660
+ >>> # Access cell state
661
+ >>> print(cell.c.value.shape) # (32, 20)
662
+ >>> print(cell.h.value.shape) # (32, 20)
663
+
664
+ Notes
665
+ -----
666
+ - The forget gate bias is initialized with +1.0 following Jozefowicz et al. (2015)
667
+ to reduce forgetting at the beginning of training.
668
+ - LSTM cells are effective for learning long-term dependencies but require
669
+ more parameters and computation than simpler RNN variants.
670
+
671
+ References
672
+ ----------
673
+ .. [1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.
674
+ Neural computation, 9(8), 1735-1780.
675
+ .. [2] Gers, F. A., Schmidhuber, J., & Cummins, F. (2000). Learning to forget:
676
+ Continual prediction with LSTM. Neural computation, 12(10), 2451-2471.
677
+ .. [3] Jozefowicz, R., Zaremba, W., & Sutskever, I. (2015). An empirical
678
+ exploration of recurrent network architectures. In International
679
+ conference on machine learning (pp. 2342-2350).
680
+ """
681
+ __module__ = 'brainstate.nn'
682
+
683
+ def __init__(
684
+ self,
685
+ num_in: int,
686
+ num_out: int,
687
+ w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
688
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
689
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
690
+ activation: str | Callable = 'tanh',
691
+ name: str = None,
692
+ ):
693
+ super().__init__(name=name)
694
+
695
+ # parameters
696
+ self.num_out = num_out
697
+ self.num_in = num_in
698
+ self.in_size = (num_in,)
699
+ self.out_size = (num_out,)
700
+
701
+ # initializers
702
+ self._state_initializer = state_init
703
+
704
+ # activation function
705
+ if isinstance(activation, str):
706
+ self.activation = getattr(functional, activation)
707
+ else:
708
+ assert callable(activation), "The activation function should be a string or a callable function. "
709
+ self.activation = activation
710
+
711
+ # weights
712
+ self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=b_init)
713
+
714
+ def init_state(self, batch_size: int = None, **kwargs):
715
+ self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
716
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
717
+
718
+ def reset_state(self, batch_size: int = None, **kwargs):
719
+ self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
720
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
721
+
722
+ def update(self, x):
723
+ h, c = self.h.value, self.c.value
724
+ xh = jnp.concat([x, h], axis=-1)
725
+ i, g, f, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
726
+ c = functional.sigmoid(f + 1.) * c + functional.sigmoid(i) * self.activation(g)
727
+ h = functional.sigmoid(o) * self.activation(c)
728
+ self.h.value = h
729
+ self.c.value = c
730
+ return h
731
+
732
+
733
+ class URLSTMCell(RNNCell):
734
+ r"""LSTM with UR gating mechanism.
735
+
736
+ URLSTM is a modification of the standard LSTM that uses untied (separate) biases
737
+ for the forget and retention mechanisms, allowing for more flexible gating control.
738
+ This implementation is based on the paper "Improving the Gating Mechanism of
739
+ Recurrent Neural Networks" by Gers et al.
740
+
741
+ The URLSTM cell follows the mathematical formulation:
742
+
743
+ .. math::
744
+
745
+ f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
746
+ r_t &= \sigma(W_r [x_t, h_{t-1}] - b_f) \\
747
+ g_t &= 2 r_t \odot f_t + (1 - 2 r_t) \odot f_t^2 \\
748
+ \tilde{c}_t &= \phi(W_c [x_t, h_{t-1}]) \\
749
+ c_t &= g_t \odot c_{t-1} + (1 - g_t) \odot \tilde{c}_t \\
750
+ o_t &= \sigma(W_o [x_t, h_{t-1}]) \\
751
+ h_t &= o_t \odot \phi(c_t)
752
+
753
+ where:
754
+
755
+ - :math:`x_t` is the input vector at time t
756
+ - :math:`h_t` is the hidden state at time t
757
+ - :math:`c_t` is the cell state at time t
758
+ - :math:`f_t` is the forget gate with positive bias
759
+ - :math:`r_t` is the retention gate with negative bias
760
+ - :math:`g_t` is the unified gate combining forget and retention
761
+ - :math:`\tilde{c}_t` is the candidate cell state
762
+ - :math:`o_t` is the output gate
763
+ - :math:`\odot` represents element-wise multiplication
764
+ - :math:`\sigma` is the sigmoid activation function
765
+ - :math:`\phi` is the activation function (typically tanh)
766
+
767
+ The key innovation is the untied bias mechanism where the forget and retention
768
+ gates use opposite biases, initialized using a uniform distribution to encourage
769
+ diverse gating behavior across units.
770
+
771
+ Parameters
772
+ ----------
773
+ num_in : int
774
+ The number of input units.
775
+ num_out : int
776
+ The number of hidden/output units.
777
+ w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
778
+ Initializer for the weight matrix.
779
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
780
+ Initializer for the hidden and cell states.
781
+ activation : str or Callable, default='tanh'
782
+ Activation function to use. Can be a string (e.g., 'relu', 'tanh')
783
+ or a callable function.
784
+ name : str, optional
785
+ Name of the module.
786
+
787
+ State Variables
788
+ ---------------
789
+ h : HiddenState
790
+ Hidden state of the URLSTM cell.
791
+ c : HiddenState
792
+ Cell state of the URLSTM cell.
793
+
794
+ Methods
795
+ -------
796
+ init_state(batch_size=None, **kwargs)
797
+ Initialize the cell and hidden states.
798
+ reset_state(batch_size=None, **kwargs)
799
+ Reset the cell and hidden states to their initial values.
800
+ update(x)
801
+ Update the cell and hidden states for one time step and return the hidden state.
802
+
803
+ Examples
804
+ --------
805
+ .. code-block:: python
806
+
807
+ >>> import brainstate as bs
808
+ >>> import jax.numpy as jnp
809
+ >>>
810
+ >>> # Create a URLSTM cell
811
+ >>> cell = bs.nn.URLSTMCell(num_in=10, num_out=20)
812
+ >>>
813
+ >>> # Initialize the state for batch size 32
814
+ >>> cell.init_state(batch_size=32)
815
+ >>>
816
+ >>> # Process a sequence
817
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
818
+ >>> output = cell.update(x)
819
+ >>> print(output.shape) # (32, 20)
820
+ >>>
821
+ >>> # Process multiple time steps
822
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
823
+ >>> outputs = []
824
+ >>> for t in range(100):
825
+ ... output = cell.update(sequence[t])
826
+ ... outputs.append(output)
827
+ >>> outputs = jnp.stack(outputs)
828
+ >>> print(outputs.shape) # (100, 32, 20)
829
+
830
+ References
831
+ ----------
832
+ .. [1] Gu, Albert, et al. "Improving the gating mechanism of recurrent neural networks."
833
+ International conference on machine learning. PMLR, 2020.
834
+ """
835
+ __module__ = 'brainstate.nn'
836
+
837
+ def __init__(
838
+ self,
839
+ num_in: int,
840
+ num_out: int,
841
+ w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
842
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
843
+ activation: str | Callable = 'tanh',
844
+ name: str = None,
845
+ ):
846
+ super().__init__(name=name)
847
+
848
+ # parameters
849
+ self.num_out = num_out
850
+ self.num_in = num_in
851
+ self.in_size = (num_in,)
852
+ self.out_size = (num_out,)
853
+
854
+ # initializers
855
+ self._state_initializer = state_init
856
+
857
+ # activation function
858
+ if isinstance(activation, str):
859
+ self.activation = getattr(functional, activation)
860
+ else:
861
+ assert callable(activation), "The activation function should be a string or a callable function."
862
+ self.activation = activation
863
+
864
+ # weights - 4 gates: forget, retention, candidate, output
865
+ self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=None)
866
+
867
+ # Initialize untied bias using uniform distribution
868
+ self.bias = ParamState(self._forget_bias())
869
+
870
+ def _forget_bias(self):
871
+ """Initialize the forget gate bias using uniform distribution."""
872
+ # Sample from uniform distribution to encourage diverse gating
873
+ u = random.uniform(1 / self.num_out, 1 - 1 / self.num_out, (self.num_out,))
874
+ # Transform to logit space for initialization
875
+ return -jnp.log(1 / u - 1)
876
+
877
+ def init_state(self, batch_size: int = None, **kwargs):
878
+ """
879
+ Initialize the cell and hidden states.
880
+
881
+ Parameters
882
+ ----------
883
+ batch_size : int, optional
884
+ The batch size for state initialization.
885
+ **kwargs
886
+ Additional keyword arguments.
887
+ """
888
+ self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
889
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
890
+
891
+ def reset_state(self, batch_size: int = None, **kwargs):
892
+ """
893
+ Reset the cell and hidden states to their initial values.
894
+
895
+ Parameters
896
+ ----------
897
+ batch_size : int, optional
898
+ The batch size for state reset.
899
+ **kwargs
900
+ Additional keyword arguments.
901
+ """
902
+ self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
903
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
904
+
905
+ def update(self, x: ArrayLike) -> ArrayLike:
906
+ """
907
+ Update the URLSTM cell for one time step.
908
+
909
+ Parameters
910
+ ----------
911
+ x : ArrayLike
912
+ Input tensor with shape (batch_size, num_in).
913
+
914
+ Returns
915
+ -------
916
+ ArrayLike
917
+ Hidden state tensor with shape (batch_size, num_out).
918
+ """
919
+ h, c = self.h.value, self.c.value
920
+
921
+ # Concatenate input and hidden state
922
+ xh = jnp.concatenate([x, h], axis=-1)
923
+
924
+ # Compute all gates in one pass
925
+ gates = self.W(xh)
926
+ f, r, u, o = jnp.split(gates, indices_or_sections=4, axis=-1)
927
+
928
+ # Apply untied biases to forget and retention gates
929
+ f_gate = functional.sigmoid(f + self.bias.value)
930
+ r_gate = functional.sigmoid(r - self.bias.value)
931
+
932
+ # Compute unified gate
933
+ g = 2 * r_gate * f_gate + (1 - 2 * r_gate) * f_gate ** 2
934
+
935
+ # Update cell state
936
+ next_cell = g * c + (1 - g) * self.activation(u)
937
+
938
+ # Compute output gate and hidden state
939
+ o_gate = functional.sigmoid(o)
940
+ next_hidden = o_gate * self.activation(next_cell)
941
+
942
+ # Update states
943
+ self.h.value = next_hidden
944
+ self.c.value = next_cell
945
+
946
+ return next_hidden