brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,582 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from __future__ import annotations
|
19
|
+
|
20
|
+
from typing import Callable, Union, Optional
|
21
|
+
|
22
|
+
import brainunit as u
|
23
|
+
import jax
|
24
|
+
import jax.numpy as jnp
|
25
|
+
from jax.experimental.sparse.coo import coo_matvec_p, coo_matmat_p, COOInfo
|
26
|
+
from jax.experimental.sparse.csr import csr_matvec_p, csr_matmat_p
|
27
|
+
|
28
|
+
from brainstate import init, functional
|
29
|
+
from brainstate._state import ParamState
|
30
|
+
from brainstate.nn._module import Module
|
31
|
+
from brainstate.typing import ArrayLike, Size
|
32
|
+
|
33
|
+
__all__ = [
|
34
|
+
'Linear',
|
35
|
+
'ScaledWSLinear',
|
36
|
+
'SignedWLinear',
|
37
|
+
'CSRLinear',
|
38
|
+
'CSCLinear',
|
39
|
+
'COOLinear',
|
40
|
+
'AllToAll',
|
41
|
+
'OneToOne',
|
42
|
+
]
|
43
|
+
|
44
|
+
|
45
|
+
class Linear(Module):
|
46
|
+
"""
|
47
|
+
Linear layer.
|
48
|
+
"""
|
49
|
+
__module__ = 'brainstate.nn'
|
50
|
+
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
in_size: Size,
|
54
|
+
out_size: Size,
|
55
|
+
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
56
|
+
b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
|
57
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
58
|
+
name: Optional[str] = None,
|
59
|
+
):
|
60
|
+
super().__init__(name=name)
|
61
|
+
|
62
|
+
# input and output shape
|
63
|
+
self.in_size = in_size
|
64
|
+
self.out_size = out_size
|
65
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
66
|
+
'and "out_size" must be the same.')
|
67
|
+
|
68
|
+
# w_mask
|
69
|
+
self.w_mask = init.param(w_mask, self.in_size + self.out_size)
|
70
|
+
|
71
|
+
# weights
|
72
|
+
params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
|
73
|
+
if b_init is not None:
|
74
|
+
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
75
|
+
self.weight = ParamState(params)
|
76
|
+
|
77
|
+
def update(self, x):
|
78
|
+
params = self.weight.value
|
79
|
+
weight = params['weight']
|
80
|
+
if self.w_mask is not None:
|
81
|
+
weight = weight * self.w_mask
|
82
|
+
y = u.math.dot(x, weight)
|
83
|
+
if 'bias' in params:
|
84
|
+
y = y + params['bias']
|
85
|
+
return y
|
86
|
+
|
87
|
+
|
88
|
+
class SignedWLinear(Module):
|
89
|
+
"""
|
90
|
+
Linear layer with signed weights.
|
91
|
+
"""
|
92
|
+
__module__ = 'brainstate.nn'
|
93
|
+
|
94
|
+
def __init__(
|
95
|
+
self,
|
96
|
+
in_size: Size,
|
97
|
+
out_size: Size,
|
98
|
+
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
99
|
+
w_sign: Optional[ArrayLike] = None,
|
100
|
+
name: Optional[str] = None,
|
101
|
+
|
102
|
+
):
|
103
|
+
super().__init__(name=name)
|
104
|
+
|
105
|
+
# input and output shape
|
106
|
+
self.in_size = in_size
|
107
|
+
self.out_size = out_size
|
108
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
109
|
+
'and "out_size" must be the same.')
|
110
|
+
|
111
|
+
# w_mask
|
112
|
+
self.w_sign = w_sign
|
113
|
+
|
114
|
+
# weights
|
115
|
+
weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
|
116
|
+
self.weight = ParamState(weight)
|
117
|
+
|
118
|
+
def update(self, x):
|
119
|
+
w = self.weight.value
|
120
|
+
if self.w_sign is None:
|
121
|
+
return u.math.matmul(x, u.math.abs(w))
|
122
|
+
else:
|
123
|
+
return u.math.matmul(x, u.math.abs(w) * self.w_sign)
|
124
|
+
|
125
|
+
|
126
|
+
class ScaledWSLinear(Module):
|
127
|
+
"""
|
128
|
+
Linear Layer with Weight Standardization.
|
129
|
+
|
130
|
+
Applies weight standardization to the weights of the linear layer.
|
131
|
+
|
132
|
+
Parameters
|
133
|
+
----------
|
134
|
+
in_size: int, sequence of int
|
135
|
+
The input size.
|
136
|
+
out_size: int, sequence of int
|
137
|
+
The output size.
|
138
|
+
w_init: Callable, ArrayLike
|
139
|
+
The initializer for the weights.
|
140
|
+
b_init: Callable, ArrayLike
|
141
|
+
The initializer for the bias.
|
142
|
+
w_mask: ArrayLike, Callable
|
143
|
+
The optional mask of the weights.
|
144
|
+
ws_gain: bool
|
145
|
+
Whether to use gain for the weights. The default is True.
|
146
|
+
eps: float
|
147
|
+
The epsilon value for the weight standardization.
|
148
|
+
name: str
|
149
|
+
The name of the object.
|
150
|
+
|
151
|
+
"""
|
152
|
+
__module__ = 'brainstate.nn'
|
153
|
+
|
154
|
+
def __init__(
|
155
|
+
self,
|
156
|
+
in_size: Size,
|
157
|
+
out_size: Size,
|
158
|
+
w_init: Callable = init.KaimingNormal(),
|
159
|
+
b_init: Callable = init.ZeroInit(),
|
160
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
161
|
+
ws_gain: bool = True,
|
162
|
+
eps: float = 1e-4,
|
163
|
+
name: str = None,
|
164
|
+
):
|
165
|
+
super().__init__(name=name)
|
166
|
+
|
167
|
+
# input and output shape
|
168
|
+
self.in_size = in_size
|
169
|
+
self.out_size = out_size
|
170
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
171
|
+
'and "out_size" must be the same.')
|
172
|
+
|
173
|
+
# w_mask
|
174
|
+
self.w_mask = init.param(w_mask, (self.in_size[0], 1))
|
175
|
+
|
176
|
+
# parameters
|
177
|
+
self.eps = eps
|
178
|
+
|
179
|
+
# weights
|
180
|
+
params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
|
181
|
+
if b_init is not None:
|
182
|
+
params['bias'] = init.param(b_init, self.out_size, allow_none=False)
|
183
|
+
# gain
|
184
|
+
if ws_gain:
|
185
|
+
s = params['weight'].shape
|
186
|
+
params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
|
187
|
+
self.weight = ParamState(params)
|
188
|
+
|
189
|
+
def update(self, x):
|
190
|
+
params = self.weight.value
|
191
|
+
w = params['weight']
|
192
|
+
w = functional.weight_standardization(w, self.eps, params.get('gain', None))
|
193
|
+
if self.w_mask is not None:
|
194
|
+
w = w * self.w_mask
|
195
|
+
y = u.math.dot(x, w)
|
196
|
+
if 'bias' in params:
|
197
|
+
y = y + params['bias']
|
198
|
+
return y
|
199
|
+
|
200
|
+
|
201
|
+
def csr_matmat(data, indices, indptr, B: jax.Array, *, shape, transpose: bool = False) -> jax.Array:
|
202
|
+
"""Product of CSR sparse matrix and a dense matrix.
|
203
|
+
|
204
|
+
Args:
|
205
|
+
data : array of shape ``(nse,)``.
|
206
|
+
indices : array of shape ``(nse,)``
|
207
|
+
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
|
208
|
+
B : array of shape ``(mat.shape[0] if transpose else mat.shape[1], cols)`` and
|
209
|
+
dtype ``mat.dtype``
|
210
|
+
transpose : boolean specifying whether to transpose the sparse matrix
|
211
|
+
before computing.
|
212
|
+
|
213
|
+
Returns:
|
214
|
+
C : array of shape ``(mat.shape[1] if transpose else mat.shape[0], cols)``
|
215
|
+
representing the matrix vector product.
|
216
|
+
"""
|
217
|
+
return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose)
|
218
|
+
|
219
|
+
|
220
|
+
def csr_matvec(data, indices, indptr, v, *, shape, transpose=False) -> jax.Array:
|
221
|
+
"""Product of CSR sparse matrix and a dense vector.
|
222
|
+
|
223
|
+
Args:
|
224
|
+
data : array of shape ``(nse,)``.
|
225
|
+
indices : array of shape ``(nse,)``
|
226
|
+
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
|
227
|
+
v : array of shape ``(shape[0] if transpose else shape[1],)``
|
228
|
+
and dtype ``data.dtype``
|
229
|
+
shape : length-2 tuple representing the matrix shape
|
230
|
+
transpose : boolean specifying whether to transpose the sparse matrix
|
231
|
+
before computing.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
|
235
|
+
the matrix vector product.
|
236
|
+
"""
|
237
|
+
return csr_matvec_p.bind(data, indices, indptr, v, shape=shape, transpose=transpose)
|
238
|
+
|
239
|
+
|
240
|
+
class CSRLinear(Module):
|
241
|
+
"""
|
242
|
+
Linear layer with Compressed Sparse Row (CSR) matrix.
|
243
|
+
"""
|
244
|
+
__module__ = 'brainstate.nn'
|
245
|
+
|
246
|
+
def __init__(
|
247
|
+
self,
|
248
|
+
in_size: Size,
|
249
|
+
out_size: Size,
|
250
|
+
indptr: ArrayLike,
|
251
|
+
indices: ArrayLike,
|
252
|
+
weight: Union[Callable, ArrayLike],
|
253
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
254
|
+
name: Optional[str] = None,
|
255
|
+
):
|
256
|
+
super().__init__(name=name)
|
257
|
+
|
258
|
+
# input and output shape
|
259
|
+
self.in_size = in_size
|
260
|
+
self.out_size = out_size
|
261
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
262
|
+
'and "out_size" must be the same.')
|
263
|
+
|
264
|
+
# CSR data structure
|
265
|
+
indptr = jnp.asarray(indptr)
|
266
|
+
indices = jnp.asarray(indices)
|
267
|
+
assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
|
268
|
+
assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
|
269
|
+
assert indptr.size == self.in_size[-1] + 1, f"indptr must have size {self.in_size[-1] + 1}. Got: {indptr.size}"
|
270
|
+
with jax.ensure_compile_time_eval():
|
271
|
+
self.indptr = u.math.asarray(indptr)
|
272
|
+
self.indices = u.math.asarray(indices)
|
273
|
+
|
274
|
+
# weights
|
275
|
+
weight = init.param(weight, (len(indices),), allow_none=False, allow_scalar=False)
|
276
|
+
params = dict(weight=weight)
|
277
|
+
if b_init is not None:
|
278
|
+
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
279
|
+
self.weight = ParamState(params)
|
280
|
+
|
281
|
+
def update(self, x):
|
282
|
+
data = self.weight.value['weight']
|
283
|
+
data, w_unit = u.get_mantissa(data), u.get_unit(data)
|
284
|
+
x, x_unit = u.get_mantissa(x), u.get_unit(x)
|
285
|
+
shape = [self.in_size[-1], self.out_size[-1]]
|
286
|
+
if x.ndim == 1:
|
287
|
+
y = csr_matvec(data, self.indices, self.indptr, x, shape=shape)
|
288
|
+
elif x.ndim == 2:
|
289
|
+
y = csr_matmat(data, self.indices, self.indptr, x, shape=shape)
|
290
|
+
else:
|
291
|
+
raise NotImplementedError(f"matmul with object of shape {x.shape}")
|
292
|
+
y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
|
293
|
+
if 'bias' in self.weight.value:
|
294
|
+
y = y + self.weight.value['bias']
|
295
|
+
return y
|
296
|
+
|
297
|
+
|
298
|
+
class CSCLinear(Module):
|
299
|
+
"""
|
300
|
+
Linear layer with Compressed Sparse Column (CSC) matrix.
|
301
|
+
"""
|
302
|
+
__module__ = 'brainstate.nn'
|
303
|
+
|
304
|
+
def __init__(
|
305
|
+
self,
|
306
|
+
in_size: Size,
|
307
|
+
out_size: Size,
|
308
|
+
indptr: ArrayLike,
|
309
|
+
indices: ArrayLike,
|
310
|
+
weight: Union[Callable, ArrayLike],
|
311
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
312
|
+
name: Optional[str] = None,
|
313
|
+
):
|
314
|
+
super().__init__(name=name)
|
315
|
+
|
316
|
+
# input and output shape
|
317
|
+
self.in_size = in_size
|
318
|
+
self.out_size = out_size
|
319
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
320
|
+
'and "out_size" must be the same.')
|
321
|
+
|
322
|
+
# CSR data structure
|
323
|
+
indptr = jnp.asarray(indptr)
|
324
|
+
indices = jnp.asarray(indices)
|
325
|
+
assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
|
326
|
+
assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
|
327
|
+
assert indptr.size == self.in_size[-1] + 1, f"indptr must have size {self.in_size[-1] + 1}. Got: {indptr.size}"
|
328
|
+
with jax.ensure_compile_time_eval():
|
329
|
+
self.indptr = u.math.asarray(indptr)
|
330
|
+
self.indices = u.math.asarray(indices)
|
331
|
+
|
332
|
+
# weights
|
333
|
+
weight = init.param(weight, (len(indices),), allow_none=False, allow_scalar=False)
|
334
|
+
params = dict(weight=weight)
|
335
|
+
if b_init is not None:
|
336
|
+
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
337
|
+
self.weight = ParamState(params)
|
338
|
+
|
339
|
+
def update(self, x):
|
340
|
+
data = self.weight.value['weight']
|
341
|
+
data, w_unit = u.get_mantissa(data), u.get_unit(data)
|
342
|
+
x, x_unit = u.get_mantissa(x), u.get_unit(x)
|
343
|
+
shape = [self.out_size[-1], self.in_size[-1]]
|
344
|
+
if x.ndim == 1:
|
345
|
+
y = csr_matvec(data, self.indices, self.indptr, x, shape=shape, transpose=True)
|
346
|
+
elif x.ndim == 2:
|
347
|
+
y = csr_matmat(data, self.indices, self.indptr, x, shape=shape, transpose=True)
|
348
|
+
else:
|
349
|
+
raise NotImplementedError(f"matmul with object of shape {x.shape}")
|
350
|
+
y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
|
351
|
+
if 'bias' in self.weight.value:
|
352
|
+
y = y + self.weight.value['bias']
|
353
|
+
return y
|
354
|
+
|
355
|
+
|
356
|
+
def coo_matvec(
|
357
|
+
data: jax.Array,
|
358
|
+
row: jax.Array,
|
359
|
+
col: jax.Array,
|
360
|
+
v: jax.Array, *,
|
361
|
+
spinfo: COOInfo,
|
362
|
+
transpose: bool = False
|
363
|
+
) -> jax.Array:
|
364
|
+
"""Product of COO sparse matrix and a dense vector.
|
365
|
+
|
366
|
+
Args:
|
367
|
+
data : array of shape ``(nse,)``.
|
368
|
+
row : array of shape ``(nse,)``
|
369
|
+
col : array of shape ``(nse,)`` and dtype ``row.dtype``
|
370
|
+
v : array of shape ``(shape[0] if transpose else shape[1],)`` and
|
371
|
+
dtype ``data.dtype``
|
372
|
+
spinfo : COOInfo object containing the shape of the matrix and the dtype
|
373
|
+
transpose : boolean specifying whether to transpose the sparse matrix
|
374
|
+
before computing.
|
375
|
+
|
376
|
+
Returns:
|
377
|
+
y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
|
378
|
+
the matrix vector product.
|
379
|
+
"""
|
380
|
+
return coo_matvec_p.bind(data, row, col, v, spinfo=spinfo, transpose=transpose)
|
381
|
+
|
382
|
+
|
383
|
+
def coo_matmat(
|
384
|
+
data: jax.Array, row: jax.Array, col: jax.Array, B: jax.Array, *,
|
385
|
+
spinfo: COOInfo, transpose: bool = False
|
386
|
+
) -> jax.Array:
|
387
|
+
"""Product of COO sparse matrix and a dense matrix.
|
388
|
+
|
389
|
+
Args:
|
390
|
+
data : array of shape ``(nse,)``.
|
391
|
+
row : array of shape ``(nse,)``
|
392
|
+
col : array of shape ``(nse,)`` and dtype ``row.dtype``
|
393
|
+
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
|
394
|
+
dtype ``data.dtype``
|
395
|
+
spinfo : COOInfo object containing the shape of the matrix and the dtype
|
396
|
+
transpose : boolean specifying whether to transpose the sparse matrix
|
397
|
+
before computing.
|
398
|
+
|
399
|
+
Returns:
|
400
|
+
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
|
401
|
+
representing the matrix vector product.
|
402
|
+
"""
|
403
|
+
return coo_matmat_p.bind(data, row, col, B, spinfo=spinfo, transpose=transpose)
|
404
|
+
|
405
|
+
|
406
|
+
class COOLinear(Module):
|
407
|
+
|
408
|
+
def __init__(
|
409
|
+
self,
|
410
|
+
in_size: Size,
|
411
|
+
out_size: Size,
|
412
|
+
row: ArrayLike,
|
413
|
+
col: ArrayLike,
|
414
|
+
weight: Union[Callable, ArrayLike],
|
415
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
416
|
+
rows_sorted: bool = False,
|
417
|
+
cols_sorted: bool = False,
|
418
|
+
name: Optional[str] = None,
|
419
|
+
):
|
420
|
+
super().__init__(name=name)
|
421
|
+
|
422
|
+
# input and output shape
|
423
|
+
self.in_size = in_size
|
424
|
+
self.out_size = out_size
|
425
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
426
|
+
'and "out_size" must be the same.')
|
427
|
+
|
428
|
+
# COO data structure
|
429
|
+
row = jnp.asarray(row)
|
430
|
+
col = jnp.asarray(col)
|
431
|
+
assert row.ndim == 1, f"row must be 1D. Got: {row.ndim}"
|
432
|
+
assert col.ndim == 1, f"col must be 1D. Got: {col.ndim}"
|
433
|
+
assert row.size == col.size, f"row and col must have the same size. Got: {row.size} and {col.size}"
|
434
|
+
with jax.ensure_compile_time_eval():
|
435
|
+
self.row = u.math.asarray(row)
|
436
|
+
self.col = u.math.asarray(col)
|
437
|
+
|
438
|
+
# COO structure information
|
439
|
+
self.rows_sorted = rows_sorted
|
440
|
+
self.cols_sorted = cols_sorted
|
441
|
+
|
442
|
+
# weights
|
443
|
+
weight = init.param(weight, (len(row),), allow_none=False, allow_scalar=False)
|
444
|
+
params = dict(weight=weight)
|
445
|
+
if b_init is not None:
|
446
|
+
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
447
|
+
self.weight = ParamState(params)
|
448
|
+
|
449
|
+
def update(self, x):
|
450
|
+
data = self.weight.value['weight']
|
451
|
+
data, w_unit = u.get_mantissa(data), u.get_unit(data)
|
452
|
+
x, x_unit = u.get_mantissa(x), u.get_unit(x)
|
453
|
+
spinfo = COOInfo(
|
454
|
+
shape=(self.in_size[-1], self.out_size[-1]),
|
455
|
+
rows_sorted=self.rows_sorted,
|
456
|
+
cols_sorted=self.cols_sorted
|
457
|
+
)
|
458
|
+
if x.ndim == 1:
|
459
|
+
y = coo_matvec(data, self.row, self.col, x, spinfo=spinfo, transpose=False)
|
460
|
+
elif x.ndim == 2:
|
461
|
+
y = coo_matmat(data, self.row, self.col, x, spinfo=spinfo, transpose=False)
|
462
|
+
else:
|
463
|
+
raise NotImplementedError(f"matmul with object of shape {x.shape}")
|
464
|
+
y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
|
465
|
+
if 'bias' in self.weight.value:
|
466
|
+
y = y + self.weight.value['bias']
|
467
|
+
return y
|
468
|
+
|
469
|
+
|
470
|
+
class AllToAll(Module):
|
471
|
+
"""
|
472
|
+
Synaptic matrix multiplication with All-to-All connections.
|
473
|
+
|
474
|
+
Args:
|
475
|
+
in_size: Size. The number of neurons in the pre-synaptic neuron group.
|
476
|
+
out_size: Size. The number of neurons in the postsynaptic neuron group.
|
477
|
+
w_init: The synaptic weight initializer.
|
478
|
+
include_self: bool. Whether connect the neuron with at the same position.
|
479
|
+
name: str. The object name.
|
480
|
+
"""
|
481
|
+
|
482
|
+
def __init__(
|
483
|
+
self,
|
484
|
+
in_size: Size,
|
485
|
+
out_size: Size,
|
486
|
+
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
487
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
488
|
+
include_self: bool = True,
|
489
|
+
name: Optional[str] = None,
|
490
|
+
):
|
491
|
+
super().__init__(name=name)
|
492
|
+
|
493
|
+
# input and output shape
|
494
|
+
self.in_size = in_size
|
495
|
+
self.out_size = out_size
|
496
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
497
|
+
'and "out_size" must be the same.')
|
498
|
+
|
499
|
+
# others
|
500
|
+
self.include_self = include_self
|
501
|
+
|
502
|
+
# weights
|
503
|
+
weight = init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False)
|
504
|
+
params = dict(weight=weight)
|
505
|
+
if b_init is not None:
|
506
|
+
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
507
|
+
self.weight = ParamState(params)
|
508
|
+
|
509
|
+
def update(self, pre_val):
|
510
|
+
params = self.weight.value
|
511
|
+
pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
|
512
|
+
w_val, w_unit = u.get_mantissa(params['weight']), u.get_unit(params['weight'])
|
513
|
+
|
514
|
+
if u.math.ndim(w_val) == 0: # weight is a scalar
|
515
|
+
if pre_val.ndim == 1:
|
516
|
+
post_val = u.math.sum(pre_val)
|
517
|
+
else:
|
518
|
+
post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
|
519
|
+
if not self.include_self:
|
520
|
+
if self.in_size == self.out_size:
|
521
|
+
post_val = post_val - pre_val
|
522
|
+
elif self.in_size[-1] > self.out_size[-1]:
|
523
|
+
val = pre_val[..., :self.out_size[-1]]
|
524
|
+
post_val = post_val - val
|
525
|
+
else:
|
526
|
+
size = list(self.out_size)
|
527
|
+
size[-1] = self.out_size[-1] - self.in_size[-1]
|
528
|
+
val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
|
529
|
+
post_val = post_val - val
|
530
|
+
post_val = w_val * post_val
|
531
|
+
|
532
|
+
else: # weight is a matrix
|
533
|
+
assert u.math.ndim(w_val) == 2, '"weight" must be a 2D matrix.'
|
534
|
+
if not self.include_self:
|
535
|
+
post_val = pre_val @ u.math.fill_diagonal(w_val, 0.)
|
536
|
+
else:
|
537
|
+
post_val = pre_val @ w_val
|
538
|
+
|
539
|
+
post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
|
540
|
+
if 'bias' in params:
|
541
|
+
post_val = post_val + params['bias']
|
542
|
+
return post_val
|
543
|
+
|
544
|
+
|
545
|
+
class OneToOne(Module):
|
546
|
+
"""
|
547
|
+
Synaptic matrix multiplication with One2One connection.
|
548
|
+
|
549
|
+
Args:
|
550
|
+
in_size: Size. The number of neurons in the pre-synaptic neuron group.
|
551
|
+
w_init: The synaptic weight initializer.
|
552
|
+
b_init: The synaptic bias initializer.
|
553
|
+
name: str. The object name.
|
554
|
+
"""
|
555
|
+
|
556
|
+
def __init__(
|
557
|
+
self,
|
558
|
+
in_size: Size,
|
559
|
+
w_init: Union[Callable, ArrayLike] = init.Normal(),
|
560
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
561
|
+
name: Optional[str] = None,
|
562
|
+
):
|
563
|
+
super().__init__(name=name)
|
564
|
+
|
565
|
+
# input and output shape
|
566
|
+
self.in_size = in_size
|
567
|
+
self.out_size = in_size
|
568
|
+
|
569
|
+
# weights
|
570
|
+
param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
|
571
|
+
if b_init is not None:
|
572
|
+
param['bias'] = init.param(b_init, self.out_size, allow_none=False)
|
573
|
+
self.weight = param
|
574
|
+
|
575
|
+
def update(self, pre_val):
|
576
|
+
pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
|
577
|
+
w_val, w_unit = u.get_mantissa(self.weight['weight']), u.get_unit(self.weight['weight'])
|
578
|
+
post_val = pre_val * w_val
|
579
|
+
post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
|
580
|
+
if 'bias' in self.weight:
|
581
|
+
post_val = post_val + self.weight['bias']
|
582
|
+
return post_val
|
@@ -0,0 +1,42 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from __future__ import annotations
|
18
|
+
|
19
|
+
import jax.numpy as jnp
|
20
|
+
import pytest
|
21
|
+
from absl.testing import absltest
|
22
|
+
from absl.testing import parameterized
|
23
|
+
|
24
|
+
import brainstate as bst
|
25
|
+
|
26
|
+
|
27
|
+
|
28
|
+
|
29
|
+
|
30
|
+
class TestDense(parameterized.TestCase):
|
31
|
+
@parameterized.product(
|
32
|
+
size=[(10,),
|
33
|
+
(20, 10),
|
34
|
+
(5, 8, 10)],
|
35
|
+
num_out=[20, ]
|
36
|
+
)
|
37
|
+
def test_Dense1(self, size, num_out):
|
38
|
+
f = bst.nn.Linear(10, num_out)
|
39
|
+
x = bst.random.random(size)
|
40
|
+
y = f(x)
|
41
|
+
self.assertTrue(y.shape == size[:-1] + (num_out,))
|
42
|
+
|