brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. brainstate/__init__.py +31 -11
  2. brainstate/_state.py +760 -316
  3. brainstate/_state_test.py +41 -12
  4. brainstate/_utils.py +31 -4
  5. brainstate/augment/__init__.py +40 -0
  6. brainstate/augment/_autograd.py +608 -0
  7. brainstate/augment/_autograd_test.py +1193 -0
  8. brainstate/augment/_eval_shape.py +102 -0
  9. brainstate/augment/_eval_shape_test.py +40 -0
  10. brainstate/augment/_mapping.py +525 -0
  11. brainstate/augment/_mapping_test.py +210 -0
  12. brainstate/augment/_random.py +99 -0
  13. brainstate/{transform → compile}/__init__.py +25 -13
  14. brainstate/compile/_ad_checkpoint.py +204 -0
  15. brainstate/compile/_ad_checkpoint_test.py +51 -0
  16. brainstate/compile/_conditions.py +259 -0
  17. brainstate/compile/_conditions_test.py +221 -0
  18. brainstate/compile/_error_if.py +94 -0
  19. brainstate/compile/_error_if_test.py +54 -0
  20. brainstate/compile/_jit.py +314 -0
  21. brainstate/compile/_jit_test.py +143 -0
  22. brainstate/compile/_loop_collect_return.py +516 -0
  23. brainstate/compile/_loop_collect_return_test.py +59 -0
  24. brainstate/compile/_loop_no_collection.py +185 -0
  25. brainstate/compile/_loop_no_collection_test.py +51 -0
  26. brainstate/compile/_make_jaxpr.py +756 -0
  27. brainstate/compile/_make_jaxpr_test.py +134 -0
  28. brainstate/compile/_progress_bar.py +111 -0
  29. brainstate/compile/_unvmap.py +159 -0
  30. brainstate/compile/_util.py +147 -0
  31. brainstate/environ.py +408 -381
  32. brainstate/environ_test.py +34 -32
  33. brainstate/{nn/event → event}/__init__.py +6 -6
  34. brainstate/event/_csr.py +308 -0
  35. brainstate/event/_csr_test.py +118 -0
  36. brainstate/event/_fixed_probability.py +271 -0
  37. brainstate/event/_fixed_probability_test.py +128 -0
  38. brainstate/event/_linear.py +219 -0
  39. brainstate/event/_linear_test.py +112 -0
  40. brainstate/{nn/event → event}/_misc.py +7 -7
  41. brainstate/functional/_activations.py +521 -511
  42. brainstate/functional/_activations_test.py +300 -300
  43. brainstate/functional/_normalization.py +43 -43
  44. brainstate/functional/_others.py +15 -15
  45. brainstate/functional/_spikes.py +49 -49
  46. brainstate/graph/__init__.py +33 -0
  47. brainstate/graph/_graph_context.py +443 -0
  48. brainstate/graph/_graph_context_test.py +65 -0
  49. brainstate/graph/_graph_convert.py +246 -0
  50. brainstate/graph/_graph_node.py +300 -0
  51. brainstate/graph/_graph_node_test.py +75 -0
  52. brainstate/graph/_graph_operation.py +1746 -0
  53. brainstate/graph/_graph_operation_test.py +724 -0
  54. brainstate/init/_base.py +28 -10
  55. brainstate/init/_generic.py +175 -172
  56. brainstate/init/_random_inits.py +470 -415
  57. brainstate/init/_random_inits_test.py +150 -0
  58. brainstate/init/_regular_inits.py +66 -69
  59. brainstate/init/_regular_inits_test.py +51 -0
  60. brainstate/mixin.py +236 -244
  61. brainstate/mixin_test.py +44 -46
  62. brainstate/nn/__init__.py +26 -51
  63. brainstate/nn/_collective_ops.py +199 -0
  64. brainstate/nn/_dyn_impl/__init__.py +46 -0
  65. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  66. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  67. brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
  68. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  69. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  70. brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
  71. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  72. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  73. brainstate/nn/_dyn_impl/_readout.py +128 -0
  74. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  75. brainstate/nn/_dynamics/__init__.py +37 -0
  76. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  77. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  78. brainstate/nn/_dynamics/_projection_base.py +346 -0
  79. brainstate/nn/_dynamics/_state_delay.py +453 -0
  80. brainstate/nn/_dynamics/_synouts.py +161 -0
  81. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  82. brainstate/nn/_elementwise/__init__.py +22 -0
  83. brainstate/nn/_elementwise/_dropout.py +418 -0
  84. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  85. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  86. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  87. brainstate/nn/_exp_euler.py +97 -0
  88. brainstate/nn/_exp_euler_test.py +36 -0
  89. brainstate/nn/_interaction/__init__.py +32 -0
  90. brainstate/nn/_interaction/_connections.py +726 -0
  91. brainstate/nn/_interaction/_connections_test.py +254 -0
  92. brainstate/nn/_interaction/_embedding.py +59 -0
  93. brainstate/nn/_interaction/_normalizations.py +388 -0
  94. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  95. brainstate/nn/_interaction/_poolings.py +1179 -0
  96. brainstate/nn/_interaction/_poolings_test.py +219 -0
  97. brainstate/nn/_module.py +328 -0
  98. brainstate/nn/_module_test.py +211 -0
  99. brainstate/nn/metrics.py +309 -309
  100. brainstate/optim/__init__.py +14 -2
  101. brainstate/optim/_base.py +66 -0
  102. brainstate/optim/_lr_scheduler.py +363 -400
  103. brainstate/optim/_lr_scheduler_test.py +25 -24
  104. brainstate/optim/_optax_optimizer.py +103 -176
  105. brainstate/optim/_optax_optimizer_test.py +41 -1
  106. brainstate/optim/_sgd_optimizer.py +950 -1025
  107. brainstate/random/_rand_funs.py +3269 -3268
  108. brainstate/random/_rand_funs_test.py +568 -0
  109. brainstate/random/_rand_seed.py +149 -117
  110. brainstate/random/_rand_seed_test.py +50 -0
  111. brainstate/random/_rand_state.py +1356 -1321
  112. brainstate/random/_random_for_unit.py +13 -13
  113. brainstate/surrogate.py +1262 -1243
  114. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  115. brainstate/typing.py +157 -130
  116. brainstate/util/__init__.py +52 -0
  117. brainstate/util/_caller.py +100 -0
  118. brainstate/util/_dict.py +734 -0
  119. brainstate/util/_dict_test.py +160 -0
  120. brainstate/util/_error.py +28 -0
  121. brainstate/util/_filter.py +178 -0
  122. brainstate/util/_others.py +497 -0
  123. brainstate/util/_pretty_repr.py +208 -0
  124. brainstate/util/_scaling.py +260 -0
  125. brainstate/util/_struct.py +524 -0
  126. brainstate/util/_tracers.py +75 -0
  127. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  128. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
  129. brainstate-0.1.0.dist-info/RECORD +135 -0
  130. brainstate/_module.py +0 -1637
  131. brainstate/_module_test.py +0 -207
  132. brainstate/nn/_base.py +0 -251
  133. brainstate/nn/_connections.py +0 -686
  134. brainstate/nn/_dynamics.py +0 -426
  135. brainstate/nn/_elementwise.py +0 -1438
  136. brainstate/nn/_embedding.py +0 -66
  137. brainstate/nn/_misc.py +0 -133
  138. brainstate/nn/_normalizations.py +0 -389
  139. brainstate/nn/_others.py +0 -101
  140. brainstate/nn/_poolings.py +0 -1229
  141. brainstate/nn/_poolings_test.py +0 -231
  142. brainstate/nn/_projection/_align_post.py +0 -546
  143. brainstate/nn/_projection/_align_pre.py +0 -599
  144. brainstate/nn/_projection/_delta.py +0 -241
  145. brainstate/nn/_projection/_vanilla.py +0 -101
  146. brainstate/nn/_rate_rnns.py +0 -410
  147. brainstate/nn/_readout.py +0 -136
  148. brainstate/nn/_synouts.py +0 -166
  149. brainstate/nn/event/csr.py +0 -312
  150. brainstate/nn/event/csr_test.py +0 -118
  151. brainstate/nn/event/fixed_probability.py +0 -276
  152. brainstate/nn/event/fixed_probability_test.py +0 -127
  153. brainstate/nn/event/linear.py +0 -220
  154. brainstate/nn/event/linear_test.py +0 -111
  155. brainstate/random/random_test.py +0 -593
  156. brainstate/transform/_autograd.py +0 -585
  157. brainstate/transform/_autograd_test.py +0 -1181
  158. brainstate/transform/_conditions.py +0 -334
  159. brainstate/transform/_conditions_test.py +0 -220
  160. brainstate/transform/_error_if.py +0 -94
  161. brainstate/transform/_error_if_test.py +0 -55
  162. brainstate/transform/_jit.py +0 -265
  163. brainstate/transform/_jit_test.py +0 -118
  164. brainstate/transform/_loop_collect_return.py +0 -502
  165. brainstate/transform/_loop_no_collection.py +0 -170
  166. brainstate/transform/_make_jaxpr.py +0 -739
  167. brainstate/transform/_make_jaxpr_test.py +0 -131
  168. brainstate/transform/_mapping.py +0 -109
  169. brainstate/transform/_progress_bar.py +0 -111
  170. brainstate/transform/_unvmap.py +0 -143
  171. brainstate/util.py +0 -746
  172. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  173. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
  174. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
  175. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,400 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Callable, Union
