brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -20,9 +20,11 @@ from typing import Callable, Union
20
20
 
21
21
  import jax.numpy as jnp
22
22
 
23
- from brainstate import random, init, functional
23
+ from brainstate import random
24
24
  from brainstate._state import HiddenState, ParamState
25
25
  from brainstate.typing import ArrayLike
26
+ from . import _activations as functional
27
+ from . import init as init
26
28
  from ._linear import Linear
27
29
  from ._module import Module
28
30
 
@@ -54,7 +56,16 @@ class RNNCell(Module):
54
56
  Reset the cell state variables to their initial values.
55
57
  update(x)
56
58
  Update the cell state for one time step based on input x and return output.
59
+
60
+ See Also
61
+ --------
62
+ ValinaRNNCell : Vanilla RNN cell implementation
63
+ GRUCell : Gated Recurrent Unit cell implementation
64
+ LSTMCell : Long Short-Term Memory cell implementation
65
+ URLSTMCell : LSTM with UR gating mechanism
66
+ MGUCell : Minimal Gated Unit cell implementation
57
67
  """
68
+ __module__ = 'brainstate.nn'
58
69
  pass
59
70
 
60
71
 
@@ -97,8 +108,19 @@ class ValinaRNNCell(RNNCell):
97
108
  name : str, optional
98
109
  Name of the module.
99
110
 
111
+ Attributes
112
+ ----------
113
+ num_in : int
114
+ Number of input features.
115
+ num_out : int
116
+ Number of hidden units.
117
+ in_size : tuple
118
+ Shape of input (num_in,).
119
+ out_size : tuple
120
+ Shape of output (num_out,).
121
+
100
122
  State Variables
101
- --------------
123
+ ---------------
102
124
  h : HiddenState
103
125
  Hidden state of the RNN cell.
104
126
 
@@ -110,6 +132,39 @@ class ValinaRNNCell(RNNCell):
110
132
  Reset the cell hidden state to its initial value.
111
133
  update(x)
112
134
  Update the hidden state for one time step and return the new state.
135
+
136
+ Examples
137
+ --------
138
+ .. code-block:: python
139
+
140
+ >>> import brainstate as bs
141
+ >>> import jax.numpy as jnp
142
+ >>>
143
+ >>> # Create a vanilla RNN cell
144
+ >>> cell = bs.nn.ValinaRNNCell(num_in=10, num_out=20)
145
+ >>>
146
+ >>> # Initialize state for batch size 32
147
+ >>> cell.init_state(batch_size=32)
148
+ >>>
149
+ >>> # Process a single time step
150
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
151
+ >>> output = cell.update(x)
152
+ >>> print(output.shape) # (32, 20)
153
+ >>>
154
+ >>> # Process a sequence of inputs
155
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
156
+ >>> outputs = []
157
+ >>> for t in range(100):
158
+ ... output = cell.update(sequence[t])
159
+ ... outputs.append(output)
160
+ >>> outputs = jnp.stack(outputs)
161
+ >>> print(outputs.shape) # (100, 32, 20)
162
+
163
+ Notes
164
+ -----
165
+ Vanilla RNNs can suffer from vanishing or exploding gradient problems
166
+ when processing long sequences. For better performance on long sequences,
167
+ consider using gated architectures like GRU or LSTM.
113
168
  """
114
169
  __module__ = 'brainstate.nn'
115
170
 
@@ -143,10 +198,30 @@ class ValinaRNNCell(RNNCell):
143
198
  self.W = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
144
199
 
145
200
  def init_state(self, batch_size: int = None, **kwargs):
146
- self.h = HiddenState(init.param(self._state_initializer, self.num_out, batch_size))
201
+ """
202
+ Initialize the hidden state.
203
+
204
+ Parameters
205
+ ----------
206
+ batch_size : int, optional
207
+ The batch size for state initialization.
208
+ **kwargs
209
+ Additional keyword arguments.
210
+ """
211
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
147
212
 
148
213
  def reset_state(self, batch_size: int = None, **kwargs):
