brainstate 0.0.2.post20240913__py2.py3-none-any.whl → 0.0.2.post20241009__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. brainstate/__init__.py +4 -2
  2. brainstate/_module.py +102 -67
  3. brainstate/_state.py +2 -2
  4. brainstate/_visualization.py +47 -0
  5. brainstate/environ.py +116 -9
  6. brainstate/environ_test.py +56 -0
  7. brainstate/functional/_activations.py +134 -56
  8. brainstate/functional/_activations_test.py +331 -0
  9. brainstate/functional/_normalization.py +21 -10
  10. brainstate/init/_generic.py +4 -2
  11. brainstate/mixin.py +1 -1
  12. brainstate/nn/__init__.py +7 -2
  13. brainstate/nn/_base.py +2 -2
  14. brainstate/nn/_connections.py +4 -4
  15. brainstate/nn/_dynamics.py +5 -5
  16. brainstate/nn/_elementwise.py +9 -9
  17. brainstate/nn/_embedding.py +3 -3
  18. brainstate/nn/_normalizations.py +3 -3
  19. brainstate/nn/_others.py +2 -2
  20. brainstate/nn/_poolings.py +6 -6
  21. brainstate/nn/_rate_rnns.py +1 -1
  22. brainstate/nn/_readout.py +1 -1
  23. brainstate/nn/_synouts.py +1 -1
  24. brainstate/nn/event/__init__.py +25 -0
  25. brainstate/nn/event/_misc.py +34 -0
  26. brainstate/nn/event/csr.py +312 -0
  27. brainstate/nn/event/csr_test.py +118 -0
  28. brainstate/nn/event/fixed_probability.py +276 -0
  29. brainstate/nn/event/fixed_probability_test.py +127 -0
  30. brainstate/nn/event/linear.py +220 -0
  31. brainstate/nn/event/linear_test.py +111 -0
  32. brainstate/nn/metrics.py +390 -0
  33. brainstate/optim/__init__.py +5 -1
  34. brainstate/optim/_optax_optimizer.py +208 -0
  35. brainstate/optim/_optax_optimizer_test.py +14 -0
  36. brainstate/random/__init__.py +24 -0
  37. brainstate/{random.py → random/_rand_funs.py} +7 -1596
  38. brainstate/random/_rand_seed.py +169 -0
  39. brainstate/random/_rand_state.py +1491 -0
  40. brainstate/{_random_for_unit.py → random/_random_for_unit.py} +1 -1
  41. brainstate/{random_test.py → random/random_test.py} +208 -191
  42. brainstate/transform/_jit.py +1 -1
  43. brainstate/transform/_jit_test.py +19 -0
  44. brainstate/transform/_make_jaxpr.py +1 -1
  45. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/METADATA +1 -1
  46. brainstate-0.0.2.post20241009.dist-info/RECORD +87 -0
  47. brainstate-0.0.2.post20240913.dist-info/RECORD +0 -70
  48. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/LICENSE +0 -0
  49. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/WHEEL +0 -0
  50. {brainstate-0.0.2.post20240913.dist-info → brainstate-0.0.2.post20241009.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Tests for nn module."""
17
+
18
+ import itertools
19
+ from functools import partial
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import scipy.stats
24
+ from absl.testing import parameterized
25
+ from jax._src import test_util as jtu
26
+ from jax.test_util import check_grads
27
+
28
+ import brainstate as bst
29
+
30
+
31
+ class NNFunctionsTest(jtu.JaxTestCase):
32
+ @jtu.skip_on_flag("jax_skip_slow_tests", True)
33
+ def testSoftplusGrad(self):
34
+ check_grads(bst.functional.softplus, (1e-8,), order=4, )
35
+
36
+ def testSoftplusGradZero(self):
37
+ check_grads(bst.functional.softplus, (0.,), order=1)
38
+
39
+ def testSoftplusGradInf(self):
40
+ self.assertAllClose(1., jax.grad(bst.functional.softplus)(float('inf')))
41
+
42
+ def testSoftplusGradNegInf(self):
43
+ check_grads(bst.functional.softplus, (-float('inf'),), order=1)
44
+
45
+ def testSoftplusGradNan(self):
46
+ check_grads(bst.functional.softplus, (float('nan'),), order=1)
47
+
48
+ @parameterized.parameters([int, float] + jtu.dtypes.floating + jtu.dtypes.integer)
49
+ def testSoftplusZero(self, dtype):
50
+ self.assertEqual(jnp.log(dtype(2)), bst.functional.softplus(dtype(0)))
51
+
52
+ def testSparseplusGradZero(self):
53
+ check_grads(bst.functional.sparse_plus, (-2.,), order=1)
54
+
55
+ def testSparseplusGrad(self):
56
+ check_grads(bst.functional.sparse_plus, (0.,), order=1)
57
+
58
+ def testSparseplusAndSparseSigmoid(self):
59
+ self.assertAllClose(
60
+ jax.grad(bst.functional.sparse_plus)(0.),
61
+ bst.functional.sparse_sigmoid(0.),
62
+ check_dtypes=False)
63
+ self.assertAllClose(
64
+ jax.grad(bst.functional.sparse_plus)(2.),
65
+ bst.functional.sparse_sigmoid(2.),
66
+ check_dtypes=False)
67
+ self.assertAllClose(
68
+ jax.grad(bst.functional.sparse_plus)(-2.),
69
+ bst.functional.sparse_sigmoid(-2.),
70
+ check_dtypes=False)
71
+
72
+ def testSquareplusGrad(self):
73
+ check_grads(bst.functional.squareplus, (1e-8,), order=4,
74
+ )
75
+
76
+ def testSquareplusGradZero(self):
77
+ check_grads(bst.functional.squareplus, (0.,), order=1,
78
+ )
79
+
80
+ def testSquareplusGradNegInf(self):
81
+ check_grads(bst.functional.squareplus, (-float('inf'),), order=1,
82
+ )
83
+
84
+ def testSquareplusGradNan(self):
85
+ check_grads(bst.functional.squareplus, (float('nan'),), order=1,
86
+ )
87
+
88
+ @parameterized.parameters([float] + jtu.dtypes.floating)
89
+ def testSquareplusZero(self, dtype):
90
+ self.assertEqual(dtype(1), bst.functional.squareplus(dtype(0), dtype(4)))
91
+
92
+ def testMishGrad(self):
93
+ check_grads(bst.functional.mish, (1e-8,), order=4,
94
+ )
95
+
96
+ def testMishGradZero(self):
97
+ check_grads(bst.functional.mish, (0.,), order=1,
98
+ )
99
+
100
+ def testMishGradNegInf(self):
101
+ check_grads(bst.functional.mish, (-float('inf'),), order=1,
102
+ )
103
+
104
+ def testMishGradNan(self):
105
+ check_grads(bst.functional.mish, (float('nan'),), order=1,
106
+ )
107
+
108
+ @parameterized.parameters([float] + jtu.dtypes.floating)
109
+ def testMishZero(self, dtype):
110
+ self.assertEqual(dtype(0), bst.functional.mish(dtype(0)))
111
+
112
+ def testReluGrad(self):
113
+ rtol = None
114
+ check_grads(bst.functional.relu, (1.,), order=3, rtol=rtol)
115
+ check_grads(bst.functional.relu, (-1.,), order=3, rtol=rtol)
116
+ jaxpr = jax.make_jaxpr(jax.grad(bst.functional.relu))(0.)
117
+ self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
118
+
119
+ def testRelu6Grad(self):
120
+ rtol = None
121
+ check_grads(bst.functional.relu6, (1.,), order=3, rtol=rtol)
122
+ check_grads(bst.functional.relu6, (-1.,), order=3, rtol=rtol)
123
+ self.assertAllClose(jax.grad(bst.functional.relu6)(0.), 0., check_dtypes=False)
124
+ self.assertAllClose(jax.grad(bst.functional.relu6)(6.), 0., check_dtypes=False)
125
+
126
+ def testSoftplusValue(self):
127
+ val = bst.functional.softplus(89.)
128
+ self.assertAllClose(val, 89., check_dtypes=False)
129
+
130
+ def testSparseplusValue(self):
131
+ val = bst.functional.sparse_plus(89.)
132
+ self.assertAllClose(val, 89., check_dtypes=False)
133
+
134
+ def testSparsesigmoidValue(self):
135
+ self.assertAllClose(bst.functional.sparse_sigmoid(-2.), 0., check_dtypes=False)
136
+ self.assertAllClose(bst.functional.sparse_sigmoid(2.), 1., check_dtypes=False)
137
+ self.assertAllClose(bst.functional.sparse_sigmoid(0.), .5, check_dtypes=False)
138
+
139
+ def testSquareplusValue(self):
140
+ val = bst.functional.squareplus(1e3)
141
+ self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
142
+
143
+ def testMishValue(self):
144
+ val = bst.functional.mish(1e3)
145
+ self.assertAllClose(val, 1e3, check_dtypes=False, atol=1e-3)
146
+
147
+ def testEluValue(self):
148
+ val = bst.functional.elu(1e4)
149
+ self.assertAllClose(val, 1e4, check_dtypes=False)
150
+
151
+ def testGluValue(self):
152
+ val = bst.functional.glu(jnp.array([1.0, 0.0]), axis=0)
153
+ self.assertAllClose(val, jnp.array([0.5]))
154
+
155
+ @parameterized.parameters(False, True)
156
+ def testGeluIntType(self, approximate):
157
+ val_float = bst.functional.gelu(jnp.array(-1.0), approximate=approximate)
158
+ val_int = bst.functional.gelu(jnp.array(-1), approximate=approximate)
159
+ self.assertAllClose(val_float, val_int)
160
+
161
+ @parameterized.parameters(False, True)
162
+ def testGelu(self, approximate):
163
+ def gelu_reference(x):
164
+ return x * scipy.stats.norm.cdf(x)
165
+
166
+ rng = jtu.rand_default(self.rng())
167
+ args_maker = lambda: [rng((4, 5, 6), jnp.float32)]
168
+ self._CheckAgainstNumpy(
169
+ gelu_reference, partial(bst.functional.gelu, approximate=approximate), args_maker,
170
+ check_dtypes=False, tol=1e-3 if approximate else None)
171
+
172
+ @parameterized.parameters(*itertools.product(
173
+ (jnp.float32, jnp.bfloat16, jnp.float16),
174
+ (partial(bst.functional.gelu, approximate=False),
175
+ partial(bst.functional.gelu, approximate=True),
176
+ bst.functional.relu,
177
+ bst.functional.softplus,
178
+ bst.functional.sparse_plus,
179
+ bst.functional.sigmoid,
180
+ bst.functional.squareplus,
181
+ bst.functional.mish)))
182
+ def testDtypeMatchesInput(self, dtype, fn):
183
+ x = jnp.zeros((), dtype=dtype)
184
+ out = fn(x)
185
+ self.assertEqual(out.dtype, dtype)
186
+
187
+ def testEluMemory(self):
188
+ # see https://github.com/google/jax/pull/1640
189
+ with jax.enable_checks(False): # With checks we materialize the array
190
+ jax.make_jaxpr(lambda: bst.functional.elu(jnp.ones((10 ** 12,)))) # don't oom
191
+
192
+ def testHardTanhMemory(self):
193
+ # see https://github.com/google/jax/pull/1640
194
+ with jax.enable_checks(False): # With checks we materialize the array
195
+ jax.make_jaxpr(lambda: bst.functional.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom
196
+
197
+ @parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
198
+ def testSoftmaxEmptyArray(self, fn):
199
+ x = jnp.array([], dtype=float)
200
+ self.assertArraysEqual(fn(x), x)
201
+
202
+ @parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
203
+ def testSoftmaxEmptyMask(self, fn):
204
+ x = jnp.array([5.5, 1.3, -4.2, 0.9])
205
+ m = jnp.zeros_like(x, dtype=bool)
206
+ expected = jnp.full_like(x, 0.0 if fn is bst.functional.softmax else -jnp.inf)
207
+ self.assertArraysEqual(fn(x, where=m), expected)
208
+
209
+ @parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
210
+ def testSoftmaxWhereMask(self, fn):
211
+ x = jnp.array([5.5, 1.3, -4.2, 0.9])
212
+ m = jnp.array([True, False, True, True])
213
+
214
+ out = fn(x, where=m)
215
+ self.assertAllClose(out[m], fn(x[m]))
216
+
217
+ probs = out if fn is bst.functional.softmax else jnp.exp(out)
218
+ self.assertAllClose(probs.sum(), 1.0)
219
+
220
+ @parameterized.parameters([bst.functional.softmax, bst.functional.log_softmax])
221
+ def testSoftmaxWhereGrad(self, fn):
222
+ # regression test for https://github.com/google/jax/issues/19490
223
+ x = jnp.array([36., 10000.])
224
+ mask = x < 1000
225
+
226
+ f = lambda x, mask: fn(x, where=mask)[0]
227
+
228
+ self.assertAllClose(jax.grad(f)(x, mask), jnp.zeros_like(x))
229
+
230
+ def testSoftmaxGrad(self):
231
+ x = jnp.array([5.5, 1.3, -4.2, 0.9])
232
+ jtu.check_grads(bst.functional.softmax, (x,), order=2, atol=5e-3)
233
+
234
+ def testStandardizeWhereMask(self):
235
+ x = jnp.array([5.5, 1.3, -4.2, 0.9])
236
+ m = jnp.array([True, False, True, True])
237
+ x_filtered = jnp.take(x, jnp.array([0, 2, 3]))
238
+
239
+ out_masked = jnp.take(bst.functional.standardize(x, where=m), jnp.array([0, 2, 3]))
240
+ out_filtered = bst.functional.standardize(x_filtered)
241
+
242
+ self.assertAllClose(out_masked, out_filtered)
243
+
244
+ def testOneHot(self):
245
+ actual = bst.functional.one_hot(jnp.array([0, 1, 2]), 3)
246
+ expected = jnp.array([[1., 0., 0.],
247
+ [0., 1., 0.],
248
+ [0., 0., 1.]])
249
+ self.assertAllClose(actual, expected, check_dtypes=False)
250
+
251
+ actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3)
252
+ expected = jnp.array([[0., 1., 0.],
253
+ [0., 0., 1.],
254
+ [1., 0., 0.]])
255
+ self.assertAllClose(actual, expected, check_dtypes=False)
256
+
257
+ def testOneHotOutOfBound(self):
258
+ actual = bst.functional.one_hot(jnp.array([-1, 3]), 3)
259
+ expected = jnp.array([[0., 0., 0.],
260
+ [0., 0., 0.]])
261
+ self.assertAllClose(actual, expected, check_dtypes=False)
262
+
263
+ def testOneHotNonArrayInput(self):
264
+ actual = bst.functional.one_hot([0, 1, 2], 3)
265
+ expected = jnp.array([[1., 0., 0.],
266
+ [0., 1., 0.],
267
+ [0., 0., 1.]])
268
+ self.assertAllClose(actual, expected, check_dtypes=False)
269
+
270
+ def testOneHotCustomDtype(self):
271
+ actual = bst.functional.one_hot(jnp.array([0, 1, 2]), 3, dtype=jnp.bool_)
272
+ expected = jnp.array([[True, False, False],
273
+ [False, True, False],
274
+ [False, False, True]])
275
+ self.assertAllClose(actual, expected)
276
+
277
+ def testOneHotAxis(self):
278
+ expected = jnp.array([[0., 1., 0.],
279
+ [0., 0., 1.],
280
+ [1., 0., 0.]]).T
281
+
282
+ actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=0)
283
+ self.assertAllClose(actual, expected, check_dtypes=False)
284
+
285
+ actual = bst.functional.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
286
+ self.assertAllClose(actual, expected, check_dtypes=False)
287
+
288
+ def testTanhExists(self):
289
+ print(bst.functional.tanh) # doesn't crash
290
+
291
+ def testCustomJVPLeak(self):
292
+ # https://github.com/google/jax/issues/8171
293
+ @jax.jit
294
+ def fwd():
295
+ a = jnp.array(1.)
296
+
297
+ def f(hx, _):
298
+ hx = bst.functional.sigmoid(hx + a)
299
+ return hx, None
300
+
301
+ hx = jnp.array(0.)
302
+ jax.lax.scan(f, hx, None, length=2)
303
+
304
+ with jax.checking_leaks():
305
+ fwd() # doesn't crash
306
+
307
+ def testCustomJVPLeak2(self):
308
+ # https://github.com/google/jax/issues/8171
309
+ # The above test uses jax.bst.functional.sigmoid, as in the original #8171, but that
310
+ # function no longer actually has a custom_jvp! So we inline the old def.
311
+
312
+ @jax.custom_jvp
313
+ def sigmoid(x):
314
+ one = jnp.float32(1)
315
+ return jax.lax.div(one, jax.lax.add(one, jax.lax.exp(jax.lax.neg(x))))
316
+
317
+ sigmoid.defjvps(lambda g, ans, x: g * ans * (jnp.float32(1) - ans))
318
+
319
+ @jax.jit
320
+ def fwd():
321
+ a = jnp.array(1., 'float32')
322
+
323
+ def f(hx, _):
324
+ hx = sigmoid(hx + a)
325
+ return hx, None
326
+
327
+ hx = jnp.array(0., 'float32')
328
+ jax.lax.scan(f, hx, None, length=2)
329
+
330
+ with jax.checking_leaks():
331
+ fwd() # doesn't crash
@@ -15,12 +15,13 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- from typing import Optional
18
+ from typing import Optional, Union
19
19
 
20
+ import brainunit as u
20
21
  import jax
21
- import jax.numpy as jnp
22
22
 
23
- from .._utils import set_module_as
23
+ from brainstate._utils import set_module_as
24
+ from brainstate.typing import ArrayLike
24
25
 
25
26
  __all__ = [
26
27
  'weight_standardization',
@@ -29,18 +30,18 @@ __all__ = [
29
30
 
30
31
  @set_module_as('brainstate.functional')
31
32
  def weight_standardization(
32
- w: jax.typing.ArrayLike,
33
+ w: ArrayLike,
33
34
  eps: float = 1e-4,
34
35
  gain: Optional[jax.Array] = None,
35
36
  out_axis: int = -1,
36
- ):
37
+ ) -> Union[jax.Array, u.Quantity]:
37
38
  """
38
39
  Scaled Weight Standardization,
39
40
  see `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization <https://paperswithcode.com/paper/weight-standardization>`_.
40
41
 
41
42
  Parameters
42
43
  ----------
43
- w : jax.typing.ArrayLike
44
+ w : ArrayLike
44
45
  The weight tensor.
45
46
  eps : float
46
47
  A small value to avoid division by zero.
@@ -51,7 +52,7 @@ def weight_standardization(
51
52
 
52
53
  Returns
53
54
  -------
54
- jax.typing.ArrayLike
55
+ ArrayLike
55
56
  The scaled weight tensor.
56
57
  """
57
58
  if out_axis < 0:
@@ -63,9 +64,19 @@ def weight_standardization(
63
64
  fan_in *= w.shape[i]
64
65
  axes.append(i)
65
66
  # normalize the weight
66
- mean = jnp.mean(w, axis=axes, keepdims=True)
67
- var = jnp.var(w, axis=axes, keepdims=True)
68
- scale = jax.lax.rsqrt(jnp.maximum(var * fan_in, eps))
67
+ mean = u.math.mean(w, axis=axes, keepdims=True)
68
+ var = u.math.var(w, axis=axes, keepdims=True)
69
+
70
+ temp = u.math.maximum(var * fan_in, eps)
71
+ if isinstance(temp, u.Quantity):
72
+ unit = temp.unit
73
+ temp = temp.mantissa
74
+ if unit.is_unitless:
75
+ scale = jax.lax.rsqrt(temp)
76
+ else:
77
+ scale = u.Quantity(jax.lax.rsqrt(temp), unit=1 / unit ** 0.5)
78
+ else:
79
+ scale = jax.lax.rsqrt(temp)
69
80
  if gain is not None:
70
81
  scale = gain * scale
71
82
  shift = mean * scale
@@ -24,6 +24,7 @@ import numpy as np
24
24
  from brainstate._state import State
25
25
  from brainstate.typing import ArrayLike
26
26
  from ._base import to_size
27
+ from brainstate.mixin import Mode
27
28
 
28
29
  __all__ = [
29
30
  'param',
@@ -36,7 +37,7 @@ def _is_scalar(x):
36
37
  return bu.math.isscalar(x)
37
38
 
38
39
 
39
- def are_shapes_broadcastable(shape1, shape2):
40
+ def are_broadcastable_shapes(shape1, shape2):
40
41
  """
41
42
  Check if two shapes are broadcastable.
42
43
 
@@ -88,6 +89,7 @@ def param(
88
89
  batch_size: Optional[int] = None,
89
90
  allow_none: bool = True,
90
91
  allow_scalar: bool = True,
92
+ mode: Mode = None,
91
93
  ):
92
94
  """Initialize parameters.
93
95
 
@@ -143,7 +145,7 @@ def param(
143
145
  raise ValueError(f'Unknown parameter type: {type(parameter)}')
144
146
 
145
147
  # Check if the shape of the parameter matches the given size
146
- if not are_shapes_broadcastable(parameter.shape, sizes):
148
+ if not are_broadcastable_shapes(parameter.shape, sizes):
147
149
  raise ValueError(f'The shape of the parameter {parameter.shape} does not match with the given size {sizes}')
148
150
 
149
151
  # Expand the parameter to match the given batch size
brainstate/mixin.py CHANGED
@@ -18,7 +18,7 @@
18
18
  from typing import Sequence, Optional, TypeVar
19
19
  from typing import (_SpecialForm, _type_check, _remove_dups_flatten, _UnionGenericAlias)
20
20
 
21
- from .typing import PyTree
21
+ from brainstate.typing import PyTree
22
22
 
23
23
  T = TypeVar('T')
24
24
  State = None
brainstate/nn/__init__.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from . import metrics
16
17
  from ._base import *
17
18
  from ._base import __all__ as base_all
18
19
  from ._connections import *
@@ -39,6 +40,8 @@ from ._readout import *
39
40
  from ._readout import __all__ as readout_all
40
41
  from ._synouts import *
41
42
  from ._synouts import __all__ as synouts_all
43
+ from .event import *
44
+ from .event import __all__ as event_all
42
45
 
43
46
  __all__ = (
44
47
  base_all +
@@ -53,7 +56,8 @@ __all__ = (
53
56
  readout_all +
54
57
  synouts_all +
55
58
  _projection_all +
56
- _misc_all
59
+ _misc_all +
60
+ event_all
57
61
  )
58
62
 
59
63
  del (
@@ -68,5 +72,6 @@ del (
68
72
  readout_all,
69
73
  synouts_all,
70
74
  _projection_all,
71
- _misc_all
75
+ _misc_all,
76
+ event_all
72
77
  )
brainstate/nn/_base.py CHANGED
@@ -20,8 +20,8 @@ from __future__ import annotations
20
20
  import inspect
21
21
  from typing import Sequence, Optional, Tuple, Union
22
22
 
23
- from .._module import Module, UpdateReturn, Container, visible_module_dict
24
- from ..mixin import Mixin, DelayedInitializer, DelayedInit
23
+ from brainstate._module import Module, UpdateReturn, Container, visible_module_dict
24
+ from brainstate.mixin import Mixin, DelayedInitializer, DelayedInit
25
25
 
26
26
  __all__ = [
27
27
  'ExplicitInOutSize',
@@ -25,10 +25,10 @@ import jax
25
25
  import jax.numpy as jnp
26
26
 
27
27
  from ._base import DnnLayer
28
- from .. import init, functional
29
- from .._state import ParamState
30
- from ..mixin import Mode
31
- from ..typing import ArrayLike
28
+ from brainstate import init, functional
29
+ from brainstate._state import ParamState
30
+ from brainstate.mixin import Mode
31
+ from brainstate.typing import ArrayLike
32
32
 
33
33
  T = TypeVar('T')
34
34
 
@@ -24,11 +24,11 @@ import jax.numpy as jnp
24
24
 
25
25
  from ._base import ExplicitInOutSize
26
26
  from ._misc import exp_euler_step
27
- from .. import environ, init, surrogate
28
- from .._module import Dynamics
29
- from .._state import ShortTermState
30
- from ..mixin import DelayedInit, Mode, AlignPost
31
- from ..typing import DTypeLike, ArrayLike, Size
27
+ from brainstate import environ, init, surrogate
28
+ from brainstate._module import Dynamics
29
+ from brainstate._state import ShortTermState
30
+ from brainstate.mixin import DelayedInit, Mode, AlignPost
31
+ from brainstate.typing import DTypeLike, ArrayLike, Size
32
32
 
33
33
  __all__ = [
34
34
  # neuron models
@@ -19,16 +19,16 @@ from __future__ import annotations
19
19
 
20
20
  from typing import Optional
21
21
 
22
- import brainunit as bu
22
+ import brainunit as u
23
23
  import jax.numpy as jnp
24
24
  import jax.typing
25
25
 
26
26
  from ._base import ElementWiseBlock
27
- from .. import environ, random, functional as F
28
- from .._module import Module
29
- from .._state import ParamState
30
- from ..mixin import Mode
31
- from ..typing import ArrayLike
27
+ from brainstate import environ, random, functional as F
28
+ from brainstate._module import Module
29
+ from brainstate._state import ParamState
30
+ from brainstate.mixin import Mode
31
+ from brainstate.typing import ArrayLike
32
32
 
33
33
  __all__ = [
34
34
  # activation functions
@@ -83,7 +83,7 @@ class Threshold(Module, ElementWiseBlock):
83
83
  self.value = value
84
84
 
85
85
  def __call__(self, x: ArrayLike) -> ArrayLike:
86
- dtype = bu.math.get_dtype(x)
86
+ dtype = u.math.get_dtype(x)
87
87
  return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype),
88
88
  x,
89
89
  jnp.asarray(self.value, dtype=dtype))
@@ -1143,7 +1143,7 @@ class Dropout(Module, ElementWiseBlock):
1143
1143
  self.prob = prob
1144
1144
 
1145
1145
  def __call__(self, x):
1146
- dtype = bu.math.get_dtype(x)
1146
+ dtype = u.math.get_dtype(x)
1147
1147
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
1148
1148
  if fit_phase and self.prob < 1.:
1149
1149
  keep_mask = random.bernoulli(self.prob, x.shape)
@@ -1173,7 +1173,7 @@ class _DropoutNd(Module, ElementWiseBlock):
1173
1173
  self.channel_axis = channel_axis
1174
1174
 
1175
1175
  def __call__(self, x):
1176
- dtype = bu.math.get_dtype(x)
1176
+ dtype = u.math.get_dtype(x)
1177
1177
  # get fit phase
1178
1178
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
1179
1179
 
@@ -17,9 +17,9 @@ from typing import Optional, Callable, Union
17
17
 
18
18
  from ._base import DnnLayer
19
19
  from .. import init
20
- from .._state import ParamState
21
- from ..mixin import Mode, Training
22
- from ..typing import ArrayLike
20
+ from brainstate._state import ParamState
21
+ from brainstate.mixin import Mode, Training
22
+ from brainstate.typing import ArrayLike
23
23
 
24
24
  __all__ = [
25
25
  'Embedding',
@@ -25,9 +25,9 @@ import jax.numpy as jnp
25
25
 
26
26
  from ._base import DnnLayer
27
27
  from .. import environ, init
28
- from .._state import LongTermState, ParamState
29
- from ..mixin import Mode
30
- from ..typing import DTypeLike, ArrayLike, Size, Axes
28
+ from brainstate._state import LongTermState, ParamState
29
+ from brainstate.mixin import Mode
30
+ from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
31
31
 
32
32
  __all__ = [
33
33
  'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
brainstate/nn/_others.py CHANGED
@@ -23,8 +23,8 @@ import brainunit as bu
23
23
  import jax.numpy as jnp
24
24
 
25
25
  from ._base import DnnLayer
26
- from .. import random, environ, typing, init
27
- from ..mixin import Mode
26
+ from brainstate.mixin import Mode
27
+ from brainstate import random, environ, typing, init
28
28
 
29
29
  __all__ = [
30
30
  'DropoutFixed',
@@ -21,7 +21,7 @@ import functools
21
21
  from typing import Sequence, Optional
22
22
  from typing import Union, Tuple, Callable, List
23
23
 
24
- import brainunit as bu
24
+ import brainunit as u
25
25
  import jax
26
26
  import jax.numpy as jnp
27
27
  import numpy as np
@@ -29,7 +29,7 @@ import numpy as np
29
29
  from ._base import DnnLayer, ExplicitInOutSize
30
30
  from .. import environ
31
31
  from ..mixin import Mode
32
- from ..typing import Size
32
+ from brainstate.typing import Size
33
33
 
34
34
  __all__ = [
35
35
  'Flatten', 'Unflatten',
@@ -85,7 +85,7 @@ class Flatten(DnnLayer, ExplicitInOutSize):
85
85
 
86
86
  if in_size is not None:
87
87
  self.in_size = tuple(in_size)
88
- y = jax.eval_shape(functools.partial(bu.math.flatten, start_axis=start_axis, end_axis=end_axis),
88
+ y = jax.eval_shape(functools.partial(u.math.flatten, start_axis=start_axis, end_axis=end_axis),
89
89
  jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
90
90
  self.out_size = y.shape
91
91
 
@@ -101,7 +101,7 @@ class Flatten(DnnLayer, ExplicitInOutSize):
101
101
  start_axis = self.start_axis + dim_diff
102
102
  else:
103
103
  start_axis = x.ndim + self.start_axis
104
- return bu.math.flatten(x, start_axis, self.end_axis)
104
+ return u.math.flatten(x, start_axis, self.end_axis)
105
105
 
106
106
  def __repr__(self) -> str:
107
107
  return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
@@ -153,12 +153,12 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
153
153
 
154
154
  if in_size is not None:
155
155
  self.in_size = tuple(in_size)
156
- y = jax.eval_shape(functools.partial(bu.math.unflatten, axis=axis, sizes=sizes),
156
+ y = jax.eval_shape(functools.partial(u.math.unflatten, axis=axis, sizes=sizes),
157
157
  jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
158
158
  self.out_size = y.shape
159
159
 
160
160
  def update(self, x):
161
- return bu.math.unflatten(x, self.axis, self.sizes)
161
+ return u.math.unflatten(x, self.axis, self.sizes)
162
162
 
163
163
  def __repr__(self):
164
164
  return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
@@ -27,7 +27,7 @@ from .. import random, init, functional
27
27
  from .._module import Module
28
28
  from .._state import ShortTermState, ParamState
29
29
  from ..mixin import DelayedInit, Mode
30
- from ..typing import ArrayLike
30
+ from brainstate.typing import ArrayLike
31
31
 
32
32
  __all__ = [
33
33
  'RNNCell', 'ValinaRNNCell', 'GRUCell', 'MGUCell', 'LSTMCell', 'URLSTMCell',
brainstate/nn/_readout.py CHANGED
@@ -29,7 +29,7 @@ from ._misc import exp_euler_step
29
29
  from .. import environ, init, surrogate
30
30
  from .._state import ShortTermState, ParamState
31
31
  from ..mixin import Mode
32
- from ..typing import Size, ArrayLike, DTypeLike
32
+ from brainstate.typing import Size, ArrayLike, DTypeLike
33
33
 
34
34
  __all__ = [
35
35
  'LeakyRateReadout',