brainstate 0.0.1.post20240622__py2.py3-none-any.whl → 0.0.1.post20240708__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,7 +22,7 @@ from brainstate._module import (register_delay_of_target,
22
22
  ReceiveInputProj,
23
23
  ExtendedUpdateWithBA)
24
24
  from brainstate._utils import set_module_as
25
- from brainstate.mixin import (Mode, AllOfTypes, DelayedInitializer, BindCondData, AlignPost, UpdateReturn)
25
+ from brainstate.mixin import (Mode, JointTypes, DelayedInitializer, BindCondData, AlignPost, UpdateReturn)
26
26
  from ._utils import is_instance
27
27
 
28
28
  __all__ = [
@@ -39,7 +39,7 @@ def align_post_add_bef_update(
39
39
  out_label: str,
40
40
  syn_desc,
41
41
  out_desc,
42
- post: AllOfTypes[ReceiveInputProj, ExtendedUpdateWithBA],
42
+ post: JointTypes[ReceiveInputProj, ExtendedUpdateWithBA],
43
43
  proj_name: str
44
44
  ):
45
45
  # synapse and output initialization
@@ -60,7 +60,7 @@ class _AlignPost(Module):
60
60
  def __init__(
61
61
  self,
62
62
  syn: Module,
63
- out: AllOfTypes[Dynamics, BindCondData]
63
+ out: JointTypes[Dynamics, BindCondData]
64
64
  ):
65
65
  super().__init__()
66
66
  self.syn = syn
@@ -140,7 +140,7 @@ class HalfProjAlignPostMg(Projection):
140
140
  comm: Module,
141
141
  syn: DelayedInitializer[AlignPost],
142
142
  out: DelayedInitializer[BindCondData],
143
- post: AllOfTypes[ReceiveInputProj, ExtendedUpdateWithBA],
143
+ post: JointTypes[ReceiveInputProj, ExtendedUpdateWithBA],
144
144
  out_label: Optional[str] = None,
145
145
  name: Optional[str] = None,
146
146
  mode: Optional[Mode] = None,
@@ -150,7 +150,7 @@ class HalfProjAlignPostMg(Projection):
150
150
  # synaptic models
151
151
  is_instance(syn, DelayedInitializer[AlignPost])
152
152
  is_instance(out, DelayedInitializer[BindCondData])
153
- is_instance(post, AllOfTypes[ReceiveInputProj, ExtendedUpdateWithBA])
153
+ is_instance(post, JointTypes[ReceiveInputProj, ExtendedUpdateWithBA])
154
154
 
155
155
  # synapse and output initialization
156
156
  syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name)
@@ -257,12 +257,12 @@ class FullProjAlignPostMg(Projection):
257
257
 
258
258
  def __init__(
259
259
  self,
260
- pre: AllOfTypes[ExtendedUpdateWithBA, UpdateReturn],
260
+ pre: JointTypes[ExtendedUpdateWithBA, UpdateReturn],
261
261
  delay: Union[None, int, float],
262
262
  comm: Module,
263
263
  syn: DelayedInitializer[AlignPost],
264
264
  out: DelayedInitializer[BindCondData],
265
- post: AllOfTypes[ReceiveInputProj, ExtendedUpdateWithBA],
265
+ post: JointTypes[ReceiveInputProj, ExtendedUpdateWithBA],
266
266
  out_label: Optional[str] = None,
267
267
  name: Optional[str] = None,
268
268
  mode: Optional[Mode] = None,
@@ -270,11 +270,11 @@ class FullProjAlignPostMg(Projection):
270
270
  super().__init__(name=name, mode=mode)
271
271
 
272
272
  # synaptic models
273
- is_instance(pre, AllOfTypes[ExtendedUpdateWithBA, UpdateReturn])
273
+ is_instance(pre, JointTypes[ExtendedUpdateWithBA, UpdateReturn])
274
274
  is_instance(comm, Module)
275
275
  is_instance(syn, DelayedInitializer[AlignPost])
276
276
  is_instance(out, DelayedInitializer[BindCondData])
277
- is_instance(post, AllOfTypes[ReceiveInputProj, ExtendedUpdateWithBA])
277
+ is_instance(post, JointTypes[ReceiveInputProj, ExtendedUpdateWithBA])
278
278
  self.comm = comm
279
279
 
280
280
  # delay initialization
@@ -494,8 +494,8 @@ class FullProjAlignPost(Projection):
494
494
  # synaptic models
495
495
  is_instance(pre, UpdateReturn)
496
496
  is_instance(comm, Module)
497
- is_instance(syn, AllOfTypes[Dynamics, AlignPost])
498
- is_instance(out, AllOfTypes[Dynamics, BindCondData])
497
+ is_instance(syn, JointTypes[Dynamics, AlignPost])
498
+ is_instance(out, JointTypes[Dynamics, BindCondData])
499
499
  is_instance(post, ReceiveInputProj)
500
500
  self.comm = comm
501
501
  self.syn = syn
@@ -20,7 +20,7 @@ from brainstate._module import (Module, DelayAccess, Projection,
20
20
  ExtendedUpdateWithBA, ReceiveInputProj,
21
21
  register_delay_of_target)
22
22
  from brainstate._utils import set_module_as
23
- from brainstate.mixin import (DelayedInitializer, BindCondData, UpdateReturn, Mode, AllOfTypes)
23
+ from brainstate.mixin import (DelayedInitializer, BindCondData, UpdateReturn, Mode, JointTypes)
24
24
  from ._utils import is_instance
25
25
 
26
26
  __all__ = [
@@ -284,7 +284,7 @@ class FullProjAlignPreDSMg(Projection):
284
284
 
285
285
  def __init__(
286
286
  self,
287
- pre: AllOfTypes[ExtendedUpdateWithBA, UpdateReturn],
287
+ pre: JointTypes[ExtendedUpdateWithBA, UpdateReturn],
288
288
  delay: Union[None, int, float],
289
289
  syn: DelayedInitializer[UpdateReturn],
290
290
  comm: Module,
@@ -297,7 +297,7 @@ class FullProjAlignPreDSMg(Projection):
297
297
  super().__init__(name=name, mode=mode)
298
298
 
299
299
  # synaptic models
300
- is_instance(pre, AllOfTypes[ExtendedUpdateWithBA, UpdateReturn])
300
+ is_instance(pre, JointTypes[ExtendedUpdateWithBA, UpdateReturn])
301
301
  is_instance(syn, DelayedInitializer[Module])
302
302
  is_instance(comm, Module)
303
303
  is_instance(out, BindCondData)
brainstate/random.py CHANGED
@@ -22,11 +22,16 @@ from functools import partial
22
22
  from operator import index
23
23
  from typing import Optional
24
24
 
25
+ import brainunit as bu
25
26
  import jax
27
+ import jax.numpy as jnp
28
+ import jax.random as jr
26
29
  import numpy as np
27
- from jax import lax, jit, vmap, numpy as jnp, random as jr, core, dtypes
30
+ from jax import jit, vmap
31
+ from jax import lax, core, dtypes
28
32
 
29
33
  from brainstate import environ
34
+ from ._random_for_unit import uniform_for_unit, permutation_for_unit
30
35
  from ._state import State
31
36
  from .transform._jit_error import jit_error
32
37
  from .typing import DTypeLike, Size, SeedOrKey
@@ -144,7 +149,7 @@ class RandomState(State):
144
149
  def rand(self, *dn, key: Optional[SeedOrKey] = None, dtype: DTypeLike = None):
145
150
  key = self.split_key() if key is None else _formalize_key(key)
146
151
  dtype = dtype or environ.dftype()
147
- r = jr.uniform(key, shape=dn, minval=0., maxval=1., dtype=dtype)
152
+ r = uniform_for_unit(key, shape=dn, minval=0., maxval=1., dtype=dtype)
148
153
  return r
149
154
 
150
155
  def randint(
@@ -207,7 +212,7 @@ class RandomState(State):
207
212
  dtype: DTypeLike = None):
208
213
  dtype = dtype or environ.dftype()
209
214
  key = self.split_key() if key is None else _formalize_key(key)
210
- r = jr.uniform(key, shape=_size2shape(size), minval=0., maxval=1., dtype=dtype)
215
+ r = uniform_for_unit(key, shape=_size2shape(size), minval=0., maxval=1., dtype=dtype)
211
216
  return r
212
217
 
213
218
  def random_sample(self,
@@ -250,7 +255,7 @@ class RandomState(State):
250
255
  key: Optional[SeedOrKey] = None):
251
256
  x = _check_py_seq(x)
252
257
  key = self.split_key() if key is None else _formalize_key(key)
253
- r = jr.permutation(key, x, axis=axis, independent=independent)
258
+ r = permutation_for_unit(key, x, axis=axis, independent=independent)
254
259
  return r
255
260
 
256
261
  def shuffle(self,
@@ -258,7 +263,7 @@ class RandomState(State):
258
263
  axis=0,
259
264
  key: Optional[SeedOrKey] = None):
260
265
  key = self.split_key() if key is None else _formalize_key(key)
261
- x = jr.permutation(key, x, axis=axis)
266
+ x = permutation_for_unit(key, x, axis=axis)
262
267
  return x
263
268
 
264
269
  def beta(self,
@@ -458,7 +463,7 @@ class RandomState(State):
458
463
  size = lax.broadcast_shapes(jnp.shape(low), jnp.shape(high))
459
464
  key = self.split_key() if key is None else _formalize_key(key)
460
465
  dtype = dtype or environ.dftype()
461
- r = jr.uniform(key, shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)
466
+ r = uniform_for_unit(key, shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)
462
467
  return r
463
468
 
464
469
  def __norm_cdf(self, x, sqrt2, dtype):
@@ -481,36 +486,46 @@ class RandomState(State):
481
486
  scale = _check_py_seq(scale)
482
487
  dtype = dtype or environ.dftype()
483
488
 
484
- lower = lax.convert_element_type(lower, dtype)
485
- upper = lax.convert_element_type(upper, dtype)
486
- loc = lax.convert_element_type(loc, dtype)
487
- scale = lax.convert_element_type(scale, dtype)
489
+ lower = bu.math.asarray(lower, dtype=dtype)
490
+ upper = bu.math.asarray(upper, dtype=dtype)
491
+ loc = bu.math.asarray(loc, dtype=dtype)
492
+ scale = bu.math.asarray(scale, dtype=dtype)
493
+ bu.fail_for_dimension_mismatch(lower, upper)
494
+ bu.fail_for_dimension_mismatch(lower, loc)
495
+ bu.fail_for_dimension_mismatch(lower, scale)
496
+ dim = lower.dim if isinstance(lower, bu.Quantity) else bu.DIMENSIONLESS
497
+ lower = lower.value if isinstance(lower, bu.Quantity) else lower
498
+ upper = upper.value if isinstance(upper, bu.Quantity) else upper
499
+ loc = loc.value if isinstance(loc, bu.Quantity) else loc
500
+ scale = scale.value if isinstance(scale, bu.Quantity) else scale
488
501
 
489
502
  jit_error(
490
- jnp.any(jnp.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
503
+ bu.math.any(bu.math.logical_or(loc < lower - 2 * scale, loc > upper + 2 * scale)),
491
504
  "mean is more than 2 std from [lower, upper] in truncated_normal. "
492
505
  "The distribution of values may be incorrect."
493
506
  )
494
507
 
495
508
  if size is None:
496
- size = lax.broadcast_shapes(jnp.shape(lower),
497
- jnp.shape(upper),
498
- jnp.shape(loc),
499
- jnp.shape(scale))
509
+ size = bu.math.broadcast_shapes(jnp.shape(lower),
510
+ jnp.shape(upper),
511
+ jnp.shape(loc),
512
+ jnp.shape(scale))
500
513
 
501
514
  # Values are generated by using a truncated uniform distribution and
502
515
  # then using the inverse CDF for the normal distribution.
503
516
  # Get upper and lower cdf values
504
- sqrt2 = np.array(np.sqrt(2), dtype)
517
+ sqrt2 = np.array(np.sqrt(2), dtype=dtype)
505
518
  l = self.__norm_cdf((lower - loc) / scale, sqrt2, dtype)
506
519
  u = self.__norm_cdf((upper - loc) / scale, sqrt2, dtype)
507
520
 
508
521
  # Uniformly fill tensor with values from [l, u], then translate to
509
522
  # [2l-1, 2u-1].
510
523
  key = self.split_key() if key is None else _formalize_key(key)
511
- out = jr.uniform(key, size, dtype,
512
- minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
513
- maxval=lax.nextafter(2 * u - 1, np.array(-np.inf, dtype=dtype)))
524
+ out = uniform_for_unit(
525
+ key, size, dtype,
526
+ minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)),
527
+ maxval=lax.nextafter(2 * u - 1, np.array(-np.inf, dtype=dtype))
528
+ )
514
529
 
515
530
  # Use inverse cdf transform for normal distribution to get truncated
516
531
  # standard normal
@@ -523,7 +538,7 @@ class RandomState(State):
523
538
  out = jnp.clip(out,
524
539
  lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),
525
540
  lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype)))
526
- return out
541
+ return out if dim == bu.DIMENSIONLESS else bu.Quantity(out, dim=dim)
527
542
 
528
543
  def _check_p(self, p):
529
544
  raise ValueError(f'Parameter p should be within [0, 1], but we got {p}')
@@ -548,6 +563,13 @@ class RandomState(State):
548
563
  dtype: DTypeLike = None):
549
564
  mean = _check_py_seq(mean)
550
565
  sigma = _check_py_seq(sigma)
566
+ mean = bu.math.asarray(mean, dtype=dtype)
567
+ sigma = bu.math.asarray(sigma, dtype=dtype)
568
+ bu.fail_for_dimension_mismatch(mean, sigma)
569
+ dim = mean.dim if isinstance(mean, bu.Quantity) else bu.DIMENSIONLESS
570
+ mean = mean.value if isinstance(mean, bu.Quantity) else mean
571
+ sigma = sigma.value if isinstance(sigma, bu.Quantity) else sigma
572
+
551
573
  if size is None:
552
574
  size = jnp.broadcast_shapes(jnp.shape(mean),
553
575
  jnp.shape(sigma))
@@ -556,7 +578,7 @@ class RandomState(State):
556
578
  samples = jr.normal(key, shape=_size2shape(size), dtype=dtype)
