brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
brainstate/nn/_linear.py CHANGED
@@ -1,424 +1,744 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- from typing import Callable, Union, Optional
19
-
20
- import brainunit as u
21
- import jax.numpy as jnp
22
-
23
- from brainstate import init, functional
24
- from brainstate._state import ParamState
25
- from brainstate.typing import ArrayLike, Size
26
- from ._module import Module
27
-
28
- __all__ = [
29
- 'Linear',
30
- 'ScaledWSLinear',
31
- 'SignedWLinear',
32
- 'SparseLinear',
33
- 'AllToAll',
34
- 'OneToOne',
35
- 'LoRA',
36
- ]
37
-
38
-
39
- class Linear(Module):
40
- """
41
- Linear layer.
42
- """
43
- __module__ = 'brainstate.nn'
44
-
45
- def __init__(
46
- self,
47
- in_size: Size,
48
- out_size: Size,
49
- w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
50
- b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
51
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
52
- name: Optional[str] = None,
53
- param_type: type = ParamState,
54
- ):
55
- super().__init__(name=name)
56
-
57
- # input and output shape
58
- self.in_size = in_size
59
- self.out_size = out_size
60
- assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
61
- 'and "out_size" must be the same.')
62
-
63
- # w_mask
64
- self.w_mask = init.param(w_mask, self.in_size + self.out_size)
65
-
66
- # weights
67
- params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
68
- if b_init is not None:
69
- params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
70
- self.weight = param_type(params)
71
-
72
- def update(self, x):
73
- params = self.weight.value
74
- weight = params['weight']
75
- if self.w_mask is not None:
76
- weight = weight * self.w_mask
77
- y = u.linalg.dot(x, weight)
78
- if 'bias' in params:
79
- y = y + params['bias']
80
- return y
81
-
82
-
83
- class SignedWLinear(Module):
84
- """
85
- Linear layer with signed weights.
86
- """
87
- __module__ = 'brainstate.nn'
88
-
89
- def __init__(
90
- self,
91
- in_size: Size,
92
- out_size: Size,
93
- w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
94
- w_sign: Optional[ArrayLike] = None,
95
- name: Optional[str] = None,
96
- param_type: type = ParamState,
97
- ):
98
- super().__init__(name=name)
99
-
100
- # input and output shape
101
- self.in_size = in_size
102
- self.out_size = out_size
103
- assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
104
- 'and "out_size" must be the same.')
105
-
106
- # w_mask
107
- self.w_sign = w_sign
108
-
109
- # weights
110
- weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
111
- self.weight = param_type(weight)
112
-
113
- def update(self, x):
114
- w = self.weight.value
115
- if self.w_sign is None:
116
- return u.math.matmul(x, u.math.abs(w))
117
- else:
118
- return u.math.matmul(x, u.math.abs(w) * self.w_sign)
119
-
120
-
121
- class ScaledWSLinear(Module):
122
- """
123
- Linear Layer with Weight Standardization.
124
-
125
- Applies weight standardization to the weights of the linear layer.
126
-
127
- Parameters
128
- ----------
129
- in_size: int, sequence of int
130
- The input size.
131
- out_size: int, sequence of int
132
- The output size.
133
- w_init: Callable, ArrayLike
134
- The initializer for the weights.
135
- b_init: Callable, ArrayLike
136
- The initializer for the bias.
137
- w_mask: ArrayLike, Callable
138
- The optional mask of the weights.
139
- ws_gain: bool
140
- Whether to use gain for the weights. The default is True.
141
- eps: float
142
- The epsilon value for the weight standardization.
143
- name: str
144
- The name of the object.
145
-
146
- """
147
- __module__ = 'brainstate.nn'
148
-
149
- def __init__(
150
- self,
151
- in_size: Size,
152
- out_size: Size,
153
- w_init: Callable = init.KaimingNormal(),
154
- b_init: Callable = init.ZeroInit(),
155
- w_mask: Optional[Union[ArrayLike, Callable]] = None,
156
- ws_gain: bool = True,
157
- eps: float = 1e-4,
158
- name: str = None,
159
- param_type: type = ParamState,
160
- ):
161
- super().__init__(name=name)
162
-
163
- # input and output shape
164
- self.in_size = in_size
165
- self.out_size = out_size
166
- assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
167
- 'and "out_size" must be the same.')
168
-
169
- # w_mask
170
- self.w_mask = init.param(w_mask, (self.in_size[0], 1))
171
-
172
- # parameters
173
- self.eps = eps
174
-
175
- # weights
176
- params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
177
- if b_init is not None:
178
- params['bias'] = init.param(b_init, self.out_size, allow_none=False)
179
- # gain
180
- if ws_gain:
181
- s = params['weight'].shape
182
- params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
183
- self.weight = param_type(params)
184
-
185
- def update(self, x):
186
- params = self.weight.value
187
- w = params['weight']
188
- w = functional.weight_standardization(w, self.eps, params.get('gain', None))
189
- if self.w_mask is not None:
190
- w = w * self.w_mask
191
- y = u.linalg.dot(x, w)
192
- if 'bias' in params:
193
- y = y + params['bias']
194
- return y
195
-
196
-
197
- class SparseLinear(Module):
198
- """
199
- Linear layer with Sparse Matrix (can be ``brainunit.sparse.CSR``,
200
- ``brainunit.sparse.CSC``, ``brainunit.sparse.COO``, or any other sparse matrix).
201
-
202
- Args:
203
- spar_mat: SparseMatrix. The sparse weight matrix.
204
- in_size: Size. The input size.
205
- name: str. The object name.
206
- """
207
- __module__ = 'brainstate.nn'
208
-
209
- def __init__(
210
- self,
211
- spar_mat: u.sparse.SparseMatrix,
212
- b_init: Optional[Union[Callable, ArrayLike]] = None,
213
- in_size: Size = None,
214
- name: Optional[str] = None,
215
- param_type: type = ParamState,
216
- ):
217
- super().__init__(name=name)
218
-
219
- # input and output shape
220
- if in_size is not None:
221
- self.in_size = in_size
222
- self.out_size = spar_mat.shape[-1]
223
- if in_size is not None:
224
- assert self.in_size[:-1] == self.out_size[:-1], (
225
- 'The first n-1 dimensions of "in_size" '
226
- 'and "out_size" must be the same.'
227
- )
228
-
229
- # weights
230
- assert isinstance(spar_mat, u.sparse.SparseMatrix), '"weight" must be a SparseMatrix.'
231
- self.spar_mat = spar_mat
232
- params = dict(weight=spar_mat.data)
233
- if b_init is not None:
234
- params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
235
- self.weight = param_type(params)
236
-
237
- def update(self, x):
238
- data = self.weight.value['weight']
239
- y = x @ self.spar_mat.with_data(data)
240
- if 'bias' in self.weight.value:
241
- y = y + self.weight.value['bias']
242
- return y
243
-
244
-
245
- class AllToAll(Module):
246
- """
247
- Synaptic matrix multiplication with All-to-All connections.
248
-
249
- Args:
250
- in_size: Size. The number of neurons in the pre-synaptic neuron group.
251
- out_size: Size. The number of neurons in the postsynaptic neuron group.
252
- w_init: The synaptic weight initializer.
253
- include_self: bool. Whether connect the neuron with at the same position.
254
- name: str. The object name.
255
- """
256
-
257
- def __init__(
258
- self,
259
- in_size: Size,
260
- out_size: Size,
261
- w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
262
- b_init: Optional[Union[Callable, ArrayLike]] = None,
263
- include_self: bool = True,
264
- name: Optional[str] = None,
265
- param_type: type = ParamState,
266
- ):
267
- super().__init__(name=name)
268
-
269
- # input and output shape
270
- self.in_size = in_size
271
- self.out_size = out_size
272
- assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
273
- 'and "out_size" must be the same.')
274
-
275
- # others
276
- self.include_self = include_self
277
-
278
- # weights
279
- weight = init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False)
280
- params = dict(weight=weight)
281
- if b_init is not None:
282
- params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
283
- self.weight = param_type(params)
284
-
285
- def update(self, pre_val):
286
- params = self.weight.value
287
- pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
288
- w_val, w_unit = u.get_mantissa(params['weight']), u.get_unit(params['weight'])
289
-
290
- if u.math.ndim(w_val) == 0: # weight is a scalar
291
- if pre_val.ndim == 1:
292
- post_val = u.math.sum(pre_val)
293
- else:
294
- post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
295
- if not self.include_self:
296
- if self.in_size == self.out_size:
297
- post_val = post_val - pre_val
298
- elif self.in_size[-1] > self.out_size[-1]:
299
- val = pre_val[..., :self.out_size[-1]]
300
- post_val = post_val - val
301
- else:
302
- size = list(self.out_size)
303
- size[-1] = self.out_size[-1] - self.in_size[-1]
304
- val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
305
- post_val = post_val - val
306
- post_val = w_val * post_val
307
-
308
- else: # weight is a matrix
309
- assert u.math.ndim(w_val) == 2, '"weight" must be a 2D matrix.'
310
- if not self.include_self:
311
- post_val = pre_val @ u.math.fill_diagonal(w_val, 0.)
312
- else:
313
- post_val = pre_val @ w_val
314
-
315
- post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
316
- if 'bias' in params:
317
- post_val = post_val + params['bias']
318
- return post_val
319
-
320
-
321
- class OneToOne(Module):
322
- """
323
- Synaptic matrix multiplication with One2One connection.
324
-
325
- Args:
326
- in_size: Size. The number of neurons in the pre-synaptic neuron group.
327
- w_init: The synaptic weight initializer.
328
- b_init: The synaptic bias initializer.
329
- name: str. The object name.
330
- """
331
-
332
- def __init__(
333
- self,
334
- in_size: Size,
335
- w_init: Union[Callable, ArrayLike] = init.Normal(),
336
- b_init: Optional[Union[Callable, ArrayLike]] = None,
337
- name: Optional[str] = None,
338
- param_type: type = ParamState,
339
- ):
340
- super().__init__(name=name)
341
-
342
- # input and output shape
343
- self.in_size = in_size
344
- self.out_size = in_size
345
-
346
- # weights
347
- param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
348
- if b_init is not None:
349
- param['bias'] = init.param(b_init, self.out_size, allow_none=False)
350
- self.weight = param_type(param)
351
-
352
- def update(self, pre_val):
353
- post_val = pre_val * self.weight.value['weight']
354
- if 'bias' in self.weight.value:
355
- post_val = post_val + self.weight.value['bias']
356
- return post_val
357
-
358
-
359
- class LoRA(Module):
360
- """A standalone LoRA layer.
361
-
362
- Example usage::
363
-
364
- >>> import brainstate as brainstate
365
- >>> import jax, jax.numpy as jnp
366
- >>> layer = brainstate.nn.LoRA(3, 2, 4)
367
- >>> layer.weight.value
368
- {'lora_a': Array([[ 0.25141352, -0.09826107],
369
- [ 0.2328382 , 0.38869813],
370
- [ 0.27069277, 0.7678282 ]], dtype=float32),
371
- 'lora_b': Array([[-0.8372317 , 0.21012013, -0.52999765, -0.31939325],
372
- [ 0.64234126, -0.42980042, 1.2549229 , -0.47134295]], dtype=float32)}
373
- >>> # Wrap around existing layer
374
- >>> linear = brainstate.nn.Linear(3, 4)
375
- >>> wrapper = brainstate.nn.LoRA(3, 2, 4, base_module=linear)
376
- >>> assert wrapper.base_module == linear
377
- >>> y = layer(jnp.ones((16, 3)))
378
- >>> y.shape
379
- (16, 4)
380
-
381
- Args:
382
- in_features: the number of input features.
383
- lora_rank: the rank of the LoRA dimension.
384
- out_features: the number of output features.
385
- base_module: a base module to call and substitute, if possible.
386
- kernel_init: initializer function for the weight matrices.
387
- param_type: the type of the LoRA params.
388
- """
389
-
390
- def __init__(
391
- self,
392
- in_features: int,
393
- lora_rank: int,
394
- out_features: int,
395
- *,
396
- base_module: Optional[Module] = None,
397
- kernel_init: Union[Callable, ArrayLike] = init.LecunNormal(),
398
- param_type: type = ParamState,
399
- ):
400
- super().__init__()
401
-
402
- # input and output shape
403
- self.in_size = in_features
404
- self.out_size = out_features
405
- self.in_features = in_features
406
- self.out_features = out_features
407
-
408
- # others
409
- self.base_module = base_module
410
-
411
- # weights
412
- param = dict(
413
- lora_a=kernel_init((in_features, lora_rank)),
414
- lora_b=kernel_init((lora_rank, out_features))
415
- )
416
- self.weight = param_type(param)
417
-
418
- def __call__(self, x: ArrayLike):
419
- out = x @ self.weight.value['lora_a'] @ self.weight.value['lora_b']
420
- if self.base_module is not None:
421
- if not callable(self.base_module):
422
- raise ValueError('`self.base_module` must be callable.')
423
- out += self.base_module(x)
424
- return out
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from typing import Callable, Union, Optional
19
+
20
+ import brainunit as u
21
+ import jax.numpy as jnp
22
+
23
+ from brainstate._state import ParamState
24
+ from brainstate.typing import ArrayLike, Size
25
+ from . import init as init
26
+ from ._module import Module
27
+ from ._normalizations import weight_standardization
28
+
29
+ __all__ = [
30
+ 'Linear',
31
+ 'ScaledWSLinear',
32
+ 'SignedWLinear',
33
+ 'SparseLinear',
34
+ 'AllToAll',
35
+ 'OneToOne',
36
+ 'LoRA',
37
+ ]
38
+
39
+
40
+ class Linear(Module):
41
+ """
42
+ Linear transformation layer.
43
+
44
+ Applies a linear transformation to the incoming data: :math:`y = xW + b`
45
+
46
+ Parameters
47
+ ----------
48
+ in_size : int or tuple of int
49
+ The input feature size.
50
+ out_size : int or tuple of int
51
+ The output feature size.
52
+ w_init : Callable or ArrayLike, optional
53
+ Weight initializer. Default is ``KaimingNormal()``.
54
+ b_init : Callable, ArrayLike, or None, optional
55
+ Bias initializer. If ``None``, no bias is added. Default is ``ZeroInit()``.
56
+ w_mask : ArrayLike, Callable, or None, optional
57
+ Optional mask for the weights. If provided, weights will be element-wise
58
+ multiplied by this mask.
59
+ name : str, optional
60
+ Name of the module.
61
+ param_type : type, optional
62
+ Type of parameter state. Default is ``ParamState``.
63
+
64
+ Attributes
65
+ ----------
66
+ in_size : tuple
67
+ Input feature size.
68
+ out_size : tuple
69
+ Output feature size.
70
+ w_mask : ArrayLike or None
71
+ Weight mask if provided.
72
+ weight : ParamState
73
+ Parameter state containing 'weight' and optionally 'bias'.
74
+
75
+ Examples
76
+ --------
77
+ .. code-block:: python
78
+
79
+ >>> import brainstate as bst
80
+ >>> import jax.numpy as jnp
81
+ >>>
82
+ >>> # Create a linear layer
83
+ >>> layer = bst.nn.Linear((10,), (5,))
84
+ >>> x = jnp.ones((32, 10))
85
+ >>> y = layer(x)
86
+ >>> y.shape
87
+ (32, 5)
88
+ >>>
89
+ >>> # Linear layer without bias
90
+ >>> layer = bst.nn.Linear((10,), (5,), b_init=None)
91
+ >>> y = layer(x)
92
+ >>> y.shape
93
+ (32, 5)
94
+ """
95
+ __module__ = 'brainstate.nn'
96
+
97
+ def __init__(
98
+ self,
99
+ in_size: Size,
100
+ out_size: Size,
101
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
102
+ b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
103
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
104
+ name: Optional[str] = None,
105
+ param_type: type = ParamState,
106
+ ):
107
+ super().__init__(name=name)
108
+
109
+ # input and output shape
110
+ self.in_size = in_size
111
+ self.out_size = out_size
112
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
113
+ 'and "out_size" must be the same.')
114
+
115
+ # w_mask
116
+ self.w_mask = init.param(w_mask, self.in_size + self.out_size)
117
+
118
+ # weights
119
+ params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
120
+ if b_init is not None:
121
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
122
+ self.weight = param_type(params)
123
+
124
+ def update(self, x):
125
+ params = self.weight.value
126
+ weight = params['weight']
127
+ if self.w_mask is not None:
128
+ weight = weight * self.w_mask
129
+ y = u.linalg.dot(x, weight)
130
+ if 'bias' in params:
131
+ y = y + params['bias']
132
+ return y
133
+
134
+
135
+ class SignedWLinear(Module):
136
+ """
137
+ Linear layer with signed absolute weights.
138
+
139
+ This layer uses absolute values of weights multiplied by a sign matrix,
140
+ ensuring all effective weights have controlled signs.
141
+
142
+ Parameters
143
+ ----------
144
+ in_size : int or tuple of int
145
+ The input feature size.
146
+ out_size : int or tuple of int
147
+ The output feature size.
148
+ w_init : Callable or ArrayLike, optional
149
+ Weight initializer. Default is ``KaimingNormal()``.
150
+ w_sign : ArrayLike or None, optional
151
+ Sign matrix for the weights. If ``None``, all weights are positive
152
+ (absolute values used). If provided, should have the same shape as
153
+ the weight matrix.
154
+ name : str, optional
155
+ Name of the module.
156
+ param_type : type, optional
157
+ Type of parameter state. Default is ``ParamState``.
158
+
159
+ Attributes
160
+ ----------
161
+ in_size : tuple
162
+ Input feature size.
163
+ out_size : tuple
164
+ Output feature size.
165
+ w_sign : ArrayLike or None
166
+ Sign matrix for weights.
167
+ weight : ParamState
168
+ Parameter state containing the weight values.
169
+
170
+ Examples
171
+ --------
172
+ .. code-block:: python
173
+
174
+ >>> import brainstate as bst
175
+ >>> import jax.numpy as jnp
176
+ >>>
177
+ >>> # Create a signed weight linear layer with all positive weights
178
+ >>> layer = bst.nn.SignedWLinear((10,), (5,))
179
+ >>> x = jnp.ones((32, 10))
180
+ >>> y = layer(x)
181
+ >>> y.shape
182
+ (32, 5)
183
+ >>>
184
+ >>> # With custom sign matrix (e.g., inhibitory connections)
185
+ >>> w_sign = jnp.ones((10, 5)) * -1.0 # all negative
186
+ >>> layer = bst.nn.SignedWLinear((10,), (5,), w_sign=w_sign)
187
+ >>> y = layer(x)
188
+ >>> y.shape
189
+ (32, 5)
190
+ """
191
+ __module__ = 'brainstate.nn'
192
+
193
+ def __init__(
194
+ self,
195
+ in_size: Size,
196
+ out_size: Size,
197
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
198
+ w_sign: Optional[ArrayLike] = None,
199
+ name: Optional[str] = None,
200
+ param_type: type = ParamState,
201
+ ):
202
+ super().__init__(name=name)
203
+
204
+ # input and output shape
205
+ self.in_size = in_size
206
+ self.out_size = out_size
207
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
208
+ 'and "out_size" must be the same.')
209
+
210
+ # w_mask
211
+ self.w_sign = w_sign
212
+
213
+ # weights
214
+ weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
215
+ self.weight = param_type(weight)
216
+
217
+ def update(self, x):
218
+ w = self.weight.value
219
+ if self.w_sign is None:
220
+ return u.math.matmul(x, u.math.abs(w))
221
+ else:
222
+ return u.math.matmul(x, u.math.abs(w) * self.w_sign)
223
+
224
+
225
+ class ScaledWSLinear(Module):
226
+ """
227
+ Linear layer with weight standardization.
228
+
229
+ Applies weight standardization [1]_ to normalize weights before the linear
230
+ transformation, which can improve training stability and performance.
231
+
232
+ Parameters
233
+ ----------
234
+ in_size : int or tuple of int
235
+ The input feature size.
236
+ out_size : int or tuple of int
237
+ The output feature size.
238
+ w_init : Callable, optional
239
+ Weight initializer. Default is ``KaimingNormal()``.
240
+ b_init : Callable, optional
241
+ Bias initializer. Default is ``ZeroInit()``.
242
+ w_mask : ArrayLike, Callable, or None, optional
243
+ Optional mask for the weights.
244
+ ws_gain : bool, optional
245
+ Whether to use a learnable gain parameter for weight standardization.
246
+ Default is ``True``.
247
+ eps : float, optional
248
+ Small constant for numerical stability in standardization.
249
+ Default is ``1e-4``.
250
+ name : str, optional
251
+ Name of the module.
252
+ param_type : type, optional
253
+ Type of parameter state. Default is ``ParamState``.
254
+
255
+ Attributes
256
+ ----------
257
+ in_size : tuple
258
+ Input feature size.
259
+ out_size : tuple
260
+ Output feature size.
261
+ w_mask : ArrayLike or None
262
+ Weight mask if provided.
263
+ eps : float
264
+ Epsilon for numerical stability.
265
+ weight : ParamState
266
+ Parameter state containing 'weight', optionally 'bias' and 'gain'.
267
+
268
+ References
269
+ ----------
270
+ .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
271
+ Weight standardization. arXiv preprint arXiv:1903.10520.
272
+
273
+ Examples
274
+ --------
275
+ .. code-block:: python
276
+
277
+ >>> import brainstate as bst
278
+ >>> import jax.numpy as jnp
279
+ >>>
280
+ >>> # Create a weight-standardized linear layer
281
+ >>> layer = bst.nn.ScaledWSLinear((10,), (5,))
282
+ >>> x = jnp.ones((32, 10))
283
+ >>> y = layer(x)
284
+ >>> y.shape
285
+ (32, 5)
286
+ >>>
287
+ >>> # Without learnable gain
288
+ >>> layer = bst.nn.ScaledWSLinear((10,), (5,), ws_gain=False)
289
+ >>> y = layer(x)
290
+ >>> y.shape
291
+ (32, 5)
292
+ """
293
+ __module__ = 'brainstate.nn'
294
+
295
+ def __init__(
296
+ self,
297
+ in_size: Size,
298
+ out_size: Size,
299
+ w_init: Callable = init.KaimingNormal(),
300
+ b_init: Callable = init.ZeroInit(),
301
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
302
+ ws_gain: bool = True,
303
+ eps: float = 1e-4,
304
+ name: str = None,
305
+ param_type: type = ParamState,
306
+ ):
307
+ super().__init__(name=name)
308
+
309
+ # input and output shape
310
+ self.in_size = in_size
311
+ self.out_size = out_size
312
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
313
+ 'and "out_size" must be the same.')
314
+
315
+ # w_mask
316
+ self.w_mask = init.param(w_mask, (self.in_size[0], 1))
317
+
318
+ # parameters
319
+ self.eps = eps
320
+
321
+ # weights
322
+ params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
323
+ if b_init is not None:
324
+ params['bias'] = init.param(b_init, self.out_size, allow_none=False)
325
+ # gain
326
+ if ws_gain:
327
+ s = params['weight'].shape
328
+ params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
329
+ self.weight = param_type(params)
330
+
331
+ def update(self, x):
332
+ params = self.weight.value
333
+ w = params['weight']
334
+ w = weight_standardization(w, self.eps, params.get('gain', None))
335
+ if self.w_mask is not None:
336
+ w = w * self.w_mask
337
+ y = u.linalg.dot(x, w)
338
+ if 'bias' in params:
339
+ y = y + params['bias']
340
+ return y
341
+
342
+
343
+ class SparseLinear(Module):
344
+ """
345
+ Linear layer with sparse weight matrix.
346
+
347
+ Supports sparse matrices from ``brainunit.sparse`` including CSR, CSC,
348
+ and COO formats. Only the non-zero entries are stored and updated.
349
+
350
+ Parameters
351
+ ----------
352
+ spar_mat : brainunit.sparse.SparseMatrix
353
+ The sparse weight matrix defining the connectivity structure.
354
+ b_init : Callable, ArrayLike, or None, optional
355
+ Bias initializer. If ``None``, no bias is added.
356
+ in_size : int or tuple of int, optional
357
+ The input size. If not provided, inferred from ``spar_mat``.
358
+ name : str, optional
359
+ Name of the module.
360
+ param_type : type, optional
361
+ Type of parameter state. Default is ``ParamState``.
362
+
363
+ Attributes
364
+ ----------
365
+ in_size : tuple
366
+ Input feature size.
367
+ out_size : int
368
+ Output feature size.
369
+ spar_mat : brainunit.sparse.SparseMatrix
370
+ The sparse matrix structure.
371
+ weight : ParamState
372
+ Parameter state containing the sparse 'weight' data and optionally 'bias'.
373
+
374
+ Examples
375
+ --------
376
+ .. code-block:: python
377
+
378
+ >>> import brainstate as bst
379
+ >>> import brainunit as u
380
+ >>> import jax.numpy as jnp
381
+ >>>
382
+ >>> # Create a sparse linear layer with CSR matrix
383
+ >>> indices = jnp.array([[0, 1], [1, 2], [2, 0]])
384
+ >>> values = jnp.array([1.0, 2.0, 3.0])
385
+ >>> spar_mat = u.sparse.CSR((values, indices[:, 1], indices[:, 0]),
386
+ ... shape=(3, 3))
387
+ >>> layer = bst.nn.SparseLinear(spar_mat, in_size=(3,))
388
+ >>> x = jnp.ones((5, 3))
389
+ >>> y = layer(x)
390
+ >>> y.shape
391
+ (5, 3)
392
+ """
393
+ __module__ = 'brainstate.nn'
394
+
395
+ def __init__(
396
+ self,
397
+ spar_mat: u.sparse.SparseMatrix,
398
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
399
+ in_size: Size = None,
400
+ name: Optional[str] = None,
401
+ param_type: type = ParamState,
402
+ ):
403
+ super().__init__(name=name)
404
+
405
+ # input and output shape
406
+ if in_size is not None:
407
+ self.in_size = in_size
408
+ self.out_size = spar_mat.shape[-1]
409
+ if in_size is not None:
410
+ assert self.in_size[:-1] == self.out_size[:-1], (
411
+ 'The first n-1 dimensions of "in_size" '
412
+ 'and "out_size" must be the same.'
413
+ )
414
+
415
+ # weights
416
+ assert isinstance(spar_mat, u.sparse.SparseMatrix), '"weight" must be a SparseMatrix.'
417
+ self.spar_mat = spar_mat
418
+ params = dict(weight=spar_mat.data)
419
+ if b_init is not None:
420
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
421
+ self.weight = param_type(params)
422
+
423
+ def update(self, x):
424
+ data = self.weight.value['weight']
425
+ y = x @ self.spar_mat.with_data(data)
426
+ if 'bias' in self.weight.value:
427
+ y = y + self.weight.value['bias']
428
+ return y
429
+
430
+
431
+ class AllToAll(Module):
432
+ """
433
+ All-to-all connection layer.
434
+
435
+ Performs matrix multiplication with optional exclusion of self-connections,
436
+ commonly used in recurrent neural networks and graph neural networks.
437
+
438
+ Parameters
439
+ ----------
440
+ in_size : int or tuple of int
441
+ The number of neurons in the pre-synaptic group.
442
+ out_size : int or tuple of int
443
+ The number of neurons in the post-synaptic group.
444
+ w_init : Callable or ArrayLike, optional
445
+ Weight initializer. Default is ``KaimingNormal()``.
446
+ b_init : Callable, ArrayLike, or None, optional
447
+ Bias initializer. If ``None``, no bias is added.
448
+ include_self : bool, optional
449
+ Whether to include self-connections (diagonal elements).
450
+ Default is ``True``.
451
+ name : str, optional
452
+ Name of the module.
453
+ param_type : type, optional
454
+ Type of parameter state. Default is ``ParamState``.
455
+
456
+ Attributes
457
+ ----------
458
+ in_size : tuple
459
+ Input size.
460
+ out_size : tuple
461
+ Output size.
462
+ include_self : bool
463
+ Whether self-connections are included.
464
+ weight : ParamState
465
+ Parameter state containing 'weight' and optionally 'bias'.
466
+
467
+ Examples
468
+ --------
469
+ .. code-block:: python
470
+
471
+ >>> import brainstate as bst
472
+ >>> import jax.numpy as jnp
473
+ >>>
474
+ >>> # All-to-all with self-connections
475
+ >>> layer = bst.nn.AllToAll((10,), (10,), include_self=True)
476
+ >>> x = jnp.ones((32, 10))
477
+ >>> y = layer(x)
478
+ >>> y.shape
479
+ (32, 10)
480
+ >>>
481
+ >>> # All-to-all without self-connections (recurrent layer)
482
+ >>> layer = bst.nn.AllToAll((10,), (10,), include_self=False)
483
+ >>> y = layer(x)
484
+ >>> y.shape
485
+ (32, 10)
486
+ """
487
+ __module__ = 'brainstate.nn'
488
+
489
+ def __init__(
490
+ self,
491
+ in_size: Size,
492
+ out_size: Size,
493
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
494
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
495
+ include_self: bool = True,
496
+ name: Optional[str] = None,
497
+ param_type: type = ParamState,
498
+ ):
499
+ super().__init__(name=name)
500
+
501
+ # input and output shape
502
+ self.in_size = in_size
503
+ self.out_size = out_size
504
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
505
+ 'and "out_size" must be the same.')
506
+
507
+ # others
508
+ self.include_self = include_self
509
+
510
+ # weights
511
+ weight = init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False)
512
+ params = dict(weight=weight)
513
+ if b_init is not None:
514
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
515
+ self.weight = param_type(params)
516
+
517
+ def update(self, pre_val):
518
+ params = self.weight.value
519
+ pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
520
+ w_val, w_unit = u.get_mantissa(params['weight']), u.get_unit(params['weight'])
521
+
522
+ if u.math.ndim(w_val) == 0: # weight is a scalar
523
+ if pre_val.ndim == 1:
524
+ post_val = u.math.sum(pre_val)
525
+ else:
526
+ post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
527
+ if not self.include_self:
528
+ if self.in_size == self.out_size:
529
+ post_val = post_val - pre_val
530
+ elif self.in_size[-1] > self.out_size[-1]:
531
+ val = pre_val[..., :self.out_size[-1]]
532
+ post_val = post_val - val
533
+ else:
534
+ size = list(self.out_size)
535
+ size[-1] = self.out_size[-1] - self.in_size[-1]
536
+ val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
537
+ post_val = post_val - val
538
+ post_val = w_val * post_val
539
+
540
+ else: # weight is a matrix
541
+ assert u.math.ndim(w_val) == 2, '"weight" must be a 2D matrix.'
542
+ if not self.include_self:
543
+ post_val = pre_val @ u.math.fill_diagonal(w_val, 0.)
544
+ else:
545
+ post_val = pre_val @ w_val
546
+
547
+ post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
548
+ if 'bias' in params:
549
+ post_val = post_val + params['bias']
550
+ return post_val
551
+
552
+
553
+ class OneToOne(Module):
554
+ """
555
+ One-to-one connection layer.
556
+
557
+ Applies element-wise multiplication with a weight vector, implementing
558
+ diagonal connectivity where each input unit connects only to its
559
+ corresponding output unit.
560
+
561
+ Parameters
562
+ ----------
563
+ in_size : int or tuple of int
564
+ The number of neurons. Input and output sizes are the same.
565
+ w_init : Callable or ArrayLike, optional
566
+ Weight initializer. Default is ``Normal()``.
567
+ b_init : Callable, ArrayLike, or None, optional
568
+ Bias initializer. If ``None``, no bias is added.
569
+ name : str, optional
570
+ Name of the module.
571
+ param_type : type, optional
572
+ Type of parameter state. Default is ``ParamState``.
573
+
574
+ Attributes
575
+ ----------
576
+ in_size : tuple
577
+ Input size.
578
+ out_size : tuple
579
+ Output size (same as input size).
580
+ weight : ParamState
581
+ Parameter state containing 'weight' and optionally 'bias'.
582
+
583
+ Examples
584
+ --------
585
+ .. code-block:: python
586
+
587
+ >>> import brainstate as bst
588
+ >>> import jax.numpy as jnp
589
+ >>>
590
+ >>> # One-to-one connection
591
+ >>> layer = bst.nn.OneToOne((10,))
592
+ >>> x = jnp.ones((32, 10))
593
+ >>> y = layer(x)
594
+ >>> y.shape
595
+ (32, 10)
596
+ >>>
597
+ >>> # With bias
598
+ >>> layer = bst.nn.OneToOne((10,), b_init=bst.init.Constant(0.1))
599
+ >>> y = layer(x)
600
+ >>> y.shape
601
+ (32, 10)
602
+ """
603
+ __module__ = 'brainstate.nn'
604
+
605
+ def __init__(
606
+ self,
607
+ in_size: Size,
608
+ w_init: Union[Callable, ArrayLike] = init.Normal(),
609
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
610
+ name: Optional[str] = None,
611
+ param_type: type = ParamState,
612
+ ):
613
+ super().__init__(name=name)
614
+
615
+ # input and output shape
616
+ self.in_size = in_size
617
+ self.out_size = in_size
618
+
619
+ # weights
620
+ param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
621
+ if b_init is not None:
622
+ param['bias'] = init.param(b_init, self.out_size, allow_none=False)
623
+ self.weight = param_type(param)
624
+
625
+ def update(self, pre_val):
626
+ post_val = pre_val * self.weight.value['weight']
627
+ if 'bias' in self.weight.value:
628
+ post_val = post_val + self.weight.value['bias']
629
+ return post_val
630
+
631
+
632
+ class LoRA(Module):
633
+ """
634
+ Low-Rank Adaptation (LoRA) layer.
635
+
636
+ Implements parameter-efficient fine-tuning using low-rank decomposition [1]_.
637
+ Can be used standalone or as a wrapper around an existing module.
638
+
639
+ Parameters
640
+ ----------
641
+ in_features : int
642
+ The number of input features.
643
+ lora_rank : int
644
+ The rank of the low-rank decomposition. Lower rank means fewer parameters.
645
+ out_features : int
646
+ The number of output features.
647
+ base_module : Module, optional
648
+ A base module to wrap. If provided, the LoRA output will be added to
649
+ the base module's output. Default is ``None``.
650
+ kernel_init : Callable or ArrayLike, optional
651
+ Initializer for the LoRA weight matrices. Default is ``LecunNormal()``.
652
+ param_type : type, optional
653
+ Type of parameter state. Default is ``ParamState``.
654
+
655
+ Attributes
656
+ ----------
657
+ in_size : int
658
+ Input feature size.
659
+ out_size : int
660
+ Output feature size.
661
+ in_features : int
662
+ Number of input features.
663
+ out_features : int
664
+ Number of output features.
665
+ base_module : Module or None
666
+ The wrapped base module if provided.
667
+ weight : ParamState
668
+ Parameter state containing 'lora_a' and 'lora_b' matrices.
669
+
670
+ References
671
+ ----------
672
+ .. [1] Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S.,
673
+ Wang, L., & Chen, W. (2021). LoRA: Low-Rank Adaptation of Large
674
+ Language Models. arXiv preprint arXiv:2106.09685.
675
+
676
+ Examples
677
+ --------
678
+ .. code-block:: python
679
+
680
+ >>> import brainstate as bst
681
+ >>> import jax.numpy as jnp
682
+ >>>
683
+ >>> # Standalone LoRA layer
684
+ >>> layer = bst.nn.LoRA(in_features=10, lora_rank=2, out_features=5)
685
+ >>> x = jnp.ones((32, 10))
686
+ >>> y = layer(x)
687
+ >>> y.shape
688
+ (32, 5)
689
+ >>>
690
+ >>> # Wrap around existing linear layer
691
+ >>> base = bst.nn.Linear((10,), (5,))
692
+ >>> lora_layer = bst.nn.LoRA(in_features=10, lora_rank=2,
693
+ ... out_features=5, base_module=base)
694
+ >>> y = lora_layer(x)
695
+ >>> y.shape
696
+ (32, 5)
697
+ >>>
698
+ >>> # Check parameter count - LoRA has fewer parameters
699
+ >>> # Base layer: 10 * 5 = 50 parameters
700
+ >>> # LoRA: 10 * 2 + 2 * 5 = 30 parameters
701
+ """
702
+ __module__ = 'brainstate.nn'
703
+
704
+ def __init__(
705
+ self,
706
+ in_features: int,
707
+ lora_rank: int,
708
+ out_features: int,
709
+ *,
710
+ base_module: Optional[Module] = None,
711
+ kernel_init: Union[Callable, ArrayLike] = init.LecunNormal(),
712
+ param_type: type = ParamState,
713
+ in_size: Size = None,
714
+ ):
715
+ super().__init__()
716
+
717
+ # input and output shape
718
+ self.in_size = in_features
719
+ self.out_size = out_features
720
+ self.in_features = in_features
721
+ self.out_features = out_features
722
+
723
+ # others
724
+ self.base_module = base_module
725
+
726
+ # weights
727
+ param = dict(
728
+ lora_a=kernel_init((in_features, lora_rank)),
729
+ lora_b=kernel_init((lora_rank, out_features))
730
+ )
731
+ self.weight = param_type(param)
732
+
733
+ # in_size
734
+ if in_size is not None:
735
+ self.in_size = in_size
736
+ self.out_size = tuple(self.in_size[:-1]) + (out_features,)
737
+
738
+ def __call__(self, x: ArrayLike):
739
+ out = x @ self.weight.value['lora_a'] @ self.weight.value['lora_b']
740
+ if self.base_module is not None:
741
+ if not callable(self.base_module):
742
+ raise ValueError('`self.base_module` must be callable.')
743
+ out += self.base_module(x)
744
+ return out