brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/__init__.py +31 -11
  4. brainstate/_state.py +760 -316
  5. brainstate/_state_test.py +41 -12
  6. brainstate/_utils.py +31 -4
  7. brainstate/augment/__init__.py +40 -0
  8. brainstate/augment/_autograd.py +611 -0
  9. brainstate/augment/_autograd_test.py +1193 -0
  10. brainstate/augment/_eval_shape.py +102 -0
  11. brainstate/augment/_eval_shape_test.py +40 -0
  12. brainstate/augment/_mapping.py +525 -0
  13. brainstate/augment/_mapping_test.py +210 -0
  14. brainstate/augment/_random.py +99 -0
  15. brainstate/{transform → compile}/__init__.py +25 -13
  16. brainstate/compile/_ad_checkpoint.py +204 -0
  17. brainstate/compile/_ad_checkpoint_test.py +51 -0
  18. brainstate/compile/_conditions.py +259 -0
  19. brainstate/compile/_conditions_test.py +221 -0
  20. brainstate/compile/_error_if.py +94 -0
  21. brainstate/compile/_error_if_test.py +54 -0
  22. brainstate/compile/_jit.py +314 -0
  23. brainstate/compile/_jit_test.py +143 -0
  24. brainstate/compile/_loop_collect_return.py +516 -0
  25. brainstate/compile/_loop_collect_return_test.py +59 -0
  26. brainstate/compile/_loop_no_collection.py +185 -0
  27. brainstate/compile/_loop_no_collection_test.py +51 -0
  28. brainstate/compile/_make_jaxpr.py +756 -0
  29. brainstate/compile/_make_jaxpr_test.py +134 -0
  30. brainstate/compile/_progress_bar.py +111 -0
  31. brainstate/compile/_unvmap.py +159 -0
  32. brainstate/compile/_util.py +147 -0
  33. brainstate/environ.py +408 -381
  34. brainstate/environ_test.py +34 -32
  35. brainstate/event/__init__.py +27 -0
  36. brainstate/event/_csr.py +316 -0
  37. brainstate/event/_csr_benchmark.py +14 -0
  38. brainstate/event/_csr_test.py +118 -0
  39. brainstate/event/_fixed_probability.py +708 -0
  40. brainstate/event/_fixed_probability_benchmark.py +128 -0
  41. brainstate/event/_fixed_probability_test.py +131 -0
  42. brainstate/event/_linear.py +359 -0
  43. brainstate/event/_linear_benckmark.py +82 -0
  44. brainstate/event/_linear_test.py +117 -0
  45. brainstate/{nn/event → event}/_misc.py +7 -7
  46. brainstate/event/_xla_custom_op.py +312 -0
  47. brainstate/event/_xla_custom_op_test.py +55 -0
  48. brainstate/functional/_activations.py +521 -511
  49. brainstate/functional/_activations_test.py +300 -300
  50. brainstate/functional/_normalization.py +43 -43
  51. brainstate/functional/_others.py +15 -15
  52. brainstate/functional/_spikes.py +49 -49
  53. brainstate/graph/__init__.py +33 -0
  54. brainstate/graph/_graph_context.py +443 -0
  55. brainstate/graph/_graph_context_test.py +65 -0
  56. brainstate/graph/_graph_convert.py +246 -0
  57. brainstate/graph/_graph_node.py +300 -0
  58. brainstate/graph/_graph_node_test.py +75 -0
  59. brainstate/graph/_graph_operation.py +1746 -0
  60. brainstate/graph/_graph_operation_test.py +724 -0
  61. brainstate/init/_base.py +28 -10
  62. brainstate/init/_generic.py +175 -172
  63. brainstate/init/_random_inits.py +470 -415
  64. brainstate/init/_random_inits_test.py +150 -0
  65. brainstate/init/_regular_inits.py +66 -69
  66. brainstate/init/_regular_inits_test.py +51 -0
  67. brainstate/mixin.py +236 -244
  68. brainstate/mixin_test.py +44 -46
  69. brainstate/nn/__init__.py +26 -51
  70. brainstate/nn/_collective_ops.py +199 -0
  71. brainstate/nn/_dyn_impl/__init__.py +46 -0
  72. brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
  73. brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
  74. brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
  75. brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
  76. brainstate/nn/_dyn_impl/_inputs.py +154 -0
  77. brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
  78. brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
  79. brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
  80. brainstate/nn/_dyn_impl/_readout.py +128 -0
  81. brainstate/nn/_dyn_impl/_readout_test.py +54 -0
  82. brainstate/nn/_dynamics/__init__.py +37 -0
  83. brainstate/nn/_dynamics/_dynamics_base.py +631 -0
  84. brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
  85. brainstate/nn/_dynamics/_projection_base.py +346 -0
  86. brainstate/nn/_dynamics/_state_delay.py +453 -0
  87. brainstate/nn/_dynamics/_synouts.py +161 -0
  88. brainstate/nn/_dynamics/_synouts_test.py +58 -0
  89. brainstate/nn/_elementwise/__init__.py +22 -0
  90. brainstate/nn/_elementwise/_dropout.py +418 -0
  91. brainstate/nn/_elementwise/_dropout_test.py +100 -0
  92. brainstate/nn/_elementwise/_elementwise.py +1122 -0
  93. brainstate/nn/_elementwise/_elementwise_test.py +171 -0
  94. brainstate/nn/_exp_euler.py +97 -0
  95. brainstate/nn/_exp_euler_test.py +36 -0
  96. brainstate/nn/_interaction/__init__.py +41 -0
  97. brainstate/nn/_interaction/_conv.py +499 -0
  98. brainstate/nn/_interaction/_conv_test.py +239 -0
  99. brainstate/nn/_interaction/_embedding.py +59 -0
  100. brainstate/nn/_interaction/_linear.py +582 -0
  101. brainstate/nn/_interaction/_linear_test.py +42 -0
  102. brainstate/nn/_interaction/_normalizations.py +388 -0
  103. brainstate/nn/_interaction/_normalizations_test.py +75 -0
  104. brainstate/nn/_interaction/_poolings.py +1179 -0
  105. brainstate/nn/_interaction/_poolings_test.py +219 -0
  106. brainstate/nn/_module.py +328 -0
  107. brainstate/nn/_module_test.py +211 -0
  108. brainstate/nn/metrics.py +309 -309
  109. brainstate/optim/__init__.py +14 -2
  110. brainstate/optim/_base.py +66 -0
  111. brainstate/optim/_lr_scheduler.py +363 -400
  112. brainstate/optim/_lr_scheduler_test.py +25 -24
  113. brainstate/optim/_optax_optimizer.py +121 -176
  114. brainstate/optim/_optax_optimizer_test.py +41 -1
  115. brainstate/optim/_sgd_optimizer.py +950 -1025
  116. brainstate/random/_rand_funs.py +3269 -3268
  117. brainstate/random/_rand_funs_test.py +568 -0
  118. brainstate/random/_rand_seed.py +149 -117
  119. brainstate/random/_rand_seed_test.py +50 -0
  120. brainstate/random/_rand_state.py +1356 -1321
  121. brainstate/random/_random_for_unit.py +13 -13
  122. brainstate/surrogate.py +1262 -1243
  123. brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
  124. brainstate/typing.py +157 -130
  125. brainstate/util/__init__.py +52 -0
  126. brainstate/util/_caller.py +100 -0
  127. brainstate/util/_dict.py +734 -0
  128. brainstate/util/_dict_test.py +160 -0
  129. brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
  130. brainstate/util/_filter.py +178 -0
  131. brainstate/util/_others.py +497 -0
  132. brainstate/util/_pretty_repr.py +208 -0
  133. brainstate/util/_scaling.py +260 -0
  134. brainstate/util/_struct.py +524 -0
  135. brainstate/util/_tracers.py +75 -0
  136. brainstate/{_visualization.py → util/_visualization.py} +16 -16
  137. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
  138. brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
  139. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  140. brainstate/_module.py +0 -1637
  141. brainstate/_module_test.py +0 -207
  142. brainstate/nn/_base.py +0 -251
  143. brainstate/nn/_connections.py +0 -686
  144. brainstate/nn/_dynamics.py +0 -426
  145. brainstate/nn/_elementwise.py +0 -1438
  146. brainstate/nn/_embedding.py +0 -66
  147. brainstate/nn/_misc.py +0 -133
  148. brainstate/nn/_normalizations.py +0 -389
  149. brainstate/nn/_others.py +0 -101
  150. brainstate/nn/_poolings.py +0 -1229
  151. brainstate/nn/_poolings_test.py +0 -231
  152. brainstate/nn/_projection/_align_post.py +0 -546
  153. brainstate/nn/_projection/_align_pre.py +0 -599
  154. brainstate/nn/_projection/_delta.py +0 -241
  155. brainstate/nn/_projection/_vanilla.py +0 -101
  156. brainstate/nn/_rate_rnns.py +0 -410
  157. brainstate/nn/_readout.py +0 -136
  158. brainstate/nn/_synouts.py +0 -166
  159. brainstate/nn/event/csr.py +0 -312
  160. brainstate/nn/event/csr_test.py +0 -118
  161. brainstate/nn/event/fixed_probability.py +0 -276
  162. brainstate/nn/event/fixed_probability_test.py +0 -127
  163. brainstate/nn/event/linear.py +0 -220
  164. brainstate/nn/event/linear_test.py +0 -111
  165. brainstate/random/random_test.py +0 -593
  166. brainstate/transform/_autograd.py +0 -585
  167. brainstate/transform/_autograd_test.py +0 -1181
  168. brainstate/transform/_conditions.py +0 -334
  169. brainstate/transform/_conditions_test.py +0 -220
  170. brainstate/transform/_error_if.py +0 -94
  171. brainstate/transform/_error_if_test.py +0 -55
  172. brainstate/transform/_jit.py +0 -265
  173. brainstate/transform/_jit_test.py +0 -118
  174. brainstate/transform/_loop_collect_return.py +0 -502
  175. brainstate/transform/_loop_no_collection.py +0 -170
  176. brainstate/transform/_make_jaxpr.py +0 -739
  177. brainstate/transform/_make_jaxpr_test.py +0 -131
  178. brainstate/transform/_mapping.py +0 -109
  179. brainstate/transform/_progress_bar.py +0 -111
  180. brainstate/transform/_unvmap.py +0 -143
  181. brainstate/util.py +0 -746
  182. brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
  183. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  184. {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,582 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Callable, Union, Optional
21
+
22
+ import brainunit as u
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from jax.experimental.sparse.coo import coo_matvec_p, coo_matmat_p, COOInfo
26
+ from jax.experimental.sparse.csr import csr_matvec_p, csr_matmat_p
27
+
28
+ from brainstate import init, functional
29
+ from brainstate._state import ParamState
30
+ from brainstate.nn._module import Module
31
+ from brainstate.typing import ArrayLike, Size
32
+
33
+ __all__ = [
34
+ 'Linear',
35
+ 'ScaledWSLinear',
36
+ 'SignedWLinear',
37
+ 'CSRLinear',
38
+ 'CSCLinear',
39
+ 'COOLinear',
40
+ 'AllToAll',
41
+ 'OneToOne',
42
+ ]
43
+
44
+
45
+ class Linear(Module):
46
+ """
47
+ Linear layer.
48
+ """
49
+ __module__ = 'brainstate.nn'
50
+
51
+ def __init__(
52
+ self,
53
+ in_size: Size,
54
+ out_size: Size,
55
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
56
+ b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
57
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
58
+ name: Optional[str] = None,
59
+ ):
60
+ super().__init__(name=name)
61
+
62
+ # input and output shape
63
+ self.in_size = in_size
64
+ self.out_size = out_size
65
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
66
+ 'and "out_size" must be the same.')
67
+
68
+ # w_mask
69
+ self.w_mask = init.param(w_mask, self.in_size + self.out_size)
70
+
71
+ # weights
72
+ params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
73
+ if b_init is not None:
74
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
75
+ self.weight = ParamState(params)
76
+
77
+ def update(self, x):
78
+ params = self.weight.value
79
+ weight = params['weight']
80
+ if self.w_mask is not None:
81
+ weight = weight * self.w_mask
82
+ y = u.math.dot(x, weight)
83
+ if 'bias' in params:
84
+ y = y + params['bias']
85
+ return y
86
+
87
+
88
+ class SignedWLinear(Module):
89
+ """
90
+ Linear layer with signed weights.
91
+ """
92
+ __module__ = 'brainstate.nn'
93
+
94
+ def __init__(
95
+ self,
96
+ in_size: Size,
97
+ out_size: Size,
98
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
99
+ w_sign: Optional[ArrayLike] = None,
100
+ name: Optional[str] = None,
101
+
102
+ ):
103
+ super().__init__(name=name)
104
+
105
+ # input and output shape
106
+ self.in_size = in_size
107
+ self.out_size = out_size
108
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
109
+ 'and "out_size" must be the same.')
110
+
111
+ # w_mask
112
+ self.w_sign = w_sign
113
+
114
+ # weights
115
+ weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
116
+ self.weight = ParamState(weight)
117
+
118
+ def update(self, x):
119
+ w = self.weight.value
120
+ if self.w_sign is None:
121
+ return u.math.matmul(x, u.math.abs(w))
122
+ else:
123
+ return u.math.matmul(x, u.math.abs(w) * self.w_sign)
124
+
125
+
126
+ class ScaledWSLinear(Module):
127
+ """
128
+ Linear Layer with Weight Standardization.
129
+
130
+ Applies weight standardization to the weights of the linear layer.
131
+
132
+ Parameters
133
+ ----------
134
+ in_size: int, sequence of int
135
+ The input size.
136
+ out_size: int, sequence of int
137
+ The output size.
138
+ w_init: Callable, ArrayLike
139
+ The initializer for the weights.
140
+ b_init: Callable, ArrayLike
141
+ The initializer for the bias.
142
+ w_mask: ArrayLike, Callable
143
+ The optional mask of the weights.
144
+ ws_gain: bool
145
+ Whether to use gain for the weights. The default is True.
146
+ eps: float
147
+ The epsilon value for the weight standardization.
148
+ name: str
149
+ The name of the object.
150
+
151
+ """
152
+ __module__ = 'brainstate.nn'
153
+
154
+ def __init__(
155
+ self,
156
+ in_size: Size,
157
+ out_size: Size,
158
+ w_init: Callable = init.KaimingNormal(),
159
+ b_init: Callable = init.ZeroInit(),
160
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
161
+ ws_gain: bool = True,
162
+ eps: float = 1e-4,
163
+ name: str = None,
164
+ ):
165
+ super().__init__(name=name)
166
+
167
+ # input and output shape
168
+ self.in_size = in_size
169
+ self.out_size = out_size
170
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
171
+ 'and "out_size" must be the same.')
172
+
173
+ # w_mask
174
+ self.w_mask = init.param(w_mask, (self.in_size[0], 1))
175
+
176
+ # parameters
177
+ self.eps = eps
178
+
179
+ # weights
180
+ params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
181
+ if b_init is not None:
182
+ params['bias'] = init.param(b_init, self.out_size, allow_none=False)
183
+ # gain
184
+ if ws_gain:
185
+ s = params['weight'].shape
186
+ params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
187
+ self.weight = ParamState(params)
188
+
189
+ def update(self, x):
190
+ params = self.weight.value
191
+ w = params['weight']
192
+ w = functional.weight_standardization(w, self.eps, params.get('gain', None))
193
+ if self.w_mask is not None:
194
+ w = w * self.w_mask
195
+ y = u.math.dot(x, w)
196
+ if 'bias' in params:
197
+ y = y + params['bias']
198
+ return y
199
+
200
+
201
+ def csr_matmat(data, indices, indptr, B: jax.Array, *, shape, transpose: bool = False) -> jax.Array:
202
+ """Product of CSR sparse matrix and a dense matrix.
203
+
204
+ Args:
205
+ data : array of shape ``(nse,)``.
206
+ indices : array of shape ``(nse,)``
207
+ indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
208
+ B : array of shape ``(mat.shape[0] if transpose else mat.shape[1], cols)`` and
209
+ dtype ``mat.dtype``
210
+ transpose : boolean specifying whether to transpose the sparse matrix
211
+ before computing.
212
+
213
+ Returns:
214
+ C : array of shape ``(mat.shape[1] if transpose else mat.shape[0], cols)``
215
+ representing the matrix vector product.
216
+ """
217
+ return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose)
218
+
219
+
220
+ def csr_matvec(data, indices, indptr, v, *, shape, transpose=False) -> jax.Array:
221
+ """Product of CSR sparse matrix and a dense vector.
222
+
223
+ Args:
224
+ data : array of shape ``(nse,)``.
225
+ indices : array of shape ``(nse,)``
226
+ indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
227
+ v : array of shape ``(shape[0] if transpose else shape[1],)``
228
+ and dtype ``data.dtype``
229
+ shape : length-2 tuple representing the matrix shape
230
+ transpose : boolean specifying whether to transpose the sparse matrix
231
+ before computing.
232
+
233
+ Returns:
234
+ y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
235
+ the matrix vector product.
236
+ """
237
+ return csr_matvec_p.bind(data, indices, indptr, v, shape=shape, transpose=transpose)
238
+
239
+
240
+ class CSRLinear(Module):
241
+ """
242
+ Linear layer with Compressed Sparse Row (CSR) matrix.
243
+ """
244
+ __module__ = 'brainstate.nn'
245
+
246
+ def __init__(
247
+ self,
248
+ in_size: Size,
249
+ out_size: Size,
250
+ indptr: ArrayLike,
251
+ indices: ArrayLike,
252
+ weight: Union[Callable, ArrayLike],
253
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
254
+ name: Optional[str] = None,
255
+ ):
256
+ super().__init__(name=name)
257
+
258
+ # input and output shape
259
+ self.in_size = in_size
260
+ self.out_size = out_size
261
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
262
+ 'and "out_size" must be the same.')
263
+
264
+ # CSR data structure
265
+ indptr = jnp.asarray(indptr)
266
+ indices = jnp.asarray(indices)
267
+ assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
268
+ assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
269
+ assert indptr.size == self.in_size[-1] + 1, f"indptr must have size {self.in_size[-1] + 1}. Got: {indptr.size}"
270
+ with jax.ensure_compile_time_eval():
271
+ self.indptr = u.math.asarray(indptr)
272
+ self.indices = u.math.asarray(indices)
273
+
274
+ # weights
275
+ weight = init.param(weight, (len(indices),), allow_none=False, allow_scalar=False)
276
+ params = dict(weight=weight)
277
+ if b_init is not None:
278
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
279
+ self.weight = ParamState(params)
280
+
281
+ def update(self, x):
282
+ data = self.weight.value['weight']
283
+ data, w_unit = u.get_mantissa(data), u.get_unit(data)
284
+ x, x_unit = u.get_mantissa(x), u.get_unit(x)
285
+ shape = [self.in_size[-1], self.out_size[-1]]
286
+ if x.ndim == 1:
287
+ y = csr_matvec(data, self.indices, self.indptr, x, shape=shape)
288
+ elif x.ndim == 2:
289
+ y = csr_matmat(data, self.indices, self.indptr, x, shape=shape)
290
+ else:
291
+ raise NotImplementedError(f"matmul with object of shape {x.shape}")
292
+ y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
293
+ if 'bias' in self.weight.value:
294
+ y = y + self.weight.value['bias']
295
+ return y
296
+
297
+
298
+ class CSCLinear(Module):
299
+ """
300
+ Linear layer with Compressed Sparse Column (CSC) matrix.
301
+ """
302
+ __module__ = 'brainstate.nn'
303
+
304
+ def __init__(
305
+ self,
306
+ in_size: Size,
307
+ out_size: Size,
308
+ indptr: ArrayLike,
309
+ indices: ArrayLike,
310
+ weight: Union[Callable, ArrayLike],
311
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
312
+ name: Optional[str] = None,
313
+ ):
314
+ super().__init__(name=name)
315
+
316
+ # input and output shape
317
+ self.in_size = in_size
318
+ self.out_size = out_size
319
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
320
+ 'and "out_size" must be the same.')
321
+
322
+ # CSR data structure
323
+ indptr = jnp.asarray(indptr)
324
+ indices = jnp.asarray(indices)
325
+ assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
326
+ assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
327
+ assert indptr.size == self.in_size[-1] + 1, f"indptr must have size {self.in_size[-1] + 1}. Got: {indptr.size}"
328
+ with jax.ensure_compile_time_eval():
329
+ self.indptr = u.math.asarray(indptr)
330
+ self.indices = u.math.asarray(indices)
331
+
332
+ # weights
333
+ weight = init.param(weight, (len(indices),), allow_none=False, allow_scalar=False)
334
+ params = dict(weight=weight)
335
+ if b_init is not None:
336
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
337
+ self.weight = ParamState(params)
338
+
339
+ def update(self, x):
340
+ data = self.weight.value['weight']
341
+ data, w_unit = u.get_mantissa(data), u.get_unit(data)
342
+ x, x_unit = u.get_mantissa(x), u.get_unit(x)
343
+ shape = [self.out_size[-1], self.in_size[-1]]
344
+ if x.ndim == 1:
345
+ y = csr_matvec(data, self.indices, self.indptr, x, shape=shape, transpose=True)
346
+ elif x.ndim == 2:
347
+ y = csr_matmat(data, self.indices, self.indptr, x, shape=shape, transpose=True)
348
+ else:
349
+ raise NotImplementedError(f"matmul with object of shape {x.shape}")
350
+ y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
351
+ if 'bias' in self.weight.value:
352
+ y = y + self.weight.value['bias']
353
+ return y
354
+
355
+
356
+ def coo_matvec(
357
+ data: jax.Array,
358
+ row: jax.Array,
359
+ col: jax.Array,
360
+ v: jax.Array, *,
361
+ spinfo: COOInfo,
362
+ transpose: bool = False
363
+ ) -> jax.Array:
364
+ """Product of COO sparse matrix and a dense vector.
365
+
366
+ Args:
367
+ data : array of shape ``(nse,)``.
368
+ row : array of shape ``(nse,)``
369
+ col : array of shape ``(nse,)`` and dtype ``row.dtype``
370
+ v : array of shape ``(shape[0] if transpose else shape[1],)`` and
371
+ dtype ``data.dtype``
372
+ spinfo : COOInfo object containing the shape of the matrix and the dtype
373
+ transpose : boolean specifying whether to transpose the sparse matrix
374
+ before computing.
375
+
376
+ Returns:
377
+ y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
378
+ the matrix vector product.
379
+ """
380
+ return coo_matvec_p.bind(data, row, col, v, spinfo=spinfo, transpose=transpose)
381
+
382
+
383
+ def coo_matmat(
384
+ data: jax.Array, row: jax.Array, col: jax.Array, B: jax.Array, *,
385
+ spinfo: COOInfo, transpose: bool = False
386
+ ) -> jax.Array:
387
+ """Product of COO sparse matrix and a dense matrix.
388
+
389
+ Args:
390
+ data : array of shape ``(nse,)``.
391
+ row : array of shape ``(nse,)``
392
+ col : array of shape ``(nse,)`` and dtype ``row.dtype``
393
+ B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
394
+ dtype ``data.dtype``
395
+ spinfo : COOInfo object containing the shape of the matrix and the dtype
396
+ transpose : boolean specifying whether to transpose the sparse matrix
397
+ before computing.
398
+
399
+ Returns:
400
+ C : array of shape ``(shape[1] if transpose else shape[0], cols)``
401
+ representing the matrix vector product.
402
+ """
403
+ return coo_matmat_p.bind(data, row, col, B, spinfo=spinfo, transpose=transpose)
404
+
405
+
406
+ class COOLinear(Module):
407
+
408
+ def __init__(
409
+ self,
410
+ in_size: Size,
411
+ out_size: Size,
412
+ row: ArrayLike,
413
+ col: ArrayLike,
414
+ weight: Union[Callable, ArrayLike],
415
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
416
+ rows_sorted: bool = False,
417
+ cols_sorted: bool = False,
418
+ name: Optional[str] = None,
419
+ ):
420
+ super().__init__(name=name)
421
+
422
+ # input and output shape
423
+ self.in_size = in_size
424
+ self.out_size = out_size
425
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
426
+ 'and "out_size" must be the same.')
427
+
428
+ # COO data structure
429
+ row = jnp.asarray(row)
430
+ col = jnp.asarray(col)
431
+ assert row.ndim == 1, f"row must be 1D. Got: {row.ndim}"
432
+ assert col.ndim == 1, f"col must be 1D. Got: {col.ndim}"
433
+ assert row.size == col.size, f"row and col must have the same size. Got: {row.size} and {col.size}"
434
+ with jax.ensure_compile_time_eval():
435
+ self.row = u.math.asarray(row)
436
+ self.col = u.math.asarray(col)
437
+
438
+ # COO structure information
439
+ self.rows_sorted = rows_sorted
440
+ self.cols_sorted = cols_sorted
441
+
442
+ # weights
443
+ weight = init.param(weight, (len(row),), allow_none=False, allow_scalar=False)
444
+ params = dict(weight=weight)
445
+ if b_init is not None:
446
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
447
+ self.weight = ParamState(params)
448
+
449
+ def update(self, x):
450
+ data = self.weight.value['weight']
451
+ data, w_unit = u.get_mantissa(data), u.get_unit(data)
452
+ x, x_unit = u.get_mantissa(x), u.get_unit(x)
453
+ spinfo = COOInfo(
454
+ shape=(self.in_size[-1], self.out_size[-1]),
455
+ rows_sorted=self.rows_sorted,
456
+ cols_sorted=self.cols_sorted
457
+ )
458
+ if x.ndim == 1:
459
+ y = coo_matvec(data, self.row, self.col, x, spinfo=spinfo, transpose=False)
460
+ elif x.ndim == 2:
461
+ y = coo_matmat(data, self.row, self.col, x, spinfo=spinfo, transpose=False)
462
+ else:
463
+ raise NotImplementedError(f"matmul with object of shape {x.shape}")
464
+ y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
465
+ if 'bias' in self.weight.value:
466
+ y = y + self.weight.value['bias']
467
+ return y
468
+
469
+
470
+ class AllToAll(Module):
471
+ """
472
+ Synaptic matrix multiplication with All-to-All connections.
473
+
474
+ Args:
475
+ in_size: Size. The number of neurons in the pre-synaptic neuron group.
476
+ out_size: Size. The number of neurons in the postsynaptic neuron group.
477
+ w_init: The synaptic weight initializer.
478
+ include_self: bool. Whether connect the neuron with at the same position.
479
+ name: str. The object name.
480
+ """
481
+
482
+ def __init__(
483
+ self,
484
+ in_size: Size,
485
+ out_size: Size,
486
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
487
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
488
+ include_self: bool = True,
489
+ name: Optional[str] = None,
490
+ ):
491
+ super().__init__(name=name)
492
+
493
+ # input and output shape
494
+ self.in_size = in_size
495
+ self.out_size = out_size
496
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
497
+ 'and "out_size" must be the same.')
498
+
499
+ # others
500
+ self.include_self = include_self
501
+
502
+ # weights
503
+ weight = init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False)
504
+ params = dict(weight=weight)
505
+ if b_init is not None:
506
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
507
+ self.weight = ParamState(params)
508
+
509
+ def update(self, pre_val):
510
+ params = self.weight.value
511
+ pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
512
+ w_val, w_unit = u.get_mantissa(params['weight']), u.get_unit(params['weight'])
513
+
514
+ if u.math.ndim(w_val) == 0: # weight is a scalar
515
+ if pre_val.ndim == 1:
516
+ post_val = u.math.sum(pre_val)
517
+ else:
518
+ post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
519
+ if not self.include_self:
520
+ if self.in_size == self.out_size:
521
+ post_val = post_val - pre_val
522
+ elif self.in_size[-1] > self.out_size[-1]:
523
+ val = pre_val[..., :self.out_size[-1]]
524
+ post_val = post_val - val
525
+ else:
526
+ size = list(self.out_size)
527
+ size[-1] = self.out_size[-1] - self.in_size[-1]
528
+ val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
529
+ post_val = post_val - val
530
+ post_val = w_val * post_val
531
+
532
+ else: # weight is a matrix
533
+ assert u.math.ndim(w_val) == 2, '"weight" must be a 2D matrix.'
534
+ if not self.include_self:
535
+ post_val = pre_val @ u.math.fill_diagonal(w_val, 0.)
536
+ else:
537
+ post_val = pre_val @ w_val
538
+
539
+ post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
540
+ if 'bias' in params:
541
+ post_val = post_val + params['bias']
542
+ return post_val
543
+
544
+
545
+ class OneToOne(Module):
546
+ """
547
+ Synaptic matrix multiplication with One2One connection.
548
+
549
+ Args:
550
+ in_size: Size. The number of neurons in the pre-synaptic neuron group.
551
+ w_init: The synaptic weight initializer.
552
+ b_init: The synaptic bias initializer.
553
+ name: str. The object name.
554
+ """
555
+
556
+ def __init__(
557
+ self,
558
+ in_size: Size,
559
+ w_init: Union[Callable, ArrayLike] = init.Normal(),
560
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
561
+ name: Optional[str] = None,
562
+ ):
563
+ super().__init__(name=name)
564
+
565
+ # input and output shape
566
+ self.in_size = in_size
567
+ self.out_size = in_size
568
+
569
+ # weights
570
+ param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
571
+ if b_init is not None:
572
+ param['bias'] = init.param(b_init, self.out_size, allow_none=False)
573
+ self.weight = param
574
+
575
+ def update(self, pre_val):
576
+ pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
577
+ w_val, w_unit = u.get_mantissa(self.weight['weight']), u.get_unit(self.weight['weight'])
578
+ post_val = pre_val * w_val
579
+ post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
580
+ if 'bias' in self.weight:
581
+ post_val = post_val + self.weight['bias']
582
+ return post_val
@@ -0,0 +1,42 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from __future__ import annotations
18
+
19
+ import jax.numpy as jnp
20
+ import pytest
21
+ from absl.testing import absltest
22
+ from absl.testing import parameterized
23
+
24
+ import brainstate as bst
25
+
26
+
27
+
28
+
29
+
30
+ class TestDense(parameterized.TestCase):
31
+ @parameterized.product(
32
+ size=[(10,),
33
+ (20, 10),
34
+ (5, 8, 10)],
35
+ num_out=[20, ]
36
+ )
37
+ def test_Dense1(self, size, num_out):
38
+ f = bst.nn.Linear(10, num_out)
39
+ x = bst.random.random(size)
40
+ y = f(x)
41
+ self.assertTrue(y.shape == size[:-1] + (num_out,))
42
+