brainstate 0.1.0__py2.py3-none-any.whl → 0.1.0.post20241125__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/_state.py +1 -1
- brainstate/augment/_autograd.py +121 -120
- brainstate/augment/_autograd_test.py +97 -0
- brainstate/event/__init__.py +10 -8
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/{_csr.py → _csr_mv.py} +26 -18
- brainstate/event/_csr_mv_benchmark.py +14 -0
- brainstate/event/_fixedprob_mv.py +708 -0
- brainstate/event/_fixedprob_mv_benchmark.py +128 -0
- brainstate/event/{_fixed_probability_test.py → _fixedprob_mv_test.py} +13 -10
- brainstate/event/_linear_mv.py +359 -0
- brainstate/event/_linear_mv_benckmark.py +82 -0
- brainstate/event/{_linear_test.py → _linear_mv_test.py} +9 -4
- brainstate/event/_xla_custom_op.py +309 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +6 -11
- brainstate/nn/_dyn_impl/_rate_rnns.py +1 -1
- brainstate/nn/_dynamics/_projection_base.py +1 -1
- brainstate/nn/_exp_euler.py +1 -1
- brainstate/nn/_interaction/__init__.py +13 -4
- brainstate/nn/_interaction/{_connections.py → _conv.py} +0 -227
- brainstate/nn/_interaction/{_connections_test.py → _conv_test.py} +0 -15
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/optim/_lr_scheduler.py +1 -1
- brainstate/optim/_optax_optimizer.py +19 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/METADATA +2 -2
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/RECORD +34 -24
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/top_level.txt +1 -0
- brainstate/event/_fixed_probability.py +0 -271
- brainstate/event/_linear.py +0 -219
- /brainstate/event/{_csr_test.py → _csr_mv_test.py} +0 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.dist-info → brainstate-0.1.0.post20241125.dist-info}/WHEEL +0 -0
brainstate/augment/_autograd.py
CHANGED
@@ -29,15 +29,11 @@ The wrapped gradient transformations here are made possible by using the followi
|
|
29
29
|
|
30
30
|
from __future__ import annotations
|
31
31
|
|
32
|
-
import
|
33
|
-
from functools import partial, wraps
|
32
|
+
from functools import wraps, partial
|
34
33
|
from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
|
35
34
|
|
35
|
+
import brainunit as u
|
36
36
|
import jax
|
37
|
-
from jax import numpy as jnp
|
38
|
-
from jax._src.api import _vjp
|
39
|
-
from jax.api_util import argnums_partial
|
40
|
-
from jax.extend import linear_util
|
41
37
|
|
42
38
|
from brainstate._state import State, StateTraceStack
|
43
39
|
from brainstate._utils import set_module_as
|
@@ -45,7 +41,7 @@ from brainstate.typing import PyTree, Missing
|
|
45
41
|
from brainstate.util import PrettyType, PrettyAttr, PrettyRepr
|
46
42
|
|
47
43
|
__all__ = [
|
48
|
-
'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
|
44
|
+
'GradientTransform', 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
|
49
45
|
]
|
50
46
|
|
51
47
|
A = TypeVar('A')
|
@@ -54,54 +50,15 @@ LossValue = PyTree
|
|
54
50
|
AuxData = PyTree
|
55
51
|
|
56
52
|
|
57
|
-
def
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
def _check_callable(fun):
|
67
|
-
# In Python 3.10+, the only thing stopping us from supporting staticmethods
|
68
|
-
# is that we can't take weak references to them, which the C++ JIT requires.
|
69
|
-
if isinstance(fun, staticmethod):
|
70
|
-
raise TypeError(f"staticmethod arguments are not supported, got {fun}")
|
71
|
-
if not callable(fun):
|
72
|
-
raise TypeError(f"Expected a callable value, got {fun}")
|
73
|
-
if _isgeneratorfunction(fun):
|
74
|
-
raise TypeError(f"Expected a function, got a generator function: {fun}")
|
75
|
-
|
76
|
-
|
77
|
-
def functional_vector_grad(func, argnums=0, return_value: bool = False, has_aux: bool = False):
|
78
|
-
"""
|
79
|
-
Compute the gradient of a vector with respect to the input.
|
80
|
-
"""
|
81
|
-
_check_callable(func)
|
82
|
-
|
83
|
-
@wraps(func)
|
84
|
-
def grad_fun(*args, **kwargs):
|
85
|
-
f = linear_util.wrap_init(func, kwargs)
|
86
|
-
f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False)
|
87
|
-
if has_aux:
|
88
|
-
y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True)
|
89
|
-
else:
|
90
|
-
y, vjp_fn = _vjp(f_partial, *dyn_args, has_aux=False)
|
91
|
-
leaves, tree = jax.tree.flatten(y)
|
92
|
-
tangents = jax.tree.unflatten(tree, [jnp.ones(l.shape, dtype=l.dtype) for l in leaves])
|
93
|
-
grads = vjp_fn(tangents)
|
94
|
-
if isinstance(argnums, int):
|
95
|
-
grads = grads[0]
|
96
|
-
if has_aux:
|
97
|
-
return (grads, y, aux) if return_value else (grads, aux)
|
98
|
-
else:
|
99
|
-
return (grads, y) if return_value else grads
|
100
|
-
|
101
|
-
return grad_fun
|
102
|
-
|
103
|
-
|
104
|
-
def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, return_value=False):
|
53
|
+
def _jacrev(
|
54
|
+
fun,
|
55
|
+
argnums=0,
|
56
|
+
holomorphic=False,
|
57
|
+
allow_int=False,
|
58
|
+
has_aux=False,
|
59
|
+
return_value=False,
|
60
|
+
unit_aware=False,
|
61
|
+
):
|
105
62
|
@wraps(fun)
|
106
63
|
def fun_wrapped(*args, **kwargs):
|
107
64
|
if has_aux:
|
@@ -117,7 +74,18 @@ def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, r
|
|
117
74
|
else:
|
118
75
|
return y, None
|
119
76
|
|
120
|
-
|
77
|
+
if unit_aware:
|
78
|
+
transform = u.autograd.jacrev(fun_wrapped,
|
79
|
+
argnums=argnums,
|
80
|
+
holomorphic=holomorphic,
|
81
|
+
allow_int=allow_int,
|
82
|
+
has_aux=True)
|
83
|
+
else:
|
84
|
+
transform = jax.jacrev(fun_wrapped,
|
85
|
+
argnums=argnums,
|
86
|
+
holomorphic=holomorphic,
|
87
|
+
allow_int=allow_int,
|
88
|
+
has_aux=True)
|
121
89
|
|
122
90
|
@wraps(fun)
|
123
91
|
def jacfun(*args, **kwargs):
|
@@ -130,7 +98,14 @@ def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, r
|
|
130
98
|
return jacfun
|
131
99
|
|
132
100
|
|
133
|
-
def _jacfwd(
|
101
|
+
def _jacfwd(
|
102
|
+
fun,
|
103
|
+
argnums=0,
|
104
|
+
holomorphic=False,
|
105
|
+
has_aux=False,
|
106
|
+
return_value=False,
|
107
|
+
unit_aware=False,
|
108
|
+
):
|
134
109
|
@wraps(fun)
|
135
110
|
def fun_wrapped(*args, **kwargs):
|
136
111
|
if has_aux:
|
@@ -146,7 +121,16 @@ def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False
|
|
146
121
|
else:
|
147
122
|
return y, None
|
148
123
|
|
149
|
-
|
124
|
+
if unit_aware:
|
125
|
+
transform = u.autograd.jacfwd(fun_wrapped,
|
126
|
+
argnums=argnums,
|
127
|
+
holomorphic=holomorphic,
|
128
|
+
has_aux=True)
|
129
|
+
else:
|
130
|
+
transform = jax.jacfwd(fun_wrapped,
|
131
|
+
argnums=argnums,
|
132
|
+
holomorphic=holomorphic,
|
133
|
+
has_aux=True)
|
150
134
|
|
151
135
|
@wraps(fun)
|
152
136
|
def jacfun(*args, **kwargs):
|
@@ -159,6 +143,9 @@ def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False
|
|
159
143
|
return jacfun
|
160
144
|
|
161
145
|
|
146
|
+
TransformFn = Callable
|
147
|
+
|
148
|
+
|
162
149
|
class GradientTransform(PrettyRepr):
|
163
150
|
"""
|
164
151
|
Automatic Differentiation Transformations for the ``State`` system.
|
@@ -168,11 +155,11 @@ class GradientTransform(PrettyRepr):
|
|
168
155
|
def __init__(
|
169
156
|
self,
|
170
157
|
target: Callable,
|
171
|
-
transform:
|
172
|
-
grad_states:
|
173
|
-
argnums: Optional[Union[int, Sequence[int]]],
|
174
|
-
return_value: bool,
|
175
|
-
has_aux: bool,
|
158
|
+
transform: TransformFn,
|
159
|
+
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
160
|
+
argnums: Optional[Union[int, Sequence[int]]] = None,
|
161
|
+
return_value: bool = False,
|
162
|
+
has_aux: bool = False,
|
176
163
|
transform_params: Optional[Dict[str, Any]] = None,
|
177
164
|
):
|
178
165
|
# gradient variables
|
@@ -320,9 +307,9 @@ def grad(
|
|
320
307
|
argnums: Optional[Union[int, Sequence[int]]] = None,
|
321
308
|
holomorphic: Optional[bool] = False,
|
322
309
|
allow_int: Optional[bool] = False,
|
323
|
-
reduce_axes: Optional[Sequence[str]] = (),
|
324
310
|
has_aux: Optional[bool] = None,
|
325
311
|
return_value: Optional[bool] = False,
|
312
|
+
unit_aware: bool = False,
|
326
313
|
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
327
314
|
"""
|
328
315
|
Compute the gradient of a scalar-valued function with respect to its arguments.
|
@@ -330,27 +317,24 @@ def grad(
|
|
330
317
|
%s
|
331
318
|
|
332
319
|
Args:
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
differentiated and the second element is auxiliary data. Default False.
|
352
|
-
return_value: (bool) optional. Indicates whether to return the value of the
|
353
|
-
function along with the gradient. Default False.
|
320
|
+
fun: callable. the scalar-valued function to be differentiated.
|
321
|
+
allow_int: (bool) optional. Whether to allow differentiating with respect to
|
322
|
+
integer valued inputs. The gradient of an integer input will have a trivial
|
323
|
+
vector-space dtype (float0). Default False.
|
324
|
+
holomorphic: (bool) optional. Whether fun is promised to be holomorphic.
|
325
|
+
Default False.
|
326
|
+
grad_states: (State, Sequence[State], Dict[str, State]) optional. The variables
|
327
|
+
in fun to take their gradients.
|
328
|
+
fun: the scalar-valued function to be differentiated.
|
329
|
+
argnums: (int or tuple of ints) optional. Specifies which positional
|
330
|
+
argument(s) to differentiate with respect to.
|
331
|
+
has_aux: (bool) optional. Indicates whether fun returns a pair where the
|
332
|
+
first element is considered the output of the mathematical function to be
|
333
|
+
differentiated and the second element is auxiliary data. Default False.
|
334
|
+
return_value: (bool) optional. Indicates whether to return the value of the
|
335
|
+
function along with the gradient. Default False.
|
336
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
337
|
+
mode. Default False.
|
354
338
|
|
355
339
|
Returns:
|
356
340
|
A function which computes the gradient of fun. The function takes the same
|
@@ -364,26 +348,24 @@ def grad(
|
|
364
348
|
if isinstance(fun, Missing):
|
365
349
|
def transform(fun) -> GradientTransform:
|
366
350
|
return GradientTransform(target=fun,
|
367
|
-
transform=jax.grad,
|
351
|
+
transform=u.autograd.grad if unit_aware else jax.grad,
|
368
352
|
grad_states=grad_states,
|
369
353
|
argnums=argnums,
|
370
354
|
return_value=return_value,
|
371
355
|
has_aux=False if has_aux is None else has_aux,
|
372
356
|
transform_params=dict(holomorphic=holomorphic,
|
373
|
-
allow_int=allow_int
|
374
|
-
reduce_axes=reduce_axes))
|
357
|
+
allow_int=allow_int))
|
375
358
|
|
376
359
|
return transform
|
377
360
|
|
378
361
|
return GradientTransform(target=fun,
|
379
|
-
transform=jax.grad,
|
362
|
+
transform=u.autograd.grad if unit_aware else jax.grad,
|
380
363
|
grad_states=grad_states,
|
381
364
|
argnums=argnums,
|
382
365
|
return_value=return_value,
|
383
366
|
has_aux=False if has_aux is None else has_aux,
|
384
367
|
transform_params=dict(holomorphic=holomorphic,
|
385
|
-
allow_int=allow_int
|
386
|
-
reduce_axes=reduce_axes))
|
368
|
+
allow_int=allow_int))
|
387
369
|
|
388
370
|
|
389
371
|
grad.__doc__ = grad.__doc__ % _doc_of_return
|
@@ -396,6 +378,7 @@ def vector_grad(
|
|
396
378
|
argnums: Optional[Union[int, Sequence[int]]] = None,
|
397
379
|
return_value: bool = False,
|
398
380
|
has_aux: Optional[bool] = None,
|
381
|
+
unit_aware: bool = False,
|
399
382
|
) -> GradientTransform | Callable[[Callable], GradientTransform]:
|
400
383
|
"""Take vector-valued gradients for function ``func``.
|
401
384
|
|
@@ -407,28 +390,30 @@ def vector_grad(
|
|
407
390
|
Parameters
|
408
391
|
----------
|
409
392
|
func: Callable
|
410
|
-
|
393
|
+
Function whose gradient is to be computed.
|
411
394
|
grad_states : optional, ArrayType, sequence of ArrayType, dict
|
412
|
-
|
395
|
+
The variables in ``func`` to take their gradients.
|
413
396
|
has_aux: optional, bool
|
414
|
-
|
415
|
-
|
416
|
-
|
397
|
+
Indicates whether ``fun`` returns a pair where the
|
398
|
+
first element is considered the output of the mathematical function to be
|
399
|
+
differentiated and the second element is auxiliary data. Default False.
|
417
400
|
return_value : bool
|
418
|
-
|
401
|
+
Whether return the loss value.
|
419
402
|
argnums: Optional, integer or sequence of integers. Specifies which
|
420
|
-
|
403
|
+
positional argument(s) to differentiate with respect to (default ``0``).
|
404
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
405
|
+
mode. Default False.
|
421
406
|
|
422
407
|
Returns
|
423
408
|
-------
|
424
409
|
func : GradientTransform
|
425
|
-
|
410
|
+
The vector gradient function.
|
426
411
|
"""
|
427
412
|
|
428
413
|
if isinstance(func, Missing):
|
429
414
|
def transform(fun) -> GradientTransform:
|
430
415
|
return GradientTransform(target=fun,
|
431
|
-
transform=
|
416
|
+
transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
|
432
417
|
grad_states=grad_states,
|
433
418
|
argnums=argnums,
|
434
419
|
return_value=return_value,
|
@@ -438,7 +423,7 @@ def vector_grad(
|
|
438
423
|
|
439
424
|
else:
|
440
425
|
return GradientTransform(target=func,
|
441
|
-
transform=
|
426
|
+
transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
|
442
427
|
grad_states=grad_states,
|
443
428
|
argnums=argnums,
|
444
429
|
return_value=return_value,
|
@@ -457,6 +442,7 @@ def jacrev(
|
|
457
442
|
return_value: bool = False,
|
458
443
|
holomorphic: bool = False,
|
459
444
|
allow_int: bool = False,
|
445
|
+
unit_aware: bool = False,
|
460
446
|
) -> GradientTransform:
|
461
447
|
"""
|
462
448
|
Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
|
@@ -470,25 +456,28 @@ def jacrev(
|
|
470
456
|
|
471
457
|
Parameters
|
472
458
|
----------
|
473
|
-
fun:
|
459
|
+
fun: Callable
|
460
|
+
Function whose Jacobian is to be computed.
|
474
461
|
grad_states : optional, ArrayType, sequence of ArrayType, dict
|
475
|
-
|
462
|
+
The variables in ``func`` to take their gradients.
|
476
463
|
has_aux: optional, bool
|
477
|
-
|
478
|
-
|
479
|
-
|
464
|
+
Indicates whether ``fun`` returns a pair where the
|
465
|
+
first element is considered the output of the mathematical function to be
|
466
|
+
differentiated and the second element is auxiliary data. Default False.
|
480
467
|
return_value : bool
|
481
|
-
|
468
|
+
Whether return the loss value.
|
482
469
|
argnums: Optional, integer or sequence of integers.
|
483
|
-
|
484
|
-
|
470
|
+
Specifies which
|
471
|
+
positional argument(s) to differentiate with respect to (default ``0``).
|
485
472
|
holomorphic: Optional, bool.
|
486
|
-
|
487
|
-
|
473
|
+
Indicates whether ``fun`` is promised to be
|
474
|
+
holomorphic. Default False.
|
488
475
|
allow_int: Optional, bool.
|
489
|
-
|
490
|
-
|
491
|
-
|
476
|
+
Whether to allow differentiating with
|
477
|
+
respect to integer valued inputs. The gradient of an integer input will
|
478
|
+
have a trivial vector-space dtype (float0). Default False.
|
479
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
480
|
+
mode. Default False.
|
492
481
|
|
493
482
|
Returns
|
494
483
|
-------
|
@@ -502,7 +491,8 @@ def jacrev(
|
|
502
491
|
return_value=return_value,
|
503
492
|
has_aux=False if has_aux is None else has_aux,
|
504
493
|
transform_params=dict(holomorphic=holomorphic,
|
505
|
-
allow_int=allow_int
|
494
|
+
allow_int=allow_int,
|
495
|
+
unit_aware=unit_aware, ))
|
506
496
|
|
507
497
|
|
508
498
|
jacrev.__doc__ = jacrev.__doc__ % _doc_of_return
|
@@ -518,6 +508,7 @@ def jacfwd(
|
|
518
508
|
has_aux: Optional[bool] = None,
|
519
509
|
return_value: bool = False,
|
520
510
|
holomorphic: bool = False,
|
511
|
+
unit_aware: bool = False,
|
521
512
|
) -> GradientTransform:
|
522
513
|
"""Extending automatic Jacobian (forward-mode) of ``func`` to classes.
|
523
514
|
|
@@ -539,9 +530,11 @@ def jacfwd(
|
|
539
530
|
return_value : bool
|
540
531
|
Whether return the loss value.
|
541
532
|
argnums: Optional, integer or sequence of integers. Specifies which
|
542
|
-
|
533
|
+
positional argument(s) to differentiate with respect to (default ``0``).
|
543
534
|
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
|
544
|
-
|
535
|
+
holomorphic. Default False.
|
536
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
537
|
+
mode. Default False.
|
545
538
|
|
546
539
|
Returns
|
547
540
|
-------
|
@@ -555,7 +548,8 @@ def jacfwd(
|
|
555
548
|
argnums=argnums,
|
556
549
|
return_value=return_value,
|
557
550
|
has_aux=False if has_aux is None else has_aux,
|
558
|
-
transform_params=dict(holomorphic=holomorphic
|
551
|
+
transform_params=dict(holomorphic=holomorphic,
|
552
|
+
unit_aware=unit_aware))
|
559
553
|
|
560
554
|
|
561
555
|
jacfwd.__doc__ = jacfwd.__doc__ % _doc_of_return
|
@@ -566,9 +560,10 @@ def hessian(
|
|
566
560
|
func: Callable,
|
567
561
|
grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
|
568
562
|
argnums: Optional[Union[int, Sequence[int]]] = None,
|
569
|
-
has_aux: bool = False,
|
570
563
|
return_value: bool = False,
|
571
564
|
holomorphic: bool = False,
|
565
|
+
has_aux: Optional[bool] = None,
|
566
|
+
unit_aware: bool = False,
|
572
567
|
) -> GradientTransform:
|
573
568
|
"""
|
574
569
|
Hessian of ``func`` as a dense array.
|
@@ -590,6 +585,12 @@ def hessian(
|
|
590
585
|
Indicates whether ``fun`` is promised to be holomorphic. Default False.
|
591
586
|
return_value : bool
|
592
587
|
Whether return the hessian values.
|
588
|
+
has_aux: Optional, bool
|
589
|
+
Indicates whether ``fun`` returns a pair where the first element is considered
|
590
|
+
the output of the mathematical function to be differentiated and the second
|
591
|
+
element is auxiliary data. Default False.
|
592
|
+
unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
|
593
|
+
mode. Default False.
|
593
594
|
|
594
595
|
Returns
|
595
596
|
-------
|
@@ -597,7 +598,7 @@ def hessian(
|
|
597
598
|
The transformed object.
|
598
599
|
"""
|
599
600
|
return GradientTransform(target=func,
|
600
|
-
transform=jax.hessian,
|
601
|
+
transform=u.autograd.hessian if unit_aware else jax.hessian,
|
601
602
|
grad_states=grad_states,
|
602
603
|
argnums=argnums,
|
603
604
|
return_value=return_value,
|
@@ -19,6 +19,7 @@ from __future__ import annotations
|
|
19
19
|
import unittest
|
20
20
|
from pprint import pprint
|
21
21
|
|
22
|
+
import brainunit as u
|
22
23
|
import jax
|
23
24
|
import jax.numpy as jnp
|
24
25
|
import pytest
|
@@ -608,6 +609,8 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
608
609
|
br = bst.augment.jacrev(t, grad_states=[t.x, t.y])()
|
609
610
|
self.assertTrue((br[0] == _jr[0]).all())
|
610
611
|
self.assertTrue((br[1] == _jr[1]).all())
|
612
|
+
|
613
|
+
|
611
614
|
#
|
612
615
|
# def test_jacfwd1(self):
|
613
616
|
# def f1(x, y):
|
@@ -1191,3 +1194,97 @@ class TestClassFuncJacobian(unittest.TestCase):
|
|
1191
1194
|
# self.assertTrue(file.read().strip() == expect_res.strip())
|
1192
1195
|
#
|
1193
1196
|
#
|
1197
|
+
|
1198
|
+
|
1199
|
+
class TestUnitAwareGrad(unittest.TestCase):
|
1200
|
+
def test_grad1(self):
|
1201
|
+
def f(x):
|
1202
|
+
return u.math.sum(x ** 2)
|
1203
|
+
|
1204
|
+
x = jnp.array([1., 2., 3.]) * u.ms
|
1205
|
+
g = bst.augment.grad(f, unit_aware=True)(x)
|
1206
|
+
self.assertTrue(u.math.allclose(g, 2 * x))
|
1207
|
+
|
1208
|
+
def test_vector_grad1(self):
|
1209
|
+
def f(x):
|
1210
|
+
return x ** 3
|
1211
|
+
|
1212
|
+
x = jnp.array([1., 2., 3.]) * u.ms
|
1213
|
+
g = bst.augment.vector_grad(f, unit_aware=True)(x)
|
1214
|
+
self.assertTrue(u.math.allclose(g, 3 * x ** 2))
|
1215
|
+
|
1216
|
+
def test_jacrev1(self):
|
1217
|
+
def f(x, y):
|
1218
|
+
return u.math.asarray([x[0] * y[0],
|
1219
|
+
5 * x[2] * y[1],
|
1220
|
+
4 * x[1] ** 2, ])
|
1221
|
+
|
1222
|
+
_x = jnp.array([1., 2., 3.]) * u.ms
|
1223
|
+
_y = jnp.array([10., 5.]) * u.ms
|
1224
|
+
|
1225
|
+
g = bst.augment.jacrev(f, unit_aware=True, argnums=(0, 1))(_x, _y)
|
1226
|
+
self.assertTrue(
|
1227
|
+
u.math.allclose(
|
1228
|
+
g[0],
|
1229
|
+
u.math.asarray([
|
1230
|
+
[10., 0., 0.],
|
1231
|
+
[0., 0., 25.],
|
1232
|
+
[0., 16., 0.]
|
1233
|
+
]) * u.ms
|
1234
|
+
)
|
1235
|
+
)
|
1236
|
+
|
1237
|
+
self.assertTrue(
|
1238
|
+
u.math.allclose(
|
1239
|
+
g[1],
|
1240
|
+
u.math.asarray([
|
1241
|
+
[1., 0.],
|
1242
|
+
[0., 15.],
|
1243
|
+
[0., 0.]
|
1244
|
+
]) * u.ms
|
1245
|
+
)
|
1246
|
+
)
|
1247
|
+
|
1248
|
+
def test_jacfwd1(self):
|
1249
|
+
def f(x, y):
|
1250
|
+
return u.math.asarray([x[0] * y[0],
|
1251
|
+
5 * x[2] * y[1],
|
1252
|
+
4 * x[1] ** 2, ])
|
1253
|
+
|
1254
|
+
_x = jnp.array([1., 2., 3.]) * u.ms
|
1255
|
+
_y = jnp.array([10., 5.]) * u.ms
|
1256
|
+
|
1257
|
+
g = bst.augment.jacfwd(f, unit_aware=True, argnums=(0, 1))(_x, _y)
|
1258
|
+
self.assertTrue(
|
1259
|
+
u.math.allclose(
|
1260
|
+
g[0],
|
1261
|
+
u.math.asarray([
|
1262
|
+
[10., 0., 0.],
|
1263
|
+
[0., 0., 25.],
|
1264
|
+
[0., 16., 0.]
|
1265
|
+
]) * u.ms
|
1266
|
+
)
|
1267
|
+
)
|
1268
|
+
|
1269
|
+
self.assertTrue(
|
1270
|
+
u.math.allclose(
|
1271
|
+
g[1],
|
1272
|
+
u.math.asarray([
|
1273
|
+
[1., 0.],
|
1274
|
+
[0., 15.],
|
1275
|
+
[0., 0.]
|
1276
|
+
]) * u.ms
|
1277
|
+
)
|
1278
|
+
)
|
1279
|
+
|
1280
|
+
def test_hessian(self):
|
1281
|
+
unit = u.ms
|
1282
|
+
|
1283
|
+
def scalar_function(x):
|
1284
|
+
return x ** 3 + 3 * x * unit * unit + 2 * unit * unit * unit
|
1285
|
+
|
1286
|
+
hess = bst.augment.hessian(scalar_function, unit_aware=True)
|
1287
|
+
x = jnp.array(1.0) * unit
|
1288
|
+
res = hess(x)
|
1289
|
+
expected_hessian = jnp.array([[6.0]]) * unit
|
1290
|
+
assert u.math.allclose(res, expected_hessian)
|
brainstate/event/__init__.py
CHANGED
@@ -14,12 +14,14 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
|
17
|
-
from .
|
18
|
-
from .
|
19
|
-
from .
|
20
|
-
from .
|
21
|
-
from .
|
22
|
-
from .
|
17
|
+
from ._csr_mv import *
|
18
|
+
from ._csr_mv import __all__ as __all_csr
|
19
|
+
from ._fixedprob_mv import *
|
20
|
+
from ._fixedprob_mv import __all__ as __all_fixed_probability
|
21
|
+
from ._linear_mv import *
|
22
|
+
from ._xla_custom_op import *
|
23
|
+
from ._xla_custom_op import __all__ as __all_xla_custom_op
|
24
|
+
from ._linear_mv import __all__ as __all_linear
|
23
25
|
|
24
|
-
__all__ = __all_fixed_probability + __all_linear + __all_csr
|
25
|
-
del __all_fixed_probability, __all_linear, __all_csr
|
26
|
+
__all__ = __all_fixed_probability + __all_linear + __all_csr + __all_xla_custom_op
|
27
|
+
del __all_fixed_probability, __all_linear, __all_csr, __all_xla_custom_op
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|