brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,400 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Callable, Union
21
+
22
+ import jax.numpy as jnp
23
+
24
+ from brainstate import random, init, functional
25
+ from brainstate._state import HiddenState, ParamState
26
+ from brainstate.nn._interaction._linear import Linear
27
+ from brainstate.nn._module import Module
28
+ from brainstate.typing import ArrayLike
29
+
30
+ __all__ = [
31
+ 'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
32
+ ]
33
+
34
+
35
+ class RNNCell(Module):
36
+ """
37
+ Base class for RNN cells.
38
+ """
39
+ pass
40
+
41
+
42
+ class ValinaRNNCell(RNNCell):
43
+ """
44
+ Vanilla RNN cell.
45
+
46
+ Args:
47
+ num_in: int. The number of input units.
48
+ num_out: int. The number of hidden units.
49
+ state_init: callable, ArrayLike. The state initializer.
50
+ w_init: callable, ArrayLike. The input weight initializer.
51
+ b_init: optional, callable, ArrayLike. The bias weight initializer.
52
+ activation: str, callable. The activation function. It can be a string or a callable function.
53
+ name: optional, str. The name of the module.
54
+ """
55
+ __module__ = 'brainstate.nn'
56
+
57
+ def __init__(
58
+ self,
59
+ num_in: int,
60
+ num_out: int,
61
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
62
+ w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
63
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
64
+ activation: str | Callable = 'relu',
65
+ name: str = None,
66
+ ):
67
+ super().__init__(name=name)
68
+
69
+ # parameters
70
+ self.num_out = num_out
71
+ self.num_in = num_in
72
+ self.in_size = (num_in,)
73
+ self.out_size = (num_out,)
74
+ self._state_initializer = state_init
75
+
76
+ # activation function
77
+ if isinstance(activation, str):
78
+ self.activation = getattr(functional, activation)
79
+ else:
80
+ assert callable(activation), "The activation function should be a string or a callable function. "
81
+ self.activation = activation
82
+
83
+ # weights
84
+ self.W = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
85
+
86
+ def init_state(self, batch_size: int = None, **kwargs):
87
+ self.h = HiddenState(init.param(self._state_initializer, self.num_out, batch_size))
88
+
89
+ def reset_state(self, batch_size: int = None, **kwargs):
90
+ self.h.value = init.param(self._state_initializer, self.num_out, batch_size)
91
+
92
+ def update(self, x):
93
+ xh = jnp.concatenate([x, self.h.value], axis=-1)
94
+ h = self.W(xh)
95
+ self.h.value = self.activation(h)
96
+ return self.h.value
97
+
98
+
99
+ class GRUCell(RNNCell):
100
+ """
101
+ Gated Recurrent Unit (GRU) cell.
102
+
103
+ Args:
104
+ num_in: int. The number of input units.
105
+ num_out: int. The number of hidden units.
106
+ state_init: callable, ArrayLike. The state initializer.
107
+ w_init: callable, ArrayLike. The input weight initializer.
108
+ b_init: optional, callable, ArrayLike. The bias weight initializer.
109
+ activation: str, callable. The activation function. It can be a string or a callable function.
110
+ name: optional, str. The name of the module.
111
+ """
112
+ __module__ = 'brainstate.nn'
113
+
114
+ def __init__(
115
+ self,
116
+ num_in: int,
117
+ num_out: int,
118
+ w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
119
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
120
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
121
+ activation: str | Callable = 'tanh',
122
+ name: str = None,
123
+ ):
124
+ super().__init__(name=name)
125
+
126
+ # parameters
127
+ self._state_initializer = state_init
128
+ self.num_out = num_out
129
+ self.num_in = num_in
130
+ self.in_size = (num_in,)
131
+ self.out_size = (num_out,)
132
+
133
+ # activation function
134
+ if isinstance(activation, str):
135
+ self.activation = getattr(functional, activation)
136
+ else:
137
+ assert callable(activation), "The activation function should be a string or a callable function. "
138
+ self.activation = activation
139
+
140
+ # weights
141
+ self.Wrz = Linear(num_in + num_out, num_out * 2, w_init=w_init, b_init=b_init)
142
+ self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
143
+
144
+ def init_state(self, batch_size: int = None, **kwargs):
145
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
146
+
147
+ def reset_state(self, batch_size: int = None, **kwargs):
148
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
149
+
150
+ def update(self, x):
151
+ old_h = self.h.value
152
+ xh = jnp.concatenate([x, old_h], axis=-1)
153
+ r, z = jnp.split(functional.sigmoid(self.Wrz(xh)), indices_or_sections=2, axis=-1)
154
+ rh = r * old_h
155
+ h = self.activation(self.Wh(jnp.concatenate([x, rh], axis=-1)))
156
+ h = (1 - z) * old_h + z * h
157
+ self.h.value = h
158
+ return h
159
+
160
+
161
+ class MGUCell(RNNCell):
162
+ r"""
163
+ Minimal Gated Recurrent Unit (MGU) cell.
164
+
165
+ .. math::
166
+
167
+ \begin{aligned}
168
+ f_{t}&=\sigma (W_{f}x_{t}+U_{f}h_{t-1}+b_{f})\\
169
+ {\hat {h}}_{t}&=\phi (W_{h}x_{t}+U_{h}(f_{t}\odot h_{t-1})+b_{h})\\
170
+ h_{t}&=(1-f_{t})\odot h_{t-1}+f_{t}\odot {\hat {h}}_{t}
171
+ \end{aligned}
172
+
173
+ where:
174
+
175
+ - :math:`x_{t}`: input vector
176
+ - :math:`h_{t}`: output vector
177
+ - :math:`{\hat {h}}_{t}`: candidate activation vector
178
+ - :math:`f_{t}`: forget vector
179
+ - :math:`W, U, b`: parameter matrices and vector
180
+
181
+ Args:
182
+ num_in: int. The number of input units.
183
+ num_out: int. The number of hidden units.
184
+ state_init: callable, ArrayLike. The state initializer.
185
+ w_init: callable, ArrayLike. The input weight initializer.
186
+ b_init: optional, callable, ArrayLike. The bias weight initializer.
187
+ activation: str, callable. The activation function. It can be a string or a callable function.
188
+ name: optional, str. The name of the module.
189
+ """
190
+ __module__ = 'brainstate.nn'
191
+
192
+ def __init__(
193
+ self,
194
+ num_in: int,
195
+ num_out: int,
196
+ w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
197
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
198
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
199
+ activation: str | Callable = 'tanh',
200
+ name: str = None,
201
+ ):
202
+ super().__init__(name=name)
203
+
204
+ # parameters
205
+ self._state_initializer = state_init
206
+ self.num_out = num_out
207
+ self.num_in = num_in
208
+ self.in_size = (num_in,)
209
+ self.out_size = (num_out,)
210
+
211
+ # activation function
212
+ if isinstance(activation, str):
213
+ self.activation = getattr(functional, activation)
214
+ else:
215
+ assert callable(activation), "The activation function should be a string or a callable function. "
216
+ self.activation = activation
217
+
218
+ # weights
219
+ self.Wf = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
220
+ self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
221
+
222
+ def init_state(self, batch_size: int = None, **kwargs):
223
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
224
+
225
+ def reset_state(self, batch_size: int = None, **kwargs):
226
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
227
+
228
+ def update(self, x):
229
+ old_h = self.h.value
230
+ xh = jnp.concatenate([x, old_h], axis=-1)
231
+ f = functional.sigmoid(self.Wf(xh))
232
+ fh = f * old_h
233
+ h = self.activation(self.Wh(jnp.concatenate([x, fh], axis=-1)))
234
+ self.h.value = (1 - f) * self.h.value + f * h
235
+ return self.h.value
236
+
237
+
238
+ class LSTMCell(RNNCell):
239
+ r"""Long short-term memory (LSTM) RNN core.
240
+
241
+ The implementation is based on (zaremba, et al., 2014) [1]_. Given
242
+ :math:`x_t` and the previous state :math:`(h_{t-1}, c_{t-1})` the core
243
+ computes
244
+
245
+ .. math::
246
+
247
+ \begin{array}{ll}
248
+ i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\
249
+ f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\
250
+ g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\
251
+ o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\
252
+ c_t = f_t c_{t-1} + i_t g_t \\
253
+ h_t = o_t \tanh(c_t)
254
+ \end{array}
255
+
256
+ where :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and
257
+ output gate activations, and :math:`g_t` is a vector of cell updates.
258
+
259
+ The output is equal to the new hidden, :math:`h_t`.
260
+
261
+ Notes
262
+ -----
263
+
264
+ Forget gate initialization: Following (Jozefowicz, et al., 2015) [2]_ we add 1.0
265
+ to :math:`b_f` after initialization in order to reduce the scale of forgetting in
266
+ the beginning of the training.
267
+
268
+
269
+ Parameters
270
+ ----------
271
+ num_in: int
272
+ The dimension of the input vector
273
+ num_out: int
274
+ The number of hidden unit in the node.
275
+ state_init: callable, ArrayLike
276
+ The state initializer.
277
+ w_init: callable, ArrayLike
278
+ The input weight initializer.
279
+ b_init: optional, callable, ArrayLike
280
+ The bias weight initializer.
281
+ activation: str, callable
282
+ The activation function. It can be a string or a callable function.
283
+
284
+ References
285
+ ----------
286
+
287
+ .. [1] Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. "Recurrent neural
288
+ network regularization." arXiv preprint arXiv:1409.2329 (2014).
289
+ .. [2] Jozefowicz, Rafal, Wojciech Zaremba, and Ilya Sutskever. "An empirical
290
+ exploration of recurrent network architectures." In International conference
291
+ on machine learning, pp. 2342-2350. PMLR, 2015.
292
+ """
293
+ __module__ = 'brainstate.nn'
294
+
295
+ def __init__(
296
+ self,
297
+ num_in: int,
298
+ num_out: int,
299
+ w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
300
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
301
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
302
+ activation: str | Callable = 'tanh',
303
+ name: str = None,
304
+ ):
305
+ super().__init__(name=name)
306
+
307
+ # parameters
308
+ self.num_out = num_out
309
+ self.num_in = num_in
310
+ self.in_size = (num_in,)
311
+ self.out_size = (num_out,)
312
+
313
+ # initializers
314
+ self._state_initializer = state_init
315
+
316
+ # activation function
317
+ if isinstance(activation, str):
318
+ self.activation = getattr(functional, activation)
319
+ else:
320
+ assert callable(activation), "The activation function should be a string or a callable function. "
321
+ self.activation = activation
322
+
323
+ # weights
324
+ self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=b_init)
325
+
326
+ def init_state(self, batch_size: int = None, **kwargs):
327
+ self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
328
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
329
+
330
+ def reset_state(self, batch_size: int = None, **kwargs):
331
+ self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
332
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
333
+
334
+ def update(self, x):
335
+ h, c = self.h.value, self.c.value
336
+ xh = jnp.concat([x, h], axis=-1)
337
+ i, g, f, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
338
+ c = functional.sigmoid(f + 1.) * c + functional.sigmoid(i) * self.activation(g)
339
+ h = functional.sigmoid(o) * self.activation(c)
340
+ self.h.value = h
341
+ self.c.value = c
342
+ return h
343
+
344
+
345
+ class URLSTMCell(RNNCell):
346
+ def __init__(
347
+ self,
348
+ num_in: int,
349
+ num_out: int,
350
+ w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
351
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
352
+ activation: str | Callable = 'tanh',
353
+ name: str = None,
354
+ ):
355
+ super().__init__(name=name)
356
+
357
+ # parameters
358
+ self.num_out = num_out
359
+ self.num_in = num_in
360
+ self.in_size = (num_in,)
361
+ self.out_size = (num_out,)
362
+
363
+ # initializers
364
+ self._state_initializer = state_init
365
+
366
+ # activation function
367
+ if isinstance(activation, str):
368
+ self.activation = getattr(functional, activation)
369
+ else:
370
+ assert callable(activation), "The activation function should be a string or a callable function. "
371
+ self.activation = activation
372
+
373
+ # weights
374
+ self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=None)
375
+ self.bias = ParamState(self._forget_bias())
376
+
377
+ def _forget_bias(self):
378
+ u = random.uniform(1 / self.num_out, 1 - 1 / self.num_out, (self.num_out,))
379
+ return -jnp.log(1 / u - 1)
380
+
381
+ def init_state(self, batch_size: int = None, **kwargs):
382
+ self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
383
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
384
+
385
+ def reset_state(self, batch_size: int = None, **kwargs):
386
+ self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
387
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
388
+
389
+ def update(self, x: ArrayLike) -> ArrayLike:
390
+ h, c = self.h.value, self.c.value
391
+ xh = jnp.concat([x, h], axis=-1)
392
+ f, r, u, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
393
+ f_ = functional.sigmoid(f + self.bias.value)
394
+ r_ = functional.sigmoid(r - self.bias.value)
395
+ g = 2 * r_ * f_ + (1 - 2 * r_) * f_ ** 2
396
+ next_cell = g * c + (1 - g) * self.activation(u)
397
+ next_hidden = functional.sigmoid(o) * self.activation(next_cell)
398
+ self.h.value = next_hidden
399
+ self.c.value = next_cell
400
+ return next_hidden
@@ -0,0 +1,64 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+
20
+ import jax.numpy as jnp
21
+
22
+ import brainstate as bst
23
+
24
+
25
+ class TestRateRNNModels(unittest.TestCase):
26
+ def setUp(self):
27
+ self.num_in = 3
28
+ self.num_out = 3
29
+ self.batch_size = 4
30
+ self.x = jnp.ones((self.batch_size, self.num_in))
31
+
32
+ def test_ValinaRNNCell(self):
33
+ model = bst.nn.ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
34
+ model.init_state(batch_size=self.batch_size)
35
+ output = model.update(self.x)
36
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
37
+
38
+ def test_GRUCell(self):
39
+ model = bst.nn.GRUCell(num_in=self.num_in, num_out=self.num_out)
40
+ model.init_state(batch_size=self.batch_size)
41
+ output = model.update(self.x)
42
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
43
+
44
+ def test_MGUCell(self):
45
+ model = bst.nn.MGUCell(num_in=self.num_in, num_out=self.num_out)
46
+ model.init_state(batch_size=self.batch_size)
47
+ output = model.update(self.x)
48
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
49
+
50
+ def test_LSTMCell(self):
51
+ model = bst.nn.LSTMCell(num_in=self.num_in, num_out=self.num_out)
52
+ model.init_state(batch_size=self.batch_size)
53
+ output = model.update(self.x)
54
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
55
+
56
+ def test_URLSTMCell(self):
57
+ model = bst.nn.URLSTMCell(num_in=self.num_in, num_out=self.num_out)
58
+ model.init_state(batch_size=self.batch_size)
59
+ output = model.update(self.x)
60
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
61
+
62
+
63
+ if __name__ == '__main__':
64
+ unittest.main()
@@ -0,0 +1,128 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ import numbers
21
+ from typing import Callable
22
+
23
+ import brainunit as u
24
+ import jax
25
+
26
+ from brainstate import environ, init, surrogate
27
+ from brainstate._state import HiddenState, ParamState
28
+ from brainstate.nn._exp_euler import exp_euler_step
29
+ from brainstate.nn._module import Module
30
+ from brainstate.typing import Size, ArrayLike
31
+ from ._dynamics_neuron import Neuron
32
+
33
+ __all__ = [
34
+ 'LeakyRateReadout',
35
+ 'LeakySpikeReadout',
36
+ ]
37
+
38
+
39
+ class LeakyRateReadout(Module):
40
+ """
41
+ Leaky dynamics for the read-out module used in the Real-Time Recurrent Learning.
42
+ """
43
+ __module__ = 'brainstate.nn'
44
+
45
+ def __init__(
46
+ self,
47
+ in_size: Size,
48
+ out_size: Size,
49
+ tau: ArrayLike = 5. * u.ms,
50
+ w_init: Callable = init.KaimingNormal(),
51
+ name: str = None,
52
+ ):
53
+ super().__init__(name=name)
54
+
55
+ # parameters
56
+ self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
57
+ self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
58
+ self.tau = init.param(tau, self.in_size)
59
+ self.decay = u.math.exp(-environ.get_dt() / self.tau)
60
+
61
+ # weights
62
+ self.weight = ParamState(init.param(w_init, (self.in_size[0], self.out_size[0])))
63
+
64
+ def init_state(self, batch_size=None, **kwargs):
65
+ self.r = HiddenState(init.param(init.Constant(0.), self.out_size, batch_size))
66
+
67
+ def reset_state(self, batch_size=None, **kwargs):
68
+ self.r.value = init.param(init.Constant(0.), self.out_size, batch_size)
69
+
70
+ def update(self, x):
71
+ self.r.value = self.decay * self.r.value + x @ self.weight.value
72
+ return self.r.value
73
+
74
+
75
+ class LeakySpikeReadout(Neuron):
76
+ """
77
+ Integrate-and-fire neuron model with leaky dynamics.
78
+ """
79
+
80
+ __module__ = 'brainstate.nn'
81
+
82
+ def __init__(
83
+ self,
84
+ in_size: Size,
85
+ tau: ArrayLike = 5. * u.ms,
86
+ V_th: ArrayLike = 1. * u.mV,
87
+ w_init: Callable = init.KaimingNormal(unit=u.mV),
88
+ V_initializer: ArrayLike = init.ZeroInit(unit=u.mV),
89
+ spk_fun: Callable = surrogate.ReluGrad(),
90
+ spk_reset: str = 'soft',
91
+ name: str = None,
92
+ ):
93
+ super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
94
+
95
+ # parameters
96
+ self.tau = init.param(tau, (self.varshape,))
97
+ self.V_th = init.param(V_th, (self.varshape,))
98
+ self.V_initializer = V_initializer
99
+
100
+ # weights
101
+ self.weight = ParamState(init.param(w_init, (self.in_size[-1], self.out_size[-1])))
102
+
103
+ def init_state(self, batch_size, **kwargs):
104
+ self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
105
+
106
+ def reset_state(self, batch_size, **kwargs):
107
+ self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
108
+
109
+ @property
110
+ def spike(self):
111
+ return self.get_spike(self.V.value)
112
+
113
+ def get_spike(self, V):
114
+ v_scaled = (V - self.V_th) / self.V_th
115
+ return self.spk_fun(v_scaled)
116
+
117
+ def update(self, spk):
118
+ # reset
119
+ last_V = self.V.value
120
+ last_spike = self.get_spike(last_V)
121
+ V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_V)
122
+ V = last_V - V_th * last_spike
123
+ # membrane potential
124
+ x = spk @ self.weight.value
125
+ dv = lambda v: (-v + self.sum_current_inputs(x, v)) / self.tau
126
+ V = exp_euler_step(dv, V)
127
+ self.V.value = self.sum_delta_inputs(V)
128
+ return self.get_spike(V)
@@ -0,0 +1,54 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+
20
+ import jax.numpy as jnp
21
+
22
+ import brainstate as bst
23
+
24
+
25
+ class TestReadoutModels(unittest.TestCase):
26
+ def setUp(self):
27
+ self.in_size = 3
28
+ self.out_size = 3
29
+ self.batch_size = 4
30
+ self.tau = 5.0
31
+ self.V_th = 1.0
32
+ self.x = jnp.ones((self.batch_size, self.in_size))
33
+
34
+ def test_LeakyRateReadout(self):
35
+ with bst.environ.context(dt=0.1):
36
+ model = bst.nn.LeakyRateReadout(in_size=self.in_size, out_size=self.out_size, tau=self.tau)
37
+ model.init_state(batch_size=self.batch_size)
38
+ output = model.update(self.x)
39
+ self.assertEqual(output.shape, (self.batch_size, self.out_size))
40
+
41
+ def test_LeakySpikeReadout(self):
42
+ with bst.environ.context(dt=0.1):
43
+ model = bst.nn.LeakySpikeReadout(in_size=self.in_size, tau=self.tau, V_th=self.V_th,
44
+ V_initializer=bst.init.ZeroInit(),
45
+ w_init=bst.init.KaimingNormal())
46
+ model.init_state(batch_size=self.batch_size)
47
+ with bst.environ.context(t=0.):
48
+ output = model.update(self.x)
49
+ self.assertEqual(output.shape, (self.batch_size, self.out_size))
50
+
51
+
52
+ if __name__ == '__main__':
53
+ with bst.environ.context(dt=0.1):
54
+ unittest.main()
@@ -0,0 +1,37 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ._dynamics_base import *
17
+ from ._dynamics_base import __all__ as dyn_all
18
+ from ._projection_base import *
19
+ from ._projection_base import __all__ as projection_all
20
+ from ._state_delay import *
21
+ from ._state_delay import __all__ as state_delay_all
22
+ from ._synouts import *
23
+ from ._synouts import __all__ as synouts_all
24
+
25
+ __all__ = (
26
+ dyn_all
27
+ + projection_all
28
+ + state_delay_all
29
+ + synouts_all
30
+ )
31
+
32
+ del (
33
+ dyn_all,
34
+ projection_all,
35
+ state_delay_all,
36
+ synouts_all
37
+ )