brainstate 0.0.2.post20241009__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1360 -1318
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241009.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241009.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,453 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import math
|
19
|
+
import numbers
|
20
|
+
from functools import partial
|
21
|
+
from typing import Optional, Dict, Callable, Union, Sequence
|
22
|
+
|
23
|
+
import brainunit as u
|
24
|
+
import jax
|
25
|
+
import jax.numpy as jnp
|
26
|
+
import numpy as np
|
27
|
+
|
28
|
+
from brainstate import environ
|
29
|
+
from brainstate._state import ShortTermState, State
|
30
|
+
from brainstate.compile import jit_error_if
|
31
|
+
from brainstate.graph import Node
|
32
|
+
from brainstate.nn._collective_ops import call_order
|
33
|
+
from brainstate.nn._module import Module
|
34
|
+
from brainstate.typing import ArrayLike, PyTree
|
35
|
+
|
36
|
+
__all__ = [
|
37
|
+
'Delay', 'DelayAccess', 'StateWithDelay',
|
38
|
+
]
|
39
|
+
|
40
|
+
_DELAY_ROTATE = 'rotation'
|
41
|
+
_DELAY_CONCAT = 'concat'
|
42
|
+
_INTERP_LINEAR = 'linear_interp'
|
43
|
+
_INTERP_ROUND = 'round'
|
44
|
+
|
45
|
+
|
46
|
+
def _get_delay(delay_time, delay_step):
|
47
|
+
if delay_time is None:
|
48
|
+
if delay_step is None:
|
49
|
+
return 0., 0
|
50
|
+
else:
|
51
|
+
assert isinstance(delay_step, int), '"delay_step" should be an integer.'
|
52
|
+
if delay_step == 0:
|
53
|
+
return 0., 0
|
54
|
+
delay_time = delay_step * environ.get_dt()
|
55
|
+
else:
|
56
|
+
assert delay_step is None, '"delay_step" should be None if "delay_time" is given.'
|
57
|
+
# assert isinstance(delay_time, (int, float))
|
58
|
+
delay_step = math.ceil(delay_time / environ.get_dt())
|
59
|
+
return delay_time, delay_step
|
60
|
+
|
61
|
+
|
62
|
+
class DelayAccess(Node):
|
63
|
+
"""
|
64
|
+
The delay access class.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
delay: The delay instance.
|
68
|
+
time: The delay time.
|
69
|
+
indices: The indices of the delay data.
|
70
|
+
delay_entry: The delay entry.
|
71
|
+
"""
|
72
|
+
|
73
|
+
__module__ = 'brainstate.nn'
|
74
|
+
|
75
|
+
def __init__(
|
76
|
+
self,
|
77
|
+
delay: 'Delay',
|
78
|
+
time: Union[None, int, float],
|
79
|
+
delay_entry: str,
|
80
|
+
*indices,
|
81
|
+
):
|
82
|
+
super().__init__()
|
83
|
+
self.refs = {'delay': delay}
|
84
|
+
assert isinstance(delay, Delay), 'The input delay should be an instance of Delay.'
|
85
|
+
self._delay_entry = delay_entry
|
86
|
+
delay.register_entry(self._delay_entry, time)
|
87
|
+
self.indices = indices
|
88
|
+
|
89
|
+
def update(self):
|
90
|
+
return self.refs['delay'].at(self._delay_entry, *self.indices)
|
91
|
+
|
92
|
+
|
93
|
+
class Delay(Module):
|
94
|
+
"""
|
95
|
+
Generate Delays for the given :py:class:`~.State` instance.
|
96
|
+
|
97
|
+
The data in this delay variable is arranged as::
|
98
|
+
|
99
|
+
delay = 0 [ data
|
100
|
+
delay = 1 data
|
101
|
+
delay = 2 data
|
102
|
+
... ....
|
103
|
+
... ....
|
104
|
+
delay = length-1 data
|
105
|
+
delay = length data ]
|
106
|
+
|
107
|
+
Args:
|
108
|
+
time: int, float. The delay time.
|
109
|
+
init: Any. The delay data. It can be a Python number, like float, int, boolean values.
|
110
|
+
It can also be arrays. Or a callable function or instance of ``Connector``.
|
111
|
+
Note that ``initial_delay_data`` should be arranged as the following way::
|
112
|
+
|
113
|
+
delay = 1 [ data
|
114
|
+
delay = 2 data
|
115
|
+
... ....
|
116
|
+
... ....
|
117
|
+
delay = length-1 data
|
118
|
+
delay = length data ]
|
119
|
+
entries: optional, dict. The delay access entries.
|
120
|
+
delay_method: str. The method used for updating delay. Default None.
|
121
|
+
"""
|
122
|
+
|
123
|
+
__module__ = 'brainstate.nn'
|
124
|
+
|
125
|
+
max_time: float #
|
126
|
+
max_length: int
|
127
|
+
history: Optional[ShortTermState]
|
128
|
+
|
129
|
+
def __init__(
|
130
|
+
self,
|
131
|
+
target_info: PyTree,
|
132
|
+
time: Optional[Union[int, float, u.Quantity]] = None, # delay time
|
133
|
+
init: Optional[Union[ArrayLike, Callable]] = None, # delay data before t0
|
134
|
+
entries: Optional[Dict] = None, # delay access entry
|
135
|
+
delay_method: Optional[str] = _DELAY_ROTATE, # delay method
|
136
|
+
interp_method: str = _INTERP_LINEAR, # interpolation method
|
137
|
+
):
|
138
|
+
# target information
|
139
|
+
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), target_info)
|
140
|
+
|
141
|
+
# delay method
|
142
|
+
assert delay_method in [_DELAY_ROTATE, _DELAY_CONCAT], (f'Un-supported delay method {delay_method}. '
|
143
|
+
f'Only support {_DELAY_ROTATE} and {_DELAY_CONCAT}')
|
144
|
+
self.delay_method = delay_method
|
145
|
+
|
146
|
+
# interp method
|
147
|
+
assert interp_method in [_INTERP_LINEAR, _INTERP_ROUND], (f'Un-supported interpolation method {interp_method}. '
|
148
|
+
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
|
149
|
+
self.interp_method = interp_method
|
150
|
+
|
151
|
+
# delay length and time
|
152
|
+
self.max_time, delay_length = _get_delay(time, None)
|
153
|
+
self.max_length = delay_length + 1
|
154
|
+
|
155
|
+
super().__init__()
|
156
|
+
|
157
|
+
# delay data
|
158
|
+
if init is not None:
|
159
|
+
if not isinstance(init, (numbers.Number, jax.Array, np.ndarray, Callable)):
|
160
|
+
raise TypeError(f'init should be Array, Callable, or None. But got {init}')
|
161
|
+
self._init = init
|
162
|
+
self._history = None
|
163
|
+
|
164
|
+
# other info
|
165
|
+
self._registered_entries = dict()
|
166
|
+
|
167
|
+
# other info
|
168
|
+
if entries is not None:
|
169
|
+
for entry, delay_time in entries.items():
|
170
|
+
self.register_entry(entry, delay_time)
|
171
|
+
|
172
|
+
@property
|
173
|
+
def history(self):
|
174
|
+
return self._history
|
175
|
+
|
176
|
+
@history.setter
|
177
|
+
def history(self, value):
|
178
|
+
self._history = value
|
179
|
+
|
180
|
+
def _f_to_init(self, a, batch_size, length):
|
181
|
+
shape = list(a.shape)
|
182
|
+
if batch_size is not None:
|
183
|
+
shape.insert(0, batch_size)
|
184
|
+
shape.insert(0, length)
|
185
|
+
if isinstance(self._init, (jax.Array, np.ndarray, numbers.Number)):
|
186
|
+
data = jnp.broadcast_to(jnp.asarray(self._init, a.dtype), shape)
|
187
|
+
elif callable(self._init):
|
188
|
+
data = self._init(shape, dtype=a.dtype)
|
189
|
+
else:
|
190
|
+
assert self._init is None, f'init should be Array, Callable, or None. but got {self._init}'
|
191
|
+
data = jnp.zeros(shape, dtype=a.dtype)
|
192
|
+
return data
|
193
|
+
|
194
|
+
@call_order(3)
|
195
|
+
def init_state(self, batch_size: int = None, **kwargs):
|
196
|
+
fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
|
197
|
+
self.history = ShortTermState(jax.tree.map(fun, self.target_info))
|
198
|
+
|
199
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
200
|
+
fun = partial(self._f_to_init, length=self.max_length, batch_size=batch_size)
|
201
|
+
self.history.value = jax.tree.map(fun, self.target_info)
|
202
|
+
|
203
|
+
def register_delay(
|
204
|
+
self,
|
205
|
+
delay_time: Optional[Union[int, float]] = None,
|
206
|
+
delay_step: Optional[int] = None,
|
207
|
+
):
|
208
|
+
if isinstance(delay_time, (np.ndarray, jax.Array)):
|
209
|
+
assert delay_time.size == 1 and delay_time.ndim == 0
|
210
|
+
delay_time = delay_time.item()
|
211
|
+
|
212
|
+
_, delay_step = _get_delay(delay_time, delay_step)
|
213
|
+
|
214
|
+
# delay variable
|
215
|
+
if self.max_length <= delay_step + 1:
|
216
|
+
self.max_length = delay_step + 1
|
217
|
+
self.max_time = delay_time
|
218
|
+
return self
|
219
|
+
|
220
|
+
def register_entry(
|
221
|
+
self,
|
222
|
+
entry: str,
|
223
|
+
delay_time: Optional[Union[int, float]] = None,
|
224
|
+
delay_step: Optional[int] = None,
|
225
|
+
) -> 'Delay':
|
226
|
+
"""
|
227
|
+
Register an entry to access the delay data.
|
228
|
+
|
229
|
+
Args:
|
230
|
+
entry: str. The entry to access the delay data.
|
231
|
+
delay_time: The delay time of the entry (can be a float).
|
232
|
+
delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
Return the self.
|
236
|
+
"""
|
237
|
+
if entry in self._registered_entries:
|
238
|
+
raise KeyError(f'Entry {entry} has been registered. '
|
239
|
+
f'The existing delay for the key {entry} is {self._registered_entries[entry]}. '
|
240
|
+
f'The new delay for the key {entry} is {delay_time}. '
|
241
|
+
f'You can use another key. ')
|
242
|
+
|
243
|
+
if isinstance(delay_time, (np.ndarray, jax.Array)):
|
244
|
+
assert delay_time.size == 1 and delay_time.ndim == 0
|
245
|
+
delay_time = delay_time.item()
|
246
|
+
|
247
|
+
_, delay_step = _get_delay(delay_time, delay_step)
|
248
|
+
|
249
|
+
# delay variable
|
250
|
+
if self.max_length <= delay_step + 1:
|
251
|
+
self.max_length = delay_step + 1
|
252
|
+
self.max_time = delay_time
|
253
|
+
self._registered_entries[entry] = delay_step
|
254
|
+
return self
|
255
|
+
|
256
|
+
def access(
|
257
|
+
self,
|
258
|
+
entry: str = None,
|
259
|
+
time: Sequence = None,
|
260
|
+
) -> DelayAccess:
|
261
|
+
return DelayAccess(self, time, delay_entry=entry)
|
262
|
+
|
263
|
+
def at(self, entry: str, *indices) -> ArrayLike:
|
264
|
+
"""
|
265
|
+
Get the data at the given entry.
|
266
|
+
|
267
|
+
Args:
|
268
|
+
entry: str. The entry to access the data.
|
269
|
+
*indices: The slicing indices. Not include the slice at the batch dimension.
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
The data.
|
273
|
+
"""
|
274
|
+
assert isinstance(entry, str), (f'entry should be a string for describing the '
|
275
|
+
f'entry of the delay data. But we got {entry}.')
|
276
|
+
if entry not in self._registered_entries:
|
277
|
+
raise KeyError(f'Does not find delay entry "{entry}".')
|
278
|
+
delay_step = self._registered_entries[entry]
|
279
|
+
if delay_step is None:
|
280
|
+
delay_step = 0
|
281
|
+
return self.retrieve_at_step(delay_step, *indices)
|
282
|
+
|
283
|
+
def retrieve_at_step(self, delay_step, *indices) -> PyTree:
|
284
|
+
"""
|
285
|
+
Retrieve the delay data at the given delay time step (the integer to indicate the time step).
|
286
|
+
|
287
|
+
Parameters
|
288
|
+
----------
|
289
|
+
delay_step: int_like
|
290
|
+
Retrieve the data at the given time step.
|
291
|
+
indices: tuple
|
292
|
+
The indices to slice the data.
|
293
|
+
|
294
|
+
Returns
|
295
|
+
-------
|
296
|
+
delay_data: The delay data at the given delay step.
|
297
|
+
|
298
|
+
"""
|
299
|
+
assert self.history is not None, 'The delay history is not initialized.'
|
300
|
+
assert delay_step is not None, 'The delay step should be given.'
|
301
|
+
|
302
|
+
if environ.get(environ.JIT_ERROR_CHECK, False):
|
303
|
+
def _check_delay(delay_len):
|
304
|
+
raise ValueError(f'The request delay length should be less than the '
|
305
|
+
f'maximum delay {self.max_length - 1}. But we got {delay_len}')
|
306
|
+
|
307
|
+
jit_error_if(delay_step >= self.max_length, _check_delay, delay_step)
|
308
|
+
|
309
|
+
# rotation method
|
310
|
+
if self.delay_method == _DELAY_ROTATE:
|
311
|
+
i = environ.get(environ.I, desc='The time step index.')
|
312
|
+
di = i - delay_step
|
313
|
+
delay_idx = jnp.asarray(di % self.max_length, dtype=jnp.int32)
|
314
|
+
delay_idx = jax.lax.stop_gradient(delay_idx)
|
315
|
+
|
316
|
+
elif self.delay_method == _DELAY_CONCAT:
|
317
|
+
delay_idx = delay_step
|
318
|
+
|
319
|
+
else:
|
320
|
+
raise ValueError(f'Unknown delay updating method "{self.delay_method}"')
|
321
|
+
|
322
|
+
# the delay index
|
323
|
+
if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer):
|
324
|
+
raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}')
|
325
|
+
indices = (delay_idx,) + indices
|
326
|
+
|
327
|
+
# the delay data
|
328
|
+
return jax.tree.map(lambda a: a[indices], self.history.value)
|
329
|
+
|
330
|
+
def retrieve_at_time(self, delay_time, *indices) -> PyTree:
|
331
|
+
"""
|
332
|
+
Retrieve the delay data at the given delay time step (the integer to indicate the time step).
|
333
|
+
|
334
|
+
Parameters
|
335
|
+
----------
|
336
|
+
delay_time: float
|
337
|
+
Retrieve the data at the given time.
|
338
|
+
indices: tuple
|
339
|
+
The indices to slice the data.
|
340
|
+
|
341
|
+
Returns
|
342
|
+
-------
|
343
|
+
delay_data: The delay data at the given delay step.
|
344
|
+
|
345
|
+
"""
|
346
|
+
assert self.history is not None, 'The delay history is not initialized.'
|
347
|
+
assert delay_time is not None, 'The delay time should be given.'
|
348
|
+
|
349
|
+
current_time = environ.get(environ.T, desc='The current time.')
|
350
|
+
dt = environ.get_dt()
|
351
|
+
|
352
|
+
if environ.get(environ.JIT_ERROR_CHECK, False):
|
353
|
+
def _check_delay(t_now, t_delay):
|
354
|
+
raise ValueError(f'The request delay time should be within '
|
355
|
+
f'[{t_now - self.max_time - dt}, {t_now}], '
|
356
|
+
f'but we got {t_delay}')
|
357
|
+
|
358
|
+
jit_error_if(
|
359
|
+
jnp.logical_or(delay_time > current_time,
|
360
|
+
delay_time < current_time - self.max_time - dt),
|
361
|
+
_check_delay,
|
362
|
+
current_time,
|
363
|
+
delay_time
|
364
|
+
)
|
365
|
+
|
366
|
+
diff = current_time - delay_time
|
367
|
+
float_time_step = diff / dt
|
368
|
+
|
369
|
+
if self.interp_method == _INTERP_LINEAR: # "linear" interpolation
|
370
|
+
data_at_t0 = self.retrieve_at_step(jnp.asarray(jnp.floor(float_time_step), dtype=jnp.int32), *indices)
|
371
|
+
data_at_t1 = self.retrieve_at_step(jnp.asarray(jnp.ceil(float_time_step), dtype=jnp.int32), *indices)
|
372
|
+
t_diff = float_time_step - jnp.floor(float_time_step)
|
373
|
+
return jax.tree.map(lambda a, b: a * (1 - t_diff) + b * t_diff, data_at_t0, data_at_t1)
|
374
|
+
|
375
|
+
elif self.interp_method == _INTERP_ROUND: # "round" interpolation
|
376
|
+
return self.retrieve_at_step(
|
377
|
+
jnp.asarray(jnp.round(float_time_step), dtype=jnp.int32),
|
378
|
+
*indices
|
379
|
+
)
|
380
|
+
|
381
|
+
else: # raise error
|
382
|
+
raise ValueError(f'Un-supported interpolation method {self.interp_method}, '
|
383
|
+
f'we only support: {[_INTERP_LINEAR, _INTERP_ROUND]}')
|
384
|
+
|
385
|
+
def update(self, current: PyTree) -> None:
|
386
|
+
"""
|
387
|
+
Update delay variable with the new data.
|
388
|
+
"""
|
389
|
+
assert self.history is not None, 'The delay history is not initialized.'
|
390
|
+
|
391
|
+
# update the delay data at the rotation index
|
392
|
+
if self.delay_method == _DELAY_ROTATE:
|
393
|
+
i = environ.get(environ.I)
|
394
|
+
idx = jnp.asarray(i % self.max_length, dtype=environ.dutype())
|
395
|
+
idx = jax.lax.stop_gradient(idx)
|
396
|
+
self.history.value = jax.tree.map(
|
397
|
+
lambda hist, cur: hist.at[idx].set(cur),
|
398
|
+
self.history.value,
|
399
|
+
current
|
400
|
+
)
|
401
|
+
# update the delay data at the first position
|
402
|
+
elif self.delay_method == _DELAY_CONCAT:
|
403
|
+
current = jax.tree.map(lambda a: jnp.expand_dims(a, 0), current)
|
404
|
+
if self.max_length > 1:
|
405
|
+
self.history.value = jax.tree.map(
|
406
|
+
lambda hist, cur: jnp.concatenate([cur, hist[:-1]], axis=0),
|
407
|
+
self.history.value,
|
408
|
+
current
|
409
|
+
)
|
410
|
+
else:
|
411
|
+
self.history.value = current
|
412
|
+
|
413
|
+
else:
|
414
|
+
raise ValueError(f'Unknown updating method "{self.delay_method}"')
|
415
|
+
|
416
|
+
|
417
|
+
class StateWithDelay(Delay):
|
418
|
+
"""
|
419
|
+
A ``State`` type that defines the state in a differential equation.
|
420
|
+
"""
|
421
|
+
|
422
|
+
__module__ = 'brainstate.nn'
|
423
|
+
|
424
|
+
state: State # state
|
425
|
+
|
426
|
+
def __init__(self, target: Node, item: str):
|
427
|
+
super().__init__(None)
|
428
|
+
|
429
|
+
self._target = target
|
430
|
+
self._target_term = item
|
431
|
+
|
432
|
+
@property
|
433
|
+
def state(self) -> State:
|
434
|
+
r = getattr(self._target, self._target_term)
|
435
|
+
if not isinstance(r, State):
|
436
|
+
raise TypeError(f'The term "{self._target_term}" in the module "{self._target}" is not a State.')
|
437
|
+
return r
|
438
|
+
|
439
|
+
@call_order(3)
|
440
|
+
def init_state(self, *args, **kwargs):
|
441
|
+
"""
|
442
|
+
State initialization function.
|
443
|
+
"""
|
444
|
+
state = self.state
|
445
|
+
self.target_info = jax.tree.map(lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), state.value)
|
446
|
+
super().init_state(*args, **kwargs)
|
447
|
+
|
448
|
+
def update(self, *args) -> None:
|
449
|
+
"""
|
450
|
+
Update the delay variable with the new data.
|
451
|
+
"""
|
452
|
+
value = self.state.value
|
453
|
+
return super().update(value)
|
@@ -0,0 +1,161 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
import brainunit as u
|
21
|
+
import jax.numpy as jnp
|
22
|
+
|
23
|
+
from brainstate.mixin import BindCondData
|
24
|
+
from brainstate.nn._module import Module
|
25
|
+
from brainstate.typing import ArrayLike
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
'SynOut', 'COBA', 'CUBA', 'MgBlock',
|
29
|
+
]
|
30
|
+
|
31
|
+
|
32
|
+
class SynOut(Module, BindCondData):
|
33
|
+
"""
|
34
|
+
Base class for synaptic outputs.
|
35
|
+
|
36
|
+
:py:class:`~.SynOut` is also subclass of :py:class:`~.ParamDesc` and :py:class:`~.BindCondData`.
|
37
|
+
"""
|
38
|
+
|
39
|
+
__module__ = 'brainstate.nn'
|
40
|
+
|
41
|
+
def __init__(self, ):
|
42
|
+
super().__init__()
|
43
|
+
self._conductance = None
|
44
|
+
|
45
|
+
def __call__(self, *args, **kwargs):
|
46
|
+
if self._conductance is None:
|
47
|
+
raise ValueError(f'Please first pack conductance data at the current step using '
|
48
|
+
f'".{BindCondData.bind_cond.__name__}(data)". {self}')
|
49
|
+
ret = self.update(self._conductance, *args, **kwargs)
|
50
|
+
return ret
|
51
|
+
|
52
|
+
|
53
|
+
class COBA(SynOut):
|
54
|
+
r"""
|
55
|
+
Conductance-based synaptic output.
|
56
|
+
|
57
|
+
Given the synaptic conductance, the model output the post-synaptic current with
|
58
|
+
|
59
|
+
.. math::
|
60
|
+
|
61
|
+
I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t))
|
62
|
+
|
63
|
+
Parameters
|
64
|
+
----------
|
65
|
+
E: ArrayLike
|
66
|
+
The reversal potential.
|
67
|
+
|
68
|
+
See Also
|
69
|
+
--------
|
70
|
+
CUBA
|
71
|
+
"""
|
72
|
+
__module__ = 'brainstate.nn'
|
73
|
+
|
74
|
+
def __init__(self, E: ArrayLike):
|
75
|
+
super().__init__()
|
76
|
+
|
77
|
+
self.E = E
|
78
|
+
|
79
|
+
def update(self, conductance, potential):
|
80
|
+
return conductance * (self.E - potential)
|
81
|
+
|
82
|
+
|
83
|
+
class CUBA(SynOut):
|
84
|
+
r"""Current-based synaptic output.
|
85
|
+
|
86
|
+
Given the conductance, this model outputs the post-synaptic current with a identity function:
|
87
|
+
|
88
|
+
.. math::
|
89
|
+
|
90
|
+
I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t)
|
91
|
+
|
92
|
+
Parameters
|
93
|
+
----------
|
94
|
+
scale: ArrayLike
|
95
|
+
The scaling factor for the conductance. Default 1. [mV]
|
96
|
+
|
97
|
+
See Also
|
98
|
+
--------
|
99
|
+
COBA
|
100
|
+
"""
|
101
|
+
__module__ = 'brainstate.nn'
|
102
|
+
|
103
|
+
def __init__(self, scale: ArrayLike = u.volt):
|
104
|
+
super().__init__()
|
105
|
+
self.scale = scale
|
106
|
+
|
107
|
+
def update(self, conductance, potential=None):
|
108
|
+
return conductance * self.scale
|
109
|
+
|
110
|
+
|
111
|
+
class MgBlock(SynOut):
|
112
|
+
r"""Synaptic output based on Magnesium blocking.
|
113
|
+
|
114
|
+
Given the synaptic conductance, the model output the post-synaptic current with
|
115
|
+
|
116
|
+
.. math::
|
117
|
+
|
118
|
+
I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o})
|
119
|
+
|
120
|
+
where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to
|
121
|
+
|
122
|
+
.. math::
|
123
|
+
|
124
|
+
g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1}
|
125
|
+
|
126
|
+
Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration.
|
127
|
+
|
128
|
+
Parameters
|
129
|
+
----------
|
130
|
+
E: ArrayLike
|
131
|
+
The reversal potential for the synaptic current. [mV]
|
132
|
+
alpha: ArrayLike
|
133
|
+
Binding constant. Default 0.062
|
134
|
+
beta: ArrayLike
|
135
|
+
Unbinding constant. Default 3.57
|
136
|
+
cc_Mg: ArrayLike
|
137
|
+
Concentration of Magnesium ion. Default 1.2 [mM].
|
138
|
+
V_offset: ArrayLike
|
139
|
+
The offset potential. Default 0. [mV]
|
140
|
+
"""
|
141
|
+
__module__ = 'brainstate.nn'
|
142
|
+
|
143
|
+
def __init__(
|
144
|
+
self,
|
145
|
+
E: ArrayLike = 0.,
|
146
|
+
cc_Mg: ArrayLike = 1.2,
|
147
|
+
alpha: ArrayLike = 0.062,
|
148
|
+
beta: ArrayLike = 3.57,
|
149
|
+
V_offset: ArrayLike = 0.,
|
150
|
+
):
|
151
|
+
super().__init__()
|
152
|
+
|
153
|
+
self.E = E
|
154
|
+
self.V_offset = V_offset
|
155
|
+
self.cc_Mg = cc_Mg
|
156
|
+
self.alpha = alpha
|
157
|
+
self.beta = beta
|
158
|
+
|
159
|
+
def update(self, conductance, potential):
|
160
|
+
norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - potential)))
|
161
|
+
return conductance * (self.E - potential) / norm
|
@@ -0,0 +1,58 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import unittest
|
19
|
+
|
20
|
+
import brainunit as u
|
21
|
+
import jax.numpy as jnp
|
22
|
+
import numpy as np
|
23
|
+
|
24
|
+
import brainstate as bst
|
25
|
+
|
26
|
+
|
27
|
+
class TestSynOutModels(unittest.TestCase):
|
28
|
+
def setUp(self):
|
29
|
+
self.conductance = jnp.array([0.5, 1.0, 1.5])
|
30
|
+
self.potential = jnp.array([-70.0, -65.0, -60.0])
|
31
|
+
self.E = jnp.array([-70.0])
|
32
|
+
self.alpha = jnp.array([0.062])
|
33
|
+
self.beta = jnp.array([3.57])
|
34
|
+
self.cc_Mg = jnp.array([1.2])
|
35
|
+
self.V_offset = jnp.array([0.0])
|
36
|
+
|
37
|
+
def test_COBA(self):
|
38
|
+
model = bst.nn.COBA(E=self.E)
|
39
|
+
output = model.update(self.conductance, self.potential)
|
40
|
+
expected_output = self.conductance * (self.E - self.potential)
|
41
|
+
np.testing.assert_array_almost_equal(output, expected_output)
|
42
|
+
|
43
|
+
def test_CUBA(self):
|
44
|
+
model = bst.nn.CUBA()
|
45
|
+
output = model.update(self.conductance)
|
46
|
+
expected_output = self.conductance * model.scale
|
47
|
+
self.assertTrue(u.math.allclose(output, expected_output))
|
48
|
+
|
49
|
+
def test_MgBlock(self):
|
50
|
+
model = bst.nn.MgBlock(E=self.E, cc_Mg=self.cc_Mg, alpha=self.alpha, beta=self.beta, V_offset=self.V_offset)
|
51
|
+
output = model.update(self.conductance, self.potential)
|
52
|
+
norm = (1 + self.cc_Mg / self.beta * jnp.exp(self.alpha * (self.V_offset - self.potential)))
|
53
|
+
expected_output = self.conductance * (self.E - self.potential) / norm
|
54
|
+
np.testing.assert_array_almost_equal(output, expected_output)
|
55
|
+
|
56
|
+
|
57
|
+
if __name__ == '__main__':
|
58
|
+
unittest.main()
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from ._dropout import *
|
17
|
+
from ._dropout import __all__ as dropout_all
|
18
|
+
from ._elementwise import *
|
19
|
+
from ._elementwise import __all__ as elementwise_all
|
20
|
+
|
21
|
+
__all__ = dropout_all + elementwise_all
|
22
|
+
del dropout_all, elementwise_all
|