brainstate 0.1.7__py2.py3-none-any.whl → 0.1.9__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -146
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +509 -470
- brainstate/nn/_delay_test.py +238 -0
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1361
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1120
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -208
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/METADATA +91 -99
- brainstate-0.1.9.dist-info/RECORD +130 -0
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/WHEEL +1 -1
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.7.dist-info/RECORD +0 -131
- {brainstate-0.1.7.dist-info → brainstate-0.1.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,238 @@
|
|
1
|
+
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
import unittest
|
18
|
+
|
19
|
+
import jax.numpy as jnp
|
20
|
+
|
21
|
+
import brainstate
|
22
|
+
|
23
|
+
brainstate.environ.set(dt=0.1)
|
24
|
+
|
25
|
+
|
26
|
+
class TestDelay(unittest.TestCase):
|
27
|
+
def test_delay1(self):
|
28
|
+
a = brainstate.State(brainstate.random.random(10, 20))
|
29
|
+
delay = brainstate.nn.Delay(a.value)
|
30
|
+
delay.register_entry('a', 1.)
|
31
|
+
delay.register_entry('b', 2.)
|
32
|
+
delay.register_entry('c', None)
|
33
|
+
|
34
|
+
delay.init_state()
|
35
|
+
with self.assertRaises(KeyError):
|
36
|
+
delay.register_entry('c', 10.)
|
37
|
+
|
38
|
+
def test_rotation_delay(self):
|
39
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
|
40
|
+
t0 = 0.
|
41
|
+
t1, n1 = 1., 10
|
42
|
+
t2, n2 = 2., 20
|
43
|
+
|
44
|
+
rotation_delay.register_entry('a', t0)
|
45
|
+
rotation_delay.register_entry('b', t1)
|
46
|
+
rotation_delay.register_entry('c2', 1.9)
|
47
|
+
rotation_delay.register_entry('c', t2)
|
48
|
+
|
49
|
+
rotation_delay.init_state()
|
50
|
+
|
51
|
+
print()
|
52
|
+
# print(rotation_delay)
|
53
|
+
# print(rotation_delay.max_length)
|
54
|
+
|
55
|
+
for i in range(100):
|
56
|
+
brainstate.environ.set(i=i)
|
57
|
+
rotation_delay.update(jnp.ones((1,)) * i)
|
58
|
+
# print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
|
59
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
60
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
61
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
62
|
+
|
63
|
+
def test_concat_delay(self):
|
64
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
|
65
|
+
t0 = 0.
|
66
|
+
t1, n1 = 1., 10
|
67
|
+
t2, n2 = 2., 20
|
68
|
+
|
69
|
+
rotation_delay.register_entry('a', t0)
|
70
|
+
rotation_delay.register_entry('b', t1)
|
71
|
+
rotation_delay.register_entry('c', t2)
|
72
|
+
|
73
|
+
rotation_delay.init_state()
|
74
|
+
|
75
|
+
print()
|
76
|
+
for i in range(100):
|
77
|
+
brainstate.environ.set(i=i)
|
78
|
+
rotation_delay.update(jnp.ones((1,)) * i)
|
79
|
+
print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
|
80
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
81
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
82
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
83
|
+
# brainstate.util.clear_buffer_memory()
|
84
|
+
|
85
|
+
def test_jit_erro(self):
|
86
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|
87
|
+
rotation_delay.init_state()
|
88
|
+
|
89
|
+
with brainstate.environ.context(i=0, t=0, jit_error_check=True):
|
90
|
+
rotation_delay.retrieve_at_time(-2.0)
|
91
|
+
with self.assertRaises(Exception):
|
92
|
+
rotation_delay.retrieve_at_time(-2.1)
|
93
|
+
rotation_delay.retrieve_at_time(-2.01)
|
94
|
+
with self.assertRaises(Exception):
|
95
|
+
rotation_delay.retrieve_at_time(-2.09)
|
96
|
+
with self.assertRaises(Exception):
|
97
|
+
rotation_delay.retrieve_at_time(0.1)
|
98
|
+
with self.assertRaises(Exception):
|
99
|
+
rotation_delay.retrieve_at_time(0.01)
|
100
|
+
|
101
|
+
def test_round_interp(self):
|
102
|
+
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
103
|
+
for delay_method in ['rotation', 'concat']:
|
104
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
|
105
|
+
interp_method='round')
|
106
|
+
t0, n1 = 0.01, 0
|
107
|
+
t1, n1 = 1.04, 10
|
108
|
+
t2, n2 = 1.06, 11
|
109
|
+
rotation_delay.init_state()
|
110
|
+
|
111
|
+
@brainstate.compile.jit
|
112
|
+
def retrieve(td, i):
|
113
|
+
with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
|
114
|
+
return rotation_delay.retrieve_at_time(td)
|
115
|
+
|
116
|
+
print()
|
117
|
+
for i in range(100):
|
118
|
+
t = i * brainstate.environ.get_dt()
|
119
|
+
with brainstate.environ.context(i=i, t=t):
|
120
|
+
rotation_delay.update(jnp.ones(shape) * i)
|
121
|
+
print(i,
|
122
|
+
retrieve(t - t0, i),
|
123
|
+
retrieve(t - t1, i),
|
124
|
+
retrieve(t - t2, i))
|
125
|
+
self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.ones(shape) * i))
|
126
|
+
self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
|
127
|
+
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
128
|
+
|
129
|
+
def test_linear_interp(self):
|
130
|
+
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
131
|
+
for delay_method in ['rotation', 'concat']:
|
132
|
+
print(shape, delay_method)
|
133
|
+
|
134
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones(shape), time=2., delay_method=delay_method,
|
135
|
+
interp_method='linear_interp')
|
136
|
+
t0, n0 = 0.01, 0.1
|
137
|
+
t1, n1 = 1.04, 10.4
|
138
|
+
t2, n2 = 1.06, 10.6
|
139
|
+
rotation_delay.init_state()
|
140
|
+
|
141
|
+
@brainstate.compile.jit
|
142
|
+
def retrieve(td, i):
|
143
|
+
with brainstate.environ.context(i=i, t=i * brainstate.environ.get_dt()):
|
144
|
+
return rotation_delay.retrieve_at_time(td)
|
145
|
+
|
146
|
+
print()
|
147
|
+
for i in range(100):
|
148
|
+
t = i * brainstate.environ.get_dt()
|
149
|
+
with brainstate.environ.context(i=i, t=t):
|
150
|
+
rotation_delay.update(jnp.ones(shape) * i)
|
151
|
+
print(i,
|
152
|
+
retrieve(t - t0, i),
|
153
|
+
retrieve(t - t1, i),
|
154
|
+
retrieve(t - t2, i))
|
155
|
+
self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.maximum(jnp.ones(shape) * i - n0, 0.)))
|
156
|
+
self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
|
157
|
+
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
158
|
+
|
159
|
+
def test_rotation_and_concat_delay(self):
|
160
|
+
rotation_delay = brainstate.nn.Delay(jnp.ones((1,)))
|
161
|
+
concat_delay = brainstate.nn.Delay(jnp.ones([1]), delay_method='concat')
|
162
|
+
t0 = 0.
|
163
|
+
t1, n1 = 1., 10
|
164
|
+
t2, n2 = 2., 20
|
165
|
+
|
166
|
+
rotation_delay.register_entry('a', t0)
|
167
|
+
rotation_delay.register_entry('b', t1)
|
168
|
+
rotation_delay.register_entry('c', t2)
|
169
|
+
concat_delay.register_entry('a', t0)
|
170
|
+
concat_delay.register_entry('b', t1)
|
171
|
+
concat_delay.register_entry('c', t2)
|
172
|
+
|
173
|
+
rotation_delay.init_state()
|
174
|
+
concat_delay.init_state()
|
175
|
+
|
176
|
+
print()
|
177
|
+
for i in range(100):
|
178
|
+
brainstate.environ.set(i=i)
|
179
|
+
new = jnp.ones((1,)) * i
|
180
|
+
rotation_delay.update(new)
|
181
|
+
concat_delay.update(new)
|
182
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('a'), concat_delay.at('a'), ))
|
183
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('b'), concat_delay.at('b'), ))
|
184
|
+
self.assertTrue(jnp.allclose(rotation_delay.at('c'), concat_delay.at('c'), ))
|
185
|
+
|
186
|
+
def test_delay_2d(self):
|
187
|
+
with brainstate.environ.context(dt=0.1, i=0):
|
188
|
+
rotation_delay = brainstate.nn.Delay(jnp.arange(2))
|
189
|
+
index = (brainstate.random.uniform(0., 10., (2, 2)),
|
190
|
+
brainstate.random.randint(0, 2, (2, 2)))
|
191
|
+
rotation_delay.register_entry('a', *index)
|
192
|
+
rotation_delay.init_state()
|
193
|
+
data = rotation_delay.at('a')
|
194
|
+
print(index[0])
|
195
|
+
print(index[1])
|
196
|
+
print(data)
|
197
|
+
assert data.shape == (2, 2)
|
198
|
+
|
199
|
+
def test_delay_time2(self):
|
200
|
+
with brainstate.environ.context(dt=0.1, i=0):
|
201
|
+
rotation_delay = brainstate.nn.Delay(jnp.arange(2))
|
202
|
+
index = (brainstate.random.uniform(0., 10., (2, 2)),
|
203
|
+
1)
|
204
|
+
rotation_delay.register_entry('a', *index)
|
205
|
+
rotation_delay.init_state()
|
206
|
+
data = rotation_delay.at('a')
|
207
|
+
print(index[0])
|
208
|
+
print(index[1])
|
209
|
+
print(data)
|
210
|
+
assert data.shape == (2, 2)
|
211
|
+
|
212
|
+
def test_delay_time3(self):
|
213
|
+
with brainstate.environ.context(dt=0.1, i=0):
|
214
|
+
rotation_delay = brainstate.nn.Delay(jnp.zeros((2, 2)))
|
215
|
+
index = (brainstate.random.uniform(0., 10., (2, 2)),
|
216
|
+
1,
|
217
|
+
brainstate.random.randint(0, 2, (2, 2)))
|
218
|
+
rotation_delay.register_entry('a', *index)
|
219
|
+
rotation_delay.init_state()
|
220
|
+
data = rotation_delay.at('a')
|
221
|
+
print(index[0])
|
222
|
+
print(index[1])
|
223
|
+
print(data)
|
224
|
+
assert data.shape == (2, 2)
|
225
|
+
|
226
|
+
def test_delay_time4(self):
|
227
|
+
with brainstate.environ.context(dt=0.1, i=0):
|
228
|
+
rotation_delay = brainstate.nn.Delay(jnp.zeros((2, 2)))
|
229
|
+
index = (brainstate.random.uniform(0., 10., (2, 2)),
|
230
|
+
1,
|
231
|
+
brainstate.random.randint(0, 2, (2, 2)))
|
232
|
+
rotation_delay.register_entry('a', *index)
|
233
|
+
rotation_delay.init_state()
|
234
|
+
data = rotation_delay.at('a')
|
235
|
+
print(index[0])
|
236
|
+
print(index[1])
|
237
|
+
print(data)
|
238
|
+
assert data.shape == (2, 2)
|