brainstate 0.1.9__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +95 -29
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.9.dist-info/RECORD +0 -130
  161. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.9.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
brainstate/nn/_linear.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -20,10 +20,11 @@ from typing import Callable, Union, Optional
20
20
  import brainunit as u
21
21
  import jax.numpy as jnp
22
22
 
23
- from brainstate import init, functional
24
23
  from brainstate._state import ParamState
25
24
  from brainstate.typing import ArrayLike, Size
25
+ from . import init as init
26
26
  from ._module import Module
27
+ from ._normalizations import weight_standardization
27
28
 
28
29
  __all__ = [
29
30
  'Linear',
@@ -38,7 +39,58 @@ __all__ = [
38
39
 
39
40
  class Linear(Module):
40
41
  """
41
- Linear layer.
42
+ Linear transformation layer.
43
+
44
+ Applies a linear transformation to the incoming data: :math:`y = xW + b`
45
+
46
+ Parameters
47
+ ----------
48
+ in_size : int or tuple of int
49
+ The input feature size.
50
+ out_size : int or tuple of int
51
+ The output feature size.
52
+ w_init : Callable or ArrayLike, optional
53
+ Weight initializer. Default is ``KaimingNormal()``.
54
+ b_init : Callable, ArrayLike, or None, optional
55
+ Bias initializer. If ``None``, no bias is added. Default is ``ZeroInit()``.
56
+ w_mask : ArrayLike, Callable, or None, optional
57
+ Optional mask for the weights. If provided, weights will be element-wise
58
+ multiplied by this mask.
59
+ name : str, optional
60
+ Name of the module.
61
+ param_type : type, optional
62
+ Type of parameter state. Default is ``ParamState``.
63
+
64
+ Attributes
65
+ ----------
66
+ in_size : tuple
67
+ Input feature size.
68
+ out_size : tuple
69
+ Output feature size.
70
+ w_mask : ArrayLike or None
71
+ Weight mask if provided.
72
+ weight : ParamState
73
+ Parameter state containing 'weight' and optionally 'bias'.
74
+
75
+ Examples
76
+ --------
77
+ .. code-block:: python
78
+
79
+ >>> import brainstate as bst
80
+ >>> import jax.numpy as jnp
81
+ >>>
82
+ >>> # Create a linear layer
83
+ >>> layer = bst.nn.Linear((10,), (5,))
84
+ >>> x = jnp.ones((32, 10))
85
+ >>> y = layer(x)
86
+ >>> y.shape
87
+ (32, 5)
88
+ >>>
89
+ >>> # Linear layer without bias
90
+ >>> layer = bst.nn.Linear((10,), (5,), b_init=None)
91
+ >>> y = layer(x)
92
+ >>> y.shape
93
+ (32, 5)
42
94
  """
43
95
  __module__ = 'brainstate.nn'
44
96
 
@@ -82,7 +134,59 @@ class Linear(Module):
82
134
 
83
135
  class SignedWLinear(Module):
84
136
  """
85
- Linear layer with signed weights.
137
+ Linear layer with signed absolute weights.
138
+
139
+ This layer uses absolute values of weights multiplied by a sign matrix,
140
+ ensuring all effective weights have controlled signs.
141
+
142
+ Parameters
143
+ ----------
144
+ in_size : int or tuple of int
145
+ The input feature size.
146
+ out_size : int or tuple of int
147
+ The output feature size.
148
+ w_init : Callable or ArrayLike, optional
149
+ Weight initializer. Default is ``KaimingNormal()``.
150
+ w_sign : ArrayLike or None, optional
151
+ Sign matrix for the weights. If ``None``, all weights are positive
152
+ (absolute values used). If provided, should have the same shape as
153
+ the weight matrix.
154
+ name : str, optional
155
+ Name of the module.
156
+ param_type : type, optional
157
+ Type of parameter state. Default is ``ParamState``.
158
+
159
+ Attributes
160
+ ----------
161
+ in_size : tuple
162
+ Input feature size.
163
+ out_size : tuple
164
+ Output feature size.
165
+ w_sign : ArrayLike or None
166
+ Sign matrix for weights.
167
+ weight : ParamState
168
+ Parameter state containing the weight values.
169
+
170
+ Examples
171
+ --------
172
+ .. code-block:: python
173
+
174
+ >>> import brainstate as bst
175
+ >>> import jax.numpy as jnp
176
+ >>>
177
+ >>> # Create a signed weight linear layer with all positive weights
178
+ >>> layer = bst.nn.SignedWLinear((10,), (5,))
179
+ >>> x = jnp.ones((32, 10))
180
+ >>> y = layer(x)
181
+ >>> y.shape
182
+ (32, 5)
183
+ >>>
184
+ >>> # With custom sign matrix (e.g., inhibitory connections)
185
+ >>> w_sign = jnp.ones((10, 5)) * -1.0 # all negative
186
+ >>> layer = bst.nn.SignedWLinear((10,), (5,), w_sign=w_sign)
187
+ >>> y = layer(x)
188
+ >>> y.shape
189
+ (32, 5)
86
190
  """
87
191
  __module__ = 'brainstate.nn'
88
192
 
@@ -120,29 +224,71 @@ class SignedWLinear(Module):
120
224
 
121
225
  class ScaledWSLinear(Module):
122
226
  """
123
- Linear Layer with Weight Standardization.
227
+ Linear layer with weight standardization.
124
228
 
125
- Applies weight standardization to the weights of the linear layer.
229
+ Applies weight standardization [1]_ to normalize weights before the linear
230
+ transformation, which can improve training stability and performance.
126
231
 
127
232
  Parameters
128
233
  ----------
129
- in_size: int, sequence of int
130
- The input size.
131
- out_size: int, sequence of int
132
- The output size.
133
- w_init: Callable, ArrayLike
134
- The initializer for the weights.
135
- b_init: Callable, ArrayLike
136
- The initializer for the bias.
137
- w_mask: ArrayLike, Callable
138
- The optional mask of the weights.
139
- ws_gain: bool
140
- Whether to use gain for the weights. The default is True.
141
- eps: float
142
- The epsilon value for the weight standardization.
143
- name: str
144
- The name of the object.
145
-
234
+ in_size : int or tuple of int
235
+ The input feature size.
236
+ out_size : int or tuple of int
237
+ The output feature size.
238
+ w_init : Callable, optional
239
+ Weight initializer. Default is ``KaimingNormal()``.
240
+ b_init : Callable, optional
241
+ Bias initializer. Default is ``ZeroInit()``.
242
+ w_mask : ArrayLike, Callable, or None, optional
243
+ Optional mask for the weights.
244
+ ws_gain : bool, optional
245
+ Whether to use a learnable gain parameter for weight standardization.
246
+ Default is ``True``.
247
+ eps : float, optional
248
+ Small constant for numerical stability in standardization.
249
+ Default is ``1e-4``.
250
+ name : str, optional
251
+ Name of the module.
252
+ param_type : type, optional
253
+ Type of parameter state. Default is ``ParamState``.
254
+
255
+ Attributes
256
+ ----------
257
+ in_size : tuple
258
+ Input feature size.
259
+ out_size : tuple
260
+ Output feature size.
261
+ w_mask : ArrayLike or None
262
+ Weight mask if provided.
263
+ eps : float
264
+ Epsilon for numerical stability.
265
+ weight : ParamState
266
+ Parameter state containing 'weight', optionally 'bias' and 'gain'.
267
+
268
+ References
269
+ ----------
270
+ .. [1] Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019).
271
+ Weight standardization. arXiv preprint arXiv:1903.10520.
272
+
273
+ Examples
274
+ --------
275
+ .. code-block:: python
276
+
277
+ >>> import brainstate as bst
278
+ >>> import jax.numpy as jnp
279
+ >>>
280
+ >>> # Create a weight-standardized linear layer
281
+ >>> layer = bst.nn.ScaledWSLinear((10,), (5,))
282
+ >>> x = jnp.ones((32, 10))
283
+ >>> y = layer(x)
284
+ >>> y.shape
285
+ (32, 5)
286
+ >>>
287
+ >>> # Without learnable gain
288
+ >>> layer = bst.nn.ScaledWSLinear((10,), (5,), ws_gain=False)
289
+ >>> y = layer(x)
290
+ >>> y.shape
291
+ (32, 5)
146
292
  """
147
293
  __module__ = 'brainstate.nn'
148
294
 
@@ -185,7 +331,7 @@ class ScaledWSLinear(Module):
185
331
  def update(self, x):
186
332
  params = self.weight.value
187
333
  w = params['weight']
188
- w = functional.weight_standardization(w, self.eps, params.get('gain', None))
334
+ w = weight_standardization(w, self.eps, params.get('gain', None))
189
335
  if self.w_mask is not None:
190
336
  w = w * self.w_mask
191
337
  y = u.linalg.dot(x, w)
@@ -196,13 +342,53 @@ class ScaledWSLinear(Module):
196
342
 
197
343
  class SparseLinear(Module):
198
344
  """
199
- Linear layer with Sparse Matrix (can be ``brainunit.sparse.CSR``,
200
- ``brainunit.sparse.CSC``, ``brainunit.sparse.COO``, or any other sparse matrix).
345
+ Linear layer with sparse weight matrix.
346
+
347
+ Supports sparse matrices from ``brainunit.sparse`` including CSR, CSC,
348
+ and COO formats. Only the non-zero entries are stored and updated.
201
349
 
202
- Args:
203
- spar_mat: SparseMatrix. The sparse weight matrix.
204
- in_size: Size. The input size.
205
- name: str. The object name.
350
+ Parameters
351
+ ----------
352
+ spar_mat : brainunit.sparse.SparseMatrix
353
+ The sparse weight matrix defining the connectivity structure.
354
+ b_init : Callable, ArrayLike, or None, optional
355
+ Bias initializer. If ``None``, no bias is added.
356
+ in_size : int or tuple of int, optional
357
+ The input size. If not provided, inferred from ``spar_mat``.
358
+ name : str, optional
359
+ Name of the module.
360
+ param_type : type, optional
361
+ Type of parameter state. Default is ``ParamState``.
362
+
363
+ Attributes
364
+ ----------
365
+ in_size : tuple
366
+ Input feature size.
367
+ out_size : int
368
+ Output feature size.
369
+ spar_mat : brainunit.sparse.SparseMatrix
370
+ The sparse matrix structure.
371
+ weight : ParamState
372
+ Parameter state containing the sparse 'weight' data and optionally 'bias'.
373
+
374
+ Examples
375
+ --------
376
+ .. code-block:: python
377
+
378
+ >>> import brainstate as bst
379
+ >>> import brainunit as u
380
+ >>> import jax.numpy as jnp
381
+ >>>
382
+ >>> # Create a sparse linear layer with CSR matrix
383
+ >>> indices = jnp.array([[0, 1], [1, 2], [2, 0]])
384
+ >>> values = jnp.array([1.0, 2.0, 3.0])
385
+ >>> spar_mat = u.sparse.CSR((values, indices[:, 1], indices[:, 0]),
386
+ ... shape=(3, 3))
387
+ >>> layer = bst.nn.SparseLinear(spar_mat, in_size=(3,))
388
+ >>> x = jnp.ones((5, 3))
389
+ >>> y = layer(x)
390
+ >>> y.shape
391
+ (5, 3)
206
392
  """
207
393
  __module__ = 'brainstate.nn'
208
394
 
@@ -244,15 +430,61 @@ class SparseLinear(Module):
244
430
 
245
431
  class AllToAll(Module):
246
432
  """
247
- Synaptic matrix multiplication with All-to-All connections.
248
-
249
- Args:
250
- in_size: Size. The number of neurons in the pre-synaptic neuron group.
251
- out_size: Size. The number of neurons in the postsynaptic neuron group.
252
- w_init: The synaptic weight initializer.
253
- include_self: bool. Whether connect the neuron with at the same position.
254
- name: str. The object name.
433
+ All-to-all connection layer.
434
+
435
+ Performs matrix multiplication with optional exclusion of self-connections,
436
+ commonly used in recurrent neural networks and graph neural networks.
437
+
438
+ Parameters
439
+ ----------
440
+ in_size : int or tuple of int
441
+ The number of neurons in the pre-synaptic group.
442
+ out_size : int or tuple of int
443
+ The number of neurons in the post-synaptic group.
444
+ w_init : Callable or ArrayLike, optional
445
+ Weight initializer. Default is ``KaimingNormal()``.
446
+ b_init : Callable, ArrayLike, or None, optional
447
+ Bias initializer. If ``None``, no bias is added.
448
+ include_self : bool, optional
449
+ Whether to include self-connections (diagonal elements).
450
+ Default is ``True``.
451
+ name : str, optional
452
+ Name of the module.
453
+ param_type : type, optional
454
+ Type of parameter state. Default is ``ParamState``.
455
+
456
+ Attributes
457
+ ----------
458
+ in_size : tuple
459
+ Input size.
460
+ out_size : tuple
461
+ Output size.
462
+ include_self : bool
463
+ Whether self-connections are included.
464
+ weight : ParamState
465
+ Parameter state containing 'weight' and optionally 'bias'.
466
+
467
+ Examples
468
+ --------
469
+ .. code-block:: python
470
+
471
+ >>> import brainstate as bst
472
+ >>> import jax.numpy as jnp
473
+ >>>
474
+ >>> # All-to-all with self-connections
475
+ >>> layer = bst.nn.AllToAll((10,), (10,), include_self=True)
476
+ >>> x = jnp.ones((32, 10))
477
+ >>> y = layer(x)
478
+ >>> y.shape
479
+ (32, 10)
480
+ >>>
481
+ >>> # All-to-all without self-connections (recurrent layer)
482
+ >>> layer = bst.nn.AllToAll((10,), (10,), include_self=False)
483
+ >>> y = layer(x)
484
+ >>> y.shape
485
+ (32, 10)
255
486
  """
487
+ __module__ = 'brainstate.nn'
256
488
 
257
489
  def __init__(
258
490
  self,
@@ -320,14 +552,55 @@ class AllToAll(Module):
320
552
 
321
553
  class OneToOne(Module):
322
554
  """
323
- Synaptic matrix multiplication with One2One connection.
555
+ One-to-one connection layer.
556
+
557
+ Applies element-wise multiplication with a weight vector, implementing
558
+ diagonal connectivity where each input unit connects only to its
559
+ corresponding output unit.
324
560
 
325
- Args:
326
- in_size: Size. The number of neurons in the pre-synaptic neuron group.
327
- w_init: The synaptic weight initializer.
328
- b_init: The synaptic bias initializer.
329
- name: str. The object name.
561
+ Parameters
562
+ ----------
563
+ in_size : int or tuple of int
564
+ The number of neurons. Input and output sizes are the same.
565
+ w_init : Callable or ArrayLike, optional
566
+ Weight initializer. Default is ``Normal()``.
567
+ b_init : Callable, ArrayLike, or None, optional
568
+ Bias initializer. If ``None``, no bias is added.
569
+ name : str, optional
570
+ Name of the module.
571
+ param_type : type, optional
572
+ Type of parameter state. Default is ``ParamState``.
573
+
574
+ Attributes
575
+ ----------
576
+ in_size : tuple
577
+ Input size.
578
+ out_size : tuple
579
+ Output size (same as input size).
580
+ weight : ParamState
581
+ Parameter state containing 'weight' and optionally 'bias'.
582
+
583
+ Examples
584
+ --------
585
+ .. code-block:: python
586
+
587
+ >>> import brainstate as bst
588
+ >>> import jax.numpy as jnp
589
+ >>>
590
+ >>> # One-to-one connection
591
+ >>> layer = bst.nn.OneToOne((10,))
592
+ >>> x = jnp.ones((32, 10))
593
+ >>> y = layer(x)
594
+ >>> y.shape
595
+ (32, 10)
596
+ >>>
597
+ >>> # With bias
598
+ >>> layer = bst.nn.OneToOne((10,), b_init=bst.init.Constant(0.1))
599
+ >>> y = layer(x)
600
+ >>> y.shape
601
+ (32, 10)
330
602
  """
603
+ __module__ = 'brainstate.nn'
331
604
 
332
605
  def __init__(
333
606
  self,
@@ -357,35 +630,76 @@ class OneToOne(Module):
357
630
 
358
631
 
359
632
  class LoRA(Module):
360
- """A standalone LoRA layer.
361
-
362
- Example usage::
363
-
364
- >>> import brainstate as brainstate
365
- >>> import jax, jax.numpy as jnp
366
- >>> layer = brainstate.nn.LoRA(3, 2, 4)
367
- >>> layer.weight.value
368
- {'lora_a': Array([[ 0.25141352, -0.09826107],
369
- [ 0.2328382 , 0.38869813],
370
- [ 0.27069277, 0.7678282 ]], dtype=float32),
371
- 'lora_b': Array([[-0.8372317 , 0.21012013, -0.52999765, -0.31939325],
372
- [ 0.64234126, -0.42980042, 1.2549229 , -0.47134295]], dtype=float32)}
373
- >>> # Wrap around existing layer
374
- >>> linear = brainstate.nn.Linear(3, 4)
375
- >>> wrapper = brainstate.nn.LoRA(3, 2, 4, base_module=linear)
376
- >>> assert wrapper.base_module == linear
377
- >>> y = layer(jnp.ones((16, 3)))
633
+ """
634
+ Low-Rank Adaptation (LoRA) layer.
635
+
636
+ Implements parameter-efficient fine-tuning using low-rank decomposition [1]_.
637
+ Can be used standalone or as a wrapper around an existing module.
638
+
639
+ Parameters
640
+ ----------
641
+ in_features : int
642
+ The number of input features.
643
+ lora_rank : int
644
+ The rank of the low-rank decomposition. Lower rank means fewer parameters.
645
+ out_features : int
646
+ The number of output features.
647
+ base_module : Module, optional
648
+ A base module to wrap. If provided, the LoRA output will be added to
649
+ the base module's output. Default is ``None``.
650
+ kernel_init : Callable or ArrayLike, optional
651
+ Initializer for the LoRA weight matrices. Default is ``LecunNormal()``.
652
+ param_type : type, optional
653
+ Type of parameter state. Default is ``ParamState``.
654
+
655
+ Attributes
656
+ ----------
657
+ in_size : int
658
+ Input feature size.
659
+ out_size : int
660
+ Output feature size.
661
+ in_features : int
662
+ Number of input features.
663
+ out_features : int
664
+ Number of output features.
665
+ base_module : Module or None
666
+ The wrapped base module if provided.
667
+ weight : ParamState
668
+ Parameter state containing 'lora_a' and 'lora_b' matrices.
669
+
670
+ References
671
+ ----------
672
+ .. [1] Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S.,
673
+ Wang, L., & Chen, W. (2021). LoRA: Low-Rank Adaptation of Large
674
+ Language Models. arXiv preprint arXiv:2106.09685.
675
+
676
+ Examples
677
+ --------
678
+ .. code-block:: python
679
+
680
+ >>> import brainstate as bst
681
+ >>> import jax.numpy as jnp
682
+ >>>
683
+ >>> # Standalone LoRA layer
684
+ >>> layer = bst.nn.LoRA(in_features=10, lora_rank=2, out_features=5)
685
+ >>> x = jnp.ones((32, 10))
686
+ >>> y = layer(x)
378
687
  >>> y.shape
379
- (16, 4)
380
-
381
- Args:
382
- in_features: the number of input features.
383
- lora_rank: the rank of the LoRA dimension.
384
- out_features: the number of output features.
385
- base_module: a base module to call and substitute, if possible.
386
- kernel_init: initializer function for the weight matrices.
387
- param_type: the type of the LoRA params.
688
+ (32, 5)
689
+ >>>
690
+ >>> # Wrap around existing linear layer
691
+ >>> base = bst.nn.Linear((10,), (5,))
692
+ >>> lora_layer = bst.nn.LoRA(in_features=10, lora_rank=2,
693
+ ... out_features=5, base_module=base)
694
+ >>> y = lora_layer(x)
695
+ >>> y.shape
696
+ (32, 5)
697
+ >>>
698
+ >>> # Check parameter count - LoRA has fewer parameters
699
+ >>> # Base layer: 10 * 5 = 50 parameters
700
+ >>> # LoRA: 10 * 2 + 2 * 5 = 30 parameters
388
701
  """
702
+ __module__ = 'brainstate.nn'
389
703
 
390
704
  def __init__(
391
705
  self,
@@ -396,6 +710,7 @@ class LoRA(Module):
396
710
  base_module: Optional[Module] = None,
397
711
  kernel_init: Union[Callable, ArrayLike] = init.LecunNormal(),
398
712
  param_type: type = ParamState,
713
+ in_size: Size = None,
399
714
  ):
400
715
  super().__init__()
401
716
 
@@ -415,6 +730,11 @@ class LoRA(Module):
415
730
  )
416
731
  self.weight = param_type(param)
417
732
 
733
+ # in_size
734
+ if in_size is not None:
735
+ self.in_size = in_size
736
+ self.out_size = tuple(self.in_size[:-1]) + (out_features,)
737
+
418
738
  def __call__(self, x: ArrayLike):
419
739
  out = x @ self.weight.value['lora_a'] @ self.weight.value['lora_b']
420
740
  if self.base_module is not None: