brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. benchmark/COBA_2005.py +125 -0
  2. benchmark/CUBA_2005.py +149 -0
  3. brainstate/augment/_autograd.py +9 -6
  4. brainstate/event/__init__.py +4 -2
  5. brainstate/event/_csr.py +26 -18
  6. brainstate/event/_csr_benchmark.py +14 -0
  7. brainstate/event/_fixed_probability.py +589 -152
  8. brainstate/event/_fixed_probability_benchmark.py +128 -0
  9. brainstate/event/_fixed_probability_test.py +13 -10
  10. brainstate/event/_linear.py +267 -127
  11. brainstate/event/_linear_benckmark.py +82 -0
  12. brainstate/event/_linear_test.py +8 -3
  13. brainstate/event/_xla_custom_op.py +312 -0
  14. brainstate/event/_xla_custom_op_test.py +55 -0
  15. brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
  16. brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
  17. brainstate/nn/_dynamics/_projection_base.py +1 -1
  18. brainstate/nn/_exp_euler.py +1 -1
  19. brainstate/nn/_interaction/__init__.py +13 -4
  20. brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
  21. brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
  22. brainstate/nn/_interaction/_linear.py +582 -0
  23. brainstate/nn/_interaction/_linear_test.py +42 -0
  24. brainstate/optim/_lr_scheduler.py +1 -1
  25. brainstate/optim/_optax_optimizer.py +18 -0
  26. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +1 -1
  27. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/RECORD +30 -21
  28. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
  29. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
  30. {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -0,0 +1,582 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # -*- coding: utf-8 -*-
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Callable, Union, Optional
21
+
22
+ import brainunit as u
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from jax.experimental.sparse.coo import coo_matvec_p, coo_matmat_p, COOInfo
26
+ from jax.experimental.sparse.csr import csr_matvec_p, csr_matmat_p
27
+
28
+ from brainstate import init, functional
29
+ from brainstate._state import ParamState
30
+ from brainstate.nn._module import Module
31
+ from brainstate.typing import ArrayLike, Size
32
+
33
+ __all__ = [
34
+ 'Linear',
35
+ 'ScaledWSLinear',
36
+ 'SignedWLinear',
37
+ 'CSRLinear',
38
+ 'CSCLinear',
39
+ 'COOLinear',
40
+ 'AllToAll',
41
+ 'OneToOne',
42
+ ]
43
+
44
+
45
+ class Linear(Module):
46
+ """
47
+ Linear layer.
48
+ """
49
+ __module__ = 'brainstate.nn'
50
+
51
+ def __init__(
52
+ self,
53
+ in_size: Size,
54
+ out_size: Size,
55
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
56
+ b_init: Optional[Union[Callable, ArrayLike]] = init.ZeroInit(),
57
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
58
+ name: Optional[str] = None,
59
+ ):
60
+ super().__init__(name=name)
61
+
62
+ # input and output shape
63
+ self.in_size = in_size
64
+ self.out_size = out_size
65
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
66
+ 'and "out_size" must be the same.')
67
+
68
+ # w_mask
69
+ self.w_mask = init.param(w_mask, self.in_size + self.out_size)
70
+
71
+ # weights
72
+ params = dict(weight=init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False))
73
+ if b_init is not None:
74
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
75
+ self.weight = ParamState(params)
76
+
77
+ def update(self, x):
78
+ params = self.weight.value
79
+ weight = params['weight']
80
+ if self.w_mask is not None:
81
+ weight = weight * self.w_mask
82
+ y = u.math.dot(x, weight)
83
+ if 'bias' in params:
84
+ y = y + params['bias']
85
+ return y
86
+
87
+
88
+ class SignedWLinear(Module):
89
+ """
90
+ Linear layer with signed weights.
91
+ """
92
+ __module__ = 'brainstate.nn'
93
+
94
+ def __init__(
95
+ self,
96
+ in_size: Size,
97
+ out_size: Size,
98
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
99
+ w_sign: Optional[ArrayLike] = None,
100
+ name: Optional[str] = None,
101
+
102
+ ):
103
+ super().__init__(name=name)
104
+
105
+ # input and output shape
106
+ self.in_size = in_size
107
+ self.out_size = out_size
108
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
109
+ 'and "out_size" must be the same.')
110
+
111
+ # w_mask
112
+ self.w_sign = w_sign
113
+
114
+ # weights
115
+ weight = init.param(w_init, self.in_size + self.out_size, allow_none=False)
116
+ self.weight = ParamState(weight)
117
+
118
+ def update(self, x):
119
+ w = self.weight.value
120
+ if self.w_sign is None:
121
+ return u.math.matmul(x, u.math.abs(w))
122
+ else:
123
+ return u.math.matmul(x, u.math.abs(w) * self.w_sign)
124
+
125
+
126
+ class ScaledWSLinear(Module):
127
+ """
128
+ Linear Layer with Weight Standardization.
129
+
130
+ Applies weight standardization to the weights of the linear layer.
131
+
132
+ Parameters
133
+ ----------
134
+ in_size: int, sequence of int
135
+ The input size.
136
+ out_size: int, sequence of int
137
+ The output size.
138
+ w_init: Callable, ArrayLike
139
+ The initializer for the weights.
140
+ b_init: Callable, ArrayLike
141
+ The initializer for the bias.
142
+ w_mask: ArrayLike, Callable
143
+ The optional mask of the weights.
144
+ ws_gain: bool
145
+ Whether to use gain for the weights. The default is True.
146
+ eps: float
147
+ The epsilon value for the weight standardization.
148
+ name: str
149
+ The name of the object.
150
+
151
+ """
152
+ __module__ = 'brainstate.nn'
153
+
154
+ def __init__(
155
+ self,
156
+ in_size: Size,
157
+ out_size: Size,
158
+ w_init: Callable = init.KaimingNormal(),
159
+ b_init: Callable = init.ZeroInit(),
160
+ w_mask: Optional[Union[ArrayLike, Callable]] = None,
161
+ ws_gain: bool = True,
162
+ eps: float = 1e-4,
163
+ name: str = None,
164
+ ):
165
+ super().__init__(name=name)
166
+
167
+ # input and output shape
168
+ self.in_size = in_size
169
+ self.out_size = out_size
170
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
171
+ 'and "out_size" must be the same.')
172
+
173
+ # w_mask
174
+ self.w_mask = init.param(w_mask, (self.in_size[0], 1))
175
+
176
+ # parameters
177
+ self.eps = eps
178
+
179
+ # weights
180
+ params = dict(weight=init.param(w_init, self.in_size + self.out_size, allow_none=False))
181
+ if b_init is not None:
182
+ params['bias'] = init.param(b_init, self.out_size, allow_none=False)
183
+ # gain
184
+ if ws_gain:
185
+ s = params['weight'].shape
186
+ params['gain'] = jnp.ones((1,) * (len(s) - 1) + (s[-1],), dtype=params['weight'].dtype)
187
+ self.weight = ParamState(params)
188
+
189
+ def update(self, x):
190
+ params = self.weight.value
191
+ w = params['weight']
192
+ w = functional.weight_standardization(w, self.eps, params.get('gain', None))
193
+ if self.w_mask is not None:
194
+ w = w * self.w_mask
195
+ y = u.math.dot(x, w)
196
+ if 'bias' in params:
197
+ y = y + params['bias']
198
+ return y
199
+
200
+
201
+ def csr_matmat(data, indices, indptr, B: jax.Array, *, shape, transpose: bool = False) -> jax.Array:
202
+ """Product of CSR sparse matrix and a dense matrix.
203
+
204
+ Args:
205
+ data : array of shape ``(nse,)``.
206
+ indices : array of shape ``(nse,)``
207
+ indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
208
+ B : array of shape ``(mat.shape[0] if transpose else mat.shape[1], cols)`` and
209
+ dtype ``mat.dtype``
210
+ transpose : boolean specifying whether to transpose the sparse matrix
211
+ before computing.
212
+
213
+ Returns:
214
+ C : array of shape ``(mat.shape[1] if transpose else mat.shape[0], cols)``
215
+ representing the matrix vector product.
216
+ """
217
+ return csr_matmat_p.bind(data, indices, indptr, B, shape=shape, transpose=transpose)
218
+
219
+
220
+ def csr_matvec(data, indices, indptr, v, *, shape, transpose=False) -> jax.Array:
221
+ """Product of CSR sparse matrix and a dense vector.
222
+
223
+ Args:
224
+ data : array of shape ``(nse,)``.
225
+ indices : array of shape ``(nse,)``
226
+ indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
227
+ v : array of shape ``(shape[0] if transpose else shape[1],)``
228
+ and dtype ``data.dtype``
229
+ shape : length-2 tuple representing the matrix shape
230
+ transpose : boolean specifying whether to transpose the sparse matrix
231
+ before computing.
232
+
233
+ Returns:
234
+ y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
235
+ the matrix vector product.
236
+ """
237
+ return csr_matvec_p.bind(data, indices, indptr, v, shape=shape, transpose=transpose)
238
+
239
+
240
+ class CSRLinear(Module):
241
+ """
242
+ Linear layer with Compressed Sparse Row (CSR) matrix.
243
+ """
244
+ __module__ = 'brainstate.nn'
245
+
246
+ def __init__(
247
+ self,
248
+ in_size: Size,
249
+ out_size: Size,
250
+ indptr: ArrayLike,
251
+ indices: ArrayLike,
252
+ weight: Union[Callable, ArrayLike],
253
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
254
+ name: Optional[str] = None,
255
+ ):
256
+ super().__init__(name=name)
257
+
258
+ # input and output shape
259
+ self.in_size = in_size
260
+ self.out_size = out_size
261
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
262
+ 'and "out_size" must be the same.')
263
+
264
+ # CSR data structure
265
+ indptr = jnp.asarray(indptr)
266
+ indices = jnp.asarray(indices)
267
+ assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
268
+ assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
269
+ assert indptr.size == self.in_size[-1] + 1, f"indptr must have size {self.in_size[-1] + 1}. Got: {indptr.size}"
270
+ with jax.ensure_compile_time_eval():
271
+ self.indptr = u.math.asarray(indptr)
272
+ self.indices = u.math.asarray(indices)
273
+
274
+ # weights
275
+ weight = init.param(weight, (len(indices),), allow_none=False, allow_scalar=False)
276
+ params = dict(weight=weight)
277
+ if b_init is not None:
278
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
279
+ self.weight = ParamState(params)
280
+
281
+ def update(self, x):
282
+ data = self.weight.value['weight']
283
+ data, w_unit = u.get_mantissa(data), u.get_unit(data)
284
+ x, x_unit = u.get_mantissa(x), u.get_unit(x)
285
+ shape = [self.in_size[-1], self.out_size[-1]]
286
+ if x.ndim == 1:
287
+ y = csr_matvec(data, self.indices, self.indptr, x, shape=shape)
288
+ elif x.ndim == 2:
289
+ y = csr_matmat(data, self.indices, self.indptr, x, shape=shape)
290
+ else:
291
+ raise NotImplementedError(f"matmul with object of shape {x.shape}")
292
+ y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
293
+ if 'bias' in self.weight.value:
294
+ y = y + self.weight.value['bias']
295
+ return y
296
+
297
+
298
+ class CSCLinear(Module):
299
+ """
300
+ Linear layer with Compressed Sparse Column (CSC) matrix.
301
+ """
302
+ __module__ = 'brainstate.nn'
303
+
304
+ def __init__(
305
+ self,
306
+ in_size: Size,
307
+ out_size: Size,
308
+ indptr: ArrayLike,
309
+ indices: ArrayLike,
310
+ weight: Union[Callable, ArrayLike],
311
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
312
+ name: Optional[str] = None,
313
+ ):
314
+ super().__init__(name=name)
315
+
316
+ # input and output shape
317
+ self.in_size = in_size
318
+ self.out_size = out_size
319
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
320
+ 'and "out_size" must be the same.')
321
+
322
+ # CSR data structure
323
+ indptr = jnp.asarray(indptr)
324
+ indices = jnp.asarray(indices)
325
+ assert indptr.ndim == 1, f"indptr must be 1D. Got: {indptr.ndim}"
326
+ assert indices.ndim == 1, f"indices must be 1D. Got: {indices.ndim}"
327
+ assert indptr.size == self.in_size[-1] + 1, f"indptr must have size {self.in_size[-1] + 1}. Got: {indptr.size}"
328
+ with jax.ensure_compile_time_eval():
329
+ self.indptr = u.math.asarray(indptr)
330
+ self.indices = u.math.asarray(indices)
331
+
332
+ # weights
333
+ weight = init.param(weight, (len(indices),), allow_none=False, allow_scalar=False)
334
+ params = dict(weight=weight)
335
+ if b_init is not None:
336
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
337
+ self.weight = ParamState(params)
338
+
339
+ def update(self, x):
340
+ data = self.weight.value['weight']
341
+ data, w_unit = u.get_mantissa(data), u.get_unit(data)
342
+ x, x_unit = u.get_mantissa(x), u.get_unit(x)
343
+ shape = [self.out_size[-1], self.in_size[-1]]
344
+ if x.ndim == 1:
345
+ y = csr_matvec(data, self.indices, self.indptr, x, shape=shape, transpose=True)
346
+ elif x.ndim == 2:
347
+ y = csr_matmat(data, self.indices, self.indptr, x, shape=shape, transpose=True)
348
+ else:
349
+ raise NotImplementedError(f"matmul with object of shape {x.shape}")
350
+ y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
351
+ if 'bias' in self.weight.value:
352
+ y = y + self.weight.value['bias']
353
+ return y
354
+
355
+
356
+ def coo_matvec(
357
+ data: jax.Array,
358
+ row: jax.Array,
359
+ col: jax.Array,
360
+ v: jax.Array, *,
361
+ spinfo: COOInfo,
362
+ transpose: bool = False
363
+ ) -> jax.Array:
364
+ """Product of COO sparse matrix and a dense vector.
365
+
366
+ Args:
367
+ data : array of shape ``(nse,)``.
368
+ row : array of shape ``(nse,)``
369
+ col : array of shape ``(nse,)`` and dtype ``row.dtype``
370
+ v : array of shape ``(shape[0] if transpose else shape[1],)`` and
371
+ dtype ``data.dtype``
372
+ spinfo : COOInfo object containing the shape of the matrix and the dtype
373
+ transpose : boolean specifying whether to transpose the sparse matrix
374
+ before computing.
375
+
376
+ Returns:
377
+ y : array of shape ``(shape[1] if transpose else shape[0],)`` representing
378
+ the matrix vector product.
379
+ """
380
+ return coo_matvec_p.bind(data, row, col, v, spinfo=spinfo, transpose=transpose)
381
+
382
+
383
+ def coo_matmat(
384
+ data: jax.Array, row: jax.Array, col: jax.Array, B: jax.Array, *,
385
+ spinfo: COOInfo, transpose: bool = False
386
+ ) -> jax.Array:
387
+ """Product of COO sparse matrix and a dense matrix.
388
+
389
+ Args:
390
+ data : array of shape ``(nse,)``.
391
+ row : array of shape ``(nse,)``
392
+ col : array of shape ``(nse,)`` and dtype ``row.dtype``
393
+ B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
394
+ dtype ``data.dtype``
395
+ spinfo : COOInfo object containing the shape of the matrix and the dtype
396
+ transpose : boolean specifying whether to transpose the sparse matrix
397
+ before computing.
398
+
399
+ Returns:
400
+ C : array of shape ``(shape[1] if transpose else shape[0], cols)``
401
+ representing the matrix vector product.
402
+ """
403
+ return coo_matmat_p.bind(data, row, col, B, spinfo=spinfo, transpose=transpose)
404
+
405
+
406
+ class COOLinear(Module):
407
+
408
+ def __init__(
409
+ self,
410
+ in_size: Size,
411
+ out_size: Size,
412
+ row: ArrayLike,
413
+ col: ArrayLike,
414
+ weight: Union[Callable, ArrayLike],
415
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
416
+ rows_sorted: bool = False,
417
+ cols_sorted: bool = False,
418
+ name: Optional[str] = None,
419
+ ):
420
+ super().__init__(name=name)
421
+
422
+ # input and output shape
423
+ self.in_size = in_size
424
+ self.out_size = out_size
425
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
426
+ 'and "out_size" must be the same.')
427
+
428
+ # COO data structure
429
+ row = jnp.asarray(row)
430
+ col = jnp.asarray(col)
431
+ assert row.ndim == 1, f"row must be 1D. Got: {row.ndim}"
432
+ assert col.ndim == 1, f"col must be 1D. Got: {col.ndim}"
433
+ assert row.size == col.size, f"row and col must have the same size. Got: {row.size} and {col.size}"
434
+ with jax.ensure_compile_time_eval():
435
+ self.row = u.math.asarray(row)
436
+ self.col = u.math.asarray(col)
437
+
438
+ # COO structure information
439
+ self.rows_sorted = rows_sorted
440
+ self.cols_sorted = cols_sorted
441
+
442
+ # weights
443
+ weight = init.param(weight, (len(row),), allow_none=False, allow_scalar=False)
444
+ params = dict(weight=weight)
445
+ if b_init is not None:
446
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
447
+ self.weight = ParamState(params)
448
+
449
+ def update(self, x):
450
+ data = self.weight.value['weight']
451
+ data, w_unit = u.get_mantissa(data), u.get_unit(data)
452
+ x, x_unit = u.get_mantissa(x), u.get_unit(x)
453
+ spinfo = COOInfo(
454
+ shape=(self.in_size[-1], self.out_size[-1]),
455
+ rows_sorted=self.rows_sorted,
456
+ cols_sorted=self.cols_sorted
457
+ )
458
+ if x.ndim == 1:
459
+ y = coo_matvec(data, self.row, self.col, x, spinfo=spinfo, transpose=False)
460
+ elif x.ndim == 2:
461
+ y = coo_matmat(data, self.row, self.col, x, spinfo=spinfo, transpose=False)
462
+ else:
463
+ raise NotImplementedError(f"matmul with object of shape {x.shape}")
464
+ y = u.maybe_decimal(u.Quantity(y, unit=w_unit * x_unit))
465
+ if 'bias' in self.weight.value:
466
+ y = y + self.weight.value['bias']
467
+ return y
468
+
469
+
470
+ class AllToAll(Module):
471
+ """
472
+ Synaptic matrix multiplication with All-to-All connections.
473
+
474
+ Args:
475
+ in_size: Size. The number of neurons in the pre-synaptic neuron group.
476
+ out_size: Size. The number of neurons in the postsynaptic neuron group.
477
+ w_init: The synaptic weight initializer.
478
+ include_self: bool. Whether connect the neuron with at the same position.
479
+ name: str. The object name.
480
+ """
481
+
482
+ def __init__(
483
+ self,
484
+ in_size: Size,
485
+ out_size: Size,
486
+ w_init: Union[Callable, ArrayLike] = init.KaimingNormal(),
487
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
488
+ include_self: bool = True,
489
+ name: Optional[str] = None,
490
+ ):
491
+ super().__init__(name=name)
492
+
493
+ # input and output shape
494
+ self.in_size = in_size
495
+ self.out_size = out_size
496
+ assert self.in_size[:-1] == self.out_size[:-1], ('The first n-1 dimensions of "in_size" '
497
+ 'and "out_size" must be the same.')
498
+
499
+ # others
500
+ self.include_self = include_self
501
+
502
+ # weights
503
+ weight = init.param(w_init, (self.in_size[-1], self.out_size[-1]), allow_none=False)
504
+ params = dict(weight=weight)
505
+ if b_init is not None:
506
+ params['bias'] = init.param(b_init, self.out_size[-1], allow_none=False)
507
+ self.weight = ParamState(params)
508
+
509
+ def update(self, pre_val):
510
+ params = self.weight.value
511
+ pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
512
+ w_val, w_unit = u.get_mantissa(params['weight']), u.get_unit(params['weight'])
513
+
514
+ if u.math.ndim(w_val) == 0: # weight is a scalar
515
+ if pre_val.ndim == 1:
516
+ post_val = u.math.sum(pre_val)
517
+ else:
518
+ post_val = u.math.sum(pre_val, keepdims=True, axis=-1)
519
+ if not self.include_self:
520
+ if self.in_size == self.out_size:
521
+ post_val = post_val - pre_val
522
+ elif self.in_size[-1] > self.out_size[-1]:
523
+ val = pre_val[..., :self.out_size[-1]]
524
+ post_val = post_val - val
525
+ else:
526
+ size = list(self.out_size)
527
+ size[-1] = self.out_size[-1] - self.in_size[-1]
528
+ val = u.math.concatenate([pre_val, u.math.zeros(size, dtype=pre_val.dtype)])
529
+ post_val = post_val - val
530
+ post_val = w_val * post_val
531
+
532
+ else: # weight is a matrix
533
+ assert u.math.ndim(w_val) == 2, '"weight" must be a 2D matrix.'
534
+ if not self.include_self:
535
+ post_val = pre_val @ u.math.fill_diagonal(w_val, 0.)
536
+ else:
537
+ post_val = pre_val @ w_val
538
+
539
+ post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
540
+ if 'bias' in params:
541
+ post_val = post_val + params['bias']
542
+ return post_val
543
+
544
+
545
+ class OneToOne(Module):
546
+ """
547
+ Synaptic matrix multiplication with One2One connection.
548
+
549
+ Args:
550
+ in_size: Size. The number of neurons in the pre-synaptic neuron group.
551
+ w_init: The synaptic weight initializer.
552
+ b_init: The synaptic bias initializer.
553
+ name: str. The object name.
554
+ """
555
+
556
+ def __init__(
557
+ self,
558
+ in_size: Size,
559
+ w_init: Union[Callable, ArrayLike] = init.Normal(),
560
+ b_init: Optional[Union[Callable, ArrayLike]] = None,
561
+ name: Optional[str] = None,
562
+ ):
563
+ super().__init__(name=name)
564
+
565
+ # input and output shape
566
+ self.in_size = in_size
567
+ self.out_size = in_size
568
+
569
+ # weights
570
+ param = dict(weight=init.param(w_init, self.in_size, allow_none=False))
571
+ if b_init is not None:
572
+ param['bias'] = init.param(b_init, self.out_size, allow_none=False)
573
+ self.weight = param
574
+
575
+ def update(self, pre_val):
576
+ pre_val, pre_unit = u.get_mantissa(pre_val), u.get_unit(pre_val)
577
+ w_val, w_unit = u.get_mantissa(self.weight['weight']), u.get_unit(self.weight['weight'])
578
+ post_val = pre_val * w_val
579
+ post_val = u.maybe_decimal(u.Quantity(post_val, unit=w_unit * pre_unit))
580
+ if 'bias' in self.weight:
581
+ post_val = post_val + self.weight['bias']
582
+ return post_val
@@ -0,0 +1,42 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from __future__ import annotations
18
+
19
+ import jax.numpy as jnp
20
+ import pytest
21
+ from absl.testing import absltest
22
+ from absl.testing import parameterized
23
+
24
+ import brainstate as bst
25
+
26
+
27
+
28
+
29
+
30
+ class TestDense(parameterized.TestCase):
31
+ @parameterized.product(
32
+ size=[(10,),
33
+ (20, 10),
34
+ (5, 8, 10)],
35
+ num_out=[20, ]
36
+ )
37
+ def test_Dense1(self, size, num_out):
38
+ f = bst.nn.Linear(10, num_out)
39
+ x = bst.random.random(size)
40
+ y = f(x)
41
+ self.assertTrue(y.shape == size[:-1] + (num_out,))
42
+
@@ -286,7 +286,7 @@ class CosineAnnealingLR(LearningRateScheduler):
286
286
 
