brainstate 0.0.2.post20240913__py2.py3-none-any.whl → 0.0.2.post20241010__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +4 -2
- brainstate/_module.py +102 -67
- brainstate/_state.py +2 -2
- brainstate/_visualization.py +47 -0
- brainstate/environ.py +116 -9
- brainstate/environ_test.py +56 -0
- brainstate/functional/_activations.py +134 -56
- brainstate/functional/_activations_test.py +331 -0
- brainstate/functional/_normalization.py +21 -10
- brainstate/init/_generic.py +4 -2
- brainstate/mixin.py +1 -1
- brainstate/nn/__init__.py +7 -2
- brainstate/nn/_base.py +2 -2
- brainstate/nn/_connections.py +4 -4
- brainstate/nn/_dynamics.py +5 -5
- brainstate/nn/_elementwise.py +9 -9
- brainstate/nn/_embedding.py +3 -3
- brainstate/nn/_normalizations.py +3 -3
- brainstate/nn/_others.py +2 -2
- brainstate/nn/_poolings.py +6 -6
- brainstate/nn/_rate_rnns.py +1 -1
- brainstate/nn/_readout.py +1 -1
- brainstate/nn/_synouts.py +1 -1
- brainstate/nn/event/__init__.py +25 -0
- brainstate/nn/event/_misc.py +34 -0
- brainstate/nn/event/csr.py +312 -0
- brainstate/nn/event/csr_test.py +118 -0
- brainstate/nn/event/fixed_probability.py +276 -0
- brainstate/nn/event/fixed_probability_test.py +127 -0
- brainstate/nn/event/linear.py +220 -0
- brainstate/nn/event/linear_test.py +111 -0
- brainstate/nn/metrics.py +390 -0
- brainstate/optim/__init__.py +5 -1
- brainstate/optim/_optax_optimizer.py +208 -0
- brainstate/optim/_optax_optimizer_test.py +14 -0
- brainstate/random/__init__.py +24 -0
- brainstate/{random.py → random/_rand_funs.py} +7 -1596
- brainstate/random/_rand_seed.py +169 -0
- brainstate/random/_rand_state.py +1498 -0
- brainstate/{_random_for_unit.py → random/_random_for_unit.py} +1 -1
- brainstate/{random_test.py → random/random_test.py} +208 -191
- brainstate/transform/_jit.py +1 -1
- brainstate/transform/_jit_test.py +19 -0
- brainstate/transform/_make_jaxpr.py +1 -1
- {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/METADATA +1 -1
- brainstate-0.0.2.post20241010.dist-info/RECORD +87 -0
- brainstate-0.0.2.post20240913.dist-info/RECORD +0 -70
- {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241010.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Tests for nn module."""
|
17
|
+
|
18
|
+
import itertools
|
19
|
+
from functools import partial
|
20
|
+
|
21
|
+
import jax
|
22
|
+
import jax.numpy as jnp
|
23
|
+
import scipy.stats
|
24
|
+
from absl.testing import parameterized
|
25
|
+
from jax._src import test_util as jtu
|
26
|
+
from jax.test_util import check_grads
|
27
|
+
|
28
|
+
import brainstate as bst
|
29
|
+
|
30
|
+
|
31
|
+
class NNFunctionsTest(jtu.JaxTestCase):
|
32
|
+
@jtu.skip_on_flag("jax_skip_slow_tests", True)
|
33
|
+
def testSoftplusGrad(self):
|
34
|
+
check_grads(bst.functional.softplus, (1e-8,), order=4, )
|
35
|
+
|
36
|
+
def testSoftplusGradZero(self):
|
37
|
+
check_grads(bst.functional.softplus, (0.,), order=1)
|
38
|
+
|
39
|
+
def testSoftplusGradInf(self):
|
40
|
+
self.assertAllClose(1., jax.grad(bst.functional.softplus)(float('inf')))
|
41
|
+
|
42
|
+
def testSoftplusGradNegInf(self):
|
43
|
+
check_grads(bst.functional.softplus, (-float('inf'),), order=1)
|
44
|
+
|
45
|
+
def testSoftplusGradNan(self):
|
46
|
+
check_grads(bst.functional.softplus, (float('nan'),), order=1)
|
47
|
+
|
48
|
+
@parameterized.parameters([int, float] + jtu.dtypes.floating + jtu.dtypes.integer)
|
49
|
+
def testSoftplusZero(self, dtype):
|
50
|
+
self.assertEqual(jnp.log(dtype(2)), bst.functional.softplus(dtype(0)))
|
51
|
+
|
52
|
+
def testSparseplusGradZero(self):
|
53
|
+
check_grads(bst.functional.sparse_plus, (-2.,), order=1)
|
54
|
+
|
55
|
+
def testSparseplusGrad(self):
|
56
|
+
check_grads(bst.functional.sparse_plus, (0.,), order=1)
|
57
|
+
|
58
|
+
def testSparseplusAndSparseSigmoid(self):
|
59
|
+
self.assertAllClose(
|
60
|
+
jax.grad(bst.functional.sparse_plus)(0.),
|
61
|
+
bst.functional.sparse_sigmoid(0.),
|
62
|
+
check_dtypes=False)
|
63
|
+
self.assertAllClose(
|
64
|
+
jax.grad(bst.functional.sparse_plus)(2.),
|
65
|
+
bst.functional.sparse_sigmoid(2.),
|
66
|
+
check_dtypes=False)
|
67
|
+
self.assertAllClose(
|
68
|
+
jax.grad(bst.functional.sparse_plus)(-2.),
|
69
|
+
bst.functional.sparse_sigmoid(-2.),
|
70
|
+
check_dtypes=False)
|
71
|
+
|
72
|
+
def testSquareplusGrad(self):
|
73
|
+
check_grads(bst.functional.squareplus, (1e-8,), order=4,
|
74
|
+
)
|
75
|
+
|
76
|
+
def testSquareplusGradZero(self):
|
77
|
+
check_grads(bst.functional.squareplus, (0.,), order=1,
|
78
|
+
)
|
79
|
+
|
80
|
+
def testSquareplusGradNegInf(self):
|
81
|
+
check_grads(bst.functional.squareplus, (-float('inf'),), order=1,
|
82
|
+
)
|
83
|
+
|
84
|
+
def testSquareplusGradNan(self):
|
85
|
+
check_grads(bst.functional.squareplus, (float('nan'),), order=1,
|
86
|
+
)
|
87
|
+
|
88
|
+
@parameterized.parameters([float] + jtu.dtypes.floating)
|
89
|
+
def testSquareplusZero(self, dtype):
|
90
|
+
self.assertEqual(dtype(1), bst.functional.squareplus(dtype(0), dtype(4)))
|
91
|
+
|
92
|
+
def testMishGrad(self):
|
93
|
+
check_grads(bst.functional.mish, (1e-8,), order=4,
|
94
|
+
)
|
95
|
+
|
96
|
+
def testMishGradZero(self):
|
97
|
+
check_grads(bst.functional.mish, (0.,), order=1,
|
98
|
+
)
|
99
|
+
|
100
|
+
def testMishGradNegInf(self):
|
101
|
+
check_grads(bst.functional.mish, (-float('inf'),), order=1,
|
102
|
+
)
|
103
|
+
|
104
|
+
def testMishGradNan(self):
|
105
|
+
check_grads(bst.functional.mish, (float('nan'),), order=1,
|
106
|
+
)
|
107
|
+
|
108
|
+
@parameterized.parameters([float] + jtu.dtypes.floating)
|
109
|
+
def testMishZero(self, dtype):
|
110
|
+
self.assertEqual(dtype(0), bst.functional.mish(dtype(0)))
|
111
|
+
|
112
|
+
def testReluGrad(self):
|
113
|
+
rtol = None
|
114
|
+
check_grads(bst.functional.relu, (1.,), order=3, rtol=rtol)
|
115
|
+
check_grads(bst.functional.relu, (-1.,), order=3, rtol=rtol)
|
116
|
+
jaxpr = jax.make_jaxpr(jax.grad(bst.functional.relu))(0.)
|
117
|
+
self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
|
118
|
+
|
119
|
+
def testRelu6Grad(self):
|
120
|
+
rtol = None
|
121
|
+
check_grads(bst.functional.relu6, (1.,), order=3, rtol=rtol)
|
122
|
+
check_grads(bst.functional.relu6, (-1.,), order=3, rtol=rtol)
|
123
|
+
self.assertAllClose(jax.grad(bst.functional.relu6)(0.), 0., check_dtypes=False)
|
124
|
+
self.assertAllClose(jax.grad(bst.functional.relu6)(6.), 0., check_dtypes=False)
|
125
|
+
|
126
|
+
def testSoftplusValue(self):
|
127
|
+
val = bst.functional.softplus(89.)
|
128
|
+
self.assertAllClose(val, 89., check_dtypes=False)
|
129
|
+
|
130
|
+
def testSparseplusValue(self):
|
131
|
+
val = bst.functional.sparse_plus(89.)
|
132
|
+
self.assertAllClose(val, 89., check_dtypes=False)
|
133
|
+
|
134
|
+
def testSparsesigmoidValue(self):
|
135
|
+
self.assertAllClose(bst.functional.sparse_sigmoid(-2.), 0., check_dtypes=False)
|
136
|
+
self.assertAllClose(bst.functional.sparse_sigmoid(2.), 1., check_dtypes=False)
|
137
|
+
self.assertAllClose(bst.functional.sparse_sigmoid(0.), .5, check_dtypes=False)
|
138
|
+
|
139
|
+
def testSquareplusValue(self):
|
140
|
+
val = bst.functional.squareplus(1e3)
|
141
|
+
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
142
|
+
|
143
|
+
def testMishValue(self):
|
144
|
+
val = bst.functional.mish(1e3)
|
145
|
+
self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
|
146
|
+
|
147
|
+
def testEluValue(self):
|
148
|
+
val = bst.functional.elu(1e4)
|
149
|
+
self.assertAllClose(val, 1e4, check_dtypes=False)
|
150
|
+
|
151
|
+
def testGluValue(self):
|
152
|
+
val = bst.functional.glu(jnp.array([1.0, 0.0]), axis=0)
|
153
|
+
self.assertAllClose(val, jnp.array([0.5]))
|
154
|
+
|
155
|
+
@parameterized.parameters(False, True)
|
156
|
+
def testGeluIntType(self, approximate):
|
157
|
+
val_float = bst.functional.gelu(jnp.array(-1.0), approximate=approximate)
|
158
|
+
val_int = bst.functional.gelu(jnp.array(-1), approximate=approximate)
|
159
|
+
self.assertAllClose(val_float, val_int)
|
160
|
+
|
161
|
+
@parameterized.parameters(False, True)
|
162
|
+
def testGelu(self, approximate):
|
163
|
+
def gelu_reference(x):
|
164
|
+
return x * scipy.stats.norm.cdf(x)
|
165
|
+
|
166
|
+
rng = jtu.rand_default(self.rng())
|
167
|
+
args_maker = lambda: [rng((4, 5, 6), jnp.float32)]
|
168
|
+
self._CheckAgainstNumpy(
|
169
|
+
gelu_reference, partial(bst.functional.gelu, approximate=approximate), args_maker,
|
170
|
+
check_dtypes=False, tol=1e-3 if approximate else None)
|
171
|
+
|
172
|
+
@parameterized.parameters(*itertools.product(
|
173
|
+
(jnp.float32, jnp.bfloat16, jnp.float16),
|
174
|
+
(partial(bst.functional.gelu, approximate=False),
|
175
|
+
partial(bst.functional.gelu, approximate=True),
|
176
|
+
bst.functional.relu,
|
177
|
+
bst.functional.softplus,
|
178
|
+
bst.functional.sparse_plus,
|
179
|
+
bst.functional.sigmoid,
|
180
|
+
bst.functional.squareplus,
|
181
|
+
bst.functional.mish)))
|
182
|
+
def testDtypeMatchesInput(self, dtype, fn):
|
183
|
+
x = jnp.zeros((), dtype=dtype)
|
184
|
+
out = fn(x)
|
185
|
+
self.assertEqual(out.dtype, dtype)
|
186
|
+
|
187
|
+
def testEluMemory(self):
|
188
|
+
# see https://github.com/google/jax/pull/1640
|
189
|
+
with jax.enable_checks(False): # With checks we materialize the array
|
190
|
+
jax.make_jaxpr(lambda: bst.functional.elu(jnp.ones((10 ** 12,)))) # don't oom
|
191
|
+
|
192
|
+
def testHardTanhMemory(self):
|
193
|
+
# see https://github.com/google/jax/pull/1640
|
194
|
+
with jax.enable_checks(False): # With checks we materialize the array
|
195
|
+
jax.make_jaxpr(lambda: bst.functional.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
|
196
|
+
|
197
|
+
@parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
|
198
|
+
def testSoftmaxEmptyArray(self, fn):
|
199
|
+
x = jnp.array([], dtype=float)
|
200
|
+
self.assertArraysEqual(fn(x), x)
|
201
|
+
|
202
|
+
@parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
|
203
|
+
def testSoftmaxEmptyMask(self, fn):
|
204
|
+
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
205
|
+
m = jnp.zeros_like(x, dtype=bool)
|
206
|
+
expected = jnp.full_like(x, 0.0 if fn is bst.functional.softmax else -jnp.inf)
|
207
|
+
self.assertArraysEqual(fn(x, where=m), expected)
|
208
|
+
|
209
|
+
@parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
|
210
|
+
def testSoftmaxWhereMask(self, fn):
|
211
|
+
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
212
|
+
m = jnp.array([True, False, True, True])
|
213
|
+
|
214
|
+
out = fn(x, where=m)
|
215
|
+
self.assertAllClose(out[m], fn(x[m]))
|
216
|
+
|
217
|
+
probs = out if fn is bst.functional.softmax else jnp.exp(out)
|
218
|
+
self.assertAllClose(probs.sum(), 1.0)
|
219
|
+
|
220
|
+
@parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
|
221
|
+
def testSoftmaxWhereGrad(self, fn):
|
222
|
+
# regression test for https://github.com/google/jax/issues/19490
|
223
|
+
x = jnp.array([36., 10000.])
|
224
|
+
mask = x < 1000
|
225
|
+
|
226
|
+
f = lambda x, mask: fn(x, where=mask)[0]
|
227
|
+
|
228
|
+
self.assertAllClose(jax.grad(f)(x, mask), jnp.zeros_like(x))
|
229
|
+
|
230
|
+
def testSoftmaxGrad(self):
|
231
|
+
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
232
|
+
jtu.check_grads(bst.functional.softmax, (x,), order=2, atol=5e-3)
|
233
|
+
|
234
|
+
def testStandardizeWhereMask(self):
|
235
|
+
x = jnp.array([5.5, 1.3, -4.2, 0.9])
|
236
|
+
m = jnp.array([True, False, True, True])
|
237
|
+
x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
|
238
|
+
|
239
|
+
out_masked = jnp.take(bst.functional.standardize(x, where=m), jnp.array([0, 2, 3]))
|
240
|
+
out_filtered = bst.functional.standardize(x_filtered)
|
241
|
+
|
242
|
+
self.assertAllClose(out_masked, out_filtered)
|
243
|
+
|
244
|
+
def testOneHot(self):
|
245
|
+
actual = bst.functional.one_hot(jnp.array([0, 1, 2]), 3)
|
246
|
+
expected = jnp.array([[1., 0., 0.],
|
247
|
+
[0., 1., 0.],
|
248
|
+
[0., 0., 1.]])
|
249
|
+
self.assertAllClose(actual, expected, check_dtypes=False)
|
250
|
+
|
251
|
+
actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3)
|
252
|
+
expected = jnp.array([[0., 1., 0.],
|
253
|
+
[0., 0., 1.],
|
254
|
+
[1., 0., 0.]])
|
255
|
+
self.assertAllClose(actual, expected, check_dtypes=False)
|
256
|
+
|
257
|
+
def testOneHotOutOfBound(self):
|
258
|
+
actual = bst.functional.one_hot(jnp.array([-1, 3]), 3)
|
259
|
+
expected = jnp.array([[0., 0., 0.],
|
260
|
+
[0., 0., 0.]])
|
261
|
+
self.assertAllClose(actual, expected, check_dtypes=False)
|
262
|
+
|
263
|
+
def testOneHotNonArrayInput(self):
|
264
|
+
actual = bst.functional.one_hot([0, 1, 2], 3)
|
265
|
+
expected = jnp.array([[1., 0., 0.],
|
266
|
+
[0., 1., 0.],
|
267
|
+
[0., 0., 1.]])
|
268
|
+
self.assertAllClose(actual, expected, check_dtypes=False)
|
269
|
+
|
270
|
+
def testOneHotCustomDtype(self):
|
271
|
+
actual = bst.functional.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
|
272
|
+
expected = jnp.array([[True, False, False],
|
273
|
+
[False, True, False],
|
274
|
+
[False, False, True]])
|
275
|
+
self.assertAllClose(actual, expected)
|
276
|
+
|
277
|
+
def testOneHotAxis(self):
|
278
|
+
expected = jnp.array([[0., 1., 0.],
|
279
|
+
[0., 0., 1.],
|
280
|
+
[1., 0., 0.]]).T
|
281
|
+
|
282
|
+
actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
|
283
|
+
self.assertAllClose(actual, expected, check_dtypes=False)
|
284
|
+
|
285
|
+
actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
|
286
|
+
self.assertAllClose(actual, expected, check_dtypes=False)
|
287
|
+
|
288
|
+
def testTanhExists(self):
|
289
|
+
print(bst.functional.tanh) # doesn't crash
|
290
|
+
|
291
|
+
def testCustomJVPLeak(self):
|
292
|
+
# https://github.com/google/jax/issues/8171
|
293
|
+
@jax.jit
|
294
|
+
def fwd():
|
295
|
+
a = jnp.array(1.)
|
296
|
+
|
297
|
+
def f(hx, _):
|
298
|
+
hx = bst.functional.sigmoid(hx + a)
|
299
|
+
return hx, None
|
300
|
+
|
301
|
+
hx = jnp.array(0.)
|
302
|
+
jax.lax.scan(f, hx, None, length=2)
|
303
|
+
|
304
|
+
with jax.checking_leaks():
|
305
|
+
fwd() # doesn't crash
|
306
|
+
|
307
|
+
def testCustomJVPLeak2(self):
|
308
|
+
# https://github.com/google/jax/issues/8171
|
309
|
+
# The above test uses jax.bst.functional.sigmoid, as in the original #8171, but that
|
310
|
+
# function no longer actually has a custom_jvp! So we inline the old def.
|
311
|
+
|
312
|
+
@jax.custom_jvp
|
313
|
+
def sigmoid(x):
|
314
|
+
one = jnp.float32(1)
|
315
|
+
return jax.lax.div(one, jax.lax.add(one, jax.lax.exp(jax.lax.neg(x))))
|
316
|
+
|
317
|
+
sigmoid.defjvps(lambda g, ans, x: g * ans * (jnp.float32(1) - ans))
|
318
|
+
|
319
|
+
@jax.jit
|
320
|
+
def fwd():
|
321
|
+
a = jnp.array(1., 'float32')
|
322
|
+
|
323
|
+
def f(hx, _):
|
324
|
+
hx = sigmoid(hx + a)
|
325
|
+
return hx, None
|
326
|
+
|
327
|
+
hx = jnp.array(0., 'float32')
|
328
|
+
jax.lax.scan(f, hx, None, length=2)
|
329
|
+
|
330
|
+
with jax.checking_leaks():
|
331
|
+
fwd() # doesn't crash
|
@@ -15,12 +15,13 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
from typing import Optional
|
18
|
+
from typing import Optional, Union
|
19
19
|
|
20
|
+
import brainunit as u
|
20
21
|
import jax
|
21
|
-
import jax.numpy as jnp
|
22
22
|
|
23
|
-
from
|
23
|
+
from brainstate._utils import set_module_as
|
24
|
+
from brainstate.typing import ArrayLike
|
24
25
|
|
25
26
|
__all__ = [
|
26
27
|
'weight_standardization',
|
@@ -29,18 +30,18 @@ __all__ = [
|
|
29
30
|
|
30
31
|
@set_module_as('brainstate.functional')
|
31
32
|
def weight_standardization(
|
32
|
-
w:
|
33
|
+
w: ArrayLike,
|
33
34
|
eps: float = 1e-4,
|
34
35
|
gain: Optional[jax.Array] = None,
|
35
36
|
out_axis: int = -1,
|
36
|
-
):
|
37
|
+
) -> Union[jax.Array, u.Quantity]:
|
37
38
|
"""
|
38
39
|
Scaled Weight Standardization,
|
39
40
|
see `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization <https://paperswithcode.com/paper/weight-standardization>`_.
|
40
41
|
|
41
42
|
Parameters
|
42
43
|
----------
|
43
|
-
w :
|
44
|
+
w : ArrayLike
|
44
45
|
The weight tensor.
|
45
46
|
eps : float
|
46
47
|
A small value to avoid division by zero.
|
@@ -51,7 +52,7 @@ def weight_standardization(
|
|
51
52
|
|
52
53
|
Returns
|
53
54
|
-------
|
54
|
-
|
55
|
+
ArrayLike
|
55
56
|
The scaled weight tensor.
|
56
57
|
"""
|
57
58
|
if out_axis < 0:
|
@@ -63,9 +64,19 @@ def weight_standardization(
|
|
63
64
|
fan_in *= w.shape[i]
|
64
65
|
axes.append(i)
|
65
66
|
# normalize the weight
|
66
|
-
mean =
|
67
|
-
var =
|
68
|
-
|
67
|
+
mean = u.math.mean(w, axis=axes, keepdims=True)
|
68
|
+
var = u.math.var(w, axis=axes, keepdims=True)
|
69
|
+
|
70
|
+
temp = u.math.maximum(var * fan_in, eps)
|
71
|
+
if isinstance(temp, u.Quantity):
|
72
|
+
unit = temp.unit
|
73
|
+
temp = temp.mantissa
|
74
|
+
if unit.is_unitless:
|
75
|
+
scale = jax.lax.rsqrt(temp)
|
76
|
+
else:
|
77
|
+
scale = u.Quantity(jax.lax.rsqrt(temp), unit=1 / unit ** 0.5)
|
78
|
+
else:
|
79
|
+
scale = jax.lax.rsqrt(temp)
|
69
80
|
if gain is not None:
|
70
81
|
scale = gain * scale
|
71
82
|
shift = mean * scale
|
brainstate/init/_generic.py
CHANGED
@@ -24,6 +24,7 @@ import numpy as np
|
|
24
24
|
from brainstate._state import State
|
25
25
|
from brainstate.typing import ArrayLike
|
26
26
|
from ._base import to_size
|
27
|
+
from brainstate.mixin import Mode
|
27
28
|
|
28
29
|
__all__ = [
|
29
30
|
'param',
|
@@ -36,7 +37,7 @@ def _is_scalar(x):
|
|
36
37
|
return bu.math.isscalar(x)
|
37
38
|
|
38
39
|
|
39
|
-
def
|
40
|
+
def are_broadcastable_shapes(shape1, shape2):
|
40
41
|
"""
|
41
42
|
Check if two shapes are broadcastable.
|
42
43
|
|
@@ -88,6 +89,7 @@ def param(
|
|
88
89
|
batch_size: Optional[int] = None,
|
89
90
|
allow_none: bool = True,
|
90
91
|
allow_scalar: bool = True,
|
92
|
+
mode: Mode = None,
|
91
93
|
):
|
92
94
|
"""Initialize parameters.
|
93
95
|
|
@@ -143,7 +145,7 @@ def param(
|
|
143
145
|
raise ValueError(f'Unknown parameter type: {type(parameter)}')
|
144
146
|
|
145
147
|
# Check if the shape of the parameter matches the given size
|
146
|
-
if not
|
148
|
+
if not are_broadcastable_shapes(parameter.shape, sizes):
|
147
149
|
raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
|
148
150
|
|
149
151
|
# Expand the parameter to match the given batch size
|
brainstate/mixin.py
CHANGED
brainstate/nn/__init__.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
+
from . import metrics
|
16
17
|
from ._base import *
|
17
18
|
from ._base import __all__ as base_all
|
18
19
|
from ._connections import *
|
@@ -39,6 +40,8 @@ from ._readout import *
|
|
39
40
|
from ._readout import __all__ as readout_all
|
40
41
|
from ._synouts import *
|
41
42
|
from ._synouts import __all__ as synouts_all
|
43
|
+
from .event import *
|
44
|
+
from .event import __all__ as event_all
|
42
45
|
|
43
46
|
__all__ = (
|
44
47
|
base_all +
|
@@ -53,7 +56,8 @@ __all__ = (
|
|
53
56
|
readout_all +
|
54
57
|
synouts_all +
|
55
58
|
_projection_all +
|
56
|
-
_misc_all
|
59
|
+
_misc_all +
|
60
|
+
event_all
|
57
61
|
)
|
58
62
|
|
59
63
|
del (
|
@@ -68,5 +72,6 @@ del (
|
|
68
72
|
readout_all,
|
69
73
|
synouts_all,
|
70
74
|
_projection_all,
|
71
|
-
_misc_all
|
75
|
+
_misc_all,
|
76
|
+
event_all
|
72
77
|
)
|
brainstate/nn/_base.py
CHANGED
@@ -20,8 +20,8 @@ from __future__ import annotations
|
|
20
20
|
import inspect
|
21
21
|
from typing import Sequence, Optional, Tuple, Union
|
22
22
|
|
23
|
-
from
|
24
|
-
from
|
23
|
+
from brainstate._module import Module, UpdateReturn, Container, visible_module_dict
|
24
|
+
from brainstate.mixin import Mixin, DelayedInitializer, DelayedInit
|
25
25
|
|
26
26
|
__all__ = [
|
27
27
|
'ExplicitInOutSize',
|
brainstate/nn/_connections.py
CHANGED
@@ -25,10 +25,10 @@ import jax
|
|
25
25
|
import jax.numpy as jnp
|
26
26
|
|
27
27
|
from ._base import DnnLayer
|
28
|
-
from
|
29
|
-
from
|
30
|
-
from
|
31
|
-
from
|
28
|
+
from brainstate import init, functional
|
29
|
+
from brainstate._state import ParamState
|
30
|
+
from brainstate.mixin import Mode
|
31
|
+
from brainstate.typing import ArrayLike
|
32
32
|
|
33
33
|
T = TypeVar('T')
|
34
34
|
|
brainstate/nn/_dynamics.py
CHANGED
@@ -24,11 +24,11 @@ import jax.numpy as jnp
|
|
24
24
|
|
25
25
|
from ._base import ExplicitInOutSize
|
26
26
|
from ._misc import exp_euler_step
|
27
|
-
from
|
28
|
-
from
|
29
|
-
from
|
30
|
-
from
|
31
|
-
from
|
27
|
+
from brainstate import environ, init, surrogate
|
28
|
+
from brainstate._module import Dynamics
|
29
|
+
from brainstate._state import ShortTermState
|
30
|
+
from brainstate.mixin import DelayedInit, Mode, AlignPost
|
31
|
+
from brainstate.typing import DTypeLike, ArrayLike, Size
|
32
32
|
|
33
33
|
__all__ = [
|
34
34
|
# neuron models
|
brainstate/nn/_elementwise.py
CHANGED
@@ -19,16 +19,16 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
from typing import Optional
|
21
21
|
|
22
|
-
import brainunit as
|
22
|
+
import brainunit as u
|
23
23
|
import jax.numpy as jnp
|
24
24
|
import jax.typing
|
25
25
|
|
26
26
|
from ._base import ElementWiseBlock
|
27
|
-
from
|
28
|
-
from
|
29
|
-
from
|
30
|
-
from
|
31
|
-
from
|
27
|
+
from brainstate import environ, random, functional as F
|
28
|
+
from brainstate._module import Module
|
29
|
+
from brainstate._state import ParamState
|
30
|
+
from brainstate.mixin import Mode
|
31
|
+
from brainstate.typing import ArrayLike
|
32
32
|
|
33
33
|
__all__ = [
|
34
34
|
# activation functions
|
@@ -83,7 +83,7 @@ class Threshold(Module, ElementWiseBlock):
|
|
83
83
|
self.value = value
|
84
84
|
|
85
85
|
def __call__(self, x: ArrayLike) -> ArrayLike:
|
86
|
-
dtype =
|
86
|
+
dtype = u.math.get_dtype(x)
|
87
87
|
return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype),
|
88
88
|
x,
|
89
89
|
jnp.asarray(self.value, dtype=dtype))
|
@@ -1143,7 +1143,7 @@ class Dropout(Module, ElementWiseBlock):
|
|
1143
1143
|
self.prob = prob
|
1144
1144
|
|
1145
1145
|
def __call__(self, x):
|
1146
|
-
dtype =
|
1146
|
+
dtype = u.math.get_dtype(x)
|
1147
1147
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
1148
1148
|
if fit_phase and self.prob < 1.:
|
1149
1149
|
keep_mask = random.bernoulli(self.prob, x.shape)
|
@@ -1173,7 +1173,7 @@ class _DropoutNd(Module, ElementWiseBlock):
|
|
1173
1173
|
self.channel_axis = channel_axis
|
1174
1174
|
|
1175
1175
|
def __call__(self, x):
|
1176
|
-
dtype =
|
1176
|
+
dtype = u.math.get_dtype(x)
|
1177
1177
|
# get fit phase
|
1178
1178
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
1179
1179
|
|
brainstate/nn/_embedding.py
CHANGED
@@ -17,9 +17,9 @@ from typing import Optional, Callable, Union
|
|
17
17
|
|
18
18
|
from ._base import DnnLayer
|
19
19
|
from .. import init
|
20
|
-
from
|
21
|
-
from
|
22
|
-
from
|
20
|
+
from brainstate._state import ParamState
|
21
|
+
from brainstate.mixin import Mode, Training
|
22
|
+
from brainstate.typing import ArrayLike
|
23
23
|
|
24
24
|
__all__ = [
|
25
25
|
'Embedding',
|
brainstate/nn/_normalizations.py
CHANGED
@@ -25,9 +25,9 @@ import jax.numpy as jnp
|
|
25
25
|
|
26
26
|
from ._base import DnnLayer
|
27
27
|
from .. import environ, init
|
28
|
-
from
|
29
|
-
from
|
30
|
-
from
|
28
|
+
from brainstate._state import LongTermState, ParamState
|
29
|
+
from brainstate.mixin import Mode
|
30
|
+
from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
|
31
31
|
|
32
32
|
__all__ = [
|
33
33
|
'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
|
brainstate/nn/_others.py
CHANGED
@@ -23,8 +23,8 @@ import brainunit as bu
|
|
23
23
|
import jax.numpy as jnp
|
24
24
|
|
25
25
|
from ._base import DnnLayer
|
26
|
-
from
|
27
|
-
from
|
26
|
+
from brainstate.mixin import Mode
|
27
|
+
from brainstate import random, environ, typing, init
|
28
28
|
|
29
29
|
__all__ = [
|
30
30
|
'DropoutFixed',
|
brainstate/nn/_poolings.py
CHANGED
@@ -21,7 +21,7 @@ import functools
|
|
21
21
|
from typing import Sequence, Optional
|
22
22
|
from typing import Union, Tuple, Callable, List
|
23
23
|
|
24
|
-
import brainunit as
|
24
|
+
import brainunit as u
|
25
25
|
import jax
|
26
26
|
import jax.numpy as jnp
|
27
27
|
import numpy as np
|
@@ -29,7 +29,7 @@ import numpy as np
|
|
29
29
|
from ._base import DnnLayer, ExplicitInOutSize
|
30
30
|
from .. import environ
|
31
31
|
from ..mixin import Mode
|
32
|
-
from
|
32
|
+
from brainstate.typing import Size
|
33
33
|
|
34
34
|
__all__ = [
|
35
35
|
'Flatten', 'Unflatten',
|
@@ -85,7 +85,7 @@ class Flatten(DnnLayer, ExplicitInOutSize):
|
|
85
85
|
|
86
86
|
if in_size is not None:
|
87
87
|
self.in_size = tuple(in_size)
|
88
|
-
y = jax.eval_shape(functools.partial(
|
88
|
+
y = jax.eval_shape(functools.partial(u.math.flatten, start_axis=start_axis, end_axis=end_axis),
|
89
89
|
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
|
90
90
|
self.out_size = y.shape
|
91
91
|
|
@@ -101,7 +101,7 @@ class Flatten(DnnLayer, ExplicitInOutSize):
|
|
101
101
|
start_axis = self.start_axis + dim_diff
|
102
102
|
else:
|
103
103
|
start_axis = x.ndim + self.start_axis
|
104
|
-
return
|
104
|
+
return u.math.flatten(x, start_axis, self.end_axis)
|
105
105
|
|
106
106
|
def __repr__(self) -> str:
|
107
107
|
return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
|
@@ -153,12 +153,12 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
|
|
153
153
|
|
154
154
|
if in_size is not None:
|
155
155
|
self.in_size = tuple(in_size)
|
156
|
-
y = jax.eval_shape(functools.partial(
|
156
|
+
y = jax.eval_shape(functools.partial(u.math.unflatten, axis=axis, sizes=sizes),
|
157
157
|
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
|
158
158
|
self.out_size = y.shape
|
159
159
|
|
160
160
|
def update(self, x):
|
161
|
-
return
|
161
|
+
return u.math.unflatten(x, self.axis, self.sizes)
|
162
162
|
|
163
163
|
def __repr__(self):
|
164
164
|
return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
|
brainstate/nn/_rate_rnns.py
CHANGED
@@ -27,7 +27,7 @@ from .. import random, init, functional
|
|
27
27
|
from .._module import Module
|
28
28
|
from .._state import ShortTermState, ParamState
|
29
29
|
from ..mixin import DelayedInit, Mode
|
30
|
-
from
|
30
|
+
from brainstate.typing import ArrayLike
|
31
31
|
|
32
32
|
__all__ = [
|
33
33
|
'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
|
brainstate/nn/_readout.py
CHANGED
@@ -29,7 +29,7 @@ from ._misc import exp_euler_step
|
|
29
29
|
from .. import environ, init, surrogate
|
30
30
|
from .._state import ShortTermState, ParamState
|
31
31
|
from ..mixin import Mode
|
32
|
-
from
|
32
|
+
from brainstate.typing import Size, ArrayLike, DTypeLike
|
33
33
|
|
34
34
|
__all__ = [
|
35
35
|
'LeakyRateReadout',
|