brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
brainstate/nn/_linear.py CHANGED
@@ -1,744 +1,744 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from typing import Callable, Union, Optional
19
-
20
- import brainunit as u
21
- import jax.numpy as jnp
22
-
23
- from brainstate._state import ParamState
24
- from brainstate.typing import ArrayLike, Size
25
- from . import init as init
26
- from ._module import Module
27
- from ._normalizations import weight_standardization
28
-
29
- __all__ = [
30
- 'Linear',
31
- 'ScaledWSLinear',
32
- 'SignedWLinear',
33
- 'SparseLinear',
34
- 'AllToAll',
35
- 'OneToOne',
36
- 'LoRA',
37
- ]
38
-
39
-
40
- class Linear(Module):
41
- """
42
- Linear transformation layer.
43
-
44
- Applies a linear transformation to the incoming data: :math:`y = xW + b`
45
-
46
- Parameters
47
- ----------
48
- in_size : int or tuple of int
49
- The input feature size.
50
- out_size : int or tuple of int
51
- The output feature size.
52
- w_init : Callable or ArrayLike, optional
53
- Weight initializer. Default is ``KaimingNormal()``.
54
- b_init : Callable, ArrayLike, or None, optional
55
- Bias initializer. If ``None``, no bias is added. Default is ``ZeroInit()``.
56
- w_mask : ArrayLike, Callable, or None, optional
57
- Optional mask for the weights. If provided, weights will be element-wise
58
- multiplied by this mask.
59
- name : str, optional
60
- Name of the module.
61
- param_type : type, optional
62
- Type of parameter state. Default is ``ParamState``.
63
-
64
- Attributes
65
- ----------
66
- in_size : tuple
67
- Input feature size.
68
- out_size : tuple
69
- Output feature size.
70
- w_mask : ArrayLike or None
71
- Weight mask if provided.
72
- weight : ParamState
73
- Parameter state containing 'weight' and optionally 'bias'.
74
-
75
- Examples
76
- --------
77
- .. code-block:: python
78
-
79
- >>> import brainstate as bst
80
- >>> import jax.numpy as jnp
81
- >>>
82
- >>> # Create a linear layer
83
- >>> layer = bst.nn.Linear((10,), (5,))
84
- >>> x = jnp.ones((32, 10))
85
- >>> y = layer(x)
86
- >>> y.shape
87
- (32, 5)
88
- >>>
89
- >>> # Linear layer without bias
90
- >>> layer = bst.nn.Linear((10,), (5,), b_init=None)
91
- >>> y = layer(x)
92
- >>> y.shape
93
- (32, 5)
94
- """
95
- __module__ = 'brainstate.nn'
96
-
97
- def __init__(
98
- self,
99
- in_size: Size,
100
- out_size: Size,
101
- w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
102
- b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
103
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
104
- name: Optional[str] = None,
105
- param_type: type = ParamState,
106
- ):
107
- super().__init__(name=name)
108
-
109
- # input and output shape
110
- self.in_size = in_size
111
- self.out_size = out_size
112
- assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
113
- 'and "out_size" must be the same.')
114
-
115
- # w_mask
116
- self.w_mask = init.param(w_mask, self.in_size + self.out_size)
117
-
118
- # weights
119
- params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
120
- if b_init is not None:
121
- params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
122
- self.weight = param_type(params)
123
-
124
- def update(self, x):
125
- params = self.weight.value
126
- weight = params['weight']
127
- if self.w_mask is not None:
128
- weight = weight * self.w_mask
129
- y = u.linalg.dot(x, weight)
130
- if 'bias' in params:
131
- y = y + params['bias']
132
- return y
133
-
134
-
135
- class SignedWLinear(Module):
136
- """
137
- Linear layer with signed absolute weights.
138
-
139
- This layer uses absolute values of weights multiplied by a sign matrix,
140
- ensuring all effective weights have controlled signs.
141
-
142
- Parameters
143
- ----------
144
- in_size : int or tuple of int
145
- The input feature size.
146
- out_size : int or tuple of int
147
- The output feature size.
148
- w_init : Callable or ArrayLike, optional
149
- Weight initializer. Default is ``KaimingNormal()``.
150
- w_sign : ArrayLike or None, optional
151
- Sign matrix for the weights. If ``None``, all weights are positive
152
- (absolute values used). If provided, should have the same shape as
153
- the weight matrix.
154
- name : str, optional
155
- Name of the module.
156
- param_type : type, optional
157
- Type of parameter state. Default is ``ParamState``.
158
-
159
- Attributes
160
- ----------
161
- in_size : tuple
162
- Input feature size.
163
- out_size : tuple
164
- Output feature size.
165
- w_sign : ArrayLike or None
166
- Sign matrix for weights.
167
- weight : ParamState
168
- Parameter state containing the weight values.
169
-
170
- Examples
171
- --------
172
- .. code-block:: python
173
-
174
- >>> import brainstate as bst
175
- >>> import jax.numpy as jnp
176
- >>>
177
- >>> # Create a signed weight linear layer with all positive weights
178
- >>> layer = bst.nn.SignedWLinear((10,), (5,))
179
- >>> x = jnp.ones((32, 10))
180
- >>> y = layer(x)
181
- >>> y.shape
182
- (32, 5)
183
- >>>
184
- >>> # With custom sign matrix (e.g., inhibitory connections)
185
- >>> w_sign = jnp.ones((10, 5)) * -1.0 # all negative
186
- >>> layer = bst.nn.SignedWLinear((10,), (5,), w_sign=w_sign)
187
- >>> y = layer(x)
188
- >>> y.shape
189
- (32, 5)
190
- """
191
- __module__ = 'brainstate.nn'
192
-
193
- def __init__(
194
- self,
195
- in_size: Size,
196
- out_size: Size,
197
- w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
198
- w_sign: Optional[ArrayLike] = None,
199
- name: Optional[str] = None,
200
- param_type: type = ParamState,
201
- ):
202
- super().__init__(name=name)
203
-
204
- # input and output shape
205
- self.in_size = in_size
206
- self.out_size = out_size
207
- assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
208
- 'and "out_size" must be the same.')
209
-
210
- # w_mask
211
- self.w_sign = w_sign
212
-
213
- # weights
214
- weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
215
- self.weight = param_type(weight)
216
-
217
- def update(self, x):
218
- w = self.weight.value
219
- if self.w_sign is None:
220
- return u.math.matmul(x, u.math.abs(w))
221
- else:
222
- return u.math.matmul(x, u.math.abs(w) * self.w_sign)
223
-
224
-
225
- class ScaledWSLinear(Module):
226
- """
227
- Linear layer with weight standardization.
228
-
229
- Applies weight standardization [1]_ to normalize weights before the linear
230
- transformation, which can improve training stability and performance.
231
-
232
- Parameters
233
- ----------
234
- in_size : int or tuple of int
235
- The input feature size.
236
- out_size : int or tuple of int
237
- The output feature size.
238
- w_init : Callable, optional
239
- Weight initializer. Default is ``KaimingNormal()``.
240
- b_init : Callable, optional
241
- Bias initializer. Default is ``ZeroInit()``.
242
- w_mask : ArrayLike, Callable, or None, optional
243
- Optional mask for the weights.
244
- ws_gain : bool, optional
245
- Whether to use a learnable gain parameter for weight standardization.
246
- Default is ``True``.
247
- eps : float, optional
248
- Small constant for numerical stability in standardization.
249
- Default is ``1e-4``.
250
- name : str, optional
251
- Name of the module.
252
- param_type : type, optional
253
- Type of parameter state. Default is ``ParamState``.
254
-
255
- Attributes
256
- ----------
257
- in_size : tuple
258
- Input feature size.
259
- out_size : tuple
260
- Output feature size.
261
- w_mask : ArrayLike or None
262
- Weight mask if provided.
263
- eps : float
264
- Epsilon for numerical stability.
265
- weight : ParamState
266
- Parameter state containing 'weight', optionally 'bias' and 'gain'.
267
-
268
- References
269
- ----------
270
- .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
271
- Weight standardization. arXiv preprint arXiv:1903.10520.
272
-
273
- Examples
274
- --------
275
- .. code-block:: python
276
-
277
- >>> import brainstate as bst
278
- >>> import jax.numpy as jnp
279
- >>>
280
- >>> # Create a weight-standardized linear layer
281
- >>> layer = bst.nn.ScaledWSLinear((10,), (5,))
282
- >>> x = jnp.ones((32, 10))
283
- >>> y = layer(x)
284
- >>> y.shape
285
- (32, 5)
286
- >>>
287
- >>> # Without learnable gain
288
- >>> layer = bst.nn.ScaledWSLinear((10,), (5,), ws_gain=False)
289
- >>> y = layer(x)
290
- >>> y.shape
291
- (32, 5)
292
- """
293
- __module__ = 'brainstate.nn'
294
-
295
- def __init__(
296
- self,
297
- in_size: Size,
298
- out_size: Size,
299
- w_init: Callable = init.KaimingNormal(),
300
- b_init: Callable = init.ZeroInit(),
301
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
302
- ws_gain: bool = True,
303
- eps: float = 1e-4,
304
- name: str = None,
305
- param_type: type = ParamState,
306
- ):
307
- super().__init__(name=name)
308
-
309
- # input and output shape
310
- self.in_size = in_size
311
- self.out_size = out_size
312
- assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
313
- 'and "out_size" must be the same.')
314
-
315
- # w_mask
316
- self.w_mask = init.param(w_mask, (self.in_size[0], 1))
317
-
318
- # parameters
319
- self.eps = eps
320
-
321
- # weights
322
- params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
323
- if b_init is not None:
324
- params['bias'] = init.param(b_init, self.out_size, allow_none=False)
325
- # gain
326
- if ws_gain:
327
- s = params['weight'].shape
328
- params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
329
- self.weight = param_type(params)
330
-
331
- def update(self, x):
332
- params = self.weight.value
333
- w = params['weight']
334
- w = weight_standardization(w, self.eps, params.get('gain', None))
335
- if self.w_mask is not None:
336
- w = w * self.w_mask
337
- y = u.linalg.dot(x, w)
338
- if 'bias' in params:
339
- y = y + params['bias']
340
- return y
341
-
342
-
343
- class SparseLinear(Module):
344
- """
345
- Linear layer with sparse weight matrix.
346
-
347
- Supports sparse matrices from ``brainunit.sparse`` including CSR, CSC,
348
- and COO formats. Only the non-zero entries are stored and updated.
349
-
350
- Parameters
351
- ----------
352
- spar_mat : brainunit.sparse.SparseMatrix
353
- The sparse weight matrix defining the connectivity structure.
354
- b_init : Callable, ArrayLike, or None, optional
355
- Bias initializer. If ``None``, no bias is added.
356
- in_size : int or tuple of int, optional
357
- The input size. If not provided, inferred from ``spar_mat``.
358
- name : str, optional
359
- Name of the module.
360
- param_type : type, optional
361
- Type of parameter state. Default is ``ParamState``.
362
-
363
- Attributes
364
- ----------
365
- in_size : tuple
366
- Input feature size.
367
- out_size : int
368
- Output feature size.
369
- spar_mat : brainunit.sparse.SparseMatrix
370
- The sparse matrix structure.
371
- weight : ParamState
372
- Parameter state containing the sparse 'weight' data and optionally 'bias'.
373
-
374
- Examples
375
- --------
376
- .. code-block:: python
377
-
378
- >>> import brainstate as bst
379
- >>> import brainunit as u
380
- >>> import jax.numpy as jnp
381
- >>>
382
- >>> # Create a sparse linear layer with CSR matrix
383
- >>> indices = jnp.array([[0, 1], [1, 2], [2, 0]])
384
- >>> values = jnp.array([1.0, 2.0, 3.0])
385
- >>> spar_mat = u.sparse.CSR((values, indices[:, 1], indices[:, 0]),
386
- ... shape=(3, 3))
387
- >>> layer = bst.nn.SparseLinear(spar_mat, in_size=(3,))
388
- >>> x = jnp.ones((5, 3))
389
- >>> y = layer(x)
390
- >>> y.shape
391
- (5, 3)
392
- """
393
- __module__ = 'brainstate.nn'
394
-
395
- def __init__(
396
- self,
397
- spar_mat: u.sparse.SparseMatrix,
398
- b_init: Optional[Union[Callable, ArrayLike]] = None,
399
- in_size: Size = None,
400
- name: Optional[str] = None,
401
- param_type: type = ParamState,
402
- ):
403
- super().__init__(name=name)
404
-
405
- # input and output shape
406
- if in_size is not None:
407
- self.in_size = in_size
408
- self.out_size = spar_mat.shape[-1]
409
- if in_size is not None:
410
- assert self.in_size[:-1] == self.out_size[:-1], (
411
- 'The first n-1 dimensions of "in_size" '
412
- 'and "out_size" must be the same.'
413
- )
414
-
415
- # weights
416
- assert isinstance(spar_mat, u.sparse.SparseMatrix), '"weight" must be a SparseMatrix.'
417
- self.spar_mat = spar_mat
418
- params = dict(weight=spar_mat.data)
419
- if b_init is not None:
420
- params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
421
- self.weight = param_type(params)
422
-
423
- def update(self, x):
424
- data = self.weight.value['weight']
425
- y = x @ self.spar_mat.with_data(data)
426
- if 'bias' in self.weight.value:
427
- y = y + self.weight.value['bias']
428
- return y
429
-
430
-
431
- class AllToAll(Module):
432
- """
433
- All-to-all connection layer.
434
-
435
- Performs matrix multiplication with optional exclusion of self-connections,
436
- commonly used in recurrent neural networks and graph neural networks.
437
-
438
- Parameters
439
- ----------
440
- in_size : int or tuple of int
441
- The number of neurons in the pre-synaptic group.
442
- out_size : int or tuple of int
443
- The number of neurons in the post-synaptic group.
444
- w_init : Callable or ArrayLike, optional
445
- Weight initializer. Default is ``KaimingNormal()``.
446
- b_init : Callable, ArrayLike, or None, optional
447
- Bias initializer. If ``None``, no bias is added.
448
- include_self : bool, optional
449
- Whether to include self-connections (diagonal elements).
450
- Default is ``True``.
451
- name : str, optional
452
- Name of the module.
453
- param_type : type, optional
454
- Type of parameter state. Default is ``ParamState``.
455
-
456
- Attributes
457
- ----------
458
- in_size : tuple
459
- Input size.
460
- out_size : tuple
461
- Output size.
462
- include_self : bool
463
- Whether self-connections are included.
464
- weight : ParamState
465
- Parameter state containing 'weight' and optionally 'bias'.
466
-
467
- Examples
468
- --------
469
- .. code-block:: python
470
-
471
- >>> import brainstate as bst
472
- >>> import jax.numpy as jnp
473
- >>>
474
- >>> # All-to-all with self-connections
475
- >>> layer = bst.nn.AllToAll((10,), (10,), include_self=True)
476
- >>> x = jnp.ones((32, 10))
477
- >>> y = layer(x)
478
- >>> y.shape
479
- (32, 10)
480
- >>>
481
- >>> # All-to-all without self-connections (recurrent layer)
482
- >>> layer = bst.nn.AllToAll((10,), (10,), include_self=False)
483
- >>> y = layer(x)
484
- >>> y.shape
485
- (32, 10)
486
- """
487
- __module__ = 'brainstate.nn'
488
-
489
- def __init__(
490
- self,
491
- in_size: Size,
492
- out_size: Size,
493
- w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
494
- b_init: Optional[Union[Callable, ArrayLike]] = None,
495
- include_self: bool = True,
496
- name: Optional[str] = None,
497
- param_type: type = ParamState,
498
- ):
499
- super().__init__(name=name)
500
-
501
- # input and output shape
502
- self.in_size = in_size
503
- self.out_size = out_size
504
- assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
505
- 'and "out_size" must be the same.')
506
-
507
- # others
508
- self.include_self = include_self
509
-
510
- # weights
511
- weight = init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False)
512
- params = dict(weight=weight)
513
- if b_init is not None:
514
- params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
515
- self.weight = param_type(params)
516
-
517
- def update(self, pre_val):
518
- params = self.weight.value
519
- pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
520
- w_val, w_unit = u.get_mantissa(params['weight']), u.get_unit(params['weight'])
521
-
522
- if u.math.ndim(w_val) == 0: # weight is a scalar
523
- if pre_val.ndim == 1:
524
- post_val = u.math.sum(pre_val)
525
- else:
526
- post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
527
- if not self.include_self:
528
- if self.in_size == self.out_size:
529
- post_val = post_val - pre_val
530
- elif self.in_size[-1] > self.out_size[-1]:
531
- val = pre_val[..., :self.out_size[-1]]
532
- post_val = post_val - val
533
- else:
534
- size = list(self.out_size)
535
- size[-1] = self.out_size[-1] - self.in_size[-1]
536
- val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
537
- post_val = post_val - val
538
- post_val = w_val * post_val
539
-
540
- else: # weight is a matrix
541
- assert u.math.ndim(w_val) == 2, '"weight" must be a 2D matrix.'
542
- if not self.include_self:
543
- post_val = pre_val @ u.math.fill_diagonal(w_val, 0.)
544
- else:
545
- post_val = pre_val @ w_val
546
-
547
- post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
548
- if 'bias' in params:
549
- post_val = post_val + params['bias']
550
- return post_val
551
-
552
-
553
- class OneToOne(Module):
554
- """
555
- One-to-one connection layer.
556
-
557
- Applies element-wise multiplication with a weight vector, implementing
558
- diagonal connectivity where each input unit connects only to its
559
- corresponding output unit.
560
-
561
- Parameters
562
- ----------
563
- in_size : int or tuple of int
564
- The number of neurons. Input and output sizes are the same.
565
- w_init : Callable or ArrayLike, optional
566
- Weight initializer. Default is ``Normal()``.
567
- b_init : Callable, ArrayLike, or None, optional
568
- Bias initializer. If ``None``, no bias is added.
569
- name : str, optional
570
- Name of the module.
571
- param_type : type, optional
572
- Type of parameter state. Default is ``ParamState``.
573
-
574
- Attributes
575
- ----------
576
- in_size : tuple
577
- Input size.
578
- out_size : tuple
579
- Output size (same as input size).
580
- weight : ParamState
581
- Parameter state containing 'weight' and optionally 'bias'.
582
-
583
- Examples
584
- --------
585
- .. code-block:: python
586
-
587
- >>> import brainstate as bst
588
- >>> import jax.numpy as jnp
589
- >>>
590
- >>> # One-to-one connection
591
- >>> layer = bst.nn.OneToOne((10,))
592
- >>> x = jnp.ones((32, 10))
593
- >>> y = layer(x)
594
- >>> y.shape
595
- (32, 10)
596
- >>>
597
- >>> # With bias
598
- >>> layer = bst.nn.OneToOne((10,), b_init=bst.init.Constant(0.1))
599
- >>> y = layer(x)
600
- >>> y.shape
601
- (32, 10)
602
- """
603
- __module__ = 'brainstate.nn'
604
-
605
- def __init__(
606
- self,
607
- in_size: Size,
608
- w_init: Union[Callable, ArrayLike] = init.Normal(),
609
- b_init: Optional[Union[Callable, ArrayLike]] = None,
610
- name: Optional[str] = None,
611
- param_type: type = ParamState,
612
- ):
613
- super().__init__(name=name)
614
-
615
- # input and output shape
616
- self.in_size = in_size
617
- self.out_size = in_size
618
-
619
- # weights
620
- param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
621
- if b_init is not None:
622
- param['bias'] = init.param(b_init, self.out_size, allow_none=False)
623
- self.weight = param_type(param)
624
-
625
- def update(self, pre_val):
626
- post_val = pre_val * self.weight.value['weight']
627
- if 'bias' in self.weight.value:
628
- post_val = post_val + self.weight.value['bias']
629
- return post_val
630
-
631
-
632
- class LoRA(Module):
633
- """
634
- Low-Rank Adaptation (LoRA) layer.
635
-
636
- Implements parameter-efficient fine-tuning using low-rank decomposition [1]_.
637
- Can be used standalone or as a wrapper around an existing module.
638
-
639
- Parameters
640
- ----------
641
- in_features : int
642
- The number of input features.
643
- lora_rank : int
644
- The rank of the low-rank decomposition. Lower rank means fewer parameters.
645
- out_features : int
646
- The number of output features.
647
- base_module : Module, optional
648
- A base module to wrap. If provided, the LoRA output will be added to
649
- the base module's output. Default is ``None``.
650
- kernel_init : Callable or ArrayLike, optional
651
- Initializer for the LoRA weight matrices. Default is ``LecunNormal()``.
652
- param_type : type, optional
653
- Type of parameter state. Default is ``ParamState``.
654
-
655
- Attributes
656
- ----------
657
- in_size : int
658
- Input feature size.
659
- out_size : int
660
- Output feature size.
661
- in_features : int
662
- Number of input features.
663
- out_features : int
664
- Number of output features.
665
- base_module : Module or None
666
- The wrapped base module if provided.
667
- weight : ParamState
668
- Parameter state containing 'lora_a' and 'lora_b' matrices.
669
-
670
- References
671
- ----------
672
- .. [1] Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S.,
673
- Wang, L., & Chen, W. (2021). LoRA: Low-Rank Adaptation of Large
674
- Language Models. arXiv preprint arXiv:2106.09685.
675
-
676
- Examples
677
- --------
678
- .. code-block:: python
679
-
680
- >>> import brainstate as bst
681
- >>> import jax.numpy as jnp
682
- >>>
683
- >>> # Standalone LoRA layer
684
- >>> layer = bst.nn.LoRA(in_features=10, lora_rank=2, out_features=5)
685
- >>> x = jnp.ones((32, 10))
686
- >>> y = layer(x)
687
- >>> y.shape
688
- (32, 5)
689
- >>>
690
- >>> # Wrap around existing linear layer
691
- >>> base = bst.nn.Linear((10,), (5,))
692
- >>> lora_layer = bst.nn.LoRA(in_features=10, lora_rank=2,
693
- ... out_features=5, base_module=base)
694
- >>> y = lora_layer(x)
695
- >>> y.shape
696
- (32, 5)
697
- >>>
698
- >>> # Check parameter count - LoRA has fewer parameters
699
- >>> # Base layer: 10 * 5 = 50 parameters
700
- >>> # LoRA: 10 * 2 + 2 * 5 = 30 parameters
701
- """
702
- __module__ = 'brainstate.nn'
703
-
704
- def __init__(
705
- self,
706
- in_features: int,
707
- lora_rank: int,
708
- out_features: int,
709
- *,
710
- base_module: Optional[Module] = None,
711
- kernel_init: Union[Callable, ArrayLike] = init.LecunNormal(),
712
- param_type: type = ParamState,
713
- in_size: Size = None,
714
- ):
715
- super().__init__()
716
-
717
- # input and output shape
718
- self.in_size = in_features
719
- self.out_size = out_features
720
- self.in_features = in_features
721
- self.out_features = out_features
722
-
723
- # others
724
- self.base_module = base_module
725
-
726
- # weights
727
- param = dict(
728
- lora_a=kernel_init((in_features, lora_rank)),
729
- lora_b=kernel_init((lora_rank, out_features))
730
- )
731
- self.weight = param_type(param)
732
-
733
- # in_size
734
- if in_size is not None:
735
- self.in_size = in_size
736
- self.out_size = tuple(self.in_size[:-1]) + (out_features,)
737
-
738
- def __call__(self, x: ArrayLike):
739
- out = x @ self.weight.value['lora_a'] @ self.weight.value['lora_b']
740
- if self.base_module is not None:
741
- if not callable(self.base_module):
742
- raise ValueError('`self.base_module` must be callable.')
743
- out += self.base_module(x)
744
- return out
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from typing import Callable, Union, Optional
19
+
20
+ import brainunit as u
21
+ import jax.numpy as jnp
22
+
23
+ from brainstate._state import ParamState
24
+ from brainstate.typing import ArrayLike, Size
25
+ from . import init as init
26
+ from ._module import Module
27
+ from ._normalizations import weight_standardization
28
+
29
+ __all__ = [
30
+ 'Linear',
31
+ 'ScaledWSLinear',
32
+ 'SignedWLinear',
33
+ 'SparseLinear',
34
+ 'AllToAll',
35
+ 'OneToOne',
36
+ 'LoRA',
37
+ ]
38
+
39
+
40
+ class Linear(Module):
41
+ """
42
+ Linear transformation layer.
43
+
44
+ Applies a linear transformation to the incoming data: :math:`y = xW + b`
45
+
46
+ Parameters
47
+ ----------
48
+ in_size : int or tuple of int
49
+ The input feature size.
50
+ out_size : int or tuple of int
51
+ The output feature size.
52
+ w_init : Callable or ArrayLike, optional
53
+ Weight initializer. Default is ``KaimingNormal()``.
54
+ b_init : Callable, ArrayLike, or None, optional
55
+ Bias initializer. If ``None``, no bias is added. Default is ``ZeroInit()``.
56
+ w_mask : ArrayLike, Callable, or None, optional
57
+ Optional mask for the weights. If provided, weights will be element-wise
58
+ multiplied by this mask.
59
+ name : str, optional
60
+ Name of the module.
61
+ param_type : type, optional
62
+ Type of parameter state. Default is ``ParamState``.
63
+
64
+ Attributes
65
+ ----------
66
+ in_size : tuple
67
+ Input feature size.
68
+ out_size : tuple
69
+ Output feature size.
70
+ w_mask : ArrayLike or None
71
+ Weight mask if provided.
72
+ weight : ParamState
73
+ Parameter state containing 'weight' and optionally 'bias'.
74
+
75
+ Examples
76
+ --------
77
+ .. code-block:: python
78
+
79
+ >>> import brainstate as brainstate
80
+ >>> import jax.numpy as jnp
81
+ >>>
82
+ >>> # Create a linear layer
83
+ >>> layer = brainstate.nn.Linear((10,), (5,))
84
+ >>> x = jnp.ones((32, 10))
85
+ >>> y = layer(x)
86
+ >>> y.shape
87
+ (32, 5)
88
+ >>>
89
+ >>> # Linear layer without bias
90
+ >>> layer = brainstate.nn.Linear((10,), (5,), b_init=None)
91
+ >>> y = layer(x)
92
+ >>> y.shape
93
+ (32, 5)
94
+ """
95
+ __module__ = 'brainstate.nn'
96
+
97
+ def __init__(
98
+ self,
99
+ in_size: Size,
100
+ out_size: Size,
101
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
102
+ b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
103
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
104
+ name: Optional[str] = None,
105
+ param_type: type = ParamState,
106
+ ):
107
+ super().__init__(name=name)
108
+
109
+ # input and output shape
110
+ self.in_size = in_size
111
+ self.out_size = out_size
112
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
113
+ 'and "out_size" must be the same.')
114
+
115
+ # w_mask
116
+ self.w_mask = init.param(w_mask, self.in_size + self.out_size)
117
+
118
+ # weights
119
+ params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
120
+ if b_init is not None:
121
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
122
+ self.weight = param_type(params)
123
+
124
+ def update(self, x):
125
+ params = self.weight.value
126
+ weight = params['weight']
127
+ if self.w_mask is not None:
128
+ weight = weight * self.w_mask
129
+ y = u.linalg.dot(x, weight)
130
+ if 'bias' in params:
131
+ y = y + params['bias']
132
+ return y
133
+
134
+
135
+ class SignedWLinear(Module):
136
+ """
137
+ Linear layer with signed absolute weights.
138
+
139
+ This layer uses absolute values of weights multiplied by a sign matrix,
140
+ ensuring all effective weights have controlled signs.
141
+
142
+ Parameters
143
+ ----------
144
+ in_size : int or tuple of int
145
+ The input feature size.
146
+ out_size : int or tuple of int
147
+ The output feature size.
148
+ w_init : Callable or ArrayLike, optional
149
+ Weight initializer. Default is ``KaimingNormal()``.
150
+ w_sign : ArrayLike or None, optional
151
+ Sign matrix for the weights. If ``None``, all weights are positive
152
+ (absolute values used). If provided, should have the same shape as
153
+ the weight matrix.
154
+ name : str, optional
155
+ Name of the module.
156
+ param_type : type, optional
157
+ Type of parameter state. Default is ``ParamState``.
158
+
159
+ Attributes
160
+ ----------
161
+ in_size : tuple
162
+ Input feature size.
163
+ out_size : tuple
164
+ Output feature size.
165
+ w_sign : ArrayLike or None
166
+ Sign matrix for weights.
167
+ weight : ParamState
168
+ Parameter state containing the weight values.
169
+
170
+ Examples
171
+ --------
172
+ .. code-block:: python
173
+
174
+ >>> import brainstate as brainstate
175
+ >>> import jax.numpy as jnp
176
+ >>>
177
+ >>> # Create a signed weight linear layer with all positive weights
178
+ >>> layer = brainstate.nn.SignedWLinear((10,), (5,))
179
+ >>> x = jnp.ones((32, 10))
180
+ >>> y = layer(x)
181
+ >>> y.shape
182
+ (32, 5)
183
+ >>>
184
+ >>> # With custom sign matrix (e.g., inhibitory connections)
185
+ >>> w_sign = jnp.ones((10, 5)) * -1.0 # all negative
186
+ >>> layer = brainstate.nn.SignedWLinear((10,), (5,), w_sign=w_sign)
187
+ >>> y = layer(x)
188
+ >>> y.shape
189
+ (32, 5)
190
+ """
191
+ __module__ = 'brainstate.nn'
192
+
193
+ def __init__(
194
+ self,
195
+ in_size: Size,
196
+ out_size: Size,
197
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
198
+ w_sign: Optional[ArrayLike] = None,
199
+ name: Optional[str] = None,
200
+ param_type: type = ParamState,
201
+ ):
202
+ super().__init__(name=name)
203
+
204
+ # input and output shape
205
+ self.in_size = in_size
206
+ self.out_size = out_size
207
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
208
+ 'and "out_size" must be the same.')
209
+
210
+ # w_mask
211
+ self.w_sign = w_sign
212
+
213
+ # weights
214
+ weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
215
+ self.weight = param_type(weight)
216
+
217
+ def update(self, x):
218
+ w = self.weight.value
219
+ if self.w_sign is None:
220
+ return u.math.matmul(x, u.math.abs(w))
221
+ else:
222
+ return u.math.matmul(x, u.math.abs(w) * self.w_sign)
223
+
224
+
225
+ class ScaledWSLinear(Module):
226
+ """
227
+ Linear layer with weight standardization.
228
+
229
+ Applies weight standardization [1]_ to normalize weights before the linear
230
+ transformation, which can improve training stability and performance.
231
+
232
+ Parameters
233
+ ----------
234
+ in_size : int or tuple of int
235
+ The input feature size.
236
+ out_size : int or tuple of int
237
+ The output feature size.
238
+ w_init : Callable, optional
239
+ Weight initializer. Default is ``KaimingNormal()``.
240
+ b_init : Callable, optional
241
+ Bias initializer. Default is ``ZeroInit()``.
242
+ w_mask : ArrayLike, Callable, or None, optional
243
+ Optional mask for the weights.
244
+ ws_gain : bool, optional
245
+ Whether to use a learnable gain parameter for weight standardization.
246
+ Default is ``True``.
247
+ eps : float, optional
248
+ Small constant for numerical stability in standardization.
249
+ Default is ``1e-4``.
250
+ name : str, optional
251
+ Name of the module.
252
+ param_type : type, optional
253
+ Type of parameter state. Default is ``ParamState``.
254
+
255
+ Attributes
256
+ ----------
257
+ in_size : tuple
258
+ Input feature size.
259
+ out_size : tuple
260
+ Output feature size.
261
+ w_mask : ArrayLike or None
262
+ Weight mask if provided.
263
+ eps : float
264
+ Epsilon for numerical stability.
265
+ weight : ParamState
266
+ Parameter state containing 'weight', optionally 'bias' and 'gain'.
267
+
268
+ References
269
+ ----------
270
+ .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
271
+ Weight standardization. arXiv preprint arXiv:1903.10520.
272
+
273
+ Examples
274
+ --------
275
+ .. code-block:: python
276
+
277
+ >>> import brainstate as brainstate
278
+ >>> import jax.numpy as jnp
279
+ >>>
280
+ >>> # Create a weight-standardized linear layer
281
+ >>> layer = brainstate.nn.ScaledWSLinear((10,), (5,))
282
+ >>> x = jnp.ones((32, 10))
283
+ >>> y = layer(x)
284
+ >>> y.shape
285
+ (32, 5)
286
+ >>>
287
+ >>> # Without learnable gain
288
+ >>> layer = brainstate.nn.ScaledWSLinear((10,), (5,), ws_gain=False)
289
+ >>> y = layer(x)
290
+ >>> y.shape
291
+ (32, 5)
292
+ """
293
+ __module__ = 'brainstate.nn'
294
+
295
+ def __init__(
296
+ self,
297
+ in_size: Size,
298
+ out_size: Size,
299
+ w_init: Callable = init.KaimingNormal(),
300
+ b_init: Callable = init.ZeroInit(),
301
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
302
+ ws_gain: bool = True,
303
+ eps: float = 1e-4,
304
+ name: str = None,
305
+ param_type: type = ParamState,
306
+ ):
307
+ super().__init__(name=name)
308
+
309
+ # input and output shape
310
+ self.in_size = in_size
311
+ self.out_size = out_size
312
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
313
+ 'and "out_size" must be the same.')
314
+
315
+ # w_mask
316
+ self.w_mask = init.param(w_mask, (self.in_size[0], 1))
317
+
318
+ # parameters
319
+ self.eps = eps
320
+
321
+ # weights
322
+ params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
323
+ if b_init is not None:
324
+ params['bias'] = init.param(b_init, self.out_size, allow_none=False)
325
+ # gain
326
+ if ws_gain:
327
+ s = params['weight'].shape
328
+ params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
329
+ self.weight = param_type(params)
330
+
331
+ def update(self, x):
332
+ params = self.weight.value
333
+ w = params['weight']
334
+ w = weight_standardization(w, self.eps, params.get('gain', None))
335
+ if self.w_mask is not None:
336
+ w = w * self.w_mask
337
+ y = u.linalg.dot(x, w)
338
+ if 'bias' in params:
339
+ y = y + params['bias']
340
+ return y
341
+
342
+
343
+ class SparseLinear(Module):
344
+ """
345
+ Linear layer with sparse weight matrix.
346
+
347
+ Supports sparse matrices from ``brainunit.sparse`` including CSR, CSC,
348
+ and COO formats. Only the non-zero entries are stored and updated.
349
+
350
+ Parameters
351
+ ----------
352
+ spar_mat : brainunit.sparse.SparseMatrix
353
+ The sparse weight matrix defining the connectivity structure.
354
+ b_init : Callable, ArrayLike, or None, optional
355
+ Bias initializer. If ``None``, no bias is added.
356
+ in_size : int or tuple of int, optional
357
+ The input size. If not provided, inferred from ``spar_mat``.
358
+ name : str, optional
359
+ Name of the module.
360
+ param_type : type, optional
361
+ Type of parameter state. Default is ``ParamState``.
362
+
363
+ Attributes
364
+ ----------
365
+ in_size : tuple
366
+ Input feature size.
367
+ out_size : int
368
+ Output feature size.
369
+ spar_mat : brainunit.sparse.SparseMatrix
370
+ The sparse matrix structure.
371
+ weight : ParamState
372
+ Parameter state containing the sparse 'weight' data and optionally 'bias'.
373
+
374
+ Examples
375
+ --------
376
+ .. code-block:: python
377
+
378
+ >>> import brainstate as brainstate
379
+ >>> import brainunit as u
380
+ >>> import jax.numpy as jnp
381
+ >>>
382
+ >>> # Create a sparse linear layer with CSR matrix
383
+ >>> indices = jnp.array([[0, 1], [1, 2], [2, 0]])
384
+ >>> values = jnp.array([1.0, 2.0, 3.0])
385
+ >>> spar_mat = u.sparse.CSR((values, indices[:, 1], indices[:, 0]),
386
+ ... shape=(3, 3))
387
+ >>> layer = brainstate.nn.SparseLinear(spar_mat, in_size=(3,))
388
+ >>> x = jnp.ones((5, 3))
389
+ >>> y = layer(x)
390
+ >>> y.shape
391
+ (5, 3)
392
+ """
393
+ __module__ = 'brainstate.nn'
394
+
395
+ def __init__(
396
+ self,
397
+ spar_mat: u.sparse.SparseMatrix,
398
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
399
+ in_size: Size = None,
400
+ name: Optional[str] = None,
401
+ param_type: type = ParamState,
402
+ ):
403
+ super().__init__(name=name)
404
+
405
+ # input and output shape
406
+ if in_size is not None:
407
+ self.in_size = in_size
408
+ self.out_size = spar_mat.shape[-1]
409
+ if in_size is not None:
410
+ assert self.in_size[:-1] == self.out_size[:-1], (
411
+ 'The first n-1 dimensions of "in_size" '
412
+ 'and "out_size" must be the same.'
413
+ )
414
+
415
+ # weights
416
+ assert isinstance(spar_mat, u.sparse.SparseMatrix), '"weight" must be a SparseMatrix.'
417
+ self.spar_mat = spar_mat
418
+ params = dict(weight=spar_mat.data)
419
+ if b_init is not None:
420
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
421
+ self.weight = param_type(params)
422
+
423
+ def update(self, x):
424
+ data = self.weight.value['weight']
425
+ y = x @ self.spar_mat.with_data(data)
426
+ if 'bias' in self.weight.value:
427
+ y = y + self.weight.value['bias']
428
+ return y
429
+
430
+
431
+ class AllToAll(Module):
432
+ """
433
+ All-to-all connection layer.
434
+
435
+ Performs matrix multiplication with optional exclusion of self-connections,
436
+ commonly used in recurrent neural networks and graph neural networks.
437
+
438
+ Parameters
439
+ ----------
440
+ in_size : int or tuple of int
441
+ The number of neurons in the pre-synaptic group.
442
+ out_size : int or tuple of int
443
+ The number of neurons in the post-synaptic group.
444
+ w_init : Callable or ArrayLike, optional
445
+ Weight initializer. Default is ``KaimingNormal()``.
446
+ b_init : Callable, ArrayLike, or None, optional
447
+ Bias initializer. If ``None``, no bias is added.
448
+ include_self : bool, optional
449
+ Whether to include self-connections (diagonal elements).
450
+ Default is ``True``.
451
+ name : str, optional
452
+ Name of the module.
453
+ param_type : type, optional
454
+ Type of parameter state. Default is ``ParamState``.
455
+
456
+ Attributes
457
+ ----------
458
+ in_size : tuple
459
+ Input size.
460
+ out_size : tuple
461
+ Output size.
462
+ include_self : bool
463
+ Whether self-connections are included.
464
+ weight : ParamState
465
+ Parameter state containing 'weight' and optionally 'bias'.
466
+
467
+ Examples
468
+ --------
469
+ .. code-block:: python
470
+
471
+ >>> import brainstate as brainstate
472
+ >>> import jax.numpy as jnp
473
+ >>>
474
+ >>> # All-to-all with self-connections
475
+ >>> layer = brainstate.nn.AllToAll((10,), (10,), include_self=True)
476
+ >>> x = jnp.ones((32, 10))
477
+ >>> y = layer(x)
478
+ >>> y.shape
479
+ (32, 10)
480
+ >>>
481
+ >>> # All-to-all without self-connections (recurrent layer)
482
+ >>> layer = brainstate.nn.AllToAll((10,), (10,), include_self=False)
483
+ >>> y = layer(x)
484
+ >>> y.shape
485
+ (32, 10)
486
+ """
487
+ __module__ = 'brainstate.nn'
488
+
489
+ def __init__(
490
+ self,
491
+ in_size: Size,
492
+ out_size: Size,
493
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
494
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
495
+ include_self: bool = True,
496
+ name: Optional[str] = None,
497
+ param_type: type = ParamState,
498
+ ):
499
+ super().__init__(name=name)
500
+
501
+ # input and output shape
502
+ self.in_size = in_size
503
+ self.out_size = out_size
504
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
505
+ 'and "out_size" must be the same.')
506
+
507
+ # others
508
+ self.include_self = include_self
509
+
510
+ # weights
511
+ weight = init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False)
512
+ params = dict(weight=weight)
513
+ if b_init is not None:
514
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
515
+ self.weight = param_type(params)
516
+
517
+ def update(self, pre_val):
518
+ params = self.weight.value
519
+ pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
520
+ w_val, w_unit = u.get_mantissa(params['weight']), u.get_unit(params['weight'])
521
+
522
+ if u.math.ndim(w_val) == 0: # weight is a scalar
523
+ if pre_val.ndim == 1:
524
+ post_val = u.math.sum(pre_val)
525
+ else:
526
+ post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
527
+ if not self.include_self:
528
+ if self.in_size == self.out_size:
529
+ post_val = post_val - pre_val
530
+ elif self.in_size[-1] > self.out_size[-1]:
531
+ val = pre_val[..., :self.out_size[-1]]
532
+ post_val = post_val - val
533
+ else:
534
+ size = list(self.out_size)
535
+ size[-1] = self.out_size[-1] - self.in_size[-1]
536
+ val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
537
+ post_val = post_val - val
538
+ post_val = w_val * post_val
539
+
540
+ else: # weight is a matrix
541
+ assert u.math.ndim(w_val) == 2, '"weight" must be a 2D matrix.'
542
+ if not self.include_self:
543
+ post_val = pre_val @ u.math.fill_diagonal(w_val, 0.)
544
+ else:
545
+ post_val = pre_val @ w_val
546
+
547
+ post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
548
+ if 'bias' in params:
549
+ post_val = post_val + params['bias']
550
+ return post_val
551
+
552
+
553
+ class OneToOne(Module):
554
+ """
555
+ One-to-one connection layer.
556
+
557
+ Applies element-wise multiplication with a weight vector, implementing
558
+ diagonal connectivity where each input unit connects only to its
559
+ corresponding output unit.
560
+
561
+ Parameters
562
+ ----------
563
+ in_size : int or tuple of int
564
+ The number of neurons. Input and output sizes are the same.
565
+ w_init : Callable or ArrayLike, optional
566
+ Weight initializer. Default is ``Normal()``.
567
+ b_init : Callable, ArrayLike, or None, optional
568
+ Bias initializer. If ``None``, no bias is added.
569
+ name : str, optional
570
+ Name of the module.
571
+ param_type : type, optional
572
+ Type of parameter state. Default is ``ParamState``.
573
+
574
+ Attributes
575
+ ----------
576
+ in_size : tuple
577
+ Input size.
578
+ out_size : tuple
579
+ Output size (same as input size).
580
+ weight : ParamState
581
+ Parameter state containing 'weight' and optionally 'bias'.
582
+
583
+ Examples
584
+ --------
585
+ .. code-block:: python
586
+
587
+ >>> import brainstate as brainstate
588
+ >>> import jax.numpy as jnp
589
+ >>>
590
+ >>> # One-to-one connection
591
+ >>> layer = brainstate.nn.OneToOne((10,))
592
+ >>> x = jnp.ones((32, 10))
593
+ >>> y = layer(x)
594
+ >>> y.shape
595
+ (32, 10)
596
+ >>>
597
+ >>> # With bias
598
+ >>> layer = brainstate.nn.OneToOne((10,), b_init=brainstate.init.Constant(0.1))
599
+ >>> y = layer(x)
600
+ >>> y.shape
601
+ (32, 10)
602
+ """
603
+ __module__ = 'brainstate.nn'
604
+
605
+ def __init__(
606
+ self,
607
+ in_size: Size,
608
+ w_init: Union[Callable, ArrayLike] = init.Normal(),
609
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
610
+ name: Optional[str] = None,
611
+ param_type: type = ParamState,
612
+ ):
613
+ super().__init__(name=name)
614
+
615
+ # input and output shape
616
+ self.in_size = in_size
617
+ self.out_size = in_size
618
+
619
+ # weights
620
+ param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
621
+ if b_init is not None:
622
+ param['bias'] = init.param(b_init, self.out_size, allow_none=False)
623
+ self.weight = param_type(param)
624
+
625
+ def update(self, pre_val):
626
+ post_val = pre_val * self.weight.value['weight']
627
+ if 'bias' in self.weight.value:
628
+ post_val = post_val + self.weight.value['bias']
629
+ return post_val
630
+
631
+
632
+ class LoRA(Module):
633
+ """
634
+ Low-Rank Adaptation (LoRA) layer.
635
+
636
+ Implements parameter-efficient fine-tuning using low-rank decomposition [1]_.
637
+ Can be used standalone or as a wrapper around an existing module.
638
+
639
+ Parameters
640
+ ----------
641
+ in_features : int
642
+ The number of input features.
643
+ lora_rank : int
644
+ The rank of the low-rank decomposition. Lower rank means fewer parameters.
645
+ out_features : int
646
+ The number of output features.
647
+ base_module : Module, optional
648
+ A base module to wrap. If provided, the LoRA output will be added to
649
+ the base module's output. Default is ``None``.
650
+ kernel_init : Callable or ArrayLike, optional
651
+ Initializer for the LoRA weight matrices. Default is ``LecunNormal()``.
652
+ param_type : type, optional
653
+ Type of parameter state. Default is ``ParamState``.
654
+
655
+ Attributes
656
+ ----------
657
+ in_size : int
658
+ Input feature size.
659
+ out_size : int
660
+ Output feature size.
661
+ in_features : int
662
+ Number of input features.
663
+ out_features : int
664
+ Number of output features.
665
+ base_module : Module or None
666
+ The wrapped base module if provided.
667
+ weight : ParamState
668
+ Parameter state containing 'lora_a' and 'lora_b' matrices.
669
+
670
+ References
671
+ ----------
672
+ .. [1] Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S.,
673
+ Wang, L., & Chen, W. (2021). LoRA: Low-Rank Adaptation of Large
674
+ Language Models. arXiv preprint arXiv:2106.09685.
675
+
676
+ Examples
677
+ --------
678
+ .. code-block:: python
679
+
680
+ >>> import brainstate as brainstate
681
+ >>> import jax.numpy as jnp
682
+ >>>
683
+ >>> # Standalone LoRA layer
684
+ >>> layer = brainstate.nn.LoRA(in_features=10, lora_rank=2, out_features=5)
685
+ >>> x = jnp.ones((32, 10))
686
+ >>> y = layer(x)
687
+ >>> y.shape
688
+ (32, 5)
689
+ >>>
690
+ >>> # Wrap around existing linear layer
691
+ >>> base = brainstate.nn.Linear((10,), (5,))
692
+ >>> lora_layer = brainstate.nn.LoRA(in_features=10, lora_rank=2,
693
+ ... out_features=5, base_module=base)
694
+ >>> y = lora_layer(x)
695
+ >>> y.shape
696
+ (32, 5)
697
+ >>>
698
+ >>> # Check parameter count - LoRA has fewer parameters
699
+ >>> # Base layer: 10 * 5 = 50 parameters
700
+ >>> # LoRA: 10 * 2 + 2 * 5 = 30 parameters
701
+ """
702
+ __module__ = 'brainstate.nn'
703
+
704
+ def __init__(
705
+ self,
706
+ in_features: int,
707
+ lora_rank: int,
708
+ out_features: int,
709
+ *,
710
+ base_module: Optional[Module] = None,
711
+ kernel_init: Union[Callable, ArrayLike] = init.LecunNormal(),
712
+ param_type: type = ParamState,
713
+ in_size: Size = None,
714
+ ):
715
+ super().__init__()
716
+
717
+ # input and output shape
718
+ self.in_size = in_features
719
+ self.out_size = out_features
720
+ self.in_features = in_features
721
+ self.out_features = out_features
722
+
723
+ # others
724
+ self.base_module = base_module
725
+
726
+ # weights
727
+ param = dict(
728
+ lora_a=kernel_init((in_features, lora_rank)),
729
+ lora_b=kernel_init((lora_rank, out_features))
730
+ )
731
+ self.weight = param_type(param)
732
+
733
+ # in_size
734
+ if in_size is not None:
735
+ self.in_size = in_size
736
+ self.out_size = tuple(self.in_size[:-1]) + (out_features,)
737
+
738
+ def __call__(self, x: ArrayLike):
739
+ out = x @ self.weight.value['lora_a'] @ self.weight.value['lora_b']
740
+ if self.base_module is not None:
741
+ if not callable(self.base_module):
742
+ raise ValueError('`self.base_module` must be callable.')
743
+ out += self.base_module(x)
744
+ return out