brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +95 -29
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.9.dist-info/RECORD +0 -130
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_linear.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -20,10 +20,11 @@ from typing import Callable, Union, Optional
|
|
20
20
|
import brainunit as u
|
21
21
|
import jax.numpy as jnp
|
22
22
|
|
23
|
-
from brainstate import init, functional
|
24
23
|
from brainstate._state import ParamState
|
25
24
|
from brainstate.typing import ArrayLike, Size
|
25
|
+
from . import init as init
|
26
26
|
from ._module import Module
|
27
|
+
from ._normalizations import weight_standardization
|
27
28
|
|
28
29
|
__all__ = [
|
29
30
|
'Linear',
|
@@ -38,7 +39,58 @@ __all__ = [
|
|
38
39
|
|
39
40
|
class Linear(Module):
|
40
41
|
"""
|
41
|
-
Linear layer.
|
42
|
+
Linear transformation layer.
|
43
|
+
|
44
|
+
Applies a linear transformation to the incoming data: :math:`y = xW + b`
|
45
|
+
|
46
|
+
Parameters
|
47
|
+
----------
|
48
|
+
in_size : int or tuple of int
|
49
|
+
The input feature size.
|
50
|
+
out_size : int or tuple of int
|
51
|
+
The output feature size.
|
52
|
+
w_init : Callable or ArrayLike, optional
|
53
|
+
Weight initializer. Default is ``KaimingNormal()``.
|
54
|
+
b_init : Callable, ArrayLike, or None, optional
|
55
|
+
Bias initializer. If ``None``, no bias is added. Default is ``ZeroInit()``.
|
56
|
+
w_mask : ArrayLike, Callable, or None, optional
|
57
|
+
Optional mask for the weights. If provided, weights will be element-wise
|
58
|
+
multiplied by this mask.
|
59
|
+
name : str, optional
|
60
|
+
Name of the module.
|
61
|
+
param_type : type, optional
|
62
|
+
Type of parameter state. Default is ``ParamState``.
|
63
|
+
|
64
|
+
Attributes
|
65
|
+
----------
|
66
|
+
in_size : tuple
|
67
|
+
Input feature size.
|
68
|
+
out_size : tuple
|
69
|
+
Output feature size.
|
70
|
+
w_mask : ArrayLike or None
|
71
|
+
Weight mask if provided.
|
72
|
+
weight : ParamState
|
73
|
+
Parameter state containing 'weight' and optionally 'bias'.
|
74
|
+
|
75
|
+
Examples
|
76
|
+
--------
|
77
|
+
.. code-block:: python
|
78
|
+
|
79
|
+
>>> import brainstate as bst
|
80
|
+
>>> import jax.numpy as jnp
|
81
|
+
>>>
|
82
|
+
>>> # Create a linear layer
|
83
|
+
>>> layer = bst.nn.Linear((10,), (5,))
|
84
|
+
>>> x = jnp.ones((32, 10))
|
85
|
+
>>> y = layer(x)
|
86
|
+
>>> y.shape
|
87
|
+
(32, 5)
|
88
|
+
>>>
|
89
|
+
>>> # Linear layer without bias
|
90
|
+
>>> layer = bst.nn.Linear((10,), (5,), b_init=None)
|
91
|
+
>>> y = layer(x)
|
92
|
+
>>> y.shape
|
93
|
+
(32, 5)
|
42
94
|
"""
|
43
95
|
__module__ = 'brainstate.nn'
|
44
96
|
|
@@ -82,7 +134,59 @@ class Linear(Module):
|
|
82
134
|
|
83
135
|
class SignedWLinear(Module):
|
84
136
|
"""
|
85
|
-
Linear layer with signed weights.
|
137
|
+
Linear layer with signed absolute weights.
|
138
|
+
|
139
|
+
This layer uses absolute values of weights multiplied by a sign matrix,
|
140
|
+
ensuring all effective weights have controlled signs.
|
141
|
+
|
142
|
+
Parameters
|
143
|
+
----------
|
144
|
+
in_size : int or tuple of int
|
145
|
+
The input feature size.
|
146
|
+
out_size : int or tuple of int
|
147
|
+
The output feature size.
|
148
|
+
w_init : Callable or ArrayLike, optional
|
149
|
+
Weight initializer. Default is ``KaimingNormal()``.
|
150
|
+
w_sign : ArrayLike or None, optional
|
151
|
+
Sign matrix for the weights. If ``None``, all weights are positive
|
152
|
+
(absolute values used). If provided, should have the same shape as
|
153
|
+
the weight matrix.
|
154
|
+
name : str, optional
|
155
|
+
Name of the module.
|
156
|
+
param_type : type, optional
|
157
|
+
Type of parameter state. Default is ``ParamState``.
|
158
|
+
|
159
|
+
Attributes
|
160
|
+
----------
|
161
|
+
in_size : tuple
|
162
|
+
Input feature size.
|
163
|
+
out_size : tuple
|
164
|
+
Output feature size.
|
165
|
+
w_sign : ArrayLike or None
|
166
|
+
Sign matrix for weights.
|
167
|
+
weight : ParamState
|
168
|
+
Parameter state containing the weight values.
|
169
|
+
|
170
|
+
Examples
|
171
|
+
--------
|
172
|
+
.. code-block:: python
|
173
|
+
|
174
|
+
>>> import brainstate as bst
|
175
|
+
>>> import jax.numpy as jnp
|
176
|
+
>>>
|
177
|
+
>>> # Create a signed weight linear layer with all positive weights
|
178
|
+
>>> layer = bst.nn.SignedWLinear((10,), (5,))
|
179
|
+
>>> x = jnp.ones((32, 10))
|
180
|
+
>>> y = layer(x)
|
181
|
+
>>> y.shape
|
182
|
+
(32, 5)
|
183
|
+
>>>
|
184
|
+
>>> # With custom sign matrix (e.g., inhibitory connections)
|
185
|
+
>>> w_sign = jnp.ones((10, 5)) * -1.0 # all negative
|
186
|
+
>>> layer = bst.nn.SignedWLinear((10,), (5,), w_sign=w_sign)
|
187
|
+
>>> y = layer(x)
|
188
|
+
>>> y.shape
|
189
|
+
(32, 5)
|
86
190
|
"""
|
87
191
|
__module__ = 'brainstate.nn'
|
88
192
|
|
@@ -120,29 +224,71 @@ class SignedWLinear(Module):
|
|
120
224
|
|
121
225
|
class ScaledWSLinear(Module):
|
122
226
|
"""
|
123
|
-
Linear
|
227
|
+
Linear layer with weight standardization.
|
124
228
|
|
125
|
-
Applies weight standardization to
|
229
|
+
Applies weight standardization [1]_ to normalize weights before the linear
|
230
|
+
transformation, which can improve training stability and performance.
|
126
231
|
|
127
232
|
Parameters
|
128
233
|
----------
|
129
|
-
in_size: int
|
130
|
-
|
131
|
-
out_size: int
|
132
|
-
|
133
|
-
w_init: Callable,
|
134
|
-
|
135
|
-
b_init: Callable,
|
136
|
-
|
137
|
-
w_mask: ArrayLike, Callable
|
138
|
-
|
139
|
-
ws_gain: bool
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
234
|
+
in_size : int or tuple of int
|
235
|
+
The input feature size.
|
236
|
+
out_size : int or tuple of int
|
237
|
+
The output feature size.
|
238
|
+
w_init : Callable, optional
|
239
|
+
Weight initializer. Default is ``KaimingNormal()``.
|
240
|
+
b_init : Callable, optional
|
241
|
+
Bias initializer. Default is ``ZeroInit()``.
|
242
|
+
w_mask : ArrayLike, Callable, or None, optional
|
243
|
+
Optional mask for the weights.
|
244
|
+
ws_gain : bool, optional
|
245
|
+
Whether to use a learnable gain parameter for weight standardization.
|
246
|
+
Default is ``True``.
|
247
|
+
eps : float, optional
|
248
|
+
Small constant for numerical stability in standardization.
|
249
|
+
Default is ``1e-4``.
|
250
|
+
name : str, optional
|
251
|
+
Name of the module.
|
252
|
+
param_type : type, optional
|
253
|
+
Type of parameter state. Default is ``ParamState``.
|
254
|
+
|
255
|
+
Attributes
|
256
|
+
----------
|
257
|
+
in_size : tuple
|
258
|
+
Input feature size.
|
259
|
+
out_size : tuple
|
260
|
+
Output feature size.
|
261
|
+
w_mask : ArrayLike or None
|
262
|
+
Weight mask if provided.
|
263
|
+
eps : float
|
264
|
+
Epsilon for numerical stability.
|
265
|
+
weight : ParamState
|
266
|
+
Parameter state containing 'weight', optionally 'bias' and 'gain'.
|
267
|
+
|
268
|
+
References
|
269
|
+
----------
|
270
|
+
.. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
|
271
|
+
Weight standardization. arXiv preprint arXiv:1903.10520.
|
272
|
+
|
273
|
+
Examples
|
274
|
+
--------
|
275
|
+
.. code-block:: python
|
276
|
+
|
277
|
+
>>> import brainstate as bst
|
278
|
+
>>> import jax.numpy as jnp
|
279
|
+
>>>
|
280
|
+
>>> # Create a weight-standardized linear layer
|
281
|
+
>>> layer = bst.nn.ScaledWSLinear((10,), (5,))
|
282
|
+
>>> x = jnp.ones((32, 10))
|
283
|
+
>>> y = layer(x)
|
284
|
+
>>> y.shape
|
285
|
+
(32, 5)
|
286
|
+
>>>
|
287
|
+
>>> # Without learnable gain
|
288
|
+
>>> layer = bst.nn.ScaledWSLinear((10,), (5,), ws_gain=False)
|
289
|
+
>>> y = layer(x)
|
290
|
+
>>> y.shape
|
291
|
+
(32, 5)
|
146
292
|
"""
|
147
293
|
__module__ = 'brainstate.nn'
|
148
294
|
|
@@ -185,7 +331,7 @@ class ScaledWSLinear(Module):
|
|
185
331
|
def update(self, x):
|
186
332
|
params = self.weight.value
|
187
333
|
w = params['weight']
|
188
|
-
w =
|
334
|
+
w = weight_standardization(w, self.eps, params.get('gain', None))
|
189
335
|
if self.w_mask is not None:
|
190
336
|
w = w * self.w_mask
|
191
337
|
y = u.linalg.dot(x, w)
|
@@ -196,13 +342,53 @@ class ScaledWSLinear(Module):
|
|
196
342
|
|
197
343
|
class SparseLinear(Module):
|
198
344
|
"""
|
199
|
-
Linear layer with
|
200
|
-
|
345
|
+
Linear layer with sparse weight matrix.
|
346
|
+
|
347
|
+
Supports sparse matrices from ``brainunit.sparse`` including CSR, CSC,
|
348
|
+
and COO formats. Only the non-zero entries are stored and updated.
|
201
349
|
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
350
|
+
Parameters
|
351
|
+
----------
|
352
|
+
spar_mat : brainunit.sparse.SparseMatrix
|
353
|
+
The sparse weight matrix defining the connectivity structure.
|
354
|
+
b_init : Callable, ArrayLike, or None, optional
|
355
|
+
Bias initializer. If ``None``, no bias is added.
|
356
|
+
in_size : int or tuple of int, optional
|
357
|
+
The input size. If not provided, inferred from ``spar_mat``.
|
358
|
+
name : str, optional
|
359
|
+
Name of the module.
|
360
|
+
param_type : type, optional
|
361
|
+
Type of parameter state. Default is ``ParamState``.
|
362
|
+
|
363
|
+
Attributes
|
364
|
+
----------
|
365
|
+
in_size : tuple
|
366
|
+
Input feature size.
|
367
|
+
out_size : int
|
368
|
+
Output feature size.
|
369
|
+
spar_mat : brainunit.sparse.SparseMatrix
|
370
|
+
The sparse matrix structure.
|
371
|
+
weight : ParamState
|
372
|
+
Parameter state containing the sparse 'weight' data and optionally 'bias'.
|
373
|
+
|
374
|
+
Examples
|
375
|
+
--------
|
376
|
+
.. code-block:: python
|
377
|
+
|
378
|
+
>>> import brainstate as bst
|
379
|
+
>>> import brainunit as u
|
380
|
+
>>> import jax.numpy as jnp
|
381
|
+
>>>
|
382
|
+
>>> # Create a sparse linear layer with CSR matrix
|
383
|
+
>>> indices = jnp.array([[0, 1], [1, 2], [2, 0]])
|
384
|
+
>>> values = jnp.array([1.0, 2.0, 3.0])
|
385
|
+
>>> spar_mat = u.sparse.CSR((values, indices[:, 1], indices[:, 0]),
|
386
|
+
... shape=(3, 3))
|
387
|
+
>>> layer = bst.nn.SparseLinear(spar_mat, in_size=(3,))
|
388
|
+
>>> x = jnp.ones((5, 3))
|
389
|
+
>>> y = layer(x)
|
390
|
+
>>> y.shape
|
391
|
+
(5, 3)
|
206
392
|
"""
|
207
393
|
__module__ = 'brainstate.nn'
|
208
394
|
|
@@ -244,15 +430,61 @@ class SparseLinear(Module):
|
|
244
430
|
|
245
431
|
class AllToAll(Module):
|
246
432
|
"""
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
433
|
+
All-to-all connection layer.
|
434
|
+
|
435
|
+
Performs matrix multiplication with optional exclusion of self-connections,
|
436
|
+
commonly used in recurrent neural networks and graph neural networks.
|
437
|
+
|
438
|
+
Parameters
|
439
|
+
----------
|
440
|
+
in_size : int or tuple of int
|
441
|
+
The number of neurons in the pre-synaptic group.
|
442
|
+
out_size : int or tuple of int
|
443
|
+
The number of neurons in the post-synaptic group.
|
444
|
+
w_init : Callable or ArrayLike, optional
|
445
|
+
Weight initializer. Default is ``KaimingNormal()``.
|
446
|
+
b_init : Callable, ArrayLike, or None, optional
|
447
|
+
Bias initializer. If ``None``, no bias is added.
|
448
|
+
include_self : bool, optional
|
449
|
+
Whether to include self-connections (diagonal elements).
|
450
|
+
Default is ``True``.
|
451
|
+
name : str, optional
|
452
|
+
Name of the module.
|
453
|
+
param_type : type, optional
|
454
|
+
Type of parameter state. Default is ``ParamState``.
|
455
|
+
|
456
|
+
Attributes
|
457
|
+
----------
|
458
|
+
in_size : tuple
|
459
|
+
Input size.
|
460
|
+
out_size : tuple
|
461
|
+
Output size.
|
462
|
+
include_self : bool
|
463
|
+
Whether self-connections are included.
|
464
|
+
weight : ParamState
|
465
|
+
Parameter state containing 'weight' and optionally 'bias'.
|
466
|
+
|
467
|
+
Examples
|
468
|
+
--------
|
469
|
+
.. code-block:: python
|
470
|
+
|
471
|
+
>>> import brainstate as bst
|
472
|
+
>>> import jax.numpy as jnp
|
473
|
+
>>>
|
474
|
+
>>> # All-to-all with self-connections
|
475
|
+
>>> layer = bst.nn.AllToAll((10,), (10,), include_self=True)
|
476
|
+
>>> x = jnp.ones((32, 10))
|
477
|
+
>>> y = layer(x)
|
478
|
+
>>> y.shape
|
479
|
+
(32, 10)
|
480
|
+
>>>
|
481
|
+
>>> # All-to-all without self-connections (recurrent layer)
|
482
|
+
>>> layer = bst.nn.AllToAll((10,), (10,), include_self=False)
|
483
|
+
>>> y = layer(x)
|
484
|
+
>>> y.shape
|
485
|
+
(32, 10)
|
255
486
|
"""
|
487
|
+
__module__ = 'brainstate.nn'
|
256
488
|
|
257
489
|
def __init__(
|
258
490
|
self,
|
@@ -320,14 +552,55 @@ class AllToAll(Module):
|
|
320
552
|
|
321
553
|
class OneToOne(Module):
|
322
554
|
"""
|
323
|
-
|
555
|
+
One-to-one connection layer.
|
556
|
+
|
557
|
+
Applies element-wise multiplication with a weight vector, implementing
|
558
|
+
diagonal connectivity where each input unit connects only to its
|
559
|
+
corresponding output unit.
|
324
560
|
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
561
|
+
Parameters
|
562
|
+
----------
|
563
|
+
in_size : int or tuple of int
|
564
|
+
The number of neurons. Input and output sizes are the same.
|
565
|
+
w_init : Callable or ArrayLike, optional
|
566
|
+
Weight initializer. Default is ``Normal()``.
|
567
|
+
b_init : Callable, ArrayLike, or None, optional
|
568
|
+
Bias initializer. If ``None``, no bias is added.
|
569
|
+
name : str, optional
|
570
|
+
Name of the module.
|
571
|
+
param_type : type, optional
|
572
|
+
Type of parameter state. Default is ``ParamState``.
|
573
|
+
|
574
|
+
Attributes
|
575
|
+
----------
|
576
|
+
in_size : tuple
|
577
|
+
Input size.
|
578
|
+
out_size : tuple
|
579
|
+
Output size (same as input size).
|
580
|
+
weight : ParamState
|
581
|
+
Parameter state containing 'weight' and optionally 'bias'.
|
582
|
+
|
583
|
+
Examples
|
584
|
+
--------
|
585
|
+
.. code-block:: python
|
586
|
+
|
587
|
+
>>> import brainstate as bst
|
588
|
+
>>> import jax.numpy as jnp
|
589
|
+
>>>
|
590
|
+
>>> # One-to-one connection
|
591
|
+
>>> layer = bst.nn.OneToOne((10,))
|
592
|
+
>>> x = jnp.ones((32, 10))
|
593
|
+
>>> y = layer(x)
|
594
|
+
>>> y.shape
|
595
|
+
(32, 10)
|
596
|
+
>>>
|
597
|
+
>>> # With bias
|
598
|
+
>>> layer = bst.nn.OneToOne((10,), b_init=bst.init.Constant(0.1))
|
599
|
+
>>> y = layer(x)
|
600
|
+
>>> y.shape
|
601
|
+
(32, 10)
|
330
602
|
"""
|
603
|
+
__module__ = 'brainstate.nn'
|
331
604
|
|
332
605
|
def __init__(
|
333
606
|
self,
|
@@ -357,35 +630,76 @@ class OneToOne(Module):
|
|
357
630
|
|
358
631
|
|
359
632
|
class LoRA(Module):
|
360
|
-
"""
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
633
|
+
"""
|
634
|
+
Low-Rank Adaptation (LoRA) layer.
|
635
|
+
|
636
|
+
Implements parameter-efficient fine-tuning using low-rank decomposition [1]_.
|
637
|
+
Can be used standalone or as a wrapper around an existing module.
|
638
|
+
|
639
|
+
Parameters
|
640
|
+
----------
|
641
|
+
in_features : int
|
642
|
+
The number of input features.
|
643
|
+
lora_rank : int
|
644
|
+
The rank of the low-rank decomposition. Lower rank means fewer parameters.
|
645
|
+
out_features : int
|
646
|
+
The number of output features.
|
647
|
+
base_module : Module, optional
|
648
|
+
A base module to wrap. If provided, the LoRA output will be added to
|
649
|
+
the base module's output. Default is ``None``.
|
650
|
+
kernel_init : Callable or ArrayLike, optional
|
651
|
+
Initializer for the LoRA weight matrices. Default is ``LecunNormal()``.
|
652
|
+
param_type : type, optional
|
653
|
+
Type of parameter state. Default is ``ParamState``.
|
654
|
+
|
655
|
+
Attributes
|
656
|
+
----------
|
657
|
+
in_size : int
|
658
|
+
Input feature size.
|
659
|
+
out_size : int
|
660
|
+
Output feature size.
|
661
|
+
in_features : int
|
662
|
+
Number of input features.
|
663
|
+
out_features : int
|
664
|
+
Number of output features.
|
665
|
+
base_module : Module or None
|
666
|
+
The wrapped base module if provided.
|
667
|
+
weight : ParamState
|
668
|
+
Parameter state containing 'lora_a' and 'lora_b' matrices.
|
669
|
+
|
670
|
+
References
|
671
|
+
----------
|
672
|
+
.. [1] Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S.,
|
673
|
+
Wang, L., & Chen, W. (2021). LoRA: Low-Rank Adaptation of Large
|
674
|
+
Language Models. arXiv preprint arXiv:2106.09685.
|
675
|
+
|
676
|
+
Examples
|
677
|
+
--------
|
678
|
+
.. code-block:: python
|
679
|
+
|
680
|
+
>>> import brainstate as bst
|
681
|
+
>>> import jax.numpy as jnp
|
682
|
+
>>>
|
683
|
+
>>> # Standalone LoRA layer
|
684
|
+
>>> layer = bst.nn.LoRA(in_features=10, lora_rank=2, out_features=5)
|
685
|
+
>>> x = jnp.ones((32, 10))
|
686
|
+
>>> y = layer(x)
|
378
687
|
>>> y.shape
|
379
|
-
(
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
out_features
|
385
|
-
|
386
|
-
|
387
|
-
|
688
|
+
(32, 5)
|
689
|
+
>>>
|
690
|
+
>>> # Wrap around existing linear layer
|
691
|
+
>>> base = bst.nn.Linear((10,), (5,))
|
692
|
+
>>> lora_layer = bst.nn.LoRA(in_features=10, lora_rank=2,
|
693
|
+
... out_features=5, base_module=base)
|
694
|
+
>>> y = lora_layer(x)
|
695
|
+
>>> y.shape
|
696
|
+
(32, 5)
|
697
|
+
>>>
|
698
|
+
>>> # Check parameter count - LoRA has fewer parameters
|
699
|
+
>>> # Base layer: 10 * 5 = 50 parameters
|
700
|
+
>>> # LoRA: 10 * 2 + 2 * 5 = 30 parameters
|
388
701
|
"""
|
702
|
+
__module__ = 'brainstate.nn'
|
389
703
|
|
390
704
|
def __init__(
|
391
705
|
self,
|
@@ -396,6 +710,7 @@ class LoRA(Module):
|
|
396
710
|
base_module: Optional[Module] = None,
|
397
711
|
kernel_init: Union[Callable, ArrayLike] = init.LecunNormal(),
|
398
712
|
param_type: type = ParamState,
|
713
|
+
in_size: Size = None,
|
399
714
|
):
|
400
715
|
super().__init__()
|
401
716
|
|
@@ -415,6 +730,11 @@ class LoRA(Module):
|
|
415
730
|
)
|
416
731
|
self.weight = param_type(param)
|
417
732
|
|
733
|
+
# in_size
|
734
|
+
if in_size is not None:
|
735
|
+
self.in_size = in_size
|
736
|
+
self.out_size = tuple(self.in_size[:-1]) + (out_features,)
|
737
|
+
|
418
738
|
def __call__(self, x: ArrayLike):
|
419
739
|
out = x @ self.weight.value['lora_a'] @ self.weight.value['lora_b']
|
420
740
|
if self.base_module is not None:
|