brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. brainstate/__init__.py +169 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2319 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +1652 -1652
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1624 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1433 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +137 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +633 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +154 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +477 -477
  32. brainstate/nn/_dynamics.py +1267 -1267
  33. brainstate/nn/_dynamics_test.py +67 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +384 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/_rand_funs.py +3938 -3938
  64. brainstate/random/_rand_funs_test.py +640 -640
  65. brainstate/random/_rand_seed.py +675 -675
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1617
  68. brainstate/random/_rand_state_test.py +551 -551
  69. brainstate/transform/__init__.py +59 -59
  70. brainstate/transform/_ad_checkpoint.py +176 -176
  71. brainstate/transform/_ad_checkpoint_test.py +49 -49
  72. brainstate/transform/_autograd.py +1025 -1025
  73. brainstate/transform/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -316
  75. brainstate/transform/_conditions_test.py +220 -220
  76. brainstate/transform/_error_if.py +94 -94
  77. brainstate/transform/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -145
  79. brainstate/transform/_eval_shape_test.py +38 -38
  80. brainstate/transform/_jit.py +399 -399
  81. brainstate/transform/_jit_test.py +143 -143
  82. brainstate/transform/_loop_collect_return.py +675 -675
  83. brainstate/transform/_loop_collect_return_test.py +58 -58
  84. brainstate/transform/_loop_no_collection.py +283 -283
  85. brainstate/transform/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -2016
  87. brainstate/transform/_make_jaxpr_test.py +1510 -1510
  88. brainstate/transform/_mapping.py +529 -529
  89. brainstate/transform/_mapping_test.py +194 -194
  90. brainstate/transform/_progress_bar.py +255 -255
  91. brainstate/transform/_random.py +171 -171
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate-0.2.0.dist-info/RECORD +0 -111
  111. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  112. {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_rnns.py CHANGED
@@ -1,946 +1,946 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
-
19
- from typing import Callable, Union
20
-
21
- import jax.numpy as jnp
22
-
23
- from brainstate import random
24
- from brainstate._state import HiddenState, ParamState
25
- from brainstate.typing import ArrayLike
26
- from . import _activations as functional
27
- from . import init as init
28
- from ._linear import Linear
29
- from ._module import Module
30
-
31
- __all__ = [
32
- 'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
33
- ]
34
-
35
-
36
- class RNNCell(Module):
37
- """
38
- Base class for all recurrent neural network (RNN) cell implementations.
39
-
40
- This abstract class serves as the foundation for implementing various RNN cell types
41
- such as vanilla RNN, GRU, LSTM, and other recurrent architectures. It extends the
42
- Module class and provides common functionality and interface for recurrent units.
43
-
44
- All RNN cell implementations should inherit from this class and implement the required
45
- methods, particularly the `init_state()`, `reset_state()`, and `update()` methods that
46
- define the state initialization and recurrent dynamics.
47
-
48
- The RNNCell typically maintains hidden state(s) that are updated at each time step
49
- based on the current input and previous state values.
50
-
51
- Methods
52
- -------
53
- init_state(batch_size=None, **kwargs)
54
- Initialize the cell state variables with appropriate dimensions.
55
- reset_state(batch_size=None, **kwargs)
56
- Reset the cell state variables to their initial values.
57
- update(x)
58
- Update the cell state for one time step based on input x and return output.
59
-
60
- See Also
61
- --------
62
- ValinaRNNCell : Vanilla RNN cell implementation
63
- GRUCell : Gated Recurrent Unit cell implementation
64
- LSTMCell : Long Short-Term Memory cell implementation
65
- URLSTMCell : LSTM with UR gating mechanism
66
- MGUCell : Minimal Gated Unit cell implementation
67
- """
68
- __module__ = 'brainstate.nn'
69
- pass
70
-
71
-
72
- class ValinaRNNCell(RNNCell):
73
- r"""
74
- Vanilla Recurrent Neural Network (RNN) cell implementation.
75
-
76
- This class implements the basic RNN model that updates a hidden state based on
77
- the current input and previous hidden state. The standard RNN cell follows the
78
- mathematical formulation:
79
-
80
- .. math::
81
-
82
- h_t = \phi(W [x_t, h_{t-1}] + b)
83
-
84
- where:
85
-
86
- - :math:`x_t` is the input vector at time t
87
- - :math:`h_t` is the hidden state at time t
88
- - :math:`h_{t-1}` is the hidden state at previous time step
89
- - :math:`W` is the weight matrix for the combined input-hidden linear transformation
90
- - :math:`b` is the bias vector
91
- - :math:`\phi` is the activation function
92
-
93
- Parameters
94
- ----------
95
- num_in : int
96
- The number of input units.
97
- num_out : int
98
- The number of hidden units.
99
- state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
100
- Initializer for the hidden state.
101
- w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
102
- Initializer for the weight matrix.
103
- b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
104
- Initializer for the bias vector.
105
- activation : str or Callable, default='relu'
106
- Activation function to use. Can be a string (e.g., 'relu', 'tanh')
107
- or a callable function.
108
- name : str, optional
109
- Name of the module.
110
-
111
- Attributes
112
- ----------
113
- num_in : int
114
- Number of input features.
115
- num_out : int
116
- Number of hidden units.
117
- in_size : tuple
118
- Shape of input (num_in,).
119
- out_size : tuple
120
- Shape of output (num_out,).
121
-
122
- State Variables
123
- ---------------
124
- h : HiddenState
125
- Hidden state of the RNN cell.
126
-
127
- Methods
128
- -------
129
- init_state(batch_size=None, **kwargs)
130
- Initialize the cell hidden state.
131
- reset_state(batch_size=None, **kwargs)
132
- Reset the cell hidden state to its initial value.
133
- update(x)
134
- Update the hidden state for one time step and return the new state.
135
-
136
- Examples
137
- --------
138
- .. code-block:: python
139
-
140
- >>> import brainstate as bs
141
- >>> import jax.numpy as jnp
142
- >>>
143
- >>> # Create a vanilla RNN cell
144
- >>> cell = bs.nn.ValinaRNNCell(num_in=10, num_out=20)
145
- >>>
146
- >>> # Initialize state for batch size 32
147
- >>> cell.init_state(batch_size=32)
148
- >>>
149
- >>> # Process a single time step
150
- >>> x = jnp.ones((32, 10)) # batch_size x num_in
151
- >>> output = cell.update(x)
152
- >>> print(output.shape) # (32, 20)
153
- >>>
154
- >>> # Process a sequence of inputs
155
- >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
156
- >>> outputs = []
157
- >>> for t in range(100):
158
- ... output = cell.update(sequence[t])
159
- ... outputs.append(output)
160
- >>> outputs = jnp.stack(outputs)
161
- >>> print(outputs.shape) # (100, 32, 20)
162
-
163
- Notes
164
- -----
165
- Vanilla RNNs can suffer from vanishing or exploding gradient problems
166
- when processing long sequences. For better performance on long sequences,
167
- consider using gated architectures like GRU or LSTM.
168
- """
169
- __module__ = 'brainstate.nn'
170
-
171
- def __init__(
172
- self,
173
- num_in: int,
174
- num_out: int,
175
- state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
176
- w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
177
- b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
178
- activation: str | Callable = 'relu',
179
- name: str = None,
180
- ):
181
- super().__init__(name=name)
182
-
183
- # parameters
184
- self.num_out = num_out
185
- self.num_in = num_in
186
- self.in_size = (num_in,)
187
- self.out_size = (num_out,)
188
- self._state_initializer = state_init
189
-
190
- # activation function
191
- if isinstance(activation, str):
192
- self.activation = getattr(functional, activation)
193
- else:
194
- assert callable(activation), "The activation function should be a string or a callable function. "
195
- self.activation = activation
196
-
197
- # weights
198
- self.W = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
199
-
200
- def init_state(self, batch_size: int = None, **kwargs):
201
- """
202
- Initialize the hidden state.
203
-
204
- Parameters
205
- ----------
206
- batch_size : int, optional
207
- The batch size for state initialization.
208
- **kwargs
209
- Additional keyword arguments.
210
- """
211
- self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
212
-
213
- def reset_state(self, batch_size: int = None, **kwargs):
214
- """
215
- Reset the hidden state to initial value.
216
-
217
- Parameters
218
- ----------
219
- batch_size : int, optional
220
- The batch size for state reset.
221
- **kwargs
222
- Additional keyword arguments.
223
- """
224
- self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
225
-
226
- def update(self, x):
227
- xh = jnp.concatenate([x, self.h.value], axis=-1)
228
- h = self.W(xh)
229
- self.h.value = self.activation(h)
230
- return self.h.value
231
-
232
-
233
- class GRUCell(RNNCell):
234
- r"""
235
- Gated Recurrent Unit (GRU) cell implementation.
236
-
237
- The GRU is a gating mechanism in recurrent neural networks that aims to solve
238
- the vanishing gradient problem. It uses gating mechanisms to control information
239
- flow and has fewer parameters than LSTM as it combines the forget and input gates
240
- into a single update gate.
241
-
242
- The GRU cell follows the mathematical formulation:
243
-
244
- .. math::
245
-
246
- r_t &= \sigma(W_r [x_t, h_{t-1}] + b_r) \\
247
- z_t &= \sigma(W_z [x_t, h_{t-1}] + b_z) \\
248
- \tilde{h}_t &= \phi(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\
249
- h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
250
-
251
- where:
252
-
253
- - :math:`x_t` is the input vector at time t
254
- - :math:`h_t` is the hidden state at time t
255
- - :math:`r_t` is the reset gate vector
256
- - :math:`z_t` is the update gate vector
257
- - :math:`\tilde{h}_t` is the candidate hidden state
258
- - :math:`\odot` represents element-wise multiplication
259
- - :math:`\sigma` is the sigmoid activation function
260
- - :math:`\phi` is the activation function (typically tanh)
261
-
262
- Parameters
263
- ----------
264
- num_in : int
265
- The number of input units.
266
- num_out : int
267
- The number of hidden units.
268
- w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
269
- Initializer for the weight matrices.
270
- b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
271
- Initializer for the bias vectors.
272
- state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
273
- Initializer for the hidden state.
274
- activation : str or Callable, default='tanh'
275
- Activation function to use. Can be a string (e.g., 'tanh', 'relu')
276
- or a callable function.
277
- name : str, optional
278
- Name of the module.
279
-
280
- Attributes
281
- ----------
282
- num_in : int
283
- Number of input features.
284
- num_out : int
285
- Number of hidden units.
286
- in_size : tuple
287
- Shape of input (num_in,).
288
- out_size : tuple
289
- Shape of output (num_out,).
290
-
291
- State Variables
292
- ---------------
293
- h : HiddenState
294
- Hidden state of the GRU cell.
295
-
296
- Methods
297
- -------
298
- init_state(batch_size=None, **kwargs)
299
- Initialize the cell hidden state.
300
- reset_state(batch_size=None, **kwargs)
301
- Reset the cell hidden state to its initial value.
302
- update(x)
303
- Update the hidden state for one time step and return the new state.
304
-
305
- Examples
306
- --------
307
- .. code-block:: python
308
-
309
- >>> import brainstate as bs
310
- >>> import jax.numpy as jnp
311
- >>>
312
- >>> # Create a GRU cell
313
- >>> cell = bs.nn.GRUCell(num_in=10, num_out=20)
314
- >>>
315
- >>> # Initialize state for batch size 32
316
- >>> cell.init_state(batch_size=32)
317
- >>>
318
- >>> # Process a single time step
319
- >>> x = jnp.ones((32, 10)) # batch_size x num_in
320
- >>> output = cell.update(x)
321
- >>> print(output.shape) # (32, 20)
322
- >>>
323
- >>> # Process a sequence
324
- >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
325
- >>> outputs = []
326
- >>> for t in range(100):
327
- ... output = cell.update(sequence[t])
328
- ... outputs.append(output)
329
- >>> outputs = jnp.stack(outputs)
330
- >>> print(outputs.shape) # (100, 32, 20)
331
- >>>
332
- >>> # Reset state with different batch size
333
- >>> cell.reset_state(batch_size=16)
334
- >>> x_new = jnp.ones((16, 10))
335
- >>> output_new = cell.update(x_new)
336
- >>> print(output_new.shape) # (16, 20)
337
-
338
- Notes
339
- -----
340
- GRU cells are computationally more efficient than LSTM cells due to having
341
- fewer parameters, while often achieving comparable performance on many tasks.
342
-
343
- References
344
- ----------
345
- .. [1] Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F.,
346
- Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using
347
- RNN encoder-decoder for statistical machine translation.
348
- arXiv preprint arXiv:1406.1078.
349
- """
350
- __module__ = 'brainstate.nn'
351
-
352
- def __init__(
353
- self,
354
- num_in: int,
355
- num_out: int,
356
- w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
357
- b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
358
- state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
359
- activation: str | Callable = 'tanh',
360
- name: str = None,
361
- ):
362
- super().__init__(name=name)
363
-
364
- # parameters
365
- self._state_initializer = state_init
366
- self.num_out = num_out
367
- self.num_in = num_in
368
- self.in_size = (num_in,)
369
- self.out_size = (num_out,)
370
-
371
- # activation function
372
- if isinstance(activation, str):
373
- self.activation = getattr(functional, activation)
374
- else:
375
- assert callable(activation), "The activation function should be a string or a callable function. "
376
- self.activation = activation
377
-
378
- # weights
379
- self.Wrz = Linear(num_in + num_out, num_out * 2, w_init=w_init, b_init=b_init)
380
- self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
381
-
382
- def init_state(self, batch_size: int = None, **kwargs):
383
- self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
384
-
385
- def reset_state(self, batch_size: int = None, **kwargs):
386
- self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
387
-
388
- def update(self, x):
389
- old_h = self.h.value
390
- xh = jnp.concatenate([x, old_h], axis=-1)
391
- r, z = jnp.split(functional.sigmoid(self.Wrz(xh)), indices_or_sections=2, axis=-1)
392
- rh = r * old_h
393
- h = self.activation(self.Wh(jnp.concatenate([x, rh], axis=-1)))
394
- h = (1 - z) * old_h + z * h
395
- self.h.value = h
396
- return h
397
-
398
-
399
- class MGUCell(RNNCell):
400
- r"""
401
- Minimal Gated Unit (MGU) cell implementation.
402
-
403
- MGU is a simplified version of GRU that uses a single forget gate instead of
404
- separate reset and update gates. This design significantly reduces the number
405
- of parameters while maintaining much of the gating capability. MGU provides
406
- a good trade-off between model complexity and performance.
407
-
408
- The MGU cell follows the mathematical formulation:
409
-
410
- .. math::
411
-
412
- f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
413
- \tilde{h}_t &= \phi(W_h [x_t, (f_t \odot h_{t-1})] + b_h) \\
414
- h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t
415
-
416
- where:
417
-
418
- - :math:`x_t` is the input vector at time t
419
- - :math:`h_t` is the hidden state at time t
420
- - :math:`f_t` is the forget gate vector
421
- - :math:`\tilde{h}_t` is the candidate hidden state
422
- - :math:`\odot` represents element-wise multiplication
423
- - :math:`\sigma` is the sigmoid activation function
424
- - :math:`\phi` is the activation function (typically tanh)
425
-
426
- Parameters
427
- ----------
428
- num_in : int
429
- The number of input units.
430
- num_out : int
431
- The number of hidden units.
432
- w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
433
- Initializer for the weight matrices.
434
- b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
435
- Initializer for the bias vectors.
436
- state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
437
- Initializer for the hidden state.
438
- activation : str or Callable, default='tanh'
439
- Activation function to use. Can be a string (e.g., 'tanh', 'relu')
440
- or a callable function.
441
- name : str, optional
442
- Name of the module.
443
-
444
- Attributes
445
- ----------
446
- num_in : int
447
- Number of input features.
448
- num_out : int
449
- Number of hidden units.
450
- in_size : tuple
451
- Shape of input (num_in,).
452
- out_size : tuple
453
- Shape of output (num_out,).
454
-
455
- State Variables
456
- ---------------
457
- h : HiddenState
458
- Hidden state of the MGU cell.
459
-
460
- Methods
461
- -------
462
- init_state(batch_size=None, **kwargs)
463
- Initialize the cell hidden state.
464
- reset_state(batch_size=None, **kwargs)
465
- Reset the cell hidden state to its initial value.
466
- update(x)
467
- Update the hidden state for one time step and return the new state.
468
-
469
- Examples
470
- --------
471
- .. code-block:: python
472
-
473
- >>> import brainstate as bs
474
- >>> import jax.numpy as jnp
475
- >>>
476
- >>> # Create an MGU cell
477
- >>> cell = bs.nn.MGUCell(num_in=10, num_out=20)
478
- >>>
479
- >>> # Initialize state for batch size 32
480
- >>> cell.init_state(batch_size=32)
481
- >>>
482
- >>> # Process a single time step
483
- >>> x = jnp.ones((32, 10)) # batch_size x num_in
484
- >>> output = cell.update(x)
485
- >>> print(output.shape) # (32, 20)
486
- >>>
487
- >>> # Process a sequence
488
- >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
489
- >>> outputs = []
490
- >>> for t in range(100):
491
- ... output = cell.update(sequence[t])
492
- ... outputs.append(output)
493
- >>> outputs = jnp.stack(outputs)
494
- >>> print(outputs.shape) # (100, 32, 20)
495
-
496
- Notes
497
- -----
498
- MGU provides a lightweight alternative to GRU and LSTM, making it suitable
499
- for resource-constrained applications or when model simplicity is preferred.
500
-
501
- References
502
- ----------
503
- .. [1] Zhou, G. B., Wu, J., Zhang, C. L., & Zhou, Z. H. (2016). Minimal gated unit
504
- for recurrent neural networks. International Journal of Automation and Computing,
505
- 13(3), 226-234.
506
- """
507
- __module__ = 'brainstate.nn'
508
-
509
- def __init__(
510
- self,
511
- num_in: int,
512
- num_out: int,
513
- w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
514
- b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
515
- state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
516
- activation: str | Callable = 'tanh',
517
- name: str = None,
518
- ):
519
- super().__init__(name=name)
520
-
521
- # parameters
522
- self._state_initializer = state_init
523
- self.num_out = num_out
524
- self.num_in = num_in
525
- self.in_size = (num_in,)
526
- self.out_size = (num_out,)
527
-
528
- # activation function
529
- if isinstance(activation, str):
530
- self.activation = getattr(functional, activation)
531
- else:
532
- assert callable(activation), "The activation function should be a string or a callable function. "
533
- self.activation = activation
534
-
535
- # weights
536
- self.Wf = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
537
- self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
538
-
539
- def init_state(self, batch_size: int = None, **kwargs):
540
- self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
541
-
542
- def reset_state(self, batch_size: int = None, **kwargs):
543
- self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
544
-
545
- def update(self, x):
546
- old_h = self.h.value
547
- xh = jnp.concatenate([x, old_h], axis=-1)
548
- f = functional.sigmoid(self.Wf(xh))
549
- fh = f * old_h
550
- h = self.activation(self.Wh(jnp.concatenate([x, fh], axis=-1)))
551
- self.h.value = (1 - f) * self.h.value + f * h
552
- return self.h.value
553
-
554
-
555
- class LSTMCell(RNNCell):
556
- r"""
557
- Long Short-Term Memory (LSTM) cell implementation.
558
-
559
- LSTM is a type of RNN architecture designed to address the vanishing gradient
560
- problem and learn long-term dependencies. It uses a cell state to carry
561
- information across time steps and three gates (input, forget, output) to
562
- control information flow.
563
-
564
- The LSTM cell follows the mathematical formulation:
565
-
566
- .. math::
567
-
568
- i_t &= \sigma(W_i [x_t, h_{t-1}] + b_i) \\
569
- f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
570
- g_t &= \phi(W_g [x_t, h_{t-1}] + b_g) \\
571
- o_t &= \sigma(W_o [x_t, h_{t-1}] + b_o) \\
572
- c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\
573
- h_t &= o_t \odot \phi(c_t)
574
-
575
- where:
576
-
577
- - :math:`x_t` is the input vector at time t
578
- - :math:`h_t` is the hidden state at time t
579
- - :math:`c_t` is the cell state at time t
580
- - :math:`i_t` is the input gate activation
581
- - :math:`f_t` is the forget gate activation
582
- - :math:`o_t` is the output gate activation
583
- - :math:`g_t` is the cell update (candidate) vector
584
- - :math:`\odot` represents element-wise multiplication
585
- - :math:`\sigma` is the sigmoid activation function
586
- - :math:`\phi` is the activation function (typically tanh)
587
-
588
- Parameters
589
- ----------
590
- num_in : int
591
- The number of input units.
592
- num_out : int
593
- The number of hidden/cell units.
594
- w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
595
- Initializer for the weight matrices.
596
- b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
597
- Initializer for the bias vectors.
598
- state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
599
- Initializer for the hidden and cell states.
600
- activation : str or Callable, default='tanh'
601
- Activation function to use. Can be a string (e.g., 'tanh', 'relu')
602
- or a callable function.
603
- name : str, optional
604
- Name of the module.
605
-
606
- Attributes
607
- ----------
608
- num_in : int
609
- Number of input features.
610
- num_out : int
611
- Number of hidden/cell units.
612
- in_size : tuple
613
- Shape of input (num_in,).
614
- out_size : tuple
615
- Shape of output (num_out,).
616
-
617
- State Variables
618
- ---------------
619
- h : HiddenState
620
- Hidden state of the LSTM cell.
621
- c : HiddenState
622
- Cell state of the LSTM cell.
623
-
624
- Methods
625
- -------
626
- init_state(batch_size=None, **kwargs)
627
- Initialize the cell and hidden states.
628
- reset_state(batch_size=None, **kwargs)
629
- Reset the cell and hidden states to their initial values.
630
- update(x)
631
- Update the states for one time step and return the new hidden state.
632
-
633
- Examples
634
- --------
635
- .. code-block:: python
636
-
637
- >>> import brainstate as bs
638
- >>> import jax.numpy as jnp
639
- >>>
640
- >>> # Create an LSTM cell
641
- >>> cell = bs.nn.LSTMCell(num_in=10, num_out=20)
642
- >>>
643
- >>> # Initialize states for batch size 32
644
- >>> cell.init_state(batch_size=32)
645
- >>>
646
- >>> # Process a single time step
647
- >>> x = jnp.ones((32, 10)) # batch_size x num_in
648
- >>> output = cell.update(x)
649
- >>> print(output.shape) # (32, 20)
650
- >>>
651
- >>> # Process a sequence
652
- >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
653
- >>> outputs = []
654
- >>> for t in range(100):
655
- ... output = cell.update(sequence[t])
656
- ... outputs.append(output)
657
- >>> outputs = jnp.stack(outputs)
658
- >>> print(outputs.shape) # (100, 32, 20)
659
- >>>
660
- >>> # Access cell state
661
- >>> print(cell.c.value.shape) # (32, 20)
662
- >>> print(cell.h.value.shape) # (32, 20)
663
-
664
- Notes
665
- -----
666
- - The forget gate bias is initialized with +1.0 following Jozefowicz et al. (2015)
667
- to reduce forgetting at the beginning of training.
668
- - LSTM cells are effective for learning long-term dependencies but require
669
- more parameters and computation than simpler RNN variants.
670
-
671
- References
672
- ----------
673
- .. [1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.
674
- Neural computation, 9(8), 1735-1780.
675
- .. [2] Gers, F. A., Schmidhuber, J., & Cummins, F. (2000). Learning to forget:
676
- Continual prediction with LSTM. Neural computation, 12(10), 2451-2471.
677
- .. [3] Jozefowicz, R., Zaremba, W., & Sutskever, I. (2015). An empirical
678
- exploration of recurrent network architectures. In International
679
- conference on machine learning (pp. 2342-2350).
680
- """
681
- __module__ = 'brainstate.nn'
682
-
683
- def __init__(
684
- self,
685
- num_in: int,
686
- num_out: int,
687
- w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
688
- b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
689
- state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
690
- activation: str | Callable = 'tanh',
691
- name: str = None,
692
- ):
693
- super().__init__(name=name)
694
-
695
- # parameters
696
- self.num_out = num_out
697
- self.num_in = num_in
698
- self.in_size = (num_in,)
699
- self.out_size = (num_out,)
700
-
701
- # initializers
702
- self._state_initializer = state_init
703
-
704
- # activation function
705
- if isinstance(activation, str):
706
- self.activation = getattr(functional, activation)
707
- else:
708
- assert callable(activation), "The activation function should be a string or a callable function. "
709
- self.activation = activation
710
-
711
- # weights
712
- self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=b_init)
713
-
714
- def init_state(self, batch_size: int = None, **kwargs):
715
- self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
716
- self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
717
-
718
- def reset_state(self, batch_size: int = None, **kwargs):
719
- self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
720
- self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
721
-
722
- def update(self, x):
723
- h, c = self.h.value, self.c.value
724
- xh = jnp.concat([x, h], axis=-1)
725
- i, g, f, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
726
- c = functional.sigmoid(f + 1.) * c + functional.sigmoid(i) * self.activation(g)
727
- h = functional.sigmoid(o) * self.activation(c)
728
- self.h.value = h
729
- self.c.value = c
730
- return h
731
-
732
-
733
- class URLSTMCell(RNNCell):
734
- r"""LSTM with UR gating mechanism.
735
-
736
- URLSTM is a modification of the standard LSTM that uses untied (separate) biases
737
- for the forget and retention mechanisms, allowing for more flexible gating control.
738
- This implementation is based on the paper "Improving the Gating Mechanism of
739
- Recurrent Neural Networks" by Gers et al.
740
-
741
- The URLSTM cell follows the mathematical formulation:
742
-
743
- .. math::
744
-
745
- f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
746
- r_t &= \sigma(W_r [x_t, h_{t-1}] - b_f) \\
747
- g_t &= 2 r_t \odot f_t + (1 - 2 r_t) \odot f_t^2 \\
748
- \tilde{c}_t &= \phi(W_c [x_t, h_{t-1}]) \\
749
- c_t &= g_t \odot c_{t-1} + (1 - g_t) \odot \tilde{c}_t \\
750
- o_t &= \sigma(W_o [x_t, h_{t-1}]) \\
751
- h_t &= o_t \odot \phi(c_t)
752
-
753
- where:
754
-
755
- - :math:`x_t` is the input vector at time t
756
- - :math:`h_t` is the hidden state at time t
757
- - :math:`c_t` is the cell state at time t
758
- - :math:`f_t` is the forget gate with positive bias
759
- - :math:`r_t` is the retention gate with negative bias
760
- - :math:`g_t` is the unified gate combining forget and retention
761
- - :math:`\tilde{c}_t` is the candidate cell state
762
- - :math:`o_t` is the output gate
763
- - :math:`\odot` represents element-wise multiplication
764
- - :math:`\sigma` is the sigmoid activation function
765
- - :math:`\phi` is the activation function (typically tanh)
766
-
767
- The key innovation is the untied bias mechanism where the forget and retention
768
- gates use opposite biases, initialized using a uniform distribution to encourage
769
- diverse gating behavior across units.
770
-
771
- Parameters
772
- ----------
773
- num_in : int
774
- The number of input units.
775
- num_out : int
776
- The number of hidden/output units.
777
- w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
778
- Initializer for the weight matrix.
779
- state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
780
- Initializer for the hidden and cell states.
781
- activation : str or Callable, default='tanh'
782
- Activation function to use. Can be a string (e.g., 'relu', 'tanh')
783
- or a callable function.
784
- name : str, optional
785
- Name of the module.
786
-
787
- State Variables
788
- ---------------
789
- h : HiddenState
790
- Hidden state of the URLSTM cell.
791
- c : HiddenState
792
- Cell state of the URLSTM cell.
793
-
794
- Methods
795
- -------
796
- init_state(batch_size=None, **kwargs)
797
- Initialize the cell and hidden states.
798
- reset_state(batch_size=None, **kwargs)
799
- Reset the cell and hidden states to their initial values.
800
- update(x)
801
- Update the cell and hidden states for one time step and return the hidden state.
802
-
803
- Examples
804
- --------
805
- .. code-block:: python
806
-
807
- >>> import brainstate as bs
808
- >>> import jax.numpy as jnp
809
- >>>
810
- >>> # Create a URLSTM cell
811
- >>> cell = bs.nn.URLSTMCell(num_in=10, num_out=20)
812
- >>>
813
- >>> # Initialize the state for batch size 32
814
- >>> cell.init_state(batch_size=32)
815
- >>>
816
- >>> # Process a sequence
817
- >>> x = jnp.ones((32, 10)) # batch_size x num_in
818
- >>> output = cell.update(x)
819
- >>> print(output.shape) # (32, 20)
820
- >>>
821
- >>> # Process multiple time steps
822
- >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
823
- >>> outputs = []
824
- >>> for t in range(100):
825
- ... output = cell.update(sequence[t])
826
- ... outputs.append(output)
827
- >>> outputs = jnp.stack(outputs)
828
- >>> print(outputs.shape) # (100, 32, 20)
829
-
830
- References
831
- ----------
832
- .. [1] Gu, Albert, et al. "Improving the gating mechanism of recurrent neural networks."
833
- International conference on machine learning. PMLR, 2020.
834
- """
835
- __module__ = 'brainstate.nn'
836
-
837
- def __init__(
838
- self,
839
- num_in: int,
840
- num_out: int,
841
- w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
842
- state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
843
- activation: str | Callable = 'tanh',
844
- name: str = None,
845
- ):
846
- super().__init__(name=name)
847
-
848
- # parameters
849
- self.num_out = num_out
850
- self.num_in = num_in
851
- self.in_size = (num_in,)
852
- self.out_size = (num_out,)
853
-
854
- # initializers
855
- self._state_initializer = state_init
856
-
857
- # activation function
858
- if isinstance(activation, str):
859
- self.activation = getattr(functional, activation)
860
- else:
861
- assert callable(activation), "The activation function should be a string or a callable function."
862
- self.activation = activation
863
-
864
- # weights - 4 gates: forget, retention, candidate, output
865
- self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=None)
866
-
867
- # Initialize untied bias using uniform distribution
868
- self.bias = ParamState(self._forget_bias())
869
-
870
- def _forget_bias(self):
871
- """Initialize the forget gate bias using uniform distribution."""
872
- # Sample from uniform distribution to encourage diverse gating
873
- u = random.uniform(1 / self.num_out, 1 - 1 / self.num_out, (self.num_out,))
874
- # Transform to logit space for initialization
875
- return -jnp.log(1 / u - 1)
876
-
877
- def init_state(self, batch_size: int = None, **kwargs):
878
- """
879
- Initialize the cell and hidden states.
880
-
881
- Parameters
882
- ----------
883
- batch_size : int, optional
884
- The batch size for state initialization.
885
- **kwargs
886
- Additional keyword arguments.
887
- """
888
- self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
889
- self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
890
-
891
- def reset_state(self, batch_size: int = None, **kwargs):
892
- """
893
- Reset the cell and hidden states to their initial values.
894
-
895
- Parameters
896
- ----------
897
- batch_size : int, optional
898
- The batch size for state reset.
899
- **kwargs
900
- Additional keyword arguments.
901
- """
902
- self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
903
- self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
904
-
905
- def update(self, x: ArrayLike) -> ArrayLike:
906
- """
907
- Update the URLSTM cell for one time step.
908
-
909
- Parameters
910
- ----------
911
- x : ArrayLike
912
- Input tensor with shape (batch_size, num_in).
913
-
914
- Returns
915
- -------
916
- ArrayLike
917
- Hidden state tensor with shape (batch_size, num_out).
918
- """
919
- h, c = self.h.value, self.c.value
920
-
921
- # Concatenate input and hidden state
922
- xh = jnp.concatenate([x, h], axis=-1)
923
-
924
- # Compute all gates in one pass
925
- gates = self.W(xh)
926
- f, r, u, o = jnp.split(gates, indices_or_sections=4, axis=-1)
927
-
928
- # Apply untied biases to forget and retention gates
929
- f_gate = functional.sigmoid(f + self.bias.value)
930
- r_gate = functional.sigmoid(r - self.bias.value)
931
-
932
- # Compute unified gate
933
- g = 2 * r_gate * f_gate + (1 - 2 * r_gate) * f_gate ** 2
934
-
935
- # Update cell state
936
- next_cell = g * c + (1 - g) * self.activation(u)
937
-
938
- # Compute output gate and hidden state
939
- o_gate = functional.sigmoid(o)
940
- next_hidden = o_gate * self.activation(next_cell)
941
-
942
- # Update states
943
- self.h.value = next_hidden
944
- self.c.value = next_cell
945
-
946
- return next_hidden
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+
19
+ from typing import Callable, Union
20
+
21
+ import jax.numpy as jnp
22
+
23
+ from brainstate import random
24
+ from brainstate._state import HiddenState, ParamState
25
+ from brainstate.typing import ArrayLike
26
+ from . import _activations as functional
27
+ from . import init as init
28
+ from ._linear import Linear
29
+ from ._module import Module
30
+
31
+ __all__ = [
32
+ 'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
33
+ ]
34
+
35
+
36
+ class RNNCell(Module):
37
+ """
38
+ Base class for all recurrent neural network (RNN) cell implementations.
39
+
40
+ This abstract class serves as the foundation for implementing various RNN cell types
41
+ such as vanilla RNN, GRU, LSTM, and other recurrent architectures. It extends the
42
+ Module class and provides common functionality and interface for recurrent units.
43
+
44
+ All RNN cell implementations should inherit from this class and implement the required
45
+ methods, particularly the `init_state()`, `reset_state()`, and `update()` methods that
46
+ define the state initialization and recurrent dynamics.
47
+
48
+ The RNNCell typically maintains hidden state(s) that are updated at each time step
49
+ based on the current input and previous state values.
50
+
51
+ Methods
52
+ -------
53
+ init_state(batch_size=None, **kwargs)
54
+ Initialize the cell state variables with appropriate dimensions.
55
+ reset_state(batch_size=None, **kwargs)
56
+ Reset the cell state variables to their initial values.
57
+ update(x)
58
+ Update the cell state for one time step based on input x and return output.
59
+
60
+ See Also
61
+ --------
62
+ ValinaRNNCell : Vanilla RNN cell implementation
63
+ GRUCell : Gated Recurrent Unit cell implementation
64
+ LSTMCell : Long Short-Term Memory cell implementation
65
+ URLSTMCell : LSTM with UR gating mechanism
66
+ MGUCell : Minimal Gated Unit cell implementation
67
+ """
68
+ __module__ = 'brainstate.nn'
69
+ pass
70
+
71
+
72
+ class ValinaRNNCell(RNNCell):
73
+ r"""
74
+ Vanilla Recurrent Neural Network (RNN) cell implementation.
75
+
76
+ This class implements the basic RNN model that updates a hidden state based on
77
+ the current input and previous hidden state. The standard RNN cell follows the
78
+ mathematical formulation:
79
+
80
+ .. math::
81
+
82
+ h_t = \phi(W [x_t, h_{t-1}] + b)
83
+
84
+ where:
85
+
86
+ - :math:`x_t` is the input vector at time t
87
+ - :math:`h_t` is the hidden state at time t
88
+ - :math:`h_{t-1}` is the hidden state at previous time step
89
+ - :math:`W` is the weight matrix for the combined input-hidden linear transformation
90
+ - :math:`b` is the bias vector
91
+ - :math:`\phi` is the activation function
92
+
93
+ Parameters
94
+ ----------
95
+ num_in : int
96
+ The number of input units.
97
+ num_out : int
98
+ The number of hidden units.
99
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
100
+ Initializer for the hidden state.
101
+ w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
102
+ Initializer for the weight matrix.
103
+ b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
104
+ Initializer for the bias vector.
105
+ activation : str or Callable, default='relu'
106
+ Activation function to use. Can be a string (e.g., 'relu', 'tanh')
107
+ or a callable function.
108
+ name : str, optional
109
+ Name of the module.
110
+
111
+ Attributes
112
+ ----------
113
+ num_in : int
114
+ Number of input features.
115
+ num_out : int
116
+ Number of hidden units.
117
+ in_size : tuple
118
+ Shape of input (num_in,).
119
+ out_size : tuple
120
+ Shape of output (num_out,).
121
+
122
+ State Variables
123
+ ---------------
124
+ h : HiddenState
125
+ Hidden state of the RNN cell.
126
+
127
+ Methods
128
+ -------
129
+ init_state(batch_size=None, **kwargs)
130
+ Initialize the cell hidden state.
131
+ reset_state(batch_size=None, **kwargs)
132
+ Reset the cell hidden state to its initial value.
133
+ update(x)
134
+ Update the hidden state for one time step and return the new state.
135
+
136
+ Examples
137
+ --------
138
+ .. code-block:: python
139
+
140
+ >>> import brainstate as bs
141
+ >>> import jax.numpy as jnp
142
+ >>>
143
+ >>> # Create a vanilla RNN cell
144
+ >>> cell = bs.nn.ValinaRNNCell(num_in=10, num_out=20)
145
+ >>>
146
+ >>> # Initialize state for batch size 32
147
+ >>> cell.init_state(batch_size=32)
148
+ >>>
149
+ >>> # Process a single time step
150
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
151
+ >>> output = cell.update(x)
152
+ >>> print(output.shape) # (32, 20)
153
+ >>>
154
+ >>> # Process a sequence of inputs
155
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
156
+ >>> outputs = []
157
+ >>> for t in range(100):
158
+ ... output = cell.update(sequence[t])
159
+ ... outputs.append(output)
160
+ >>> outputs = jnp.stack(outputs)
161
+ >>> print(outputs.shape) # (100, 32, 20)
162
+
163
+ Notes
164
+ -----
165
+ Vanilla RNNs can suffer from vanishing or exploding gradient problems
166
+ when processing long sequences. For better performance on long sequences,
167
+ consider using gated architectures like GRU or LSTM.
168
+ """
169
+ __module__ = 'brainstate.nn'
170
+
171
+ def __init__(
172
+ self,
173
+ num_in: int,
174
+ num_out: int,
175
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
176
+ w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
177
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
178
+ activation: str | Callable = 'relu',
179
+ name: str = None,
180
+ ):
181
+ super().__init__(name=name)
182
+
183
+ # parameters
184
+ self.num_out = num_out
185
+ self.num_in = num_in
186
+ self.in_size = (num_in,)
187
+ self.out_size = (num_out,)
188
+ self._state_initializer = state_init
189
+
190
+ # activation function
191
+ if isinstance(activation, str):
192
+ self.activation = getattr(functional, activation)
193
+ else:
194
+ assert callable(activation), "The activation function should be a string or a callable function. "
195
+ self.activation = activation
196
+
197
+ # weights
198
+ self.W = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
199
+
200
+ def init_state(self, batch_size: int = None, **kwargs):
201
+ """
202
+ Initialize the hidden state.
203
+
204
+ Parameters
205
+ ----------
206
+ batch_size : int, optional
207
+ The batch size for state initialization.
208
+ **kwargs
209
+ Additional keyword arguments.
210
+ """
211
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
212
+
213
+ def reset_state(self, batch_size: int = None, **kwargs):
214
+ """
215
+ Reset the hidden state to initial value.
216
+
217
+ Parameters
218
+ ----------
219
+ batch_size : int, optional
220
+ The batch size for state reset.
221
+ **kwargs
222
+ Additional keyword arguments.
223
+ """
224
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
225
+
226
+ def update(self, x):
227
+ xh = jnp.concatenate([x, self.h.value], axis=-1)
228
+ h = self.W(xh)
229
+ self.h.value = self.activation(h)
230
+ return self.h.value
231
+
232
+
233
+ class GRUCell(RNNCell):
234
+ r"""
235
+ Gated Recurrent Unit (GRU) cell implementation.
236
+
237
+ The GRU is a gating mechanism in recurrent neural networks that aims to solve
238
+ the vanishing gradient problem. It uses gating mechanisms to control information
239
+ flow and has fewer parameters than LSTM as it combines the forget and input gates
240
+ into a single update gate.
241
+
242
+ The GRU cell follows the mathematical formulation:
243
+
244
+ .. math::
245
+
246
+ r_t &= \sigma(W_r [x_t, h_{t-1}] + b_r) \\
247
+ z_t &= \sigma(W_z [x_t, h_{t-1}] + b_z) \\
248
+ \tilde{h}_t &= \phi(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\
249
+ h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
250
+
251
+ where:
252
+
253
+ - :math:`x_t` is the input vector at time t
254
+ - :math:`h_t` is the hidden state at time t
255
+ - :math:`r_t` is the reset gate vector
256
+ - :math:`z_t` is the update gate vector
257
+ - :math:`\tilde{h}_t` is the candidate hidden state
258
+ - :math:`\odot` represents element-wise multiplication
259
+ - :math:`\sigma` is the sigmoid activation function
260
+ - :math:`\phi` is the activation function (typically tanh)
261
+
262
+ Parameters
263
+ ----------
264
+ num_in : int
265
+ The number of input units.
266
+ num_out : int
267
+ The number of hidden units.
268
+ w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
269
+ Initializer for the weight matrices.
270
+ b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
271
+ Initializer for the bias vectors.
272
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
273
+ Initializer for the hidden state.
274
+ activation : str or Callable, default='tanh'
275
+ Activation function to use. Can be a string (e.g., 'tanh', 'relu')
276
+ or a callable function.
277
+ name : str, optional
278
+ Name of the module.
279
+
280
+ Attributes
281
+ ----------
282
+ num_in : int
283
+ Number of input features.
284
+ num_out : int
285
+ Number of hidden units.
286
+ in_size : tuple
287
+ Shape of input (num_in,).
288
+ out_size : tuple
289
+ Shape of output (num_out,).
290
+
291
+ State Variables
292
+ ---------------
293
+ h : HiddenState
294
+ Hidden state of the GRU cell.
295
+
296
+ Methods
297
+ -------
298
+ init_state(batch_size=None, **kwargs)
299
+ Initialize the cell hidden state.
300
+ reset_state(batch_size=None, **kwargs)
301
+ Reset the cell hidden state to its initial value.
302
+ update(x)
303
+ Update the hidden state for one time step and return the new state.
304
+
305
+ Examples
306
+ --------
307
+ .. code-block:: python
308
+
309
+ >>> import brainstate as bs
310
+ >>> import jax.numpy as jnp
311
+ >>>
312
+ >>> # Create a GRU cell
313
+ >>> cell = bs.nn.GRUCell(num_in=10, num_out=20)
314
+ >>>
315
+ >>> # Initialize state for batch size 32
316
+ >>> cell.init_state(batch_size=32)
317
+ >>>
318
+ >>> # Process a single time step
319
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
320
+ >>> output = cell.update(x)
321
+ >>> print(output.shape) # (32, 20)
322
+ >>>
323
+ >>> # Process a sequence
324
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
325
+ >>> outputs = []
326
+ >>> for t in range(100):
327
+ ... output = cell.update(sequence[t])
328
+ ... outputs.append(output)
329
+ >>> outputs = jnp.stack(outputs)
330
+ >>> print(outputs.shape) # (100, 32, 20)
331
+ >>>
332
+ >>> # Reset state with different batch size
333
+ >>> cell.reset_state(batch_size=16)
334
+ >>> x_new = jnp.ones((16, 10))
335
+ >>> output_new = cell.update(x_new)
336
+ >>> print(output_new.shape) # (16, 20)
337
+
338
+ Notes
339
+ -----
340
+ GRU cells are computationally more efficient than LSTM cells due to having
341
+ fewer parameters, while often achieving comparable performance on many tasks.
342
+
343
+ References
344
+ ----------
345
+ .. [1] Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F.,
346
+ Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using
347
+ RNN encoder-decoder for statistical machine translation.
348
+ arXiv preprint arXiv:1406.1078.
349
+ """
350
+ __module__ = 'brainstate.nn'
351
+
352
+ def __init__(
353
+ self,
354
+ num_in: int,
355
+ num_out: int,
356
+ w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
357
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
358
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
359
+ activation: str | Callable = 'tanh',
360
+ name: str = None,
361
+ ):
362
+ super().__init__(name=name)
363
+
364
+ # parameters
365
+ self._state_initializer = state_init
366
+ self.num_out = num_out
367
+ self.num_in = num_in
368
+ self.in_size = (num_in,)
369
+ self.out_size = (num_out,)
370
+
371
+ # activation function
372
+ if isinstance(activation, str):
373
+ self.activation = getattr(functional, activation)
374
+ else:
375
+ assert callable(activation), "The activation function should be a string or a callable function. "
376
+ self.activation = activation
377
+
378
+ # weights
379
+ self.Wrz = Linear(num_in + num_out, num_out * 2, w_init=w_init, b_init=b_init)
380
+ self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
381
+
382
+ def init_state(self, batch_size: int = None, **kwargs):
383
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
384
+
385
+ def reset_state(self, batch_size: int = None, **kwargs):
386
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
387
+
388
+ def update(self, x):
389
+ old_h = self.h.value
390
+ xh = jnp.concatenate([x, old_h], axis=-1)
391
+ r, z = jnp.split(functional.sigmoid(self.Wrz(xh)), indices_or_sections=2, axis=-1)
392
+ rh = r * old_h
393
+ h = self.activation(self.Wh(jnp.concatenate([x, rh], axis=-1)))
394
+ h = (1 - z) * old_h + z * h
395
+ self.h.value = h
396
+ return h
397
+
398
+
399
+ class MGUCell(RNNCell):
400
+ r"""
401
+ Minimal Gated Unit (MGU) cell implementation.
402
+
403
+ MGU is a simplified version of GRU that uses a single forget gate instead of
404
+ separate reset and update gates. This design significantly reduces the number
405
+ of parameters while maintaining much of the gating capability. MGU provides
406
+ a good trade-off between model complexity and performance.
407
+
408
+ The MGU cell follows the mathematical formulation:
409
+
410
+ .. math::
411
+
412
+ f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
413
+ \tilde{h}_t &= \phi(W_h [x_t, (f_t \odot h_{t-1})] + b_h) \\
414
+ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t
415
+
416
+ where:
417
+
418
+ - :math:`x_t` is the input vector at time t
419
+ - :math:`h_t` is the hidden state at time t
420
+ - :math:`f_t` is the forget gate vector
421
+ - :math:`\tilde{h}_t` is the candidate hidden state
422
+ - :math:`\odot` represents element-wise multiplication
423
+ - :math:`\sigma` is the sigmoid activation function
424
+ - :math:`\phi` is the activation function (typically tanh)
425
+
426
+ Parameters
427
+ ----------
428
+ num_in : int
429
+ The number of input units.
430
+ num_out : int
431
+ The number of hidden units.
432
+ w_init : Union[ArrayLike, Callable], default=init.Orthogonal()
433
+ Initializer for the weight matrices.
434
+ b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
435
+ Initializer for the bias vectors.
436
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
437
+ Initializer for the hidden state.
438
+ activation : str or Callable, default='tanh'
439
+ Activation function to use. Can be a string (e.g., 'tanh', 'relu')
440
+ or a callable function.
441
+ name : str, optional
442
+ Name of the module.
443
+
444
+ Attributes
445
+ ----------
446
+ num_in : int
447
+ Number of input features.
448
+ num_out : int
449
+ Number of hidden units.
450
+ in_size : tuple
451
+ Shape of input (num_in,).
452
+ out_size : tuple
453
+ Shape of output (num_out,).
454
+
455
+ State Variables
456
+ ---------------
457
+ h : HiddenState
458
+ Hidden state of the MGU cell.
459
+
460
+ Methods
461
+ -------
462
+ init_state(batch_size=None, **kwargs)
463
+ Initialize the cell hidden state.
464
+ reset_state(batch_size=None, **kwargs)
465
+ Reset the cell hidden state to its initial value.
466
+ update(x)
467
+ Update the hidden state for one time step and return the new state.
468
+
469
+ Examples
470
+ --------
471
+ .. code-block:: python
472
+
473
+ >>> import brainstate as bs
474
+ >>> import jax.numpy as jnp
475
+ >>>
476
+ >>> # Create an MGU cell
477
+ >>> cell = bs.nn.MGUCell(num_in=10, num_out=20)
478
+ >>>
479
+ >>> # Initialize state for batch size 32
480
+ >>> cell.init_state(batch_size=32)
481
+ >>>
482
+ >>> # Process a single time step
483
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
484
+ >>> output = cell.update(x)
485
+ >>> print(output.shape) # (32, 20)
486
+ >>>
487
+ >>> # Process a sequence
488
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
489
+ >>> outputs = []
490
+ >>> for t in range(100):
491
+ ... output = cell.update(sequence[t])
492
+ ... outputs.append(output)
493
+ >>> outputs = jnp.stack(outputs)
494
+ >>> print(outputs.shape) # (100, 32, 20)
495
+
496
+ Notes
497
+ -----
498
+ MGU provides a lightweight alternative to GRU and LSTM, making it suitable
499
+ for resource-constrained applications or when model simplicity is preferred.
500
+
501
+ References
502
+ ----------
503
+ .. [1] Zhou, G. B., Wu, J., Zhang, C. L., & Zhou, Z. H. (2016). Minimal gated unit
504
+ for recurrent neural networks. International Journal of Automation and Computing,
505
+ 13(3), 226-234.
506
+ """
507
+ __module__ = 'brainstate.nn'
508
+
509
+ def __init__(
510
+ self,
511
+ num_in: int,
512
+ num_out: int,
513
+ w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
514
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
515
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
516
+ activation: str | Callable = 'tanh',
517
+ name: str = None,
518
+ ):
519
+ super().__init__(name=name)
520
+
521
+ # parameters
522
+ self._state_initializer = state_init
523
+ self.num_out = num_out
524
+ self.num_in = num_in
525
+ self.in_size = (num_in,)
526
+ self.out_size = (num_out,)
527
+
528
+ # activation function
529
+ if isinstance(activation, str):
530
+ self.activation = getattr(functional, activation)
531
+ else:
532
+ assert callable(activation), "The activation function should be a string or a callable function. "
533
+ self.activation = activation
534
+
535
+ # weights
536
+ self.Wf = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
537
+ self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
538
+
539
+ def init_state(self, batch_size: int = None, **kwargs):
540
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
541
+
542
+ def reset_state(self, batch_size: int = None, **kwargs):
543
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
544
+
545
+ def update(self, x):
546
+ old_h = self.h.value
547
+ xh = jnp.concatenate([x, old_h], axis=-1)
548
+ f = functional.sigmoid(self.Wf(xh))
549
+ fh = f * old_h
550
+ h = self.activation(self.Wh(jnp.concatenate([x, fh], axis=-1)))
551
+ self.h.value = (1 - f) * self.h.value + f * h
552
+ return self.h.value
553
+
554
+
555
+ class LSTMCell(RNNCell):
556
+ r"""
557
+ Long Short-Term Memory (LSTM) cell implementation.
558
+
559
+ LSTM is a type of RNN architecture designed to address the vanishing gradient
560
+ problem and learn long-term dependencies. It uses a cell state to carry
561
+ information across time steps and three gates (input, forget, output) to
562
+ control information flow.
563
+
564
+ The LSTM cell follows the mathematical formulation:
565
+
566
+ .. math::
567
+
568
+ i_t &= \sigma(W_i [x_t, h_{t-1}] + b_i) \\
569
+ f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
570
+ g_t &= \phi(W_g [x_t, h_{t-1}] + b_g) \\
571
+ o_t &= \sigma(W_o [x_t, h_{t-1}] + b_o) \\
572
+ c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\
573
+ h_t &= o_t \odot \phi(c_t)
574
+
575
+ where:
576
+
577
+ - :math:`x_t` is the input vector at time t
578
+ - :math:`h_t` is the hidden state at time t
579
+ - :math:`c_t` is the cell state at time t
580
+ - :math:`i_t` is the input gate activation
581
+ - :math:`f_t` is the forget gate activation
582
+ - :math:`o_t` is the output gate activation
583
+ - :math:`g_t` is the cell update (candidate) vector
584
+ - :math:`\odot` represents element-wise multiplication
585
+ - :math:`\sigma` is the sigmoid activation function
586
+ - :math:`\phi` is the activation function (typically tanh)
587
+
588
+ Parameters
589
+ ----------
590
+ num_in : int
591
+ The number of input units.
592
+ num_out : int
593
+ The number of hidden/cell units.
594
+ w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
595
+ Initializer for the weight matrices.
596
+ b_init : Union[ArrayLike, Callable], default=init.ZeroInit()
597
+ Initializer for the bias vectors.
598
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
599
+ Initializer for the hidden and cell states.
600
+ activation : str or Callable, default='tanh'
601
+ Activation function to use. Can be a string (e.g., 'tanh', 'relu')
602
+ or a callable function.
603
+ name : str, optional
604
+ Name of the module.
605
+
606
+ Attributes
607
+ ----------
608
+ num_in : int
609
+ Number of input features.
610
+ num_out : int
611
+ Number of hidden/cell units.
612
+ in_size : tuple
613
+ Shape of input (num_in,).
614
+ out_size : tuple
615
+ Shape of output (num_out,).
616
+
617
+ State Variables
618
+ ---------------
619
+ h : HiddenState
620
+ Hidden state of the LSTM cell.
621
+ c : HiddenState
622
+ Cell state of the LSTM cell.
623
+
624
+ Methods
625
+ -------
626
+ init_state(batch_size=None, **kwargs)
627
+ Initialize the cell and hidden states.
628
+ reset_state(batch_size=None, **kwargs)
629
+ Reset the cell and hidden states to their initial values.
630
+ update(x)
631
+ Update the states for one time step and return the new hidden state.
632
+
633
+ Examples
634
+ --------
635
+ .. code-block:: python
636
+
637
+ >>> import brainstate as bs
638
+ >>> import jax.numpy as jnp
639
+ >>>
640
+ >>> # Create an LSTM cell
641
+ >>> cell = bs.nn.LSTMCell(num_in=10, num_out=20)
642
+ >>>
643
+ >>> # Initialize states for batch size 32
644
+ >>> cell.init_state(batch_size=32)
645
+ >>>
646
+ >>> # Process a single time step
647
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
648
+ >>> output = cell.update(x)
649
+ >>> print(output.shape) # (32, 20)
650
+ >>>
651
+ >>> # Process a sequence
652
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
653
+ >>> outputs = []
654
+ >>> for t in range(100):
655
+ ... output = cell.update(sequence[t])
656
+ ... outputs.append(output)
657
+ >>> outputs = jnp.stack(outputs)
658
+ >>> print(outputs.shape) # (100, 32, 20)
659
+ >>>
660
+ >>> # Access cell state
661
+ >>> print(cell.c.value.shape) # (32, 20)
662
+ >>> print(cell.h.value.shape) # (32, 20)
663
+
664
+ Notes
665
+ -----
666
+ - The forget gate bias is initialized with +1.0 following Jozefowicz et al. (2015)
667
+ to reduce forgetting at the beginning of training.
668
+ - LSTM cells are effective for learning long-term dependencies but require
669
+ more parameters and computation than simpler RNN variants.
670
+
671
+ References
672
+ ----------
673
+ .. [1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.
674
+ Neural computation, 9(8), 1735-1780.
675
+ .. [2] Gers, F. A., Schmidhuber, J., & Cummins, F. (2000). Learning to forget:
676
+ Continual prediction with LSTM. Neural computation, 12(10), 2451-2471.
677
+ .. [3] Jozefowicz, R., Zaremba, W., & Sutskever, I. (2015). An empirical
678
+ exploration of recurrent network architectures. In International
679
+ conference on machine learning (pp. 2342-2350).
680
+ """
681
+ __module__ = 'brainstate.nn'
682
+
683
+ def __init__(
684
+ self,
685
+ num_in: int,
686
+ num_out: int,
687
+ w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
688
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
689
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
690
+ activation: str | Callable = 'tanh',
691
+ name: str = None,
692
+ ):
693
+ super().__init__(name=name)
694
+
695
+ # parameters
696
+ self.num_out = num_out
697
+ self.num_in = num_in
698
+ self.in_size = (num_in,)
699
+ self.out_size = (num_out,)
700
+
701
+ # initializers
702
+ self._state_initializer = state_init
703
+
704
+ # activation function
705
+ if isinstance(activation, str):
706
+ self.activation = getattr(functional, activation)
707
+ else:
708
+ assert callable(activation), "The activation function should be a string or a callable function. "
709
+ self.activation = activation
710
+
711
+ # weights
712
+ self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=b_init)
713
+
714
+ def init_state(self, batch_size: int = None, **kwargs):
715
+ self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
716
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
717
+
718
+ def reset_state(self, batch_size: int = None, **kwargs):
719
+ self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
720
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
721
+
722
+ def update(self, x):
723
+ h, c = self.h.value, self.c.value
724
+ xh = jnp.concat([x, h], axis=-1)
725
+ i, g, f, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
726
+ c = functional.sigmoid(f + 1.) * c + functional.sigmoid(i) * self.activation(g)
727
+ h = functional.sigmoid(o) * self.activation(c)
728
+ self.h.value = h
729
+ self.c.value = c
730
+ return h
731
+
732
+
733
+ class URLSTMCell(RNNCell):
734
+ r"""LSTM with UR gating mechanism.
735
+
736
+ URLSTM is a modification of the standard LSTM that uses untied (separate) biases
737
+ for the forget and retention mechanisms, allowing for more flexible gating control.
738
+ This implementation is based on the paper "Improving the Gating Mechanism of
739
+ Recurrent Neural Networks" by Gers et al.
740
+
741
+ The URLSTM cell follows the mathematical formulation:
742
+
743
+ .. math::
744
+
745
+ f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
746
+ r_t &= \sigma(W_r [x_t, h_{t-1}] - b_f) \\
747
+ g_t &= 2 r_t \odot f_t + (1 - 2 r_t) \odot f_t^2 \\
748
+ \tilde{c}_t &= \phi(W_c [x_t, h_{t-1}]) \\
749
+ c_t &= g_t \odot c_{t-1} + (1 - g_t) \odot \tilde{c}_t \\
750
+ o_t &= \sigma(W_o [x_t, h_{t-1}]) \\
751
+ h_t &= o_t \odot \phi(c_t)
752
+
753
+ where:
754
+
755
+ - :math:`x_t` is the input vector at time t
756
+ - :math:`h_t` is the hidden state at time t
757
+ - :math:`c_t` is the cell state at time t
758
+ - :math:`f_t` is the forget gate with positive bias
759
+ - :math:`r_t` is the retention gate with negative bias
760
+ - :math:`g_t` is the unified gate combining forget and retention
761
+ - :math:`\tilde{c}_t` is the candidate cell state
762
+ - :math:`o_t` is the output gate
763
+ - :math:`\odot` represents element-wise multiplication
764
+ - :math:`\sigma` is the sigmoid activation function
765
+ - :math:`\phi` is the activation function (typically tanh)
766
+
767
+ The key innovation is the untied bias mechanism where the forget and retention
768
+ gates use opposite biases, initialized using a uniform distribution to encourage
769
+ diverse gating behavior across units.
770
+
771
+ Parameters
772
+ ----------
773
+ num_in : int
774
+ The number of input units.
775
+ num_out : int
776
+ The number of hidden/output units.
777
+ w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
778
+ Initializer for the weight matrix.
779
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
780
+ Initializer for the hidden and cell states.
781
+ activation : str or Callable, default='tanh'
782
+ Activation function to use. Can be a string (e.g., 'relu', 'tanh')
783
+ or a callable function.
784
+ name : str, optional
785
+ Name of the module.
786
+
787
+ State Variables
788
+ ---------------
789
+ h : HiddenState
790
+ Hidden state of the URLSTM cell.
791
+ c : HiddenState
792
+ Cell state of the URLSTM cell.
793
+
794
+ Methods
795
+ -------
796
+ init_state(batch_size=None, **kwargs)
797
+ Initialize the cell and hidden states.
798
+ reset_state(batch_size=None, **kwargs)
799
+ Reset the cell and hidden states to their initial values.
800
+ update(x)
801
+ Update the cell and hidden states for one time step and return the hidden state.
802
+
803
+ Examples
804
+ --------
805
+ .. code-block:: python
806
+
807
+ >>> import brainstate as bs
808
+ >>> import jax.numpy as jnp
809
+ >>>
810
+ >>> # Create a URLSTM cell
811
+ >>> cell = bs.nn.URLSTMCell(num_in=10, num_out=20)
812
+ >>>
813
+ >>> # Initialize the state for batch size 32
814
+ >>> cell.init_state(batch_size=32)
815
+ >>>
816
+ >>> # Process a sequence
817
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
818
+ >>> output = cell.update(x)
819
+ >>> print(output.shape) # (32, 20)
820
+ >>>
821
+ >>> # Process multiple time steps
822
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
823
+ >>> outputs = []
824
+ >>> for t in range(100):
825
+ ... output = cell.update(sequence[t])
826
+ ... outputs.append(output)
827
+ >>> outputs = jnp.stack(outputs)
828
+ >>> print(outputs.shape) # (100, 32, 20)
829
+
830
+ References
831
+ ----------
832
+ .. [1] Gu, Albert, et al. "Improving the gating mechanism of recurrent neural networks."
833
+ International conference on machine learning. PMLR, 2020.
834
+ """
835
+ __module__ = 'brainstate.nn'
836
+
837
+ def __init__(
838
+ self,
839
+ num_in: int,
840
+ num_out: int,
841
+ w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
842
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
843
+ activation: str | Callable = 'tanh',
844
+ name: str = None,
845
+ ):
846
+ super().__init__(name=name)
847
+
848
+ # parameters
849
+ self.num_out = num_out
850
+ self.num_in = num_in
851
+ self.in_size = (num_in,)
852
+ self.out_size = (num_out,)
853
+
854
+ # initializers
855
+ self._state_initializer = state_init
856
+
857
+ # activation function
858
+ if isinstance(activation, str):
859
+ self.activation = getattr(functional, activation)
860
+ else:
861
+ assert callable(activation), "The activation function should be a string or a callable function."
862
+ self.activation = activation
863
+
864
+ # weights - 4 gates: forget, retention, candidate, output
865
+ self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=None)
866
+
867
+ # Initialize untied bias using uniform distribution
868
+ self.bias = ParamState(self._forget_bias())
869
+
870
+ def _forget_bias(self):
871
+ """Initialize the forget gate bias using uniform distribution."""
872
+ # Sample from uniform distribution to encourage diverse gating
873
+ u = random.uniform(1 / self.num_out, 1 - 1 / self.num_out, (self.num_out,))
874
+ # Transform to logit space for initialization
875
+ return -jnp.log(1 / u - 1)
876
+
877
+ def init_state(self, batch_size: int = None, **kwargs):
878
+ """
879
+ Initialize the cell and hidden states.
880
+
881
+ Parameters
882
+ ----------
883
+ batch_size : int, optional
884
+ The batch size for state initialization.
885
+ **kwargs
886
+ Additional keyword arguments.
887
+ """
888
+ self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
889
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
890
+
891
+ def reset_state(self, batch_size: int = None, **kwargs):
892
+ """
893
+ Reset the cell and hidden states to their initial values.
894
+
895
+ Parameters
896
+ ----------
897
+ batch_size : int, optional
898
+ The batch size for state reset.
899
+ **kwargs
900
+ Additional keyword arguments.
901
+ """
902
+ self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
903
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
904
+
905
+ def update(self, x: ArrayLike) -> ArrayLike:
906
+ """
907
+ Update the URLSTM cell for one time step.
908
+
909
+ Parameters
910
+ ----------
911
+ x : ArrayLike
912
+ Input tensor with shape (batch_size, num_in).
913
+
914
+ Returns
915
+ -------
916
+ ArrayLike
917
+ Hidden state tensor with shape (batch_size, num_out).
918
+ """
919
+ h, c = self.h.value, self.c.value
920
+
921
+ # Concatenate input and hidden state
922
+ xh = jnp.concatenate([x, h], axis=-1)
923
+
924
+ # Compute all gates in one pass
925
+ gates = self.W(xh)
926
+ f, r, u, o = jnp.split(gates, indices_or_sections=4, axis=-1)
927
+
928
+ # Apply untied biases to forget and retention gates
929
+ f_gate = functional.sigmoid(f + self.bias.value)
930
+ r_gate = functional.sigmoid(r - self.bias.value)
931
+
932
+ # Compute unified gate
933
+ g = 2 * r_gate * f_gate + (1 - 2 * r_gate) * f_gate ** 2
934
+
935
+ # Update cell state
936
+ next_cell = g * c + (1 - g) * self.activation(u)
937
+
938
+ # Compute output gate and hidden state
939
+ o_gate = functional.sigmoid(o)
940
+ next_hidden = o_gate * self.activation(next_cell)
941
+
942
+ # Update states
943
+ self.h.value = next_hidden
944
+ self.c.value = next_cell
945
+
946
+ return next_hidden