557
579
  samples = _loc_scale(mean, sigma, samples)
558
580
  samples = jnp.exp(samples)
559
- return samples
581
+ return samples if dim == bu.DIMENSIONLESS else bu.Quantity(samples, dim=dim)
560
582
 
561
583
  def binomial(self,
562
584
  n,
@@ -614,7 +636,7 @@ class RandomState(State):
614
636
  size = jnp.shape(p)
615
637
  key = self.split_key() if key is None else _formalize_key(key)
616
638
  dtype = dtype or environ.dftype()
617
- u = jr.uniform(key, size, dtype=dtype)
639
+ u = uniform_for_unit(key, size, dtype=dtype)
618
640
  r = jnp.floor(jnp.log1p(-u) / jnp.log1p(-p))
619
641
  return r
620
642
 
@@ -640,20 +662,28 @@ class RandomState(State):
640
662
  dtype = dtype or environ.ditype()
641
663
  return jnp.asarray(r, dtype=dtype)
642
664
 
643
- def multivariate_normal(self,
644
- mean,
645
- cov,
646
- size: Optional[Size] = None,
647
- method: str = 'cholesky',
648
- key: Optional[SeedOrKey] = None,
649
- dtype: DTypeLike = None):
665
+ def multivariate_normal(
666
+ self,
667
+ mean,
668
+ cov,
669
+ size: Optional[Size] = None,
670
+ method: str = 'cholesky',
671
+ key: Optional[SeedOrKey] = None,
672
+ dtype: DTypeLike = None
673
+ ):
650
674
  if method not in {'svd', 'eigh', 'cholesky'}:
651
675
  raise ValueError("method must be one of {'svd', 'eigh', 'cholesky'}")
652
676
  dtype = dtype or environ.dftype()
653
- mean = jnp.asarray(_check_py_seq(mean), dtype=dtype)
654
- cov = jnp.asarray(_check_py_seq(cov), dtype=dtype)
655
- key = self.split_key() if key is None else _formalize_key(key)
677
+ mean = bu.math.asarray(_check_py_seq(mean), dtype=dtype)
678
+ cov = bu.math.asarray(_check_py_seq(cov), dtype=dtype)
679
+ if isinstance(mean, bu.Quantity):
680
+ assert isinstance(cov, bu.Quantity)
681
+ assert mean.dim ** 2 == cov.dim
682
+ mean = mean.value if isinstance(mean, bu.Quantity) else mean
683
+ cov = cov.value if isinstance(cov, bu.Quantity) else cov
684
+ dim = mean.dim if isinstance(mean, bu.Quantity) else bu.DIMENSIONLESS
656
685
 
686
+ key = self.split_key() if key is None else _formalize_key(key)
657
687
  if not jnp.ndim(mean) >= 1:
658
688
  raise ValueError(f"multivariate_normal requires mean.ndim >= 1, got mean.ndim == {jnp.ndim(mean)}")
659
689
  if not jnp.ndim(cov) >= 2:
@@ -678,7 +708,7 @@ class RandomState(State):
678
708
  factor = jnp.linalg.cholesky(cov)
679
709
  normal_samples = jr.normal(key, size + mean.shape[-1:], dtype=dtype)
680
710
  r = mean + jnp.einsum('...ij,...j->...i', factor, normal_samples)
681
- return r
711
+ return r if dim == bu.DIMENSIONLESS else bu.Quantity(r, dim=dim)
682
712
 
683
713
  def rayleigh(self,
684
714
  scale=1.0,
@@ -690,7 +720,7 @@ class RandomState(State):
690
720
  size = jnp.shape(scale)
691
721
  key = self.split_key() if key is None else _formalize_key(key)
692
722
  dtype = dtype or environ.dftype()
693
- x = jnp.sqrt(-2. * jnp.log(jr.uniform(key, shape=_size2shape(size), minval=0, maxval=1, dtype=dtype)))
723
+ x = jnp.sqrt(-2. * jnp.log(uniform_for_unit(key, shape=_size2shape(size), minval=0, maxval=1, dtype=dtype)))
694
724
  r = x * scale
695
725
  return r
696
726
 
@@ -734,7 +764,7 @@ class RandomState(State):
734
764
  raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
735
765
  size = _size2shape(size)
736
766
  dtype = dtype or environ.dftype()
737
- random_uniform = jr.uniform(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
767
+ random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
738
768
  r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
739
769
  return r
740
770
 
@@ -754,7 +784,7 @@ class RandomState(State):
754
784
  raise ValueError(f'"a" should be a scalar when "size" is provided. But we got {a}')
755
785
  size = _size2shape(size)
756
786
  dtype = dtype or environ.dftype()
757
- random_uniform = jr.uniform(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
787
+ random_uniform = uniform_for_unit(key=key, shape=size, minval=0, maxval=1, dtype=dtype)
758
788
  r = jnp.power(-jnp.log1p(-random_uniform), 1.0 / a)
759
789
  if scale is not None:
760
790
  r /= scale
@@ -15,14 +15,14 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- from functools import wraps, partial
18
+ import functools
19
+ from functools import partial
19
20
  from typing import Callable, Union
20
21
 
21
22
  import jax
22
23
  from jax import numpy as jnp
23
24
  from jax.core import Primitive, ShapedArray
24
- from jax.interpreters import batching, mlir, xla
25
- from jax.lax import cond
25
+ from jax.interpreters import batching, mlir
26
26
 
27
27
  from brainstate._utils import set_module_as
28
28
 
@@ -32,16 +32,39 @@ __all__ = [
32
32
 
33
33
 
34
34
  @set_module_as('brainstate.transform')
35
- def remove_vmap(x, op='any'):
35
+ def remove_vmap(x, op: str = 'any'):
36
36
  if op == 'any':
37
37
  return _any_without_vmap(x)
38
38
  elif op == 'all':
39
39
  return _all_without_vmap(x)
40
+ elif op == 'none':
41
+ return _without_vmap(x)
40
42
  else:
41
43
  raise ValueError(f'Do not support type: {op}')
42
44
 
43
45
 
44
- _any_no_vmap_prim = Primitive('any_no_vmap')
46
+ def _without_vmap(x):
47
+ return _no_vmap_prim.bind(x)
48
+
49
+
50
+ def _without_vmap_imp(x):
51
+ return x
52
+
53
+
54
+ def _without_vmap_abs(x):
55
+ return x
56
+
57
+
58
+ def _without_vmap_batch(x, batch_axes):
59
+ (x,) = x
60
+ return _without_vmap(x), batching.not_mapped
61
+
62
+
63
+ _no_vmap_prim = Primitive('no_vmap')
64
+ _no_vmap_prim.def_impl(_without_vmap_imp)
65
+ _no_vmap_prim.def_abstract_eval(_without_vmap_abs)
66
+ batching.primitive_batchers[_no_vmap_prim] = _without_vmap_batch
67
+ mlir.register_lowering(_no_vmap_prim, mlir.lower_fun(_without_vmap_imp, multiple_results=False))
45
68
 
46
69
 
47
70
  def _any_without_vmap(x):
@@ -61,16 +84,12 @@ def _any_without_vmap_batch(x, batch_axes):
61
84
  return _any_without_vmap(x), batching.not_mapped
62
85
 
63
86
 
87
+ _any_no_vmap_prim = Primitive('any_no_vmap')
64
88
  _any_no_vmap_prim.def_impl(_any_without_vmap_imp)
65
89
  _any_no_vmap_prim.def_abstract_eval(_any_without_vmap_abs)
66
90
  batching.primitive_batchers[_any_no_vmap_prim] = _any_without_vmap_batch
67
- if hasattr(xla, "lower_fun"):
68
- xla.register_translation(_any_no_vmap_prim,
69
- xla.lower_fun(_any_without_vmap_imp, multiple_results=False, new_style=True))
70
91
  mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False))
71
92
 
72
- _all_no_vmap_prim = Primitive('all_no_vmap')
73
-
74
93
 
75
94
  def _all_without_vmap(x):
76
95
  return _all_no_vmap_prim.bind(x)
@@ -89,50 +108,54 @@ def _all_without_vmap_batch(x, batch_axes):
89
108
  return _all_without_vmap(x), batching.not_mapped
90
109
 
91
110
 
111
+ _all_no_vmap_prim = Primitive('all_no_vmap')
92
112
  _all_no_vmap_prim.def_impl(_all_without_vmap_imp)
93
113
  _all_no_vmap_prim.def_abstract_eval(_all_without_vmap_abs)
94
114
  batching.primitive_batchers[_all_no_vmap_prim] = _all_without_vmap_batch
95
- if hasattr(xla, "lower_fun"):
96
- xla.register_translation(_all_no_vmap_prim,
97
- xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True))
98
115
  mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False))