149
- self.h.value = init.param(self._state_initializer, self.num_out, batch_size)
214
+ """
215
+ Reset the hidden state to initial value.
216
+
217
+ Parameters
218
+ ----------
219
+ batch_size : int, optional
220
+ The batch size for state reset.
221
+ **kwargs
222
+ Additional keyword arguments.
223
+ """
224
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
150
225
 
151
226
  def update(self, x):
152
227
  xh = jnp.concatenate([x, self.h.value], axis=-1)
@@ -159,16 +234,18 @@ class GRUCell(RNNCell):
159
234
  r"""
160
235
  Gated Recurrent Unit (GRU) cell implementation.
161
236
 
162
- This class implements the GRU model that uses gating mechanisms to control
163
- information flow. The GRU has fewer parameters than LSTM as it combines
164
- the forget and input gates into a single update gate. The GRU follows the
165
- mathematical formulation:
237
+ The GRU is a gating mechanism in recurrent neural networks that aims to solve
238
+ the vanishing gradient problem. It uses gating mechanisms to control information
239
+ flow and has fewer parameters than LSTM as it combines the forget and input gates
240
+ into a single update gate.
241
+
242
+ The GRU cell follows the mathematical formulation:
166
243
 
167
244
  .. math::
168
245
 
169
246
  r_t &= \sigma(W_r [x_t, h_{t-1}] + b_r) \\
170
247
  z_t &= \sigma(W_z [x_t, h_{t-1}] + b_z) \\
171
- \tilde{h}_t &= \tanh(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\
248
+ \tilde{h}_t &= \phi(W_h [x_t, (r_t \odot h_{t-1})] + b_h) \\
172
249
  h_t &= (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t
173
250
 
174
251
  where:
@@ -180,6 +257,7 @@ class GRUCell(RNNCell):
180
257
  - :math:`\tilde{h}_t` is the candidate hidden state
181
258
  - :math:`\odot` represents element-wise multiplication
182
259
  - :math:`\sigma` is the sigmoid activation function
260
+ - :math:`\phi` is the activation function (typically tanh)
183
261
 
184
262
  Parameters
185
263
  ----------
@@ -194,13 +272,24 @@ class GRUCell(RNNCell):
194
272
  state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
195
273
  Initializer for the hidden state.
196
274
  activation : str or Callable, default='tanh'
197
- Activation function to use. Can be a string (e.g., 'tanh')
275
+ Activation function to use. Can be a string (e.g., 'tanh', 'relu')
198
276
  or a callable function.
199
277
  name : str, optional
200
278
  Name of the module.
201
279
 
280
+ Attributes
281
+ ----------
282
+ num_in : int
283
+ Number of input features.
284
+ num_out : int
285
+ Number of hidden units.
286
+ in_size : tuple
287
+ Shape of input (num_in,).
288
+ out_size : tuple
289
+ Shape of output (num_out,).
290
+
202
291
  State Variables
203
- --------------
292
+ ---------------
204
293
  h : HiddenState
205
294
  Hidden state of the GRU cell.
206
295
 
@@ -212,6 +301,51 @@ class GRUCell(RNNCell):
212
301
  Reset the cell hidden state to its initial value.
213
302
  update(x)
214
303
  Update the hidden state for one time step and return the new state.
304
+
305
+ Examples
306
+ --------
307
+ .. code-block:: python
308
+
309
+ >>> import brainstate as bs
310
+ >>> import jax.numpy as jnp
311
+ >>>
312
+ >>> # Create a GRU cell
313
+ >>> cell = bs.nn.GRUCell(num_in=10, num_out=20)
314
+ >>>
315
+ >>> # Initialize state for batch size 32
316
+ >>> cell.init_state(batch_size=32)
317
+ >>>
318
+ >>> # Process a single time step
319
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
320
+ >>> output = cell.update(x)
321
+ >>> print(output.shape) # (32, 20)
322
+ >>>
323
+ >>> # Process a sequence
324
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
325
+ >>> outputs = []
326
+ >>> for t in range(100):
327
+ ... output = cell.update(sequence[t])
328
+ ... outputs.append(output)
329
+ >>> outputs = jnp.stack(outputs)
330
+ >>> print(outputs.shape) # (100, 32, 20)
331
+ >>>
332
+ >>> # Reset state with different batch size
333
+ >>> cell.reset_state(batch_size=16)
334
+ >>> x_new = jnp.ones((16, 10))
335
+ >>> output_new = cell.update(x_new)
336
+ >>> print(output_new.shape) # (16, 20)
337
+
338
+ Notes
339
+ -----
340
+ GRU cells are computationally more efficient than LSTM cells due to having
341
+ fewer parameters, while often achieving comparable performance on many tasks.
342
+
343
+ References
344
+ ----------
345
+ .. [1] Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F.,
346
+ Schwenk, H., & Bengio, Y. (2014). Learning phrase representations using
347
+ RNN encoder-decoder for statistical machine translation.
348
+ arXiv preprint arXiv:1406.1078.
215
349
  """
