brainstate 0.0.1__py2.py3-none-any.whl → 0.0.1.post20240622__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. brainstate/__init__.py +4 -5
  2. brainstate/_module.py +191 -48
  3. brainstate/_module_test.py +95 -21
  4. brainstate/_state.py +17 -0
  5. brainstate/environ.py +2 -2
  6. brainstate/functional/__init__.py +3 -2
  7. brainstate/functional/_activations.py +7 -26
  8. brainstate/functional/_normalization.py +3 -0
  9. brainstate/functional/_others.py +49 -0
  10. brainstate/functional/_spikes.py +0 -1
  11. brainstate/mixin.py +2 -2
  12. brainstate/nn/__init__.py +4 -0
  13. brainstate/nn/_base.py +10 -7
  14. brainstate/nn/_dynamics.py +20 -0
  15. brainstate/nn/_elementwise.py +5 -4
  16. brainstate/nn/_embedding.py +66 -0
  17. brainstate/nn/_misc.py +4 -3
  18. brainstate/nn/_others.py +3 -2
  19. brainstate/nn/_poolings.py +21 -20
  20. brainstate/nn/_poolings_test.py +4 -4
  21. brainstate/nn/_rate_rnns.py +17 -0
  22. brainstate/nn/_readout.py +6 -0
  23. brainstate/optim/__init__.py +0 -1
  24. brainstate/optim/_lr_scheduler_test.py +13 -0
  25. brainstate/optim/_sgd_optimizer.py +18 -17
  26. brainstate/transform/__init__.py +2 -3
  27. brainstate/transform/_autograd.py +1 -1
  28. brainstate/transform/_autograd_test.py +0 -2
  29. brainstate/transform/_jit.py +47 -21
  30. brainstate/transform/_jit_test.py +0 -3
  31. brainstate/transform/_make_jaxpr.py +164 -3
  32. brainstate/transform/_make_jaxpr_test.py +0 -2
  33. brainstate/transform/_progress_bar.py +1 -3
  34. brainstate/util.py +0 -1
  35. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/METADATA +9 -17
  36. brainstate-0.0.1.post20240622.dist-info/RECORD +64 -0
  37. brainstate/math/__init__.py +0 -21
  38. brainstate/math/_einops.py +0 -787
  39. brainstate/math/_einops_parsing.py +0 -169
  40. brainstate/math/_einops_parsing_test.py +0 -126
  41. brainstate/math/_einops_test.py +0 -346
  42. brainstate/math/_misc.py +0 -298
  43. brainstate/math/_misc_test.py +0 -58
  44. brainstate/nn/functional/__init__.py +0 -25
  45. brainstate/nn/functional/_activations.py +0 -754
  46. brainstate/nn/functional/_normalization.py +0 -69
  47. brainstate/nn/functional/_spikes.py +0 -90
  48. brainstate/nn/init/__init__.py +0 -26
  49. brainstate/nn/init/_base.py +0 -36
  50. brainstate/nn/init/_generic.py +0 -175
  51. brainstate/nn/init/_random_inits.py +0 -489
  52. brainstate/nn/init/_regular_inits.py +0 -109
  53. brainstate/nn/surrogate.py +0 -1740
  54. brainstate-0.0.1.dist-info/RECORD +0 -79
  55. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/LICENSE +0 -0
  56. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/WHEEL +0 -0
  57. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/top_level.txt +0 -0
brainstate/environ.py CHANGED
@@ -18,12 +18,12 @@ from .util import MemScaling, IdMemScaling
18
18
  __all__ = [
19
19
  'set', 'context', 'get', 'all',
20
20
  'set_host_device_count', 'set_platform',
21
- 'get_host_device_count', 'get_platform', 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
21
+ 'get_host_device_count', 'get_platform',
22
+ 'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
22
23
  'tolerance',
23
24
  'dftype', 'ditype', 'dutype', 'dctype',
24
25
  ]
25
26
 
26
-
27
27
  # Default, there are several shared arguments in the global context.
28
28
  I = 'i' # the index of the current computation.
29
29
  T = 't' # the current time of the current computation.
@@ -18,8 +18,9 @@ from ._activations import *
18
18
  from ._activations import __all__ as __activations_all__
19
19
  from ._normalization import *
20
20
  from ._normalization import __all__ as __others_all__
