brainstate 0.1.8__py2.py3-none-any.whl → 0.1.10__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +58 -51
- brainstate/_compatible_import.py +148 -148
- brainstate/_state.py +1605 -1663
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/augment/__init__.py +30 -30
- brainstate/augment/_autograd.py +778 -778
- brainstate/augment/_autograd_test.py +1289 -1289
- brainstate/augment/_eval_shape.py +99 -99
- brainstate/augment/_eval_shape_test.py +38 -38
- brainstate/augment/_mapping.py +1060 -1060
- brainstate/augment/_mapping_test.py +597 -597
- brainstate/augment/_random.py +151 -151
- brainstate/compile/__init__.py +38 -38
- brainstate/compile/_ad_checkpoint.py +204 -204
- brainstate/compile/_ad_checkpoint_test.py +49 -49
- brainstate/compile/_conditions.py +256 -256
- brainstate/compile/_conditions_test.py +220 -220
- brainstate/compile/_error_if.py +92 -92
- brainstate/compile/_error_if_test.py +52 -52
- brainstate/compile/_jit.py +346 -346
- brainstate/compile/_jit_test.py +143 -143
- brainstate/compile/_loop_collect_return.py +536 -536
- brainstate/compile/_loop_collect_return_test.py +58 -58
- brainstate/compile/_loop_no_collection.py +184 -184
- brainstate/compile/_loop_no_collection_test.py +50 -50
- brainstate/compile/_make_jaxpr.py +888 -888
- brainstate/compile/_make_jaxpr_test.py +156 -156
- brainstate/compile/_progress_bar.py +202 -202
- brainstate/compile/_unvmap.py +159 -159
- brainstate/compile/_util.py +147 -147
- brainstate/environ.py +563 -563
- brainstate/environ_test.py +62 -62
- brainstate/functional/__init__.py +27 -26
- brainstate/graph/__init__.py +29 -29
- brainstate/graph/_graph_node.py +244 -244
- brainstate/graph/_graph_node_test.py +73 -73
- brainstate/graph/_graph_operation.py +1738 -1738
- brainstate/graph/_graph_operation_test.py +563 -563
- brainstate/init/__init__.py +26 -26
- brainstate/init/_base.py +52 -52
- brainstate/init/_generic.py +244 -244
- brainstate/init/_random_inits.py +553 -553
- brainstate/init/_random_inits_test.py +149 -149
- brainstate/init/_regular_inits.py +105 -105
- brainstate/init/_regular_inits_test.py +50 -50
- brainstate/mixin.py +365 -363
- brainstate/mixin_test.py +77 -73
- brainstate/nn/__init__.py +135 -131
- brainstate/{functional → nn}/_activations.py +808 -813
- brainstate/{functional → nn}/_activations_test.py +331 -331
- brainstate/nn/_collective_ops.py +514 -514
- brainstate/nn/_collective_ops_test.py +43 -43
- brainstate/nn/_common.py +178 -178
- brainstate/nn/_conv.py +501 -501
- brainstate/nn/_conv_test.py +238 -238
- brainstate/nn/_delay.py +588 -502
- brainstate/nn/_delay_test.py +238 -184
- brainstate/nn/_dropout.py +426 -426
- brainstate/nn/_dropout_test.py +100 -100
- brainstate/nn/_dynamics.py +1343 -1343
- brainstate/nn/_dynamics_test.py +78 -78
- brainstate/nn/_elementwise.py +1119 -1119
- brainstate/nn/_elementwise_test.py +169 -169
- brainstate/nn/_embedding.py +58 -58
- brainstate/nn/_exp_euler.py +92 -92
- brainstate/nn/_exp_euler_test.py +35 -35
- brainstate/nn/_fixedprob.py +239 -239
- brainstate/nn/_fixedprob_test.py +114 -114
- brainstate/nn/_inputs.py +608 -608
- brainstate/nn/_linear.py +424 -424
- brainstate/nn/_linear_mv.py +83 -83
- brainstate/nn/_linear_mv_test.py +120 -120
- brainstate/nn/_linear_test.py +107 -107
- brainstate/nn/_ltp.py +28 -28
- brainstate/nn/_module.py +377 -377
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_neuron.py +705 -705
- brainstate/nn/_neuron_test.py +161 -161
- brainstate/nn/_normalizations.py +975 -918
- brainstate/nn/_normalizations_test.py +73 -73
- brainstate/{functional → nn}/_others.py +46 -46
- brainstate/nn/_poolings.py +1177 -1177
- brainstate/nn/_poolings_test.py +217 -217
- brainstate/nn/_projection.py +486 -486
- brainstate/nn/_rate_rnns.py +554 -554
- brainstate/nn/_rate_rnns_test.py +63 -63
- brainstate/nn/_readout.py +209 -209
- brainstate/nn/_readout_test.py +53 -53
- brainstate/nn/_stp.py +236 -236
- brainstate/nn/_synapse.py +505 -505
- brainstate/nn/_synapse_test.py +131 -131
- brainstate/nn/_synaptic_projection.py +423 -423
- brainstate/nn/_synouts.py +162 -162
- brainstate/nn/_synouts_test.py +57 -57
- brainstate/nn/_utils.py +89 -89
- brainstate/nn/metrics.py +388 -388
- brainstate/optim/__init__.py +38 -38
- brainstate/optim/_base.py +64 -64
- brainstate/optim/_lr_scheduler.py +448 -448
- brainstate/optim/_lr_scheduler_test.py +50 -50
- brainstate/optim/_optax_optimizer.py +152 -152
- brainstate/optim/_optax_optimizer_test.py +53 -53
- brainstate/optim/_sgd_optimizer.py +1104 -1104
- brainstate/random/__init__.py +24 -24
- brainstate/random/_rand_funs.py +3616 -3616
- brainstate/random/_rand_funs_test.py +567 -567
- brainstate/random/_rand_seed.py +210 -210
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1409 -1409
- brainstate/random/_random_for_unit.py +52 -52
- brainstate/surrogate.py +1957 -1957
- brainstate/transform.py +23 -23
- brainstate/typing.py +304 -304
- brainstate/util/__init__.py +50 -50
- brainstate/util/caller.py +98 -98
- brainstate/util/error.py +55 -55
- brainstate/util/filter.py +469 -469
- brainstate/util/others.py +540 -540
- brainstate/util/pretty_pytree.py +945 -945
- brainstate/util/pretty_pytree_test.py +159 -159
- brainstate/util/pretty_repr.py +328 -328
- brainstate/util/pretty_table.py +2954 -2954
- brainstate/util/scaling.py +258 -258
- brainstate/util/struct.py +523 -523
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/METADATA +91 -99
- brainstate-0.1.10.dist-info/RECORD +130 -0
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/WHEEL +1 -1
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info/licenses}/LICENSE +202 -202
- brainstate/functional/_normalization.py +0 -81
- brainstate/functional/_spikes.py +0 -204
- brainstate-0.1.8.dist-info/RECORD +0 -132
- {brainstate-0.1.8.dist-info → brainstate-0.1.10.dist-info}/top_level.txt +0 -0
brainstate/nn/_linear.py
CHANGED
@@ -1,424 +1,424 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from typing import Callable, Union, Optional
|
19
|
-
|
20
|
-
import brainunit as u
|
21
|
-
import jax.numpy as jnp
|
22
|
-
|
23
|
-
from brainstate import init, functional
|
24
|
-
from brainstate._state import ParamState
|
25
|
-
from brainstate.typing import ArrayLike, Size
|
26
|
-
from ._module import Module
|
27
|
-
|
28
|
-
__all__ = [
|
29
|
-
'Linear',
|
30
|
-
'ScaledWSLinear',
|
31
|
-
'SignedWLinear',
|
32
|
-
'SparseLinear',
|
33
|
-
'AllToAll',
|
34
|
-
'OneToOne',
|
35
|
-
'LoRA',
|
36
|
-
]
|
37
|
-
|
38
|
-
|
39
|
-
class Linear(Module):
|
40
|
-
"""
|
41
|
-
Linear layer.
|
42
|
-
"""
|
43
|
-
__module__ = 'brainstate.nn'
|
44
|
-
|
45
|
-
def __init__(
|
46
|
-
self,
|
47
|
-
in_size: Size,
|
48
|
-
out_size: Size,
|
49
|
-
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
50
|
-
b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
|
51
|
-
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
52
|
-
name: Optional[str] = None,
|
53
|
-
param_type: type = ParamState,
|
54
|
-
):
|
55
|
-
super().__init__(name=name)
|
56
|
-
|
57
|
-
# input and output shape
|
58
|
-
self.in_size = in_size
|
59
|
-
self.out_size = out_size
|
60
|
-
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
61
|
-
'and "out_size" must be the same.')
|
62
|
-
|
63
|
-
# w_mask
|
64
|
-
self.w_mask = init.param(w_mask, self.in_size + self.out_size)
|
65
|
-
|
66
|
-
# weights
|
67
|
-
params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
|
68
|
-
if b_init is not None:
|
69
|
-
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
70
|
-
self.weight = param_type(params)
|
71
|
-
|
72
|
-
def update(self, x):
|
73
|
-
params = self.weight.value
|
74
|
-
weight = params['weight']
|
75
|
-
if self.w_mask is not None:
|
76
|
-
weight = weight * self.w_mask
|
77
|
-
y = u.linalg.dot(x, weight)
|
78
|
-
if 'bias' in params:
|
79
|
-
y = y + params['bias']
|
80
|
-
return y
|
81
|
-
|
82
|
-
|
83
|
-
class SignedWLinear(Module):
|
84
|
-
"""
|
85
|
-
Linear layer with signed weights.
|
86
|
-
"""
|
87
|
-
__module__ = 'brainstate.nn'
|
88
|
-
|
89
|
-
def __init__(
|
90
|
-
self,
|
91
|
-
in_size: Size,
|
92
|
-
out_size: Size,
|
93
|
-
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
94
|
-
w_sign: Optional[ArrayLike] = None,
|
95
|
-
name: Optional[str] = None,
|
96
|
-
param_type: type = ParamState,
|
97
|
-
):
|
98
|
-
super().__init__(name=name)
|
99
|
-
|
100
|
-
# input and output shape
|
101
|
-
self.in_size = in_size
|
102
|
-
self.out_size = out_size
|
103
|
-
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
104
|
-
'and "out_size" must be the same.')
|
105
|
-
|
106
|
-
# w_mask
|
107
|
-
self.w_sign = w_sign
|
108
|
-
|
109
|
-
# weights
|
110
|
-
weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
|
111
|
-
self.weight = param_type(weight)
|
112
|
-
|
113
|
-
def update(self, x):
|
114
|
-
w = self.weight.value
|
115
|
-
if self.w_sign is None:
|
116
|
-
return u.math.matmul(x, u.math.abs(w))
|
117
|
-
else:
|
118
|
-
return u.math.matmul(x, u.math.abs(w) * self.w_sign)
|
119
|
-
|
120
|
-
|
121
|
-
class ScaledWSLinear(Module):
|
122
|
-
"""
|
123
|
-
Linear Layer with Weight Standardization.
|
124
|
-
|
125
|
-
Applies weight standardization to the weights of the linear layer.
|
126
|
-
|
127
|
-
Parameters
|
128
|
-
----------
|
129
|
-
in_size: int, sequence of int
|
130
|
-
The input size.
|
131
|
-
out_size: int, sequence of int
|
132
|
-
The output size.
|
133
|
-
w_init: Callable, ArrayLike
|
134
|
-
The initializer for the weights.
|
135
|
-
b_init: Callable, ArrayLike
|
136
|
-
The initializer for the bias.
|
137
|
-
w_mask: ArrayLike, Callable
|
138
|
-
The optional mask of the weights.
|
139
|
-
ws_gain: bool
|
140
|
-
Whether to use gain for the weights. The default is True.
|
141
|
-
eps: float
|
142
|
-
The epsilon value for the weight standardization.
|
143
|
-
name: str
|
144
|
-
The name of the object.
|
145
|
-
|
146
|
-
"""
|
147
|
-
__module__ = 'brainstate.nn'
|
148
|
-
|
149
|
-
def __init__(
|
150
|
-
self,
|
151
|
-
in_size: Size,
|
152
|
-
out_size: Size,
|
153
|
-
w_init: Callable = init.KaimingNormal(),
|
154
|
-
b_init: Callable = init.ZeroInit(),
|
155
|
-
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
156
|
-
ws_gain: bool = True,
|
157
|
-
eps: float = 1e-4,
|
158
|
-
name: str = None,
|
159
|
-
param_type: type = ParamState,
|
160
|
-
):
|
161
|
-
super().__init__(name=name)
|
162
|
-
|
163
|
-
# input and output shape
|
164
|
-
self.in_size = in_size
|
165
|
-
self.out_size = out_size
|
166
|
-
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
167
|
-
'and "out_size" must be the same.')
|
168
|
-
|
169
|
-
# w_mask
|
170
|
-
self.w_mask = init.param(w_mask, (self.in_size[0], 1))
|
171
|
-
|
172
|
-
# parameters
|
173
|
-
self.eps = eps
|
174
|
-
|
175
|
-
# weights
|
176
|
-
params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
|
177
|
-
if b_init is not None:
|
178
|
-
params['bias'] = init.param(b_init, self.out_size, allow_none=False)
|
179
|
-
# gain
|
180
|
-
if ws_gain:
|
181
|
-
s = params['weight'].shape
|
182
|
-
params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
|
183
|
-
self.weight = param_type(params)
|
184
|
-
|
185
|
-
def update(self, x):
|
186
|
-
params = self.weight.value
|
187
|
-
w = params['weight']
|
188
|
-
w = functional.weight_standardization(w, self.eps, params.get('gain', None))
|
189
|
-
if self.w_mask is not None:
|
190
|
-
w = w * self.w_mask
|
191
|
-
y = u.linalg.dot(x, w)
|
192
|
-
if 'bias' in params:
|
193
|
-
y = y + params['bias']
|
194
|
-
return y
|
195
|
-
|
196
|
-
|
197
|
-
class SparseLinear(Module):
|
198
|
-
"""
|
199
|
-
Linear layer with Sparse Matrix (can be ``brainunit.sparse.CSR``,
|
200
|
-
``brainunit.sparse.CSC``, ``brainunit.sparse.COO``, or any other sparse matrix).
|
201
|
-
|
202
|
-
Args:
|
203
|
-
spar_mat: SparseMatrix. The sparse weight matrix.
|
204
|
-
in_size: Size. The input size.
|
205
|
-
name: str. The object name.
|
206
|
-
"""
|
207
|
-
__module__ = 'brainstate.nn'
|
208
|
-
|
209
|
-
def __init__(
|
210
|
-
self,
|
211
|
-
spar_mat: u.sparse.SparseMatrix,
|
212
|
-
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
213
|
-
in_size: Size = None,
|
214
|
-
name: Optional[str] = None,
|
215
|
-
param_type: type = ParamState,
|
216
|
-
):
|
217
|
-
super().__init__(name=name)
|
218
|
-
|
219
|
-
# input and output shape
|
220
|
-
if in_size is not None:
|
221
|
-
self.in_size = in_size
|
222
|
-
self.out_size = spar_mat.shape[-1]
|
223
|
-
if in_size is not None:
|
224
|
-
assert self.in_size[:-1] == self.out_size[:-1], (
|
225
|
-
'The first n-1 dimensions of "in_size" '
|
226
|
-
'and "out_size" must be the same.'
|
227
|
-
)
|
228
|
-
|
229
|
-
# weights
|
230
|
-
assert isinstance(spar_mat, u.sparse.SparseMatrix), '"weight" must be a SparseMatrix.'
|
231
|
-
self.spar_mat = spar_mat
|
232
|
-
params = dict(weight=spar_mat.data)
|
233
|
-
if b_init is not None:
|
234
|
-
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
235
|
-
self.weight = param_type(params)
|
236
|
-
|
237
|
-
def update(self, x):
|
238
|
-
data = self.weight.value['weight']
|
239
|
-
y = x @ self.spar_mat.with_data(data)
|
240
|
-
if 'bias' in self.weight.value:
|
241
|
-
y = y + self.weight.value['bias']
|
242
|
-
return y
|
243
|
-
|
244
|
-
|
245
|
-
class AllToAll(Module):
|
246
|
-
"""
|
247
|
-
Synaptic matrix multiplication with All-to-All connections.
|
248
|
-
|
249
|
-
Args:
|
250
|
-
in_size: Size. The number of neurons in the pre-synaptic neuron group.
|
251
|
-
out_size: Size. The number of neurons in the postsynaptic neuron group.
|
252
|
-
w_init: The synaptic weight initializer.
|
253
|
-
include_self: bool. Whether connect the neuron with at the same position.
|
254
|
-
name: str. The object name.
|
255
|
-
"""
|
256
|
-
|
257
|
-
def __init__(
|
258
|
-
self,
|
259
|
-
in_size: Size,
|
260
|
-
out_size: Size,
|
261
|
-
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
262
|
-
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
263
|
-
include_self: bool = True,
|
264
|
-
name: Optional[str] = None,
|
265
|
-
param_type: type = ParamState,
|
266
|
-
):
|
267
|
-
super().__init__(name=name)
|
268
|
-
|
269
|
-
# input and output shape
|
270
|
-
self.in_size = in_size
|
271
|
-
self.out_size = out_size
|
272
|
-
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
273
|
-
'and "out_size" must be the same.')
|
274
|
-
|
275
|
-
# others
|
276
|
-
self.include_self = include_self
|
277
|
-
|
278
|
-
# weights
|
279
|
-
weight = init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False)
|
280
|
-
params = dict(weight=weight)
|
281
|
-
if b_init is not None:
|
282
|
-
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
283
|
-
self.weight = param_type(params)
|
284
|
-
|
285
|
-
def update(self, pre_val):
|
286
|
-
params = self.weight.value
|
287
|
-
pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
|
288
|
-
w_val, w_unit = u.get_mantissa(params['weight']), u.get_unit(params['weight'])
|
289
|
-
|
290
|
-
if u.math.ndim(w_val) == 0: # weight is a scalar
|
291
|
-
if pre_val.ndim == 1:
|
292
|
-
post_val = u.math.sum(pre_val)
|
293
|
-
else:
|
294
|
-
post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
|
295
|
-
if not self.include_self:
|
296
|
-
if self.in_size == self.out_size:
|
297
|
-
post_val = post_val - pre_val
|
298
|
-
elif self.in_size[-1] > self.out_size[-1]:
|
299
|
-
val = pre_val[..., :self.out_size[-1]]
|
300
|
-
post_val = post_val - val
|
301
|
-
else:
|
302
|
-
size = list(self.out_size)
|
303
|
-
size[-1] = self.out_size[-1] - self.in_size[-1]
|
304
|
-
val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
|
305
|
-
post_val = post_val - val
|
306
|
-
post_val = w_val * post_val
|
307
|
-
|
308
|
-
else: # weight is a matrix
|
309
|
-
assert u.math.ndim(w_val) == 2, '"weight" must be a 2D matrix.'
|
310
|
-
if not self.include_self:
|
311
|
-
post_val = pre_val @ u.math.fill_diagonal(w_val, 0.)
|
312
|
-
else:
|
313
|
-
post_val = pre_val @ w_val
|
314
|
-
|
315
|
-
post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
|
316
|
-
if 'bias' in params:
|
317
|
-
post_val = post_val + params['bias']
|
318
|
-
return post_val
|
319
|
-
|
320
|
-
|
321
|
-
class OneToOne(Module):
|
322
|
-
"""
|
323
|
-
Synaptic matrix multiplication with One2One connection.
|
324
|
-
|
325
|
-
Args:
|
326
|
-
in_size: Size. The number of neurons in the pre-synaptic neuron group.
|
327
|
-
w_init: The synaptic weight initializer.
|
328
|
-
b_init: The synaptic bias initializer.
|
329
|
-
name: str. The object name.
|
330
|
-
"""
|
331
|
-
|
332
|
-
def __init__(
|
333
|
-
self,
|
334
|
-
in_size: Size,
|
335
|
-
w_init: Union[Callable, ArrayLike] = init.Normal(),
|
336
|
-
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
337
|
-
name: Optional[str] = None,
|
338
|
-
param_type: type = ParamState,
|
339
|
-
):
|
340
|
-
super().__init__(name=name)
|
341
|
-
|
342
|
-
# input and output shape
|
343
|
-
self.in_size = in_size
|
344
|
-
self.out_size = in_size
|
345
|
-
|
346
|
-
# weights
|
347
|
-
param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
|
348
|
-
if b_init is not None:
|
349
|
-
param['bias'] = init.param(b_init, self.out_size, allow_none=False)
|
350
|
-
self.weight = param_type(param)
|
351
|
-
|
352
|
-
def update(self, pre_val):
|
353
|
-
post_val = pre_val * self.weight.value['weight']
|
354
|
-
if 'bias' in self.weight.value:
|
355
|
-
post_val = post_val + self.weight.value['bias']
|
356
|
-
return post_val
|
357
|
-
|
358
|
-
|
359
|
-
class LoRA(Module):
|
360
|
-
"""A standalone LoRA layer.
|
361
|
-
|
362
|
-
Example usage::
|
363
|
-
|
364
|
-
>>> import brainstate as brainstate
|
365
|
-
>>> import jax, jax.numpy as jnp
|
366
|
-
>>> layer = brainstate.nn.LoRA(3, 2, 4)
|
367
|
-
>>> layer.weight.value
|
368
|
-
{'lora_a': Array([[ 0.25141352, -0.09826107],
|
369
|
-
[ 0.2328382 , 0.38869813],
|
370
|
-
[ 0.27069277, 0.7678282 ]], dtype=float32),
|
371
|
-
'lora_b': Array([[-0.8372317 , 0.21012013, -0.52999765, -0.31939325],
|
372
|
-
[ 0.64234126, -0.42980042, 1.2549229 , -0.47134295]], dtype=float32)}
|
373
|
-
>>> # Wrap around existing layer
|
374
|
-
>>> linear = brainstate.nn.Linear(3, 4)
|
375
|
-
>>> wrapper = brainstate.nn.LoRA(3, 2, 4, base_module=linear)
|
376
|
-
>>> assert wrapper.base_module == linear
|
377
|
-
>>> y = layer(jnp.ones((16, 3)))
|
378
|
-
>>> y.shape
|
379
|
-
(16, 4)
|
380
|
-
|
381
|
-
Args:
|
382
|
-
in_features: the number of input features.
|
383
|
-
lora_rank: the rank of the LoRA dimension.
|
384
|
-
out_features: the number of output features.
|
385
|
-
base_module: a base module to call and substitute, if possible.
|
386
|
-
kernel_init: initializer function for the weight matrices.
|
387
|
-
param_type: the type of the LoRA params.
|
388
|
-
"""
|
389
|
-
|
390
|
-
def __init__(
|
391
|
-
self,
|
392
|
-
in_features: int,
|
393
|
-
lora_rank: int,
|
394
|
-
out_features: int,
|
395
|
-
*,
|
396
|
-
base_module: Optional[Module] = None,
|
397
|
-
kernel_init: Union[Callable, ArrayLike] = init.LecunNormal(),
|
398
|
-
param_type: type = ParamState,
|
399
|
-
):
|
400
|
-
super().__init__()
|
401
|
-
|
402
|
-
# input and output shape
|
403
|
-
self.in_size = in_features
|
404
|
-
self.out_size = out_features
|
405
|
-
self.in_features = in_features
|
406
|
-
self.out_features = out_features
|
407
|
-
|
408
|
-
# others
|
409
|
-
self.base_module = base_module
|
410
|
-
|
411
|
-
# weights
|
412
|
-
param = dict(
|
413
|
-
lora_a=kernel_init((in_features, lora_rank)),
|
414
|
-
lora_b=kernel_init((lora_rank, out_features))
|
415
|
-
)
|
416
|
-
self.weight = param_type(param)
|
417
|
-
|
418
|
-
def __call__(self, x: ArrayLike):
|
419
|
-
out = x @ self.weight.value['lora_a'] @ self.weight.value['lora_b']
|
420
|
-
if self.base_module is not None:
|
421
|
-
if not callable(self.base_module):
|
422
|
-
raise ValueError('`self.base_module` must be callable.')
|
423
|
-
out += self.base_module(x)
|
424
|
-
return out
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from typing import Callable, Union, Optional
|
19
|
+
|
20
|
+
import brainunit as u
|
21
|
+
import jax.numpy as jnp
|
22
|
+
|
23
|
+
from brainstate import init, functional
|
24
|
+
from brainstate._state import ParamState
|
25
|
+
from brainstate.typing import ArrayLike, Size
|
26
|
+
from ._module import Module
|
27
|
+
|
28
|
+
__all__ = [
|
29
|
+
'Linear',
|
30
|
+
'ScaledWSLinear',
|
31
|
+
'SignedWLinear',
|
32
|
+
'SparseLinear',
|
33
|
+
'AllToAll',
|
34
|
+
'OneToOne',
|
35
|
+
'LoRA',
|
36
|
+
]
|
37
|
+
|
38
|
+
|
39
|
+
class Linear(Module):
|
40
|
+
"""
|
41
|
+
Linear layer.
|
42
|
+
"""
|
43
|
+
__module__ = 'brainstate.nn'
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
in_size: Size,
|
48
|
+
out_size: Size,
|
49
|
+
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
50
|
+
b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
|
51
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
52
|
+
name: Optional[str] = None,
|
53
|
+
param_type: type = ParamState,
|
54
|
+
):
|
55
|
+
super().__init__(name=name)
|
56
|
+
|
57
|
+
# input and output shape
|
58
|
+
self.in_size = in_size
|
59
|
+
self.out_size = out_size
|
60
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
61
|
+
'and "out_size" must be the same.')
|
62
|
+
|
63
|
+
# w_mask
|
64
|
+
self.w_mask = init.param(w_mask, self.in_size + self.out_size)
|
65
|
+
|
66
|
+
# weights
|
67
|
+
params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
|
68
|
+
if b_init is not None:
|
69
|
+
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
70
|
+
self.weight = param_type(params)
|
71
|
+
|
72
|
+
def update(self, x):
|
73
|
+
params = self.weight.value
|
74
|
+
weight = params['weight']
|
75
|
+
if self.w_mask is not None:
|
76
|
+
weight = weight * self.w_mask
|
77
|
+
y = u.linalg.dot(x, weight)
|
78
|
+
if 'bias' in params:
|
79
|
+
y = y + params['bias']
|
80
|
+
return y
|
81
|
+
|
82
|
+
|
83
|
+
class SignedWLinear(Module):
|
84
|
+
"""
|
85
|
+
Linear layer with signed weights.
|
86
|
+
"""
|
87
|
+
__module__ = 'brainstate.nn'
|
88
|
+
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
in_size: Size,
|
92
|
+
out_size: Size,
|
93
|
+
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
94
|
+
w_sign: Optional[ArrayLike] = None,
|
95
|
+
name: Optional[str] = None,
|
96
|
+
param_type: type = ParamState,
|
97
|
+
):
|
98
|
+
super().__init__(name=name)
|
99
|
+
|
100
|
+
# input and output shape
|
101
|
+
self.in_size = in_size
|
102
|
+
self.out_size = out_size
|
103
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
104
|
+
'and "out_size" must be the same.')
|
105
|
+
|
106
|
+
# w_mask
|
107
|
+
self.w_sign = w_sign
|
108
|
+
|
109
|
+
# weights
|
110
|
+
weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
|
111
|
+
self.weight = param_type(weight)
|
112
|
+
|
113
|
+
def update(self, x):
|
114
|
+
w = self.weight.value
|
115
|
+
if self.w_sign is None:
|
116
|
+
return u.math.matmul(x, u.math.abs(w))
|
117
|
+
else:
|
118
|
+
return u.math.matmul(x, u.math.abs(w) * self.w_sign)
|
119
|
+
|
120
|
+
|
121
|
+
class ScaledWSLinear(Module):
|
122
|
+
"""
|
123
|
+
Linear Layer with Weight Standardization.
|
124
|
+
|
125
|
+
Applies weight standardization to the weights of the linear layer.
|
126
|
+
|
127
|
+
Parameters
|
128
|
+
----------
|
129
|
+
in_size: int, sequence of int
|
130
|
+
The input size.
|
131
|
+
out_size: int, sequence of int
|
132
|
+
The output size.
|
133
|
+
w_init: Callable, ArrayLike
|
134
|
+
The initializer for the weights.
|
135
|
+
b_init: Callable, ArrayLike
|
136
|
+
The initializer for the bias.
|
137
|
+
w_mask: ArrayLike, Callable
|
138
|
+
The optional mask of the weights.
|
139
|
+
ws_gain: bool
|
140
|
+
Whether to use gain for the weights. The default is True.
|
141
|
+
eps: float
|
142
|
+
The epsilon value for the weight standardization.
|
143
|
+
name: str
|
144
|
+
The name of the object.
|
145
|
+
|
146
|
+
"""
|
147
|
+
__module__ = 'brainstate.nn'
|
148
|
+
|
149
|
+
def __init__(
|
150
|
+
self,
|
151
|
+
in_size: Size,
|
152
|
+
out_size: Size,
|
153
|
+
w_init: Callable = init.KaimingNormal(),
|
154
|
+
b_init: Callable = init.ZeroInit(),
|
155
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
156
|
+
ws_gain: bool = True,
|
157
|
+
eps: float = 1e-4,
|
158
|
+
name: str = None,
|
159
|
+
param_type: type = ParamState,
|
160
|
+
):
|
161
|
+
super().__init__(name=name)
|
162
|
+
|
163
|
+
# input and output shape
|
164
|
+
self.in_size = in_size
|
165
|
+
self.out_size = out_size
|
166
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
167
|
+
'and "out_size" must be the same.')
|
168
|
+
|
169
|
+
# w_mask
|
170
|
+
self.w_mask = init.param(w_mask, (self.in_size[0], 1))
|
171
|
+
|
172
|
+
# parameters
|
173
|
+
self.eps = eps
|
174
|
+
|
175
|
+
# weights
|
176
|
+
params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
|
177
|
+
if b_init is not None:
|
178
|
+
params['bias'] = init.param(b_init, self.out_size, allow_none=False)
|
179
|
+
# gain
|
180
|
+
if ws_gain:
|
181
|
+
s = params['weight'].shape
|
182
|
+
params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
|
183
|
+
self.weight = param_type(params)
|
184
|
+
|
185
|
+
def update(self, x):
|
186
|
+
params = self.weight.value
|
187
|
+
w = params['weight']
|
188
|
+
w = functional.weight_standardization(w, self.eps, params.get('gain', None))
|
189
|
+
if self.w_mask is not None:
|
190
|
+
w = w * self.w_mask
|
191
|
+
y = u.linalg.dot(x, w)
|
192
|
+
if 'bias' in params:
|
193
|
+
y = y + params['bias']
|
194
|
+
return y
|
195
|
+
|
196
|
+
|
197
|
+
class SparseLinear(Module):
|
198
|
+
"""
|
199
|
+
Linear layer with Sparse Matrix (can be ``brainunit.sparse.CSR``,
|
200
|
+
``brainunit.sparse.CSC``, ``brainunit.sparse.COO``, or any other sparse matrix).
|
201
|
+
|
202
|
+
Args:
|
203
|
+
spar_mat: SparseMatrix. The sparse weight matrix.
|
204
|
+
in_size: Size. The input size.
|
205
|
+
name: str. The object name.
|
206
|
+
"""
|
207
|
+
__module__ = 'brainstate.nn'
|
208
|
+
|
209
|
+
def __init__(
|
210
|
+
self,
|
211
|
+
spar_mat: u.sparse.SparseMatrix,
|
212
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
213
|
+
in_size: Size = None,
|
214
|
+
name: Optional[str] = None,
|
215
|
+
param_type: type = ParamState,
|
216
|
+
):
|
217
|
+
super().__init__(name=name)
|
218
|
+
|
219
|
+
# input and output shape
|
220
|
+
if in_size is not None:
|
221
|
+
self.in_size = in_size
|
222
|
+
self.out_size = spar_mat.shape[-1]
|
223
|
+
if in_size is not None:
|
224
|
+
assert self.in_size[:-1] == self.out_size[:-1], (
|
225
|
+
'The first n-1 dimensions of "in_size" '
|
226
|
+
'and "out_size" must be the same.'
|
227
|
+
)
|
228
|
+
|
229
|
+
# weights
|
230
|
+
assert isinstance(spar_mat, u.sparse.SparseMatrix), '"weight" must be a SparseMatrix.'
|
231
|
+
self.spar_mat = spar_mat
|
232
|
+
params = dict(weight=spar_mat.data)
|
233
|
+
if b_init is not None:
|
234
|
+
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
235
|
+
self.weight = param_type(params)
|
236
|
+
|
237
|
+
def update(self, x):
|
238
|
+
data = self.weight.value['weight']
|
239
|
+
y = x @ self.spar_mat.with_data(data)
|
240
|
+
if 'bias' in self.weight.value:
|
241
|
+
y = y + self.weight.value['bias']
|
242
|
+
return y
|
243
|
+
|
244
|
+
|
245
|
+
class AllToAll(Module):
|
246
|
+
"""
|
247
|
+
Synaptic matrix multiplication with All-to-All connections.
|
248
|
+
|
249
|
+
Args:
|
250
|
+
in_size: Size. The number of neurons in the pre-synaptic neuron group.
|
251
|
+
out_size: Size. The number of neurons in the postsynaptic neuron group.
|
252
|
+
w_init: The synaptic weight initializer.
|
253
|
+
include_self: bool. Whether connect the neuron with at the same position.
|
254
|
+
name: str. The object name.
|
255
|
+
"""
|
256
|
+
|
257
|
+
def __init__(
|
258
|
+
self,
|
259
|
+
in_size: Size,
|
260
|
+
out_size: Size,
|
261
|
+
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
262
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
263
|
+
include_self: bool = True,
|
264
|
+
name: Optional[str] = None,
|
265
|
+
param_type: type = ParamState,
|
266
|
+
):
|
267
|
+
super().__init__(name=name)
|
268
|
+
|
269
|
+
# input and output shape
|
270
|
+
self.in_size = in_size
|
271
|
+
self.out_size = out_size
|
272
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
273
|
+
'and "out_size" must be the same.')
|
274
|
+
|
275
|
+
# others
|
276
|
+
self.include_self = include_self
|
277
|
+
|
278
|
+
# weights
|
279
|
+
weight = init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False)
|
280
|
+
params = dict(weight=weight)
|
281
|
+
if b_init is not None:
|
282
|
+
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
283
|
+
self.weight = param_type(params)
|
284
|
+
|
285
|
+
def update(self, pre_val):
|
286
|
+
params = self.weight.value
|
287
|
+
pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
|
288
|
+
w_val, w_unit = u.get_mantissa(params['weight']), u.get_unit(params['weight'])
|
289
|
+
|
290
|
+
if u.math.ndim(w_val) == 0: # weight is a scalar
|
291
|
+
if pre_val.ndim == 1:
|
292
|
+
post_val = u.math.sum(pre_val)
|
293
|
+
else:
|
294
|
+
post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
|
295
|
+
if not self.include_self:
|
296
|
+
if self.in_size == self.out_size:
|
297
|
+
post_val = post_val - pre_val
|
298
|
+
elif self.in_size[-1] > self.out_size[-1]:
|
299
|
+
val = pre_val[..., :self.out_size[-1]]
|
300
|
+
post_val = post_val - val
|
301
|
+
else:
|
302
|
+
size = list(self.out_size)
|
303
|
+
size[-1] = self.out_size[-1] - self.in_size[-1]
|
304
|
+
val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
|
305
|
+
post_val = post_val - val
|
306
|
+
post_val = w_val * post_val
|
307
|
+
|
308
|
+
else: # weight is a matrix
|
309
|
+
assert u.math.ndim(w_val) == 2, '"weight" must be a 2D matrix.'
|
310
|
+
if not self.include_self:
|
311
|
+
post_val = pre_val @ u.math.fill_diagonal(w_val, 0.)
|
312
|
+
else:
|
313
|
+
post_val = pre_val @ w_val
|
314
|
+
|
315
|
+
post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
|
316
|
+
if 'bias' in params:
|
317
|
+
post_val = post_val + params['bias']
|
318
|
+
return post_val
|
319
|
+
|
320
|
+
|
321
|
+
class OneToOne(Module):
|
322
|
+
"""
|
323
|
+
Synaptic matrix multiplication with One2One connection.
|
324
|
+
|
325
|
+
Args:
|
326
|
+
in_size: Size. The number of neurons in the pre-synaptic neuron group.
|
327
|
+
w_init: The synaptic weight initializer.
|
328
|
+
b_init: The synaptic bias initializer.
|
329
|
+
name: str. The object name.
|
330
|
+
"""
|
331
|
+
|
332
|
+
def __init__(
|
333
|
+
self,
|
334
|
+
in_size: Size,
|
335
|
+
w_init: Union[Callable, ArrayLike] = init.Normal(),
|
336
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
337
|
+
name: Optional[str] = None,
|
338
|
+
param_type: type = ParamState,
|
339
|
+
):
|
340
|
+
super().__init__(name=name)
|
341
|
+
|
342
|
+
# input and output shape
|
343
|
+
self.in_size = in_size
|
344
|
+
self.out_size = in_size
|
345
|
+
|
346
|
+
# weights
|
347
|
+
param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
|
348
|
+
if b_init is not None:
|
349
|
+
param['bias'] = init.param(b_init, self.out_size, allow_none=False)
|
350
|
+
self.weight = param_type(param)
|
351
|
+
|
352
|
+
def update(self, pre_val):
|
353
|
+
post_val = pre_val * self.weight.value['weight']
|
354
|
+
if 'bias' in self.weight.value:
|
355
|
+
post_val = post_val + self.weight.value['bias']
|
356
|
+
return post_val
|
357
|
+
|
358
|
+
|
359
|
+
class LoRA(Module):
|
360
|
+
"""A standalone LoRA layer.
|
361
|
+
|
362
|
+
Example usage::
|
363
|
+
|
364
|
+
>>> import brainstate as brainstate
|
365
|
+
>>> import jax, jax.numpy as jnp
|
366
|
+
>>> layer = brainstate.nn.LoRA(3, 2, 4)
|
367
|
+
>>> layer.weight.value
|
368
|
+
{'lora_a': Array([[ 0.25141352, -0.09826107],
|
369
|
+
[ 0.2328382 , 0.38869813],
|
370
|
+
[ 0.27069277, 0.7678282 ]], dtype=float32),
|
371
|
+
'lora_b': Array([[-0.8372317 , 0.21012013, -0.52999765, -0.31939325],
|
372
|
+
[ 0.64234126, -0.42980042, 1.2549229 , -0.47134295]], dtype=float32)}
|
373
|
+
>>> # Wrap around existing layer
|
374
|
+
>>> linear = brainstate.nn.Linear(3, 4)
|
375
|
+
>>> wrapper = brainstate.nn.LoRA(3, 2, 4, base_module=linear)
|
376
|
+
>>> assert wrapper.base_module == linear
|
377
|
+
>>> y = layer(jnp.ones((16, 3)))
|
378
|
+
>>> y.shape
|
379
|
+
(16, 4)
|
380
|
+
|
381
|
+
Args:
|
382
|
+
in_features: the number of input features.
|
383
|
+
lora_rank: the rank of the LoRA dimension.
|
384
|
+
out_features: the number of output features.
|
385
|
+
base_module: a base module to call and substitute, if possible.
|
386
|
+
kernel_init: initializer function for the weight matrices.
|
387
|
+
param_type: the type of the LoRA params.
|
388
|
+
"""
|
389
|
+
|
390
|
+
def __init__(
|
391
|
+
self,
|
392
|
+
in_features: int,
|
393
|
+
lora_rank: int,
|
394
|
+
out_features: int,
|
395
|
+
*,
|
396
|
+
base_module: Optional[Module] = None,
|
397
|
+
kernel_init: Union[Callable, ArrayLike] = init.LecunNormal(),
|
398
|
+
param_type: type = ParamState,
|
399
|
+
):
|
400
|
+
super().__init__()
|
401
|
+
|
402
|
+
# input and output shape
|
403
|
+
self.in_size = in_features
|
404
|
+
self.out_size = out_features
|
405
|
+
self.in_features = in_features
|
406
|
+
self.out_features = out_features
|
407
|
+
|
408
|
+
# others
|
409
|
+
self.base_module = base_module
|
410
|
+
|
411
|
+
# weights
|
412
|
+
param = dict(
|
413
|
+
lora_a=kernel_init((in_features, lora_rank)),
|
414
|
+
lora_b=kernel_init((lora_rank, out_features))
|
415
|
+
)
|
416
|
+
self.weight = param_type(param)
|
417
|
+
|
418
|
+
def __call__(self, x: ArrayLike):
|
419
|
+
out = x @ self.weight.value['lora_a'] @ self.weight.value['lora_b']
|
420
|
+
if self.base_module is not None:
|
421
|
+
if not callable(self.base_module):
|
422
|
+
raise ValueError('`self.base_module` must be callable.')
|
423
|
+
out += self.base_module(x)
|
424
|
+
return out
|