brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
brainstate/nn/_poolings_test.py
DELETED
@@ -1,231 +0,0 @@
|
|
1
|
-
# -*- coding: utf-8 -*-
|
2
|
-
|
3
|
-
|
4
|
-
import jax
|
5
|
-
import numpy as np
|
6
|
-
from absl.testing import absltest
|
7
|
-
from absl.testing import parameterized
|
8
|
-
|
9
|
-
import brainstate as bst
|
10
|
-
import brainstate.nn as nn
|
11
|
-
|
12
|
-
|
13
|
-
class TestFlatten(parameterized.TestCase):
|
14
|
-
def test_flatten1(self):
|
15
|
-
for size in [
|
16
|
-
(16, 32, 32, 8),
|
17
|
-
(32, 8),
|
18
|
-
(10, 20, 30),
|
19
|
-
]:
|
20
|
-
arr = bst.random.rand(*size)
|
21
|
-
f = nn.Flatten(start_axis=0)
|
22
|
-
out = f(arr)
|
23
|
-
self.assertTrue(out.shape == (np.prod(size),))
|
24
|
-
|
25
|
-
def test_flatten2(self):
|
26
|
-
for size in [
|
27
|
-
(16, 32, 32, 8),
|
28
|
-
(32, 8),
|
29
|
-
(10, 20, 30),
|
30
|
-
]:
|
31
|
-
arr = bst.random.rand(*size)
|
32
|
-
f = nn.Flatten(start_axis=1)
|
33
|
-
out = f(arr)
|
34
|
-
self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
|
35
|
-
|
36
|
-
def test_flatten3(self):
|
37
|
-
size = (16, 32, 32, 8)
|
38
|
-
arr = bst.random.rand(*size)
|
39
|
-
f = nn.Flatten(start_axis=0, in_size=(32, 8))
|
40
|
-
out = f(arr)
|
41
|
-
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
42
|
-
|
43
|
-
def test_flatten4(self):
|
44
|
-
size = (16, 32, 32, 8)
|
45
|
-
arr = bst.random.rand(*size)
|
46
|
-
f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
|
47
|
-
out = f(arr)
|
48
|
-
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
49
|
-
|
50
|
-
|
51
|
-
class TestUnflatten(parameterized.TestCase):
|
52
|
-
pass
|
53
|
-
|
54
|
-
|
55
|
-
class TestPool(parameterized.TestCase):
|
56
|
-
def __init__(self, *args, **kwargs):
|
57
|
-
super().__init__(*args, **kwargs)
|
58
|
-
|
59
|
-
def test_MaxPool2d_v1(self):
|
60
|
-
arr = bst.random.rand(16, 32, 32, 8)
|
61
|
-
|
62
|
-
out = nn.MaxPool2d(2, 2, channel_axis=-1)(arr)
|
63
|
-
self.assertTrue(out.shape == (16, 16, 16, 8))
|
64
|
-
|
65
|
-
out = nn.MaxPool2d(2, 2, channel_axis=None)(arr)
|
66
|
-
self.assertTrue(out.shape == (16, 32, 16, 4))
|
67
|
-
|
68
|
-
out = nn.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr)
|
69
|
-
self.assertTrue(out.shape == (16, 32, 17, 5))
|
70
|
-
|
71
|
-
out = nn.MaxPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
|
72
|
-
self.assertTrue(out.shape == (16, 32, 18, 5))
|
73
|
-
|
74
|
-
out = nn.MaxPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
|
75
|
-
self.assertTrue(out.shape == (16, 17, 17, 8))
|
76
|
-
|
77
|
-
out = nn.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
|
78
|
-
self.assertTrue(out.shape == (16, 17, 32, 5))
|
79
|
-
bst.util.clear_buffer_memory()
|
80
|
-
|
81
|
-
def test_AvgPool2d_v1(self):
|
82
|
-
arr = bst.random.rand(16, 32, 32, 8)
|
83
|
-
|
84
|
-
out = nn.AvgPool2d(2, 2, channel_axis=-1)(arr)
|
85
|
-
self.assertTrue(out.shape == (16, 16, 16, 8))
|
86
|
-
|
87
|
-
out = nn.AvgPool2d(2, 2, channel_axis=None)(arr)
|
88
|
-
self.assertTrue(out.shape == (16, 32, 16, 4))
|
89
|
-
|
90
|
-
out = nn.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr)
|
91
|
-
self.assertTrue(out.shape == (16, 32, 17, 5))
|
92
|
-
|
93
|
-
out = nn.AvgPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr)
|
94
|
-
self.assertTrue(out.shape == (16, 32, 18, 5))
|
95
|
-
|
96
|
-
out = nn.AvgPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr)
|
97
|
-
self.assertTrue(out.shape == (16, 17, 17, 8))
|
98
|
-
|
99
|
-
out = nn.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr)
|
100
|
-
self.assertTrue(out.shape == (16, 17, 32, 5))
|
101
|
-
bst.util.clear_buffer_memory()
|
102
|
-
|
103
|
-
@parameterized.named_parameters(
|
104
|
-
dict(testcase_name=f'target_size={target_size}',
|
105
|
-
target_size=target_size)
|
106
|
-
for target_size in [10, 9, 8, 7, 6]
|
107
|
-
)
|
108
|
-
def test_adaptive_pool1d(self, target_size):
|
109
|
-
from brainstate.nn._poolings import _adaptive_pool1d
|
110
|
-
|
111
|
-
arr = bst.random.rand(100)
|
112
|
-
op = jax.numpy.mean
|
113
|
-
|
114
|
-
out = _adaptive_pool1d(arr, target_size, op)
|
115
|
-
print(out.shape)
|
116
|
-
self.assertTrue(out.shape == (target_size,))
|
117
|
-
|
118
|
-
out = _adaptive_pool1d(arr, target_size, op)
|
119
|
-
print(out.shape)
|
120
|
-
self.assertTrue(out.shape == (target_size,))
|
121
|
-
bst.util.clear_buffer_memory()
|
122
|
-
|
123
|
-
def test_AdaptiveAvgPool2d_v1(self):
|
124
|
-
input = bst.random.randn(64, 8, 9)
|
125
|
-
|
126
|
-
output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
|
127
|
-
self.assertTrue(output.shape == (64, 5, 7))
|
128
|
-
|
129
|
-
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
|
130
|
-
self.assertTrue(output.shape == (64, 2, 3))
|
131
|
-
|
132
|
-
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
|
133
|
-
self.assertTrue(output.shape == (2, 3, 9))
|
134
|
-
|
135
|
-
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
|
136
|
-
self.assertTrue(output.shape == (2, 8, 3))
|
137
|
-
|
138
|
-
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=None)(input)
|
139
|
-
self.assertTrue(output.shape == (64, 2, 3))
|
140
|
-
bst.util.clear_buffer_memory()
|
141
|
-
|
142
|
-
def test_AdaptiveAvgPool2d_v2(self):
|
143
|
-
bst.random.seed()
|
144
|
-
input = bst.random.randn(128, 64, 32, 16)
|
145
|
-
|
146
|
-
output = nn.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input)
|
147
|
-
self.assertTrue(output.shape == (128, 64, 5, 7))
|
148
|
-
|
149
|
-
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input)
|
150
|
-
self.assertTrue(output.shape == (128, 64, 2, 3))
|
151
|
-
|
152
|
-
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input)
|
153
|
-
self.assertTrue(output.shape == (128, 2, 3, 16))
|
154
|
-
|
155
|
-
output = nn.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input)
|
156
|
-
self.assertTrue(output.shape == (128, 64, 2, 3))
|
157
|
-
print()
|
158
|
-
bst.util.clear_buffer_memory()
|
159
|
-
|
160
|
-
def test_AdaptiveAvgPool3d_v1(self):
|
161
|
-
input = bst.random.randn(10, 128, 64, 32)
|
162
|
-
net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3], channel_axis=0)
|
163
|
-
output = net(input)
|
164
|
-
self.assertTrue(output.shape == (10, 6, 5, 3))
|
165
|
-
bst.util.clear_buffer_memory()
|
166
|
-
|
167
|
-
def test_AdaptiveAvgPool3d_v2(self):
|
168
|
-
input = bst.random.randn(10, 20, 128, 64, 32)
|
169
|
-
net = nn.AdaptiveAvgPool3d(target_size=[6, 5, 3])
|
170
|
-
output = net(input)
|
171
|
-
self.assertTrue(output.shape == (10, 6, 5, 3, 32))
|
172
|
-
bst.util.clear_buffer_memory()
|
173
|
-
|
174
|
-
@parameterized.product(
|
175
|
-
axis=(-1, 0, 1)
|
176
|
-
)
|
177
|
-
def test_AdaptiveMaxPool1d_v1(self, axis):
|
178
|
-
input = bst.random.randn(32, 16)
|
179
|
-
net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
|
180
|
-
output = net(input)
|
181
|
-
bst.util.clear_buffer_memory()
|
182
|
-
|
183
|
-
@parameterized.product(
|
184
|
-
axis=(-1, 0, 1, 2)
|
185
|
-
)
|
186
|
-
def test_AdaptiveMaxPool1d_v2(self, axis):
|
187
|
-
input = bst.random.randn(2, 32, 16)
|
188
|
-
net = nn.AdaptiveMaxPool1d(target_size=4, channel_axis=axis)
|
189
|
-
output = net(input)
|
190
|
-
bst.util.clear_buffer_memory()
|
191
|
-
|
192
|
-
@parameterized.product(
|
193
|
-
axis=(-1, 0, 1, 2)
|
194
|
-
)
|
195
|
-
def test_AdaptiveMaxPool2d_v1(self, axis):
|
196
|
-
input = bst.random.randn(32, 16, 12)
|
197
|
-
net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
|
198
|
-
output = net(input)
|
199
|
-
bst.util.clear_buffer_memory()
|
200
|
-
|
201
|
-
@parameterized.product(
|
202
|
-
axis=(-1, 0, 1, 2, 3)
|
203
|
-
)
|
204
|
-
def test_AdaptiveMaxPool2d_v2(self, axis):
|
205
|
-
input = bst.random.randn(2, 32, 16, 12)
|
206
|
-
net = nn.AdaptiveAvgPool2d(target_size=[5, 4], channel_axis=axis)
|
207
|
-
output = net(input)
|
208
|
-
bst.util.clear_buffer_memory()
|
209
|
-
|
210
|
-
@parameterized.product(
|
211
|
-
axis=(-1, 0, 1, 2, 3)
|
212
|
-
)
|
213
|
-
def test_AdaptiveMaxPool3d_v1(self, axis):
|
214
|
-
input = bst.random.randn(2, 128, 64, 32)
|
215
|
-
net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
|
216
|
-
output = net(input)
|
217
|
-
print()
|
218
|
-
bst.util.clear_buffer_memory()
|
219
|
-
|
220
|
-
@parameterized.product(
|
221
|
-
axis=(-1, 0, 1, 2, 3, 4)
|
222
|
-
)
|
223
|
-
def test_AdaptiveMaxPool3d_v1(self, axis):
|
224
|
-
input = bst.random.randn(2, 128, 64, 32, 16)
|
225
|
-
net = nn.AdaptiveMaxPool3d(target_size=[6, 5, 4], channel_axis=axis)
|
226
|
-
output = net(input)
|
227
|
-
bst.util.clear_buffer_memory()
|
228
|
-
|
229
|
-
|
230
|
-
if __name__ == '__main__':
|
231
|
-
absltest.main()
|