21
+ from ._others import *
22
+ from ._others import __all__ as __others_all__
21
23
  from ._spikes import *
22
24
  from ._spikes import __all__ as __spikes_all__
23
25
 
24
- __all__ = __spikes_all__ + __others_all__ + __activations_all__
25
-
26
+ __all__ = __spikes_all__ + __others_all__ + __activations_all__ + __others_all__
@@ -27,7 +27,7 @@ import jax.numpy as jnp
27
27
  from jax.scipy.special import logsumexp
28
28
  from jax.typing import ArrayLike
29
29
 
30
- from brainstate import math, random
30
+ from .. import random
31
31
 
32
32
  __all__ = [
33
33
  "tanh",
@@ -136,10 +136,7 @@ def prelu(x, a=0.25):
136
136
  parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
137
137
  a separate :math:`a` is used for each input channel.
138
138
  """
139
- dtype = math.get_dtype(x)
140
- return jnp.where(x >= jnp.asarray(0., dtype),
141
- x,
142
- jnp.asarray(a, dtype) * x)
139
+ return jnp.where(x >= 0., x, a * x)
143
140
 
144
141
 
145
142
  def soft_shrink(x, lambd=0.5):
@@ -161,11 +158,7 @@ def soft_shrink(x, lambd=0.5):
161
158
  - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
162
159
  - Output: :math:`(*)`, same shape as the input.
163
160
  """
164
- dtype = math.get_dtype(x)
165
- lambd = jnp.asarray(lambd, dtype)
166
- return jnp.where(x > lambd,
167
- x - lambd,
168
- jnp.where(x < -lambd, x + lambd, jnp.asarray(0., dtype)))
161
+ return jnp.where(x > lambd, x - lambd, jnp.where(x < -lambd, x + lambd, 0.))
169
162
 
170
163
 
171
164
  def mish(x):
@@ -217,9 +210,8 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333):
217
210
  .. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
218
211
  https://arxiv.org/abs/1505.00853
219
212
  """
220
- dtype = math.get_dtype(x)
221
- a = random.uniform(lower, upper, size=jnp.shape(x), dtype=dtype)
222
- return jnp.where(x >= jnp.asarray(0., dtype), x, jnp.asarray(a, dtype) * x)
213
+ a = random.uniform(lower, upper, size=jnp.shape(x), dtype=x.dtype)
214
+ return jnp.where(x >= 0., x, a * x)
223
215
 
224
216
 
225
217
  def hard_shrink(x, lambd=0.5):
@@ -243,11 +235,7 @@ def hard_shrink(x, lambd=0.5):
243
235
  - Output: :math:`(*)`, same shape as the input.
244
236
 
245
237
  """
246
- dtype = math.get_dtype(x)
247
- lambd = jnp.asarray(lambd, dtype)
248
- return jnp.where(x > lambd,
249
- x,
250
- jnp.where(x < -lambd, x, jnp.asarray(0., dtype)))
238
+ return jnp.where(x > lambd, x, jnp.where(x < -lambd, x, 0.))
251
239
 
252
240
 
253
241
  def relu(x: ArrayLike) -> jax.Array:
@@ -298,8 +286,7 @@ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> jax.Array:
298
286
  x : input array
299
287
  b : smoothness parameter
300
288
  """
301
- dtype = math.get_dtype(x)
302
- return jax.nn.squareplus(x, jnp.asarray(b, dtype))
289
+ return jax.nn.squareplus(x, b)
303
290
 
304
291
 
305
292
  def softplus(x: ArrayLike) -> jax.Array:
@@ -417,8 +404,6 @@ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
417
404
  See also:
418
405
  :func:`selu`
419
406
  """
420
- dtype = math.get_dtype(x)
421
- alpha = jnp.asarray(alpha, dtype)
422
407
  return jax.nn.elu(x, alpha)
423
408
 
424
409
 
@@ -445,8 +430,6 @@ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> jax.Array:
445
430
  See also:
446
431
  :func:`relu`
447
432
  """
448
- dtype = math.get_dtype(x)
449
- negative_slope = jnp.asarray(negative_slope, dtype)
450
433
  return jax.nn.leaky_relu(x, negative_slope=negative_slope)
451
434
 
452
435
 
@@ -493,8 +476,6 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
493
476
  Returns:
494
477
  An array.
495
478
  """