287
287
 
288
288
  class CosineAnnealingWarmRestarts(CallBasedLRScheduler):
289
- """Set the learning rate of each parameter group using a cosine annealing
289
+ r"""Set the learning rate of each parameter group using a cosine annealing
290
290
  schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
291
291
  is the number of epochs since the last restart and :math:`T_{i}` is the number
292
292
  of epochs between two warm restarts in SGDR:
@@ -133,3 +133,21 @@ class OptaxOptimizer(Optimizer):
133
133
  for k, v in self.param_states.items():
134
134
  v.value = new_params[k]
135
135
  self.opt_state.value = new_opt_state
136
+
137
+
138
+ class LBFGS(OptaxOptimizer):
139
+ def __init__(
140
+ self,
141
+ lr: float,
142
+ memory_size: int = 10,
143
+ scale_init_precond: bool = True,
144
+ ):
145
+ import optax # type: ignore[import-not-found,import-untyped]
146
+ super().__init__(
147
+ optax.lbfgs(
148
+ lr,
149
+ memory_size=memory_size,
150
+ scale_init_precond=scale_init_precond,
151
+ linesearch=None,
152
+ )
153
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.1.0
3
+ Version: 0.1.0.post20241122
4
4
  Summary: A ``State``-based Transformation System for Program Compilation and Augmentation.
5
5
  Home-page: https://github.com/chaobrain/brainstate
6
6
  Author: BrainState Developers