brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,211 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import unittest
|
19
|
+
|
20
|
+
import jax.numpy as jnp
|
21
|
+
import jaxlib.xla_extension
|
22
|
+
|
23
|
+
import brainstate as bst
|
24
|
+
|
25
|
+
|
26
|
+
class TestDelay(unittest.TestCase):
|
27
|
+
def test_delay1(self):
|
28
|
+
a = bst.State(bst.random.random(10, 20))
|
29
|
+
delay = bst.nn.Delay(a.value)
|
30
|
+
delay.register_entry('a', 1.)
|
31
|
+
delay.register_entry('b', 2.)
|
32
|
+
delay.register_entry('c', None)
|
33
|
+
|
34
|
+
delay.init_state()
|
35
|
+
with self.assertRaises(KeyError):
|
36
|
+
delay.register_entry('c', 10.)
|
37
|
+
|
38
|
+
def test_rotation_delay(self):
|
39
|
+
rotation_delay = bst.nn.Delay(jnp.ones((1,)))
|
40
|
+
t0 = 0.
|
41
|
+
t1, n1 = 1., 10
|
42
|
+
t2, n2 = 2., 20
|
43
|
+
|
44
|
+
rotation_delay.register_entry('a', t0)
|
45
|
+
rotation_delay.register_entry('b', t1)
|
46
|
+
rotation_delay.register_entry('c2', 1.9)
|
47
|
+
rotation_delay.register_entry('c', t2)
|
48
|
+
|
49
|
+
rotation_delay.init_state()
|
50
|
+
|
51
|
+
print()
|
52
|
+
# print(rotation_delay)
|
53
|
+
# print(rotation_delay.max_length)
|
54
|
+
|
55
|
+
for i in range(100):
|
56
|
+
bst.environ.set(i=i)
|
57
|
+
rotation_delay.update(jnp.ones((1,)) * i)
|
58
|
+
# print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
|
59
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
60
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
61
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
62
|
+
|
63
|
+
def test_concat_delay(self):
|
64
|
+
rotation_delay = bst.nn.Delay(jnp.ones([1]), delay_method='concat')
|
65
|
+
t0 = 0.
|
66
|
+
t1, n1 = 1., 10
|
67
|
+
t2, n2 = 2., 20
|
68
|
+
|
69
|
+
rotation_delay.register_entry('a', t0)
|
70
|
+
rotation_delay.register_entry('b', t1)
|
71
|
+
rotation_delay.register_entry('c', t2)
|
72
|
+
|
73
|
+
rotation_delay.init_state()
|
74
|
+
|
75
|
+
print()
|
76
|
+
for i in range(100):
|
77
|
+
bst.environ.set(i=i)
|
78
|
+
rotation_delay.update(jnp.ones((1,)) * i)
|
79
|
+
print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
|
80
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
81
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
82
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
83
|
+
# bst.util.clear_buffer_memory()
|
84
|
+
|
85
|
+
def test_jit_erro(self):
|
86
|
+
rotation_delay = bst.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|
87
|
+
rotation_delay.init_state()
|
88
|
+
|
89
|
+
with bst.environ.context(i=0, t=0, jit_error_check=True):
|
90
|
+
rotation_delay.retrieve_at_time(-2.0)
|
91
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
92
|
+
rotation_delay.retrieve_at_time(-2.1)
|
93
|
+
rotation_delay.retrieve_at_time(-2.01)
|
94
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
95
|
+
rotation_delay.retrieve_at_time(-2.09)
|
96
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
97
|
+
rotation_delay.retrieve_at_time(0.1)
|
98
|
+
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
99
|
+
rotation_delay.retrieve_at_time(0.01)
|
100
|
+
|
101
|
+
def test_round_interp(self):
|
102
|
+
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
103
|
+
for delay_method in ['rotation', 'concat']:
|
104
|
+
rotation_delay = bst.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
|
105
|
+
interp_method='round')
|
106
|
+
t0, n1 = 0.01, 0
|
107
|
+
t1, n1 = 1.04, 10
|
108
|
+
t2, n2 = 1.06, 11
|
109
|
+
rotation_delay.init_state()
|
110
|
+
|
111
|
+
@bst.compile.jit
|
112
|
+
def retrieve(td, i):
|
113
|
+
with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
|
114
|
+
return rotation_delay.retrieve_at_time(td)
|
115
|
+
|
116
|
+
print()
|
117
|
+
for i in range(100):
|
118
|
+
t = i * bst.environ.get_dt()
|
119
|
+
with bst.environ.context(i=i, t=t):
|
120
|
+
rotation_delay.update(jnp.ones(shape) * i)
|
121
|
+
print(i,
|
122
|
+
retrieve(t - t0, i),
|
123
|
+
retrieve(t - t1, i),
|
124
|
+
retrieve(t - t2, i))
|
125
|
+
self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.ones(shape) * i))
|
126
|
+
self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
|
127
|
+
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
128
|
+
|
129
|
+
def test_linear_interp(self):
|
130
|
+
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
131
|
+
for delay_method in ['rotation', 'concat']:
|
132
|
+
print(shape, delay_method)
|
133
|
+
|
134
|
+
rotation_delay = bst.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
|
135
|
+
interp_method='linear_interp')
|
136
|
+
t0, n0 = 0.01, 0.1
|
137
|
+
t1, n1 = 1.04, 10.4
|
138
|
+
t2, n2 = 1.06, 10.6
|
139
|
+
rotation_delay.init_state()
|
140
|
+
|
141
|
+
@bst.compile.jit
|
142
|
+
def retrieve(td, i):
|
143
|
+
with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
|
144
|
+
return rotation_delay.retrieve_at_time(td)
|
145
|
+
|
146
|
+
print()
|
147
|
+
for i in range(100):
|
148
|
+
t = i * bst.environ.get_dt()
|
149
|
+
with bst.environ.context(i=i, t=t):
|
150
|
+
rotation_delay.update(jnp.ones(shape) * i)
|
151
|
+
print(i,
|
152
|
+
retrieve(t - t0, i),
|
153
|
+
retrieve(t - t1, i),
|
154
|
+
retrieve(t - t2, i))
|
155
|
+
self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.maximum(jnp.ones(shape) * i - n0, 0.)))
|
156
|
+
self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
|
157
|
+
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
158
|
+
|
159
|
+
def test_rotation_and_concat_delay(self):
|
160
|
+
rotation_delay = bst.nn.Delay(jnp.ones((1,)))
|
161
|
+
concat_delay = bst.nn.Delay(jnp.ones([1]), delay_method='concat')
|
162
|
+
t0 = 0.
|
163
|
+
t1, n1 = 1., 10
|
164
|
+
t2, n2 = 2., 20
|
165
|
+
|
166
|
+
rotation_delay.register_entry('a', t0)
|
167
|
+
rotation_delay.register_entry('b', t1)
|
168
|
+
rotation_delay.register_entry('c', t2)
|
169
|
+
concat_delay.register_entry('a', t0)
|
170
|
+
concat_delay.register_entry('b', t1)
|
171
|
+
concat_delay.register_entry('c', t2)
|
172
|
+
|
173
|
+
rotation_delay.init_state()
|
174
|
+
concat_delay.init_state()
|
175
|
+
|
176
|
+
print()
|
177
|
+
for i in range(100):
|
178
|
+
bst.environ.set(i=i)
|
179
|
+
new = jnp.ones((1,)) * i
|
180
|
+
rotation_delay.update(new)
|
181
|
+
concat_delay.update(new)
|
182
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('a'), concat_delay.at('a'), ))
|
183
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('b'), concat_delay.at('b'), ))
|
184
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('c'), concat_delay.at('c'), ))
|
185
|
+
|
186
|
+
|
187
|
+
class TestModule(unittest.TestCase):
|
188
|
+
def test_states(self):
|
189
|
+
class A(bst.nn.Module):
|
190
|
+
def __init__(self):
|
191
|
+
super().__init__()
|
192
|
+
self.a = bst.State(bst.random.random(10, 20))
|
193
|
+
self.b = bst.State(bst.random.random(10, 20))
|
194
|
+
|
195
|
+
class B(bst.nn.Module):
|
196
|
+
def __init__(self):
|
197
|
+
super().__init__()
|
198
|
+
self.a = A()
|
199
|
+
self.b = bst.State(bst.random.random(10, 20))
|
200
|
+
|
201
|
+
b = B()
|
202
|
+
print()
|
203
|
+
print(b.states())
|
204
|
+
print(b.states())
|
205
|
+
print(b.states(level=0))
|
206
|
+
print(b.states(level=0))
|
207
|
+
|
208
|
+
|
209
|
+
if __name__ == '__main__':
|
210
|
+
with bst.environ.context(dt=0.1):
|
211
|
+
unittest.main()
|