brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1360 -1318
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,400 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
from typing import Callable, Union
|
21
|
+
|
22
|
+
import jax.numpy as jnp
|
23
|
+
|
24
|
+
from brainstate import random, init, functional
|
25
|
+
from brainstate._state import HiddenState, ParamState
|
26
|
+
from brainstate.nn._interaction._connections import Linear
|
27
|
+
from brainstate.nn._module import Module
|
28
|
+
from brainstate.typing import ArrayLike
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
|
32
|
+
]
|
33
|
+
|
34
|
+
|
35
|
+
class RNNCell(Module):
|
36
|
+
"""
|
37
|
+
Base class for RNN cells.
|
38
|
+
"""
|
39
|
+
pass
|
40
|
+
|
41
|
+
|
42
|
+
class ValinaRNNCell(RNNCell):
|
43
|
+
"""
|
44
|
+
Vanilla RNN cell.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
num_in: int. The number of input units.
|
48
|
+
num_out: int. The number of hidden units.
|
49
|
+
state_init: callable, ArrayLike. The state initializer.
|
50
|
+
w_init: callable, ArrayLike. The input weight initializer.
|
51
|
+
b_init: optional, callable, ArrayLike. The bias weight initializer.
|
52
|
+
activation: str, callable. The activation function. It can be a string or a callable function.
|
53
|
+
name: optional, str. The name of the module.
|
54
|
+
"""
|
55
|
+
__module__ = 'brainstate.nn'
|
56
|
+
|
57
|
+
def __init__(
|
58
|
+
self,
|
59
|
+
num_in: int,
|
60
|
+
num_out: int,
|
61
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
62
|
+
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
63
|
+
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
64
|
+
activation: str | Callable = 'relu',
|
65
|
+
name: str = None,
|
66
|
+
):
|
67
|
+
super().__init__(name=name)
|
68
|
+
|
69
|
+
# parameters
|
70
|
+
self.num_out = num_out
|
71
|
+
self.num_in = num_in
|
72
|
+
self.in_size = (num_in,)
|
73
|
+
self.out_size = (num_out,)
|
74
|
+
self._state_initializer = state_init
|
75
|
+
|
76
|
+
# activation function
|
77
|
+
if isinstance(activation, str):
|
78
|
+
self.activation = getattr(functional, activation)
|
79
|
+
else:
|
80
|
+
assert callable(activation), "The activation function should be a string or a callable function. "
|
81
|
+
self.activation = activation
|
82
|
+
|
83
|
+
# weights
|
84
|
+
self.W = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
85
|
+
|
86
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
87
|
+
self.h = HiddenState(init.param(self._state_initializer, self.num_out, batch_size))
|
88
|
+
|
89
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
90
|
+
self.h.value = init.param(self._state_initializer, self.num_out, batch_size)
|
91
|
+
|
92
|
+
def update(self, x):
|
93
|
+
xh = jnp.concatenate([x, self.h.value], axis=-1)
|
94
|
+
h = self.W(xh)
|
95
|
+
self.h.value = self.activation(h)
|
96
|
+
return self.h.value
|
97
|
+
|
98
|
+
|
99
|
+
class GRUCell(RNNCell):
|
100
|
+
"""
|
101
|
+
Gated Recurrent Unit (GRU) cell.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
num_in: int. The number of input units.
|
105
|
+
num_out: int. The number of hidden units.
|
106
|
+
state_init: callable, ArrayLike. The state initializer.
|
107
|
+
w_init: callable, ArrayLike. The input weight initializer.
|
108
|
+
b_init: optional, callable, ArrayLike. The bias weight initializer.
|
109
|
+
activation: str, callable. The activation function. It can be a string or a callable function.
|
110
|
+
name: optional, str. The name of the module.
|
111
|
+
"""
|
112
|
+
__module__ = 'brainstate.nn'
|
113
|
+
|
114
|
+
def __init__(
|
115
|
+
self,
|
116
|
+
num_in: int,
|
117
|
+
num_out: int,
|
118
|
+
w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
|
119
|
+
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
120
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
121
|
+
activation: str | Callable = 'tanh',
|
122
|
+
name: str = None,
|
123
|
+
):
|
124
|
+
super().__init__(name=name)
|
125
|
+
|
126
|
+
# parameters
|
127
|
+
self._state_initializer = state_init
|
128
|
+
self.num_out = num_out
|
129
|
+
self.num_in = num_in
|
130
|
+
self.in_size = (num_in,)
|
131
|
+
self.out_size = (num_out,)
|
132
|
+
|
133
|
+
# activation function
|
134
|
+
if isinstance(activation, str):
|
135
|
+
self.activation = getattr(functional, activation)
|
136
|
+
else:
|
137
|
+
assert callable(activation), "The activation function should be a string or a callable function. "
|
138
|
+
self.activation = activation
|
139
|
+
|
140
|
+
# weights
|
141
|
+
self.Wrz = Linear(num_in + num_out, num_out * 2, w_init=w_init, b_init=b_init)
|
142
|
+
self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
143
|
+
|
144
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
145
|
+
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
146
|
+
|
147
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
148
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
149
|
+
|
150
|
+
def update(self, x):
|
151
|
+
old_h = self.h.value
|
152
|
+
xh = jnp.concatenate([x, old_h], axis=-1)
|
153
|
+
r, z = jnp.split(functional.sigmoid(self.Wrz(xh)), indices_or_sections=2, axis=-1)
|
154
|
+
rh = r * old_h
|
155
|
+
h = self.activation(self.Wh(jnp.concatenate([x, rh], axis=-1)))
|
156
|
+
h = (1 - z) * old_h + z * h
|
157
|
+
self.h.value = h
|
158
|
+
return h
|
159
|
+
|
160
|
+
|
161
|
+
class MGUCell(RNNCell):
|
162
|
+
r"""
|
163
|
+
Minimal Gated Recurrent Unit (MGU) cell.
|
164
|
+
|
165
|
+
.. math::
|
166
|
+
|
167
|
+
\begin{aligned}
|
168
|
+
f_{t}&=\sigma (W_{f}x_{t}+U_{f}h_{t-1}+b_{f})\\
|
169
|
+
{\hat {h}}_{t}&=\phi (W_{h}x_{t}+U_{h}(f_{t}\odot h_{t-1})+b_{h})\\
|
170
|
+
h_{t}&=(1-f_{t})\odot h_{t-1}+f_{t}\odot {\hat {h}}_{t}
|
171
|
+
\end{aligned}
|
172
|
+
|
173
|
+
where:
|
174
|
+
|
175
|
+
- :math:`x_{t}`: input vector
|
176
|
+
- :math:`h_{t}`: output vector
|
177
|
+
- :math:`{\hat {h}}_{t}`: candidate activation vector
|
178
|
+
- :math:`f_{t}`: forget vector
|
179
|
+
- :math:`W, U, b`: parameter matrices and vector
|
180
|
+
|
181
|
+
Args:
|
182
|
+
num_in: int. The number of input units.
|
183
|
+
num_out: int. The number of hidden units.
|
184
|
+
state_init: callable, ArrayLike. The state initializer.
|
185
|
+
w_init: callable, ArrayLike. The input weight initializer.
|
186
|
+
b_init: optional, callable, ArrayLike. The bias weight initializer.
|
187
|
+
activation: str, callable. The activation function. It can be a string or a callable function.
|
188
|
+
name: optional, str. The name of the module.
|
189
|
+
"""
|
190
|
+
__module__ = 'brainstate.nn'
|
191
|
+
|
192
|
+
def __init__(
|
193
|
+
self,
|
194
|
+
num_in: int,
|
195
|
+
num_out: int,
|
196
|
+
w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
|
197
|
+
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
198
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
199
|
+
activation: str | Callable = 'tanh',
|
200
|
+
name: str = None,
|
201
|
+
):
|
202
|
+
super().__init__(name=name)
|
203
|
+
|
204
|
+
# parameters
|
205
|
+
self._state_initializer = state_init
|
206
|
+
self.num_out = num_out
|
207
|
+
self.num_in = num_in
|
208
|
+
self.in_size = (num_in,)
|
209
|
+
self.out_size = (num_out,)
|
210
|
+
|
211
|
+
# activation function
|
212
|
+
if isinstance(activation, str):
|
213
|
+
self.activation = getattr(functional, activation)
|
214
|
+
else:
|
215
|
+
assert callable(activation), "The activation function should be a string or a callable function. "
|
216
|
+
self.activation = activation
|
217
|
+
|
218
|
+
# weights
|
219
|
+
self.Wf = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
220
|
+
self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init)
|
221
|
+
|
222
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
223
|
+
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
224
|
+
|
225
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
226
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
227
|
+
|
228
|
+
def update(self, x):
|
229
|
+
old_h = self.h.value
|
230
|
+
xh = jnp.concatenate([x, old_h], axis=-1)
|
231
|
+
f = functional.sigmoid(self.Wf(xh))
|
232
|
+
fh = f * old_h
|
233
|
+
h = self.activation(self.Wh(jnp.concatenate([x, fh], axis=-1)))
|
234
|
+
self.h.value = (1 - f) * self.h.value + f * h
|
235
|
+
return self.h.value
|
236
|
+
|
237
|
+
|
238
|
+
class LSTMCell(RNNCell):
|
239
|
+
r"""Long short-term memory (LSTM) RNN core.
|
240
|
+
|
241
|
+
The implementation is based on (zaremba, et al., 2014) [1]_. Given
|
242
|
+
:math:`x_t` and the previous state :math:`(h_{t-1}, c_{t-1})` the core
|
243
|
+
computes
|
244
|
+
|
245
|
+
.. math::
|
246
|
+
|
247
|
+
\begin{array}{ll}
|
248
|
+
i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\
|
249
|
+
f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\
|
250
|
+
g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\
|
251
|
+
o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\
|
252
|
+
c_t = f_t c_{t-1} + i_t g_t \\
|
253
|
+
h_t = o_t \tanh(c_t)
|
254
|
+
\end{array}
|
255
|
+
|
256
|
+
where :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and
|
257
|
+
output gate activations, and :math:`g_t` is a vector of cell updates.
|
258
|
+
|
259
|
+
The output is equal to the new hidden, :math:`h_t`.
|
260
|
+
|
261
|
+
Notes
|
262
|
+
-----
|
263
|
+
|
264
|
+
Forget gate initialization: Following (Jozefowicz, et al., 2015) [2]_ we add 1.0
|
265
|
+
to :math:`b_f` after initialization in order to reduce the scale of forgetting in
|
266
|
+
the beginning of the training.
|
267
|
+
|
268
|
+
|
269
|
+
Parameters
|
270
|
+
----------
|
271
|
+
num_in: int
|
272
|
+
The dimension of the input vector
|
273
|
+
num_out: int
|
274
|
+
The number of hidden unit in the node.
|
275
|
+
state_init: callable, ArrayLike
|
276
|
+
The state initializer.
|
277
|
+
w_init: callable, ArrayLike
|
278
|
+
The input weight initializer.
|
279
|
+
b_init: optional, callable, ArrayLike
|
280
|
+
The bias weight initializer.
|
281
|
+
activation: str, callable
|
282
|
+
The activation function. It can be a string or a callable function.
|
283
|
+
|
284
|
+
References
|
285
|
+
----------
|
286
|
+
|
287
|
+
.. [1] Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. "Recurrent neural
|
288
|
+
network regularization." arXiv preprint arXiv:1409.2329 (2014).
|
289
|
+
.. [2] Jozefowicz, Rafal, Wojciech Zaremba, and Ilya Sutskever. "An empirical
|
290
|
+
exploration of recurrent network architectures." In International conference
|
291
|
+
on machine learning, pp. 2342-2350. PMLR, 2015.
|
292
|
+
"""
|
293
|
+
__module__ = 'brainstate.nn'
|
294
|
+
|
295
|
+
def __init__(
|
296
|
+
self,
|
297
|
+
num_in: int,
|
298
|
+
num_out: int,
|
299
|
+
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
300
|
+
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
301
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
302
|
+
activation: str | Callable = 'tanh',
|
303
|
+
name: str = None,
|
304
|
+
):
|
305
|
+
super().__init__(name=name)
|
306
|
+
|
307
|
+
# parameters
|
308
|
+
self.num_out = num_out
|
309
|
+
self.num_in = num_in
|
310
|
+
self.in_size = (num_in,)
|
311
|
+
self.out_size = (num_out,)
|
312
|
+
|
313
|
+
# initializers
|
314
|
+
self._state_initializer = state_init
|
315
|
+
|
316
|
+
# activation function
|
317
|
+
if isinstance(activation, str):
|
318
|
+
self.activation = getattr(functional, activation)
|
319
|
+
else:
|
320
|
+
assert callable(activation), "The activation function should be a string or a callable function. "
|
321
|
+
self.activation = activation
|
322
|
+
|
323
|
+
# weights
|
324
|
+
self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=b_init)
|
325
|
+
|
326
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
327
|
+
self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
328
|
+
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
329
|
+
|
330
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
331
|
+
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
332
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
333
|
+
|
334
|
+
def update(self, x):
|
335
|
+
h, c = self.h.value, self.c.value
|
336
|
+
xh = jnp.concat([x, h], axis=-1)
|
337
|
+
i, g, f, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
|
338
|
+
c = functional.sigmoid(f + 1.) * c + functional.sigmoid(i) * self.activation(g)
|
339
|
+
h = functional.sigmoid(o) * self.activation(c)
|
340
|
+
self.h.value = h
|
341
|
+
self.c.value = c
|
342
|
+
return h
|
343
|
+
|
344
|
+
|
345
|
+
class URLSTMCell(RNNCell):
|
346
|
+
def __init__(
|
347
|
+
self,
|
348
|
+
num_in: int,
|
349
|
+
num_out: int,
|
350
|
+
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
351
|
+
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
352
|
+
activation: str | Callable = 'tanh',
|
353
|
+
name: str = None,
|
354
|
+
):
|
355
|
+
super().__init__(name=name)
|
356
|
+
|
357
|
+
# parameters
|
358
|
+
self.num_out = num_out
|
359
|
+
self.num_in = num_in
|
360
|
+
self.in_size = (num_in,)
|
361
|
+
self.out_size = (num_out,)
|
362
|
+
|
363
|
+
# initializers
|
364
|
+
self._state_initializer = state_init
|
365
|
+
|
366
|
+
# activation function
|
367
|
+
if isinstance(activation, str):
|
368
|
+
self.activation = getattr(functional, activation)
|
369
|
+
else:
|
370
|
+
assert callable(activation), "The activation function should be a string or a callable function. "
|
371
|
+
self.activation = activation
|
372
|
+
|
373
|
+
# weights
|
374
|
+
self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=None)
|
375
|
+
self.bias = ParamState(self._forget_bias())
|
376
|
+
|
377
|
+
def _forget_bias(self):
|
378
|
+
u = random.uniform(1 / self.num_out, 1 - 1 / self.num_out, (self.num_out,))
|
379
|
+
return -jnp.log(1 / u - 1)
|
380
|
+
|
381
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
382
|
+
self.c = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
383
|
+
self.h = HiddenState(init.param(self._state_initializer, [self.num_out], batch_size))
|
384
|
+
|
385
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
386
|
+
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
387
|
+
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
388
|
+
|
389
|
+
def update(self, x: ArrayLike) -> ArrayLike:
|
390
|
+
h, c = self.h.value, self.c.value
|
391
|
+
xh = jnp.concat([x, h], axis=-1)
|
392
|
+
f, r, u, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
|
393
|
+
f_ = functional.sigmoid(f + self.bias.value)
|
394
|
+
r_ = functional.sigmoid(r - self.bias.value)
|
395
|
+
g = 2 * r_ * f_ + (1 - 2 * r_) * f_ ** 2
|
396
|
+
next_cell = g * c + (1 - g) * self.activation(u)
|
397
|
+
next_hidden = functional.sigmoid(o) * self.activation(next_cell)
|
398
|
+
self.h.value = next_hidden
|
399
|
+
self.c.value = next_cell
|
400
|
+
return next_hidden
|
@@ -0,0 +1,64 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import unittest
|
19
|
+
|
20
|
+
import jax.numpy as jnp
|
21
|
+
|
22
|
+
import brainstate as bst
|
23
|
+
|
24
|
+
|
25
|
+
class TestRateRNNModels(unittest.TestCase):
|
26
|
+
def setUp(self):
|
27
|
+
self.num_in = 3
|
28
|
+
self.num_out = 3
|
29
|
+
self.batch_size = 4
|
30
|
+
self.x = jnp.ones((self.batch_size, self.num_in))
|
31
|
+
|
32
|
+
def test_ValinaRNNCell(self):
|
33
|
+
model = bst.nn.ValinaRNNCell(num_in=self.num_in, num_out=self.num_out)
|
34
|
+
model.init_state(batch_size=self.batch_size)
|
35
|
+
output = model.update(self.x)
|
36
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
37
|
+
|
38
|
+
def test_GRUCell(self):
|
39
|
+
model = bst.nn.GRUCell(num_in=self.num_in, num_out=self.num_out)
|
40
|
+
model.init_state(batch_size=self.batch_size)
|
41
|
+
output = model.update(self.x)
|
42
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
43
|
+
|
44
|
+
def test_MGUCell(self):
|
45
|
+
model = bst.nn.MGUCell(num_in=self.num_in, num_out=self.num_out)
|
46
|
+
model.init_state(batch_size=self.batch_size)
|
47
|
+
output = model.update(self.x)
|
48
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
49
|
+
|
50
|
+
def test_LSTMCell(self):
|
51
|
+
model = bst.nn.LSTMCell(num_in=self.num_in, num_out=self.num_out)
|
52
|
+
model.init_state(batch_size=self.batch_size)
|
53
|
+
output = model.update(self.x)
|
54
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
55
|
+
|
56
|
+
def test_URLSTMCell(self):
|
57
|
+
model = bst.nn.URLSTMCell(num_in=self.num_in, num_out=self.num_out)
|
58
|
+
model.init_state(batch_size=self.batch_size)
|
59
|
+
output = model.update(self.x)
|
60
|
+
self.assertEqual(output.shape, (self.batch_size, self.num_out))
|
61
|
+
|
62
|
+
|
63
|
+
if __name__ == '__main__':
|
64
|
+
unittest.main()
|
@@ -0,0 +1,128 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
import numbers
|
21
|
+
from typing import Callable
|
22
|
+
|
23
|
+
import brainunit as u
|
24
|
+
import jax
|
25
|
+
|
26
|
+
from brainstate import environ, init, surrogate
|
27
|
+
from brainstate._state import HiddenState, ParamState
|
28
|
+
from brainstate.nn._exp_euler import exp_euler_step
|
29
|
+
from brainstate.nn._module import Module
|
30
|
+
from brainstate.typing import Size, ArrayLike
|
31
|
+
from ._dynamics_neuron import Neuron
|
32
|
+
|
33
|
+
__all__ = [
|
34
|
+
'LeakyRateReadout',
|
35
|
+
'LeakySpikeReadout',
|
36
|
+
]
|
37
|
+
|
38
|
+
|
39
|
+
class LeakyRateReadout(Module):
|
40
|
+
"""
|
41
|
+
Leaky dynamics for the read-out module used in the Real-Time Recurrent Learning.
|
42
|
+
"""
|
43
|
+
__module__ = 'brainstate.nn'
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
in_size: Size,
|
48
|
+
out_size: Size,
|
49
|
+
tau: ArrayLike = 5. * u.ms,
|
50
|
+
w_init: Callable = init.KaimingNormal(),
|
51
|
+
name: str = None,
|
52
|
+
):
|
53
|
+
super().__init__(name=name)
|
54
|
+
|
55
|
+
# parameters
|
56
|
+
self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
|
57
|
+
self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
|
58
|
+
self.tau = init.param(tau, self.in_size)
|
59
|
+
self.decay = u.math.exp(-environ.get_dt() / self.tau)
|
60
|
+
|
61
|
+
# weights
|
62
|
+
self.weight = ParamState(init.param(w_init, (self.in_size[0], self.out_size[0])))
|
63
|
+
|
64
|
+
def init_state(self, batch_size=None, **kwargs):
|
65
|
+
self.r = HiddenState(init.param(init.Constant(0.), self.out_size, batch_size))
|
66
|
+
|
67
|
+
def reset_state(self, batch_size=None, **kwargs):
|
68
|
+
self.r.value = init.param(init.Constant(0.), self.out_size, batch_size)
|
69
|
+
|
70
|
+
def update(self, x):
|
71
|
+
self.r.value = self.decay * self.r.value + x @ self.weight.value
|
72
|
+
return self.r.value
|
73
|
+
|
74
|
+
|
75
|
+
class LeakySpikeReadout(Neuron):
|
76
|
+
"""
|
77
|
+
Integrate-and-fire neuron model with leaky dynamics.
|
78
|
+
"""
|
79
|
+
|
80
|
+
__module__ = 'brainstate.nn'
|
81
|
+
|
82
|
+
def __init__(
|
83
|
+
self,
|
84
|
+
in_size: Size,
|
85
|
+
tau: ArrayLike = 5. * u.ms,
|
86
|
+
V_th: ArrayLike = 1. * u.mV,
|
87
|
+
w_init: Callable = init.KaimingNormal(unit=u.mV),
|
88
|
+
V_initializer: ArrayLike = init.ZeroInit(unit=u.mV),
|
89
|
+
spk_fun: Callable = surrogate.ReluGrad(),
|
90
|
+
spk_reset: str = 'soft',
|
91
|
+
name: str = None,
|
92
|
+
):
|
93
|
+
super().__init__(in_size, name=name, spk_fun=spk_fun, spk_reset=spk_reset)
|
94
|
+
|
95
|
+
# parameters
|
96
|
+
self.tau = init.param(tau, (self.varshape,))
|
97
|
+
self.V_th = init.param(V_th, (self.varshape,))
|
98
|
+
self.V_initializer = V_initializer
|
99
|
+
|
100
|
+
# weights
|
101
|
+
self.weight = ParamState(init.param(w_init, (self.in_size[-1], self.out_size[-1])))
|
102
|
+
|
103
|
+
def init_state(self, batch_size, **kwargs):
|
104
|
+
self.V = HiddenState(init.param(self.V_initializer, self.varshape, batch_size))
|
105
|
+
|
106
|
+
def reset_state(self, batch_size, **kwargs):
|
107
|
+
self.V.value = init.param(self.V_initializer, self.varshape, batch_size)
|
108
|
+
|
109
|
+
@property
|
110
|
+
def spike(self):
|
111
|
+
return self.get_spike(self.V.value)
|
112
|
+
|
113
|
+
def get_spike(self, V):
|
114
|
+
v_scaled = (V - self.V_th) / self.V_th
|
115
|
+
return self.spk_fun(v_scaled)
|
116
|
+
|
117
|
+
def update(self, spk):
|
118
|
+
# reset
|
119
|
+
last_V = self.V.value
|
120
|
+
last_spike = self.get_spike(last_V)
|
121
|
+
V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_V)
|
122
|
+
V = last_V - V_th * last_spike
|
123
|
+
# membrane potential
|
124
|
+
x = spk @ self.weight.value
|
125
|
+
dv = lambda v: (-v + self.sum_current_inputs(x, v)) / self.tau
|
126
|
+
V = exp_euler_step(dv, V)
|
127
|
+
self.V.value = self.sum_delta_inputs(V)
|
128
|
+
return self.get_spike(V)
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import unittest
|
19
|
+
|
20
|
+
import jax.numpy as jnp
|
21
|
+
|
22
|
+
import brainstate as bst
|
23
|
+
|
24
|
+
|
25
|
+
class TestReadoutModels(unittest.TestCase):
|
26
|
+
def setUp(self):
|
27
|
+
self.in_size = 3
|
28
|
+
self.out_size = 3
|
29
|
+
self.batch_size = 4
|
30
|
+
self.tau = 5.0
|
31
|
+
self.V_th = 1.0
|
32
|
+
self.x = jnp.ones((self.batch_size, self.in_size))
|
33
|
+
|
34
|
+
def test_LeakyRateReadout(self):
|
35
|
+
with bst.environ.context(dt=0.1):
|
36
|
+
model = bst.nn.LeakyRateReadout(in_size=self.in_size, out_size=self.out_size, tau=self.tau)
|
37
|
+
model.init_state(batch_size=self.batch_size)
|
38
|
+
output = model.update(self.x)
|
39
|
+
self.assertEqual(output.shape, (self.batch_size, self.out_size))
|
40
|
+
|
41
|
+
def test_LeakySpikeReadout(self):
|
42
|
+
with bst.environ.context(dt=0.1):
|
43
|
+
model = bst.nn.LeakySpikeReadout(in_size=self.in_size, tau=self.tau, V_th=self.V_th,
|
44
|
+
V_initializer=bst.init.ZeroInit(),
|
45
|
+
w_init=bst.init.KaimingNormal())
|
46
|
+
model.init_state(batch_size=self.batch_size)
|
47
|
+
with bst.environ.context(t=0.):
|
48
|
+
output = model.update(self.x)
|
49
|
+
self.assertEqual(output.shape, (self.batch_size, self.out_size))
|
50
|
+
|
51
|
+
|
52
|
+
if __name__ == '__main__':
|
53
|
+
with bst.environ.context(dt=0.1):
|
54
|
+
unittest.main()
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from ._dynamics_base import *
|
17
|
+
from ._dynamics_base import __all__ as dyn_all
|
18
|
+
from ._projection_base import *
|
19
|
+
from ._projection_base import __all__ as projection_all
|
20
|
+
from ._state_delay import *
|
21
|
+
from ._state_delay import __all__ as state_delay_all
|
22
|
+
from ._synouts import *
|
23
|
+
from ._synouts import __all__ as synouts_all
|
24
|
+
|
25
|
+
__all__ = (
|
26
|
+
dyn_all
|
27
|
+
+ projection_all
|
28
|
+
+ state_delay_all
|
29
|
+
+ synouts_all
|
30
|
+
)
|
31
|
+
|
32
|
+
del (
|
33
|
+
dyn_all,
|
34
|
+
projection_all,
|
35
|
+
state_delay_all,
|
36
|
+
synouts_all
|
37
|
+
)
|