brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/augment/_autograd.py +9 -6
- brainstate/event/__init__.py +4 -2
- brainstate/event/_csr.py +26 -18
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_fixed_probability.py +589 -152
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +13 -10
- brainstate/event/_linear.py +267 -127
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +8 -3
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
- brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
- brainstate/nn/_dynamics/_projection_base.py +1 -1
- brainstate/nn/_exp_euler.py +1 -1
- brainstate/nn/_interaction/__init__.py +13 -4
- brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
- brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/optim/_lr_scheduler.py +1 -1
- brainstate/optim/_optax_optimizer.py +18 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +1 -1
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/RECORD +30 -21
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -18,10 +18,8 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import collections.abc
|
21
|
-
import numbers
|
22
21
|
from typing import Callable, Tuple, Union, Sequence, Optional, TypeVar
|
23
22
|
|
24
|
-
import brainunit as u
|
25
23
|
import jax
|
26
24
|
import jax.numpy as jnp
|
27
25
|
|
@@ -33,10 +31,8 @@ from brainstate.typing import ArrayLike
|
|
33
31
|
T = TypeVar('T')
|
34
32
|
|
35
33
|
__all__ = [
|
36
|
-
'Linear', 'ScaledWSLinear', 'SignedWLinear', 'CSRLinear',
|
37
34
|
'Conv1d', 'Conv2d', 'Conv3d',
|
38
35
|
'ScaledWSConv1d', 'ScaledWSConv2d', 'ScaledWSConv3d',
|
39
|
-
'AllToAll',
|
40
36
|
]
|
41
37
|
|
42
38
|
|
@@ -79,168 +75,6 @@ def replicate(
|
|
79
75
|
f"sequence of length {num_replicate}.")
|
80
76
|
|
81
77
|
|
82
|
-
class Linear(Module):
|
83
|
-
"""
|
84
|
-
Linear layer.
|
85
|
-
"""
|
86
|
-
__module__ = 'brainstate.nn'
|
87
|
-
|
88
|
-
def __init__(
|
89
|
-
self,
|
90
|
-
in_size: Union[int, Sequence[int]],
|
91
|
-
out_size: Union[int, Sequence[int]],
|
92
|
-
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
93
|
-
b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
|
94
|
-
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
95
|
-
name: Optional[str] = None,
|
96
|
-
):
|
97
|
-
super().__init__(name=name)
|
98
|
-
|
99
|
-
# input and output shape
|
100
|
-
self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
|
101
|
-
self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
|
102
|
-
|
103
|
-
# w_mask
|
104
|
-
self.w_mask = init.param(w_mask, self.in_size + self.out_size)
|
105
|
-
|
106
|
-
# weights
|
107
|
-
params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
|
108
|
-
if b_init is not None:
|
109
|
-
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
110
|
-
|
111
|
-
# weight + op
|
112
|
-
self.weight = ParamState(params)
|
113
|
-
|
114
|
-
def update(self, x):
|
115
|
-
params = self.weight.value
|
116
|
-
weight = params['weight']
|
117
|
-
if self.w_mask is not None:
|
118
|
-
weight = weight * self.w_mask
|
119
|
-
y = u.math.dot(x, weight)
|
120
|
-
if 'bias' in params:
|
121
|
-
y = y + params['bias']
|
122
|
-
return y
|
123
|
-
|
124
|
-
|
125
|
-
class SignedWLinear(Module):
|
126
|
-
"""
|
127
|
-
Linear layer with signed weights.
|
128
|
-
"""
|
129
|
-
__module__ = 'brainstate.nn'
|
130
|
-
|
131
|
-
def __init__(
|
132
|
-
self,
|
133
|
-
in_size: Union[int, Sequence[int]],
|
134
|
-
out_size: Union[int, Sequence[int]],
|
135
|
-
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
136
|
-
w_sign: Optional[ArrayLike] = None,
|
137
|
-
name: Optional[str] = None,
|
138
|
-
|
139
|
-
):
|
140
|
-
super().__init__(name=name)
|
141
|
-
|
142
|
-
# input and output shape
|
143
|
-
self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
|
144
|
-
self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
|
145
|
-
|
146
|
-
# w_mask
|
147
|
-
self.w_sign = w_sign
|
148
|
-
|
149
|
-
# weights
|
150
|
-
weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
|
151
|
-
self.weight = ParamState(weight)
|
152
|
-
|
153
|
-
def _operation(self, x, w):
|
154
|
-
if self.w_sign is None:
|
155
|
-
return jnp.matmul(x, jnp.abs(w))
|
156
|
-
else:
|
157
|
-
return jnp.matmul(x, jnp.abs(w) * self.w_sign)
|
158
|
-
|
159
|
-
def update(self, x):
|
160
|
-
return self._operation(x, self.weight.value)
|
161
|
-
|
162
|
-
|
163
|
-
class ScaledWSLinear(Module):
|
164
|
-
"""
|
165
|
-
Linear Layer with Weight Standardization.
|
166
|
-
|
167
|
-
Applies weight standardization to the weights of the linear layer.
|
168
|
-
|
169
|
-
Parameters
|
170
|
-
----------
|
171
|
-
in_size: int, sequence of int
|
172
|
-
The input size.
|
173
|
-
out_size: int, sequence of int
|
174
|
-
The output size.
|
175
|
-
w_init: Callable, ArrayLike
|
176
|
-
The initializer for the weights.
|
177
|
-
b_init: Callable, ArrayLike
|
178
|
-
The initializer for the bias.
|
179
|
-
w_mask: ArrayLike, Callable
|
180
|
-
The optional mask of the weights.
|
181
|
-
ws_gain: bool
|
182
|
-
Whether to use gain for the weights. The default is True.
|
183
|
-
eps: float
|
184
|
-
The epsilon value for the weight standardization.
|
185
|
-
name: str
|
186
|
-
The name of the object.
|
187
|
-
|
188
|
-
"""
|
189
|
-
__module__ = 'brainstate.nn'
|
190
|
-
|
191
|
-
def __init__(
|
192
|
-
self,
|
193
|
-
in_size: Union[int, Sequence[int]],
|
194
|
-
out_size: Union[int, Sequence[int]],
|
195
|
-
w_init: Callable = init.KaimingNormal(),
|
196
|
-
b_init: Callable = init.ZeroInit(),
|
197
|
-
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
198
|
-
ws_gain: bool = True,
|
199
|
-
eps: float = 1e-4,
|
200
|
-
name: str = None,
|
201
|
-
):
|
202
|
-
super().__init__(name=name)
|
203
|
-
|
204
|
-
# input and output shape
|
205
|
-
self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
|
206
|
-
self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
|
207
|
-
|
208
|
-
# w_mask
|
209
|
-
self.w_mask = init.param(w_mask, (self.in_size[0], 1))
|
210
|
-
|
211
|
-
# parameters
|
212
|
-
self.eps = eps
|
213
|
-
|
214
|
-
# weights
|
215
|
-
params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
|
216
|
-
if b_init is not None:
|
217
|
-
params['bias'] = init.param(b_init, self.out_size, allow_none=False)
|
218
|
-
# gain
|
219
|
-
if ws_gain:
|
220
|
-
s = params['weight'].shape
|
221
|
-
params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
|
222
|
-
|
223
|
-
# weight operation
|
224
|
-
self.weight = ParamState(params)
|
225
|
-
|
226
|
-
def update(self, x):
|
227
|
-
return self._operation(x, self.weight.value)
|
228
|
-
|
229
|
-
def _operation(self, x, params):
|
230
|
-
w = params['weight']
|
231
|
-
w = functional.weight_standardization(w, self.eps, params.get('gain', None))
|
232
|
-
if self.w_mask is not None:
|
233
|
-
w = w * self.w_mask
|
234
|
-
y = jnp.dot(x, w)
|
235
|
-
if 'bias' in params:
|
236
|
-
y = y + params['bias']
|
237
|
-
return y
|
238
|
-
|
239
|
-
|
240
|
-
class CSRLinear(Module):
|
241
|
-
__module__ = 'brainstate.nn'
|
242
|
-
|
243
|
-
|
244
78
|
class _BaseConv(Module):
|
245
79
|
# the number of spatial dimensions
|
246
80
|
num_spatial_dims: int
|
@@ -663,64 +497,3 @@ _ws_conv_doc = '''
|
|
663
497
|
ScaledWSConv1d.__doc__ = ScaledWSConv1d.__doc__ % _ws_conv_doc
|
664
498
|
ScaledWSConv2d.__doc__ = ScaledWSConv2d.__doc__ % _ws_conv_doc
|
665
499
|
ScaledWSConv3d.__doc__ = ScaledWSConv3d.__doc__ % _ws_conv_doc
|
666
|
-
|
667
|
-
|
668
|
-
class AllToAll(Module):
|
669
|
-
"""Synaptic matrix multiplication with All2All connections.
|
670
|
-
|
671
|
-
Args:
|
672
|
-
in_size: int. The number of neurons in the presynaptic neuron group.
|
673
|
-
out_size: int. The number of neurons in the postsynaptic neuron group.
|
674
|
-
w_init: The synaptic weights.
|
675
|
-
include_self: bool. Whether connect the neuron with at the same position.
|
676
|
-
name: str. The object name.
|
677
|
-
"""
|
678
|
-
|
679
|
-
def __init__(
|
680
|
-
self,
|
681
|
-
in_size: Union[int, Sequence[int]],
|
682
|
-
out_size: Union[int, Sequence[int]],
|
683
|
-
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
684
|
-
include_self: bool = True,
|
685
|
-
name: Optional[str] = None,
|
686
|
-
):
|
687
|
-
super().__init__(name=name)
|
688
|
-
|
689
|
-
# input and output shape
|
690
|
-
self.in_size = (in_size,) if isinstance(in_size, numbers.Integral) else tuple(in_size)
|
691
|
-
self.out_size = (out_size,) if isinstance(out_size, numbers.Integral) else tuple(out_size)
|
692
|
-
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
693
|
-
'and "out_size" must be the same.')
|
694
|
-
|
695
|
-
# weights
|
696
|
-
self.weight = ParamState(init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
|
697
|
-
|
698
|
-
# others
|
699
|
-
self.include_self = include_self
|
700
|
-
|
701
|
-
def update(self, pre_val):
|
702
|
-
if u.math.ndim(self.weight.value) == 0: # weight is a scalar
|
703
|
-
if pre_val.ndim == 1:
|
704
|
-
post_val = u.math.sum(pre_val)
|
705
|
-
else:
|
706
|
-
post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
|
707
|
-
if not self.include_self:
|
708
|
-
if self.in_size == self.out_size:
|
709
|
-
post_val = post_val - pre_val
|
710
|
-
elif self.in_size[-1] > self.out_size[-1]:
|
711
|
-
val = pre_val[..., :self.out_size[-1]]
|
712
|
-
post_val = post_val - val
|
713
|
-
else:
|
714
|
-
size = list(self.out_size)
|
715
|
-
size[-1] = self.out_size[-1] - self.in_size[-1]
|
716
|
-
val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
|
717
|
-
post_val = post_val - val
|
718
|
-
post_val = self.weight.value * post_val
|
719
|
-
|
720
|
-
else: # weight is a matrix
|
721
|
-
assert u.math.ndim(self.weight.value) == 2, '"weight" must be a 2D matrix.'
|
722
|
-
if not self.include_self:
|
723
|
-
post_val = pre_val @ u.math.fill_diagonal(self.weight.value, 0.)
|
724
|
-
else:
|
725
|
-
post_val = pre_val @ self.weight.value
|
726
|
-
return post_val
|
@@ -235,20 +235,5 @@ class TestConvTranspose3d(parameterized.TestCase):
|
|
235
235
|
y = conv_transpose_module(x)
|
236
236
|
print(y.shape)
|
237
237
|
|
238
|
-
|
239
|
-
class TestDense(parameterized.TestCase):
|
240
|
-
@parameterized.product(
|
241
|
-
size=[(10,),
|
242
|
-
(20, 10),
|
243
|
-
(5, 8, 10)],
|
244
|
-
num_out=[20, ]
|
245
|
-
)
|
246
|
-
def test_Dense1(self, size, num_out):
|
247
|
-
f = bst.nn.Linear(10, num_out)
|
248
|
-
x = bst.random.random(size)
|
249
|
-
y = f(x)
|
250
|
-
self.assertTrue(y.shape == size[:-1] + (num_out,))
|
251
|
-
|
252
|
-
|
253
238
|
if __name__ == '__main__':
|
254
239
|
absltest.main()
|