216
350
  __module__ = 'brainstate.nn'
217
351
 
@@ -264,28 +398,30 @@ class GRUCell(RNNCell):
264
398
 
265
399
  class MGUCell(RNNCell):
266
400
  r"""
267
- Minimal Gated Recurrent Unit (MGU) cell implementation.
401
+ Minimal Gated Unit (MGU) cell implementation.
268
402
 
269
403
  MGU is a simplified version of GRU that uses a single forget gate instead of
270
- separate reset and update gates. This results in fewer parameters while
271
- maintaining much of the gating capability. The MGU follows the mathematical
272
- formulation:
404
+ separate reset and update gates. This design significantly reduces the number
405
+ of parameters while maintaining much of the gating capability. MGU provides
406
+ a good trade-off between model complexity and performance.
407
+
408
+ The MGU cell follows the mathematical formulation:
273
409
 
274
410
  .. math::
275
411
 
276
- f_t &= \\sigma(W_f [x_t, h_{t-1}] + b_f) \\\\
277
- \\tilde{h}_t &= \\phi(W_h [x_t, (f_t \\odot h_{t-1})] + b_h) \\\\
278
- h_t &= (1 - f_t) \\odot h_{t-1} + f_t \\odot \\tilde{h}_t
412
+ f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
413
+ \tilde{h}_t &= \phi(W_h [x_t, (f_t \odot h_{t-1})] + b_h) \\
414
+ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t
279
415
 
280
416
  where:
281
417
 
282
418
  - :math:`x_t` is the input vector at time t
283
419
  - :math:`h_t` is the hidden state at time t
284
420
  - :math:`f_t` is the forget gate vector
285
- - :math:`\\tilde{h}_t` is the candidate hidden state
286
- - :math:`\\odot` represents element-wise multiplication
287
- - :math:`\\sigma` is the sigmoid activation function
288
- - :math:`\\phi` is the activation function (typically tanh)
421
+ - :math:`\tilde{h}_t` is the candidate hidden state
422
+ - :math:`\odot` represents element-wise multiplication
423
+ - :math:`\sigma` is the sigmoid activation function
424
+ - :math:`\phi` is the activation function (typically tanh)
289
425
 
290
426
  Parameters
291
427
  ----------
@@ -300,13 +436,24 @@ class MGUCell(RNNCell):
300
436
  state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
301
437
  Initializer for the hidden state.
302
438
  activation : str or Callable, default='tanh'
303
- Activation function to use. Can be a string (e.g., 'tanh')
439
+ Activation function to use. Can be a string (e.g., 'tanh', 'relu')
304
440
  or a callable function.
305
441
  name : str, optional
306
442
  Name of the module.
307
443
 
444
+ Attributes
445
+ ----------
446
+ num_in : int
447
+ Number of input features.
448
+ num_out : int
449
+ Number of hidden units.
450
+ in_size : tuple
451
+ Shape of input (num_in,).
452
+ out_size : tuple
453
+ Shape of output (num_out,).
454
+
308
455
  State Variables
309
- --------------
456
+ ---------------
310
457
  h : HiddenState
311
458
  Hidden state of the MGU cell.
312
459
 
@@ -318,6 +465,44 @@ class MGUCell(RNNCell):
318
465
  Reset the cell hidden state to its initial value.
319
466
  update(x)
320
467
  Update the hidden state for one time step and return the new state.
468
+
469
+ Examples
470
+ --------
471
+ .. code-block:: python
472
+
473
+ >>> import brainstate as bs
474
+ >>> import jax.numpy as jnp
475
+ >>>
476
+ >>> # Create an MGU cell
477
+ >>> cell = bs.nn.MGUCell(num_in=10, num_out=20)
478
+ >>>
479
+ >>> # Initialize state for batch size 32
480
+ >>> cell.init_state(batch_size=32)
481
+ >>>
482
+ >>> # Process a single time step
483
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
484
+ >>> output = cell.update(x)
485
+ >>> print(output.shape) # (32, 20)
486
+ >>>
487
+ >>> # Process a sequence
488
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
489
+ >>> outputs = []
490
+ >>> for t in range(100):
491
+ ... output = cell.update(sequence[t])
492
+ ... outputs.append(output)
493
+ >>> outputs = jnp.stack(outputs)
494
+ >>> print(outputs.shape) # (100, 32, 20)
495
+
496
+ Notes
497
+ -----
498
+ MGU provides a lightweight alternative to GRU and LSTM, making it suitable
499
+ for resource-constrained applications or when model simplicity is preferred.
500
+
501
+ References
502
+ ----------
503
+ .. [1] Zhou, G. B., Wu, J., Zhang, C. L., & Zhou, Z. H. (2016). Minimal gated unit
504
+ for recurrent neural networks. International Journal of Automation and Computing,
505
+ 13(3), 226-234.
321
506
  """
