brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/_module_test.py
DELETED
@@ -1,207 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import unittest
|
17
|
-
|
18
|
-
import jax.numpy as jnp
|
19
|
-
import jaxlib.xla_extension
|
20
|
-
|
21
|
-
import brainstate as bst
|
22
|
-
|
23
|
-
|
24
|
-
class TestDelay(unittest.TestCase):
|
25
|
-
def test_delay1(self):
|
26
|
-
a = bst.State(bst.random.random(10, 20))
|
27
|
-
delay = bst.Delay(a.value)
|
28
|
-
delay.register_entry('a', 1.)
|
29
|
-
delay.register_entry('b', 2.)
|
30
|
-
delay.register_entry('c', None)
|
31
|
-
|
32
|
-
delay.init_state()
|
33
|
-
with self.assertRaises(KeyError):
|
34
|
-
delay.register_entry('c', 10.)
|
35
|
-
bst.util.clear_buffer_memory()
|
36
|
-
|
37
|
-
def test_rotation_delay(self):
|
38
|
-
rotation_delay = bst.Delay(jnp.ones((1,)))
|
39
|
-
t0 = 0.
|
40
|
-
t1, n1 = 1., 10
|
41
|
-
t2, n2 = 2., 20
|
42
|
-
|
43
|
-
rotation_delay.register_entry('a', t0)
|
44
|
-
rotation_delay.register_entry('b', t1)
|
45
|
-
rotation_delay.register_entry('c2', 1.9)
|
46
|
-
rotation_delay.register_entry('c', t2)
|
47
|
-
|
48
|
-
rotation_delay.init_state()
|
49
|
-
|
50
|
-
print()
|
51
|
-
# print(rotation_delay)
|
52
|
-
# print(rotation_delay.max_length)
|
53
|
-
|
54
|
-
for i in range(100):
|
55
|
-
bst.environ.set(i=i)
|
56
|
-
rotation_delay(jnp.ones((1,)) * i)
|
57
|
-
# print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c2'), rotation_delay.at('c'))
|
58
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
59
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
60
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
61
|
-
bst.util.clear_buffer_memory()
|
62
|
-
|
63
|
-
def test_concat_delay(self):
|
64
|
-
rotation_delay = bst.Delay(jnp.ones([1]), delay_method='concat')
|
65
|
-
t0 = 0.
|
66
|
-
t1, n1 = 1., 10
|
67
|
-
t2, n2 = 2., 20
|
68
|
-
|
69
|
-
rotation_delay.register_entry('a', t0)
|
70
|
-
rotation_delay.register_entry('b', t1)
|
71
|
-
rotation_delay.register_entry('c', t2)
|
72
|
-
|
73
|
-
rotation_delay.init_state()
|
74
|
-
|
75
|
-
print()
|
76
|
-
for i in range(100):
|
77
|
-
bst.environ.set(i=i)
|
78
|
-
rotation_delay(jnp.ones((1,)) * i)
|
79
|
-
print(i, rotation_delay.at('a'), rotation_delay.at('b'), rotation_delay.at('c'))
|
80
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i))
|
81
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1, 0.)))
|
82
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2, 0.)))
|
83
|
-
bst.util.clear_buffer_memory()
|
84
|
-
|
85
|
-
def test_jit_erro(self):
|
86
|
-
rotation_delay = bst.Delay(jnp.ones([1]), time=2., delay_method='concat', interp_method='round')
|
87
|
-
rotation_delay.init_state()
|
88
|
-
|
89
|
-
with bst.environ.context(i=0, t=0, jit_error_check=True):
|
90
|
-
rotation_delay.retrieve_at_time(-2.0)
|
91
|
-
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
92
|
-
rotation_delay.retrieve_at_time(-2.1)
|
93
|
-
rotation_delay.retrieve_at_time(-2.01)
|
94
|
-
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
95
|
-
rotation_delay.retrieve_at_time(-2.09)
|
96
|
-
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
97
|
-
rotation_delay.retrieve_at_time(0.1)
|
98
|
-
with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
|
99
|
-
rotation_delay.retrieve_at_time(0.01)
|
100
|
-
|
101
|
-
def test_round_interp(self):
|
102
|
-
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
103
|
-
for delay_method in ['rotation', 'concat']:
|
104
|
-
rotation_delay = bst.Delay(jnp.ones(shape), time=2., delay_method=delay_method, interp_method='round')
|
105
|
-
t0, n1 = 0.01, 0
|
106
|
-
t1, n1 = 1.04, 10
|
107
|
-
t2, n2 = 1.06, 11
|
108
|
-
rotation_delay.init_state()
|
109
|
-
|
110
|
-
@bst.transform.jit
|
111
|
-
def retrieve(td, i):
|
112
|
-
with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
|
113
|
-
return rotation_delay.retrieve_at_time(td)
|
114
|
-
|
115
|
-
print()
|
116
|
-
for i in range(100):
|
117
|
-
t = i * bst.environ.get_dt()
|
118
|
-
with bst.environ.context(i=i, t=t):
|
119
|
-
rotation_delay(jnp.ones(shape) * i)
|
120
|
-
print(i,
|
121
|
-
retrieve(t - t0, i),
|
122
|
-
retrieve(t - t1, i),
|
123
|
-
retrieve(t - t2, i))
|
124
|
-
self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.ones(shape) * i))
|
125
|
-
self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
|
126
|
-
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
127
|
-
bst.util.clear_buffer_memory()
|
128
|
-
|
129
|
-
def test_linear_interp(self):
|
130
|
-
for shape in [(1,), (1, 1), (1, 1, 1)]:
|
131
|
-
for delay_method in ['rotation', 'concat']:
|
132
|
-
print(shape, delay_method)
|
133
|
-
|
134
|
-
rotation_delay = bst.Delay(jnp.ones(shape), time=2., delay_method=delay_method, interp_method='linear_interp')
|
135
|
-
t0, n0 = 0.01, 0.1
|
136
|
-
t1, n1 = 1.04, 10.4
|
137
|
-
t2, n2 = 1.06, 10.6
|
138
|
-
rotation_delay.init_state()
|
139
|
-
|
140
|
-
@bst.transform.jit
|
141
|
-
def retrieve(td, i):
|
142
|
-
with bst.environ.context(i=i, t=i * bst.environ.get_dt()):
|
143
|
-
return rotation_delay.retrieve_at_time(td)
|
144
|
-
|
145
|
-
print()
|
146
|
-
for i in range(100):
|
147
|
-
t = i * bst.environ.get_dt()
|
148
|
-
with bst.environ.context(i=i, t=t):
|
149
|
-
rotation_delay(jnp.ones(shape) * i)
|
150
|
-
print(i,
|
151
|
-
retrieve(t - t0, i),
|
152
|
-
retrieve(t - t1, i),
|
153
|
-
retrieve(t - t2, i))
|
154
|
-
self.assertTrue(jnp.allclose(retrieve(t - t0, i), jnp.maximum(jnp.ones(shape) * i - n0, 0.)))
|
155
|
-
self.assertTrue(jnp.allclose(retrieve(t - t1, i), jnp.maximum(jnp.ones(shape) * i - n1, 0.)))
|
156
|
-
self.assertTrue(jnp.allclose(retrieve(t - t2, i), jnp.maximum(jnp.ones(shape) * i - n2, 0.)))
|
157
|
-
bst.util.clear_buffer_memory()
|
158
|
-
|
159
|
-
def test_rotation_and_concat_delay(self):
|
160
|
-
rotation_delay = bst.Delay(jnp.ones((1,)))
|
161
|
-
concat_delay = bst.Delay(jnp.ones([1]), delay_method='concat')
|
162
|
-
t0 = 0.
|
163
|
-
t1, n1 = 1., 10
|
164
|
-
t2, n2 = 2., 20
|
165
|
-
|
166
|
-
rotation_delay.register_entry('a', t0)
|
167
|
-
rotation_delay.register_entry('b', t1)
|
168
|
-
rotation_delay.register_entry('c', t2)
|
169
|
-
concat_delay.register_entry('a', t0)
|
170
|
-
concat_delay.register_entry('b', t1)
|
171
|
-
concat_delay.register_entry('c', t2)
|
172
|
-
|
173
|
-
rotation_delay.init_state()
|
174
|
-
concat_delay.init_state()
|
175
|
-
|
176
|
-
print()
|
177
|
-
for i in range(100):
|
178
|
-
bst.environ.set(i=i)
|
179
|
-
new = jnp.ones((1,)) * i
|
180
|
-
rotation_delay(new)
|
181
|
-
concat_delay(new)
|
182
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('a'), concat_delay.at('a'), ))
|
183
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('b'), concat_delay.at('b'), ))
|
184
|
-
self.assertTrue(jnp.allclose(rotation_delay.at('c'), concat_delay.at('c'), ))
|
185
|
-
bst.util.clear_buffer_memory()
|
186
|
-
|
187
|
-
|
188
|
-
class TestModule(unittest.TestCase):
|
189
|
-
def test_states(self):
|
190
|
-
class A(bst.Module):
|
191
|
-
def __init__(self):
|
192
|
-
super().__init__()
|
193
|
-
self.a = bst.State(bst.random.random(10, 20))
|
194
|
-
self.b = bst.State(bst.random.random(10, 20))
|
195
|
-
|
196
|
-
class B(bst.Module):
|
197
|
-
def __init__(self):
|
198
|
-
super().__init__()
|
199
|
-
self.a = A()
|
200
|
-
self.b = bst.State(bst.random.random(10, 20))
|
201
|
-
|
202
|
-
b = B()
|
203
|
-
print()
|
204
|
-
print(b.states())
|
205
|
-
print(b.states())
|
206
|
-
print(b.states(level=0))
|
207
|
-
print(b.states(level=0))
|
brainstate/nn/_base.py
DELETED
@@ -1,251 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from __future__ import annotations
|
19
|
-
|
20
|
-
import inspect
|
21
|
-
from typing import Sequence, Optional, Tuple, Union
|
22
|
-
|
23
|
-
from brainstate._module import Module, UpdateReturn, Container, visible_module_dict
|
24
|
-
from brainstate.mixin import Mixin, DelayedInitializer, DelayedInit
|
25
|
-
|
26
|
-
__all__ = [
|
27
|
-
'ExplicitInOutSize',
|
28
|
-
'ElementWiseBlock',
|
29
|
-
'Sequential',
|
30
|
-
'DnnLayer',
|
31
|
-
]
|
32
|
-
|
33
|
-
|
34
|
-
# -------------------------------------------------------------------------------------- #
|
35
|
-
# Network Related Concepts
|
36
|
-
# -------------------------------------------------------------------------------------- #
|
37
|
-
|
38
|
-
|
39
|
-
class ExplicitInOutSize(Mixin):
|
40
|
-
"""
|
41
|
-
Mix-in class with the explicit input and output shape.
|
42
|
-
|
43
|
-
Attributes
|
44
|
-
----------
|
45
|
-
in_size: tuple[int]
|
46
|
-
The input shape, without the batch size. This argument is important, since it is
|
47
|
-
used to evaluate the shape of the output.
|
48
|
-
out_size: tuple[int]
|
49
|
-
The output shape, without the batch size.
|
50
|
-
"""
|
51
|
-
__module__ = 'brainstate.nn'
|
52
|
-
|
53
|
-
_in_size: Optional[Tuple[int, ...]] = None
|
54
|
-
_out_size: Optional[Tuple[int, ...]] = None
|
55
|
-
|
56
|
-
@property
|
57
|
-
def in_size(self) -> Tuple[int, ...]:
|
58
|
-
return self._in_size
|
59
|
-
|
60
|
-
@in_size.setter
|
61
|
-
def in_size(self, in_size: Sequence[int] | int):
|
62
|
-
if isinstance(in_size, int):
|
63
|
-
in_size = (in_size,)
|
64
|
-
assert isinstance(in_size, (tuple, list)), f"Invalid type of in_size: {type(in_size)}"
|
65
|
-
self._in_size = tuple(in_size)
|
66
|
-
|
67
|
-
@property
|
68
|
-
def out_size(self) -> Tuple[int, ...]:
|
69
|
-
return self._out_size
|
70
|
-
|
71
|
-
@out_size.setter
|
72
|
-
def out_size(self, out_size: Sequence[int] | int):
|
73
|
-
if isinstance(out_size, int):
|
74
|
-
out_size = (out_size,)
|
75
|
-
assert isinstance(out_size, (tuple, list)), f"Invalid type of out_size: {type(out_size)}"
|
76
|
-
self._out_size = tuple(out_size)
|
77
|
-
|
78
|
-
|
79
|
-
class ElementWiseBlock(Mixin):
|
80
|
-
"""
|
81
|
-
Mix-in class for element-wise modules.
|
82
|
-
"""
|
83
|
-
__module__ = 'brainstate.nn'
|
84
|
-
|
85
|
-
|
86
|
-
class Sequential(Module, UpdateReturn, Container, ExplicitInOutSize):
|
87
|
-
"""
|
88
|
-
A sequential `input-output` module.
|
89
|
-
|
90
|
-
Modules will be added to it in the order they are passed in the
|
91
|
-
constructor. Alternatively, an ``dict`` of modules can be
|
92
|
-
passed in. The ``update()`` method of ``Sequential`` accepts any
|
93
|
-
input and forwards it to the first module it contains. It then
|
94
|
-
"chains" outputs to inputs sequentially for each subsequent module,
|
95
|
-
finally returning the output of the last module.
|
96
|
-
|
97
|
-
The value a ``Sequential`` provides over manually calling a sequence
|
98
|
-
of modules is that it allows treating the whole container as a
|
99
|
-
single module, such that performing a transformation on the
|
100
|
-
``Sequential`` applies to each of the modules it stores (which are
|
101
|
-
each a registered submodule of the ``Sequential``).
|
102
|
-
|
103
|
-
What's the difference between a ``Sequential`` and a
|
104
|
-
:py:class:`Container`? A ``Container`` is exactly what it
|
105
|
-
sounds like--a container to store :py:class:`DynamicalSystem` s!
|
106
|
-
On the other hand, the layers in a ``Sequential`` are connected
|
107
|
-
in a cascading way.
|
108
|
-
|
109
|
-
Examples
|
110
|
-
--------
|
111
|
-
|
112
|
-
>>> import jax
|
113
|
-
>>> import brainstate as bst
|
114
|
-
>>> import brainstate.nn as nn
|
115
|
-
>>>
|
116
|
-
>>> # composing ANN models
|
117
|
-
>>> l = nn.Sequential(nn.Linear(100, 10),
|
118
|
-
>>> jax.nn.relu,
|
119
|
-
>>> nn.Linear(10, 2))
|
120
|
-
>>> l(bst.random.random((256, 100)))
|
121
|
-
>>>
|
122
|
-
>>> # Using Sequential with Dict. This is functionally the
|
123
|
-
>>> # same as the above code
|
124
|
-
>>> l = nn.Sequential(l1=nn.Linear(100, 10),
|
125
|
-
>>> l2=jax.nn.relu,
|
126
|
-
>>> l3=nn.Linear(10, 2))
|
127
|
-
>>> l(bst.random.random((256, 100)))
|
128
|
-
|
129
|
-
Args:
|
130
|
-
modules_as_tuple: The children modules.
|
131
|
-
modules_as_dict: The children modules.
|
132
|
-
name: The object name.
|
133
|
-
mode: The object computing context/mode. Default is ``None``.
|
134
|
-
"""
|
135
|
-
__module__ = 'brainstate.nn'
|
136
|
-
|
137
|
-
def __init__(self, first: ExplicitInOutSize, *modules_as_tuple, **modules_as_dict):
|
138
|
-
super().__init__()
|
139
|
-
|
140
|
-
assert isinstance(first, ExplicitInOutSize)
|
141
|
-
in_size = first.out_size
|
142
|
-
|
143
|
-
tuple_modules = []
|
144
|
-
for module in modules_as_tuple:
|
145
|
-
module, in_size = self._format_module(module, in_size)
|
146
|
-
tuple_modules.append(module)
|
147
|
-
|
148
|
-
dict_modules = dict()
|
149
|
-
for key, module in modules_as_dict.items():
|
150
|
-
module, in_size = self._format_module(module, in_size)
|
151
|
-
dict_modules[key] = module
|
152
|
-
|
153
|
-
# Attribute of "Container"
|
154
|
-
self.children = visible_module_dict(self.format_elements(object, first, *tuple_modules, **dict_modules))
|
155
|
-
|
156
|
-
# the input and output shape
|
157
|
-
if first.in_size is not None:
|
158
|
-
self.in_size = first.in_size
|
159
|
-
self.out_size = tuple(in_size)
|
160
|
-
|
161
|
-
def _format_module(self, module, in_size):
|
162
|
-
if isinstance(module, DelayedInitializer):
|
163
|
-
module = module(in_size=in_size)
|
164
|
-
assert isinstance(module, ExplicitInOutSize)
|
165
|
-
out_size = module.out_size
|
166
|
-
elif isinstance(module, ElementWiseBlock):
|
167
|
-
out_size = in_size
|
168
|
-
elif isinstance(module, ExplicitInOutSize):
|
169
|
-
out_size = module.out_size
|
170
|
-
else:
|
171
|
-
raise TypeError(f"Unsupported type {type(module)}. ")
|
172
|
-
return module, out_size
|
173
|
-
|
174
|
-
def update(self, x):
|
175
|
-
"""Update function of a sequential model.
|
176
|
-
"""
|
177
|
-
for m in self.children.values():
|
178
|
-
x = m(x)
|
179
|
-
return x
|
180
|
-
|
181
|
-
def update_return(self):
|
182
|
-
"""
|
183
|
-
The return information of the sequence according to the final model.
|
184
|
-
"""
|
185
|
-
last = self[-1]
|
186
|
-
if not isinstance(last, UpdateReturn):
|
187
|
-
raise NotImplementedError(f'The last element in the sequence is not an instance of {UpdateReturn.__name__}')
|
188
|
-
return last.update_return()
|
189
|
-
|
190
|
-
def update_return_info(self):
|
191
|
-
"""
|
192
|
-
The return information of the sequence according to the final model.
|
193
|
-
"""
|
194
|
-
last = self[-1]
|
195
|
-
if not isinstance(last, UpdateReturn):
|
196
|
-
raise NotImplementedError(f'The last element in the sequence is not an instance of {UpdateReturn.__name__}')
|
197
|
-
return last.update_return_info()
|
198
|
-
|
199
|
-
def __getitem__(self, key: Union[int, slice, str]):
|
200
|
-
if isinstance(key, str):
|
201
|
-
if key in self.children:
|
202
|
-
return self.children[key]
|
203
|
-
else:
|
204
|
-
raise KeyError(f'Does not find a component named {key} in\n {str(self)}')
|
205
|
-
elif isinstance(key, slice):
|
206
|
-
return Sequential(**dict(tuple(self.children.items())[key]))
|
207
|
-
elif isinstance(key, int):
|
208
|
-
return tuple(self.children.values())[key]
|
209
|
-
elif isinstance(key, (tuple, list)):
|
210
|
-
_all_nodes = tuple(self.children.items())
|
211
|
-
return Sequential(**dict(_all_nodes[k] for k in key))
|
212
|
-
else:
|
213
|
-
raise KeyError(f'Unknown type of key: {type(key)}')
|
214
|
-
|
215
|
-
def __repr__(self):
|
216
|
-
nodes = self.children.values()
|
217
|
-
entries = '\n'.join(f' [{i}] {_repr_object(x)}' for i, x in enumerate(nodes))
|
218
|
-
return f'{self.__class__.__name__}(\n{entries}\n)'
|
219
|
-
|
220
|
-
|
221
|
-
def _repr_object(x):
|
222
|
-
if isinstance(x, Module):
|
223
|
-
return repr(x)
|
224
|
-
elif callable(x):
|
225
|
-
signature = inspect.signature(x)
|
226
|
-
args = [f'{k}={v.default}' for k, v in signature.parameters.items()
|
227
|
-
if v.default is not inspect.Parameter.empty]
|
228
|
-
args = ', '.join(args)
|
229
|
-
while not hasattr(x, '__name__'):
|
230
|
-
if not hasattr(x, 'func'):
|
231
|
-
break
|
232
|
-
x = x.func # Handle functools.partial
|
233
|
-
if not hasattr(x, '__name__') and hasattr(x, '__class__'):
|
234
|
-
return x.__class__.__name__
|
235
|
-
if args:
|
236
|
-
return f'{x.__name__}(*, {args})'
|
237
|
-
return x.__name__
|
238
|
-
else:
|
239
|
-
x = repr(x).split('\n')
|
240
|
-
x = [x[0]] + [' ' + y for y in x[1:]]
|
241
|
-
return '\n'.join(x)
|
242
|
-
|
243
|
-
|
244
|
-
class DnnLayer(Module, ExplicitInOutSize, DelayedInit):
|
245
|
-
"""
|
246
|
-
A DNN layer.
|
247
|
-
"""
|
248
|
-
__module__ = 'brainstate.nn'
|
249
|
-
|
250
|
-
def __repr__(self):
|
251
|
-
return f"{self.__class__.__name__}(in_size={self.in_size}, out_size={self.out_size})"
|