99
116
 
100
117
 
101
- def _err_jit_true_branch(err_fun, x):
102
- jax.debug.callback(err_fun, x)
103
- return
118
+ def _err_jit_true_branch(err_fun, args, kwargs):
119
+ jax.debug.callback(err_fun, *args, **kwargs)
104
120
 
105
121
 
106
- def _err_jit_false_branch(x):
107
- return
122
+ def _err_jit_false_branch(args, kwargs):
123
+ pass
108
124
 
109
125
 
110
- def _cond(err_fun, pred, err_arg):
111
- @wraps(err_fun)
112
- def true_err_fun(*arg):
113
- err_fun(*arg)
126
+ def _error_msg(msg, *arg, **kwargs):
127
+ if len(arg):
128
+ msg = msg % arg
129
+ if len(kwargs):
130
+ msg = msg.format(**kwargs)
131
+ raise ValueError(msg)
114
132
 
115
- cond(pred,
116
- partial(_err_jit_true_branch, true_err_fun),
117
- _err_jit_false_branch,
118
- err_arg)
119
133
 
134
+ @set_module_as('brainstate.transform')
135
+ def jit_error(
136
+ pred,
137
+ err_fun: Union[Callable, str],
138
+ *err_args,
139
+ **err_kwargs,
140
+ ):
141
+ """
142
+ Check errors in a jit function.
120
143
 
121
- def _error_msg(msg, *arg):
122
- if len(arg) == 0:
123
- raise ValueError(msg)
124
- else:
125
- raise ValueError(msg.format(arg))
144
+ Examples
145
+ --------
126
146
 
147
+ It can give a function which receive arguments that passed from the JIT variables and raise errors.
127
148
 
128
- @set_module_as('brainstate.transform')
129
- def jit_error(pred, err_fun: Union[Callable, str], err_arg=None, scope: str = 'any'):
130
- """Check errors in a jit function.
149
+ >>> def error(x):
150
+ >>> raise ValueError(f'error {x}')
151
+ >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
152
+ >>> jit_error(x.sum() < 5., error, x)
153
+
154
+ Or, it can be a simple string message.
131
155
 
