brainstate 0.0.1__py2.py3-none-any.whl → 0.0.1.post20240622__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +4 -5
- brainstate/_module.py +191 -48
- brainstate/_module_test.py +95 -21
- brainstate/_state.py +17 -0
- brainstate/environ.py +2 -2
- brainstate/functional/__init__.py +3 -2
- brainstate/functional/_activations.py +7 -26
- brainstate/functional/_normalization.py +3 -0
- brainstate/functional/_others.py +49 -0
- brainstate/functional/_spikes.py +0 -1
- brainstate/mixin.py +2 -2
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_base.py +10 -7
- brainstate/nn/_dynamics.py +20 -0
- brainstate/nn/_elementwise.py +5 -4
- brainstate/nn/_embedding.py +66 -0
- brainstate/nn/_misc.py +4 -3
- brainstate/nn/_others.py +3 -2
- brainstate/nn/_poolings.py +21 -20
- brainstate/nn/_poolings_test.py +4 -4
- brainstate/nn/_rate_rnns.py +17 -0
- brainstate/nn/_readout.py +6 -0
- brainstate/optim/__init__.py +0 -1
- brainstate/optim/_lr_scheduler_test.py +13 -0
- brainstate/optim/_sgd_optimizer.py +18 -17
- brainstate/transform/__init__.py +2 -3
- brainstate/transform/_autograd.py +1 -1
- brainstate/transform/_autograd_test.py +0 -2
- brainstate/transform/_jit.py +47 -21
- brainstate/transform/_jit_test.py +0 -3
- brainstate/transform/_make_jaxpr.py +164 -3
- brainstate/transform/_make_jaxpr_test.py +0 -2
- brainstate/transform/_progress_bar.py +1 -3
- brainstate/util.py +0 -1
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/METADATA +9 -17
- brainstate-0.0.1.post20240622.dist-info/RECORD +64 -0
- brainstate/math/__init__.py +0 -21
- brainstate/math/_einops.py +0 -787
- brainstate/math/_einops_parsing.py +0 -169
- brainstate/math/_einops_parsing_test.py +0 -126
- brainstate/math/_einops_test.py +0 -346
- brainstate/math/_misc.py +0 -298
- brainstate/math/_misc_test.py +0 -58
- brainstate/nn/functional/__init__.py +0 -25
- brainstate/nn/functional/_activations.py +0 -754
- brainstate/nn/functional/_normalization.py +0 -69
- brainstate/nn/functional/_spikes.py +0 -90
- brainstate/nn/init/__init__.py +0 -26
- brainstate/nn/init/_base.py +0 -36
- brainstate/nn/init/_generic.py +0 -175
- brainstate/nn/init/_random_inits.py +0 -489
- brainstate/nn/init/_regular_inits.py +0 -109
- brainstate/nn/surrogate.py +0 -1740
- brainstate-0.0.1.dist-info/RECORD +0 -79
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/top_level.txt +0 -0
brainstate/environ.py
CHANGED
@@ -18,12 +18,12 @@ from .util import MemScaling, IdMemScaling
|
|
18
18
|
__all__ = [
|
19
19
|
'set', 'context', 'get', 'all',
|
20
20
|
'set_host_device_count', 'set_platform',
|
21
|
-
'get_host_device_count', 'get_platform',
|
21
|
+
'get_host_device_count', 'get_platform',
|
22
|
+
'get_dt', 'get_mode', 'get_mem_scaling', 'get_precision',
|
22
23
|
'tolerance',
|
23
24
|
'dftype', 'ditype', 'dutype', 'dctype',
|
24
25
|
]
|
25
26
|
|
26
|
-
|
27
27
|
# Default, there are several shared arguments in the global context.
|
28
28
|
I = 'i' # the index of the current computation.
|
29
29
|
T = 't' # the current time of the current computation.
|
@@ -18,8 +18,9 @@ from ._activations import *
|
|
18
18
|
from ._activations import __all__ as __activations_all__
|
19
19
|
from ._normalization import *
|
20
20
|
from ._normalization import __all__ as __others_all__
|
21
|
+
from ._others import *
|
22
|
+
from ._others import __all__ as __others_all__
|
21
23
|
from ._spikes import *
|
22
24
|
from ._spikes import __all__ as __spikes_all__
|
23
25
|
|
24
|
-
__all__ = __spikes_all__ + __others_all__ + __activations_all__
|
25
|
-
|
26
|
+
__all__ = __spikes_all__ + __others_all__ + __activations_all__ + __others_all__
|
@@ -27,7 +27,7 @@ import jax.numpy as jnp
|
|
27
27
|
from jax.scipy.special import logsumexp
|
28
28
|
from jax.typing import ArrayLike
|
29
29
|
|
30
|
-
from
|
30
|
+
from .. import random
|
31
31
|
|
32
32
|
__all__ = [
|
33
33
|
"tanh",
|
@@ -136,10 +136,7 @@ def prelu(x, a=0.25):
|
|
136
136
|
parameter :math:`a` across all input channels. If called with `nn.PReLU(nChannels)`,
|
137
137
|
a separate :math:`a` is used for each input channel.
|
138
138
|
"""
|
139
|
-
|
140
|
-
return jnp.where(x >= jnp.asarray(0., dtype),
|
141
|
-
x,
|
142
|
-
jnp.asarray(a, dtype) * x)
|
139
|
+
return jnp.where(x >= 0., x, a * x)
|
143
140
|
|
144
141
|
|
145
142
|
def soft_shrink(x, lambd=0.5):
|
@@ -161,11 +158,7 @@ def soft_shrink(x, lambd=0.5):
|
|
161
158
|
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
162
159
|
- Output: :math:`(*)`, same shape as the input.
|
163
160
|
"""
|
164
|
-
|
165
|
-
lambd = jnp.asarray(lambd, dtype)
|
166
|
-
return jnp.where(x > lambd,
|
167
|
-
x - lambd,
|
168
|
-
jnp.where(x < -lambd, x + lambd, jnp.asarray(0., dtype)))
|
161
|
+
return jnp.where(x > lambd, x - lambd, jnp.where(x < -lambd, x + lambd, 0.))
|
169
162
|
|
170
163
|
|
171
164
|
def mish(x):
|
@@ -217,9 +210,8 @@ def rrelu(x, lower=0.125, upper=0.3333333333333333):
|
|
217
210
|
.. _`Empirical Evaluation of Rectified Activations in Convolutional Network`:
|
218
211
|
https://arxiv.org/abs/1505.00853
|
219
212
|
"""
|
220
|
-
|
221
|
-
|
222
|
-
return jnp.where(x >= jnp.asarray(0., dtype), x, jnp.asarray(a, dtype) * x)
|
213
|
+
a = random.uniform(lower, upper, size=jnp.shape(x), dtype=x.dtype)
|
214
|
+
return jnp.where(x >= 0., x, a * x)
|
223
215
|
|
224
216
|
|
225
217
|
def hard_shrink(x, lambd=0.5):
|
@@ -243,11 +235,7 @@ def hard_shrink(x, lambd=0.5):
|
|
243
235
|
- Output: :math:`(*)`, same shape as the input.
|
244
236
|
|
245
237
|
"""
|
246
|
-
|
247
|
-
lambd = jnp.asarray(lambd, dtype)
|
248
|
-
return jnp.where(x > lambd,
|
249
|
-
x,
|
250
|
-
jnp.where(x < -lambd, x, jnp.asarray(0., dtype)))
|
238
|
+
return jnp.where(x > lambd, x, jnp.where(x < -lambd, x, 0.))
|
251
239
|
|
252
240
|
|
253
241
|
def relu(x: ArrayLike) -> jax.Array:
|
@@ -298,8 +286,7 @@ def squareplus(x: ArrayLike, b: ArrayLike = 4) -> jax.Array:
|
|
298
286
|
x : input array
|
299
287
|
b : smoothness parameter
|
300
288
|
"""
|
301
|
-
|
302
|
-
return jax.nn.squareplus(x, jnp.asarray(b, dtype))
|
289
|
+
return jax.nn.squareplus(x, b)
|
303
290
|
|
304
291
|
|
305
292
|
def softplus(x: ArrayLike) -> jax.Array:
|
@@ -417,8 +404,6 @@ def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
|
|
417
404
|
See also:
|
418
405
|
:func:`selu`
|
419
406
|
"""
|
420
|
-
dtype = math.get_dtype(x)
|
421
|
-
alpha = jnp.asarray(alpha, dtype)
|
422
407
|
return jax.nn.elu(x, alpha)
|
423
408
|
|
424
409
|
|
@@ -445,8 +430,6 @@ def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> jax.Array:
|
|
445
430
|
See also:
|
446
431
|
:func:`relu`
|
447
432
|
"""
|
448
|
-
dtype = math.get_dtype(x)
|
449
|
-
negative_slope = jnp.asarray(negative_slope, dtype)
|
450
433
|
return jax.nn.leaky_relu(x, negative_slope=negative_slope)
|
451
434
|
|
452
435
|
|
@@ -493,8 +476,6 @@ def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> jax.Array:
|
|
493
476
|
Returns:
|
494
477
|
An array.
|
495
478
|
"""
|
496
|
-
dtype = math.get_dtype(x)
|
497
|
-
alpha = jnp.asarray(alpha, dtype)
|
498
479
|
return jax.nn.celu(x, alpha)
|
499
480
|
|
500
481
|
|
@@ -20,11 +20,14 @@ from typing import Optional
|
|
20
20
|
import jax
|
21
21
|
import jax.numpy as jnp
|
22
22
|
|
23
|
+
from .._utils import set_module_as
|
24
|
+
|
23
25
|
__all__ = [
|
24
26
|
'weight_standardization',
|
25
27
|
]
|
26
28
|
|
27
29
|
|
30
|
+
@set_module_as('brainstate.functional')
|
28
31
|
def weight_standardization(
|
29
32
|
w: jax.typing.ArrayLike,
|
30
33
|
eps: float = 1e-4,
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
from functools import partial
|
19
|
+
from typing import Any
|
20
|
+
|
21
|
+
import jax
|
22
|
+
import jax.numpy as jnp
|
23
|
+
|
24
|
+
PyTree = Any
|
25
|
+
|
26
|
+
__all__ = [
|
27
|
+
'clip_grad_norm',
|
28
|
+
]
|
29
|
+
|
30
|
+
|
31
|
+
def clip_grad_norm(
|
32
|
+
grad: PyTree,
|
33
|
+
max_norm: float | jax.Array,
|
34
|
+
norm_type: int | str | None = None
|
35
|
+
):
|
36
|
+
"""
|
37
|
+
Clips gradient norm of an iterable of parameters.
|
38
|
+
|
39
|
+
The norm is computed over all gradients together, as if they were
|
40
|
+
concatenated into a single vector. Gradients are modified in-place.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
grad (PyTree): an iterable of Tensors or a single Tensor that will have gradients normalized
|
44
|
+
max_norm (float): max norm of the gradients.
|
45
|
+
norm_type (int, str, None): type of the used p-norm. Can be ``'inf'`` for infinity norm.
|
46
|
+
"""
|
47
|
+
norm_fn = partial(jnp.linalg.norm, ord=norm_type)
|
48
|
+
norm = norm_fn(jnp.asarray(jax.tree.leaves(jax.tree.map(norm_fn, grad))))
|
49
|
+
return jax.tree.map(lambda x: jnp.where(norm < max_norm, x, x * max_norm / (norm + 1e-6)), grad)
|
brainstate/functional/_spikes.py
CHANGED
brainstate/mixin.py
CHANGED
@@ -68,7 +68,7 @@ class DelayedInit(Mixin):
|
|
68
68
|
Note this Mixin can be applied in any Python object.
|
69
69
|
"""
|
70
70
|
|
71
|
-
|
71
|
+
non_hashable_params: Optional[Sequence[str]] = None
|
72
72
|
|
73
73
|
@classmethod
|
74
74
|
def delayed(cls, *args, **kwargs) -> 'DelayedInitializer':
|
@@ -94,7 +94,7 @@ class DelayedInitializer(metaclass=NoSubclassMeta):
|
|
94
94
|
"""
|
95
95
|
|
96
96
|
def __init__(self, cls: T, *desc_tuple, **desc_dict):
|
97
|
-
self.cls = cls
|
97
|
+
self.cls: type = cls
|
98
98
|
|
99
99
|
# arguments
|
100
100
|
self.args = desc_tuple
|
brainstate/nn/__init__.py
CHANGED
@@ -21,6 +21,8 @@ from ._dynamics import *
|
|
21
21
|
from ._dynamics import __all__ as dynamics_all
|
22
22
|
from ._elementwise import *
|
23
23
|
from ._elementwise import __all__ as elementwise_all
|
24
|
+
from ._embedding import *
|
25
|
+
from ._embedding import __all__ as embed_all
|
24
26
|
from ._misc import *
|
25
27
|
from ._misc import __all__ as _misc_all
|
26
28
|
from ._normalizations import *
|
@@ -43,6 +45,7 @@ __all__ = (
|
|
43
45
|
connections_all +
|
44
46
|
dynamics_all +
|
45
47
|
elementwise_all +
|
48
|
+
embed_all +
|
46
49
|
normalizations_all +
|
47
50
|
others_all +
|
48
51
|
poolings_all +
|
@@ -58,6 +61,7 @@ del (
|
|
58
61
|
connections_all,
|
59
62
|
dynamics_all,
|
60
63
|
elementwise_all,
|
64
|
+
embed_all,
|
61
65
|
normalizations_all,
|
62
66
|
others_all,
|
63
67
|
poolings_all,
|
brainstate/nn/_base.py
CHANGED
@@ -55,22 +55,24 @@ class ExplicitInOutSize(Mixin):
|
|
55
55
|
|
56
56
|
@property
|
57
57
|
def in_size(self) -> Tuple[int, ...]:
|
58
|
-
if self._in_size is None:
|
59
|
-
raise ValueError(f"The input shape is not set in this node: {self} ")
|
60
58
|
return self._in_size
|
61
59
|
|
62
60
|
@in_size.setter
|
63
|
-
def in_size(self, in_size: Sequence[int]):
|
61
|
+
def in_size(self, in_size: Sequence[int] | int):
|
62
|
+
if isinstance(in_size, int):
|
63
|
+
in_size = (in_size,)
|
64
|
+
assert isinstance(in_size, (tuple, list)), f"Invalid type of in_size: {type(in_size)}"
|
64
65
|
self._in_size = tuple(in_size)
|
65
66
|
|
66
67
|
@property
|
67
68
|
def out_size(self) -> Tuple[int, ...]:
|
68
|
-
if self._out_size is None:
|
69
|
-
raise ValueError(f"The output shape is not set in this node: {self}")
|
70
69
|
return self._out_size
|
71
70
|
|
72
71
|
@out_size.setter
|
73
|
-
def out_size(self, out_size: Sequence[int]):
|
72
|
+
def out_size(self, out_size: Sequence[int] | int):
|
73
|
+
if isinstance(out_size, int):
|
74
|
+
out_size = (out_size,)
|
75
|
+
assert isinstance(out_size, (tuple, list)), f"Invalid type of out_size: {type(out_size)}"
|
74
76
|
self._out_size = tuple(out_size)
|
75
77
|
|
76
78
|
|
@@ -152,7 +154,8 @@ class Sequential(Module, UpdateReturn, Container, ExplicitInOutSize):
|
|
152
154
|
self.children = visible_module_dict(self.format_elements(object, first, *tuple_modules, **dict_modules))
|
153
155
|
|
154
156
|
# the input and output shape
|
155
|
-
|
157
|
+
if first.in_size is not None:
|
158
|
+
self.in_size = first.in_size
|
156
159
|
self.out_size = tuple(in_size)
|
157
160
|
|
158
161
|
def _format_module(self, module, in_size):
|
brainstate/nn/_dynamics.py
CHANGED
@@ -103,6 +103,9 @@ class IF(Neuron):
|
|
103
103
|
def init_state(self, batch_size: int = None, **kwargs):
|
104
104
|
self.V = ShortTermState(init.param(jnp.zeros, self.varshape, batch_size))
|
105
105
|
|
106
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
107
|
+
self.V.value = init.param(jnp.zeros, self.varshape, batch_size)
|
108
|
+
|
106
109
|
def get_spike(self, V=None):
|
107
110
|
V = self.V.value if V is None else V
|
108
111
|
v_scaled = (V - self.V_th) / self.V_th
|
@@ -160,6 +163,9 @@ class LIF(Neuron):
|
|
160
163
|
def init_state(self, batch_size: int = None, **kwargs):
|
161
164
|
self.V = ShortTermState(init.param(init.Constant(self.V_reset), self.varshape, batch_size))
|
162
165
|
|
166
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
167
|
+
self.V.value = init.param(init.Constant(self.V_reset), self.varshape, batch_size)
|
168
|
+
|
163
169
|
def get_spike(self, V=None):
|
164
170
|
V = self.V.value if V is None else V
|
165
171
|
v_scaled = (V - self.V_th) / self.V_th
|
@@ -214,6 +220,10 @@ class ALIF(Neuron):
|
|
214
220
|
self.V = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
|
215
221
|
self.a = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
|
216
222
|
|
223
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
224
|
+
self.V.value = init.param(init.Constant(0.), self.varshape, batch_size)
|
225
|
+
self.a.value = init.param(init.Constant(0.), self.varshape, batch_size)
|
226
|
+
|
217
227
|
def get_spike(self, V=None, a=None):
|
218
228
|
V = self.V.value if V is None else V
|
219
229
|
a = self.a.value if a is None else a
|
@@ -275,6 +285,9 @@ class Expon(Synapse):
|
|
275
285
|
def init_state(self, batch_size: int = None, **kwargs):
|
276
286
|
self.g = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
|
277
287
|
|
288
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
289
|
+
self.g.value = init.param(init.Constant(0.), self.varshape, batch_size)
|
290
|
+
|
278
291
|
def update(self, x=None):
|
279
292
|
self.g.value = exp_euler_step(self.dg, self.g.value, environ.get('t'))
|
280
293
|
if x is not None:
|
@@ -325,6 +338,10 @@ class STP(Synapse):
|
|
325
338
|
self.x = ShortTermState(init.param(init.Constant(1.), self.varshape, batch_size))
|
326
339
|
self.u = ShortTermState(init.param(init.Constant(self.U), self.varshape, batch_size))
|
327
340
|
|
341
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
342
|
+
self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
|
343
|
+
self.u.value = init.param(init.Constant(self.U), self.varshape, batch_size)
|
344
|
+
|
328
345
|
def du(self, u, t):
|
329
346
|
return self.U - u / self.tau_f
|
330
347
|
|
@@ -390,6 +407,9 @@ class STD(Synapse):
|
|
390
407
|
def init_state(self, batch_size: int = None, **kwargs):
|
391
408
|
self.x = ShortTermState(init.param(init.Constant(1.), self.varshape, batch_size))
|
392
409
|
|
410
|
+
def reset_state(self, batch_size: int = None, **kwargs):
|
411
|
+
self.x.value = init.param(init.Constant(1.), self.varshape, batch_size)
|
412
|
+
|
393
413
|
def update(self, pre_spike):
|
394
414
|
t = environ.get('t')
|
395
415
|
x = exp_euler_step(self.dx, self.x.value, t)
|
brainstate/nn/_elementwise.py
CHANGED
@@ -19,11 +19,12 @@ from __future__ import annotations
|
|
19
19
|
|
20
20
|
from typing import Optional
|
21
21
|
|
22
|
+
import brainunit as bu
|
22
23
|
import jax.numpy as jnp
|
23
24
|
import jax.typing
|
24
25
|
|
25
26
|
from ._base import ElementWiseBlock
|
26
|
-
from .. import
|
27
|
+
from .. import environ, random, functional as F
|
27
28
|
from .._module import Module
|
28
29
|
from .._state import ParamState
|
29
30
|
from ..mixin import Mode
|
@@ -82,7 +83,7 @@ class Threshold(Module, ElementWiseBlock):
|
|
82
83
|
self.value = value
|
83
84
|
|
84
85
|
def __call__(self, x: ArrayLike) -> ArrayLike:
|
85
|
-
dtype = math.get_dtype(x)
|
86
|
+
dtype = bu.math.get_dtype(x)
|
86
87
|
return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype),
|
87
88
|
x,
|
88
89
|
jnp.asarray(self.value, dtype=dtype))
|
@@ -1142,7 +1143,7 @@ class Dropout(Module, ElementWiseBlock):
|
|
1142
1143
|
self.prob = prob
|
1143
1144
|
|
1144
1145
|
def __call__(self, x):
|
1145
|
-
dtype = math.get_dtype(x)
|
1146
|
+
dtype = bu.math.get_dtype(x)
|
1146
1147
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
1147
1148
|
if fit_phase:
|
1148
1149
|
keep_mask = random.bernoulli(self.prob, x.shape)
|
@@ -1172,7 +1173,7 @@ class _DropoutNd(Module, ElementWiseBlock):
|
|
1172
1173
|
self.channel_axis = channel_axis
|
1173
1174
|
|
1174
1175
|
def __call__(self, x):
|
1175
|
-
dtype = math.get_dtype(x)
|
1176
|
+
dtype = bu.math.get_dtype(x)
|
1176
1177
|
# get fit phase
|
1177
1178
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
1178
1179
|
|
@@ -0,0 +1,66 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Optional, Callable, Union
|
17
|
+
|
18
|
+
from ._base import DnnLayer
|
19
|
+
from .. import init
|
20
|
+
from .._state import ParamState
|
21
|
+
from ..mixin import Mode, Training
|
22
|
+
from ..typing import ArrayLike
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
'Embedding',
|
26
|
+
]
|
27
|
+
|
28
|
+
|
29
|
+
class Embedding(DnnLayer):
|
30
|
+
r"""
|
31
|
+
A simple lookup table that stores embeddings of a fixed size.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
num_embeddings: Size of embedding dictionary. Must be non-negative.
|
35
|
+
embedding_size: Size of each embedding vector. Must be non-negative.
|
36
|
+
embed_init: The initializer for the embedding lookup table, of shape `(num_embeddings, embedding_size)`.
|
37
|
+
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
num_embeddings: int,
|
43
|
+
embedding_size: int,
|
44
|
+
embed_init: Union[Callable, ArrayLike] = init.LecunUniform(),
|
45
|
+
name: Optional[str] = None,
|
46
|
+
mode: Optional[Mode] = None,
|
47
|
+
):
|
48
|
+
super().__init__(name=name, mode=mode)
|
49
|
+
if num_embeddings < 0:
|
50
|
+
raise ValueError("num_embeddings must not be negative.")
|
51
|
+
if embedding_size < 0:
|
52
|
+
raise ValueError("embedding_size must not be negative.")
|
53
|
+
self.num_embeddings = num_embeddings
|
54
|
+
self.embedding_size = embedding_size
|
55
|
+
self.out_size = (embedding_size,)
|
56
|
+
|
57
|
+
weight = init.param(embed_init, (self.num_embeddings, self.embedding_size))
|
58
|
+
if self.mode.has(Training):
|
59
|
+
self.weight = ParamState(weight)
|
60
|
+
else:
|
61
|
+
self.weight = weight
|
62
|
+
|
63
|
+
def update(self, indices: ArrayLike):
|
64
|
+
if self.mode.has(Training):
|
65
|
+
return self.weight.value[indices]
|
66
|
+
return self.weight[indices]
|
brainstate/nn/_misc.py
CHANGED
@@ -20,9 +20,10 @@ from enum import Enum
|
|
20
20
|
from functools import wraps
|
21
21
|
from typing import Sequence, Callable
|
22
22
|
|
23
|
+
import brainunit as bu
|
23
24
|
import jax.numpy as jnp
|
24
25
|
|
25
|
-
from .. import environ
|
26
|
+
from .. import environ
|
26
27
|
from .._state import State
|
27
28
|
from ..transform import vector_grad
|
28
29
|
|
@@ -96,7 +97,7 @@ def exp_euler(fun):
|
|
96
97
|
)
|
97
98
|
dt = environ.get('dt')
|
98
99
|
linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
|
99
|
-
phi = math.exprel(dt * linear)
|
100
|
+
phi = bu.math.exprel(dt * linear)
|
100
101
|
return args[0] + dt * phi * derivative
|
101
102
|
|
102
103
|
return integral
|
@@ -128,5 +129,5 @@ def exp_euler_step(fun: Callable, *args, **kwargs):
|
|
128
129
|
)
|
129
130
|
dt = environ.get('dt')
|
130
131
|
linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
|
131
|
-
phi = math.exprel(dt * linear)
|
132
|
+
phi = bu.math.exprel(dt * linear)
|
132
133
|
return args[0] + dt * phi * derivative
|
brainstate/nn/_others.py
CHANGED
@@ -19,10 +19,11 @@ from __future__ import annotations
|
|
19
19
|
from functools import partial
|
20
20
|
from typing import Optional
|
21
21
|
|
22
|
+
import brainunit as bu
|
22
23
|
import jax.numpy as jnp
|
23
24
|
|
24
25
|
from ._base import DnnLayer
|
25
|
-
from .. import random,
|
26
|
+
from .. import random, environ, typing, init
|
26
27
|
from ..mixin import Mode
|
27
28
|
|
28
29
|
__all__ = [
|
@@ -88,7 +89,7 @@ class DropoutFixed(DnnLayer):
|
|
88
89
|
self.mask = init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size)
|
89
90
|
|
90
91
|
def update(self, x):
|
91
|
-
dtype = math.get_dtype(x)
|
92
|
+
dtype = bu.math.get_dtype(x)
|
92
93
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
93
94
|
if fit_phase:
|
94
95
|
assert self.mask.shape == x.shape, (f"Input shape {x.shape} does not match the mask shape {self.mask.shape}. "
|
brainstate/nn/_poolings.py
CHANGED
@@ -21,12 +21,13 @@ import functools
|
|
21
21
|
from typing import Sequence, Optional
|
22
22
|
from typing import Union, Tuple, Callable, List
|
23
23
|
|
24
|
+
import brainunit as bu
|
24
25
|
import jax
|
25
26
|
import jax.numpy as jnp
|
26
27
|
import numpy as np
|
27
28
|
|
28
29
|
from ._base import DnnLayer, ExplicitInOutSize
|
29
|
-
from .. import environ
|
30
|
+
from .. import environ
|
30
31
|
from ..mixin import Mode
|
31
32
|
from ..typing import Size
|
32
33
|
|
@@ -53,8 +54,8 @@ class Flatten(DnnLayer, ExplicitInOutSize):
|
|
53
54
|
|
54
55
|
Args:
|
55
56
|
in_size: Sequence of int. The shape of the input tensor.
|
56
|
-
|
57
|
-
|
57
|
+
start_axis: first dim to flatten (default = 1).
|
58
|
+
end_axis: last dim to flatten (default = -1).
|
58
59
|
|
59
60
|
Examples::
|
60
61
|
>>> import brainstate as bst
|
@@ -74,36 +75,36 @@ class Flatten(DnnLayer, ExplicitInOutSize):
|
|
74
75
|
|
75
76
|
def __init__(
|
76
77
|
self,
|
77
|
-
|
78
|
-
|
78
|
+
start_axis: int = 0,
|
79
|
+
end_axis: int = -1,
|
79
80
|
in_size: Optional[Size] = None
|
80
81
|
) -> None:
|
81
82
|
super().__init__()
|
82
|
-
self.
|
83
|
-
self.
|
83
|
+
self.start_axis = start_axis
|
84
|
+
self.end_axis = end_axis
|
84
85
|
|
85
86
|
if in_size is not None:
|
86
87
|
self.in_size = tuple(in_size)
|
87
|
-
y = jax.eval_shape(functools.partial(math.flatten,
|
88
|
+
y = jax.eval_shape(functools.partial(bu.math.flatten, start_axis=start_axis, end_axis=end_axis),
|
88
89
|
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
|
89
90
|
self.out_size = y.shape
|
90
91
|
|
91
92
|
def update(self, x):
|
92
93
|
if self._in_size is None:
|
93
|
-
|
94
|
+
start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
|
94
95
|
else:
|
95
96
|
assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
|
96
97
|
dim_diff = x.ndim - len(self.in_size)
|
97
98
|
if self.in_size != x.shape[dim_diff:]:
|
98
99
|
raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
|
99
|
-
if self.
|
100
|
-
|
100
|
+
if self.start_axis >= 0:
|
101
|
+
start_axis = self.start_axis + dim_diff
|
101
102
|
else:
|
102
|
-
|
103
|
-
return math.flatten(x,
|
103
|
+
start_axis = x.ndim + self.start_axis
|
104
|
+
return bu.math.flatten(x, start_axis, self.end_axis)
|
104
105
|
|
105
106
|
def __repr__(self) -> str:
|
106
|
-
return f'{self.__class__.__name__}(
|
107
|
+
return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'
|
107
108
|
|
108
109
|
|
109
110
|
class Unflatten(DnnLayer, ExplicitInOutSize):
|
@@ -124,7 +125,7 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
|
|
124
125
|
:math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
|
125
126
|
|
126
127
|
Args:
|
127
|
-
|
128
|
+
axis: int, Dimension to be unflattened.
|
128
129
|
sizes: Sequence of int. New shape of the unflattened dimension.
|
129
130
|
in_size: Sequence of int. The shape of the input tensor.
|
130
131
|
"""
|
@@ -132,7 +133,7 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
|
|
132
133
|
|
133
134
|
def __init__(
|
134
135
|
self,
|
135
|
-
|
136
|
+
axis: int,
|
136
137
|
sizes: Size,
|
137
138
|
mode: Mode = None,
|
138
139
|
name: str = None,
|
@@ -140,7 +141,7 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
|
|
140
141
|
) -> None:
|
141
142
|
super().__init__(mode=mode, name=name)
|
142
143
|
|
143
|
-
self.
|
144
|
+
self.axis = axis
|
144
145
|
self.sizes = sizes
|
145
146
|
if isinstance(sizes, (tuple, list)):
|
146
147
|
for idx, elem in enumerate(sizes):
|
@@ -152,15 +153,15 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
|
|
152
153
|
|
153
154
|
if in_size is not None:
|
154
155
|
self.in_size = tuple(in_size)
|
155
|
-
y = jax.eval_shape(functools.partial(math.unflatten,
|
156
|
+
y = jax.eval_shape(functools.partial(bu.math.unflatten, axis=axis, sizes=sizes),
|
156
157
|
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
|
157
158
|
self.out_size = y.shape
|
158
159
|
|
159
160
|
def update(self, x):
|
160
|
-
return math.unflatten(x, self.
|
161
|
+
return bu.math.unflatten(x, self.axis, self.sizes)
|
161
162
|
|
162
163
|
def __repr__(self):
|
163
|
-
return f'{self.__class__.__name__}(
|
164
|
+
return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'
|
164
165
|
|
165
166
|
|
166
167
|
class _MaxPool(DnnLayer, ExplicitInOutSize):
|
brainstate/nn/_poolings_test.py
CHANGED
@@ -18,7 +18,7 @@ class TestFlatten(parameterized.TestCase):
|
|
18
18
|
(10, 20, 30),
|
19
19
|
]:
|
20
20
|
arr = bst.random.rand(*size)
|
21
|
-
f = nn.Flatten(
|
21
|
+
f = nn.Flatten(start_axis=0)
|
22
22
|
out = f(arr)
|
23
23
|
self.assertTrue(out.shape == (np.prod(size),))
|
24
24
|
|
@@ -29,21 +29,21 @@ class TestFlatten(parameterized.TestCase):
|
|
29
29
|
(10, 20, 30),
|
30
30
|
]:
|
31
31
|
arr = bst.random.rand(*size)
|
32
|
-
f = nn.Flatten(
|
32
|
+
f = nn.Flatten(start_axis=1)
|
33
33
|
out = f(arr)
|
34
34
|
self.assertTrue(out.shape == (size[0], np.prod(size[1:])))
|
35
35
|
|
36
36
|
def test_flatten3(self):
|
37
37
|
size = (16, 32, 32, 8)
|
38
38
|
arr = bst.random.rand(*size)
|
39
|
-
f = nn.Flatten(
|
39
|
+
f = nn.Flatten(start_axis=0, in_size=(32, 8))
|
40
40
|
out = f(arr)
|
41
41
|
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
42
42
|
|
43
43
|
def test_flatten4(self):
|
44
44
|
size = (16, 32, 32, 8)
|
45
45
|
arr = bst.random.rand(*size)
|
46
|
-
f = nn.Flatten(
|
46
|
+
f = nn.Flatten(start_axis=1, in_size=(32, 32, 8))
|
47
47
|
out = f(arr)
|
48
48
|
self.assertTrue(out.shape == (16, 32, 32 * 8))
|
49
49
|
|