brainstate 0.2.0__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +169 -169
- brainstate/_compatible_import.py +340 -340
- brainstate/_compatible_import_test.py +681 -681
- brainstate/_deprecation.py +210 -210
- brainstate/_deprecation_test.py +2319 -2319
- brainstate/_error.py +45 -45
- brainstate/_state.py +1652 -1652
- brainstate/_state_test.py +52 -52
- brainstate/_utils.py +47 -47
- brainstate/environ.py +1495 -1495
- brainstate/environ_test.py +1223 -1223
- brainstate/graph/__init__.py +22 -22
- brainstate/graph/_node.py +240 -240
- brainstate/graph/_node_test.py +589 -589
- brainstate/graph/_operation.py +1624 -1624
- brainstate/graph/_operation_test.py +1147 -1147
- brainstate/mixin.py +1433 -1433
- brainstate/mixin_test.py +1017 -1017
- brainstate/nn/__init__.py +137 -137
- brainstate/nn/_activations.py +1100 -1100
- brainstate/nn/_activations_test.py +354 -354
- brainstate/nn/_collective_ops.py +633 -633
- brainstate/nn/_collective_ops_test.py +774 -774
- brainstate/nn/_common.py +226 -226
- brainstate/nn/_common_test.py +154 -154
- brainstate/nn/_conv.py +2010 -2010
- brainstate/nn/_conv_test.py +849 -849
- brainstate/nn/_delay.py +575 -575
- brainstate/nn/_delay_test.py +243 -243
- brainstate/nn/_dropout.py +618 -618
- brainstate/nn/_dropout_test.py +477 -477
- brainstate/nn/_dynamics.py +1267 -1267
- brainstate/nn/_dynamics_test.py +67 -67
- brainstate/nn/_elementwise.py +1298 -1298
- brainstate/nn/_elementwise_test.py +829 -829
- brainstate/nn/_embedding.py +408 -408
- brainstate/nn/_embedding_test.py +156 -156
- brainstate/nn/_event_fixedprob.py +233 -233
- brainstate/nn/_event_fixedprob_test.py +115 -115
- brainstate/nn/_event_linear.py +83 -83
- brainstate/nn/_event_linear_test.py +121 -121
- brainstate/nn/_exp_euler.py +254 -254
- brainstate/nn/_exp_euler_test.py +377 -377
- brainstate/nn/_linear.py +744 -744
- brainstate/nn/_linear_test.py +475 -475
- brainstate/nn/_metrics.py +1070 -1070
- brainstate/nn/_metrics_test.py +611 -611
- brainstate/nn/_module.py +384 -384
- brainstate/nn/_module_test.py +40 -40
- brainstate/nn/_normalizations.py +1334 -1334
- brainstate/nn/_normalizations_test.py +699 -699
- brainstate/nn/_paddings.py +1020 -1020
- brainstate/nn/_paddings_test.py +722 -722
- brainstate/nn/_poolings.py +2239 -2239
- brainstate/nn/_poolings_test.py +952 -952
- brainstate/nn/_rnns.py +946 -946
- brainstate/nn/_rnns_test.py +592 -592
- brainstate/nn/_utils.py +216 -216
- brainstate/nn/_utils_test.py +401 -401
- brainstate/nn/init.py +809 -809
- brainstate/nn/init_test.py +180 -180
- brainstate/random/__init__.py +270 -270
- brainstate/random/_rand_funs.py +3938 -3938
- brainstate/random/_rand_funs_test.py +640 -640
- brainstate/random/_rand_seed.py +675 -675
- brainstate/random/_rand_seed_test.py +48 -48
- brainstate/random/_rand_state.py +1617 -1617
- brainstate/random/_rand_state_test.py +551 -551
- brainstate/transform/__init__.py +59 -59
- brainstate/transform/_ad_checkpoint.py +176 -176
- brainstate/transform/_ad_checkpoint_test.py +49 -49
- brainstate/transform/_autograd.py +1025 -1025
- brainstate/transform/_autograd_test.py +1289 -1289
- brainstate/transform/_conditions.py +316 -316
- brainstate/transform/_conditions_test.py +220 -220
- brainstate/transform/_error_if.py +94 -94
- brainstate/transform/_error_if_test.py +52 -52
- brainstate/transform/_eval_shape.py +145 -145
- brainstate/transform/_eval_shape_test.py +38 -38
- brainstate/transform/_jit.py +399 -399
- brainstate/transform/_jit_test.py +143 -143
- brainstate/transform/_loop_collect_return.py +675 -675
- brainstate/transform/_loop_collect_return_test.py +58 -58
- brainstate/transform/_loop_no_collection.py +283 -283
- brainstate/transform/_loop_no_collection_test.py +50 -50
- brainstate/transform/_make_jaxpr.py +2016 -2016
- brainstate/transform/_make_jaxpr_test.py +1510 -1510
- brainstate/transform/_mapping.py +529 -529
- brainstate/transform/_mapping_test.py +194 -194
- brainstate/transform/_progress_bar.py +255 -255
- brainstate/transform/_random.py +171 -171
- brainstate/transform/_unvmap.py +256 -256
- brainstate/transform/_util.py +286 -286
- brainstate/typing.py +837 -837
- brainstate/typing_test.py +780 -780
- brainstate/util/__init__.py +27 -27
- brainstate/util/_others.py +1024 -1024
- brainstate/util/_others_test.py +962 -962
- brainstate/util/_pretty_pytree.py +1301 -1301
- brainstate/util/_pretty_pytree_test.py +675 -675
- brainstate/util/_pretty_repr.py +462 -462
- brainstate/util/_pretty_repr_test.py +696 -696
- brainstate/util/filter.py +945 -945
- brainstate/util/filter_test.py +911 -911
- brainstate/util/struct.py +910 -910
- brainstate/util/struct_test.py +602 -602
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -108
- brainstate-0.2.1.dist-info/RECORD +111 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
- brainstate-0.2.0.dist-info/RECORD +0 -111
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_linear.py
CHANGED
@@ -1,744 +1,744 @@
|
|
1
|
-
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
# -*- coding: utf-8 -*-
|
17
|
-
|
18
|
-
from typing import Callable, Union, Optional
|
19
|
-
|
20
|
-
import brainunit as u
|
21
|
-
import jax.numpy as jnp
|
22
|
-
|
23
|
-
from brainstate._state import ParamState
|
24
|
-
from brainstate.typing import ArrayLike, Size
|
25
|
-
from . import init as init
|
26
|
-
from ._module import Module
|
27
|
-
from ._normalizations import weight_standardization
|
28
|
-
|
29
|
-
__all__ = [
|
30
|
-
'Linear',
|
31
|
-
'ScaledWSLinear',
|
32
|
-
'SignedWLinear',
|
33
|
-
'SparseLinear',
|
34
|
-
'AllToAll',
|
35
|
-
'OneToOne',
|
36
|
-
'LoRA',
|
37
|
-
]
|
38
|
-
|
39
|
-
|
40
|
-
class Linear(Module):
|
41
|
-
"""
|
42
|
-
Linear transformation layer.
|
43
|
-
|
44
|
-
Applies a linear transformation to the incoming data: :math:`y = xW + b`
|
45
|
-
|
46
|
-
Parameters
|
47
|
-
----------
|
48
|
-
in_size : int or tuple of int
|
49
|
-
The input feature size.
|
50
|
-
out_size : int or tuple of int
|
51
|
-
The output feature size.
|
52
|
-
w_init : Callable or ArrayLike, optional
|
53
|
-
Weight initializer. Default is ``KaimingNormal()``.
|
54
|
-
b_init : Callable, ArrayLike, or None, optional
|
55
|
-
Bias initializer. If ``None``, no bias is added. Default is ``ZeroInit()``.
|
56
|
-
w_mask : ArrayLike, Callable, or None, optional
|
57
|
-
Optional mask for the weights. If provided, weights will be element-wise
|
58
|
-
multiplied by this mask.
|
59
|
-
name : str, optional
|
60
|
-
Name of the module.
|
61
|
-
param_type : type, optional
|
62
|
-
Type of parameter state. Default is ``ParamState``.
|
63
|
-
|
64
|
-
Attributes
|
65
|
-
----------
|
66
|
-
in_size : tuple
|
67
|
-
Input feature size.
|
68
|
-
out_size : tuple
|
69
|
-
Output feature size.
|
70
|
-
w_mask : ArrayLike or None
|
71
|
-
Weight mask if provided.
|
72
|
-
weight : ParamState
|
73
|
-
Parameter state containing 'weight' and optionally 'bias'.
|
74
|
-
|
75
|
-
Examples
|
76
|
-
--------
|
77
|
-
.. code-block:: python
|
78
|
-
|
79
|
-
>>> import brainstate as bst
|
80
|
-
>>> import jax.numpy as jnp
|
81
|
-
>>>
|
82
|
-
>>> # Create a linear layer
|
83
|
-
>>> layer = bst.nn.Linear((10,), (5,))
|
84
|
-
>>> x = jnp.ones((32, 10))
|
85
|
-
>>> y = layer(x)
|
86
|
-
>>> y.shape
|
87
|
-
(32, 5)
|
88
|
-
>>>
|
89
|
-
>>> # Linear layer without bias
|
90
|
-
>>> layer = bst.nn.Linear((10,), (5,), b_init=None)
|
91
|
-
>>> y = layer(x)
|
92
|
-
>>> y.shape
|
93
|
-
(32, 5)
|
94
|
-
"""
|
95
|
-
__module__ = 'brainstate.nn'
|
96
|
-
|
97
|
-
def __init__(
|
98
|
-
self,
|
99
|
-
in_size: Size,
|
100
|
-
out_size: Size,
|
101
|
-
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
102
|
-
b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
|
103
|
-
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
104
|
-
name: Optional[str] = None,
|
105
|
-
param_type: type = ParamState,
|
106
|
-
):
|
107
|
-
super().__init__(name=name)
|
108
|
-
|
109
|
-
# input and output shape
|
110
|
-
self.in_size = in_size
|
111
|
-
self.out_size = out_size
|
112
|
-
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
113
|
-
'and "out_size" must be the same.')
|
114
|
-
|
115
|
-
# w_mask
|
116
|
-
self.w_mask = init.param(w_mask, self.in_size + self.out_size)
|
117
|
-
|
118
|
-
# weights
|
119
|
-
params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
|
120
|
-
if b_init is not None:
|
121
|
-
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
122
|
-
self.weight = param_type(params)
|
123
|
-
|
124
|
-
def update(self, x):
|
125
|
-
params = self.weight.value
|
126
|
-
weight = params['weight']
|
127
|
-
if self.w_mask is not None:
|
128
|
-
weight = weight * self.w_mask
|
129
|
-
y = u.linalg.dot(x, weight)
|
130
|
-
if 'bias' in params:
|
131
|
-
y = y + params['bias']
|
132
|
-
return y
|
133
|
-
|
134
|
-
|
135
|
-
class SignedWLinear(Module):
|
136
|
-
"""
|
137
|
-
Linear layer with signed absolute weights.
|
138
|
-
|
139
|
-
This layer uses absolute values of weights multiplied by a sign matrix,
|
140
|
-
ensuring all effective weights have controlled signs.
|
141
|
-
|
142
|
-
Parameters
|
143
|
-
----------
|
144
|
-
in_size : int or tuple of int
|
145
|
-
The input feature size.
|
146
|
-
out_size : int or tuple of int
|
147
|
-
The output feature size.
|
148
|
-
w_init : Callable or ArrayLike, optional
|
149
|
-
Weight initializer. Default is ``KaimingNormal()``.
|
150
|
-
w_sign : ArrayLike or None, optional
|
151
|
-
Sign matrix for the weights. If ``None``, all weights are positive
|
152
|
-
(absolute values used). If provided, should have the same shape as
|
153
|
-
the weight matrix.
|
154
|
-
name : str, optional
|
155
|
-
Name of the module.
|
156
|
-
param_type : type, optional
|
157
|
-
Type of parameter state. Default is ``ParamState``.
|
158
|
-
|
159
|
-
Attributes
|
160
|
-
----------
|
161
|
-
in_size : tuple
|
162
|
-
Input feature size.
|
163
|
-
out_size : tuple
|
164
|
-
Output feature size.
|
165
|
-
w_sign : ArrayLike or None
|
166
|
-
Sign matrix for weights.
|
167
|
-
weight : ParamState
|
168
|
-
Parameter state containing the weight values.
|
169
|
-
|
170
|
-
Examples
|
171
|
-
--------
|
172
|
-
.. code-block:: python
|
173
|
-
|
174
|
-
>>> import brainstate as bst
|
175
|
-
>>> import jax.numpy as jnp
|
176
|
-
>>>
|
177
|
-
>>> # Create a signed weight linear layer with all positive weights
|
178
|
-
>>> layer = bst.nn.SignedWLinear((10,), (5,))
|
179
|
-
>>> x = jnp.ones((32, 10))
|
180
|
-
>>> y = layer(x)
|
181
|
-
>>> y.shape
|
182
|
-
(32, 5)
|
183
|
-
>>>
|
184
|
-
>>> # With custom sign matrix (e.g., inhibitory connections)
|
185
|
-
>>> w_sign = jnp.ones((10, 5)) * -1.0 # all negative
|
186
|
-
>>> layer = bst.nn.SignedWLinear((10,), (5,), w_sign=w_sign)
|
187
|
-
>>> y = layer(x)
|
188
|
-
>>> y.shape
|
189
|
-
(32, 5)
|
190
|
-
"""
|
191
|
-
__module__ = 'brainstate.nn'
|
192
|
-
|
193
|
-
def __init__(
|
194
|
-
self,
|
195
|
-
in_size: Size,
|
196
|
-
out_size: Size,
|
197
|
-
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
198
|
-
w_sign: Optional[ArrayLike] = None,
|
199
|
-
name: Optional[str] = None,
|
200
|
-
param_type: type = ParamState,
|
201
|
-
):
|
202
|
-
super().__init__(name=name)
|
203
|
-
|
204
|
-
# input and output shape
|
205
|
-
self.in_size = in_size
|
206
|
-
self.out_size = out_size
|
207
|
-
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
208
|
-
'and "out_size" must be the same.')
|
209
|
-
|
210
|
-
# w_mask
|
211
|
-
self.w_sign = w_sign
|
212
|
-
|
213
|
-
# weights
|
214
|
-
weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
|
215
|
-
self.weight = param_type(weight)
|
216
|
-
|
217
|
-
def update(self, x):
|
218
|
-
w = self.weight.value
|
219
|
-
if self.w_sign is None:
|
220
|
-
return u.math.matmul(x, u.math.abs(w))
|
221
|
-
else:
|
222
|
-
return u.math.matmul(x, u.math.abs(w) * self.w_sign)
|
223
|
-
|
224
|
-
|
225
|
-
class ScaledWSLinear(Module):
|
226
|
-
"""
|
227
|
-
Linear layer with weight standardization.
|
228
|
-
|
229
|
-
Applies weight standardization [1]_ to normalize weights before the linear
|
230
|
-
transformation, which can improve training stability and performance.
|
231
|
-
|
232
|
-
Parameters
|
233
|
-
----------
|
234
|
-
in_size : int or tuple of int
|
235
|
-
The input feature size.
|
236
|
-
out_size : int or tuple of int
|
237
|
-
The output feature size.
|
238
|
-
w_init : Callable, optional
|
239
|
-
Weight initializer. Default is ``KaimingNormal()``.
|
240
|
-
b_init : Callable, optional
|
241
|
-
Bias initializer. Default is ``ZeroInit()``.
|
242
|
-
w_mask : ArrayLike, Callable, or None, optional
|
243
|
-
Optional mask for the weights.
|
244
|
-
ws_gain : bool, optional
|
245
|
-
Whether to use a learnable gain parameter for weight standardization.
|
246
|
-
Default is ``True``.
|
247
|
-
eps : float, optional
|
248
|
-
Small constant for numerical stability in standardization.
|
249
|
-
Default is ``1e-4``.
|
250
|
-
name : str, optional
|
251
|
-
Name of the module.
|
252
|
-
param_type : type, optional
|
253
|
-
Type of parameter state. Default is ``ParamState``.
|
254
|
-
|
255
|
-
Attributes
|
256
|
-
----------
|
257
|
-
in_size : tuple
|
258
|
-
Input feature size.
|
259
|
-
out_size : tuple
|
260
|
-
Output feature size.
|
261
|
-
w_mask : ArrayLike or None
|
262
|
-
Weight mask if provided.
|
263
|
-
eps : float
|
264
|
-
Epsilon for numerical stability.
|
265
|
-
weight : ParamState
|
266
|
-
Parameter state containing 'weight', optionally 'bias' and 'gain'.
|
267
|
-
|
268
|
-
References
|
269
|
-
----------
|
270
|
-
.. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
|
271
|
-
Weight standardization. arXiv preprint arXiv:1903.10520.
|
272
|
-
|
273
|
-
Examples
|
274
|
-
--------
|
275
|
-
.. code-block:: python
|
276
|
-
|
277
|
-
>>> import brainstate as bst
|
278
|
-
>>> import jax.numpy as jnp
|
279
|
-
>>>
|
280
|
-
>>> # Create a weight-standardized linear layer
|
281
|
-
>>> layer = bst.nn.ScaledWSLinear((10,), (5,))
|
282
|
-
>>> x = jnp.ones((32, 10))
|
283
|
-
>>> y = layer(x)
|
284
|
-
>>> y.shape
|
285
|
-
(32, 5)
|
286
|
-
>>>
|
287
|
-
>>> # Without learnable gain
|
288
|
-
>>> layer = bst.nn.ScaledWSLinear((10,), (5,), ws_gain=False)
|
289
|
-
>>> y = layer(x)
|
290
|
-
>>> y.shape
|
291
|
-
(32, 5)
|
292
|
-
"""
|
293
|
-
__module__ = 'brainstate.nn'
|
294
|
-
|
295
|
-
def __init__(
|
296
|
-
self,
|
297
|
-
in_size: Size,
|
298
|
-
out_size: Size,
|
299
|
-
w_init: Callable = init.KaimingNormal(),
|
300
|
-
b_init: Callable = init.ZeroInit(),
|
301
|
-
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
302
|
-
ws_gain: bool = True,
|
303
|
-
eps: float = 1e-4,
|
304
|
-
name: str = None,
|
305
|
-
param_type: type = ParamState,
|
306
|
-
):
|
307
|
-
super().__init__(name=name)
|
308
|
-
|
309
|
-
# input and output shape
|
310
|
-
self.in_size = in_size
|
311
|
-
self.out_size = out_size
|
312
|
-
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
313
|
-
'and "out_size" must be the same.')
|
314
|
-
|
315
|
-
# w_mask
|
316
|
-
self.w_mask = init.param(w_mask, (self.in_size[0], 1))
|
317
|
-
|
318
|
-
# parameters
|
319
|
-
self.eps = eps
|
320
|
-
|
321
|
-
# weights
|
322
|
-
params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
|
323
|
-
if b_init is not None:
|
324
|
-
params['bias'] = init.param(b_init, self.out_size, allow_none=False)
|
325
|
-
# gain
|
326
|
-
if ws_gain:
|
327
|
-
s = params['weight'].shape
|
328
|
-
params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
|
329
|
-
self.weight = param_type(params)
|
330
|
-
|
331
|
-
def update(self, x):
|
332
|
-
params = self.weight.value
|
333
|
-
w = params['weight']
|
334
|
-
w = weight_standardization(w, self.eps, params.get('gain', None))
|
335
|
-
if self.w_mask is not None:
|
336
|
-
w = w * self.w_mask
|
337
|
-
y = u.linalg.dot(x, w)
|
338
|
-
if 'bias' in params:
|
339
|
-
y = y + params['bias']
|
340
|
-
return y
|
341
|
-
|
342
|
-
|
343
|
-
class SparseLinear(Module):
|
344
|
-
"""
|
345
|
-
Linear layer with sparse weight matrix.
|
346
|
-
|
347
|
-
Supports sparse matrices from ``brainunit.sparse`` including CSR, CSC,
|
348
|
-
and COO formats. Only the non-zero entries are stored and updated.
|
349
|
-
|
350
|
-
Parameters
|
351
|
-
----------
|
352
|
-
spar_mat : brainunit.sparse.SparseMatrix
|
353
|
-
The sparse weight matrix defining the connectivity structure.
|
354
|
-
b_init : Callable, ArrayLike, or None, optional
|
355
|
-
Bias initializer. If ``None``, no bias is added.
|
356
|
-
in_size : int or tuple of int, optional
|
357
|
-
The input size. If not provided, inferred from ``spar_mat``.
|
358
|
-
name : str, optional
|
359
|
-
Name of the module.
|
360
|
-
param_type : type, optional
|
361
|
-
Type of parameter state. Default is ``ParamState``.
|
362
|
-
|
363
|
-
Attributes
|
364
|
-
----------
|
365
|
-
in_size : tuple
|
366
|
-
Input feature size.
|
367
|
-
out_size : int
|
368
|
-
Output feature size.
|
369
|
-
spar_mat : brainunit.sparse.SparseMatrix
|
370
|
-
The sparse matrix structure.
|
371
|
-
weight : ParamState
|
372
|
-
Parameter state containing the sparse 'weight' data and optionally 'bias'.
|
373
|
-
|
374
|
-
Examples
|
375
|
-
--------
|
376
|
-
.. code-block:: python
|
377
|
-
|
378
|
-
>>> import brainstate as bst
|
379
|
-
>>> import brainunit as u
|
380
|
-
>>> import jax.numpy as jnp
|
381
|
-
>>>
|
382
|
-
>>> # Create a sparse linear layer with CSR matrix
|
383
|
-
>>> indices = jnp.array([[0, 1], [1, 2], [2, 0]])
|
384
|
-
>>> values = jnp.array([1.0, 2.0, 3.0])
|
385
|
-
>>> spar_mat = u.sparse.CSR((values, indices[:, 1], indices[:, 0]),
|
386
|
-
... shape=(3, 3))
|
387
|
-
>>> layer = bst.nn.SparseLinear(spar_mat, in_size=(3,))
|
388
|
-
>>> x = jnp.ones((5, 3))
|
389
|
-
>>> y = layer(x)
|
390
|
-
>>> y.shape
|
391
|
-
(5, 3)
|
392
|
-
"""
|
393
|
-
__module__ = 'brainstate.nn'
|
394
|
-
|
395
|
-
def __init__(
|
396
|
-
self,
|
397
|
-
spar_mat: u.sparse.SparseMatrix,
|
398
|
-
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
399
|
-
in_size: Size = None,
|
400
|
-
name: Optional[str] = None,
|
401
|
-
param_type: type = ParamState,
|
402
|
-
):
|
403
|
-
super().__init__(name=name)
|
404
|
-
|
405
|
-
# input and output shape
|
406
|
-
if in_size is not None:
|
407
|
-
self.in_size = in_size
|
408
|
-
self.out_size = spar_mat.shape[-1]
|
409
|
-
if in_size is not None:
|
410
|
-
assert self.in_size[:-1] == self.out_size[:-1], (
|
411
|
-
'The first n-1 dimensions of "in_size" '
|
412
|
-
'and "out_size" must be the same.'
|
413
|
-
)
|
414
|
-
|
415
|
-
# weights
|
416
|
-
assert isinstance(spar_mat, u.sparse.SparseMatrix), '"weight" must be a SparseMatrix.'
|
417
|
-
self.spar_mat = spar_mat
|
418
|
-
params = dict(weight=spar_mat.data)
|
419
|
-
if b_init is not None:
|
420
|
-
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
421
|
-
self.weight = param_type(params)
|
422
|
-
|
423
|
-
def update(self, x):
|
424
|
-
data = self.weight.value['weight']
|
425
|
-
y = x @ self.spar_mat.with_data(data)
|
426
|
-
if 'bias' in self.weight.value:
|
427
|
-
y = y + self.weight.value['bias']
|
428
|
-
return y
|
429
|
-
|
430
|
-
|
431
|
-
class AllToAll(Module):
|
432
|
-
"""
|
433
|
-
All-to-all connection layer.
|
434
|
-
|
435
|
-
Performs matrix multiplication with optional exclusion of self-connections,
|
436
|
-
commonly used in recurrent neural networks and graph neural networks.
|
437
|
-
|
438
|
-
Parameters
|
439
|
-
----------
|
440
|
-
in_size : int or tuple of int
|
441
|
-
The number of neurons in the pre-synaptic group.
|
442
|
-
out_size : int or tuple of int
|
443
|
-
The number of neurons in the post-synaptic group.
|
444
|
-
w_init : Callable or ArrayLike, optional
|
445
|
-
Weight initializer. Default is ``KaimingNormal()``.
|
446
|
-
b_init : Callable, ArrayLike, or None, optional
|
447
|
-
Bias initializer. If ``None``, no bias is added.
|
448
|
-
include_self : bool, optional
|
449
|
-
Whether to include self-connections (diagonal elements).
|
450
|
-
Default is ``True``.
|
451
|
-
name : str, optional
|
452
|
-
Name of the module.
|
453
|
-
param_type : type, optional
|
454
|
-
Type of parameter state. Default is ``ParamState``.
|
455
|
-
|
456
|
-
Attributes
|
457
|
-
----------
|
458
|
-
in_size : tuple
|
459
|
-
Input size.
|
460
|
-
out_size : tuple
|
461
|
-
Output size.
|
462
|
-
include_self : bool
|
463
|
-
Whether self-connections are included.
|
464
|
-
weight : ParamState
|
465
|
-
Parameter state containing 'weight' and optionally 'bias'.
|
466
|
-
|
467
|
-
Examples
|
468
|
-
--------
|
469
|
-
.. code-block:: python
|
470
|
-
|
471
|
-
>>> import brainstate as bst
|
472
|
-
>>> import jax.numpy as jnp
|
473
|
-
>>>
|
474
|
-
>>> # All-to-all with self-connections
|
475
|
-
>>> layer = bst.nn.AllToAll((10,), (10,), include_self=True)
|
476
|
-
>>> x = jnp.ones((32, 10))
|
477
|
-
>>> y = layer(x)
|
478
|
-
>>> y.shape
|
479
|
-
(32, 10)
|
480
|
-
>>>
|
481
|
-
>>> # All-to-all without self-connections (recurrent layer)
|
482
|
-
>>> layer = bst.nn.AllToAll((10,), (10,), include_self=False)
|
483
|
-
>>> y = layer(x)
|
484
|
-
>>> y.shape
|
485
|
-
(32, 10)
|
486
|
-
"""
|
487
|
-
__module__ = 'brainstate.nn'
|
488
|
-
|
489
|
-
def __init__(
|
490
|
-
self,
|
491
|
-
in_size: Size,
|
492
|
-
out_size: Size,
|
493
|
-
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
494
|
-
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
495
|
-
include_self: bool = True,
|
496
|
-
name: Optional[str] = None,
|
497
|
-
param_type: type = ParamState,
|
498
|
-
):
|
499
|
-
super().__init__(name=name)
|
500
|
-
|
501
|
-
# input and output shape
|
502
|
-
self.in_size = in_size
|
503
|
-
self.out_size = out_size
|
504
|
-
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
505
|
-
'and "out_size" must be the same.')
|
506
|
-
|
507
|
-
# others
|
508
|
-
self.include_self = include_self
|
509
|
-
|
510
|
-
# weights
|
511
|
-
weight = init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False)
|
512
|
-
params = dict(weight=weight)
|
513
|
-
if b_init is not None:
|
514
|
-
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
515
|
-
self.weight = param_type(params)
|
516
|
-
|
517
|
-
def update(self, pre_val):
|
518
|
-
params = self.weight.value
|
519
|
-
pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
|
520
|
-
w_val, w_unit = u.get_mantissa(params['weight']), u.get_unit(params['weight'])
|
521
|
-
|
522
|
-
if u.math.ndim(w_val) == 0: # weight is a scalar
|
523
|
-
if pre_val.ndim == 1:
|
524
|
-
post_val = u.math.sum(pre_val)
|
525
|
-
else:
|
526
|
-
post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
|
527
|
-
if not self.include_self:
|
528
|
-
if self.in_size == self.out_size:
|
529
|
-
post_val = post_val - pre_val
|
530
|
-
elif self.in_size[-1] > self.out_size[-1]:
|
531
|
-
val = pre_val[..., :self.out_size[-1]]
|
532
|
-
post_val = post_val - val
|
533
|
-
else:
|
534
|
-
size = list(self.out_size)
|
535
|
-
size[-1] = self.out_size[-1] - self.in_size[-1]
|
536
|
-
val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
|
537
|
-
post_val = post_val - val
|
538
|
-
post_val = w_val * post_val
|
539
|
-
|
540
|
-
else: # weight is a matrix
|
541
|
-
assert u.math.ndim(w_val) == 2, '"weight" must be a 2D matrix.'
|
542
|
-
if not self.include_self:
|
543
|
-
post_val = pre_val @ u.math.fill_diagonal(w_val, 0.)
|
544
|
-
else:
|
545
|
-
post_val = pre_val @ w_val
|
546
|
-
|
547
|
-
post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
|
548
|
-
if 'bias' in params:
|
549
|
-
post_val = post_val + params['bias']
|
550
|
-
return post_val
|
551
|
-
|
552
|
-
|
553
|
-
class OneToOne(Module):
|
554
|
-
"""
|
555
|
-
One-to-one connection layer.
|
556
|
-
|
557
|
-
Applies element-wise multiplication with a weight vector, implementing
|
558
|
-
diagonal connectivity where each input unit connects only to its
|
559
|
-
corresponding output unit.
|
560
|
-
|
561
|
-
Parameters
|
562
|
-
----------
|
563
|
-
in_size : int or tuple of int
|
564
|
-
The number of neurons. Input and output sizes are the same.
|
565
|
-
w_init : Callable or ArrayLike, optional
|
566
|
-
Weight initializer. Default is ``Normal()``.
|
567
|
-
b_init : Callable, ArrayLike, or None, optional
|
568
|
-
Bias initializer. If ``None``, no bias is added.
|
569
|
-
name : str, optional
|
570
|
-
Name of the module.
|
571
|
-
param_type : type, optional
|
572
|
-
Type of parameter state. Default is ``ParamState``.
|
573
|
-
|
574
|
-
Attributes
|
575
|
-
----------
|
576
|
-
in_size : tuple
|
577
|
-
Input size.
|
578
|
-
out_size : tuple
|
579
|
-
Output size (same as input size).
|
580
|
-
weight : ParamState
|
581
|
-
Parameter state containing 'weight' and optionally 'bias'.
|
582
|
-
|
583
|
-
Examples
|
584
|
-
--------
|
585
|
-
.. code-block:: python
|
586
|
-
|
587
|
-
>>> import brainstate as bst
|
588
|
-
>>> import jax.numpy as jnp
|
589
|
-
>>>
|
590
|
-
>>> # One-to-one connection
|
591
|
-
>>> layer = bst.nn.OneToOne((10,))
|
592
|
-
>>> x = jnp.ones((32, 10))
|
593
|
-
>>> y = layer(x)
|
594
|
-
>>> y.shape
|
595
|
-
(32, 10)
|
596
|
-
>>>
|
597
|
-
>>> # With bias
|
598
|
-
>>> layer = bst.nn.OneToOne((10,), b_init=bst.init.Constant(0.1))
|
599
|
-
>>> y = layer(x)
|
600
|
-
>>> y.shape
|
601
|
-
(32, 10)
|
602
|
-
"""
|
603
|
-
__module__ = 'brainstate.nn'
|
604
|
-
|
605
|
-
def __init__(
|
606
|
-
self,
|
607
|
-
in_size: Size,
|
608
|
-
w_init: Union[Callable, ArrayLike] = init.Normal(),
|
609
|
-
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
610
|
-
name: Optional[str] = None,
|
611
|
-
param_type: type = ParamState,
|
612
|
-
):
|
613
|
-
super().__init__(name=name)
|
614
|
-
|
615
|
-
# input and output shape
|
616
|
-
self.in_size = in_size
|
617
|
-
self.out_size = in_size
|
618
|
-
|
619
|
-
# weights
|
620
|
-
param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
|
621
|
-
if b_init is not None:
|
622
|
-
param['bias'] = init.param(b_init, self.out_size, allow_none=False)
|
623
|
-
self.weight = param_type(param)
|
624
|
-
|
625
|
-
def update(self, pre_val):
|
626
|
-
post_val = pre_val * self.weight.value['weight']
|
627
|
-
if 'bias' in self.weight.value:
|
628
|
-
post_val = post_val + self.weight.value['bias']
|
629
|
-
return post_val
|
630
|
-
|
631
|
-
|
632
|
-
class LoRA(Module):
|
633
|
-
"""
|
634
|
-
Low-Rank Adaptation (LoRA) layer.
|
635
|
-
|
636
|
-
Implements parameter-efficient fine-tuning using low-rank decomposition [1]_.
|
637
|
-
Can be used standalone or as a wrapper around an existing module.
|
638
|
-
|
639
|
-
Parameters
|
640
|
-
----------
|
641
|
-
in_features : int
|
642
|
-
The number of input features.
|
643
|
-
lora_rank : int
|
644
|
-
The rank of the low-rank decomposition. Lower rank means fewer parameters.
|
645
|
-
out_features : int
|
646
|
-
The number of output features.
|
647
|
-
base_module : Module, optional
|
648
|
-
A base module to wrap. If provided, the LoRA output will be added to
|
649
|
-
the base module's output. Default is ``None``.
|
650
|
-
kernel_init : Callable or ArrayLike, optional
|
651
|
-
Initializer for the LoRA weight matrices. Default is ``LecunNormal()``.
|
652
|
-
param_type : type, optional
|
653
|
-
Type of parameter state. Default is ``ParamState``.
|
654
|
-
|
655
|
-
Attributes
|
656
|
-
----------
|
657
|
-
in_size : int
|
658
|
-
Input feature size.
|
659
|
-
out_size : int
|
660
|
-
Output feature size.
|
661
|
-
in_features : int
|
662
|
-
Number of input features.
|
663
|
-
out_features : int
|
664
|
-
Number of output features.
|
665
|
-
base_module : Module or None
|
666
|
-
The wrapped base module if provided.
|
667
|
-
weight : ParamState
|
668
|
-
Parameter state containing 'lora_a' and 'lora_b' matrices.
|
669
|
-
|
670
|
-
References
|
671
|
-
----------
|
672
|
-
.. [1] Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S.,
|
673
|
-
Wang, L., & Chen, W. (2021). LoRA: Low-Rank Adaptation of Large
|
674
|
-
Language Models. arXiv preprint arXiv:2106.09685.
|
675
|
-
|
676
|
-
Examples
|
677
|
-
--------
|
678
|
-
.. code-block:: python
|
679
|
-
|
680
|
-
>>> import brainstate as bst
|
681
|
-
>>> import jax.numpy as jnp
|
682
|
-
>>>
|
683
|
-
>>> # Standalone LoRA layer
|
684
|
-
>>> layer = bst.nn.LoRA(in_features=10, lora_rank=2, out_features=5)
|
685
|
-
>>> x = jnp.ones((32, 10))
|
686
|
-
>>> y = layer(x)
|
687
|
-
>>> y.shape
|
688
|
-
(32, 5)
|
689
|
-
>>>
|
690
|
-
>>> # Wrap around existing linear layer
|
691
|
-
>>> base = bst.nn.Linear((10,), (5,))
|
692
|
-
>>> lora_layer = bst.nn.LoRA(in_features=10, lora_rank=2,
|
693
|
-
... out_features=5, base_module=base)
|
694
|
-
>>> y = lora_layer(x)
|
695
|
-
>>> y.shape
|
696
|
-
(32, 5)
|
697
|
-
>>>
|
698
|
-
>>> # Check parameter count - LoRA has fewer parameters
|
699
|
-
>>> # Base layer: 10 * 5 = 50 parameters
|
700
|
-
>>> # LoRA: 10 * 2 + 2 * 5 = 30 parameters
|
701
|
-
"""
|
702
|
-
__module__ = 'brainstate.nn'
|
703
|
-
|
704
|
-
def __init__(
|
705
|
-
self,
|
706
|
-
in_features: int,
|
707
|
-
lora_rank: int,
|
708
|
-
out_features: int,
|
709
|
-
*,
|
710
|
-
base_module: Optional[Module] = None,
|
711
|
-
kernel_init: Union[Callable, ArrayLike] = init.LecunNormal(),
|
712
|
-
param_type: type = ParamState,
|
713
|
-
in_size: Size = None,
|
714
|
-
):
|
715
|
-
super().__init__()
|
716
|
-
|
717
|
-
# input and output shape
|
718
|
-
self.in_size = in_features
|
719
|
-
self.out_size = out_features
|
720
|
-
self.in_features = in_features
|
721
|
-
self.out_features = out_features
|
722
|
-
|
723
|
-
# others
|
724
|
-
self.base_module = base_module
|
725
|
-
|
726
|
-
# weights
|
727
|
-
param = dict(
|
728
|
-
lora_a=kernel_init((in_features, lora_rank)),
|
729
|
-
lora_b=kernel_init((lora_rank, out_features))
|
730
|
-
)
|
731
|
-
self.weight = param_type(param)
|
732
|
-
|
733
|
-
# in_size
|
734
|
-
if in_size is not None:
|
735
|
-
self.in_size = in_size
|
736
|
-
self.out_size = tuple(self.in_size[:-1]) + (out_features,)
|
737
|
-
|
738
|
-
def __call__(self, x: ArrayLike):
|
739
|
-
out = x @ self.weight.value['lora_a'] @ self.weight.value['lora_b']
|
740
|
-
if self.base_module is not None:
|
741
|
-
if not callable(self.base_module):
|
742
|
-
raise ValueError('`self.base_module` must be callable.')
|
743
|
-
out += self.base_module(x)
|
744
|
-
return out
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# -*- coding: utf-8 -*-
|
17
|
+
|
18
|
+
from typing import Callable, Union, Optional
|
19
|
+
|
20
|
+
import brainunit as u
|
21
|
+
import jax.numpy as jnp
|
22
|
+
|
23
|
+
from brainstate._state import ParamState
|
24
|
+
from brainstate.typing import ArrayLike, Size
|
25
|
+
from . import init as init
|
26
|
+
from ._module import Module
|
27
|
+
from ._normalizations import weight_standardization
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
'Linear',
|
31
|
+
'ScaledWSLinear',
|
32
|
+
'SignedWLinear',
|
33
|
+
'SparseLinear',
|
34
|
+
'AllToAll',
|
35
|
+
'OneToOne',
|
36
|
+
'LoRA',
|
37
|
+
]
|
38
|
+
|
39
|
+
|
40
|
+
class Linear(Module):
|
41
|
+
"""
|
42
|
+
Linear transformation layer.
|
43
|
+
|
44
|
+
Applies a linear transformation to the incoming data: :math:`y = xW + b`
|
45
|
+
|
46
|
+
Parameters
|
47
|
+
----------
|
48
|
+
in_size : int or tuple of int
|
49
|
+
The input feature size.
|
50
|
+
out_size : int or tuple of int
|
51
|
+
The output feature size.
|
52
|
+
w_init : Callable or ArrayLike, optional
|
53
|
+
Weight initializer. Default is ``KaimingNormal()``.
|
54
|
+
b_init : Callable, ArrayLike, or None, optional
|
55
|
+
Bias initializer. If ``None``, no bias is added. Default is ``ZeroInit()``.
|
56
|
+
w_mask : ArrayLike, Callable, or None, optional
|
57
|
+
Optional mask for the weights. If provided, weights will be element-wise
|
58
|
+
multiplied by this mask.
|
59
|
+
name : str, optional
|
60
|
+
Name of the module.
|
61
|
+
param_type : type, optional
|
62
|
+
Type of parameter state. Default is ``ParamState``.
|
63
|
+
|
64
|
+
Attributes
|
65
|
+
----------
|
66
|
+
in_size : tuple
|
67
|
+
Input feature size.
|
68
|
+
out_size : tuple
|
69
|
+
Output feature size.
|
70
|
+
w_mask : ArrayLike or None
|
71
|
+
Weight mask if provided.
|
72
|
+
weight : ParamState
|
73
|
+
Parameter state containing 'weight' and optionally 'bias'.
|
74
|
+
|
75
|
+
Examples
|
76
|
+
--------
|
77
|
+
.. code-block:: python
|
78
|
+
|
79
|
+
>>> import brainstate as bst
|
80
|
+
>>> import jax.numpy as jnp
|
81
|
+
>>>
|
82
|
+
>>> # Create a linear layer
|
83
|
+
>>> layer = bst.nn.Linear((10,), (5,))
|
84
|
+
>>> x = jnp.ones((32, 10))
|
85
|
+
>>> y = layer(x)
|
86
|
+
>>> y.shape
|
87
|
+
(32, 5)
|
88
|
+
>>>
|
89
|
+
>>> # Linear layer without bias
|
90
|
+
>>> layer = bst.nn.Linear((10,), (5,), b_init=None)
|
91
|
+
>>> y = layer(x)
|
92
|
+
>>> y.shape
|
93
|
+
(32, 5)
|
94
|
+
"""
|
95
|
+
__module__ = 'brainstate.nn'
|
96
|
+
|
97
|
+
def __init__(
|
98
|
+
self,
|
99
|
+
in_size: Size,
|
100
|
+
out_size: Size,
|
101
|
+
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
102
|
+
b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
|
103
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
104
|
+
name: Optional[str] = None,
|
105
|
+
param_type: type = ParamState,
|
106
|
+
):
|
107
|
+
super().__init__(name=name)
|
108
|
+
|
109
|
+
# input and output shape
|
110
|
+
self.in_size = in_size
|
111
|
+
self.out_size = out_size
|
112
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
113
|
+
'and "out_size" must be the same.')
|
114
|
+
|
115
|
+
# w_mask
|
116
|
+
self.w_mask = init.param(w_mask, self.in_size + self.out_size)
|
117
|
+
|
118
|
+
# weights
|
119
|
+
params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
|
120
|
+
if b_init is not None:
|
121
|
+
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
122
|
+
self.weight = param_type(params)
|
123
|
+
|
124
|
+
def update(self, x):
|
125
|
+
params = self.weight.value
|
126
|
+
weight = params['weight']
|
127
|
+
if self.w_mask is not None:
|
128
|
+
weight = weight * self.w_mask
|
129
|
+
y = u.linalg.dot(x, weight)
|
130
|
+
if 'bias' in params:
|
131
|
+
y = y + params['bias']
|
132
|
+
return y
|
133
|
+
|
134
|
+
|
135
|
+
class SignedWLinear(Module):
|
136
|
+
"""
|
137
|
+
Linear layer with signed absolute weights.
|
138
|
+
|
139
|
+
This layer uses absolute values of weights multiplied by a sign matrix,
|
140
|
+
ensuring all effective weights have controlled signs.
|
141
|
+
|
142
|
+
Parameters
|
143
|
+
----------
|
144
|
+
in_size : int or tuple of int
|
145
|
+
The input feature size.
|
146
|
+
out_size : int or tuple of int
|
147
|
+
The output feature size.
|
148
|
+
w_init : Callable or ArrayLike, optional
|
149
|
+
Weight initializer. Default is ``KaimingNormal()``.
|
150
|
+
w_sign : ArrayLike or None, optional
|
151
|
+
Sign matrix for the weights. If ``None``, all weights are positive
|
152
|
+
(absolute values used). If provided, should have the same shape as
|
153
|
+
the weight matrix.
|
154
|
+
name : str, optional
|
155
|
+
Name of the module.
|
156
|
+
param_type : type, optional
|
157
|
+
Type of parameter state. Default is ``ParamState``.
|
158
|
+
|
159
|
+
Attributes
|
160
|
+
----------
|
161
|
+
in_size : tuple
|
162
|
+
Input feature size.
|
163
|
+
out_size : tuple
|
164
|
+
Output feature size.
|
165
|
+
w_sign : ArrayLike or None
|
166
|
+
Sign matrix for weights.
|
167
|
+
weight : ParamState
|
168
|
+
Parameter state containing the weight values.
|
169
|
+
|
170
|
+
Examples
|
171
|
+
--------
|
172
|
+
.. code-block:: python
|
173
|
+
|
174
|
+
>>> import brainstate as bst
|
175
|
+
>>> import jax.numpy as jnp
|
176
|
+
>>>
|
177
|
+
>>> # Create a signed weight linear layer with all positive weights
|
178
|
+
>>> layer = bst.nn.SignedWLinear((10,), (5,))
|
179
|
+
>>> x = jnp.ones((32, 10))
|
180
|
+
>>> y = layer(x)
|
181
|
+
>>> y.shape
|
182
|
+
(32, 5)
|
183
|
+
>>>
|
184
|
+
>>> # With custom sign matrix (e.g., inhibitory connections)
|
185
|
+
>>> w_sign = jnp.ones((10, 5)) * -1.0 # all negative
|
186
|
+
>>> layer = bst.nn.SignedWLinear((10,), (5,), w_sign=w_sign)
|
187
|
+
>>> y = layer(x)
|
188
|
+
>>> y.shape
|
189
|
+
(32, 5)
|
190
|
+
"""
|
191
|
+
__module__ = 'brainstate.nn'
|
192
|
+
|
193
|
+
def __init__(
|
194
|
+
self,
|
195
|
+
in_size: Size,
|
196
|
+
out_size: Size,
|
197
|
+
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
198
|
+
w_sign: Optional[ArrayLike] = None,
|
199
|
+
name: Optional[str] = None,
|
200
|
+
param_type: type = ParamState,
|
201
|
+
):
|
202
|
+
super().__init__(name=name)
|
203
|
+
|
204
|
+
# input and output shape
|
205
|
+
self.in_size = in_size
|
206
|
+
self.out_size = out_size
|
207
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
208
|
+
'and "out_size" must be the same.')
|
209
|
+
|
210
|
+
# w_mask
|
211
|
+
self.w_sign = w_sign
|
212
|
+
|
213
|
+
# weights
|
214
|
+
weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
|
215
|
+
self.weight = param_type(weight)
|
216
|
+
|
217
|
+
def update(self, x):
|
218
|
+
w = self.weight.value
|
219
|
+
if self.w_sign is None:
|
220
|
+
return u.math.matmul(x, u.math.abs(w))
|
221
|
+
else:
|
222
|
+
return u.math.matmul(x, u.math.abs(w) * self.w_sign)
|
223
|
+
|
224
|
+
|
225
|
+
class ScaledWSLinear(Module):
|
226
|
+
"""
|
227
|
+
Linear layer with weight standardization.
|
228
|
+
|
229
|
+
Applies weight standardization [1]_ to normalize weights before the linear
|
230
|
+
transformation, which can improve training stability and performance.
|
231
|
+
|
232
|
+
Parameters
|
233
|
+
----------
|
234
|
+
in_size : int or tuple of int
|
235
|
+
The input feature size.
|
236
|
+
out_size : int or tuple of int
|
237
|
+
The output feature size.
|
238
|
+
w_init : Callable, optional
|
239
|
+
Weight initializer. Default is ``KaimingNormal()``.
|
240
|
+
b_init : Callable, optional
|
241
|
+
Bias initializer. Default is ``ZeroInit()``.
|
242
|
+
w_mask : ArrayLike, Callable, or None, optional
|
243
|
+
Optional mask for the weights.
|
244
|
+
ws_gain : bool, optional
|
245
|
+
Whether to use a learnable gain parameter for weight standardization.
|
246
|
+
Default is ``True``.
|
247
|
+
eps : float, optional
|
248
|
+
Small constant for numerical stability in standardization.
|
249
|
+
Default is ``1e-4``.
|
250
|
+
name : str, optional
|
251
|
+
Name of the module.
|
252
|
+
param_type : type, optional
|
253
|
+
Type of parameter state. Default is ``ParamState``.
|
254
|
+
|
255
|
+
Attributes
|
256
|
+
----------
|
257
|
+
in_size : tuple
|
258
|
+
Input feature size.
|
259
|
+
out_size : tuple
|
260
|
+
Output feature size.
|
261
|
+
w_mask : ArrayLike or None
|
262
|
+
Weight mask if provided.
|
263
|
+
eps : float
|
264
|
+
Epsilon for numerical stability.
|
265
|
+
weight : ParamState
|
266
|
+
Parameter state containing 'weight', optionally 'bias' and 'gain'.
|
267
|
+
|
268
|
+
References
|
269
|
+
----------
|
270
|
+
.. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
|
271
|
+
Weight standardization. arXiv preprint arXiv:1903.10520.
|
272
|
+
|
273
|
+
Examples
|
274
|
+
--------
|
275
|
+
.. code-block:: python
|
276
|
+
|
277
|
+
>>> import brainstate as bst
|
278
|
+
>>> import jax.numpy as jnp
|
279
|
+
>>>
|
280
|
+
>>> # Create a weight-standardized linear layer
|
281
|
+
>>> layer = bst.nn.ScaledWSLinear((10,), (5,))
|
282
|
+
>>> x = jnp.ones((32, 10))
|
283
|
+
>>> y = layer(x)
|
284
|
+
>>> y.shape
|
285
|
+
(32, 5)
|
286
|
+
>>>
|
287
|
+
>>> # Without learnable gain
|
288
|
+
>>> layer = bst.nn.ScaledWSLinear((10,), (5,), ws_gain=False)
|
289
|
+
>>> y = layer(x)
|
290
|
+
>>> y.shape
|
291
|
+
(32, 5)
|
292
|
+
"""
|
293
|
+
__module__ = 'brainstate.nn'
|
294
|
+
|
295
|
+
def __init__(
|
296
|
+
self,
|
297
|
+
in_size: Size,
|
298
|
+
out_size: Size,
|
299
|
+
w_init: Callable = init.KaimingNormal(),
|
300
|
+
b_init: Callable = init.ZeroInit(),
|
301
|
+
w_mask: Optional[Union[ArrayLike, Callable]] = None,
|
302
|
+
ws_gain: bool = True,
|
303
|
+
eps: float = 1e-4,
|
304
|
+
name: str = None,
|
305
|
+
param_type: type = ParamState,
|
306
|
+
):
|
307
|
+
super().__init__(name=name)
|
308
|
+
|
309
|
+
# input and output shape
|
310
|
+
self.in_size = in_size
|
311
|
+
self.out_size = out_size
|
312
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
313
|
+
'and "out_size" must be the same.')
|
314
|
+
|
315
|
+
# w_mask
|
316
|
+
self.w_mask = init.param(w_mask, (self.in_size[0], 1))
|
317
|
+
|
318
|
+
# parameters
|
319
|
+
self.eps = eps
|
320
|
+
|
321
|
+
# weights
|
322
|
+
params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
|
323
|
+
if b_init is not None:
|
324
|
+
params['bias'] = init.param(b_init, self.out_size, allow_none=False)
|
325
|
+
# gain
|
326
|
+
if ws_gain:
|
327
|
+
s = params['weight'].shape
|
328
|
+
params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
|
329
|
+
self.weight = param_type(params)
|
330
|
+
|
331
|
+
def update(self, x):
|
332
|
+
params = self.weight.value
|
333
|
+
w = params['weight']
|
334
|
+
w = weight_standardization(w, self.eps, params.get('gain', None))
|
335
|
+
if self.w_mask is not None:
|
336
|
+
w = w * self.w_mask
|
337
|
+
y = u.linalg.dot(x, w)
|
338
|
+
if 'bias' in params:
|
339
|
+
y = y + params['bias']
|
340
|
+
return y
|
341
|
+
|
342
|
+
|
343
|
+
class SparseLinear(Module):
|
344
|
+
"""
|
345
|
+
Linear layer with sparse weight matrix.
|
346
|
+
|
347
|
+
Supports sparse matrices from ``brainunit.sparse`` including CSR, CSC,
|
348
|
+
and COO formats. Only the non-zero entries are stored and updated.
|
349
|
+
|
350
|
+
Parameters
|
351
|
+
----------
|
352
|
+
spar_mat : brainunit.sparse.SparseMatrix
|
353
|
+
The sparse weight matrix defining the connectivity structure.
|
354
|
+
b_init : Callable, ArrayLike, or None, optional
|
355
|
+
Bias initializer. If ``None``, no bias is added.
|
356
|
+
in_size : int or tuple of int, optional
|
357
|
+
The input size. If not provided, inferred from ``spar_mat``.
|
358
|
+
name : str, optional
|
359
|
+
Name of the module.
|
360
|
+
param_type : type, optional
|
361
|
+
Type of parameter state. Default is ``ParamState``.
|
362
|
+
|
363
|
+
Attributes
|
364
|
+
----------
|
365
|
+
in_size : tuple
|
366
|
+
Input feature size.
|
367
|
+
out_size : int
|
368
|
+
Output feature size.
|
369
|
+
spar_mat : brainunit.sparse.SparseMatrix
|
370
|
+
The sparse matrix structure.
|
371
|
+
weight : ParamState
|
372
|
+
Parameter state containing the sparse 'weight' data and optionally 'bias'.
|
373
|
+
|
374
|
+
Examples
|
375
|
+
--------
|
376
|
+
.. code-block:: python
|
377
|
+
|
378
|
+
>>> import brainstate as bst
|
379
|
+
>>> import brainunit as u
|
380
|
+
>>> import jax.numpy as jnp
|
381
|
+
>>>
|
382
|
+
>>> # Create a sparse linear layer with CSR matrix
|
383
|
+
>>> indices = jnp.array([[0, 1], [1, 2], [2, 0]])
|
384
|
+
>>> values = jnp.array([1.0, 2.0, 3.0])
|
385
|
+
>>> spar_mat = u.sparse.CSR((values, indices[:, 1], indices[:, 0]),
|
386
|
+
... shape=(3, 3))
|
387
|
+
>>> layer = bst.nn.SparseLinear(spar_mat, in_size=(3,))
|
388
|
+
>>> x = jnp.ones((5, 3))
|
389
|
+
>>> y = layer(x)
|
390
|
+
>>> y.shape
|
391
|
+
(5, 3)
|
392
|
+
"""
|
393
|
+
__module__ = 'brainstate.nn'
|
394
|
+
|
395
|
+
def __init__(
|
396
|
+
self,
|
397
|
+
spar_mat: u.sparse.SparseMatrix,
|
398
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
399
|
+
in_size: Size = None,
|
400
|
+
name: Optional[str] = None,
|
401
|
+
param_type: type = ParamState,
|
402
|
+
):
|
403
|
+
super().__init__(name=name)
|
404
|
+
|
405
|
+
# input and output shape
|
406
|
+
if in_size is not None:
|
407
|
+
self.in_size = in_size
|
408
|
+
self.out_size = spar_mat.shape[-1]
|
409
|
+
if in_size is not None:
|
410
|
+
assert self.in_size[:-1] == self.out_size[:-1], (
|
411
|
+
'The first n-1 dimensions of "in_size" '
|
412
|
+
'and "out_size" must be the same.'
|
413
|
+
)
|
414
|
+
|
415
|
+
# weights
|
416
|
+
assert isinstance(spar_mat, u.sparse.SparseMatrix), '"weight" must be a SparseMatrix.'
|
417
|
+
self.spar_mat = spar_mat
|
418
|
+
params = dict(weight=spar_mat.data)
|
419
|
+
if b_init is not None:
|
420
|
+
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
421
|
+
self.weight = param_type(params)
|
422
|
+
|
423
|
+
def update(self, x):
|
424
|
+
data = self.weight.value['weight']
|
425
|
+
y = x @ self.spar_mat.with_data(data)
|
426
|
+
if 'bias' in self.weight.value:
|
427
|
+
y = y + self.weight.value['bias']
|
428
|
+
return y
|
429
|
+
|
430
|
+
|
431
|
+
class AllToAll(Module):
|
432
|
+
"""
|
433
|
+
All-to-all connection layer.
|
434
|
+
|
435
|
+
Performs matrix multiplication with optional exclusion of self-connections,
|
436
|
+
commonly used in recurrent neural networks and graph neural networks.
|
437
|
+
|
438
|
+
Parameters
|
439
|
+
----------
|
440
|
+
in_size : int or tuple of int
|
441
|
+
The number of neurons in the pre-synaptic group.
|
442
|
+
out_size : int or tuple of int
|
443
|
+
The number of neurons in the post-synaptic group.
|
444
|
+
w_init : Callable or ArrayLike, optional
|
445
|
+
Weight initializer. Default is ``KaimingNormal()``.
|
446
|
+
b_init : Callable, ArrayLike, or None, optional
|
447
|
+
Bias initializer. If ``None``, no bias is added.
|
448
|
+
include_self : bool, optional
|
449
|
+
Whether to include self-connections (diagonal elements).
|
450
|
+
Default is ``True``.
|
451
|
+
name : str, optional
|
452
|
+
Name of the module.
|
453
|
+
param_type : type, optional
|
454
|
+
Type of parameter state. Default is ``ParamState``.
|
455
|
+
|
456
|
+
Attributes
|
457
|
+
----------
|
458
|
+
in_size : tuple
|
459
|
+
Input size.
|
460
|
+
out_size : tuple
|
461
|
+
Output size.
|
462
|
+
include_self : bool
|
463
|
+
Whether self-connections are included.
|
464
|
+
weight : ParamState
|
465
|
+
Parameter state containing 'weight' and optionally 'bias'.
|
466
|
+
|
467
|
+
Examples
|
468
|
+
--------
|
469
|
+
.. code-block:: python
|
470
|
+
|
471
|
+
>>> import brainstate as bst
|
472
|
+
>>> import jax.numpy as jnp
|
473
|
+
>>>
|
474
|
+
>>> # All-to-all with self-connections
|
475
|
+
>>> layer = bst.nn.AllToAll((10,), (10,), include_self=True)
|
476
|
+
>>> x = jnp.ones((32, 10))
|
477
|
+
>>> y = layer(x)
|
478
|
+
>>> y.shape
|
479
|
+
(32, 10)
|
480
|
+
>>>
|
481
|
+
>>> # All-to-all without self-connections (recurrent layer)
|
482
|
+
>>> layer = bst.nn.AllToAll((10,), (10,), include_self=False)
|
483
|
+
>>> y = layer(x)
|
484
|
+
>>> y.shape
|
485
|
+
(32, 10)
|
486
|
+
"""
|
487
|
+
__module__ = 'brainstate.nn'
|
488
|
+
|
489
|
+
def __init__(
|
490
|
+
self,
|
491
|
+
in_size: Size,
|
492
|
+
out_size: Size,
|
493
|
+
w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
|
494
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
495
|
+
include_self: bool = True,
|
496
|
+
name: Optional[str] = None,
|
497
|
+
param_type: type = ParamState,
|
498
|
+
):
|
499
|
+
super().__init__(name=name)
|
500
|
+
|
501
|
+
# input and output shape
|
502
|
+
self.in_size = in_size
|
503
|
+
self.out_size = out_size
|
504
|
+
assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
|
505
|
+
'and "out_size" must be the same.')
|
506
|
+
|
507
|
+
# others
|
508
|
+
self.include_self = include_self
|
509
|
+
|
510
|
+
# weights
|
511
|
+
weight = init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False)
|
512
|
+
params = dict(weight=weight)
|
513
|
+
if b_init is not None:
|
514
|
+
params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
|
515
|
+
self.weight = param_type(params)
|
516
|
+
|
517
|
+
def update(self, pre_val):
|
518
|
+
params = self.weight.value
|
519
|
+
pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
|
520
|
+
w_val, w_unit = u.get_mantissa(params['weight']), u.get_unit(params['weight'])
|
521
|
+
|
522
|
+
if u.math.ndim(w_val) == 0: # weight is a scalar
|
523
|
+
if pre_val.ndim == 1:
|
524
|
+
post_val = u.math.sum(pre_val)
|
525
|
+
else:
|
526
|
+
post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
|
527
|
+
if not self.include_self:
|
528
|
+
if self.in_size == self.out_size:
|
529
|
+
post_val = post_val - pre_val
|
530
|
+
elif self.in_size[-1] > self.out_size[-1]:
|
531
|
+
val = pre_val[..., :self.out_size[-1]]
|
532
|
+
post_val = post_val - val
|
533
|
+
else:
|
534
|
+
size = list(self.out_size)
|
535
|
+
size[-1] = self.out_size[-1] - self.in_size[-1]
|
536
|
+
val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
|
537
|
+
post_val = post_val - val
|
538
|
+
post_val = w_val * post_val
|
539
|
+
|
540
|
+
else: # weight is a matrix
|
541
|
+
assert u.math.ndim(w_val) == 2, '"weight" must be a 2D matrix.'
|
542
|
+
if not self.include_self:
|
543
|
+
post_val = pre_val @ u.math.fill_diagonal(w_val, 0.)
|
544
|
+
else:
|
545
|
+
post_val = pre_val @ w_val
|
546
|
+
|
547
|
+
post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
|
548
|
+
if 'bias' in params:
|
549
|
+
post_val = post_val + params['bias']
|
550
|
+
return post_val
|
551
|
+
|
552
|
+
|
553
|
+
class OneToOne(Module):
|
554
|
+
"""
|
555
|
+
One-to-one connection layer.
|
556
|
+
|
557
|
+
Applies element-wise multiplication with a weight vector, implementing
|
558
|
+
diagonal connectivity where each input unit connects only to its
|
559
|
+
corresponding output unit.
|
560
|
+
|
561
|
+
Parameters
|
562
|
+
----------
|
563
|
+
in_size : int or tuple of int
|
564
|
+
The number of neurons. Input and output sizes are the same.
|
565
|
+
w_init : Callable or ArrayLike, optional
|
566
|
+
Weight initializer. Default is ``Normal()``.
|
567
|
+
b_init : Callable, ArrayLike, or None, optional
|
568
|
+
Bias initializer. If ``None``, no bias is added.
|
569
|
+
name : str, optional
|
570
|
+
Name of the module.
|
571
|
+
param_type : type, optional
|
572
|
+
Type of parameter state. Default is ``ParamState``.
|
573
|
+
|
574
|
+
Attributes
|
575
|
+
----------
|
576
|
+
in_size : tuple
|
577
|
+
Input size.
|
578
|
+
out_size : tuple
|
579
|
+
Output size (same as input size).
|
580
|
+
weight : ParamState
|
581
|
+
Parameter state containing 'weight' and optionally 'bias'.
|
582
|
+
|
583
|
+
Examples
|
584
|
+
--------
|
585
|
+
.. code-block:: python
|
586
|
+
|
587
|
+
>>> import brainstate as bst
|
588
|
+
>>> import jax.numpy as jnp
|
589
|
+
>>>
|
590
|
+
>>> # One-to-one connection
|
591
|
+
>>> layer = bst.nn.OneToOne((10,))
|
592
|
+
>>> x = jnp.ones((32, 10))
|
593
|
+
>>> y = layer(x)
|
594
|
+
>>> y.shape
|
595
|
+
(32, 10)
|
596
|
+
>>>
|
597
|
+
>>> # With bias
|
598
|
+
>>> layer = bst.nn.OneToOne((10,), b_init=bst.init.Constant(0.1))
|
599
|
+
>>> y = layer(x)
|
600
|
+
>>> y.shape
|
601
|
+
(32, 10)
|
602
|
+
"""
|
603
|
+
__module__ = 'brainstate.nn'
|
604
|
+
|
605
|
+
def __init__(
|
606
|
+
self,
|
607
|
+
in_size: Size,
|
608
|
+
w_init: Union[Callable, ArrayLike] = init.Normal(),
|
609
|
+
b_init: Optional[Union[Callable, ArrayLike]] = None,
|
610
|
+
name: Optional[str] = None,
|
611
|
+
param_type: type = ParamState,
|
612
|
+
):
|
613
|
+
super().__init__(name=name)
|
614
|
+
|
615
|
+
# input and output shape
|
616
|
+
self.in_size = in_size
|
617
|
+
self.out_size = in_size
|
618
|
+
|
619
|
+
# weights
|
620
|
+
param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
|
621
|
+
if b_init is not None:
|
622
|
+
param['bias'] = init.param(b_init, self.out_size, allow_none=False)
|
623
|
+
self.weight = param_type(param)
|
624
|
+
|
625
|
+
def update(self, pre_val):
|
626
|
+
post_val = pre_val * self.weight.value['weight']
|
627
|
+
if 'bias' in self.weight.value:
|
628
|
+
post_val = post_val + self.weight.value['bias']
|
629
|
+
return post_val
|
630
|
+
|
631
|
+
|
632
|
+
class LoRA(Module):
|
633
|
+
"""
|
634
|
+
Low-Rank Adaptation (LoRA) layer.
|
635
|
+
|
636
|
+
Implements parameter-efficient fine-tuning using low-rank decomposition [1]_.
|
637
|
+
Can be used standalone or as a wrapper around an existing module.
|
638
|
+
|
639
|
+
Parameters
|
640
|
+
----------
|
641
|
+
in_features : int
|
642
|
+
The number of input features.
|
643
|
+
lora_rank : int
|
644
|
+
The rank of the low-rank decomposition. Lower rank means fewer parameters.
|
645
|
+
out_features : int
|
646
|
+
The number of output features.
|
647
|
+
base_module : Module, optional
|
648
|
+
A base module to wrap. If provided, the LoRA output will be added to
|
649
|
+
the base module's output. Default is ``None``.
|
650
|
+
kernel_init : Callable or ArrayLike, optional
|
651
|
+
Initializer for the LoRA weight matrices. Default is ``LecunNormal()``.
|
652
|
+
param_type : type, optional
|
653
|
+
Type of parameter state. Default is ``ParamState``.
|
654
|
+
|
655
|
+
Attributes
|
656
|
+
----------
|
657
|
+
in_size : int
|
658
|
+
Input feature size.
|
659
|
+
out_size : int
|
660
|
+
Output feature size.
|
661
|
+
in_features : int
|
662
|
+
Number of input features.
|
663
|
+
out_features : int
|
664
|
+
Number of output features.
|
665
|
+
base_module : Module or None
|
666
|
+
The wrapped base module if provided.
|
667
|
+
weight : ParamState
|
668
|
+
Parameter state containing 'lora_a' and 'lora_b' matrices.
|
669
|
+
|
670
|
+
References
|
671
|
+
----------
|
672
|
+
.. [1] Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S.,
|
673
|
+
Wang, L., & Chen, W. (2021). LoRA: Low-Rank Adaptation of Large
|
674
|
+
Language Models. arXiv preprint arXiv:2106.09685.
|
675
|
+
|
676
|
+
Examples
|
677
|
+
--------
|
678
|
+
.. code-block:: python
|
679
|
+
|
680
|
+
>>> import brainstate as bst
|
681
|
+
>>> import jax.numpy as jnp
|
682
|
+
>>>
|
683
|
+
>>> # Standalone LoRA layer
|
684
|
+
>>> layer = bst.nn.LoRA(in_features=10, lora_rank=2, out_features=5)
|
685
|
+
>>> x = jnp.ones((32, 10))
|
686
|
+
>>> y = layer(x)
|
687
|
+
>>> y.shape
|
688
|
+
(32, 5)
|
689
|
+
>>>
|
690
|
+
>>> # Wrap around existing linear layer
|
691
|
+
>>> base = bst.nn.Linear((10,), (5,))
|
692
|
+
>>> lora_layer = bst.nn.LoRA(in_features=10, lora_rank=2,
|
693
|
+
... out_features=5, base_module=base)
|
694
|
+
>>> y = lora_layer(x)
|
695
|
+
>>> y.shape
|
696
|
+
(32, 5)
|
697
|
+
>>>
|
698
|
+
>>> # Check parameter count - LoRA has fewer parameters
|
699
|
+
>>> # Base layer: 10 * 5 = 50 parameters
|
700
|
+
>>> # LoRA: 10 * 2 + 2 * 5 = 30 parameters
|
701
|
+
"""
|
702
|
+
__module__ = 'brainstate.nn'
|
703
|
+
|
704
|
+
def __init__(
|
705
|
+
self,
|
706
|
+
in_features: int,
|
707
|
+
lora_rank: int,
|
708
|
+
out_features: int,
|
709
|
+
*,
|
710
|
+
base_module: Optional[Module] = None,
|
711
|
+
kernel_init: Union[Callable, ArrayLike] = init.LecunNormal(),
|
712
|
+
param_type: type = ParamState,
|
713
|
+
in_size: Size = None,
|
714
|
+
):
|
715
|
+
super().__init__()
|
716
|
+
|
717
|
+
# input and output shape
|
718
|
+
self.in_size = in_features
|
719
|
+
self.out_size = out_features
|
720
|
+
self.in_features = in_features
|
721
|
+
self.out_features = out_features
|
722
|
+
|
723
|
+
# others
|
724
|
+
self.base_module = base_module
|
725
|
+
|
726
|
+
# weights
|
727
|
+
param = dict(
|
728
|
+
lora_a=kernel_init((in_features, lora_rank)),
|
729
|
+
lora_b=kernel_init((lora_rank, out_features))
|
730
|
+
)
|
731
|
+
self.weight = param_type(param)
|
732
|
+
|
733
|
+
# in_size
|
734
|
+
if in_size is not None:
|
735
|
+
self.in_size = in_size
|
736
|
+
self.out_size = tuple(self.in_size[:-1]) + (out_features,)
|
737
|
+
|
738
|
+
def __call__(self, x: ArrayLike):
|
739
|
+
out = x @ self.weight.value['lora_a'] @ self.weight.value['lora_b']
|
740
|
+
if self.base_module is not None:
|
741
|
+
if not callable(self.base_module):
|
742
|
+
raise ValueError('`self.base_module` must be callable.')
|
743
|
+
out += self.base_module(x)
|
744
|
+
return out
|