brainstate 0.0.1__py2.py3-none-any.whl → 0.0.1.post20240622__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +4 -5
- brainstate/_module.py +191 -48
- brainstate/_module_test.py +95 -21
- brainstate/_state.py +17 -0
- brainstate/environ.py +2 -2
- brainstate/functional/__init__.py +3 -2
- brainstate/functional/_activations.py +7 -26
- brainstate/functional/_normalization.py +3 -0
- brainstate/functional/_others.py +49 -0
- brainstate/functional/_spikes.py +0 -1
- brainstate/mixin.py +2 -2
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_base.py +10 -7
- brainstate/nn/_dynamics.py +20 -0
- brainstate/nn/_elementwise.py +5 -4
- brainstate/nn/_embedding.py +66 -0
- brainstate/nn/_misc.py +4 -3
- brainstate/nn/_others.py +3 -2
- brainstate/nn/_poolings.py +21 -20
- brainstate/nn/_poolings_test.py +4 -4
- brainstate/nn/_rate_rnns.py +17 -0
- brainstate/nn/_readout.py +6 -0
- brainstate/optim/__init__.py +0 -1
- brainstate/optim/_lr_scheduler_test.py +13 -0
- brainstate/optim/_sgd_optimizer.py +18 -17
- brainstate/transform/__init__.py +2 -3
- brainstate/transform/_autograd.py +1 -1
- brainstate/transform/_autograd_test.py +0 -2
- brainstate/transform/_jit.py +47 -21
- brainstate/transform/_jit_test.py +0 -3
- brainstate/transform/_make_jaxpr.py +164 -3
- brainstate/transform/_make_jaxpr_test.py +0 -2
- brainstate/transform/_progress_bar.py +1 -3
- brainstate/util.py +0 -1
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/METADATA +9 -17
- brainstate-0.0.1.post20240622.dist-info/RECORD +64 -0
- brainstate/math/__init__.py +0 -21
- brainstate/math/_einops.py +0 -787
- brainstate/math/_einops_parsing.py +0 -169
- brainstate/math/_einops_parsing_test.py +0 -126
- brainstate/math/_einops_test.py +0 -346
- brainstate/math/_misc.py +0 -298
- brainstate/math/_misc_test.py +0 -58
- brainstate/nn/functional/__init__.py +0 -25
- brainstate/nn/functional/_activations.py +0 -754
- brainstate/nn/functional/_normalization.py +0 -69
- brainstate/nn/functional/_spikes.py +0 -90
- brainstate/nn/init/__init__.py +0 -26
- brainstate/nn/init/_base.py +0 -36
- brainstate/nn/init/_generic.py +0 -175
- brainstate/nn/init/_random_inits.py +0 -489
- brainstate/nn/init/_regular_inits.py +0 -109
- brainstate/nn/surrogate.py +0 -1740
- brainstate-0.0.1.dist-info/RECORD +0 -79
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/LICENSE +0 -0
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/WHEEL +0 -0
- {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/top_level.txt +0 -0
@@ -1,69 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
|
-
from typing import Optional
|
19
|
-
|
20
|
-
import jax
|
21
|
-
import jax.numpy as jnp
|
22
|
-
|
23
|
-
__all__ = [
|
24
|
-
'weight_standardization',
|
25
|
-
]
|
26
|
-
|
27
|
-
|
28
|
-
def weight_standardization(
|
29
|
-
w: jax.typing.ArrayLike,
|
30
|
-
eps: float = 1e-4,
|
31
|
-
gain: Optional[jax.Array] = None,
|
32
|
-
out_axis: int = -1,
|
33
|
-
):
|
34
|
-
"""
|
35
|
-
Scaled Weight Standardization,
|
36
|
-
see `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization <https://paperswithcode.com/paper/weight-standardization>`_.
|
37
|
-
|
38
|
-
Parameters
|
39
|
-
----------
|
40
|
-
w : jax.typing.ArrayLike
|
41
|
-
The weight tensor.
|
42
|
-
eps : float
|
43
|
-
A small value to avoid division by zero.
|
44
|
-
gain : Array
|
45
|
-
The gain function, by default None.
|
46
|
-
out_axis : int
|
47
|
-
The output axis, by default -1.
|
48
|
-
|
49
|
-
Returns
|
50
|
-
-------
|
51
|
-
jax.typing.ArrayLike
|
52
|
-
The scaled weight tensor.
|
53
|
-
"""
|
54
|
-
if out_axis < 0:
|
55
|
-
out_axis = w.ndim + out_axis
|
56
|
-
fan_in = 1 # get the fan-in of the weight tensor
|
57
|
-
axes = [] # get the axes of the weight tensor
|
58
|
-
for i in range(w.ndim):
|
59
|
-
if i != out_axis:
|
60
|
-
fan_in *= w.shape[i]
|
61
|
-
axes.append(i)
|
62
|
-
# normalize the weight
|
63
|
-
mean = jnp.mean(w, axis=axes, keepdims=True)
|
64
|
-
var = jnp.var(w, axis=axes, keepdims=True)
|
65
|
-
scale = jax.lax.rsqrt(jnp.maximum(var * fan_in, eps))
|
66
|
-
if gain is not None:
|
67
|
-
scale = gain * scale
|
68
|
-
shift = mean * scale
|
69
|
-
return w * scale - shift
|
@@ -1,90 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
|
-
__all__ = [
|
19
|
-
'spike_bitwise_or',
|
20
|
-
'spike_bitwise_and',
|
21
|
-
'spike_bitwise_iand',
|
22
|
-
'spike_bitwise_not',
|
23
|
-
'spike_bitwise_xor',
|
24
|
-
'spike_bitwise_ixor',
|
25
|
-
'spike_bitwise',
|
26
|
-
]
|
27
|
-
|
28
|
-
|
29
|
-
def spike_bitwise_or(x, y):
|
30
|
-
"""Bitwise OR operation for spike tensors."""
|
31
|
-
return x + y - x * y
|
32
|
-
|
33
|
-
|
34
|
-
def spike_bitwise_and(x, y):
|
35
|
-
"""Bitwise AND operation for spike tensors."""
|
36
|
-
return x * y
|
37
|
-
|
38
|
-
|
39
|
-
def spike_bitwise_iand(x, y):
|
40
|
-
"""Bitwise IAND operation for spike tensors."""
|
41
|
-
return (1 - x) * y
|
42
|
-
|
43
|
-
|
44
|
-
def spike_bitwise_not(x):
|
45
|
-
"""Bitwise NOT operation for spike tensors."""
|
46
|
-
return 1 - x
|
47
|
-
|
48
|
-
|
49
|
-
def spike_bitwise_xor(x, y):
|
50
|
-
"""Bitwise XOR operation for spike tensors."""
|
51
|
-
return x + y - 2 * x * y
|
52
|
-
|
53
|
-
|
54
|
-
def spike_bitwise_ixor(x, y):
|
55
|
-
"""Bitwise IXOR operation for spike tensors."""
|
56
|
-
return x * (1 - y) + (1 - x) * y
|
57
|
-
|
58
|
-
|
59
|
-
def spike_bitwise(x, y, op: str):
|
60
|
-
r"""Bitwise operation for spike tensors.
|
61
|
-
|
62
|
-
.. math::
|
63
|
-
|
64
|
-
\begin{array}{ccc}
|
65
|
-
\hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
|
66
|
-
\hline \text { ADD } & x+y & x+y \\
|
67
|
-
\text { AND } & x \cap y & x \cdot y \\
|
68
|
-
\text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
|
69
|
-
\text { OR } & x \cup y & (x+y)-(x \cdot y) \\
|
70
|
-
\hline
|
71
|
-
\end{array}
|
72
|
-
|
73
|
-
Args:
|
74
|
-
x: A spike tensor.
|
75
|
-
y: A spike tensor.
|
76
|
-
op: A string indicating the bitwise operation to perform.
|
77
|
-
"""
|
78
|
-
if op == 'or':
|
79
|
-
return spike_bitwise_or(x, y)
|
80
|
-
elif op == 'and':
|
81
|
-
return spike_bitwise_and(x, y)
|
82
|
-
elif op == 'iand':
|
83
|
-
return spike_bitwise_iand(x, y)
|
84
|
-
elif op == 'xor':
|
85
|
-
return spike_bitwise_xor(x, y)
|
86
|
-
elif op == 'ixor':
|
87
|
-
return spike_bitwise_ixor(x, y)
|
88
|
-
else:
|
89
|
-
raise NotImplementedError(f"Unsupported bitwise operation: {op}.")
|
90
|
-
|
brainstate/nn/init/__init__.py
DELETED
@@ -1,26 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
from ._base import *
|
18
|
-
from ._base import __all__ as _base_all
|
19
|
-
from ._generic import *
|
20
|
-
from ._generic import __all__ as _generic_all
|
21
|
-
from ._random_inits import *
|
22
|
-
from ._random_inits import __all__ as _random_inits_all
|
23
|
-
from ._regular_inits import *
|
24
|
-
from ._regular_inits import __all__ as _regular_inits_all
|
25
|
-
|
26
|
-
__all__ = _generic_all + _base_all + _regular_inits_all + _random_inits_all
|
brainstate/nn/init/_base.py
DELETED
@@ -1,36 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
from typing import Optional, Tuple
|
18
|
-
|
19
|
-
import numpy as np
|
20
|
-
|
21
|
-
__all__ = ['Initializer', 'to_size']
|
22
|
-
|
23
|
-
|
24
|
-
class Initializer(object):
|
25
|
-
def __call__(self, *args, **kwargs):
|
26
|
-
raise NotImplementedError
|
27
|
-
|
28
|
-
|
29
|
-
def to_size(x) -> Optional[Tuple[int]]:
|
30
|
-
if isinstance(x, (tuple, list)):
|
31
|
-
return tuple(x)
|
32
|
-
if isinstance(x, (int, np.integer)):
|
33
|
-
return (x,)
|
34
|
-
if x is None:
|
35
|
-
return x
|
36
|
-
raise ValueError(f'Cannot make a size for {x}')
|
brainstate/nn/init/_generic.py
DELETED
@@ -1,175 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
import numbers
|
19
|
-
from typing import Union, Callable, Optional, Sequence
|
20
|
-
|
21
|
-
import jax
|
22
|
-
import jax.numpy as jnp
|
23
|
-
import numpy as np
|
24
|
-
|
25
|
-
from brainstate._state import State
|
26
|
-
from ._base import to_size
|
27
|
-
|
28
|
-
__all__ = [
|
29
|
-
'param',
|
30
|
-
'state',
|
31
|
-
'noise',
|
32
|
-
]
|
33
|
-
|
34
|
-
|
35
|
-
def _is_scalar(x):
|
36
|
-
return isinstance(x, numbers.Number)
|
37
|
-
|
38
|
-
|
39
|
-
def param(
|
40
|
-
param: Union[Callable, np.ndarray, jax.Array, float, int, bool],
|
41
|
-
sizes: Union[int, Sequence[int]],
|
42
|
-
batch_size: Optional[int] = None,
|
43
|
-
allow_none: bool = True,
|
44
|
-
allow_scalar: bool = True,
|
45
|
-
):
|
46
|
-
"""Initialize parameters.
|
47
|
-
|
48
|
-
Parameters
|
49
|
-
----------
|
50
|
-
param: callable, Initializer, bm.ndarray, jnp.ndarray, onp.ndarray, float, int, bool
|
51
|
-
The initialization of the parameter.
|
52
|
-
- If it is None, the created parameter will be None.
|
53
|
-
- If it is a callable function :math:`f`, the ``f(size)`` will be returned.
|
54
|
-
- If it is an instance of :py:class:`init.Initializer``, the ``f(size)`` will be returned.
|
55
|
-
- If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``.
|
56
|
-
sizes: int, sequence of int
|
57
|
-
The shape of the parameter.
|
58
|
-
batch_size: int
|
59
|
-
The batch size.
|
60
|
-
allow_none: bool
|
61
|
-
Whether allow the parameter is None.
|
62
|
-
allow_scalar: bool
|
63
|
-
Whether allow the parameter is a scalar value.
|
64
|
-
|
65
|
-
Returns
|
66
|
-
-------
|
67
|
-
param: ArrayType, float, int, bool, None
|
68
|
-
The initialized parameter.
|
69
|
-
|
70
|
-
See Also
|
71
|
-
--------
|
72
|
-
noise, state
|
73
|
-
"""
|
74
|
-
if param is None:
|
75
|
-
if allow_none:
|
76
|
-
return None
|
77
|
-
else:
|
78
|
-
raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
|
79
|
-
f'Callable function, but we got None. ')
|
80
|
-
sizes = list(to_size(sizes))
|
81
|
-
if allow_scalar and _is_scalar(param):
|
82
|
-
return param
|
83
|
-
|
84
|
-
if batch_size is not None:
|
85
|
-
sizes.insert(0, batch_size)
|
86
|
-
|
87
|
-
if callable(param):
|
88
|
-
return param(sizes)
|
89
|
-
elif isinstance(param, (np.ndarray, jax.Array)):
|
90
|
-
param = jnp.asarray(param)
|
91
|
-
if batch_size is not None:
|
92
|
-
param = jnp.repeat(jnp.expand_dims(param, axis=0), batch_size, axis=0)
|
93
|
-
elif isinstance(param, State):
|
94
|
-
param = param
|
95
|
-
if batch_size is not None:
|
96
|
-
param = type(param)(jnp.repeat(jnp.expand_dims(param.value, axis=batch_axis), batch_size, axis=batch_axis))
|
97
|
-
else:
|
98
|
-
raise ValueError(f'Unknown parameter type: {type(param)}')
|
99
|
-
|
100
|
-
if allow_scalar:
|
101
|
-
if param.shape == () or param.shape == (1,):
|
102
|
-
return param
|
103
|
-
if param.shape != tuple(sizes):
|
104
|
-
raise ValueError(f'The shape of the parameters should be {sizes}, but we got {param.shape}')
|
105
|
-
return param
|
106
|
-
|
107
|
-
|
108
|
-
def state(
|
109
|
-
init: Union[Callable, np.ndarray, jax.Array],
|
110
|
-
sizes: Union[int, Sequence[int]] = None,
|
111
|
-
batch_size: Optional[int] = None,
|
112
|
-
):
|
113
|
-
"""
|
114
|
-
Initialize a :math:`~.State` from a callable function or a data.
|
115
|
-
"""
|
116
|
-
sizes = to_size(sizes)
|
117
|
-
if callable(init):
|
118
|
-
if sizes is None:
|
119
|
-
raise ValueError('"varshape" cannot be None when data is a callable function.')
|
120
|
-
sizes = list(sizes)
|
121
|
-
if isinstance(batch_size, int):
|
122
|
-
sizes.insert(0, batch_size)
|
123
|
-
return State(init(sizes))
|
124
|
-
|
125
|
-
else:
|
126
|
-
if sizes is not None:
|
127
|
-
if jnp.shape(init) != sizes:
|
128
|
-
raise ValueError(f'The shape of "data" {jnp.shape(init)} does not match with "var_shape" {sizes}')
|
129
|
-
if isinstance(batch_size, int):
|
130
|
-
batch_size = batch_size
|
131
|
-
data = State(jnp.repeat(jnp.expand_dims(init, axis=0), batch_size, axis=0))
|
132
|
-
else:
|
133
|
-
data = State(init)
|
134
|
-
return data
|
135
|
-
|
136
|
-
|
137
|
-
def noise(
|
138
|
-
noises: Optional[Union[int, float, np.ndarray, jax.Array, Callable]],
|
139
|
-
size: Union[int, Sequence[int]],
|
140
|
-
num_vars: int = 1,
|
141
|
-
noise_idx: int = 0,
|
142
|
-
) -> Optional[Callable]:
|
143
|
-
"""Initialize a noise function.
|
144
|
-
|
145
|
-
Parameters
|
146
|
-
----------
|
147
|
-
noises: Any
|
148
|
-
size: Shape
|
149
|
-
The size of the noise.
|
150
|
-
num_vars: int
|
151
|
-
The number of variables.
|
152
|
-
noise_idx: int
|
153
|
-
The index of the current noise among all noise variables.
|
154
|
-
|
155
|
-
Returns
|
156
|
-
-------
|
157
|
-
noise_func: function, None
|
158
|
-
The noise function.
|
159
|
-
|
160
|
-
See Also
|
161
|
-
--------
|
162
|
-
variable_, parameter, delay
|
163
|
-
|
164
|
-
"""
|
165
|
-
if callable(noises):
|
166
|
-
return noises
|
167
|
-
elif noises is None:
|
168
|
-
return None
|
169
|
-
else:
|
170
|
-
noises = param(noises, size, allow_none=False)
|
171
|
-
if num_vars > 1:
|
172
|
-
noises_ = [None] * num_vars
|
173
|
-
noises_[noise_idx] = noises
|
174
|
-
noises = tuple(noises_)
|
175
|
-
return lambda *args, **kwargs: noises
|