brainstate 0.1.0.post20241122__py2.py3-none-any.whl → 0.1.0.post20241129__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
brainstate/_state.py CHANGED
@@ -679,7 +679,7 @@ class StateTraceStack(Generic[A]):
679
679
  """
680
680
  for st, val in zip(self.states, self._original_state_values):
681
681
  # internal use
682
- st._value = val
682
+ st.restore_value(val)
683
683
 
684
684
  def merge(self, *traces) -> 'StateTraceStack':
685
685
  """
@@ -29,15 +29,11 @@ The wrapped gradient transformations here are made possible by using the followi
29
29
 
30
30
  from __future__ import annotations
31
31
 
32
- import inspect
33
- from functools import partial, wraps
32
+ from functools import wraps, partial
34
33
  from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator
35
34
 
35
+ import brainunit as u
36
36
  import jax
37
- from jax import numpy as jnp
38
- from jax._src.api import _vjp
39
- from jax.api_util import argnums_partial
40
- from jax.extend import linear_util
41
37
 
42
38
  from brainstate._state import State, StateTraceStack
43
39
  from brainstate._utils import set_module_as
@@ -54,54 +50,15 @@ LossValue = PyTree
54
50
  AuxData = PyTree
55
51
 
56
52
 
57
- def _isgeneratorfunction(fun):
58
- # re-implemented here because of https://bugs.python.org/issue33261
59
- while inspect.ismethod(fun):
60
- fun = fun.__func__
61
- while isinstance(fun, partial):
62
- fun = fun.func
63
- return inspect.isfunction(fun) and bool(fun.__code__.co_flags & inspect.CO_GENERATOR)
64
-
65
-
66
- def _check_callable(fun):
67
- # In Python 3.10+, the only thing stopping us from supporting staticmethods
68
- # is that we can't take weak references to them, which the C++ JIT requires.
69
- if isinstance(fun, staticmethod):
70
- raise TypeError(f"staticmethod arguments are not supported, got {fun}")
71
- if not callable(fun):
72
- raise TypeError(f"Expected a callable value, got {fun}")
73
- if _isgeneratorfunction(fun):
74
- raise TypeError(f"Expected a function, got a generator function: {fun}")
75
-
76
-
77
- def functional_vector_grad(func, argnums=0, return_value: bool = False, has_aux: bool = False):
78
- """
79
- Compute the gradient of a vector with respect to the input.
80
- """
81
- _check_callable(func)
82
-
83
- @wraps(func)
84
- def grad_fun(*args, **kwargs):
85
- f = linear_util.wrap_init(func, kwargs)
86
- f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False)
87
- if has_aux:
88
- y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True)
89
- else:
90
- y, vjp_fn = _vjp(f_partial, *dyn_args, has_aux=False)
91
- leaves, tree = jax.tree.flatten(y)
92
- tangents = jax.tree.unflatten(tree, [jnp.ones(l.shape, dtype=l.dtype) for l in leaves])
93
- grads = vjp_fn(tangents)
94
- if isinstance(argnums, int):
95
- grads = grads[0]
96
- if has_aux:
97
- return (grads, y, aux) if return_value else (grads, aux)
98
- else:
99
- return (grads, y) if return_value else grads
100
-
101
- return grad_fun
102
-
103
-
104
- def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, return_value=False):
53
+ def _jacrev(
54
+ fun,
55
+ argnums=0,
56
+ holomorphic=False,
57
+ allow_int=False,
58
+ has_aux=False,
59
+ return_value=False,
60
+ unit_aware=False,
61
+ ):
105
62
  @wraps(fun)
106
63
  def fun_wrapped(*args, **kwargs):
107
64
  if has_aux:
@@ -117,7 +74,18 @@ def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, r
117
74
  else:
118
75
  return y, None
119
76
 
