brainstate 0.1.0.post20241122__py2.py3-none-any.whl → 0.1.0.post20241129__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +1 -1
- brainstate/augment/_autograd.py +112 -114
- brainstate/augment/_autograd_test.py +97 -0
- brainstate/event/__init__.py +6 -6
- brainstate/event/_csr_mv_benchmark.py +14 -0
- brainstate/event/{_linear_test.py → _linear_mv_test.py} +1 -1
- brainstate/event/_xla_custom_op.py +5 -8
- brainstate/nn/_elementwise/_dropout.py +4 -4
- brainstate/nn/_interaction/_linear.py +2 -2
- brainstate/nn/_interaction/_normalizations.py +577 -55
- brainstate/optim/_optax_optimizer.py +1 -0
- {brainstate-0.1.0.post20241122.dist-info → brainstate-0.1.0.post20241129.dist-info}/METADATA +2 -2
- {brainstate-0.1.0.post20241122.dist-info → brainstate-0.1.0.post20241129.dist-info}/RECORD +23 -22
- /brainstate/event/{_csr.py → _csr_mv.py} +0 -0
- /brainstate/event/{_csr_test.py → _csr_mv_test.py} +0 -0
- /brainstate/event/{_fixed_probability.py → _fixedprob_mv.py} +0 -0
- /brainstate/event/{_fixed_probability_benchmark.py → _fixedprob_mv_benchmark.py} +0 -0
- /brainstate/event/{_fixed_probability_test.py → _fixedprob_mv_test.py} +0 -0
- /brainstate/event/{_linear.py → _linear_mv.py} +0 -0
- /brainstate/event/{_linear_benckmark.py → _linear_mv_benckmark.py} +0 -0
- {brainstate-0.1.0.post20241122.dist-info → brainstate-0.1.0.post20241129.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20241122.dist-info → brainstate-0.1.0.post20241129.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20241122.dist-info → brainstate-0.1.0.post20241129.dist-info}/top_level.txt +0 -0
brainstate/_state.py
CHANGED
brainstate/augment/_autograd.py
CHANGED
@@ -29,15 +29,11 @@ The wrapped gradient transformations here are made possible by using the followi
|
|
29
29
|
|
30
30
|
from __future__ import annotations
|
31
31
|
|
32
|
-
import
|
33
|
-
from functools import partial, wraps
|
32
|
+
from functools import wraps, partial
|
34
33
|
from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
|
35
34
|
|
35
|
+
import brainunit as u
|
36
36
|
import jax
|
37
|
-
from jax import numpy as jnp
|
38
|
-
from jax._src.api import _vjp
|
39
|
-
from jax.api_util import argnums_partial
|
40
|
-
from jax.extend import linear_util
|
41
37
|
|
42
38
|
from brainstate._state import State, StateTraceStack
|
43
39
|
from brainstate._utils import set_module_as
|
@@ -54,54 +50,15 @@ LossValue = PyTree
|
|
54
50
|
AuxData = PyTree
|
55
51
|
|
56
52
|
|
57
|
-
def
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
def _check_callable(fun):
|
67
|
-
# In Python 3.10+, the only thing stopping us from supporting staticmethods
|
68
|
-
# is that we can't take weak references to them, which the C++ JIT requires.
|
69
|
-
if isinstance(fun, staticmethod):
|
70
|
-
raise TypeError(f"staticmethod arguments are not supported, got {fun}")
|
71
|
-
if not callable(fun):
|
72
|
-
raise TypeError(f"Expected a callable value, got {fun}")
|
73
|
-
if _isgeneratorfunction(fun):
|
74
|
-
raise TypeError(f"Expected a function, got a generator function: {fun}")
|
75
|
-
|
76
|
-
|
77
|
-
def functional_vector_grad(func, argnums=0, return_value: bool = False, has_aux: bool = False):
|
78
|
-
"""
|
79
|
-
Compute the gradient of a vector with respect to the input.
|
80
|
-
"""
|
81
|
-
_check_callable(func)
|
82
|
-
|
83
|
-
@wraps(func)
|
84
|
-
def grad_fun(*args, **kwargs):
|
85
|
-
f = linear_util.wrap_init(func, kwargs)
|
86
|
-
f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False)
|
87
|
-
if has_aux:
|
88
|
-
y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True)
|
89
|
-
else:
|
90
|
-
y, vjp_fn = _vjp(f_partial, *dyn_args, has_aux=False)
|
91
|
-
leaves, tree = jax.tree.flatten(y)
|
92
|
-
tangents = jax.tree.unflatten(tree, [jnp.ones(l.shape, dtype=l.dtype) for l in leaves])
|
93
|
-
grads = vjp_fn(tangents)
|
94
|
-
if isinstance(argnums, int):
|
95
|
-
grads = grads[0]
|
96
|
-
if has_aux:
|
97
|
-
return (grads, y, aux) if return_value else (grads, aux)
|
98
|
-
else:
|
99
|
-
return (grads, y) if return_value else grads
|
100
|
-
|
101
|
-
return grad_fun
|
102
|
-
|
103
|
-
|
104
|
-
def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, return_value=False):
|
53
|
+
def _jacrev(
|
54
|
+
fun,
|
55
|
+
argnums=0,
|
56
|
+
holomorphic=False,
|
57
|
+
allow_int=False,
|
58
|
+
has_aux=False,
|
59
|
+
return_value=False,
|
60
|
+
unit_aware=False,
|
61
|
+
):
|
105
62
|
@wraps(fun)
|
106
63
|
def fun_wrapped(*args, **kwargs):
|
107
64
|
if has_aux:
|
@@ -117,7 +74,18 @@ def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, r
|
|
117
74
|
else:
|
118
75
|
return y, None
|
119
76
|
|
120
|
-
|
77
|
+
if unit_aware:
|
78
|
+
transform = u.autograd.jacrev(fun_wrapped,
|
79
|
+
argnums=argnums,
|
80
|
+
holomorphic=holomorphic,
|
81
|
+
allow_int=allow_int,
|
82
|
+
has_aux=True)
|
83
|
+
else:
|
84
|
+
transform = jax.jacrev(fun_wrapped,
|
85
|
+
argnums=argnums,
|
86
|
+
holomorphic=holomorphic,
|
87
|
+
allow_int=allow_int,
|
88
|
+
has_aux=True)
|
121
89
|
|
122
90
|
@wraps(fun)
|
123
91
|
def jacfun(*args, **kwargs):
|
@@ -130,7 +98,14 @@ def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, r
|
|
130
98
|
return jacfun
|
131
99
|
|
132
100
|
|
133
|
-
def _jacfwd(
|
101
|
+
def _jacfwd(
|
102
|
+
fun,
|
103
|
+
argnums=0,
|
104
|
+
holomorphic=False,
|
105
|
+
has_aux=False,
|
106
|
+
return_value=False,
|
107
|
+
unit_aware=False,
|
108
|
+
):
|
134
109
|
@wraps(fun)
|
135
110
|
def fun_wrapped(*args, **kwargs):
|
136
111
|
if has_aux:
|
@@ -146,7 +121,16 @@ def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False
|
|
146
121
|
else:
|
147
122
|
return y, None
|
148
123
|
|
149
|
-
|
124
|
+
if unit_aware:
|
125
|
+
transform = u.autograd.jacfwd(fun_wrapped,
|
126
|
+
argnums=argnums,
|
127
|
+
holomorphic=holomorphic,
|
128
|
+
has_aux=True)
|
129
|
+
else:
|
130
|
+
transform = jax.jacfwd(fun_wrapped,
|
131
|
+
argnums=argnums,
|
132
|
+
holomorphic=holomorphic,
|
133
|
+
has_aux=True)
|
150
134
|
|
151
135
|
@wraps(fun)
|
152
136
|
def jacfun(*args, **kwargs):
|
@@ -323,9 +307,9 @@ def grad(
|
|
323
307
|
argnums: Optional[Union[int, Sequence[int]]] = None,
|
324
308
|
holomorphic: Optional[bool] = False,
|
325
309
|
allow_int: Optional[bool] = False,
|
326
|
-
reduce_axes: Optional[Sequence[str]] = (),
|
327
310
|
has_aux: Optional[bool] = None,
|
328
311
|
return_value: Optional[bool] = False,
|
312
|
+
unit_aware: bool = False,
|
329
313
|
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
330
314
|
"""
|
331
315
|
Compute the gradient of a scalar-valued function with respect to its arguments.
|
@@ -333,27 +317,24 @@ def grad(
|
|
333
317
|
%s
|
334
318
|
|
335
319
|
Args:
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
differentiated and the second element is auxiliary data. Default False.
|
355
|
-
return_value: (bool) optional. Indicates whether to return the value of the
|
356
|
-
function along with the gradient. Default False.
|
320
|
+
fun: callable. the scalar-valued function to be differentiated.
|
321
|
+
allow_int: (bool) optional. Whether to allow differentiating with respect to
|
322
|
+
integer valued inputs. The gradient of an integer input will have a trivial
|
323
|
+
vector-space dtype (float0). Default False.
|
324
|
+
holomorphic: (bool) optional. Whether fun is promised to be holomorphic.
|
325
|
+
Default False.
|
326
|
+
grad_states: (State, Sequence[State], Dict[str, State]) optional. The variables
|
327
|
+
in fun to take their gradients.
|
328
|
+
fun: the scalar-valued function to be differentiated.
|
329
|
+
argnums: (int or tuple of ints) optional. Specifies which positional
|
330
|
+
argument(s) to differentiate with respect to.
|
331
|
+
has_aux: (bool) optional. Indicates whether fun returns a pair where the
|
332
|
+
first element is considered the output of the mathematical function to be
|
333
|
+
differentiated and the second element is auxiliary data. Default False.
|
334
|
+
return_value: (bool) optional. Indicates whether to return the value of the
|
335
|
+
function along with the gradient. Default False.
|
336
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
337
|
+
mode. Default False.
|
357
338
|
|
358
339
|
Returns:
|
359
340
|
A function which computes the gradient of fun. The function takes the same
|
@@ -367,26 +348,24 @@ def grad(
|
|
367
348
|
if isinstance(fun, Missing):
|
368
349
|
def transform(fun) -> GradientTransform:
|
369
350
|
return GradientTransform(target=fun,
|
370
|
-
transform=jax.grad,
|
351
|
+
transform=u.autograd.grad if unit_aware else jax.grad,
|
371
352
|
grad_states=grad_states,
|
372
353
|
argnums=argnums,
|
373
354
|
return_value=return_value,
|
374
355
|
has_aux=False if has_aux is None else has_aux,
|
375
356
|
transform_params=dict(holomorphic=holomorphic,
|
376
|
-
allow_int=allow_int
|
377
|
-
reduce_axes=reduce_axes))
|
357
|
+
allow_int=allow_int))
|
378
358
|
|
379
359
|
return transform
|
380
360
|
|
381
361
|
return GradientTransform(target=fun,
|
382
|
-
transform=jax.grad,
|
362
|
+
transform=u.autograd.grad if unit_aware else jax.grad,
|
383
363
|
grad_states=grad_states,
|
384
364
|
argnums=argnums,
|
385
365
|
return_value=return_value,
|
386
366
|
has_aux=False if has_aux is None else has_aux,
|
387
367
|
transform_params=dict(holomorphic=holomorphic,
|
388
|
-
allow_int=allow_int
|
389
|
-
reduce_axes=reduce_axes))
|
368
|
+
allow_int=allow_int))
|
390
369
|
|
391
370
|
|
392
371
|
grad.__doc__ = grad.__doc__ % _doc_of_return
|
@@ -399,6 +378,7 @@ def vector_grad(
|
|
399
378
|
argnums: Optional[Union[int, Sequence[int]]] = None,
|
400
379
|
return_value: bool = False,
|
401
380
|
has_aux: Optional[bool] = None,
|
381
|
+
unit_aware: bool = False,
|
402
382
|
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
403
383
|
"""Take vector-valued gradients for function ``func``.
|
404
384
|
|
@@ -410,28 +390,30 @@ def vector_grad(
|
|
410
390
|
Parameters
|
411
391
|
----------
|
412
392
|
func: Callable
|
413
|
-
|
393
|
+
Function whose gradient is to be computed.
|
414
394
|
grad_states : optional, ArrayType, sequence of ArrayType, dict
|
415
|
-
|
395
|
+
The variables in ``func`` to take their gradients.
|
416
396
|
has_aux: optional, bool
|
417
|
-
|
418
|
-
|
419
|
-
|
397
|
+
Indicates whether ``fun`` returns a pair where the
|
398
|
+
first element is considered the output of the mathematical function to be
|
399
|
+
differentiated and the second element is auxiliary data. Default False.
|
420
400
|
return_value : bool
|
421
|
-
|
401
|
+
Whether return the loss value.
|
422
402
|
argnums: Optional, integer or sequence of integers. Specifies which
|
423
|
-
|
403
|
+
positional argument(s) to differentiate with respect to (default ``0``).
|
404
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
405
|
+
mode. Default False.
|
424
406
|
|
425
407
|
Returns
|
426
408
|
-------
|
427
409
|
func : GradientTransform
|
428
|
-
|
410
|
+
The vector gradient function.
|
429
411
|
"""
|
430
412
|
|
431
413
|
if isinstance(func, Missing):
|
432
414
|
def transform(fun) -> GradientTransform:
|
433
415
|
return GradientTransform(target=fun,
|
434
|
-
transform=
|
416
|
+
transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
|
435
417
|
grad_states=grad_states,
|
436
418
|
argnums=argnums,
|
437
419
|
return_value=return_value,
|
@@ -441,7 +423,7 @@ def vector_grad(
|
|
441
423
|
|
442
424
|
else:
|
443
425
|
return GradientTransform(target=func,
|
444
|
-
transform=
|
426
|
+
transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
|
445
427
|
grad_states=grad_states,
|
446
428
|
argnums=argnums,
|
447
429
|
return_value=return_value,
|
@@ -460,6 +442,7 @@ def jacrev(
|
|
460
442
|
return_value: bool = False,
|
461
443
|
holomorphic: bool = False,
|
462
444
|
allow_int: bool = False,
|
445
|
+
unit_aware: bool = False,
|
463
446
|
) -> GradientTransform:
|
464
447
|
"""
|
465
448
|
Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
|
@@ -473,25 +456,28 @@ def jacrev(
|
|
473
456
|
|
474
457
|
Parameters
|
475
458
|
----------
|
476
|
-
fun:
|
459
|
+
fun: Callable
|
460
|
+
Function whose Jacobian is to be computed.
|
477
461
|
grad_states : optional, ArrayType, sequence of ArrayType, dict
|
478
|
-
|
462
|
+
The variables in ``func`` to take their gradients.
|
479
463
|
has_aux: optional, bool
|
480
|
-
|
481
|
-
|
482
|
-
|
464
|
+
Indicates whether ``fun`` returns a pair where the
|
465
|
+
first element is considered the output of the mathematical function to be
|
466
|
+
differentiated and the second element is auxiliary data. Default False.
|
483
467
|
return_value : bool
|
484
|
-
|
468
|
+
Whether return the loss value.
|
485
469
|
argnums: Optional, integer or sequence of integers.
|
486
|
-
|
487
|
-
|
470
|
+
Specifies which
|
471
|
+
positional argument(s) to differentiate with respect to (default ``0``).
|
488
472
|
holomorphic: Optional, bool.
|
489
|
-
|
490
|
-
|
473
|
+
Indicates whether ``fun`` is promised to be
|
474
|
+
holomorphic. Default False.
|
491
475
|
allow_int: Optional, bool.
|
492
|
-
|
493
|
-
|
494
|
-
|
476
|
+
Whether to allow differentiating with
|
477
|
+
respect to integer valued inputs. The gradient of an integer input will
|
478
|
+
have a trivial vector-space dtype (float0). Default False.
|
479
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
480
|
+
mode. Default False.
|
495
481
|
|
496
482
|
Returns
|
497
483
|
-------
|
@@ -505,7 +491,8 @@ def jacrev(
|
|
505
491
|
return_value=return_value,
|
506
492
|
has_aux=False if has_aux is None else has_aux,
|
507
493
|
transform_params=dict(holomorphic=holomorphic,
|
508
|
-
allow_int=allow_int
|
494
|
+
allow_int=allow_int,
|
495
|
+
unit_aware=unit_aware, ))
|
509
496
|
|
510
497
|
|
511
498
|
jacrev.__doc__ = jacrev.__doc__ % _doc_of_return
|
@@ -521,6 +508,7 @@ def jacfwd(
|
|
521
508
|
has_aux: Optional[bool] = None,
|
522
509
|
return_value: bool = False,
|
523
510
|
holomorphic: bool = False,
|
511
|
+
unit_aware: bool = False,
|
524
512
|
) -> GradientTransform:
|
525
513
|
"""Extending automatic Jacobian (forward-mode) of ``func`` to classes.
|
526
514
|
|
@@ -542,9 +530,11 @@ def jacfwd(
|
|
542
530
|
return_value : bool
|
543
531
|
Whether return the loss value.
|
544
532
|
argnums: Optional, integer or sequence of integers. Specifies which
|
545
|
-
|
533
|
+
positional argument(s) to differentiate with respect to (default ``0``).
|
546
534
|
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
547
|
-
|
535
|
+
holomorphic. Default False.
|
536
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
537
|
+
mode. Default False.
|
548
538
|
|
549
539
|
Returns
|
550
540
|
-------
|
@@ -558,7 +548,8 @@ def jacfwd(
|
|
558
548
|
argnums=argnums,
|
559
549
|
return_value=return_value,
|
560
550
|
has_aux=False if has_aux is None else has_aux,
|
561
|
-
transform_params=dict(holomorphic=holomorphic
|
551
|
+
transform_params=dict(holomorphic=holomorphic,
|
552
|
+
unit_aware=unit_aware))
|
562
553
|
|
563
554
|
|
564
555
|
jacfwd.__doc__ = jacfwd.__doc__ % _doc_of_return
|
@@ -569,9 +560,10 @@ def hessian(
|
|
569
560
|
func: Callable,
|
570
561
|
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
571
562
|
argnums: Optional[Union[int, Sequence[int]]] = None,
|
572
|
-
has_aux: bool = False,
|
573
563
|
return_value: bool = False,
|
574
564
|
holomorphic: bool = False,
|
565
|
+
has_aux: Optional[bool] = None,
|
566
|
+
unit_aware: bool = False,
|
575
567
|
) -> GradientTransform:
|
576
568
|
"""
|
577
569
|
Hessian of ``func`` as a dense array.
|
@@ -593,6 +585,12 @@ def hessian(
|
|
593
585
|
Indicates whether ``fun`` is promised to be holomorphic. Default False.
|
594
586
|
return_value : bool
|
595
587
|
Whether return the hessian values.
|
588
|
+
has_aux: Optional, bool
|
589
|
+
Indicates whether ``fun`` returns a pair where the first element is considered
|
590
|
+
the output of the mathematical function to be differentiated and the second
|
591
|
+
element is auxiliary data. Default False.
|
592
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
593
|
+
mode. Default False.
|
596
594
|
|
597
595
|
Returns
|
598
596
|
-------
|
@@ -600,7 +598,7 @@ def hessian(
|
|
600
598
|
The transformed object.
|
601
599
|
"""
|
602
600
|
return GradientTransform(target=func,
|
603
|
-
transform=jax.hessian,
|
601
|
+
transform=u.autograd.hessian if unit_aware else jax.hessian,
|
604
602
|
grad_states=grad_states,
|
605
603
|
argnums=argnums,
|
606
604
|
return_value=return_value,
|
@@ -19,6 +19,7 @@ from __future__ import annotations
|
|
19
19
|
import unittest
|
20
20
|
from pprint import pprint
|
21
21
|
|
22
|
+
import brainunit as u
|
22
23
|
import jax
|
23
24
|
import jax.numpy as jnp
|
24
25
|
import pytest
|
@@ -608,6 +609,8 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
608
609
|
br = bst.augment.jacrev(t, grad_states=[t.x, t.y])()
|
609
610
|
self.assertTrue((br[0] == _jr[0]).all())
|
610
611
|
self.assertTrue((br[1] == _jr[1]).all())
|
612
|
+
|
613
|
+
|
611
614
|
#
|
612
615
|
# def test_jacfwd1(self):
|
613
616
|
# def f1(x, y):
|
@@ -1191,3 +1194,97 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
1191
1194
|
# self.assertTrue(file.read().strip() == expect_res.strip())
|
1192
1195
|
#
|
1193
1196
|
#
|
1197
|
+
|
1198
|
+
|
1199
|
+
class TestUnitAwareGrad(unittest.TestCase):
|
1200
|
+
def test_grad1(self):
|
1201
|
+
def f(x):
|
1202
|
+
return u.math.sum(x ** 2)
|
1203
|
+
|
1204
|
+
x = jnp.array([1., 2., 3.]) * u.ms
|
1205
|
+
g = bst.augment.grad(f, unit_aware=True)(x)
|
1206
|
+
self.assertTrue(u.math.allclose(g, 2 * x))
|
1207
|
+
|
1208
|
+
def test_vector_grad1(self):
|
1209
|
+
def f(x):
|
1210
|
+
return x ** 3
|
1211
|
+
|
1212
|
+
x = jnp.array([1., 2., 3.]) * u.ms
|
1213
|
+
g = bst.augment.vector_grad(f, unit_aware=True)(x)
|
1214
|
+
self.assertTrue(u.math.allclose(g, 3 * x ** 2))
|
1215
|
+
|
1216
|
+
def test_jacrev1(self):
|
1217
|
+
def f(x, y):
|
1218
|
+
return u.math.asarray([x[0] * y[0],
|
1219
|
+
5 * x[2] * y[1],
|
1220
|
+
4 * x[1] ** 2, ])
|
1221
|
+
|
1222
|
+
_x = jnp.array([1., 2., 3.]) * u.ms
|
1223
|
+
_y = jnp.array([10., 5.]) * u.ms
|
1224
|
+
|
1225
|
+
g = bst.augment.jacrev(f, unit_aware=True, argnums=(0, 1))(_x, _y)
|
1226
|
+
self.assertTrue(
|
1227
|
+
u.math.allclose(
|
1228
|
+
g[0],
|
1229
|
+
u.math.asarray([
|
1230
|
+
[10., 0., 0.],
|
1231
|
+
[0., 0., 25.],
|
1232
|
+
[0., 16., 0.]
|
1233
|
+
]) * u.ms
|
1234
|
+
)
|
1235
|
+
)
|
1236
|
+
|
1237
|
+
self.assertTrue(
|
1238
|
+
u.math.allclose(
|
1239
|
+
g[1],
|
1240
|
+
u.math.asarray([
|
1241
|
+
[1., 0.],
|
1242
|
+
[0., 15.],
|
1243
|
+
[0., 0.]
|
1244
|
+
]) * u.ms
|
1245
|
+
)
|
1246
|
+
)
|
1247
|
+
|
1248
|
+
def test_jacfwd1(self):
|
1249
|
+
def f(x, y):
|
1250
|
+
return u.math.asarray([x[0] * y[0],
|
1251
|
+
5 * x[2] * y[1],
|
1252
|
+
4 * x[1] ** 2, ])
|
1253
|
+
|
1254
|
+
_x = jnp.array([1., 2., 3.]) * u.ms
|
1255
|
+
_y = jnp.array([10., 5.]) * u.ms
|
1256
|
+
|
1257
|
+
g = bst.augment.jacfwd(f, unit_aware=True, argnums=(0, 1))(_x, _y)
|
1258
|
+
self.assertTrue(
|
1259
|
+
u.math.allclose(
|
1260
|
+
g[0],
|
1261
|
+
u.math.asarray([
|
1262
|
+
[10., 0., 0.],
|
1263
|
+
[0., 0., 25.],
|
1264
|
+
[0., 16., 0.]
|
1265
|
+
]) * u.ms
|
1266
|
+
)
|
1267
|
+
)
|
1268
|
+
|
1269
|
+
self.assertTrue(
|
1270
|
+
u.math.allclose(
|
1271
|
+
g[1],
|
1272
|
+
u.math.asarray([
|
1273
|
+
[1., 0.],
|
1274
|
+
[0., 15.],
|
1275
|
+
[0., 0.]
|
1276
|
+
]) * u.ms
|
1277
|
+
)
|
1278
|
+
)
|
1279
|
+
|
1280
|
+
def test_hessian(self):
|
1281
|
+
unit = u.ms
|
1282
|
+
|
1283
|
+
def scalar_function(x):
|
1284
|
+
return x ** 3 + 3 * x * unit * unit + 2 * unit * unit * unit
|
1285
|
+
|
1286
|
+
hess = bst.augment.hessian(scalar_function, unit_aware=True)
|
1287
|
+
x = jnp.array(1.0) * unit
|
1288
|
+
res = hess(x)
|
1289
|
+
expected_hessian = jnp.array([[6.0]]) * unit
|
1290
|
+
assert u.math.allclose(res, expected_hessian)
|
brainstate/event/__init__.py
CHANGED
@@ -14,14 +14,14 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
-
from .
|
18
|
-
from .
|
19
|
-
from .
|
20
|
-
from .
|
21
|
-
from .
|
17
|
+
from ._csr_mv import *
|
18
|
+
from ._csr_mv import __all__ as __all_csr
|
19
|
+
from ._fixedprob_mv import *
|
20
|
+
from ._fixedprob_mv import __all__ as __all_fixed_probability
|
21
|
+
from ._linear_mv import *
|
22
22
|
from ._xla_custom_op import *
|
23
23
|
from ._xla_custom_op import __all__ as __all_xla_custom_op
|
24
|
-
from .
|
24
|
+
from ._linear_mv import __all__ as __all_linear
|
25
25
|
|
26
26
|
__all__ = __all_fixed_probability + __all_linear + __all_csr + __all_xla_custom_op
|
27
27
|
del __all_fixed_probability, __all_linear, __all_csr, __all_xla_custom_op
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -17,14 +17,8 @@ from jaxlib.hlo_helpers import custom_call
|
|
17
17
|
|
18
18
|
numba_installed = importlib.util.find_spec('numba') is not None
|
19
19
|
|
20
|
-
if numba_installed:
|
21
|
-
import numba # pylint: disable=import-error
|
22
|
-
from numba import types, carray, cfunc # pylint: disable=import-error
|
23
|
-
from numba.core.dispatcher import Dispatcher # pylint: disable=import-error
|
24
|
-
else:
|
25
|
-
numba = None
|
26
|
-
|
27
20
|
__all__ = [
|
21
|
+
'defjvp',
|
28
22
|
'XLACustomOp',
|
29
23
|
]
|
30
24
|
|
@@ -93,9 +87,12 @@ def _numba_mlir_cpu_translation_rule(
|
|
93
87
|
*ins,
|
94
88
|
**kwargs
|
95
89
|
):
|
96
|
-
if
|
90
|
+
if not numba_installed:
|
97
91
|
raise ImportError('Numba is required to compile the CPU kernel for the custom operator.')
|
98
92
|
|
93
|
+
from numba import types, carray, cfunc # pylint: disable=import-error
|
94
|
+
from numba.core.dispatcher import Dispatcher # pylint: disable=import-error
|
95
|
+
|
99
96
|
if not isinstance(kernel, Dispatcher):
|
100
97
|
kernel = kernel(**kwargs)
|
101
98
|
assert isinstance(kernel, Dispatcher), f'The kernel should be a Numba dispatcher. But we got {kernel}'
|
@@ -88,7 +88,7 @@ class _DropoutNd(ElementWiseBlock):
|
|
88
88
|
name: Optional[str] = None
|
89
89
|
) -> None:
|
90
90
|
super().__init__(name=name)
|
91
|
-
assert 0. <= prob
|
91
|
+
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
92
92
|
self.prob = prob
|
93
93
|
self.channel_axis = channel_axis
|
94
94
|
|
@@ -112,7 +112,7 @@ class _DropoutNd(ElementWiseBlock):
|
|
112
112
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
113
113
|
|
114
114
|
# generate mask
|
115
|
-
if fit_phase
|
115
|
+
if fit_phase and self.prob < 1.:
|
116
116
|
dtype = u.math.get_dtype(x)
|
117
117
|
keep_mask = jnp.broadcast_to(random.bernoulli(self.prob, mask_shape), x.shape)
|
118
118
|
return jnp.where(keep_mask,
|
@@ -396,7 +396,7 @@ class DropoutFixed(ElementWiseBlock):
|
|
396
396
|
name: Optional[str] = None
|
397
397
|
) -> None:
|
398
398
|
super().__init__(name=name)
|
399
|
-
assert 0. <= prob
|
399
|
+
assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
|
400
400
|
self.prob = prob
|
401
401
|
self.in_size = in_size
|
402
402
|
self.out_size = in_size
|
@@ -407,7 +407,7 @@ class DropoutFixed(ElementWiseBlock):
|
|
407
407
|
def update(self, x):
|
408
408
|
dtype = u.math.get_dtype(x)
|
409
409
|
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
|
410
|
-
if fit_phase
|
410
|
+
if fit_phase and self.prob < 1.:
|
411
411
|
if self.mask.value.shape != x.shape:
|
412
412
|
raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
|
413
413
|
f"Please call `init_state()` method first.")
|
@@ -79,7 +79,7 @@ class Linear(Module):
|
|
79
79
|
weight = params['weight']
|
80
80
|
if self.w_mask is not None:
|
81
81
|
weight = weight * self.w_mask
|
82
|
-
y = u.
|
82
|
+
y = u.linalg.dot(x, weight)
|
83
83
|
if 'bias' in params:
|
84
84
|
y = y + params['bias']
|
85
85
|
return y
|
@@ -192,7 +192,7 @@ class ScaledWSLinear(Module):
|
|
192
192
|
w = functional.weight_standardization(w, self.eps, params.get('gain', None))
|
193
193
|
if self.w_mask is not None:
|
194
194
|
w = w * self.w_mask
|
195
|
-
y = u.
|
195
|
+
y = u.linalg.dot(x, w)
|
196
196
|
if 'bias' in params:
|
197
197
|
y = y + params['bias']
|
198
198
|
return y
|