322
507
  __module__ = 'brainstate.nn'
323
508
 
@@ -371,34 +556,34 @@ class LSTMCell(RNNCell):
371
556
  r"""
372
557
  Long Short-Term Memory (LSTM) cell implementation.
373
558
 
374
- This class implements the LSTM architecture which uses multiple gating mechanisms
375
- to regulate information flow and address the vanishing gradient problem in RNNs.
376
- The LSTM follows the mathematical formulation:
559
+ LSTM is a type of RNN architecture designed to address the vanishing gradient
560
+ problem and learn long-term dependencies. It uses a cell state to carry
561
+ information across time steps and three gates (input, forget, output) to
562
+ control information flow.
563
+
564
+ The LSTM cell follows the mathematical formulation:
377
565
 
378
566
  .. math::
379
567
 
380
- i_t &= \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\
381
- f_t &= \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\
382
- g_t &= \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\
383
- o_t &= \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\
568
+ i_t &= \sigma(W_i [x_t, h_{t-1}] + b_i) \\
569
+ f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
570
+ g_t &= \phi(W_g [x_t, h_{t-1}] + b_g) \\
571
+ o_t &= \sigma(W_o [x_t, h_{t-1}] + b_o) \\
384
572
  c_t &= f_t \odot c_{t-1} + i_t \odot g_t \\
385
- h_t &= o_t \odot \tanh(c_t)
573
+ h_t &= o_t \odot \phi(c_t)
386
574
 
387
575
  where:
388
576
 
389
577
  - :math:`x_t` is the input vector at time t
390
578
  - :math:`h_t` is the hidden state at time t
391
579
  - :math:`c_t` is the cell state at time t
392
- - :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and output gate activations
393
- - :math:`g_t` is the cell update vector
580
+ - :math:`i_t` is the input gate activation
581
+ - :math:`f_t` is the forget gate activation
582
+ - :math:`o_t` is the output gate activation
583
+ - :math:`g_t` is the cell update (candidate) vector
394
584
  - :math:`\odot` represents element-wise multiplication
395
585
  - :math:`\sigma` is the sigmoid activation function
396
-
397
- Notes
398
- -----
399
- Forget gate initialization: Following Jozefowicz et al. (2015), we add 1.0
400
- to the forget gate bias after initialization to reduce forgetting at the
401
- beginning of training.
586
+ - :math:`\phi` is the activation function (typically tanh)
402
587
 
403
588
  Parameters
404
589
  ----------
@@ -413,13 +598,24 @@ class LSTMCell(RNNCell):
413
598
  state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
414
599
  Initializer for the hidden and cell states.
415
600
  activation : str or Callable, default='tanh'
416
- Activation function to use. Can be a string (e.g., 'tanh')
601
+ Activation function to use. Can be a string (e.g., 'tanh', 'relu')
417
602
  or a callable function.
418
603
  name : str, optional
419
604
  Name of the module.
420
605
 
606
+ Attributes
607
+ ----------
608
+ num_in : int
609
+ Number of input features.
610
+ num_out : int
611
+ Number of hidden/cell units.
612
+ in_size : tuple
613
+ Shape of input (num_in,).
614
+ out_size : tuple
615
+ Shape of output (num_out,).
616
+
421
617
  State Variables
422
- --------------
618
+ ---------------
423
619
  h : HiddenState
424
620
  Hidden state of the LSTM cell.
425
621
  c : HiddenState
@@ -434,15 +630,53 @@ class LSTMCell(RNNCell):
434
630
  update(x)
435
631
  Update the states for one time step and return the new hidden state.
436
632
 
633
+ Examples
634
+ --------
635
+ .. code-block:: python
636
+
637
+ >>> import brainstate as bs
638
+ >>> import jax.numpy as jnp
639
+ >>>
640
+ >>> # Create an LSTM cell
641
+ >>> cell = bs.nn.LSTMCell(num_in=10, num_out=20)
642
+ >>>
643
+ >>> # Initialize states for batch size 32
644
+ >>> cell.init_state(batch_size=32)
645
+ >>>
646
+ >>> # Process a single time step
647
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
648
+ >>> output = cell.update(x)
649
+ >>> print(output.shape) # (32, 20)
650
+ >>>
651
+ >>> # Process a sequence
652
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
653
+ >>> outputs = []
654
+ >>> for t in range(100):
655
+ ... output = cell.update(sequence[t])
656
+ ... outputs.append(output)
657
+ >>> outputs = jnp.stack(outputs)
658
+ >>> print(outputs.shape) # (100, 32, 20)
659
+ >>>
660
+ >>> # Access cell state
661
+ >>> print(cell.c.value.shape) # (32, 20)
662
+ >>> print(cell.h.value.shape) # (32, 20)
663
+
664
+ Notes
665
+ -----
666
+ - The forget gate bias is initialized with +1.0 following Jozefowicz et al. (2015)
667
+ to reduce forgetting at the beginning of training.
668
+ - LSTM cells are effective for learning long-term dependencies but require
669
+ more parameters and computation than simpler RNN variants.
670
+
437
671
  References
438
672
  ----------
439
673
  .. [1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.
440
674
  Neural computation, 9(8), 1735-1780.
441
- .. [2] Zaremba, W., Sutskever, I., & Vinyals, O. (2014). Recurrent neural
442
- network regularization. arXiv preprint arXiv:1409.2329.
675
+ .. [2] Gers, F. A., Schmidhuber, J., & Cummins, F. (2000). Learning to forget:
676
+ Continual prediction with LSTM. Neural computation, 12(10), 2451-2471.
443
677
  .. [3] Jozefowicz, R., Zaremba, W., & Sutskever, I. (2015). An empirical
444
678
  exploration of recurrent network architectures. In International
445
- conference on machine learning, pp. 2342-2350.
679
+ conference on machine learning (pp. 2342-2350).
446
680
  """