120
- transform = jax.jacrev(fun_wrapped, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int, has_aux=True)
77
+ if unit_aware:
78
+ transform = u.autograd.jacrev(fun_wrapped,
79
+ argnums=argnums,
80
+ holomorphic=holomorphic,
81
+ allow_int=allow_int,
82
+ has_aux=True)
83
+ else:
84
+ transform = jax.jacrev(fun_wrapped,
85
+ argnums=argnums,
86
+ holomorphic=holomorphic,
87
+ allow_int=allow_int,
88
+ has_aux=True)
121
89
 
122
90
  @wraps(fun)
123
91
  def jacfun(*args, **kwargs):
@@ -130,7 +98,14 @@ def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, r
130
98
  return jacfun
131
99
 
132
100
 
133
- def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False):
101
+ def _jacfwd(
102
+ fun,
103
+ argnums=0,
104
+ holomorphic=False,
105
+ has_aux=False,
106
+ return_value=False,
107
+ unit_aware=False,
108
+ ):
134
109
  @wraps(fun)
135
110
  def fun_wrapped(*args, **kwargs):
136
111
  if has_aux:
@@ -146,7 +121,16 @@ def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False
146
121
  else:
147
122
  return y, None
148
123
 
149
- transform = jax.jacfwd(fun_wrapped, argnums=argnums, holomorphic=holomorphic, has_aux=True)
124
+ if unit_aware:
125
+ transform = u.autograd.jacfwd(fun_wrapped,
126
+ argnums=argnums,
127
+ holomorphic=holomorphic,
128
+ has_aux=True)
129
+ else:
130
+ transform = jax.jacfwd(fun_wrapped,
131
+ argnums=argnums,
132
+ holomorphic=holomorphic,
133
+ has_aux=True)
150
134
 
151
135
  @wraps(fun)
152
136
  def jacfun(*args, **kwargs):
@@ -323,9 +307,9 @@ def grad(
323
307
  argnums: Optional[Union[int, Sequence[int]]] = None,
324
308
  holomorphic: Optional[bool] = False,
325
309
  allow_int: Optional[bool] = False,
326
- reduce_axes: Optional[Sequence[str]] = (),
327
310
  has_aux: Optional[bool] = None,
328
311
  return_value: Optional[bool] = False,
312
+ unit_aware: bool = False,
329
313
  ) -> GradientTransform | Callable[[Callable], GradientTransform]:
330
314
  """
331
315
  Compute the gradient of a scalar-valued function with respect to its arguments.
@@ -333,27 +317,24 @@ def grad(
333
317
  %s
334
318
 
335
319
  Args:
336
- fun: callable. the scalar-valued function to be differentiated.
337
- reduce_axes: (Sequence[str]) optional. Specifies the axes to reduce over when
338
- differentiating with respect to array-valued arguments. The default, (),
339
- means to differentiate each element of the output with respect to each
340
- element of the argument. If the argument is an array, this argument controls
341
- how many axes the output of grad has.
342
- allow_int: (bool) optional. Whether to allow differentiating with respect to
343
- integer valued inputs. The gradient of an integer input will have a trivial
344
- vector-space dtype (float0). Default False.
345
- holomorphic: (bool) optional. Whether fun is promised to be holomorphic.
346
- Default False.
347
- grad_states: (State, Sequence[State], Dict[str, State]) optional. The variables
348
- in fun to take their gradients.
349
- fun: the scalar-valued function to be differentiated.
350
- argnums: (int or tuple of ints) optional. Specifies which positional
351
- argument(s) to differentiate with respect to.
352
- has_aux: (bool) optional. Indicates whether fun returns a pair where the
353
- first element is considered the output of the mathematical function to be
354
- differentiated and the second element is auxiliary data. Default False.
355
- return_value: (bool) optional. Indicates whether to return the value of the
356
- function along with the gradient. Default False.
320
+ fun: callable. the scalar-valued function to be differentiated.
321
+ allow_int: (bool) optional. Whether to allow differentiating with respect to
322
+ integer valued inputs. The gradient of an integer input will have a trivial
323
+ vector-space dtype (float0). Default False.
324
+ holomorphic: (bool) optional. Whether fun is promised to be holomorphic.
325
+ Default False.
326
+ grad_states: (State, Sequence[State], Dict[str, State]) optional. The variables
327
+ in fun to take their gradients.
328
+ fun: the scalar-valued function to be differentiated.
329
+ argnums: (int or tuple of ints) optional. Specifies which positional
330
+ argument(s) to differentiate with respect to.
331
+ has_aux: (bool) optional. Indicates whether fun returns a pair where the
332
+ first element is considered the output of the mathematical function to be
333
+ differentiated and the second element is auxiliary data. Default False.
334
+ return_value: (bool) optional. Indicates whether to return the value of the
335
+ function along with the gradient. Default False.
336
+ unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
337
+ mode. Default False.
357
338
 
358
339
  Returns:
359
340
  A function which computes the gradient of fun. The function takes the same
@@ -367,26 +348,24 @@ def grad(
367
348
  if isinstance(fun, Missing):
368
349
  def transform(fun) -> GradientTransform:
369
350
  return GradientTransform(target=fun,
370
- transform=jax.grad,
351
+ transform=u.autograd.grad if unit_aware else jax.grad,
371
352
  grad_states=grad_states,
372
353
  argnums=argnums,
373
354
  return_value=return_value,
374
355
  has_aux=False if has_aux is None else has_aux,
375
356
  transform_params=dict(holomorphic=holomorphic,
376
- allow_int=allow_int,
377
- reduce_axes=reduce_axes))
357
+ allow_int=allow_int))
378
358
 
379
359
  return transform
380
360
 
381
361
  return GradientTransform(target=fun,
382
- transform=jax.grad,
362
+ transform=u.autograd.grad if unit_aware else jax.grad,
383
363
  grad_states=grad_states,
384
364
  argnums=argnums,
385
365
  return_value=return_value,
386
366
  has_aux=False if has_aux is None else has_aux,
387
367
  transform_params=dict(holomorphic=holomorphic,
388
- allow_int=allow_int,
389
- reduce_axes=reduce_axes))
368
+ allow_int=allow_int))
390
369
 
391
370
 
392
371
  grad.__doc__ = grad.__doc__ % _doc_of_return
@@ -399,6 +378,7 @@ def vector_grad(
399
378
  argnums: Optional[Union[int, Sequence[int]]] = None,
400
379
  return_value: bool = False,
401
380
  has_aux: Optional[bool] = None,
381
+ unit_aware: bool = False,
402
382
  ) -> GradientTransform | Callable[[Callable], GradientTransform]:
403
383
  """Take vector-valued gradients for function ``func``.
404
384
 