21
+
22
+ import jax.numpy as jnp
23
+
24
+ from brainstate import random, init, functional
25
+ from brainstate._state import HiddenState, ParamState
26
+ from brainstate.nn._interaction._connections import Linear
27
+ from brainstate.nn._module import Module
28
+ from brainstate.typing import ArrayLike
29
+
30
+ __all__ = [
31
+ 'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
32
+ ]
33
+
34
+
35
+ class RNNCell(Module):
36
+ """
37
+ Base class for RNN cells.
38
+ """
39
+ pass
40
+
41
+
42
+ class ValinaRNNCell(RNNCell):
43
+ """
44
+ Vanilla RNN cell.
45
+
46
+ Args:
47
+ num_in: int. The number of input units.
48
+ num_out: int. The number of hidden units.
49
+ state_init: callable, ArrayLike. The state initializer.
50
+ w_init: callable, ArrayLike. The input weight initializer.
51
+ b_init: optional, callable, ArrayLike. The bias weight initializer.
52
+ activation: str, callable. The activation function. It can be a string or a callable function.
53
+ name: optional, str. The name of the module.
54
+ """
55
+ __module__ = 'brainstate.nn'
56
+
57
+ def __init__(
58
+ self,
59
+ num_in: int,
60
+ num_out: int,
61
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
62
+ w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
63
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
64
+ activation: str | Callable = 'relu',
65
+ name: str = None,
66
+ ):
67
+ super().__init__(name=name)
68
+
69
+ # parameters
70
+ self.num_out = num_out
71
+ self.num_in = num_in
72
+ self.in_size = (num_in,)
73
+ self.out_size = (num_out,)
74
+ self._state_initializer = state_init
75
+
76
+ # activation function
77
+ if isinstance(activation, str):
78
+ self.activation = getattr(functional, activation)
79
+ else:
80
+ assert callable(activation), "The activation function should be a string or a callable function. "
81
+ self.activation = activation
82
+
83
+ # weights
84
+ self.W = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
85
+
86
+ def init_state(self, batch_size: int = None, **kwargs):
87
+ self.h = HiddenState(init.param(self._state_initializer, self.num_out, batch_size))
88
+
89
+ def reset_state(self, batch_size: int = None, **kwargs):
90
+ self.h.value = init.param(self._state_initializer, self.num_out, batch_size)
91
+
92
+ def update(self, x):
93
+ xh = jnp.concatenate([x, self.h.value], axis=-1)
94
+ h = self.W(xh)
95
+ self.h.value = self.activation(h)
96
+ return self.h.value
97
+
98
+
99
+ class GRUCell(RNNCell):
100
+ """
101
+ Gated Recurrent Unit (GRU) cell.
102
+
103
+ Args:
104
+ num_in: int. The number of input units.
105
+ num_out: int. The number of hidden units.
106
+ state_init: callable, ArrayLike. The state initializer.
107
+ w_init: callable, ArrayLike. The input weight initializer.
108
+ b_init: optional, callable, ArrayLike. The bias weight initializer.
109
+ activation: str, callable. The activation function. It can be a string or a callable function.
110
+ name: optional, str. The name of the module.
111
+ """
112
+ __module__ = 'brainstate.nn'
113
+
114
+ def __init__(
115
+ self,
116
+ num_in: int,
117
+ num_out: int,
118
+ w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
119
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
120
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
121
+ activation: str | Callable = 'tanh',
122
+ name: str = None,
123
+ ):
124
+ super().__init__(name=name)
125
+
126
+ # parameters
127
+ self._state_initializer = state_init
128
+ self.num_out = num_out
129
+ self.num_in = num_in
130
+ self.in_size = (num_in,)
131
+ self.out_size = (num_out,)
132
+
133
+ # activation function
134
+ if isinstance(activation, str):
135
+ self.activation = getattr(functional, activation)
136
+ else:
137
+ assert callable(activation), "The activation function should be a string or a callable function. "
138
+ self.activation = activation
139
+
140
+ # weights
141
+ self.Wrz = Linear(num_in + num_out, num_out * 2, w_init=w_init, b_init=b_init)
142
+ self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
143
+
144
+ def init_state(self, batch_size: int = None, **kwargs):
145
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
146
+
147
+ def reset_state(self, batch_size: int = None, **kwargs):
148
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
149
+
150
+ def update(self, x):
151
+ old_h = self.h.value
152
+ xh = jnp.concatenate([x, old_h], axis=-1)
153
+ r, z = jnp.split(functional.sigmoid(self.Wrz(xh)), indices_or_sections=2, axis=-1)
154
+ rh = r * old_h
155
+ h = self.activation(self.Wh(jnp.concatenate([x, rh], axis=-1)))
156
+ h = (1 - z) * old_h + z * h
157
+ self.h.value = h
158
+ return h
159
+
160
+
161
+ class MGUCell(RNNCell):
162
+ r"""
163
+ Minimal Gated Recurrent Unit (MGU) cell.
164
+
165
+ .. math::
166
+
167
+ \begin{aligned}
168
+ f_{t}&=\sigma (W_{f}x_{t}+U_{f}h_{t-1}+b_{f})\\
169
+ {\hat {h}}_{t}&=\phi (W_{h}x_{t}+U_{h}(f_{t}\odot h_{t-1})+b_{h})\\
170
+ h_{t}&=(1-f_{t})\odot h_{t-1}+f_{t}\odot {\hat {h}}_{t}
171
+ \end{aligned}
172
+
173
+ where:
174
+
175
+ - :math:`x_{t}`: input vector
176
+ - :math:`h_{t}`: output vector
177
+ - :math:`{\hat {h}}_{t}`: candidate activation vector
178
+ - :math:`f_{t}`: forget vector
179
+ - :math:`W, U, b`: parameter matrices and vector
180
+
181
+ Args:
182
+ num_in: int. The number of input units.
183
+ num_out: int. The number of hidden units.
184
+ state_init: callable, ArrayLike. The state initializer.
185
+ w_init: callable, ArrayLike. The input weight initializer.
186
+ b_init: optional, callable, ArrayLike. The bias weight initializer.
187
+ activation: str, callable. The activation function. It can be a string or a callable function.
188
+ name: optional, str. The name of the module.
189
+ """
190
+ __module__ = 'brainstate.nn'
191
+
192
+ def __init__(
193
+ self,
194
+ num_in: int,
195
+ num_out: int,
196
+ w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
197
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
198
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
199
+ activation: str | Callable = 'tanh',
200
+ name: str = None,
201
+ ):
202
+ super().__init__(name=name)
203
+
204
+ # parameters
205
+ self._state_initializer = state_init
206
+ self.num_out = num_out
207
+ self.num_in = num_in
208
+ self.in_size = (num_in,)
209
+ self.out_size = (num_out,)
210
+
211
+ # activation function
212
+ if isinstance(activation, str):
213
+ self.activation = getattr(functional, activation)
214
+ else:
215
+ assert callable(activation), "The activation function should be a string or a callable function. "
216
+ self.activation = activation
217
+
218
+ # weights
219
+ self.Wf = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
220
+ self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
221
+
222
+ def init_state(self, batch_size: int = None, **kwargs):
223
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
224
+
225
+ def reset_state(self, batch_size: int = None, **kwargs):
226
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
227
+
228
+ def update(self, x):
229
+ old_h = self.h.value
230
+ xh = jnp.concatenate([x, old_h], axis=-1)
231
+ f = functional.sigmoid(self.Wf(xh))
232
+ fh = f * old_h
233
+ h = self.activation(self.Wh(jnp.concatenate([x, fh], axis=-1)))
234
+ self.h.value = (1 - f) * self.h.value + f * h
235
+ return self.h.value
236
+
237
+
238
+ class LSTMCell(RNNCell):
239
+ r"""Long short-term memory (LSTM) RNN core.
240
+
241
+ The implementation is based on (zaremba, et al., 2014) [1]_. Given
242
+ :math:`x_t` and the previous state :math:`(h_{t-1}, c_{t-1})` the core
243
+ computes
244
+
245
+ .. math::
246
+
247
+ \begin{array}{ll}
248
+ i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\
249
+ f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\
250
+ g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\
251
+ o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\
252
+ c_t = f_t c_{t-1} + i_t g_t \\
253
+ h_t = o_t \tanh(c_t)
254
+ \end{array}
255
+
256
+ where :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and
257
+ output gate activations, and :math:`g_t` is a vector of cell updates.
258
+
259
+ The output is equal to the new hidden, :math:`h_t`.
260
+
261
+ Notes
262
+ -----
263
+
264
+ Forget gate initialization: Following (Jozefowicz, et al., 2015) [2]_ we add 1.0
265
+ to :math:`b_f` after initialization in order to reduce the scale of forgetting in
266
+ the beginning of the training.
267
+
268
+
269
+ Parameters
270
+ ----------
271
+ num_in: int
272
+ The dimension of the input vector
273
+ num_out: int
274
+ The number of hidden unit in the node.
275
+ state_init: callable, ArrayLike
276
+ The state initializer.
277
+ w_init: callable, ArrayLike
278
+ The input weight initializer.
279
+ b_init: optional, callable, ArrayLike
280
+ The bias weight initializer.
281
+ activation: str, callable
282
+ The activation function. It can be a string or a callable function.
283
+
284
+ References
285
+ ----------
286
+
287
+ .. [1] Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. "Recurrent neural
288
+ network regularization." arXiv preprint arXiv:1409.2329 (2014).
289
+ .. [2] Jozefowicz, Rafal, Wojciech Zaremba, and Ilya Sutskever. "An empirical
290
+ exploration of recurrent network architectures." In International conference
291
+ on machine learning, pp. 2342-2350. PMLR, 2015.
292
+ """
293
+ __module__ = 'brainstate.nn'
294
+
295
+ def __init__(
296
+ self,
297
+ num_in: int,
298
+ num_out: int,
299
+ w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
300
+ b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
301
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
302
+ activation: str | Callable = 'tanh',
303
+ name: str = None,
304
+ ):
305
+ super().__init__(name=name)
306
+
307
+ # parameters
308
+ self.num_out = num_out
309
+ self.num_in = num_in
310
+ self.in_size = (num_in,)
311
+ self.out_size = (num_out,)
312
+
313
+ # initializers
314
+ self._state_initializer = state_init
315
+
316
+ # activation function
317
+ if isinstance(activation, str):
318
+ self.activation = getattr(functional, activation)
319
+ else:
320
+ assert callable(activation), "The activation function should be a string or a callable function. "
321
+ self.activation = activation
322
+
323
+ # weights
324
+ self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=b_init)
325
+
326
+ def init_state(self, batch_size: int = None, **kwargs):
327
+ self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
328
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
329
+
330
+ def reset_state(self, batch_size: int = None, **kwargs):
331
+ self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
332
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
333
+
334
+ def update(self, x):
335
+ h, c = self.h.value, self.c.value
336
+ xh = jnp.concat([x, h], axis=-1)
337
+ i, g, f, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
338
+ c = functional.sigmoid(f + 1.) * c + functional.sigmoid(i) * self.activation(g)
339
+ h = functional.sigmoid(o) * self.activation(c)
340
+ self.h.value = h
341
+ self.c.value = c
342
+ return h
343
+
344
+
345
+ class URLSTMCell(RNNCell):
346
+ def __init__(
347
+ self,
348
+ num_in: int,
349
+ num_out: int,
350
+ w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
351
+ state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
352
+ activation: str | Callable = 'tanh',
353
+ name: str = None,
354
+ ):
355
+ super().__init__(name=name)
356
+
357
+ # parameters
358
+ self.num_out = num_out
359
+ self.num_in = num_in
360
+ self.in_size = (num_in,)
361
+ self.out_size = (num_out,)
362
+
363
+ # initializers
364
+ self._state_initializer = state_init
365
+
366
+ # activation function
367
+ if isinstance(activation, str):
368
+ self.activation = getattr(functional, activation)
369
+ else:
370
+ assert callable(activation), "The activation function should be a string or a callable function. "
371
+ self.activation = activation
372
+
373
+ # weights
374
+ self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=None)
375
+ self.bias = ParamState(self._forget_bias())
376
+
377
+ def _forget_bias(self):
378
+ u = random.uniform(1 / self.num_out, 1 - 1 / self.num_out, (self.num_out,))
379
+ return -jnp.log(1 / u - 1)
380
+
381
+ def init_state(self, batch_size: int = None, **kwargs):
382
+ self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
383
+ self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
384
+
385
+ def reset_state(self, batch_size: int = None, **kwargs):
386
+ self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
387
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
388
+
389
+ def update(self, x: ArrayLike) -> ArrayLike:
390
+ h, c = self.h.value, self.c.value
391
+ xh = jnp.concat([x, h], axis=-1)
392
+ f, r, u, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
393
+ f_ = functional.sigmoid(f + self.bias.value)
394
+ r_ = functional.sigmoid(r - self.bias.value)
395
+ g = 2 * r_ * f_ + (1 - 2 * r_) * f_ ** 2
396
+ next_cell = g * c + (1 - g) * self.activation(u)
397
+ next_hidden = functional.sigmoid(o) * self.activation(next_cell)
398
+ self.h.value = next_hidden
399
+ self.c.value = next_cell
400
+ return next_hidden
@@ -0,0 +1,64 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+
20
+ import jax.numpy as jnp
21
+
22
+ import brainstate as bst
23
+
24
+
25
+ class TestRateRNNModels(unittest.TestCase):
26
+ def setUp(self):
27
+ self.num_in = 3
28
+ self.num_out = 3
29
+ self.batch_size = 4
30
+ self.x = jnp.ones((self.batch_size, self.num_in))
31
+
32
+ def test_ValinaRNNCell(self):
33
+ model = bst.nn.ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
34
+ model.init_state(batch_size=self.batch_size)
35
+ output = model.update(self.x)
36
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
37
+
38
+ def test_GRUCell(self):
39
+ model = bst.nn.GRUCell(num_in=self.num_in, num_out=self.num_out)
40
+ model.init_state(batch_size=self.batch_size)
41
+ output = model.update(self.x)
42
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
43
+
44
+ def test_MGUCell(self):
45
+ model = bst.nn.MGUCell(num_in=self.num_in, num_out=self.num_out)
46
+ model.init_state(batch_size=self.batch_size)
47
+ output = model.update(self.x)
48
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
49
+
50
+ def test_LSTMCell(self):
51
+ model = bst.nn.LSTMCell(num_in=self.num_in, num_out=self.num_out)
52
+ model.init_state(batch_size=self.batch_size)
53
+ output = model.update(self.x)
54
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
55
+
56
+ def test_URLSTMCell(self):
57
+ model = bst.nn.URLSTMCell(num_in=self.num_in, num_out=self.num_out)
58
+ model.init_state(batch_size=self.batch_size)
59
+ output = model.update(self.x)
60
+ self.assertEqual(output.shape, (self.batch_size, self.num_out))
61
+
62
+
63
+ if __name__ == '__main__':
64
+ unittest.main()
@@ -0,0 +1,128 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ import numbers
21
+ from typing import Callable
22
+
23
+ import brainunit as u
24
+ import jax
25
+
26
+ from brainstate import environ, init, surrogate
27
+ from brainstate._state import HiddenState, ParamState
28
+ from brainstate.nn._exp_euler import exp_euler_step
29
+ from brainstate.nn._module import Module
30
+ from brainstate.typing import Size, ArrayLike
31
+ from ._dynamics_neuron import Neuron
32
+
33
+ __all__ = [
34
+ 'LeakyRateReadout',
35
+ 'LeakySpikeReadout',
36
+ ]
37
+
38
+
39
+ class LeakyRateReadout(Module):
40
+ """
41
+ Leaky dynamics for the read-out module used in the Real-Time Recurrent Learning.
42
+ """
43
+ __module__ = 'brainstate.nn'
44
+
45
+ def __init__(
46
+ self,
47
+ in_size: Size,
48
+ out_size: Size,
49
+ tau: ArrayLike = 5. * u.ms,
50
+ w_init: Callable = init.KaimingNormal(),
51
+ name: str = None,
52
+ ):
53
+ super().__init__(name=name)
54
+
55
+ # parameters
56
+ self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
57
+ self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
58
+ self.tau = init.param(tau, self.in_size)
59
+ self.decay = u.math.exp(-environ.get_dt() / self.tau)
60
+
61
+ # weights
62
+ self.weight = ParamState(init.param(w_init, (self.in_size[0], self.out_size[0])))
63
+
64
+ def init_state(self, batch_size=None, **kwargs):
65
+ self.r = HiddenState(init.param(init.Constant(0.), self.out_size, batch_size))
66
+
67
+ def reset_state(self, batch_size=None, **kwargs):
68
+ self.r.value = init.param(init.Constant(0.), self.out_size, batch_size)
69
+
70
+ def update(self, x):
71
+ self.r.value = self.decay * self.r.value + x @ self.weight.value
72
+ return self.r.value
73
+
74
+
75
+ class LeakySpikeReadout(Neuron):
76
+ """
77
+ Integrate-and-fire neuron model with leaky dynamics.
78
+ """
79
+
80
+ __module__ = 'brainstate.nn'
81
+
82
+ def __init__(
83
+ self,
84
+ in_size: Size,
85
+ tau: ArrayLike = 5. * u.ms,
86
+ V_th: ArrayLike = 1. * u.mV,
87
+ w_init: Callable = init.KaimingNormal(unit=u.mV),
88
+ V_initializer: ArrayLike = init.ZeroInit(unit=u.mV),
89
+ spk_fun: Callable = surrogate.ReluGrad(),
90
+ spk_reset: str = 'soft',
91
+ name: str = None,
92
+ ):
93
+ super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
94
+
95
+ # parameters
96
+ self.tau = init.param(tau, (self.varshape,))
97
+ self.V_th = init.param(V_th, (self.varshape,))
98
+ self.V_initializer = V_initializer
99
+
100
+ # weights
101
+ self.weight = ParamState(init.param(w_init, (self.in_size[-1], self.out_size[-1])))
102
+
103
+ def init_state(self, batch_size, **kwargs):
104
+ self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
105
+
106
+ def reset_state(self, batch_size, **kwargs):
107
+ self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
108
+
109
+ @property
110
+ def spike(self):
111
+ return self.get_spike(self.V.value)
112
+
113
+ def get_spike(self, V):
114
+ v_scaled = (V - self.V_th) / self.V_th
115
+ return self.spk_fun(v_scaled)
116
+
117
+ def update(self, spk):
118
+ # reset
119
+ last_V = self.V.value
120
+ last_spike = self.get_spike(last_V)
121
+ V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_V)
122
+ V = last_V - V_th * last_spike
123
+ # membrane potential
124
+ x = spk @ self.weight.value
125
+ dv = lambda v: (-v + self.sum_current_inputs(x, v)) / self.tau
126
+ V = exp_euler_step(dv, V)
127
+ self.V.value = self.sum_delta_inputs(V)
128
+ return self.get_spike(V)
@@ -0,0 +1,54 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ import unittest
19
+
20
+ import jax.numpy as jnp
21
+
22
+ import brainstate as bst
23
+
24
+
25
+ class TestReadoutModels(unittest.TestCase):
26
+ def setUp(self):
27
+ self.in_size = 3
28
+ self.out_size = 3
29
+ self.batch_size = 4
30
+ self.tau = 5.0
31
+ self.V_th = 1.0
32
+ self.x = jnp.ones((self.batch_size, self.in_size))
33
+
34
+ def test_LeakyRateReadout(self):
35
+ with bst.environ.context(dt=0.1):
36
+ model = bst.nn.LeakyRateReadout(in_size=self.in_size, out_size=self.out_size, tau=self.tau)
37
+ model.init_state(batch_size=self.batch_size)
38
+ output = model.update(self.x)
39
+ self.assertEqual(output.shape, (self.batch_size, self.out_size))
40
+
41
+ def test_LeakySpikeReadout(self):
42
+ with bst.environ.context(dt=0.1):
43
+ model = bst.nn.LeakySpikeReadout(in_size=self.in_size, tau=self.tau, V_th=self.V_th,
44
+ V_initializer=bst.init.ZeroInit(),
45
+ w_init=bst.init.KaimingNormal())
46
+ model.init_state(batch_size=self.batch_size)
47
+ with bst.environ.context(t=0.):
48
+ output = model.update(self.x)
49
+ self.assertEqual(output.shape, (self.batch_size, self.out_size))
50
+
51
+
52
+ if __name__ == '__main__':
53
+ with bst.environ.context(dt=0.1):
54
+ unittest.main()
@@ -0,0 +1,37 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ._dynamics_base import *
17
+ from ._dynamics_base import __all__ as dyn_all
18
+ from ._projection_base import *
19
+ from ._projection_base import __all__ as projection_all
20
+ from ._state_delay import *
21
+ from ._state_delay import __all__ as state_delay_all
22
+ from ._synouts import *
23
+ from ._synouts import __all__ as synouts_all
24
+
25
+ __all__ = (
26
+ dyn_all
27
+ + projection_all
28
+ + state_delay_all
29
+ + synouts_all
30
+ )
31
+
32
+ del (
33
+ dyn_all,
34
+ projection_all,
35
+ state_delay_all,
36
+ synouts_all
37
+ )