brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/nn/_rate_rnns.py
DELETED
@@ -1,410 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from __future__ import annotations
|
19
|
-
|
20
|
-
from typing import Callable, Union
|
21
|
-
|
22
|
-
import jax.numpy as jnp
|
23
|
-
|
24
|
-
from ._base import ExplicitInOutSize
|
25
|
-
from ._connections import Linear
|
26
|
-
from .. import random, init, functional
|
27
|
-
from .._module import Module
|
28
|
-
from .._state import ShortTermState, ParamState
|
29
|
-
from ..mixin import DelayedInit, Mode
|
30
|
-
from brainstate.typing import ArrayLike
|
31
|
-
|
32
|
-
__all__ = [
|
33
|
-
'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
|
34
|
-
]
|
35
|
-
|
36
|
-
|
37
|
-
class RNNCell(Module, ExplicitInOutSize, DelayedInit):
|
38
|
-
"""
|
39
|
-
Base class for RNN cells.
|
40
|
-
"""
|
41
|
-
pass
|
42
|
-
|
43
|
-
|
44
|
-
class ValinaRNNCell(RNNCell):
|
45
|
-
"""
|
46
|
-
Vanilla RNN cell.
|
47
|
-
|
48
|
-
Args:
|
49
|
-
num_in: int. The number of input units.
|
50
|
-
num_out: int. The number of hidden units.
|
51
|
-
state_init: callable, ArrayLike. The state initializer.
|
52
|
-
w_init: callable, ArrayLike. The input weight initializer.
|
53
|
-
b_init: optional, callable, ArrayLike. The bias weight initializer.
|
54
|
-
activation: str, callable. The activation function. It can be a string or a callable function.
|
55
|
-
mode: optional, Mode. The mode of the module.
|
56
|
-
name: optional, str. The name of the module.
|
57
|
-
"""
|
58
|
-
__module__ = 'brainstate.nn'
|
59
|
-
|
60
|
-
def __init__(
|
61
|
-
self,
|
62
|
-
num_in: int,
|
63
|
-
num_out: int,
|
64
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
65
|
-
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
66
|
-
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
67
|
-
activation: str | Callable = 'relu',
|
68
|
-
mode: Mode = None,
|
69
|
-
name: str = None,
|
70
|
-
):
|
71
|
-
super().__init__(mode=mode, name=name)
|
72
|
-
|
73
|
-
# parameters
|
74
|
-
self._state_initializer = state_init
|
75
|
-
self.num_out = num_out
|
76
|
-
self.num_in = num_in
|
77
|
-
self.in_size = (num_in,)
|
78
|
-
self.out_size = (num_out,)
|
79
|
-
|
80
|
-
# activation function
|
81
|
-
if isinstance(activation, str):
|
82
|
-
self.activation = getattr(functional, activation)
|
83
|
-
else:
|
84
|
-
assert callable(activation), "The activation function should be a string or a callable function. "
|
85
|
-
self.activation = activation
|
86
|
-
|
87
|
-
# weights
|
88
|
-
self.W = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init, name=self.name + '_W')
|
89
|
-
|
90
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
91
|
-
self.h = ShortTermState(init.param(self._state_initializer, self.num_out, batch_size))
|
92
|
-
|
93
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
94
|
-
self.h.value = init.param(self._state_initializer, self.num_out, batch_size)
|
95
|
-
|
96
|
-
def update(self, x):
|
97
|
-
xh = jnp.concatenate([x, self.h.value], axis=-1)
|
98
|
-
h = self.W(xh)
|
99
|
-
self.h.value = self.activation(h)
|
100
|
-
return self.h.value
|
101
|
-
|
102
|
-
|
103
|
-
class GRUCell(RNNCell):
|
104
|
-
"""
|
105
|
-
Gated Recurrent Unit (GRU) cell.
|
106
|
-
|
107
|
-
Args:
|
108
|
-
num_in: int. The number of input units.
|
109
|
-
num_out: int. The number of hidden units.
|
110
|
-
state_init: callable, ArrayLike. The state initializer.
|
111
|
-
w_init: callable, ArrayLike. The input weight initializer.
|
112
|
-
b_init: optional, callable, ArrayLike. The bias weight initializer.
|
113
|
-
activation: str, callable. The activation function. It can be a string or a callable function.
|
114
|
-
mode: optional, Mode. The mode of the module.
|
115
|
-
name: optional, str. The name of the module.
|
116
|
-
"""
|
117
|
-
__module__ = 'brainstate.nn'
|
118
|
-
|
119
|
-
def __init__(
|
120
|
-
self,
|
121
|
-
num_in: int,
|
122
|
-
num_out: int,
|
123
|
-
w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
|
124
|
-
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
125
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
126
|
-
activation: str | Callable = 'tanh',
|
127
|
-
mode: Mode = None,
|
128
|
-
name: str = None,
|
129
|
-
):
|
130
|
-
super().__init__(mode=mode, name=name)
|
131
|
-
|
132
|
-
# parameters
|
133
|
-
self._state_initializer = state_init
|
134
|
-
self.num_out = num_out
|
135
|
-
self.num_in = num_in
|
136
|
-
self.in_size = (num_in,)
|
137
|
-
self.out_size = (num_out,)
|
138
|
-
|
139
|
-
# activation function
|
140
|
-
if isinstance(activation, str):
|
141
|
-
self.activation = getattr(functional, activation)
|
142
|
-
else:
|
143
|
-
assert callable(activation), "The activation function should be a string or a callable function. "
|
144
|
-
self.activation = activation
|
145
|
-
|
146
|
-
# weights
|
147
|
-
self.Wrz = Linear(num_in + num_out, num_out * 2, w_init=w_init, b_init=b_init, name=self.name + '_Wrz')
|
148
|
-
self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init, name=self.name + '_Wh')
|
149
|
-
|
150
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
151
|
-
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
152
|
-
|
153
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
154
|
-
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
155
|
-
|
156
|
-
def update(self, x):
|
157
|
-
old_h = self.h.value
|
158
|
-
xh = jnp.concatenate([x, old_h], axis=-1)
|
159
|
-
r, z = jnp.split(functional.sigmoid(self.Wrz(xh)), indices_or_sections=2, axis=-1)
|
160
|
-
rh = r * old_h
|
161
|
-
h = self.activation(self.Wh(jnp.concatenate([x, rh], axis=-1)))
|
162
|
-
h = (1 - z) * old_h + z * h
|
163
|
-
self.h.value = h
|
164
|
-
return h
|
165
|
-
|
166
|
-
|
167
|
-
class MGUCell(RNNCell):
|
168
|
-
r"""
|
169
|
-
Minimal Gated Recurrent Unit (MGU) cell.
|
170
|
-
|
171
|
-
.. math::
|
172
|
-
|
173
|
-
\begin{aligned}
|
174
|
-
f_{t}&=\sigma (W_{f}x_{t}+U_{f}h_{t-1}+b_{f})\\
|
175
|
-
{\hat {h}}_{t}&=\phi (W_{h}x_{t}+U_{h}(f_{t}\odot h_{t-1})+b_{h})\\
|
176
|
-
h_{t}&=(1-f_{t})\odot h_{t-1}+f_{t}\odot {\hat {h}}_{t}
|
177
|
-
\end{aligned}
|
178
|
-
|
179
|
-
where:
|
180
|
-
|
181
|
-
- :math:`x_{t}`: input vector
|
182
|
-
- :math:`h_{t}`: output vector
|
183
|
-
- :math:`{\hat {h}}_{t}`: candidate activation vector
|
184
|
-
- :math:`f_{t}`: forget vector
|
185
|
-
- :math:`W, U, b`: parameter matrices and vector
|
186
|
-
|
187
|
-
Args:
|
188
|
-
num_in: int. The number of input units.
|
189
|
-
num_out: int. The number of hidden units.
|
190
|
-
state_init: callable, ArrayLike. The state initializer.
|
191
|
-
w_init: callable, ArrayLike. The input weight initializer.
|
192
|
-
b_init: optional, callable, ArrayLike. The bias weight initializer.
|
193
|
-
activation: str, callable. The activation function. It can be a string or a callable function.
|
194
|
-
mode: optional, Mode. The mode of the module.
|
195
|
-
name: optional, str. The name of the module.
|
196
|
-
"""
|
197
|
-
__module__ = 'brainstate.nn'
|
198
|
-
|
199
|
-
def __init__(
|
200
|
-
self,
|
201
|
-
num_in: int,
|
202
|
-
num_out: int,
|
203
|
-
w_init: Union[ArrayLike, Callable] = init.Orthogonal(),
|
204
|
-
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
205
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
206
|
-
activation: str | Callable = 'tanh',
|
207
|
-
mode: Mode = None,
|
208
|
-
name: str = None,
|
209
|
-
):
|
210
|
-
super().__init__(mode=mode, name=name)
|
211
|
-
|
212
|
-
# parameters
|
213
|
-
self._state_initializer = state_init
|
214
|
-
self.num_out = num_out
|
215
|
-
self.num_in = num_in
|
216
|
-
self.in_size = (num_in,)
|
217
|
-
self.out_size = (num_out,)
|
218
|
-
|
219
|
-
# activation function
|
220
|
-
if isinstance(activation, str):
|
221
|
-
self.activation = getattr(functional, activation)
|
222
|
-
else:
|
223
|
-
assert callable(activation), "The activation function should be a string or a callable function. "
|
224
|
-
self.activation = activation
|
225
|
-
|
226
|
-
# weights
|
227
|
-
self.Wf = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init, name=self.name + '_Wf')
|
228
|
-
self.Wh = Linear(num_in + num_out, num_out, w_init=w_init, b_init=b_init, name=self.name + '_Wh')
|
229
|
-
|
230
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
231
|
-
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
232
|
-
|
233
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
234
|
-
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
235
|
-
|
236
|
-
def update(self, x):
|
237
|
-
old_h = self.h.value
|
238
|
-
xh = jnp.concatenate([x, old_h], axis=-1)
|
239
|
-
f = functional.sigmoid(self.Wf(xh))
|
240
|
-
fh = f * old_h
|
241
|
-
h = self.activation(self.Wh(jnp.concatenate([x, fh], axis=-1)))
|
242
|
-
self.h.value = (1 - f) * self.h.value + f * h
|
243
|
-
return self.h.value
|
244
|
-
|
245
|
-
|
246
|
-
class LSTMCell(RNNCell):
|
247
|
-
r"""Long short-term memory (LSTM) RNN core.
|
248
|
-
|
249
|
-
The implementation is based on (zaremba, et al., 2014) [1]_. Given
|
250
|
-
:math:`x_t` and the previous state :math:`(h_{t-1}, c_{t-1})` the core
|
251
|
-
computes
|
252
|
-
|
253
|
-
.. math::
|
254
|
-
|
255
|
-
\begin{array}{ll}
|
256
|
-
i_t = \sigma(W_{ii} x_t + W_{hi} h_{t-1} + b_i) \\
|
257
|
-
f_t = \sigma(W_{if} x_t + W_{hf} h_{t-1} + b_f) \\
|
258
|
-
g_t = \tanh(W_{ig} x_t + W_{hg} h_{t-1} + b_g) \\
|
259
|
-
o_t = \sigma(W_{io} x_t + W_{ho} h_{t-1} + b_o) \\
|
260
|
-
c_t = f_t c_{t-1} + i_t g_t \\
|
261
|
-
h_t = o_t \tanh(c_t)
|
262
|
-
\end{array}
|
263
|
-
|
264
|
-
where :math:`i_t`, :math:`f_t`, :math:`o_t` are input, forget and
|
265
|
-
output gate activations, and :math:`g_t` is a vector of cell updates.
|
266
|
-
|
267
|
-
The output is equal to the new hidden, :math:`h_t`.
|
268
|
-
|
269
|
-
Notes
|
270
|
-
-----
|
271
|
-
|
272
|
-
Forget gate initialization: Following (Jozefowicz, et al., 2015) [2]_ we add 1.0
|
273
|
-
to :math:`b_f` after initialization in order to reduce the scale of forgetting in
|
274
|
-
the beginning of the training.
|
275
|
-
|
276
|
-
|
277
|
-
Parameters
|
278
|
-
----------
|
279
|
-
num_in: int
|
280
|
-
The dimension of the input vector
|
281
|
-
num_out: int
|
282
|
-
The number of hidden unit in the node.
|
283
|
-
state_init: callable, ArrayLike
|
284
|
-
The state initializer.
|
285
|
-
w_init: callable, ArrayLike
|
286
|
-
The input weight initializer.
|
287
|
-
b_init: optional, callable, ArrayLike
|
288
|
-
The bias weight initializer.
|
289
|
-
activation: str, callable
|
290
|
-
The activation function. It can be a string or a callable function.
|
291
|
-
|
292
|
-
References
|
293
|
-
----------
|
294
|
-
|
295
|
-
.. [1] Zaremba, Wojciech, Ilya Sutskever, and Oriol Vinyals. "Recurrent neural
|
296
|
-
network regularization." arXiv preprint arXiv:1409.2329 (2014).
|
297
|
-
.. [2] Jozefowicz, Rafal, Wojciech Zaremba, and Ilya Sutskever. "An empirical
|
298
|
-
exploration of recurrent network architectures." In International conference
|
299
|
-
on machine learning, pp. 2342-2350. PMLR, 2015.
|
300
|
-
"""
|
301
|
-
__module__ = 'brainstate.nn'
|
302
|
-
|
303
|
-
def __init__(
|
304
|
-
self,
|
305
|
-
num_in: int,
|
306
|
-
num_out: int,
|
307
|
-
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
308
|
-
b_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
309
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
310
|
-
activation: str | Callable = 'tanh',
|
311
|
-
mode: Mode = None,
|
312
|
-
name: str = None,
|
313
|
-
):
|
314
|
-
super().__init__(mode=mode, name=name)
|
315
|
-
|
316
|
-
# parameters
|
317
|
-
self.num_out = num_out
|
318
|
-
self.num_in = num_in
|
319
|
-
self.in_size = (num_in,)
|
320
|
-
self.out_size = (num_out,)
|
321
|
-
|
322
|
-
# initializers
|
323
|
-
self._state_initializer = state_init
|
324
|
-
|
325
|
-
# activation function
|
326
|
-
if isinstance(activation, str):
|
327
|
-
self.activation = getattr(functional, activation)
|
328
|
-
else:
|
329
|
-
assert callable(activation), "The activation function should be a string or a callable function. "
|
330
|
-
self.activation = activation
|
331
|
-
|
332
|
-
# weights
|
333
|
-
self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=b_init, name=self.name + '_W')
|
334
|
-
|
335
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
336
|
-
self.c = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
337
|
-
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
338
|
-
|
339
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
340
|
-
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
341
|
-
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
342
|
-
|
343
|
-
def update(self, x):
|
344
|
-
h, c = self.h.value, self.c.value
|
345
|
-
xh = jnp.concat([x, h], axis=-1)
|
346
|
-
i, g, f, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
|
347
|
-
c = functional.sigmoid(f + 1.) * c + functional.sigmoid(i) * self.activation(g)
|
348
|
-
h = functional.sigmoid(o) * self.activation(c)
|
349
|
-
self.h.value = h
|
350
|
-
self.c.value = c
|
351
|
-
return h
|
352
|
-
|
353
|
-
|
354
|
-
class URLSTMCell(RNNCell):
|
355
|
-
def __init__(
|
356
|
-
self,
|
357
|
-
num_in: int,
|
358
|
-
num_out: int,
|
359
|
-
w_init: Union[ArrayLike, Callable] = init.XavierNormal(),
|
360
|
-
state_init: Union[ArrayLike, Callable] = init.ZeroInit(),
|
361
|
-
activation: str | Callable = 'tanh',
|
362
|
-
mode: Mode = None,
|
363
|
-
name: str = None,
|
364
|
-
):
|
365
|
-
super().__init__(mode=mode, name=name)
|
366
|
-
|
367
|
-
# parameters
|
368
|
-
self.num_out = num_out
|
369
|
-
self.num_in = num_in
|
370
|
-
self.in_size = (num_in,)
|
371
|
-
self.out_size = (num_out,)
|
372
|
-
|
373
|
-
# initializers
|
374
|
-
self._state_initializer = state_init
|
375
|
-
|
376
|
-
# activation function
|
377
|
-
if isinstance(activation, str):
|
378
|
-
self.activation = getattr(functional, activation)
|
379
|
-
else:
|
380
|
-
assert callable(activation), "The activation function should be a string or a callable function. "
|
381
|
-
self.activation = activation
|
382
|
-
|
383
|
-
# weights
|
384
|
-
self.W = Linear(num_in + num_out, num_out * 4, w_init=w_init, b_init=None, name=self.name + '_Wg')
|
385
|
-
self.bias = ParamState(self._forget_bias())
|
386
|
-
|
387
|
-
def _forget_bias(self):
|
388
|
-
u = random.uniform(1 / self.num_out, 1 - 1 / self.num_out, (self.num_out,))
|
389
|
-
return -jnp.log(1 / u - 1)
|
390
|
-
|
391
|
-
def init_state(self, batch_size: int = None, **kwargs):
|
392
|
-
self.c = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
393
|
-
self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
|
394
|
-
|
395
|
-
def reset_state(self, batch_size: int = None, **kwargs):
|
396
|
-
self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
397
|
-
self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
|
398
|
-
|
399
|
-
def update(self, x: ArrayLike) -> ArrayLike:
|
400
|
-
h, c = self.h.value, self.c.value
|
401
|
-
xh = jnp.concat([x, h], axis=-1)
|
402
|
-
f, r, u, o = jnp.split(self.W(xh), indices_or_sections=4, axis=-1)
|
403
|
-
f_ = functional.sigmoid(f + self.bias.value)
|
404
|
-
r_ = functional.sigmoid(r - self.bias.value)
|
405
|
-
g = 2 * r_ * f_ + (1 - 2 * r_) * f_ ** 2
|
406
|
-
next_cell = g * c + (1 - g) * self.activation(u)
|
407
|
-
next_hidden = functional.sigmoid(o) * self.activation(next_cell)
|
408
|
-
self.h.value = next_hidden
|
409
|
-
self.c.value = next_cell
|
410
|
-
return next_hidden
|
brainstate/nn/_readout.py
DELETED
@@ -1,136 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from __future__ import annotations
|
19
|
-
|
20
|
-
import numbers
|
21
|
-
from typing import Callable
|
22
|
-
|
23
|
-
import jax
|
24
|
-
import jax.numpy as jnp
|
25
|
-
|
26
|
-
from ._base import DnnLayer
|
27
|
-
from ._dynamics import Neuron
|
28
|
-
from ._misc import exp_euler_step
|
29
|
-
from .. import environ, init, surrogate
|
30
|
-
from .._state import ShortTermState, ParamState
|
31
|
-
from ..mixin import Mode
|
32
|
-
from brainstate.typing import Size, ArrayLike, DTypeLike
|
33
|
-
|
34
|
-
__all__ = [
|
35
|
-
'LeakyRateReadout',
|
36
|
-
'LeakySpikeReadout',
|
37
|
-
]
|
38
|
-
|
39
|
-
|
40
|
-
class LeakyRateReadout(DnnLayer):
|
41
|
-
"""
|
42
|
-
Leaky dynamics for the read-out module used in the Real-Time Recurrent Learning.
|
43
|
-
"""
|
44
|
-
__module__ = 'brainstate.nn'
|
45
|
-
|
46
|
-
def __init__(
|
47
|
-
self,
|
48
|
-
in_size: Size,
|
49
|
-
out_size: Size,
|
50
|
-
tau: ArrayLike = 5.,
|
51
|
-
w_init: Callable = init.KaimingNormal(),
|
52
|
-
mode: Mode = None,
|
53
|
-
name: str = None,
|
54
|
-
):
|
55
|
-
super().__init__(mode=mode, name=name)
|
56
|
-
|
57
|
-
# parameters
|
58
|
-
self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
|
59
|
-
self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
|
60
|
-
self.tau = init.param(tau, self.in_size)
|
61
|
-
self.decay = jnp.exp(-environ.get_dt() / self.tau)
|
62
|
-
|
63
|
-
# weights
|
64
|
-
self.weight = ParamState(init.param(w_init, (self.in_size[0], self.out_size[0])))
|
65
|
-
|
66
|
-
def init_state(self, batch_size=None, **kwargs):
|
67
|
-
self.r = ShortTermState(init.param(init.Constant(0.), self.out_size, batch_size))
|
68
|
-
|
69
|
-
def reset_state(self, batch_size=None, **kwargs):
|
70
|
-
self.r.value = init.param(init.Constant(0.), self.out_size, batch_size)
|
71
|
-
|
72
|
-
def update(self, x):
|
73
|
-
r = self.decay * self.r.value + x @ self.weight.value
|
74
|
-
self.r.value = r
|
75
|
-
return r
|
76
|
-
|
77
|
-
|
78
|
-
class LeakySpikeReadout(Neuron):
|
79
|
-
"""
|
80
|
-
Integrate-and-fire neuron model.
|
81
|
-
"""
|
82
|
-
|
83
|
-
__module__ = 'brainstate.nn'
|
84
|
-
|
85
|
-
def __init__(
|
86
|
-
self,
|
87
|
-
in_size: Size,
|
88
|
-
keep_size: bool = False,
|
89
|
-
tau: ArrayLike = 5.,
|
90
|
-
V_th: ArrayLike = 1.,
|
91
|
-
w_init: Callable = init.KaimingNormal(),
|
92
|
-
spk_fun: Callable = surrogate.ReluGrad(),
|
93
|
-
spk_dtype: DTypeLike = None,
|
94
|
-
spk_reset: str = 'soft',
|
95
|
-
mode: Mode = None,
|
96
|
-
name: str = None,
|
97
|
-
):
|
98
|
-
super().__init__(in_size, keep_size=keep_size, name=name, mode=mode,
|
99
|
-
spk_fun=spk_fun, spk_dtype=spk_dtype, spk_reset=spk_reset)
|
100
|
-
|
101
|
-
# parameters
|
102
|
-
self.tau = init.param(tau, (self.num,))
|
103
|
-
self.V_th = init.param(V_th, (self.num,))
|
104
|
-
|
105
|
-
# weights
|
106
|
-
self.weight = ParamState(init.param(w_init, (self.in_size[0], self.out_size[0])))
|
107
|
-
|
108
|
-
def dv(self, v, t, x):
|
109
|
-
x = self.sum_current_inputs(v, init=x)
|
110
|
-
return (-v + x) / self.tau
|
111
|
-
|
112
|
-
def init_state(self, batch_size, **kwargs):
|
113
|
-
self.V = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
|
114
|
-
|
115
|
-
def reset_state(self, batch_size, **kwargs):
|
116
|
-
self.V.value = init.param(init.Constant(0.), self.varshape, batch_size)
|
117
|
-
|
118
|
-
@property
|
119
|
-
def spike(self):
|
120
|
-
return self.get_spike(self.V.value)
|
121
|
-
|
122
|
-
def get_spike(self, V):
|
123
|
-
v_scaled = (V - self.V_th) / self.V_th
|
124
|
-
return self.spk_fun(v_scaled)
|
125
|
-
|
126
|
-
def update(self, x):
|
127
|
-
# reset
|
128
|
-
last_V = self.V.value
|
129
|
-
last_spike = self.get_spike(last_V)
|
130
|
-
V_th = self.V_th if self.spk_reset == 'soft' else jax.lax.stop_gradient(last_V)
|
131
|
-
V = last_V - V_th * last_spike
|
132
|
-
# membrane potential
|
133
|
-
V = exp_euler_step(self.dv, V, environ.get('t'), x @ self.weight.value)
|
134
|
-
V = V + self.sum_delta_inputs()
|
135
|
-
self.V.value = V
|
136
|
-
return self.get_spike(V)
|
brainstate/nn/_synouts.py
DELETED
@@ -1,166 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from __future__ import annotations
|
19
|
-
|
20
|
-
from typing import Optional
|
21
|
-
|
22
|
-
import jax.numpy as jnp
|
23
|
-
|
24
|
-
from .._module import Module
|
25
|
-
from ..mixin import DelayedInit, BindCondData
|
26
|
-
from brainstate.typing import ArrayLike
|
27
|
-
|
28
|
-
__all__ = [
|
29
|
-
'SynOut', 'COBA', 'CUBA', 'MgBlock',
|
30
|
-
]
|
31
|
-
|
32
|
-
|
33
|
-
class SynOut(Module, DelayedInit, BindCondData):
|
34
|
-
"""
|
35
|
-
Base class for synaptic outputs.
|
36
|
-
|
37
|
-
:py:class:`~.SynOut` is also subclass of :py:class:`~.ParamDesc` and :py:class:`~.BindCondData`.
|
38
|
-
"""
|
39
|
-
|
40
|
-
__module__ = 'brainstate.nn'
|
41
|
-
|
42
|
-
def __init__(self, name: Optional[str] = None):
|
43
|
-
super().__init__(name=name)
|
44
|
-
self._conductance = None
|
45
|
-
|
46
|
-
def __call__(self, *args, **kwargs):
|
47
|
-
if self._conductance is None:
|
48
|
-
raise ValueError(f'Please first pack conductance data at the current step using '
|
49
|
-
f'".{BindCondData.bind_cond.__name__}(data)". {self}')
|
50
|
-
ret = self.update(self._conductance, *args, **kwargs)
|
51
|
-
return ret
|
52
|
-
|
53
|
-
|
54
|
-
class COBA(SynOut):
|
55
|
-
r"""
|
56
|
-
Conductance-based synaptic output.
|
57
|
-
|
58
|
-
Given the synaptic conductance, the model output the post-synaptic current with
|
59
|
-
|
60
|
-
.. math::
|
61
|
-
|
62
|
-
I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t))
|
63
|
-
|
64
|
-
Parameters
|
65
|
-
----------
|
66
|
-
E: ArrayLike
|
67
|
-
The reversal potential.
|
68
|
-
name: str
|
69
|
-
The model name.
|
70
|
-
|
71
|
-
See Also
|
72
|
-
--------
|
73
|
-
CUBA
|
74
|
-
"""
|
75
|
-
__module__ = 'brainstate.nn'
|
76
|
-
|
77
|
-
def __init__(self, E: ArrayLike, name: Optional[str] = None):
|
78
|
-
super().__init__(name=name)
|
79
|
-
|
80
|
-
self.E = E
|
81
|
-
|
82
|
-
def update(self, conductance, potential):
|
83
|
-
return conductance * (self.E - potential)
|
84
|
-
|
85
|
-
|
86
|
-
class CUBA(SynOut):
|
87
|
-
r"""Current-based synaptic output.
|
88
|
-
|
89
|
-
Given the conductance, this model outputs the post-synaptic current with a identity function:
|
90
|
-
|
91
|
-
.. math::
|
92
|
-
|
93
|
-
I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t)
|
94
|
-
|
95
|
-
Parameters
|
96
|
-
----------
|
97
|
-
name: str
|
98
|
-
The model name.
|
99
|
-
|
100
|
-
See Also
|
101
|
-
--------
|
102
|
-
COBA
|
103
|
-
"""
|
104
|
-
__module__ = 'brainstate.nn'
|
105
|
-
|
106
|
-
def __init__(self, name: Optional[str] = None, ):
|
107
|
-
super().__init__(name=name)
|
108
|
-
|
109
|
-
def update(self, conductance, potential=None):
|
110
|
-
return conductance
|
111
|
-
|
112
|
-
|
113
|
-
class MgBlock(SynOut):
|
114
|
-
r"""Synaptic output based on Magnesium blocking.
|
115
|
-
|
116
|
-
Given the synaptic conductance, the model output the post-synaptic current with
|
117
|
-
|
118
|
-
.. math::
|
119
|
-
|
120
|
-
I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o})
|
121
|
-
|
122
|
-
where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to
|
123
|
-
|
124
|
-
.. math::
|
125
|
-
|
126
|
-
g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1}
|
127
|
-
|
128
|
-
Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration.
|
129
|
-
|
130
|
-
Parameters
|
131
|
-
----------
|
132
|
-
E: ArrayLike
|
133
|
-
The reversal potential for the synaptic current. [mV]
|
134
|
-
alpha: ArrayLike
|
135
|
-
Binding constant. Default 0.062
|
136
|
-
beta: ArrayLike
|
137
|
-
Unbinding constant. Default 3.57
|
138
|
-
cc_Mg: ArrayLike
|
139
|
-
Concentration of Magnesium ion. Default 1.2 [mM].
|
140
|
-
V_offset: ArrayLike
|
141
|
-
The offset potential. Default 0. [mV]
|
142
|
-
name: str
|
143
|
-
The model name.
|
144
|
-
"""
|
145
|
-
__module__ = 'brainstate.nn'
|
146
|
-
|
147
|
-
def __init__(
|
148
|
-
self,
|
149
|
-
E: ArrayLike = 0.,
|
150
|
-
cc_Mg: ArrayLike = 1.2,
|
151
|
-
alpha: ArrayLike = 0.062,
|
152
|
-
beta: ArrayLike = 3.57,
|
153
|
-
V_offset: ArrayLike = 0.,
|
154
|
-
name: Optional[str] = None,
|
155
|
-
):
|
156
|
-
super().__init__(name=name)
|
157
|
-
|
158
|
-
self.E = E
|
159
|
-
self.V_offset = V_offset
|
160
|
-
self.cc_Mg = cc_Mg
|
161
|
-
self.alpha = alpha
|
162
|
-
self.beta = beta
|
163
|
-
|
164
|
-
def update(self, conductance, potential):
|
165
|
-
norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - potential)))
|
166
|
-
return conductance * (self.E - potential) / norm
|