brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,254 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import pytest
|
7
|
+
from absl.testing import absltest
|
8
|
+
from absl.testing import parameterized
|
9
|
+
|
10
|
+
import brainstate as bst
|
11
|
+
|
12
|
+
|
13
|
+
class TestConv(parameterized.TestCase):
|
14
|
+
def test_Conv2D_img(self):
|
15
|
+
img = jnp.zeros((2, 200, 198, 4))
|
16
|
+
for k in range(4):
|
17
|
+
x = 30 + 60 * k
|
18
|
+
y = 20 + 60 * k
|
19
|
+
img = img.at[0, x:x + 10, y:y + 10, k].set(1.0)
|
20
|
+
img = img.at[1, x:x + 20, y:y + 20, k].set(3.0)
|
21
|
+
|
22
|
+
net = bst.nn.Conv2d((200, 198, 4), out_channels=32, kernel_size=(3, 3),
|
23
|
+
stride=(2, 1), padding='VALID', groups=4)
|
24
|
+
out = net(img)
|
25
|
+
print("out shape: ", out.shape)
|
26
|
+
self.assertEqual(out.shape, (2, 99, 196, 32))
|
27
|
+
# print("First output channel:")
|
28
|
+
# plt.figure(figsize=(10, 10))
|
29
|
+
# plt.imshow(np.array(img)[0, :, :, 0])
|
30
|
+
# plt.show()
|
31
|
+
|
32
|
+
def test_conv1D(self):
|
33
|
+
model = bst.nn.Conv1d((5, 3), out_channels=32, kernel_size=(3,))
|
34
|
+
input = jnp.ones((2, 5, 3))
|
35
|
+
out = model(input)
|
36
|
+
print("out shape: ", out.shape)
|
37
|
+
self.assertEqual(out.shape, (2, 5, 32))
|
38
|
+
# print("First output channel:")
|
39
|
+
# plt.figure(figsize=(10, 10))
|
40
|
+
# plt.imshow(np.array(out)[0, :, :])
|
41
|
+
# plt.show()
|
42
|
+
|
43
|
+
def test_conv2D(self):
|
44
|
+
model = bst.nn.Conv2d((5, 5, 3), out_channels=32, kernel_size=(3, 3))
|
45
|
+
input = jnp.ones((2, 5, 5, 3))
|
46
|
+
|
47
|
+
out = model(input)
|
48
|
+
print("out shape: ", out.shape)
|
49
|
+
self.assertEqual(out.shape, (2, 5, 5, 32))
|
50
|
+
|
51
|
+
def test_conv3D(self):
|
52
|
+
model = bst.nn.Conv3d((5, 5, 5, 3), out_channels=32, kernel_size=(3, 3, 3))
|
53
|
+
input = jnp.ones((2, 5, 5, 5, 3))
|
54
|
+
out = model(input)
|
55
|
+
print("out shape: ", out.shape)
|
56
|
+
self.assertEqual(out.shape, (2, 5, 5, 5, 32))
|
57
|
+
|
58
|
+
|
59
|
+
@pytest.mark.skip(reason="not implemented yet")
|
60
|
+
class TestConvTranspose1d(parameterized.TestCase):
|
61
|
+
def test_conv_transpose(self):
|
62
|
+
|
63
|
+
x = jnp.ones((1, 8, 3))
|
64
|
+
for use_bias in [True, False]:
|
65
|
+
conv_transpose_module = bst.nn.ConvTranspose1d(
|
66
|
+
in_channels=3,
|
67
|
+
out_channels=4,
|
68
|
+
kernel_size=(3,),
|
69
|
+
padding='VALID',
|
70
|
+
w_initializer=bst.init.Constant(1.),
|
71
|
+
b_initializer=bst.init.Constant(1.) if use_bias else None,
|
72
|
+
)
|
73
|
+
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
|
74
|
+
y = conv_transpose_module(x)
|
75
|
+
print(y.shape)
|
76
|
+
correct_ans = jnp.array([[[4., 4., 4., 4.],
|
77
|
+
[7., 7., 7., 7.],
|
78
|
+
[10., 10., 10., 10.],
|
79
|
+
[10., 10., 10., 10.],
|
80
|
+
[10., 10., 10., 10.],
|
81
|
+
[10., 10., 10., 10.],
|
82
|
+
[10., 10., 10., 10.],
|
83
|
+
[10., 10., 10., 10.],
|
84
|
+
[7., 7., 7., 7.],
|
85
|
+
[4., 4., 4., 4.]]])
|
86
|
+
if not use_bias:
|
87
|
+
correct_ans -= 1.
|
88
|
+
self.assertTrue(jnp.allclose(y, correct_ans))
|
89
|
+
|
90
|
+
def test_single_input_masked_conv_transpose(self):
|
91
|
+
|
92
|
+
x = jnp.ones((1, 8, 3))
|
93
|
+
m = jnp.tril(jnp.ones((3, 3, 4)))
|
94
|
+
conv_transpose_module = bst.nn.ConvTranspose1d(
|
95
|
+
in_channels=3,
|
96
|
+
out_channels=4,
|
97
|
+
kernel_size=(3,),
|
98
|
+
padding='VALID',
|
99
|
+
mask=m,
|
100
|
+
w_initializer=bst.init.Constant(),
|
101
|
+
b_initializer=bst.init.Constant(),
|
102
|
+
)
|
103
|
+
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 4))
|
104
|
+
y = conv_transpose_module(x)
|
105
|
+
print(y.shape)
|
106
|
+
correct_ans = jnp.array([[[4., 3., 2., 1.],
|
107
|
+
[7., 5., 3., 1.],
|
108
|
+
[10., 7., 4., 1.],
|
109
|
+
[10., 7., 4., 1.],
|
110
|
+
[10., 7., 4., 1.],
|
111
|
+
[10., 7., 4., 1.],
|
112
|
+
[10., 7., 4., 1.],
|
113
|
+
[10., 7., 4., 1.],
|
114
|
+
[7., 5., 3., 1.],
|
115
|
+
[4., 3., 2., 1.]]])
|
116
|
+
self.assertTrue(jnp.allclose(y, correct_ans))
|
117
|
+
|
118
|
+
def test_computation_padding_same(self):
|
119
|
+
|
120
|
+
data = jnp.ones([1, 3, 1])
|
121
|
+
for use_bias in [True, False]:
|
122
|
+
net = bst.nn.ConvTranspose1d(
|
123
|
+
in_channels=1,
|
124
|
+
out_channels=1,
|
125
|
+
kernel_size=3,
|
126
|
+
stride=1,
|
127
|
+
padding="SAME",
|
128
|
+
w_initializer=bst.init.Constant(),
|
129
|
+
b_initializer=bst.init.Constant() if use_bias else None,
|
130
|
+
)
|
131
|
+
out = net(data)
|
132
|
+
self.assertEqual(out.shape, (1, 3, 1))
|
133
|
+
out = jnp.squeeze(out, axis=(0, 2))
|
134
|
+
expected_out = jnp.asarray([2, 3, 2])
|
135
|
+
if use_bias:
|
136
|
+
expected_out += 1
|
137
|
+
self.assertTrue(jnp.allclose(out, expected_out, rtol=1e-5))
|
138
|
+
|
139
|
+
|
140
|
+
@pytest.mark.skip(reason="not implemented yet")
|
141
|
+
class TestConvTranspose2d(parameterized.TestCase):
|
142
|
+
def test_conv_transpose(self):
|
143
|
+
|
144
|
+
x = jnp.ones((1, 8, 8, 3))
|
145
|
+
for use_bias in [True, False]:
|
146
|
+
conv_transpose_module = bst.nn.ConvTranspose2d(
|
147
|
+
in_channels=3,
|
148
|
+
out_channels=4,
|
149
|
+
kernel_size=(3, 3),
|
150
|
+
padding='VALID',
|
151
|
+
w_initializer=bst.init.Constant(),
|
152
|
+
b_initializer=bst.init.Constant() if use_bias else None,
|
153
|
+
)
|
154
|
+
self.assertEqual(conv_transpose_module.w.shape, (3, 3, 3, 4))
|
155
|
+
y = conv_transpose_module(x)
|
156
|
+
print(y.shape)
|
157
|
+
|
158
|
+
def test_single_input_masked_conv_transpose(self):
|
159
|
+
|
160
|
+
x = jnp.ones((1, 8, 8, 3))
|
161
|
+
m = jnp.tril(jnp.ones((3, 3, 3, 4)))
|
162
|
+
conv_transpose_module = bst.nn.ConvTranspose2d(
|
163
|
+
in_channels=3,
|
164
|
+
out_channels=4,
|
165
|
+
kernel_size=(3, 3),
|
166
|
+
padding='VALID',
|
167
|
+
mask=m,
|
168
|
+
w_initializer=bst.init.Constant(),
|
169
|
+
)
|
170
|
+
y = conv_transpose_module(x)
|
171
|
+
print(y.shape)
|
172
|
+
|
173
|
+
def test_computation_padding_same(self):
|
174
|
+
|
175
|
+
x = jnp.ones((1, 8, 8, 3))
|
176
|
+
for use_bias in [True, False]:
|
177
|
+
conv_transpose_module = bst.nn.ConvTranspose2d(
|
178
|
+
in_channels=3,
|
179
|
+
out_channels=4,
|
180
|
+
kernel_size=(3, 3),
|
181
|
+
stride=1,
|
182
|
+
padding='SAME',
|
183
|
+
w_initializer=bst.init.Constant(),
|
184
|
+
b_initializer=bst.init.Constant() if use_bias else None,
|
185
|
+
)
|
186
|
+
y = conv_transpose_module(x)
|
187
|
+
print(y.shape)
|
188
|
+
|
189
|
+
|
190
|
+
@pytest.mark.skip(reason="not implemented yet")
|
191
|
+
class TestConvTranspose3d(parameterized.TestCase):
|
192
|
+
def test_conv_transpose(self):
|
193
|
+
|
194
|
+
x = jnp.ones((1, 8, 8, 8, 3))
|
195
|
+
for use_bias in [True, False]:
|
196
|
+
conv_transpose_module = bst.nn.ConvTranspose3d(
|
197
|
+
in_channels=3,
|
198
|
+
out_channels=4,
|
199
|
+
kernel_size=(3, 3, 3),
|
200
|
+
padding='VALID',
|
201
|
+
w_initializer=bst.init.Constant(),
|
202
|
+
b_initializer=bst.init.Constant() if use_bias else None,
|
203
|
+
)
|
204
|
+
y = conv_transpose_module(x)
|
205
|
+
print(y.shape)
|
206
|
+
|
207
|
+
def test_single_input_masked_conv_transpose(self):
|
208
|
+
|
209
|
+
x = jnp.ones((1, 8, 8, 8, 3))
|
210
|
+
m = jnp.tril(jnp.ones((3, 3, 3, 3, 4)))
|
211
|
+
conv_transpose_module = bst.nn.ConvTranspose3d(
|
212
|
+
in_channels=3,
|
213
|
+
out_channels=4,
|
214
|
+
kernel_size=(3, 3, 3),
|
215
|
+
padding='VALID',
|
216
|
+
mask=m,
|
217
|
+
w_initializer=bst.init.Constant(),
|
218
|
+
)
|
219
|
+
y = conv_transpose_module(x)
|
220
|
+
print(y.shape)
|
221
|
+
|
222
|
+
def test_computation_padding_same(self):
|
223
|
+
|
224
|
+
x = jnp.ones((1, 8, 8, 8, 3))
|
225
|
+
for use_bias in [True, False]:
|
226
|
+
conv_transpose_module = bst.nn.ConvTranspose3d(
|
227
|
+
in_channels=3,
|
228
|
+
out_channels=4,
|
229
|
+
kernel_size=(3, 3, 3),
|
230
|
+
stride=1,
|
231
|
+
padding='SAME',
|
232
|
+
w_initializer=bst.init.Constant(),
|
233
|
+
b_initializer=bst.init.Constant() if use_bias else None,
|
234
|
+
)
|
235
|
+
y = conv_transpose_module(x)
|
236
|
+
print(y.shape)
|
237
|
+
|
238
|
+
|
239
|
+
class TestDense(parameterized.TestCase):
|
240
|
+
@parameterized.product(
|
241
|
+
size=[(10,),
|
242
|
+
(20, 10),
|
243
|
+
(5, 8, 10)],
|
244
|
+
num_out=[20, ]
|
245
|
+
)
|
246
|
+
def test_Dense1(self, size, num_out):
|
247
|
+
f = bst.nn.Linear(10, num_out)
|
248
|
+
x = bst.random.random(size)
|
249
|
+
y = f(x)
|
250
|
+
self.assertTrue(y.shape == size[:-1] + (num_out,))
|
251
|
+
|
252
|
+
|
253
|
+
if __name__ == '__main__':
|
254
|
+
absltest.main()
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
from typing import Optional, Callable, Union
|
18
|
+
|
19
|
+
from brainstate import init
|
20
|
+
from brainstate._state import ParamState
|
21
|
+
from brainstate.nn._module import Module
|
22
|
+
from brainstate.typing import ArrayLike
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
'Embedding',
|
26
|
+
]
|
27
|
+
|
28
|
+
|
29
|
+
class Embedding(Module):
|
30
|
+
r"""
|
31
|
+
A simple lookup table that stores embeddings of a fixed size.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
num_embeddings: Size of embedding dictionary. Must be non-negative.
|
35
|
+
embedding_size: Size of each embedding vector. Must be non-negative.
|
36
|
+
embedding_init: The initializer for the embedding lookup table, of shape `(num_embeddings, embedding_size)`.
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
num_embeddings: int,
|
42
|
+
embedding_size: int,
|
43
|
+
embedding_init: Union[Callable, ArrayLike] = init.LecunUniform(),
|
44
|
+
name: Optional[str] = None,
|
45
|
+
):
|
46
|
+
super().__init__(name=name)
|
47
|
+
if num_embeddings < 0:
|
48
|
+
raise ValueError("num_embeddings must not be negative.")
|
49
|
+
if embedding_size < 0:
|
50
|
+
raise ValueError("embedding_size must not be negative.")
|
51
|
+
self.num_embeddings = num_embeddings
|
52
|
+
self.embedding_size = embedding_size
|
53
|
+
self.out_size = (embedding_size,)
|
54
|
+
|
55
|
+
weight = init.param(embedding_init, (self.num_embeddings, self.embedding_size))
|
56
|
+
self.weight = ParamState(weight)
|
57
|
+
|
58
|
+
def update(self, indices: ArrayLike):
|
59
|
+
return self.weight.value[indices]
|
@@ -0,0 +1,388 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
import numbers
|
21
|
+
from typing import Callable, Union, Sequence, Optional, Any
|
22
|
+
|
23
|
+
import jax
|
24
|
+
import jax.numpy as jnp
|
25
|
+
|
26
|
+
from brainstate import environ, init
|
27
|
+
from brainstate._state import LongTermState, ParamState
|
28
|
+
from brainstate.nn._module import Module
|
29
|
+
from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
'BatchNorm0d', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
|
33
|
+
]
|
34
|
+
|
35
|
+
|
36
|
+
def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
|
37
|
+
axes = []
|
38
|
+
for axis in feature_axes:
|
39
|
+
if axis < 0:
|
40
|
+
axis += ndim
|
41
|
+
if axis < 0 or axis >= ndim:
|
42
|
+
raise ValueError(f'Invalid axis {axis} for {ndim}D input')
|
43
|
+
axes.append(axis)
|
44
|
+
return tuple(axes)
|
45
|
+
|
46
|
+
|
47
|
+
def _abs_sq(x):
|
48
|
+
"""Computes the elementwise square of the absolute value |x|^2."""
|
49
|
+
if jnp.iscomplexobj(x):
|
50
|
+
return jax.lax.square(jax.lax.real(x)) + jax.lax.square(jax.lax.imag(x))
|
51
|
+
else:
|
52
|
+
return jax.lax.square(x)
|
53
|
+
|
54
|
+
|
55
|
+
def _compute_stats(
|
56
|
+
x: ArrayLike,
|
57
|
+
axes: Sequence[int],
|
58
|
+
dtype: DTypeLike,
|
59
|
+
axis_name: Optional[str] = None,
|
60
|
+
axis_index_groups: Optional[Sequence[int]] = None,
|
61
|
+
use_mean: bool = True,
|
62
|
+
):
|
63
|
+
"""Computes mean and variance statistics.
|
64
|
+
|
65
|
+
This implementation takes care of a few important details:
|
66
|
+
- Computes in float32 precision for stability in half precision training.
|
67
|
+
- mean and variance are computable in a single XLA fusion,
|
68
|
+
by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]).
|
69
|
+
- Clips negative variances to zero which can happen due to
|
70
|
+
roundoff errors. This avoids downstream NaNs.
|
71
|
+
- Supports averaging across a parallel axis and subgroups of a parallel axis
|
72
|
+
with a single `lax.pmean` call to avoid latency.
|
73
|
+
|
74
|
+
Arguments:
|
75
|
+
x: Input array.
|
76
|
+
axes: The axes in ``x`` to compute mean and variance statistics for.
|
77
|
+
dtype: tp.Optional dtype specifying the minimal precision. Statistics
|
78
|
+
are always at least float32 for stability (default: dtype of x).
|
79
|
+
axis_name: tp.Optional name for the pmapped axis to compute mean over.
|
80
|
+
axis_index_groups: tp.Optional axis indices.
|
81
|
+
use_mean: If true, calculate the mean from the input and use it when
|
82
|
+
computing the variance. If false, set the mean to zero and compute
|
83
|
+
the variance without subtracting the mean.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
A pair ``(mean, val)``.
|
87
|
+
"""
|
88
|
+
if dtype is None:
|
89
|
+
dtype = jax.numpy.result_type(x)
|
90
|
+
# promote x to at least float32, this avoids half precision computation
|
91
|
+
# but preserves double or complex floating points
|
92
|
+
dtype = jax.numpy.promote_types(dtype, environ.dftype())
|
93
|
+
x = jnp.asarray(x, dtype)
|
94
|
+
|
95
|
+
# Compute mean and mean of squared values.
|
96
|
+
mean2 = jnp.mean(_abs_sq(x), axes)
|
97
|
+
if use_mean:
|
98
|
+
mean = jnp.mean(x, axes)
|
99
|
+
else:
|
100
|
+
mean = jnp.zeros(mean2.shape, dtype=dtype)
|
101
|
+
|
102
|
+
# If axis_name is provided, we need to average the mean and mean2 across
|
103
|
+
if axis_name is not None:
|
104
|
+
concatenated_mean = jnp.concatenate([mean, mean2])
|
105
|
+
mean, mean2 = jnp.split(
|
106
|
+
jax.lax.pmean(
|
107
|
+
concatenated_mean,
|
108
|
+
axis_name=axis_name,
|
109
|
+
axis_index_groups=axis_index_groups,
|
110
|
+
),
|
111
|
+
2,
|
112
|
+
)
|
113
|
+
|
114
|
+
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
|
115
|
+
# to floating point round-off errors.
|
116
|
+
var = jnp.maximum(0.0, mean2 - _abs_sq(mean))
|
117
|
+
return mean, var
|
118
|
+
|
119
|
+
|
120
|
+
def _normalize(
|
121
|
+
x: ArrayLike,
|
122
|
+
mean: Optional[ArrayLike],
|
123
|
+
var: Optional[ArrayLike],
|
124
|
+
weights: Optional[ParamState],
|
125
|
+
reduction_axes: Sequence[int],
|
126
|
+
dtype: DTypeLike,
|
127
|
+
epsilon: Union[numbers.Number, jax.Array],
|
128
|
+
):
|
129
|
+
"""Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
|
130
|
+
|
131
|
+
Arguments:
|
132
|
+
x: The input.
|
133
|
+
mean: Mean to use for normalization.
|
134
|
+
var: Variance to use for normalization.
|
135
|
+
weights: The scale and bias parameters.
|
136
|
+
reduction_axes: The axes in ``x`` to reduce.
|
137
|
+
dtype: The dtype of the result (default: infer from input and params).
|
138
|
+
epsilon: Normalization epsilon.
|
139
|
+
|
140
|
+
Returns:
|
141
|
+
The normalized input.
|
142
|
+
"""
|
143
|
+
if mean is not None:
|
144
|
+
assert var is not None, 'mean and val must be both None or not None.'
|
145
|
+
stats_shape = list(x.shape)
|
146
|
+
for axis in reduction_axes:
|
147
|
+
stats_shape[axis] = 1
|
148
|
+
mean = mean.reshape(stats_shape)
|
149
|
+
var = var.reshape(stats_shape)
|
150
|
+
y = x - mean
|
151
|
+
mul = jax.lax.rsqrt(var + jnp.asarray(epsilon, dtype))
|
152
|
+
y = y * mul
|
153
|
+
if weights is not None:
|
154
|
+
y = _scale_operation(y, weights.value)
|
155
|
+
else:
|
156
|
+
assert var is None, 'mean and val must be both None or not None.'
|
157
|
+
assert weights is None, 'scale and bias are not supported without mean and val'
|
158
|
+
y = x
|
159
|
+
return jnp.asarray(y, dtype)
|
160
|
+
|
161
|
+
|
162
|
+
def _scale_operation(x, param):
|
163
|
+
if 'scale' in param:
|
164
|
+
x = x * param['scale']
|
165
|
+
if 'bias' in param:
|
166
|
+
x = x + param['bias']
|
167
|
+
return x
|
168
|
+
|
169
|
+
|
170
|
+
class _BatchNorm(Module):
|
171
|
+
__module__ = 'brainstate.nn'
|
172
|
+
num_spatial_dims: int
|
173
|
+
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
in_size: Size,
|
177
|
+
feature_axis: Axes = -1,
|
178
|
+
track_running_stats: bool = True,
|
179
|
+
epsilon: float = 1e-5,
|
180
|
+
momentum: float = 0.99,
|
181
|
+
affine: bool = True,
|
182
|
+
bias_initializer: Union[ArrayLike, Callable] = init.Constant(0.),
|
183
|
+
scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
|
184
|
+
axis_name: Optional[Union[str, Sequence[str]]] = None,
|
185
|
+
axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
|
186
|
+
name: Optional[str] = None,
|
187
|
+
dtype: Any = None,
|
188
|
+
):
|
189
|
+
super().__init__(name=name)
|
190
|
+
|
191
|
+
# parameters
|
192
|
+
self.in_size = tuple(in_size)
|
193
|
+
self.out_size = tuple(in_size)
|
194
|
+
self.affine = affine
|
195
|
+
self.bias_initializer = bias_initializer
|
196
|
+
self.scale_initializer = scale_initializer
|
197
|
+
self.dtype = dtype or environ.dftype()
|
198
|
+
self.track_running_stats = track_running_stats
|
199
|
+
self.momentum = jnp.asarray(momentum, dtype=self.dtype)
|
200
|
+
self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
|
201
|
+
|
202
|
+
# parameters about axis
|
203
|
+
feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
|
204
|
+
self.feature_axis = _canonicalize_axes(len(in_size), feature_axis)
|
205
|
+
self.axis_name = axis_name
|
206
|
+
self.axis_index_groups = axis_index_groups
|
207
|
+
|
208
|
+
# variables
|
209
|
+
feature_shape = tuple([ax if i in self.feature_axis else 1 for i, ax in enumerate(in_size)])
|
210
|
+
if self.track_running_stats:
|
211
|
+
self.running_mean = LongTermState(jnp.zeros(feature_shape, dtype=self.dtype))
|
212
|
+
self.running_var = LongTermState(jnp.ones(feature_shape, dtype=self.dtype))
|
213
|
+
else:
|
214
|
+
self.running_mean = None
|
215
|
+
self.running_var = None
|
216
|
+
|
217
|
+
# parameters
|
218
|
+
if self.affine:
|
219
|
+
assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
|
220
|
+
bias = init.param(self.bias_initializer, feature_shape)
|
221
|
+
scale = init.param(self.scale_initializer, feature_shape)
|
222
|
+
self.weight = ParamState(dict(bias=bias, scale=scale))
|
223
|
+
else:
|
224
|
+
self.weight = None
|
225
|
+
|
226
|
+
def update(self, x):
|
227
|
+
# input shape and batch mode or not
|
228
|
+
if x.ndim == self.num_spatial_dims + 2:
|
229
|
+
x_shape = x.shape[1:]
|
230
|
+
batch = True
|
231
|
+
elif x.ndim == self.num_spatial_dims + 1:
|
232
|
+
x_shape = x.shape
|
233
|
+
batch = False
|
234
|
+
else:
|
235
|
+
raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
|
236
|
+
f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
|
237
|
+
if self.in_size != x_shape:
|
238
|
+
raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
|
239
|
+
|
240
|
+
# reduce the feature axis
|
241
|
+
if batch:
|
242
|
+
reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axis)
|
243
|
+
else:
|
244
|
+
reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axis)
|
245
|
+
|
246
|
+
# fitting phase
|
247
|
+
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
248
|
+
|
249
|
+
# compute the running mean and variance
|
250
|
+
if self.track_running_stats:
|
251
|
+
if fit_phase:
|
252
|
+
mean, var = _compute_stats(
|
253
|
+
x,
|
254
|
+
reduction_axes,
|
255
|
+
dtype=self.dtype,
|
256
|
+
axis_name=self.axis_name,
|
257
|
+
axis_index_groups=self.axis_index_groups,
|
258
|
+
)
|
259
|
+
self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
|
260
|
+
self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
|
261
|
+
else:
|
262
|
+
mean = self.running_mean.value
|
263
|
+
var = self.running_var.value
|
264
|
+
else:
|
265
|
+
mean, var = None, None
|
266
|
+
|
267
|
+
# normalize
|
268
|
+
return _normalize(x, mean, var, self.weight, reduction_axes, self.dtype, self.epsilon)
|
269
|
+
|
270
|
+
|
271
|
+
class BatchNorm0d(_BatchNorm):
|
272
|
+
r"""1-D batch normalization [1]_.
|
273
|
+
|
274
|
+
The data should be of `(b, l, c)`, where `b` is the batch dimension,
|
275
|
+
`l` is the layer dimension, and `c` is the channel dimension.
|
276
|
+
|
277
|
+
%s
|
278
|
+
"""
|
279
|
+
__module__ = 'brainstate.nn'
|
280
|
+
num_spatial_dims: int = 0
|
281
|
+
|
282
|
+
|
283
|
+
class BatchNorm1d(_BatchNorm):
|
284
|
+
r"""1-D batch normalization [1]_.
|
285
|
+
|
286
|
+
The data should be of `(b, l, c)`, where `b` is the batch dimension,
|
287
|
+
`l` is the layer dimension, and `c` is the channel dimension.
|
288
|
+
|
289
|
+
%s
|
290
|
+
"""
|
291
|
+
__module__ = 'brainstate.nn'
|
292
|
+
num_spatial_dims: int = 1
|
293
|
+
|
294
|
+
|
295
|
+
class BatchNorm2d(_BatchNorm):
|
296
|
+
r"""2-D batch normalization [1]_.
|
297
|
+
|
298
|
+
The data should be of `(b, h, w, c)`, where `b` is the batch dimension,
|
299
|
+
`h` is the height dimension, `w` is the width dimension, and `c` is the
|
300
|
+
channel dimension.
|
301
|
+
|
302
|
+
%s
|
303
|
+
"""
|
304
|
+
__module__ = 'brainstate.nn'
|
305
|
+
num_spatial_dims: int = 2
|
306
|
+
|
307
|
+
|
308
|
+
class BatchNorm3d(_BatchNorm):
|
309
|
+
r"""3-D batch normalization [1]_.
|
310
|
+
|
311
|
+
The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
|
312
|
+
`h` is the height dimension, `w` is the width dimension, `d` is the depth
|
313
|
+
dimension, and `c` is the channel dimension.
|
314
|
+
|
315
|
+
%s
|
316
|
+
"""
|
317
|
+
__module__ = 'brainstate.nn'
|
318
|
+
num_spatial_dims: int = 3
|
319
|
+
|
320
|
+
|
321
|
+
_bn_doc = r'''
|
322
|
+
|
323
|
+
This layer aims to reduce the internal covariant shift of data. It
|
324
|
+
normalizes a batch of data by fixing the mean and variance of inputs
|
325
|
+
on each feature (channel). Most commonly, the first axis of the data
|
326
|
+
is the batch, and the last is the channel. However, users can specify
|
327
|
+
the axes to be normalized.
|
328
|
+
|
329
|
+
.. math::
|
330
|
+
y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta
|
331
|
+
|
332
|
+
.. note::
|
333
|
+
This :attr:`momentum` argument is different from one used in optimizer
|
334
|
+
classes and the conventional notion of momentum. Mathematically, the
|
335
|
+
update rule for running statistics here is
|
336
|
+
:math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`,
|
337
|
+
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
338
|
+
new observed value.
|
339
|
+
|
340
|
+
Parameters
|
341
|
+
----------
|
342
|
+
in_size: sequence of int
|
343
|
+
The input shape, without batch size.
|
344
|
+
feature_axis: int, tuple, list
|
345
|
+
The feature or non-batch axis of the input.
|
346
|
+
track_running_stats: bool
|
347
|
+
A boolean value that when set to ``True``, this module tracks the running mean and variance,
|
348
|
+
and when set to ``False``, this module does not track such statistics, and initializes
|
349
|
+
statistics buffers ``running_mean`` and ``running_var`` as ``None``. When these buffers are ``None``,
|
350
|
+
this module always uses batch statistics. in both training and eval modes. Default: ``True``.
|
351
|
+
momentum: float
|
352
|
+
The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99
|
353
|
+
epsilon: float
|
354
|
+
A value added to the denominator for numerical stability. Default: 1e-5
|
355
|
+
affine: bool
|
356
|
+
A boolean value that when set to ``True``, this module has
|
357
|
+
learnable affine parameters. Default: ``True``
|
358
|
+
bias_initializer: ArrayLike, Callable
|
359
|
+
An initializer generating the original translation matrix. If not ``None``, bias (beta) is added.
|
360
|
+
Default: ``init.Constant(0.)``
|
361
|
+
scale_initializer: ArrayLike, Callable
|
362
|
+
An initializer generating the original scaling matrix. If not ``None``, multiply by scale (gamma).
|
363
|
+
Default: ``init.Constant(1.)``
|
364
|
+
axis_name: optional, str, sequence of str
|
365
|
+
If not ``None``, it should be a string (or sequence of
|
366
|
+
strings) representing the axis name(s) over which this module is being
|
367
|
+
run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this
|
368
|
+
argument means that batch statistics are calculated across all replicas
|
369
|
+
on the named axes.
|
370
|
+
axis_index_groups: optional, sequence
|
371
|
+
Specifies how devices are grouped. Valid
|
372
|
+
only within ``jax.pmap`` collectives.
|
373
|
+
Groups of axis indices within that named axis
|
374
|
+
representing subsets of devices to reduce over (default: None). For
|
375
|
+
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
|
376
|
+
the examples on the first two and last two devices. See `jax.lax.psum`
|
377
|
+
for more details.
|
378
|
+
|
379
|
+
References
|
380
|
+
----------
|
381
|
+
.. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training
|
382
|
+
by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag.
|
383
|
+
|
384
|
+
'''
|
385
|
+
|
386
|
+
BatchNorm1d.__doc__ = BatchNorm1d.__doc__ % _bn_doc
|
387
|
+
BatchNorm2d.__doc__ = BatchNorm2d.__doc__ % _bn_doc
|
388
|
+
BatchNorm3d.__doc__ = BatchNorm3d.__doc__ % _bn_doc
|