@@ -410,28 +390,30 @@ def vector_grad(
410
390
  Parameters
411
391
  ----------
412
392
  func: Callable
413
- Function whose gradient is to be computed.
393
+ Function whose gradient is to be computed.
414
394
  grad_states : optional, ArrayType, sequence of ArrayType, dict
415
- The variables in ``func`` to take their gradients.
395
+ The variables in ``func`` to take their gradients.
416
396
  has_aux: optional, bool
417
- Indicates whether ``fun`` returns a pair where the
418
- first element is considered the output of the mathematical function to be
419
- differentiated and the second element is auxiliary data. Default False.
397
+ Indicates whether ``fun`` returns a pair where the
398
+ first element is considered the output of the mathematical function to be
399
+ differentiated and the second element is auxiliary data. Default False.
420
400
  return_value : bool
421
- Whether return the loss value.
401
+ Whether return the loss value.
422
402
  argnums: Optional, integer or sequence of integers. Specifies which
423
- positional argument(s) to differentiate with respect to (default ``0``).
403
+ positional argument(s) to differentiate with respect to (default ``0``).
404
+ unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
405
+ mode. Default False.
424
406
 
425
407
  Returns
426
408
  -------
427
409
  func : GradientTransform
428
- The vector gradient function.
410
+ The vector gradient function.
429
411
  """
430
412
 
431
413
  if isinstance(func, Missing):
432
414
  def transform(fun) -> GradientTransform:
433
415
  return GradientTransform(target=fun,
434
- transform=functional_vector_grad,
416
+ transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
435
417
  grad_states=grad_states,
436
418
  argnums=argnums,
437
419
  return_value=return_value,
@@ -441,7 +423,7 @@ def vector_grad(
441
423
 
442
424
  else:
443
425
  return GradientTransform(target=func,
444
- transform=functional_vector_grad,
426
+ transform=partial(u.autograd.vector_grad, unit_aware=unit_aware),
445
427
  grad_states=grad_states,
446
428
  argnums=argnums,
447
429
  return_value=return_value,
@@ -460,6 +442,7 @@ def jacrev(
460
442
  return_value: bool = False,
461
443
  holomorphic: bool = False,
462
444
  allow_int: bool = False,
445
+ unit_aware: bool = False,
463
446
  ) -> GradientTransform:
464
447
  """
465
448
  Extending automatic Jacobian (reverse-mode) of ``func`` to classes.
@@ -473,25 +456,28 @@ def jacrev(
473
456
 
474
457
  Parameters
475
458
  ----------
476
- fun: Function whose Jacobian is to be computed.
459
+ fun: Callable
460
+ Function whose Jacobian is to be computed.
477
461
  grad_states : optional, ArrayType, sequence of ArrayType, dict
478
- The variables in ``func`` to take their gradients.
462
+ The variables in ``func`` to take their gradients.
479
463
  has_aux: optional, bool
480
- Indicates whether ``fun`` returns a pair where the
481
- first element is considered the output of the mathematical function to be
482
- differentiated and the second element is auxiliary data. Default False.
464
+ Indicates whether ``fun`` returns a pair where the
465
+ first element is considered the output of the mathematical function to be
466
+ differentiated and the second element is auxiliary data. Default False.
483
467
  return_value : bool
484
- Whether return the loss value.
468
+ Whether return the loss value.
485
469
  argnums: Optional, integer or sequence of integers.
486
- Specifies which
487
- positional argument(s) to differentiate with respect to (default ``0``).
470
+ Specifies which
471
+ positional argument(s) to differentiate with respect to (default ``0``).
488
472
  holomorphic: Optional, bool.
489
- Indicates whether ``fun`` is promised to be
490
- holomorphic. Default False.
473
+ Indicates whether ``fun`` is promised to be
474
+ holomorphic. Default False.
491
475
  allow_int: Optional, bool.
492
- Whether to allow differentiating with
493
- respect to integer valued inputs. The gradient of an integer input will
494
- have a trivial vector-space dtype (float0). Default False.
476
+ Whether to allow differentiating with
477
+ respect to integer valued inputs. The gradient of an integer input will
478
+ have a trivial vector-space dtype (float0). Default False.
479
+ unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
480
+ mode. Default False.
495
481
 
496
482
  Returns
497
483
  -------
@@ -505,7 +491,8 @@ def jacrev(
505
491
  return_value=return_value,
506
492
  has_aux=False if has_aux is None else has_aux,
507
493
  transform_params=dict(holomorphic=holomorphic,
508
- allow_int=allow_int))
494
+ allow_int=allow_int,
495
+ unit_aware=unit_aware, ))
509
496
 
510
497
 
511
498
  jacrev.__doc__ = jacrev.__doc__ % _doc_of_return
@@ -521,6 +508,7 @@ def jacfwd(
521
508
  has_aux: Optional[bool] = None,
522
509
  return_value: bool = False,
523
510
  holomorphic: bool = False,
511
+ unit_aware: bool = False,
524
512
  ) -> GradientTransform:
525
513
  """Extending automatic Jacobian (forward-mode) of ``func`` to classes.
526
514
 
@@ -542,9 +530,11 @@ def jacfwd(
542
530
  return_value : bool
543
531
  Whether return the loss value.
544
532
  argnums: Optional, integer or sequence of integers. Specifies which
545
- positional argument(s) to differentiate with respect to (default ``0``).
533
+ positional argument(s) to differentiate with respect to (default ``0``).
546
534
  holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
547
- holomorphic. Default False.
535
+ holomorphic. Default False.
536
+ unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
537
+ mode. Default False.
548
538
 
549
539
  Returns
550
540
  -------
@@ -558,7 +548,8 @@ def jacfwd(
558
548
  argnums=argnums,
559
549
  return_value=return_value,
560
550
  has_aux=False if has_aux is None else has_aux,
561
- transform_params=dict(holomorphic=holomorphic))
551
+ transform_params=dict(holomorphic=holomorphic,
552
+ unit_aware=unit_aware))
562
553
 
563
554
 
564
555
  jacfwd.__doc__ = jacfwd.__doc__ % _doc_of_return
@@ -569,9 +560,10 @@ def hessian(
569
560
  func: Callable,
570
561
  grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None,
571
562
  argnums: Optional[Union[int, Sequence[int]]] = None,
572
- has_aux: bool = False,
573
563
  return_value: bool = False,
574
564
  holomorphic: bool = False,
565
+ has_aux: Optional[bool] = None,
566
+ unit_aware: bool = False,
575
567
  ) -> GradientTransform:
576
568
  """
577
569
  Hessian of ``func`` as a dense array.
