brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_embedding.py
DELETED
@@ -1,66 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from typing import Optional, Callable, Union
|
17
|
-
|
18
|
-
from ._base import DnnLayer
|
19
|
-
from .. import init
|
20
|
-
from brainstate._state import ParamState
|
21
|
-
from brainstate.mixin import Mode, Training
|
22
|
-
from brainstate.typing import ArrayLike
|
23
|
-
|
24
|
-
__all__ = [
|
25
|
-
'Embedding',
|
26
|
-
]
|
27
|
-
|
28
|
-
|
29
|
-
class Embedding(DnnLayer):
|
30
|
-
r"""
|
31
|
-
A simple lookup table that stores embeddings of a fixed size.
|
32
|
-
|
33
|
-
Args:
|
34
|
-
num_embeddings: Size of embedding dictionary. Must be non-negative.
|
35
|
-
embedding_size: Size of each embedding vector. Must be non-negative.
|
36
|
-
embed_init: The initializer for the embedding lookup table, of shape `(num_embeddings, embedding_size)`.
|
37
|
-
|
38
|
-
"""
|
39
|
-
|
40
|
-
def __init__(
|
41
|
-
self,
|
42
|
-
num_embeddings: int,
|
43
|
-
embedding_size: int,
|
44
|
-
embed_init: Union[Callable, ArrayLike] = init.LecunUniform(),
|
45
|
-
name: Optional[str] = None,
|
46
|
-
mode: Optional[Mode] = None,
|
47
|
-
):
|
48
|
-
super().__init__(name=name, mode=mode)
|
49
|
-
if num_embeddings < 0:
|
50
|
-
raise ValueError("num_embeddings must not be negative.")
|
51
|
-
if embedding_size < 0:
|
52
|
-
raise ValueError("embedding_size must not be negative.")
|
53
|
-
self.num_embeddings = num_embeddings
|
54
|
-
self.embedding_size = embedding_size
|
55
|
-
self.out_size = (embedding_size,)
|
56
|
-
|
57
|
-
weight = init.param(embed_init, (self.num_embeddings, self.embedding_size))
|
58
|
-
if self.mode.has(Training):
|
59
|
-
self.weight = ParamState(weight)
|
60
|
-
else:
|
61
|
-
self.weight = weight
|
62
|
-
|
63
|
-
def update(self, indices: ArrayLike):
|
64
|
-
if self.mode.has(Training):
|
65
|
-
return self.weight.value[indices]
|
66
|
-
return self.weight[indices]
|
brainstate/nn/_misc.py
DELETED
@@ -1,133 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
from __future__ import annotations
|
18
|
-
|
19
|
-
from enum import Enum
|
20
|
-
from functools import wraps
|
21
|
-
from typing import Sequence, Callable
|
22
|
-
|
23
|
-
import brainunit as bu
|
24
|
-
import jax.numpy as jnp
|
25
|
-
|
26
|
-
from .. import environ
|
27
|
-
from .._state import State
|
28
|
-
from ..transform import vector_grad
|
29
|
-
|
30
|
-
__all__ = [
|
31
|
-
# 'exp_euler',
|
32
|
-
'exp_euler_step',
|
33
|
-
]
|
34
|
-
|
35
|
-
git_issue_addr = 'https://github.com/brainpy/brainscale/issues'
|
36
|
-
|
37
|
-
|
38
|
-
def state_traceback(states: Sequence[State]):
|
39
|
-
"""
|
40
|
-
Traceback the states of the brain model.
|
41
|
-
|
42
|
-
Parameters
|
43
|
-
----------
|
44
|
-
states : Sequence[bst.State]
|
45
|
-
The states of the brain model.
|
46
|
-
|
47
|
-
Returns
|
48
|
-
-------
|
49
|
-
str
|
50
|
-
The traceback information of the states.
|
51
|
-
"""
|
52
|
-
state_info = []
|
53
|
-
for i, state in enumerate(states):
|
54
|
-
state_info.append(f'State {i}: {state}\n'
|
55
|
-
f'defined at \n'
|
56
|
-
f'{state.source_info.traceback}\n')
|
57
|
-
return '\n'.join(state_info)
|
58
|
-
|
59
|
-
|
60
|
-
class BaseEnum(Enum):
|
61
|
-
@classmethod
|
62
|
-
def get_by_name(cls, name: str):
|
63
|
-
for item in cls:
|
64
|
-
if item.name == name:
|
65
|
-
return item
|
66
|
-
raise ValueError(f'Cannot find the {cls.__name__} type {name}.')
|
67
|
-
|
68
|
-
@classmethod
|
69
|
-
def get(cls, type_: str | Enum):
|
70
|
-
if isinstance(type_, cls):
|
71
|
-
return type_
|
72
|
-
elif isinstance(type_, str):
|
73
|
-
return cls.get_by_name(type_)
|
74
|
-
else:
|
75
|
-
raise ValueError(f'Cannot find the {cls.__name__} type {type_}.')
|
76
|
-
|
77
|
-
|
78
|
-
def exp_euler(fun):
|
79
|
-
"""
|
80
|
-
Exponential Euler method for solving ODEs.
|
81
|
-
|
82
|
-
Args:
|
83
|
-
fun: Callable. The function to be solved.
|
84
|
-
|
85
|
-
Returns:
|
86
|
-
The integral function.
|
87
|
-
"""
|
88
|
-
|
89
|
-
@wraps(fun)
|
90
|
-
def integral(*args, **kwargs):
|
91
|
-
assert len(args) > 0, 'The input arguments should not be empty.'
|
92
|
-
if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]:
|
93
|
-
raise ValueError(
|
94
|
-
'The input data type should be float32, float64, float16, or bfloat16 '
|
95
|
-
'when using Exponential Euler method.'
|
96
|
-
f'But we got {args[0].dtype}.'
|
97
|
-
)
|
98
|
-
dt = environ.get('dt')
|
99
|
-
linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
|
100
|
-
phi = bu.math.exprel(dt * linear)
|
101
|
-
return args[0] + dt * phi * derivative
|
102
|
-
|
103
|
-
return integral
|
104
|
-
|
105
|
-
|
106
|
-
def exp_euler_step(fun: Callable, *args, **kwargs):
|
107
|
-
"""
|
108
|
-
Exponential Euler method for solving ODEs.
|
109
|
-
|
110
|
-
Examples
|
111
|
-
--------
|
112
|
-
>>> def fun(x, t):
|
113
|
-
... return -x
|
114
|
-
>>> x = 1.0
|
115
|
-
>>> exp_euler_step(fun, x, None)
|
116
|
-
|
117
|
-
Args:
|
118
|
-
fun: Callable. The function to be solved.
|
119
|
-
|
120
|
-
Returns:
|
121
|
-
The integral function.
|
122
|
-
"""
|
123
|
-
assert len(args) > 0, 'The input arguments should not be empty.'
|
124
|
-
if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]:
|
125
|
-
raise ValueError(
|
126
|
-
'The input data type should be float32, float64, float16, or bfloat16 '
|
127
|
-
'when using Exponential Euler method.'
|
128
|
-
f'But we got {args[0].dtype}.'
|
129
|
-
)
|
130
|
-
dt = environ.get('dt')
|
131
|
-
linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
|
132
|
-
phi = bu.math.exprel(dt * linear)
|
133
|
-
return args[0] + dt * phi * derivative
|
brainstate/nn/_normalizations.py
DELETED
@@ -1,389 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from __future__ import annotations
|
19
|
-
|
20
|
-
import numbers
|
21
|
-
from typing import Callable, Union, Sequence, Optional, Any
|
22
|
-
|
23
|
-
import jax
|
24
|
-
import jax.numpy as jnp
|
25
|
-
|
26
|
-
from ._base import DnnLayer
|
27
|
-
from .. import environ, init
|
28
|
-
from brainstate._state import LongTermState, ParamState
|
29
|
-
from brainstate.mixin import Mode
|
30
|
-
from brainstate.typing import DTypeLike, ArrayLike, Size, Axes
|
31
|
-
|
32
|
-
__all__ = [
|
33
|
-
'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d',
|
34
|
-
]
|
35
|
-
|
36
|
-
|
37
|
-
def _canonicalize_axes(ndim: int, feature_axes: Sequence[int]):
|
38
|
-
axes = []
|
39
|
-
for axis in feature_axes:
|
40
|
-
if axis < 0:
|
41
|
-
axis += ndim
|
42
|
-
if axis < 0 or axis >= ndim:
|
43
|
-
raise ValueError(f'Invalid axis {axis} for {ndim}D input')
|
44
|
-
axes.append(axis)
|
45
|
-
return tuple(axes)
|
46
|
-
|
47
|
-
|
48
|
-
def _abs_sq(x):
|
49
|
-
"""Computes the elementwise square of the absolute value |x|^2."""
|
50
|
-
if jnp.iscomplexobj(x):
|
51
|
-
return jax.lax.square(jax.lax.real(x)) + jax.lax.square(jax.lax.imag(x))
|
52
|
-
else:
|
53
|
-
return jax.lax.square(x)
|
54
|
-
|
55
|
-
|
56
|
-
def _compute_stats(
|
57
|
-
x: ArrayLike,
|
58
|
-
axes: Sequence[int],
|
59
|
-
dtype: DTypeLike,
|
60
|
-
axis_name: Optional[str] = None,
|
61
|
-
axis_index_groups: Optional[Sequence[int]] = None,
|
62
|
-
use_mean: bool = True,
|
63
|
-
):
|
64
|
-
"""Computes mean and variance statistics.
|
65
|
-
|
66
|
-
This implementation takes care of a few important details:
|
67
|
-
- Computes in float32 precision for stability in half precision training.
|
68
|
-
- mean and variance are computable in a single XLA fusion,
|
69
|
-
by using Var = E[|x|^2] - |E[x]|^2 instead of Var = E[|x - E[x]|^2]).
|
70
|
-
- Clips negative variances to zero which can happen due to
|
71
|
-
roundoff errors. This avoids downstream NaNs.
|
72
|
-
- Supports averaging across a parallel axis and subgroups of a parallel axis
|
73
|
-
with a single `lax.pmean` call to avoid latency.
|
74
|
-
|
75
|
-
Arguments:
|
76
|
-
x: Input array.
|
77
|
-
axes: The axes in ``x`` to compute mean and variance statistics for.
|
78
|
-
dtype: tp.Optional dtype specifying the minimal precision. Statistics
|
79
|
-
are always at least float32 for stability (default: dtype of x).
|
80
|
-
axis_name: tp.Optional name for the pmapped axis to compute mean over.
|
81
|
-
axis_index_groups: tp.Optional axis indices.
|
82
|
-
use_mean: If true, calculate the mean from the input and use it when
|
83
|
-
computing the variance. If false, set the mean to zero and compute
|
84
|
-
the variance without subtracting the mean.
|
85
|
-
|
86
|
-
Returns:
|
87
|
-
A pair ``(mean, var)``.
|
88
|
-
"""
|
89
|
-
if dtype is None:
|
90
|
-
dtype = jax.numpy.result_type(x)
|
91
|
-
# promote x to at least float32, this avoids half precision computation
|
92
|
-
# but preserves double or complex floating points
|
93
|
-
dtype = jax.numpy.promote_types(dtype, environ.dftype())
|
94
|
-
x = jnp.asarray(x, dtype)
|
95
|
-
|
96
|
-
# Compute mean and mean of squared values.
|
97
|
-
mean2 = jnp.mean(_abs_sq(x), axes)
|
98
|
-
if use_mean:
|
99
|
-
mean = jnp.mean(x, axes)
|
100
|
-
else:
|
101
|
-
mean = jnp.zeros(mean2.shape, dtype=dtype)
|
102
|
-
|
103
|
-
# If axis_name is provided, we need to average the mean and mean2 across
|
104
|
-
if axis_name is not None:
|
105
|
-
concatenated_mean = jnp.concatenate([mean, mean2])
|
106
|
-
mean, mean2 = jnp.split(
|
107
|
-
jax.lax.pmean(
|
108
|
-
concatenated_mean,
|
109
|
-
axis_name=axis_name,
|
110
|
-
axis_index_groups=axis_index_groups,
|
111
|
-
),
|
112
|
-
2,
|
113
|
-
)
|
114
|
-
|
115
|
-
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
|
116
|
-
# to floating point round-off errors.
|
117
|
-
var = jnp.maximum(0.0, mean2 - _abs_sq(mean))
|
118
|
-
return mean, var
|
119
|
-
|
120
|
-
|
121
|
-
def _normalize(
|
122
|
-
x: ArrayLike,
|
123
|
-
mean: Optional[ArrayLike],
|
124
|
-
var: Optional[ArrayLike],
|
125
|
-
weights: Optional[ParamState],
|
126
|
-
reduction_axes: Sequence[int],
|
127
|
-
dtype: DTypeLike,
|
128
|
-
epsilon: Union[numbers.Number, jax.Array],
|
129
|
-
):
|
130
|
-
"""Normalizes the input of a normalization layer and optionally applies a learned scale and bias.
|
131
|
-
|
132
|
-
Arguments:
|
133
|
-
x: The input.
|
134
|
-
mean: Mean to use for normalization.
|
135
|
-
var: Variance to use for normalization.
|
136
|
-
weights: The scale and bias parameters.
|
137
|
-
reduction_axes: The axes in ``x`` to reduce.
|
138
|
-
dtype: The dtype of the result (default: infer from input and params).
|
139
|
-
epsilon: Normalization epsilon.
|
140
|
-
|
141
|
-
Returns:
|
142
|
-
The normalized input.
|
143
|
-
"""
|
144
|
-
if mean is not None:
|
145
|
-
assert var is not None, 'mean and var must be both None or not None.'
|
146
|
-
stats_shape = list(x.shape)
|
147
|
-
for axis in reduction_axes:
|
148
|
-
stats_shape[axis] = 1
|
149
|
-
mean = mean.reshape(stats_shape)
|
150
|
-
var = var.reshape(stats_shape)
|
151
|
-
y = x - mean
|
152
|
-
mul = jax.lax.rsqrt(var + jnp.asarray(epsilon, dtype))
|
153
|
-
y = y * mul
|
154
|
-
if weights is not None:
|
155
|
-
y = _scale_operation(y, weights.value)
|
156
|
-
else:
|
157
|
-
assert var is None, 'mean and var must be both None or not None.'
|
158
|
-
assert weights is None, 'scale and bias are not supported without mean and var'
|
159
|
-
y = x
|
160
|
-
return jnp.asarray(y, dtype)
|
161
|
-
|
162
|
-
|
163
|
-
def _scale_operation(x, param):
|
164
|
-
if 'scale' in param:
|
165
|
-
x = x * param['scale']
|
166
|
-
if 'bias' in param:
|
167
|
-
x = x + param['bias']
|
168
|
-
return x
|
169
|
-
|
170
|
-
|
171
|
-
class _BatchNorm(DnnLayer):
|
172
|
-
__module__ = 'brainstate.nn'
|
173
|
-
num_spatial_dims: int
|
174
|
-
|
175
|
-
def __init__(
|
176
|
-
self,
|
177
|
-
in_size: Size,
|
178
|
-
feature_axis: Axes = -1,
|
179
|
-
track_running_stats: bool = True,
|
180
|
-
epsilon: float = 1e-5,
|
181
|
-
momentum: float = 0.99,
|
182
|
-
affine: bool = True,
|
183
|
-
bias_initializer: Union[ArrayLike, Callable] = init.Constant(0.),
|
184
|
-
scale_initializer: Union[ArrayLike, Callable] = init.Constant(1.),
|
185
|
-
axis_name: Optional[Union[str, Sequence[str]]] = None,
|
186
|
-
axis_index_groups: Optional[Sequence[Sequence[int]]] = None,
|
187
|
-
mode: Optional[Mode] = None,
|
188
|
-
name: Optional[str] = None,
|
189
|
-
dtype: Any = None,
|
190
|
-
):
|
191
|
-
super().__init__(name=name, mode=mode)
|
192
|
-
|
193
|
-
# parameters
|
194
|
-
self.in_size = tuple(in_size)
|
195
|
-
self.out_size = tuple(in_size)
|
196
|
-
self.affine = affine
|
197
|
-
self.bias_initializer = bias_initializer
|
198
|
-
self.scale_initializer = scale_initializer
|
199
|
-
self.dtype = dtype or environ.dftype()
|
200
|
-
self.track_running_stats = track_running_stats
|
201
|
-
self.momentum = jnp.asarray(momentum, dtype=self.dtype)
|
202
|
-
self.epsilon = jnp.asarray(epsilon, dtype=self.dtype)
|
203
|
-
|
204
|
-
# parameters about axis
|
205
|
-
feature_axis = (feature_axis,) if isinstance(feature_axis, int) else feature_axis
|
206
|
-
self.feature_axis = _canonicalize_axes(len(in_size), feature_axis)
|
207
|
-
self.axis_name = axis_name
|
208
|
-
self.axis_index_groups = axis_index_groups
|
209
|
-
|
210
|
-
# variables
|
211
|
-
feature_shape = tuple([ax if i in self.feature_axis else 1 for i, ax in enumerate(in_size)])
|
212
|
-
if self.track_running_stats:
|
213
|
-
self.running_mean = LongTermState(jnp.zeros(feature_shape, dtype=self.dtype))
|
214
|
-
self.running_var = LongTermState(jnp.ones(feature_shape, dtype=self.dtype))
|
215
|
-
else:
|
216
|
-
self.running_mean = None
|
217
|
-
self.running_var = None
|
218
|
-
|
219
|
-
# parameters
|
220
|
-
if self.affine:
|
221
|
-
assert track_running_stats, "Affine parameters are not needed when track_running_stats is False."
|
222
|
-
bias = init.param(self.bias_initializer, feature_shape)
|
223
|
-
scale = init.param(self.scale_initializer, feature_shape)
|
224
|
-
self.weight = ParamState(dict(bias=bias, scale=scale))
|
225
|
-
else:
|
226
|
-
self.weight = None
|
227
|
-
|
228
|
-
def _check_input_dim(self, x):
|
229
|
-
if x.ndim == self.num_spatial_dims + 2:
|
230
|
-
x_shape = x.shape[1:]
|
231
|
-
elif x.ndim == self.num_spatial_dims + 1:
|
232
|
-
x_shape = x.shape
|
233
|
-
else:
|
234
|
-
raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
|
235
|
-
f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
|
236
|
-
if self.in_size != x_shape:
|
237
|
-
raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
|
238
|
-
|
239
|
-
def update(self, x):
|
240
|
-
# input shape and batch mode or not
|
241
|
-
if x.ndim == self.num_spatial_dims + 2:
|
242
|
-
x_shape = x.shape[1:]
|
243
|
-
batch = True
|
244
|
-
elif x.ndim == self.num_spatial_dims + 1:
|
245
|
-
x_shape = x.shape
|
246
|
-
batch = False
|
247
|
-
else:
|
248
|
-
raise ValueError(f"expected {self.num_spatial_dims + 2}D (with batch) or "
|
249
|
-
f"{self.num_spatial_dims + 1}D (without batch) input (got {x.ndim}D input, {x.shape})")
|
250
|
-
if self.in_size != x_shape:
|
251
|
-
raise ValueError(f"The expected input shape is {self.in_size}, while we got {x_shape}.")
|
252
|
-
|
253
|
-
# reduce the feature axis
|
254
|
-
if batch:
|
255
|
-
reduction_axes = tuple(i for i in range(x.ndim) if (i - 1) not in self.feature_axis)
|
256
|
-
else:
|
257
|
-
reduction_axes = tuple(i for i in range(x.ndim) if i not in self.feature_axis)
|
258
|
-
|
259
|
-
# fitting phase
|
260
|
-
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
261
|
-
|
262
|
-
# compute the running mean and variance
|
263
|
-
if self.track_running_stats:
|
264
|
-
if fit_phase:
|
265
|
-
mean, var = _compute_stats(
|
266
|
-
x,
|
267
|
-
reduction_axes,
|
268
|
-
dtype=self.dtype,
|
269
|
-
axis_name=self.axis_name,
|
270
|
-
axis_index_groups=self.axis_index_groups,
|
271
|
-
)
|
272
|
-
self.running_mean.value = self.momentum * self.running_mean.value + (1 - self.momentum) * mean
|
273
|
-
self.running_var.value = self.momentum * self.running_var.value + (1 - self.momentum) * var
|
274
|
-
else:
|
275
|
-
mean = self.running_mean.value
|
276
|
-
var = self.running_var.value
|
277
|
-
else:
|
278
|
-
mean, var = None, None
|
279
|
-
|
280
|
-
# normalize
|
281
|
-
return _normalize(x, mean, var, self.weight, reduction_axes, self.dtype, self.epsilon)
|
282
|
-
|
283
|
-
|
284
|
-
class BatchNorm1d(_BatchNorm):
|
285
|
-
r"""1-D batch normalization [1]_.
|
286
|
-
|
287
|
-
The data should be of `(b, l, c)`, where `b` is the batch dimension,
|
288
|
-
`l` is the layer dimension, and `c` is the channel dimension.
|
289
|
-
|
290
|
-
%s
|
291
|
-
"""
|
292
|
-
__module__ = 'brainstate.nn'
|
293
|
-
num_spatial_dims: int = 1
|
294
|
-
|
295
|
-
|
296
|
-
class BatchNorm2d(_BatchNorm):
|
297
|
-
r"""2-D batch normalization [1]_.
|
298
|
-
|
299
|
-
The data should be of `(b, h, w, c)`, where `b` is the batch dimension,
|
300
|
-
`h` is the height dimension, `w` is the width dimension, and `c` is the
|
301
|
-
channel dimension.
|
302
|
-
|
303
|
-
%s
|
304
|
-
"""
|
305
|
-
__module__ = 'brainstate.nn'
|
306
|
-
num_spatial_dims: int = 2
|
307
|
-
|
308
|
-
|
309
|
-
class BatchNorm3d(_BatchNorm):
|
310
|
-
r"""3-D batch normalization [1]_.
|
311
|
-
|
312
|
-
The data should be of `(b, h, w, d, c)`, where `b` is the batch dimension,
|
313
|
-
`h` is the height dimension, `w` is the width dimension, `d` is the depth
|
314
|
-
dimension, and `c` is the channel dimension.
|
315
|
-
|
316
|
-
%s
|
317
|
-
"""
|
318
|
-
__module__ = 'brainstate.nn'
|
319
|
-
num_spatial_dims: int = 3
|
320
|
-
|
321
|
-
|
322
|
-
_bn_doc = r'''
|
323
|
-
|
324
|
-
This layer aims to reduce the internal covariant shift of data. It
|
325
|
-
normalizes a batch of data by fixing the mean and variance of inputs
|
326
|
-
on each feature (channel). Most commonly, the first axis of the data
|
327
|
-
is the batch, and the last is the channel. However, users can specify
|
328
|
-
the axes to be normalized.
|
329
|
-
|
330
|
-
.. math::
|
331
|
-
y=\frac{x-\mathrm{E}[x]}{\sqrt{\operatorname{Var}[x]+\epsilon}} * \gamma+\beta
|
332
|
-
|
333
|
-
.. note::
|
334
|
-
This :attr:`momentum` argument is different from one used in optimizer
|
335
|
-
classes and the conventional notion of momentum. Mathematically, the
|
336
|
-
update rule for running statistics here is
|
337
|
-
:math:`\hat{x}_\text{new} = \text{momentum} \times \hat{x} + (1-\text{momentum}) \times x_t`,
|
338
|
-
where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
|
339
|
-
new observed value.
|
340
|
-
|
341
|
-
Parameters
|
342
|
-
----------
|
343
|
-
in_size: sequence of int
|
344
|
-
The input shape, without batch size.
|
345
|
-
feature_axis: int, tuple, list
|
346
|
-
The feature or non-batch axis of the input.
|
347
|
-
track_running_stats: bool
|
348
|
-
A boolean value that when set to ``True``, this module tracks the running mean and variance,
|
349
|
-
and when set to ``False``, this module does not track such statistics, and initializes
|
350
|
-
statistics buffers ``running_mean`` and ``running_var`` as ``None``. When these buffers are ``None``,
|
351
|
-
this module always uses batch statistics. in both training and eval modes. Default: ``True``.
|
352
|
-
momentum: float
|
353
|
-
The value used for the ``running_mean`` and ``running_var`` computation. Default: 0.99
|
354
|
-
epsilon: float
|
355
|
-
A value added to the denominator for numerical stability. Default: 1e-5
|
356
|
-
affine: bool
|
357
|
-
A boolean value that when set to ``True``, this module has
|
358
|
-
learnable affine parameters. Default: ``True``
|
359
|
-
bias_initializer: ArrayLike, Callable
|
360
|
-
An initializer generating the original translation matrix. If not ``None``, bias (beta) is added.
|
361
|
-
Default: ``init.Constant(0.)``
|
362
|
-
scale_initializer: ArrayLike, Callable
|
363
|
-
An initializer generating the original scaling matrix. If not ``None``, multiply by scale (gamma).
|
364
|
-
Default: ``init.Constant(1.)``
|
365
|
-
axis_name: optional, str, sequence of str
|
366
|
-
If not ``None``, it should be a string (or sequence of
|
367
|
-
strings) representing the axis name(s) over which this module is being
|
368
|
-
run within a jax map (e.g. ``jax.pmap`` or ``jax.vmap``). Supplying this
|
369
|
-
argument means that batch statistics are calculated across all replicas
|
370
|
-
on the named axes.
|
371
|
-
axis_index_groups: optional, sequence
|
372
|
-
Specifies how devices are grouped. Valid
|
373
|
-
only within ``jax.pmap`` collectives.
|
374
|
-
Groups of axis indices within that named axis
|
375
|
-
representing subsets of devices to reduce over (default: None). For
|
376
|
-
example, `[[0, 1], [2, 3]]` would independently batch-normalize over
|
377
|
-
the examples on the first two and last two devices. See `jax.lax.psum`
|
378
|
-
for more details.
|
379
|
-
|
380
|
-
References
|
381
|
-
----------
|
382
|
-
.. [1] Ioffe, Sergey and Christian Szegedy. “Batch Normalization: Accelerating Deep Network Training
|
383
|
-
by Reducing Internal Covariate Shift.” ArXiv abs/1502.03167 (2015): n. pag.
|
384
|
-
|
385
|
-
'''
|
386
|
-
|
387
|
-
BatchNorm1d.__doc__ = BatchNorm1d.__doc__ % _bn_doc
|
388
|
-
BatchNorm2d.__doc__ = BatchNorm2d.__doc__ % _bn_doc
|
389
|
-
BatchNorm3d.__doc__ = BatchNorm3d.__doc__ % _bn_doc
|
brainstate/nn/_others.py
DELETED
@@ -1,101 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
from __future__ import annotations
|
18
|
-
|
19
|
-
from functools import partial
|
20
|
-
from typing import Optional
|
21
|
-
|
22
|
-
import brainunit as bu
|
23
|
-
import jax.numpy as jnp
|
24
|
-
|
25
|
-
from ._base import DnnLayer
|
26
|
-
from brainstate.mixin import Mode
|
27
|
-
from brainstate import random, environ, typing, init
|
28
|
-
|
29
|
-
__all__ = [
|
30
|
-
'DropoutFixed',
|
31
|
-
]
|
32
|
-
|
33
|
-
|
34
|
-
class DropoutFixed(DnnLayer):
|
35
|
-
"""
|
36
|
-
A dropout layer with the fixed dropout mask along the time axis once after initialized.
|
37
|
-
|
38
|
-
In training, to compensate for the fraction of input values dropped (`rate`),
|
39
|
-
all surviving values are multiplied by `1 / (1 - rate)`.
|
40
|
-
|
41
|
-
This layer is active only during training (``mode=brainstate.mixin.Training``). In other
|
42
|
-
circumstances it is a no-op.
|
43
|
-
|
44
|
-
.. [1] Srivastava, Nitish, et al. "Dropout: a simple way to prevent
|
45
|
-
neural networks from overfitting." The journal of machine learning
|
46
|
-
research 15.1 (2014): 1929-1958.
|
47
|
-
|
48
|
-
.. admonition:: Tip
|
49
|
-
:class: tip
|
50
|
-
|
51
|
-
This kind of Dropout is firstly described in `Enabling Spike-based Backpropagation for Training Deep Neural
|
52
|
-
Network Architectures <https://arxiv.org/abs/1903.06379>`_:
|
53
|
-
|
54
|
-
There is a subtle difference in the way dropout is applied in SNNs compared to ANNs. In ANNs, each epoch of
|
55
|
-
training has several iterations of mini-batches. In each iteration, randomly selected units (with dropout ratio of :math:`p`)
|
56
|
-
are disconnected from the network while weighting by its posterior probability (:math:`1-p`). However, in SNNs, each
|
57
|
-
iteration has more than one forward propagation depending on the time length of the spike train. We back-propagate
|
58
|
-
the output error and modify the network parameters only at the last time step. For dropout to be effective in
|
59
|
-
our training method, it has to be ensured that the set of connected units within an iteration of mini-batch
|
60
|
-
data is not changed, such that the neural network is constituted by the same random subset of units during
|
61
|
-
each forward propagation within a single iteration. On the other hand, if the units are randomly connected at
|
62
|
-
each time-step, the effect of dropout will be averaged out over the entire forward propagation time within an
|
63
|
-
iteration. Then, the dropout effect would fade-out once the output error is propagated backward and the parameters
|
64
|
-
are updated at the last time step. Therefore, we need to keep the set of randomly connected units for the entire
|
65
|
-
time window within an iteration.
|
66
|
-
|
67
|
-
Args:
|
68
|
-
in_size: The size of the input tensor.
|
69
|
-
prob: Probability to keep element of the tensor.
|
70
|
-
mode: Mode. The computation mode of the object.
|
71
|
-
name: str. The name of the dynamic system.
|
72
|
-
"""
|
73
|
-
__module__ = 'brainstate.nn'
|
74
|
-
|
75
|
-
def __init__(
|
76
|
-
self,
|
77
|
-
in_size: typing.Size,
|
78
|
-
prob: float = 0.5,
|
79
|
-
mode: Optional[Mode] = None,
|
80
|
-
name: Optional[str] = None
|
81
|
-
) -> None:
|
82
|
-
super().__init__(mode=mode, name=name)
|
83
|
-
assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
|
84
|
-
self.prob = prob
|
85
|
-
self.in_size = in_size
|
86
|
-
self.out_size = in_size
|
87
|
-
|
88
|
-
def init_state(self, batch_size=None, **kwargs):
|
89
|
-
self.mask = init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size)
|
90
|
-
|
91
|
-
def update(self, x):
|
92
|
-
dtype = bu.math.get_dtype(x)
|
93
|
-
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
94
|
-
if fit_phase:
|
95
|
-
assert self.mask.shape == x.shape, (f"Input shape {x.shape} does not match the mask shape {self.mask.shape}. "
|
96
|
-
f"Please call `init_state()` method first.")
|
97
|
-
return jnp.where(self.mask,
|
98
|
-
jnp.asarray(x / self.prob, dtype=dtype),
|
99
|
-
jnp.asarray(0., dtype=dtype))
|
100
|
-
else:
|
101
|
-
return x
|