447
681
  __module__ = 'brainstate.nn'
448
682
 
@@ -497,6 +731,109 @@ class LSTMCell(RNNCell):
497
731
 
498
732
 
499
733
  class URLSTMCell(RNNCell):
734
+ r"""LSTM with UR gating mechanism.
735
+
736
+ URLSTM is a modification of the standard LSTM that uses untied (separate) biases
737
+ for the forget and retention mechanisms, allowing for more flexible gating control.
738
+ This implementation is based on the paper "Improving the Gating Mechanism of
739
+ Recurrent Neural Networks" by Gers et al.
740
+
741
+ The URLSTM cell follows the mathematical formulation:
742
+
743
+ .. math::
744
+
745
+ f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \\
746
+ r_t &= \sigma(W_r [x_t, h_{t-1}] - b_f) \\
747
+ g_t &= 2 r_t \odot f_t + (1 - 2 r_t) \odot f_t^2 \\
748
+ \tilde{c}_t &= \phi(W_c [x_t, h_{t-1}]) \\
749
+ c_t &= g_t \odot c_{t-1} + (1 - g_t) \odot \tilde{c}_t \\
750
+ o_t &= \sigma(W_o [x_t, h_{t-1}]) \\
751
+ h_t &= o_t \odot \phi(c_t)
752
+
753
+ where:
754
+
755
+ - :math:`x_t` is the input vector at time t
756
+ - :math:`h_t` is the hidden state at time t
757
+ - :math:`c_t` is the cell state at time t
758
+ - :math:`f_t` is the forget gate with positive bias
759
+ - :math:`r_t` is the retention gate with negative bias
760
+ - :math:`g_t` is the unified gate combining forget and retention
761
+ - :math:`\tilde{c}_t` is the candidate cell state
762
+ - :math:`o_t` is the output gate
763
+ - :math:`\odot` represents element-wise multiplication
764
+ - :math:`\sigma` is the sigmoid activation function
765
+ - :math:`\phi` is the activation function (typically tanh)
766
+
767
+ The key innovation is the untied bias mechanism where the forget and retention
768
+ gates use opposite biases, initialized using a uniform distribution to encourage
769
+ diverse gating behavior across units.
770
+
771
+ Parameters
772
+ ----------
773
+ num_in : int
774
+ The number of input units.
775
+ num_out : int
776
+ The number of hidden/output units.
777
+ w_init : Union[ArrayLike, Callable], default=init.XavierNormal()
778
+ Initializer for the weight matrix.
779
+ state_init : Union[ArrayLike, Callable], default=init.ZeroInit()
780
+ Initializer for the hidden and cell states.
781
+ activation : str or Callable, default='tanh'
782
+ Activation function to use. Can be a string (e.g., 'relu', 'tanh')
783
+ or a callable function.
784
+ name : str, optional
785
+ Name of the module.
786
+
787
+ State Variables
788
+ ---------------
789
+ h : HiddenState
790
+ Hidden state of the URLSTM cell.
791
+ c : HiddenState
792
+ Cell state of the URLSTM cell.
793
+
794
+ Methods
795
+ -------
796
+ init_state(batch_size=None, **kwargs)
797
+ Initialize the cell and hidden states.
798
+ reset_state(batch_size=None, **kwargs)
799
+ Reset the cell and hidden states to their initial values.
800
+ update(x)
801
+ Update the cell and hidden states for one time step and return the hidden state.
802
+
803
+ Examples
804
+ --------
805
+ .. code-block:: python
806
+
807
+ >>> import brainstate as bs
808
+ >>> import jax.numpy as jnp
809
+ >>>
810
+ >>> # Create a URLSTM cell
811
+ >>> cell = bs.nn.URLSTMCell(num_in=10, num_out=20)
812
+ >>>
813
+ >>> # Initialize the state for batch size 32
814
+ >>> cell.init_state(batch_size=32)
815
+ >>>
816
+ >>> # Process a sequence
817
+ >>> x = jnp.ones((32, 10)) # batch_size x num_in
818
+ >>> output = cell.update(x)
819
+ >>> print(output.shape) # (32, 20)
820
+ >>>
821
+ >>> # Process multiple time steps
822
+ >>> sequence = jnp.ones((100, 32, 10)) # time_steps x batch_size x num_in
823
+ >>> outputs = []
824
+ >>> for t in range(100):
825
+ ... output = cell.update(sequence[t])
826
+ ... outputs.append(output)
827
+ >>> outputs = jnp.stack(outputs)
828
+ >>> print(outputs.shape) # (100, 32, 20)
829
+
830
+ References
831
+ ----------
832
+ .. [1] Gu, Albert, et al. "Improving the gating mechanism of recurrent neural networks."
833
+ International conference on machine learning. PMLR, 2020.
834
+ """
835
+ __module__ = 'brainstate.nn'
836
+
500
837
  def __init__(
501
838
  self,
502
839
  num_in: int,
@@ -521,34 +858,89 @@ class URLSTMCell(RNNCell):
521
858
  if isinstance(activation, str):
522
859
  self.activation = getattr(functional, activation)
523
860
  else:
524
- assert callable(activation), "The activation function should be a string or a callable function. "
861
+ assert callable(activation), "The activation function should be a string or a callable function."
525
862
  self.activation = activation
526
863
 
527
- # weights
864
+ # weights - 4 gates: forget, retention, candidate, output
528
865
  self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=None)