@@ -593,6 +585,12 @@ def hessian(
593
585
  Indicates whether ``fun`` is promised to be holomorphic. Default False.
594
586
  return_value : bool
595
587
  Whether return the hessian values.
588
+ has_aux: Optional, bool
589
+ Indicates whether ``fun`` returns a pair where the first element is considered
590
+ the output of the mathematical function to be differentiated and the second
591
+ element is auxiliary data. Default False.
592
+ unit_aware: (bool) optional. Whether to return the gradient in the unit-aware
593
+ mode. Default False.
596
594
 
597
595
  Returns
598
596
  -------
@@ -600,7 +598,7 @@ def hessian(
600
598
  The transformed object.
601
599
  """
602
600
  return GradientTransform(target=func,
603
- transform=jax.hessian,
601
+ transform=u.autograd.hessian if unit_aware else jax.hessian,
604
602
  grad_states=grad_states,
605
603
  argnums=argnums,
606
604
  return_value=return_value,
@@ -19,6 +19,7 @@ from __future__ import annotations
19
19
  import unittest
20
20
  from pprint import pprint
21
21
 
22
+ import brainunit as u
22
23
  import jax
23
24
  import jax.numpy as jnp
24
25
  import pytest
@@ -608,6 +609,8 @@ class TestClassFuncJacobian(unittest.TestCase):
608
609
  br = bst.augment.jacrev(t, grad_states=[t.x, t.y])()
609
610
  self.assertTrue((br[0] == _jr[0]).all())
610
611
  self.assertTrue((br[1] == _jr[1]).all())
612
+
613
+
611
614
  #
612
615
  # def test_jacfwd1(self):
613
616
  # def f1(x, y):
@@ -1191,3 +1194,97 @@ class TestClassFuncJacobian(unittest.TestCase):
1191
1194
  # self.assertTrue(file.read().strip() == expect_res.strip())
1192
1195
  #
1193
1196
  #
1197
+
1198
+
1199
+ class TestUnitAwareGrad(unittest.TestCase):
1200
+ def test_grad1(self):
1201
+ def f(x):
1202
+ return u.math.sum(x ** 2)
1203
+
1204
+ x = jnp.array([1., 2., 3.]) * u.ms
1205
+ g = bst.augment.grad(f, unit_aware=True)(x)
1206
+ self.assertTrue(u.math.allclose(g, 2 * x))
1207
+
1208
+ def test_vector_grad1(self):
1209
+ def f(x):
1210
+ return x ** 3
1211
+
1212
+ x = jnp.array([1., 2., 3.]) * u.ms
1213
+ g = bst.augment.vector_grad(f, unit_aware=True)(x)
1214
+ self.assertTrue(u.math.allclose(g, 3 * x ** 2))
1215
+
1216
+ def test_jacrev1(self):
1217
+ def f(x, y):
1218
+ return u.math.asarray([x[0] * y[0],
1219
+ 5 * x[2] * y[1],
1220
+ 4 * x[1] ** 2, ])
1221
+
1222
+ _x = jnp.array([1., 2., 3.]) * u.ms
1223
+ _y = jnp.array([10., 5.]) * u.ms
1224
+
1225
+ g = bst.augment.jacrev(f, unit_aware=True, argnums=(0, 1))(_x, _y)
1226
+ self.assertTrue(
1227
+ u.math.allclose(
1228
+ g[0],
1229
+ u.math.asarray([
1230
+ [10., 0., 0.],
1231
+ [0., 0., 25.],
1232
+ [0., 16., 0.]
1233
+ ]) * u.ms
1234
+ )
1235
+ )
1236
+
1237
+ self.assertTrue(
1238
+ u.math.allclose(
1239
+ g[1],
1240
+ u.math.asarray([
1241
+ [1., 0.],
1242
+ [0., 15.],
1243
+ [0., 0.]
1244
+ ]) * u.ms
1245
+ )
1246
+ )
1247
+
1248
+ def test_jacfwd1(self):
1249
+ def f(x, y):
1250
+ return u.math.asarray([x[0] * y[0],
1251
+ 5 * x[2] * y[1],
1252
+ 4 * x[1] ** 2, ])
1253
+
1254
+ _x = jnp.array([1., 2., 3.]) * u.ms
1255
+ _y = jnp.array([10., 5.]) * u.ms
1256
+
1257
+ g = bst.augment.jacfwd(f, unit_aware=True, argnums=(0, 1))(_x, _y)
1258
+ self.assertTrue(
1259
+ u.math.allclose(
1260
+ g[0],
1261
+ u.math.asarray([
1262
+ [10., 0., 0.],
1263
+ [0., 0., 25.],
1264
+ [0., 16., 0.]
1265
+ ]) * u.ms
1266
+ )
1267
+ )
1268
+
1269
+ self.assertTrue(
1270
+ u.math.allclose(
1271
+ g[1],
1272
+ u.math.asarray([
1273
+ [1., 0.],
1274
+ [0., 15.],
1275
+ [0., 0.]
1276
+ ]) * u.ms
1277
+ )
1278
+ )
1279
+
1280
+ def test_hessian(self):
1281
+ unit = u.ms
1282
+
1283
+ def scalar_function(x):
1284
+ return x ** 3 + 3 * x * unit * unit + 2 * unit * unit * unit
1285
+
1286
+ hess = bst.augment.hessian(scalar_function, unit_aware=True)
1287
+ x = jnp.array(1.0) * unit
1288
+ res = hess(x)
1289
+ expected_hessian = jnp.array([[6.0]]) * unit
1290
+ assert u.math.allclose(res, expected_hessian)
@@ -14,14 +14,14 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
- from ._csr import *
18
- from ._csr import __all__ as __all_csr
19
- from ._fixed_probability import *
20
- from ._fixed_probability import __all__ as __all_fixed_probability
21
- from ._linear import *
17
+ from ._csr_mv import *
18
+ from ._csr_mv import __all__ as __all_csr
19
+ from ._fixedprob_mv import *
20
+ from ._fixedprob_mv import __all__ as __all_fixed_probability
21
+ from ._linear_mv import *
22
22
  from ._xla_custom_op import *
23
23
  from ._xla_custom_op import __all__ as __all_xla_custom_op
24
- from ._linear import __all__ as __all_linear
24
+ from ._linear_mv import __all__ as __all_linear
25
25
 
26
26
  __all__ = __all_fixed_probability + __all_linear + __all_csr + __all_xla_custom_op
27
27
  del __all_fixed_probability, __all_linear, __all_csr, __all_xla_custom_op
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -20,7 +20,7 @@ import jax.numpy as jnp
20
20
  from absl.testing import parameterized
21
21
 
22
22
  import brainstate as bst
23
- from brainstate.event._linear import Linear
23
+ from brainstate.event._linear_mv import Linear
24
24
 
25
25
 
26
26
  class TestEventLinear(parameterized.TestCase):
@@ -17,14 +17,8 @@ from jaxlib.hlo_helpers import custom_call
17
17
 
18
18
  numba_installed = importlib.util.find_spec('numba') is not None
19
19
 
20
- if numba_installed:
21
- import numba # pylint: disable=import-error
22
- from numba import types, carray, cfunc # pylint: disable=import-error
23
- from numba.core.dispatcher import Dispatcher # pylint: disable=import-error
24
- else:
25
- numba = None
26
-
27
20
  __all__ = [
21
+ 'defjvp',
28
22
  'XLACustomOp',
29
23
  ]
30
24
 
@@ -93,9 +87,12 @@ def _numba_mlir_cpu_translation_rule(
93
87
  *ins,
94
88
  **kwargs
95
89
  ):
96
- if numba is None:
90
+ if not numba_installed:
97
91
  raise ImportError('Numba is required to compile the CPU kernel for the custom operator.')
98
92
 
93
+ from numba import types, carray, cfunc # pylint: disable=import-error
94
+ from numba.core.dispatcher import Dispatcher # pylint: disable=import-error
95
+
99
96
  if not isinstance(kernel, Dispatcher):
100
97
  kernel = kernel(**kwargs)
101
98
  assert isinstance(kernel, Dispatcher), f'The kernel should be a Numba dispatcher. But we got {kernel}'
@@ -88,7 +88,7 @@ class _DropoutNd(ElementWiseBlock):
88
88
  name: Optional[str] = None
89
89
  ) -> None:
90
90
  super().__init__(name=name)
91
- assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
91
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
92
92
  self.prob = prob
93
93
  self.channel_axis = channel_axis
94
94
 
@@ -112,7 +112,7 @@ class _DropoutNd(ElementWiseBlock):
112
112
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
113
113
 
114
114
  # generate mask
115
- if fit_phase:
115
+ if fit_phase and self.prob < 1.:
116
116
  dtype = u.math.get_dtype(x)
117
117
  keep_mask = jnp.broadcast_to(random.bernoulli(self.prob, mask_shape), x.shape)
118
118
  return jnp.where(keep_mask,
@@ -396,7 +396,7 @@ class DropoutFixed(ElementWiseBlock):
396
396
  name: Optional[str] = None
397
397
  ) -> None:
398
398
  super().__init__(name=name)
399
- assert 0. <= prob < 1., f"Dropout probability must be in the range [0, 1). But got {prob}."
399
+ assert 0. <= prob <= 1., f"Dropout probability must be in the range [0, 1]. But got {prob}."
400
400
  self.prob = prob
401
401
  self.in_size = in_size
402
402
  self.out_size = in_size
@@ -407,7 +407,7 @@ class DropoutFixed(ElementWiseBlock):
407
407
  def update(self, x):
408
408
  dtype = u.math.get_dtype(x)
409
409
  fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
410
- if fit_phase:
410
+ if fit_phase and self.prob < 1.:
411
411
  if self.mask.value.shape != x.shape:
412
412
  raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
413
413
  f"Please call `init_state()` method first.")
@@ -79,7 +79,7 @@ class Linear(Module):
79
79
  weight = params['weight']
80
80
  if self.w_mask is not None:
81
81
  weight = weight * self.w_mask
82
- y = u.math.dot(x, weight)
82
+ y = u.linalg.dot(x, weight)
83
83
  if 'bias' in params:
84
84
  y = y + params['bias']
85
85
  return y
@@ -192,7 +192,7 @@ class ScaledWSLinear(Module):
192
192
  w = functional.weight_standardization(w, self.eps, params.get('gain', None))
193
193
  if self.w_mask is not None:
194
194
  w = w * self.w_mask
195
- y = u.math.dot(x, w)
195
+ y = u.linalg.dot(x, w)
196
196
  if 'bias' in params:
197
197
  y = y + params['bias']
198
198
  return y