132
- >>> def error(arg):
133
- >>> raise ValueError(f'error {arg}')
134
156
  >>> x = jax.random.uniform(jax.random.PRNGKey(0), (10,))
135
- >>> jit_error(x.sum() < 5., error, err_arg=x)
157
+ >>> jit_error(x.sum() < 5., "Error: the sum is less than 5. Got {s}", s=x.sum())
158
+
136
159
 
137
160
  Parameters
138
161
  ----------
@@ -140,19 +163,18 @@ def jit_error(pred, err_fun: Union[Callable, str], err_arg=None, scope: str = 'a
140
163
  The boolean prediction.
141
164
  err_fun: callable
142
165
  The error function, which raise errors.
143
- err_arg: any
166
+ err_args:
144
167
  The arguments which passed into `err_f`.
145
- scope: str
146
- The scope of the error message. Can be None, 'all' or 'any'.
168
+ err_kwargs:
169
+ The keywords which passed into `err_f`.
147
170
  """
148
171
  if isinstance(err_fun, str):
149
172
  err_fun = partial(_error_msg, err_fun)
150
- if scope is None:
151
- pred = pred
152
- elif scope == 'all':
153
- pred = remove_vmap(pred, 'all')
154
- elif scope == 'any':
155
- pred = remove_vmap(pred, 'any')
156
- else:
157
- raise ValueError(f"Unknown scope: {scope}")
158
- _cond(err_fun, pred, err_arg)
173
+
174
+ jax.lax.cond(
175
+ remove_vmap(pred, op='any'),
176
+ partial(_err_jit_true_branch, err_fun),
177
+ _err_jit_false_branch,
178
+ jax.tree.map(functools.partial(remove_vmap, op='none'), err_args),
179
+ jax.tree.map(functools.partial(remove_vmap, op='none'), err_kwargs),
180
+ )
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import jax
19
+ import jaxlib.xla_extension
20
+ import jax.numpy as jnp
21
+
22
+ import brainstate as bst
23
+
24
+
25
+ class TestJitError(unittest.TestCase):
26
+ def test1(self):
27
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
28
+ bst.transform.jit_error(True, 'error')
29
+
30
+ def err_f(x):
31
+ raise ValueError(f'error: {x}')
32
+
33
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
34
+ bst.transform.jit_error(True, err_f, 1.)
35
+
36
+ def test_vmap(self):
37
+
38
+ def f(x):
39
+ bst.transform.jit_error(x, 'error: {x}', x=x)
40
+
41
+ jax.vmap(f)(jnp.array([False, False, False]))
42
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
43
+ jax.vmap(f)(jnp.array([True, False, False]))
44
+
45
+ def test_vmap_vmap(self):
46
+
47
+ def f(x):
48
+ bst.transform.jit_error(x, 'error: {x}', x=x)
49
+
50
+ jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
51
+ [False, False, False]]))
52
+ with self.assertRaises(jaxlib.xla_extension.XlaRuntimeError):
53
+ jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
54
+ [True, False, False]]))
55
+
@@ -369,11 +369,11 @@ class StatefulFunction(object):
369
369
  self._state_trace[cache_key] = _state_trace
370
370
  with _state_trace:
371
371
  out = self.fun(*args, **kwargs)
372
- state_values = _state_trace.collect_values('read', 'write')
372
+ state_values = _state_trace.collect_values('read', 'write', check_val_tree=True)
373
373
  _state_trace.recovery_original_values()
374
374
 
375
- # return states is not allowed
376
- # checking whether the states are returned
375
+ # State instance as functional returns is not allowed.
376
+ # Checking whether the states are returned.
377
377
  for leaf in jax.tree.leaves(out):
378
378
  if isinstance(leaf, State):
379
379
  leaf._raise_error_with_source_info(ValueError(f"State object is not allowed to be returned: {leaf}"))
@@ -399,7 +399,6 @@ class StatefulFunction(object):
399
399
  if cache_key not in self._state_trace:
400
400
  try:
401
401
  # jaxpr
402
- # jaxpr, (out_shapes, state_shapes) = jax.make_jaxpr(
403
402
  jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
404
403
  functools.partial(self._wrapped_fun_to_eval, cache_key),
405
404
  static_argnums=self.static_argnums,
@@ -435,8 +434,11 @@ class StatefulFunction(object):
435
434
  """
436
435
  # state checking
437
436
  cache_key = self.get_arg_cache_key(*args, **kwargs)
438
- states = self.get_states(cache_key)
437
+ states: Sequence[State] = self.get_states(cache_key)
439
438
  assert len(state_vals) == len(states), 'State length mismatch.'
439
+ # # No need to check, because the make_jaxpr() has been checked whether the value's tree is correct.
440
+ # for val, st in zip(state_vals, states): # check state's value tree structure
441
+ # st._check_value_tree(val)
440
442
 
441
443
  # parameters
442
444
  args = tuple(args[i] for i in range(len(args)) if i not in self.static_argnums)
@@ -450,6 +452,9 @@ class StatefulFunction(object):
450
452
  # output processing
451
453
  out, new_state_vals = out_treedef.unflatten(jaxpr_outs)
452
454
  assert len(new_state_vals) == len(state_vals), 'State length mismatch.'
455
+ # # No need to check, because the make_jaxpr() has been checked whether the value's tree is correct.
456
+ # for val, st in zip(new_state_vals, states): # check state's value tree structure
457
+ # st._check_value_tree(val)
453
458
  return new_state_vals, out
454
459
 
455
460
  def jaxpr_call_auto(self, *args, **kwargs) -> Any:
brainstate/typing.py CHANGED
@@ -16,6 +16,7 @@
16
16
 
17
17
  from typing import Any, Sequence, Protocol, Union
18
18
 
19
+ import brainunit as bu
19
20
  import jax
20
21
  import numpy as np
21
22
 
@@ -43,6 +44,7 @@ ArrayLike = Union[
43
44
  np.ndarray, # NumPy array type
44
45
  np.bool_, np.number, # NumPy scalar types
45
46
  bool, int, float, complex, # Python scalar types
47
+ bu.Quantity, # quantity
46
48
  ]
47
49
 
48
50
  # --- Dtype --- #
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.0.1.post20240622
3
+ Version: 0.0.1.post20240708
4
4
  Summary: A State-based Transformation System for Brain Dynamics Programming.
5
5
  Home-page: https://github.com/brainpy/brainstate
6
6
  Author: BrainPy Team