866
+
867
+ # Initialize untied bias using uniform distribution
529
868
  self.bias = ParamState(self._forget_bias())
530
869
 
531
870
  def _forget_bias(self):
871
+ """Initialize the forget gate bias using uniform distribution."""
872
+ # Sample from uniform distribution to encourage diverse gating
532
873
  u = random.uniform(1 / self.num_out, 1 - 1 / self.num_out, (self.num_out,))
874
+ # Transform to logit space for initialization
533
875
  return -jnp.log(1 / u - 1)
534
876
 
535
877
  def init_state(self, batch_size: int = None, **kwargs):
878
+ """
879
+ Initialize the cell and hidden states.
880
+
881
+ Parameters
882
+ ----------
883
+ batch_size : int, optional
884
+ The batch size for state initialization.
885
+ **kwargs
886
+ Additional keyword arguments.
887
+ """
536
888
  self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
537
889
  self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
538
890
 
539
891
  def reset_state(self, batch_size: int = None, **kwargs):
892
+ """
893
+ Reset the cell and hidden states to their initial values.
894
+
895
+ Parameters
896
+ ----------
897
+ batch_size : int, optional
898
+ The batch size for state reset.
899
+ **kwargs
900
+ Additional keyword arguments.
901
+ """
540
902
  self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