496
- dtype = math.get_dtype(x)
497
- alpha = jnp.asarray(alpha, dtype)
498
479
  return jax.nn.celu(x, alpha)
499
480
 
500
481
 
@@ -20,11 +20,14 @@ from typing import Optional
20
20
  import jax
21
21
  import jax.numpy as jnp
22
22
 
23
+ from .._utils import set_module_as
24
+
23
25
  __all__ = [
24
26
  'weight_standardization',
25
27
  ]
26
28
 
27
29
 
30
+ @set_module_as('brainstate.functional')
28
31
  def weight_standardization(
29
32
  w: jax.typing.ArrayLike,
30
33
  eps: float = 1e-4,
@@ -0,0 +1,49 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ from functools import partial
19
+ from typing import Any
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+
24
+ PyTree = Any
25
+
26
+ __all__ = [
27
+ 'clip_grad_norm',
28
+ ]
29
+
30
+
31
+ def clip_grad_norm(
32
+ grad: PyTree,
33
+ max_norm: float | jax.Array,
34
+ norm_type: int | str | None = None
35
+ ):
36
+ """
37
+ Clips gradient norm of an iterable of parameters.
38
+
39
+ The norm is computed over all gradients together, as if they were
40
+ concatenated into a single vector. Gradients are modified in-place.
41
+
42
+ Args:
43
+ grad (PyTree): an iterable of Tensors or a single Tensor that will have gradients normalized
44
+ max_norm (float): max norm of the gradients.
45
+ norm_type (int, str, None): type of the used p-norm. Can be ``'inf'`` for infinity norm.
46
+ """
47
+ norm_fn = partial(jnp.linalg.norm, ord=norm_type)
48
+ norm = norm_fn(jnp.asarray(jax.tree.leaves(jax.tree.map(norm_fn, grad))))
49
+ return jax.tree.map(lambda x: jnp.where(norm < max_norm, x, x * max_norm / (norm + 1e-6)), grad)
@@ -87,4 +87,3 @@ def spike_bitwise(x, y, op: str):
87
87
  return spike_bitwise_ixor(x, y)
88
88
  else:
89
89
  raise NotImplementedError(f"Unsupported bitwise operation: {op}.")
90
-
brainstate/mixin.py CHANGED
@@ -68,7 +68,7 @@ class DelayedInit(Mixin):
68
68
  Note this Mixin can be applied in any Python object.
69
69
  """
70
70
 
71
- non_hash_params: Optional[Sequence[str]] = None
71
+ non_hashable_params: Optional[Sequence[str]] = None
72
72
 
73
73
  @classmethod
74
74
  def delayed(cls, *args, **kwargs) -> 'DelayedInitializer':
@@ -94,7 +94,7 @@ class DelayedInitializer(metaclass=NoSubclassMeta):
94
94
  """
95
95
 
96
96
  def __init__(self, cls: T, *desc_tuple, **desc_dict):
97
- self.cls = cls
97
+ self.cls: type = cls
98
98
 
99
99
  # arguments
100
100
  self.args = desc_tuple
brainstate/nn/__init__.py CHANGED
@@ -21,6 +21,8 @@ from ._dynamics import *
21
21
  from ._dynamics import __all__ as dynamics_all
22
22
  from ._elementwise import *
23
23
  from ._elementwise import __all__ as elementwise_all
24
+ from ._embedding import *
25
+ from ._embedding import __all__ as embed_all
24
26
  from ._misc import *
25
27
  from ._misc import __all__ as _misc_all
26
28
  from ._normalizations import *
@@ -43,6 +45,7 @@ __all__ = (
43
45
  connections_all +
44
46
  dynamics_all +
45
47
  elementwise_all +
48
+ embed_all +
46
49
  normalizations_all +
47
50
  others_all +
48
51
  poolings_all +
@@ -58,6 +61,7 @@ del (
58
61
  connections_all,
59
62
  dynamics_all,
60
63
  elementwise_all,
64
+ embed_all,
61
65
  normalizations_all,
62
66
  others_all,
63
67
  poolings_all,
brainstate/nn/_base.py CHANGED
@@ -55,22 +55,24 @@ class ExplicitInOutSize(Mixin):
55
55
 
56
56
  @property
57
57
  def in_size(self) -> Tuple[int, ...]:
58
- if self._in_size is None:
59
- raise ValueError(f"The input shape is not set in this node: {self} ")
60
58
  return self._in_size
61
59
 
62
60
  @in_size.setter
63
- def in_size(self, in_size: Sequence[int]):
61
+ def in_size(self, in_size: Sequence[int] | int):
62
+ if isinstance(in_size, int):
63
+ in_size = (in_size,)
64
+ assert isinstance(in_size, (tuple, list)), f"Invalid type of in_size: {type(in_size)}"
64
65
  self._in_size = tuple(in_size)
65
66
 
66
67
  @property
67
68
  def out_size(self) -> Tuple[int, ...]:
68
- if self._out_size is None:
69
- raise ValueError(f"The output shape is not set in this node: {self}")
70
69
  return self._out_size
71
70
 
72
71
  @out_size.setter
73
- def out_size(self, out_size: Sequence[int]):
72
+ def out_size(self, out_size: Sequence[int] | int):
73
+ if isinstance(out_size, int):
74
+ out_size = (out_size,)
75
+ assert isinstance(out_size, (tuple, list)), f"Invalid type of out_size: {type(out_size)}"
74
76
  self._out_size = tuple(out_size)
75
77
 
76
78
 
@@ -152,7 +154,8 @@ class Sequential(Module, UpdateReturn, Container, ExplicitInOutSize):
152
154
  self.children = visible_module_dict(self.format_elements(object, first, *tuple_modules, **dict_modules))
153
155
 
154
156
  # the input and output shape
155
- self.in_size = tuple(first.in_size)
157
+ if first.in_size is not None:
158
+ self.in_size = first.in_size
156
159
  self.out_size = tuple(in_size)
157
160
 
158
161
  def _format_module(self, module, in_size):
@@ -103,6 +103,9 @@ class IF(Neuron):
103
103
  def init_state(self, batch_size: int = None, **kwargs):
104
104
  self.V = ShortTermState(init.param(jnp.zeros, self.varshape, batch_size))
105
105
 
106
+ def reset_state(self, batch_size: int = None, **kwargs):
107
+ self.V.value = init.param(jnp.zeros, self.varshape, batch_size)
108
+
106
109
  def get_spike(self, V=None):
107
110
  V = self.V.value if V is None else V
108
111
  v_scaled = (V - self.V_th) / self.V_th
@@ -160,6 +163,9 @@ class LIF(Neuron):
160
163
  def init_state(self, batch_size: int = None, **kwargs):
161
164
  self.V = ShortTermState(init.param(init.Constant(self.V_reset), self.varshape, batch_size))
162
165
 
166
+ def reset_state(self, batch_size: int = None, **kwargs):
167
+ self.V.value = init.param(init.Constant(self.V_reset), self.varshape, batch_size)
168
+
163
169
  def get_spike(self, V=None):
164
170
  V = self.V.value if V is None else V
165
171
  v_scaled = (V - self.V_th) / self.V_th
@@ -214,6 +220,10 @@ class ALIF(Neuron):
214
220
  self.V = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
215
221
  self.a = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
216
222
 
223
+ def reset_state(self, batch_size: int = None, **kwargs):
224
+ self.V.value = init.param(init.Constant(0.), self.varshape, batch_size)
225
+ self.a.value = init.param(init.Constant(0.), self.varshape, batch_size)
226
+
217
227
  def get_spike(self, V=None, a=None):
218
228
  V = self.V.value if V is None else V
219
229
  a = self.a.value if a is None else a
@@ -275,6 +285,9 @@ class Expon(Synapse):
275
285
  def init_state(self, batch_size: int = None, **kwargs):
276
286
  self.g = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
277
287
 
288
+ def reset_state(self, batch_size: int = None, **kwargs):
289
+ self.g.value = init.param(init.Constant(0.), self.varshape, batch_size)
290
+
278
291
  def update(self, x=None):
279
292
  self.g.value = exp_euler_step(self.dg, self.g.value, environ.get('t'))
280
293
  if x is not None:
@@ -325,6 +338,10 @@ class STP(Synapse):
325
338
  self.x = ShortTermState(init.param(init.Constant(1.), self.varshape, batch_size))
326
339
  self.u = ShortTermState(init.param(init.Constant(self.U), self.varshape, batch_size))
327
340
 
341
+ def reset_state(self, batch_size: int = None, **kwargs):
342
+ self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
343
+ self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
344
+
328
345
  def du(self, u, t):
329
346
  return self.U - u / self.tau_f
330
347
 
@@ -390,6 +407,9 @@ class STD(Synapse):
390
407
  def init_state(self, batch_size: int = None, **kwargs):
391
408
  self.x = ShortTermState(init.param(init.Constant(1.), self.varshape, batch_size))
392
409
 
410
+ def reset_state(self, batch_size: int = None, **kwargs):
411
+ self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
412
+
393
413
  def update(self, pre_spike):
394
414
  t = environ.get('t')
395
415
  x = exp_euler_step(self.dx, self.x.value, t)
@@ -19,11 +19,12 @@ from __future__ import annotations
19
19
 
20
20
  from typing import Optional
21
21
 
22
+ import brainunit as bu
22
23
  import jax.numpy as jnp
23
24
  import jax.typing
24
25
 
25
26
  from ._base import ElementWiseBlock
26
- from .. import math, environ, random, functional as F
27
+ from .. import environ, random, functional as F
27
28
  from .._module import Module
28
29
  from .._state import ParamState
29
30
  from ..mixin import Mode
@@ -82,7 +83,7 @@ class Threshold(Module, ElementWiseBlock):
82
83
  self.value = value
83
84
 
84
85
  def __call__(self, x: ArrayLike) -> ArrayLike:
85
- dtype = math.get_dtype(x)
86
+ dtype = bu.math.get_dtype(x)
86
87
  return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype),
87
88
  x,
88
89
  jnp.asarray(self.value, dtype=dtype))
@@ -1142,7 +1143,7 @@ class Dropout(Module, ElementWiseBlock):
1142
1143
  self.prob = prob
1143
1144
 
1144
1145
  def __call__(self, x):
1145
- dtype = math.get_dtype(x)
1146
+ dtype = bu.math.get_dtype(x)
1146
1147
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
1147
1148
  if fit_phase:
1148
1149
  keep_mask = random.bernoulli(self.prob, x.shape)
@@ -1172,7 +1173,7 @@ class _DropoutNd(Module, ElementWiseBlock):
1172
1173
  self.channel_axis = channel_axis
1173
1174
 
1174
1175
  def __call__(self, x):
1175
- dtype = math.get_dtype(x)
1176
+ dtype = bu.math.get_dtype(x)
1176
1177
  # get fit phase
1177
1178
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
1178
1179
 
@@ -0,0 +1,66 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Optional, Callable, Union
17
+
18
+ from ._base import DnnLayer
19
+ from .. import init
20
+ from .._state import ParamState
21
+ from ..mixin import Mode, Training
22
+ from ..typing import ArrayLike
23
+
24
+ __all__ = [
25
+ 'Embedding',
26
+ ]
27
+
28
+
29
+ class Embedding(DnnLayer):
30
+ r"""
31
+ A simple lookup table that stores embeddings of a fixed size.
32
+
33
+ Args:
34
+ num_embeddings: Size of embedding dictionary. Must be non-negative.
35
+ embedding_size: Size of each embedding vector. Must be non-negative.
36
+ embed_init: The initializer for the embedding lookup table, of shape `(num_embeddings, embedding_size)`.
37
+
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ num_embeddings: int,
43
+ embedding_size: int,
44
+ embed_init: Union[Callable, ArrayLike] = init.LecunUniform(),
45
+ name: Optional[str] = None,
46
+ mode: Optional[Mode] = None,
47
+ ):
48
+ super().__init__(name=name, mode=mode)
49
+ if num_embeddings < 0:
50
+ raise ValueError("num_embeddings must not be negative.")
51
+ if embedding_size < 0:
52
+ raise ValueError("embedding_size must not be negative.")
53
+ self.num_embeddings = num_embeddings
54
+ self.embedding_size = embedding_size
55
+ self.out_size = (embedding_size,)
56
+
57
+ weight = init.param(embed_init, (self.num_embeddings, self.embedding_size))
58
+ if self.mode.has(Training):
59
+ self.weight = ParamState(weight)
60
+ else:
61
+ self.weight = weight
62
+
63
+ def update(self, indices: ArrayLike):
64
+ if self.mode.has(Training):
65
+ return self.weight.value[indices]
66
+ return self.weight[indices]
brainstate/nn/_misc.py CHANGED
@@ -20,9 +20,10 @@ from enum import Enum
20
20
  from functools import wraps
21
21
  from typing import Sequence, Callable
22
22
 
23
+ import brainunit as bu
23
24
  import jax.numpy as jnp
24
25
 
25
- from .. import environ, math
26
+ from .. import environ
26
27
  from .._state import State
27
28
  from ..transform import vector_grad
28
29
 
@@ -96,7 +97,7 @@ def exp_euler(fun):
96
97
  )
97
98
  dt = environ.get('dt')
98
99
  linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
99
- phi = math.exprel(dt * linear)
100
+ phi = bu.math.exprel(dt * linear)
100
101
  return args[0] + dt * phi * derivative
101
102
 
102
103
  return integral
@@ -128,5 +129,5 @@ def exp_euler_step(fun: Callable, *args, **kwargs):
128
129
  )
129
130
  dt = environ.get('dt')
130
131
  linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
131
- phi = math.exprel(dt * linear)
132
+ phi = bu.math.exprel(dt * linear)
132
133
  return args[0] + dt * phi * derivative
brainstate/nn/_others.py CHANGED
@@ -19,10 +19,11 @@ from __future__ import annotations
19
19
  from functools import partial
20
20
  from typing import Optional
21
21
 
22
+ import brainunit as bu
22
23
  import jax.numpy as jnp
23
24
 
24
25
  from ._base import DnnLayer
25
- from .. import random, math, environ, typing, init
26
+ from .. import random, environ, typing, init
26
27
  from ..mixin import Mode
27
28
 
28
29
  __all__ = [
@@ -88,7 +89,7 @@ class DropoutFixed(DnnLayer):
88
89
  self.mask = init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size)
89
90
 
90
91
  def update(self, x):
91
- dtype = math.get_dtype(x)
92
+ dtype = bu.math.get_dtype(x)
92
93
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
93
94
  if fit_phase:
94
95
  assert self.mask.shape == x.shape, (f"Input shape {x.shape} does not match the mask shape {self.mask.shape}. "
@@ -21,12 +21,13 @@ import functools
21
21
  from typing import Sequence, Optional
22
22
  from typing import Union, Tuple, Callable, List
23
23
 
24
+ import brainunit as bu
24
25
  import jax
25
26
  import jax.numpy as jnp
26
27
  import numpy as np
27
28
 
28
29
  from ._base import DnnLayer, ExplicitInOutSize
29
- from .. import environ, math
30
+ from .. import environ
30
31
  from ..mixin import Mode
31
32
  from ..typing import Size
32
33
 
@@ -53,8 +54,8 @@ class Flatten(DnnLayer, ExplicitInOutSize):
53
54
 
54
55
  Args:
55
56
  in_size: Sequence of int. The shape of the input tensor.
56
- start_dim: first dim to flatten (default = 1).
57
- end_dim: last dim to flatten (default = -1).
57
+ start_axis: first dim to flatten (default = 1).
58
+ end_axis: last dim to flatten (default = -1).
58
59
 
59
60
  Examples::
60
61
  >>> import brainstate as bst
@@ -74,36 +75,36 @@ class Flatten(DnnLayer, ExplicitInOutSize):
74
75
 
75
76
  def __init__(
76
77
  self,
77
- start_dim: int = 0,
78
- end_dim: int = -1,
78
+ start_axis: int = 0,
79
+ end_axis: int = -1,
79
80
  in_size: Optional[Size] = None
80
81
  ) -> None:
81
82
  super().__init__()
82
- self.start_dim = start_dim
83
- self.end_dim = end_dim
83
+ self.start_axis = start_axis
84
+ self.end_axis = end_axis
84
85
 
85
86
  if in_size is not None:
86
87
  self.in_size = tuple(in_size)
87
- y = jax.eval_shape(functools.partial(math.flatten, start_dim=start_dim, end_dim=end_dim),
88
+ y = jax.eval_shape(functools.partial(bu.math.flatten, start_axis=start_axis, end_axis=end_axis),
88
89
  jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
89
90
  self.out_size = y.shape
90
91
 
91
92
  def update(self, x):
92
93
  if self._in_size is None:
93
- start_dim = self.start_dim if self.start_dim >= 0 else x.ndim + self.start_dim
94
+ start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
94
95
  else:
95
96
  assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
96
97
  dim_diff = x.ndim - len(self.in_size)
97
98
  if self.in_size != x.shape[dim_diff:]:
98
99
  raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
99
- if self.start_dim >= 0:
100
- start_dim = self.start_dim + dim_diff
100
+ if self.start_axis >= 0:
101
+ start_axis = self.start_axis + dim_diff
101
102
  else:
102
- start_dim = x.ndim + self.start_dim
103
- return math.flatten(x, start_dim, self.end_dim)
103
+ start_axis = x.ndim + self.start_axis
104
+ return bu.math.flatten(x, start_axis, self.end_axis)
104
105
 
105
106
  def __repr__(self) -> str:
106
- return f'{self.__class__.__name__}(start_dim={self.start_dim}, end_dim={self.end_dim})'
107
+ return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
107
108
 
108
109
 
109
110
  class Unflatten(DnnLayer, ExplicitInOutSize):
@@ -124,7 +125,7 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
124
125
  :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
125
126
 
126
127
  Args:
127
- dim: int, Dimension to be unflattened.
128
+ axis: int, Dimension to be unflattened.
128
129
  sizes: Sequence of int. New shape of the unflattened dimension.
129
130
  in_size: Sequence of int. The shape of the input tensor.
130
131
  """
@@ -132,7 +133,7 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
132
133
 
133
134
  def __init__(
134
135
  self,
135
- dim: int,
136
+ axis: int,
136
137
  sizes: Size,
137
138
  mode: Mode = None,
138
139
  name: str = None,
@@ -140,7 +141,7 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
140
141
  ) -> None:
141
142
  super().__init__(mode=mode, name=name)
142
143
 
143
- self.dim = dim
144
+ self.axis = axis
144
145
  self.sizes = sizes
145
146
  if isinstance(sizes, (tuple, list)):
146
147
  for idx, elem in enumerate(sizes):
@@ -152,15 +153,15 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
152
153
 
153
154
  if in_size is not None:
154
155
  self.in_size = tuple(in_size)
155
- y = jax.eval_shape(functools.partial(math.unflatten, dim=dim, sizes=sizes),
156
+ y = jax.eval_shape(functools.partial(bu.math.unflatten, axis=axis, sizes=sizes),
156
157
  jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
157
158
  self.out_size = y.shape
158
159
 
159
160
  def update(self, x):
160
- return math.unflatten(x, self.dim, self.sizes)
161
+ return bu.math.unflatten(x, self.axis, self.sizes)
161
162
 
162
163
  def __repr__(self):
163
- return f'{self.__class__.__name__}(dim={self.dim}, sizes={self.sizes})'
164
+ return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
164
165
 
165
166
 
166
167
  class _MaxPool(DnnLayer, ExplicitInOutSize):
@@ -18,7 +18,7 @@ class TestFlatten(parameterized.TestCase):
18
18
  (10, 20, 30),
19
19
  ]:
20
20
  arr = bst.random.rand(*size)
21
- f = nn.Flatten(start_dim=0)
21
+ f = nn.Flatten(start_axis=0)
22
22
  out = f(arr)
23
23
  self.assertTrue(out.shape == (np.prod(size),))
24
24
 
@@ -29,21 +29,21 @@ class TestFlatten(parameterized.TestCase):
29
29
  (10, 20, 30),
30
30
  ]:
31
31
  arr = bst.random.rand(*size)
32
- f = nn.Flatten(start_dim=1)
32
+ f = nn.Flatten(start_axis=1)
33
33
  out = f(arr)
34
34
  self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
35
35
 
36
36
  def test_flatten3(self):
37
37
  size = (16, 32, 32, 8)
38
38
  arr = bst.random.rand(*size)
39
- f = nn.Flatten(start_dim=0, in_size=(32, 8))
39
+ f = nn.Flatten(start_axis=0, in_size=(32, 8))
40
40
  out = f(arr)
41
41
  self.assertTrue(out.shape == (16, 32, 32 * 8))
42
42
 
43
43
  def test_flatten4(self):
44
44
  size = (16, 32, 32, 8)
45
45
  arr = bst.random.rand(*size)
46
- f = nn.Flatten(start_dim=1, in_size=(32, 32, 8))
46
+ f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
47
47
  out = f(arr)
48
48
  self.assertTrue(out.shape == (16, 32, 32 * 8))
49
49