541
903
  self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
542
904
 
543
905
  def update(self, x: ArrayLike) -> ArrayLike:
906
+ """
907
+ Update the URLSTM cell for one time step.
908
+
909
+ Parameters
910
+ ----------
911
+ x : ArrayLike
912
+ Input tensor with shape (batch_size, num_in).
913
+
914
+ Returns
915
+ -------
916
+ ArrayLike
917
+ Hidden state tensor with shape (batch_size, num_out).
918
+ """
544
919
  h, c = self.h.value, self.c.value
545
- xh = jnp.concat([x, h], axis=-1)
546
- f, r, u, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
547
- f_ = functional.sigmoid(f + self.bias.value)
548
- r_ = functional.sigmoid(r - self.bias.value)
549
- g = 2 * r_ * f_ + (1 - 2 * r_) * f_ ** 2
920
+
921
+ # Concatenate input and hidden state
922
+ xh = jnp.concatenate([x, h], axis=-1)
923
+
924
+ # Compute all gates in one pass
925
+ gates = self.W(xh)
926
+ f, r, u, o = jnp.split(gates, indices_or_sections=4, axis=-1)
927
+
928
+ # Apply untied biases to forget and retention gates
929
+ f_gate = functional.sigmoid(f + self.bias.value)
930
+ r_gate = functional.sigmoid(r - self.bias.value)
931
+
932
+ # Compute unified gate
933
+ g = 2 * r_gate * f_gate + (1 - 2 * r_gate) * f_gate ** 2
934
+
935
+ # Update cell state
550
936
  next_cell = g * c + (1 - g) * self.activation(u)
551
- next_hidden = functional.sigmoid(o) * self.activation(next_cell)
937
+
938
+ # Compute output gate and hidden state
939
+ o_gate = functional.sigmoid(o)
940
+ next_hidden = o_gate * self.activation(next_cell)
941
+
942
+ # Update states
552
943
  self.h.value = next_hidden
553
944
  self.c.value = next_cell
945
+
554
946
  return next_hidden