brainstate 0.0.1__py2.py3-none-any.whl → 0.0.1.post20240622__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. brainstate/__init__.py +4 -5
  2. brainstate/_module.py +191 -48
  3. brainstate/_module_test.py +95 -21
  4. brainstate/_state.py +17 -0
  5. brainstate/environ.py +2 -2
  6. brainstate/functional/__init__.py +3 -2
  7. brainstate/functional/_activations.py +7 -26
  8. brainstate/functional/_normalization.py +3 -0
  9. brainstate/functional/_others.py +49 -0
  10. brainstate/functional/_spikes.py +0 -1
  11. brainstate/mixin.py +2 -2
  12. brainstate/nn/__init__.py +4 -0
  13. brainstate/nn/_base.py +10 -7
  14. brainstate/nn/_dynamics.py +20 -0
  15. brainstate/nn/_elementwise.py +5 -4
  16. brainstate/nn/_embedding.py +66 -0
  17. brainstate/nn/_misc.py +4 -3
  18. brainstate/nn/_others.py +3 -2
  19. brainstate/nn/_poolings.py +21 -20
  20. brainstate/nn/_poolings_test.py +4 -4
  21. brainstate/nn/_rate_rnns.py +17 -0
  22. brainstate/nn/_readout.py +6 -0
  23. brainstate/optim/__init__.py +0 -1
  24. brainstate/optim/_lr_scheduler_test.py +13 -0
  25. brainstate/optim/_sgd_optimizer.py +18 -17
  26. brainstate/transform/__init__.py +2 -3
  27. brainstate/transform/_autograd.py +1 -1
  28. brainstate/transform/_autograd_test.py +0 -2
  29. brainstate/transform/_jit.py +47 -21
  30. brainstate/transform/_jit_test.py +0 -3
  31. brainstate/transform/_make_jaxpr.py +164 -3
  32. brainstate/transform/_make_jaxpr_test.py +0 -2
  33. brainstate/transform/_progress_bar.py +1 -3
  34. brainstate/util.py +0 -1
  35. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/METADATA +9 -17
  36. brainstate-0.0.1.post20240622.dist-info/RECORD +64 -0
  37. brainstate/math/__init__.py +0 -21
  38. brainstate/math/_einops.py +0 -787
  39. brainstate/math/_einops_parsing.py +0 -169
  40. brainstate/math/_einops_parsing_test.py +0 -126
  41. brainstate/math/_einops_test.py +0 -346
  42. brainstate/math/_misc.py +0 -298
  43. brainstate/math/_misc_test.py +0 -58
  44. brainstate/nn/functional/__init__.py +0 -25
  45. brainstate/nn/functional/_activations.py +0 -754
  46. brainstate/nn/functional/_normalization.py +0 -69
  47. brainstate/nn/functional/_spikes.py +0 -90
  48. brainstate/nn/init/__init__.py +0 -26
  49. brainstate/nn/init/_base.py +0 -36
  50. brainstate/nn/init/_generic.py +0 -175
  51. brainstate/nn/init/_random_inits.py +0 -489
  52. brainstate/nn/init/_regular_inits.py +0 -109
  53. brainstate/nn/surrogate.py +0 -1740
  54. brainstate-0.0.1.dist-info/RECORD +0 -79
  55. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/LICENSE +0 -0
  56. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/WHEEL +0 -0
  57. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/top_level.txt +0 -0
@@ -90,6 +90,9 @@ class ValinaRNNCell(RNNCell):
90
90
  def init_state(self, batch_size: int = None, **kwargs):
91
91
  self.h = ShortTermState(init.param(self._state_initializer, self.num_out, batch_size))
92
92
 
93
+ def reset_state(self, batch_size: int = None, **kwargs):
94
+ self.h.value = init.param(self._state_initializer, self.num_out, batch_size)
95
+
93
96
  def update(self, x):
94
97
  xh = jnp.concatenate([x, self.h.value], axis=-1)
95
98
  h = self.W(xh)
@@ -147,6 +150,9 @@ class GRUCell(RNNCell):
147
150
  def init_state(self, batch_size: int = None, **kwargs):
148
151
  self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
149
152
 
153
+ def reset_state(self, batch_size: int = None, **kwargs):
154
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
155
+
150
156
  def update(self, x):
151
157
  old_h = self.h.value
152
158
  xh = jnp.concatenate([x, old_h], axis=-1)
@@ -224,6 +230,9 @@ class MGUCell(RNNCell):
224
230
  def init_state(self, batch_size: int = None, **kwargs):
225
231
  self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
226
232
 
233
+ def reset_state(self, batch_size: int = None, **kwargs):
234
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
235
+
227
236
  def update(self, x):
228
237
  old_h = self.h.value
229
238
  xh = jnp.concatenate([x, old_h], axis=-1)
@@ -327,6 +336,10 @@ class LSTMCell(RNNCell):
327
336
  self.c = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
328
337
  self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
329
338
 
339
+ def reset_state(self, batch_size: int = None, **kwargs):
340
+ self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
341
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
342
+
330
343
  def update(self, x):
331
344
  h, c = self.h.value, self.c.value
332
345
  xh = jnp.concat([x, h], axis=-1)
@@ -379,6 +392,10 @@ class URLSTMCell(RNNCell):
379
392
  self.c = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
380
393
  self.h = ShortTermState(init.param(self._state_initializer, [self.num_out], batch_size))
381
394
 
395
+ def reset_state(self, batch_size: int = None, **kwargs):
396
+ self.c.value = init.param(self._state_initializer, [self.num_out], batch_size)
397
+ self.h.value = init.param(self._state_initializer, [self.num_out], batch_size)
398
+
382
399
  def update(self, x: ArrayLike) -> ArrayLike:
383
400
  h, c = self.h.value, self.c.value
384
401
  xh = jnp.concat([x, h], axis=-1)
brainstate/nn/_readout.py CHANGED
@@ -66,6 +66,9 @@ class LeakyRateReadout(DnnLayer):
66
66
  def init_state(self, batch_size=None, **kwargs):
67
67
  self.r = ShortTermState(init.param(init.Constant(0.), self.out_size, batch_size))
68
68
 
69
+ def reset_state(self, batch_size=None, **kwargs):
70
+ self.r.value = init.param(init.Constant(0.), self.out_size, batch_size)
71
+
69
72
  def update(self, x):
70
73
  r = self.decay * self.r.value + x @ self.weight.value
71
74
  self.r.value = r
@@ -109,6 +112,9 @@ class LeakySpikeReadout(Neuron):
109
112
  def init_state(self, batch_size, **kwargs):
110
113
  self.V = ShortTermState(init.param(init.Constant(0.), self.varshape, batch_size))
111
114
 
115
+ def reset_state(self, batch_size, **kwargs):
116
+ self.V.value = init.param(init.Constant(0.), self.varshape, batch_size)
117
+
112
118
  @property
113
119
  def spike(self):
114
120
  return self.get_spike(self.V.value)
@@ -20,4 +20,3 @@ from ._sgd_optimizer import *
20
20
  from ._sgd_optimizer import __all__ as optimizer_all
21
21
 
22
22
  __all__ = scheduler_all + optimizer_all
23
-
@@ -34,3 +34,16 @@ class TestMultiStepLR(unittest.TestCase):
34
34
  self.assertTrue(jnp.allclose(r, 0.001))
35
35
  else:
36
36
  self.assertTrue(jnp.allclose(r, 0.0001))
37
+
38
+ def test2(self):
39
+ lr = bst.transform.jit(bst.optim.MultiStepLR(0.1, [10, 20, 30], gamma=0.1))
40
+ for i in range(40):
41
+ r = lr(i)
42
+ if i < 10:
43
+ self.assertEqual(r, 0.1)
44
+ elif i < 20:
45
+ self.assertTrue(jnp.allclose(r, 0.01))
46
+ elif i < 30:
47
+ self.assertTrue(jnp.allclose(r, 0.001))
48
+ else:
49
+ self.assertTrue(jnp.allclose(r, 0.0001))
@@ -18,11 +18,12 @@
18
18
  import functools
19
19
  from typing import Union, Dict, Optional, Tuple, Any, TypeVar
20
20
 
21
+ import brainunit as bu
21
22
  import jax
22
23
  import jax.numpy as jnp
23
24
 
24
25
  from ._lr_scheduler import make_schedule, LearningRateScheduler
25
- from .. import environ, math
26
+ from .. import environ
26
27
  from .._module import Module
27
28
  from .._state import State, LongTermState, StateDictManager, visible_state_dict
28
29
 
@@ -282,7 +283,7 @@ class Momentum(_WeightDecayOptimizer):
282
283
  for k, v in train_states.items():
283
284
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
284
285
  self.weight_states.add_unique_elem(k, v)
285
- self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
286
+ self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
286
287
 
287
288
  def update(self, grads: dict):
288
289
  lr = self.lr()
@@ -349,7 +350,7 @@ class MomentumNesterov(_WeightDecayOptimizer):
349
350
  for k, v in train_states.items():
350
351
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
351
352
  self.weight_states.add_unique_elem(k, v)
352
- self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
353
+ self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
353
354
 
354
355
  def update(self, grads: dict):
355
356
  lr = self.lr()
@@ -417,7 +418,7 @@ class Adagrad(_WeightDecayOptimizer):
417
418
  for k, v in train_states.items():
418
419
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
419
420
  self.weight_states.add_unique_elem(k, v)
420
- self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
421
+ self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
421
422
 
422
423
  def update(self, grads: dict):
423
424
  lr = self.lr()
@@ -500,8 +501,8 @@ class Adadelta(_WeightDecayOptimizer):
500
501
  for k, v in train_states.items():
501
502
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
502
503
  self.weight_states.add_unique_elem(k, v)
503
- self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
504
- self.delta_states[k] = OptimState(math.tree_zeros_like(v.value))
504
+ self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
505
+ self.delta_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
505
506
 
506
507
  def update(self, grads: dict):
507
508
  weight_values, grad_values, cache_values, delta_values = to_same_dict_tree(
@@ -574,7 +575,7 @@ class RMSProp(_WeightDecayOptimizer):
574
575
  for k, v in train_states.items():
575
576
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
576
577
  self.weight_states.add_unique_elem(k, v)
577
- self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
578
+ self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
578
579
 
579
580
  def update(self, grads: dict):
580
581
  lr = self.lr()
@@ -647,8 +648,8 @@ class Adam(_WeightDecayOptimizer):
647
648
  for k, v in train_states.items():
648
649
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
649
650
  self.weight_states.add_unique_elem(k, v)
650
- self.m1_states[k] = OptimState(math.tree_zeros_like(v.value))
651
- self.m2_states[k] = OptimState(math.tree_zeros_like(v.value))
651
+ self.m1_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
652
+ self.m2_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
652
653
 
653
654
  def update(self, grads: dict):
654
655
  lr = self.lr()
@@ -730,7 +731,7 @@ class LARS(_WeightDecayOptimizer):
730
731
  for k, v in train_states.items():
731
732
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
732
733
  self.weight_states.add_unique_elem(k, v)
733
- self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
734
+ self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
734
735
 
735
736
  def update(self, grads: dict):
736
737
  lr = self.lr()
@@ -835,10 +836,10 @@ class Adan(_WeightDecayOptimizer):
835
836
  for k, v in train_states.items():
836
837
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
837
838
  self.weight_states.add_unique_elem(k, v)
838
- self.exp_avg_states[k] = OptimState(math.tree_zeros_like(v.value))
839
- self.exp_avg_sq_states[k] = OptimState(math.tree_zeros_like(v.value))
840
- self.exp_avg_diff_states[k] = OptimState(math.tree_zeros_like(v.value))
841
- self.pre_grad_states[k] = OptimState(math.tree_zeros_like(v.value))
839
+ self.exp_avg_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
840
+ self.exp_avg_sq_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
841
+ self.exp_avg_diff_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
842
+ self.pre_grad_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
842
843
 
843
844
  def update(self, grads: dict):
844
845
  lr = self.lr()
@@ -989,10 +990,10 @@ class AdamW(_WeightDecayOptimizer):
989
990
  for k, v in train_states.items():
990
991
  assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
991
992
  self.weight_states.add_unique_elem(k, v)
992
- self.m1_states[k] = OptimState(math.tree_zeros_like(v.value))
993
- self.m2_states[k] = OptimState(math.tree_zeros_like(v.value))
993
+ self.m1_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
994
+ self.m2_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
994
995
  if self.amsgrad:
995
- self.vmax_states[k] = OptimState(math.tree_zeros_like(v.value))
996
+ self.vmax_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
996
997
 
997
998
  def update(self, grads: dict):
998
999
  lr_old = self.lr()
@@ -17,10 +17,10 @@
17
17
  This module contains the functions for the transformation of the brain data.
18
18
  """
19
19
 
20
- from ._control import *
21
- from ._control import __all__ as _controls_all
22
20
  from ._autograd import *
23
21
  from ._autograd import __all__ as _gradients_all
22
+ from ._control import *
23
+ from ._control import __all__ as _controls_all
24
24
  from ._jit import *
25
25
  from ._jit import __all__ as _jit_all
26
26
  from ._jit_error import *
@@ -33,4 +33,3 @@ from ._progress_bar import __all__ as _progress_bar_all
33
33
  __all__ = _gradients_all + _jit_error_all + _controls_all + _make_jaxpr_all + _jit_all + _progress_bar_all
34
34
 
35
35
  del _gradients_all, _jit_error_all, _controls_all, _make_jaxpr_all, _jit_all, _progress_bar_all
36
-
@@ -25,8 +25,8 @@ from jax._src.api import _vjp
25
25
  from jax.api_util import argnums_partial
26
26
  from jax.extend import linear_util
27
27
 
28
- from brainstate._utils import set_module_as
29
28
  from brainstate._state import State, StateTrace, StateDictManager
29
+ from brainstate._utils import set_module_as
30
30
 
31
31
  __all__ = [
32
32
  'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
@@ -537,7 +537,6 @@ class TestPureFuncJacobian(unittest.TestCase):
537
537
 
538
538
  def test_jacrev_return_aux1(self):
539
539
  with bc.environ.context(precision=64):
540
-
541
540
  def f1(x, y):
542
541
  a = 4 * x[1] ** 2 - 2 * x[2]
543
542
  r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])])
@@ -564,7 +563,6 @@ class TestPureFuncJacobian(unittest.TestCase):
564
563
  assert (vec == _r).all()
565
564
 
566
565
 
567
-
568
566
  class TestClassFuncJacobian(unittest.TestCase):
569
567
  def test_jacrev1(self):
570
568
  def f1(x, y):
@@ -23,8 +23,8 @@ import jax
23
23
  from jax._src import sharding_impls
24
24
  from jax.lib import xla_client as xc
25
25
 
26
- from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
27
26
  from brainstate._utils import set_module_as
27
+ from ._make_jaxpr import StatefulFunction, _ensure_index_tuple, _assign_state_values
28
28
 
29
29
  __all__ = ['jit']
30
30
 
@@ -33,10 +33,13 @@ class JittedFunction(Callable):
33
33
  """
34
34
  A wrapped version of ``fun``, set up for just-in-time compilation.
35
35
  """
36
- origin_fun: Callable # the original function
36
+ origin_fun: Callable # the original function
37
37
  stateful_fun: StatefulFunction # the stateful function for extracting states
38
38
  jitted_fun: jax.stages.Wrapped # the jitted function
39
- clear_cache: Callable # clear the cache of the jitted function
39
+ clear_cache: Callable # clear the cache of the jitted function
40
+
41
+ def __call__(self, *args, **kwargs):
42
+ pass
40
43
 
41
44
 
42
45
  def _get_jitted_fun(
@@ -85,12 +88,16 @@ def _get_jitted_fun(
85
88
  jit_fun.clear_cache()
86
89
 
87
90
  jitted_fun: JittedFunction
91
+
88
92
  # the original function
89
93
  jitted_fun.origin_fun = fun.fun
94
+
90
95
  # the stateful function for extracting states
91
96
  jitted_fun.stateful_fun = fun
97
+
92
98
  # the jitted function
93
99
  jitted_fun.jitted_fun = jit_fun
100
+
94
101
  # clear cache
95
102
  jitted_fun.clear_cache = clear_cache
96
103
 
@@ -99,18 +106,18 @@ def _get_jitted_fun(
99
106
 
100
107
  @set_module_as('brainstate.transform')
101
108
  def jit(
102
- fun: Callable = None,
103
- in_shardings=sharding_impls.UNSPECIFIED,
104
- out_shardings=sharding_impls.UNSPECIFIED,
105
- static_argnums: int | Sequence[int] | None = None,
106
- donate_argnums: int | Sequence[int] | None = None,
107
- donate_argnames: str | Iterable[str] | None = None,
108
- keep_unused: bool = False,
109
- device: xc.Device | None = None,
110
- backend: str | None = None,
111
- inline: bool = False,
112
- abstracted_axes: Any | None = None,
113
- **kwargs
109
+ fun: Callable = None,
110
+ in_shardings=sharding_impls.UNSPECIFIED,
111
+ out_shardings=sharding_impls.UNSPECIFIED,
112
+ static_argnums: int | Sequence[int] | None = None,
113
+ donate_argnums: int | Sequence[int] | None = None,
114
+ donate_argnames: str | Iterable[str] | None = None,
115
+ keep_unused: bool = False,
116
+ device: xc.Device | None = None,
117
+ backend: str | None = None,
118
+ inline: bool = False,
119
+ abstracted_axes: Any | None = None,
120
+ **kwargs
114
121
  ) -> Union[JittedFunction, Callable[[Callable], JittedFunction]]:
115
122
  """
116
123
  Sets up ``fun`` for just-in-time compilation with XLA.
@@ -228,12 +235,31 @@ def jit(
228
235
 
229
236
  if fun is None:
230
237
  def wrapper(fun_again: Callable) -> JittedFunction:
231
- return _get_jitted_fun(fun_again, in_shardings, out_shardings, static_argnums,
232
- donate_argnums, donate_argnames, keep_unused,
233
- device, backend, inline, abstracted_axes, **kwargs)
238
+ return _get_jitted_fun(fun_again,
239
+ in_shardings,
240
+ out_shardings,
241
+ static_argnums,
242
+ donate_argnums,
243
+ donate_argnames,
244
+ keep_unused,
245
+ device,
246
+ backend,
247
+ inline,
248
+ abstracted_axes,
249
+ **kwargs)
250
+
234
251
  return wrapper
235
252
 
236
253
  else:
237
- return _get_jitted_fun(fun, in_shardings, out_shardings, static_argnums,
238
- donate_argnums, donate_argnames, keep_unused,
239
- device, backend, inline, abstracted_axes, **kwargs)
254
+ return _get_jitted_fun(fun,
255
+ in_shardings,
256
+ out_shardings,
257
+ static_argnums,
258
+ donate_argnums,
259
+ donate_argnames,
260
+ keep_unused,
261
+ device,
262
+ backend,
263
+ inline,
264
+ abstracted_axes,
265
+ **kwargs)
@@ -16,7 +16,6 @@
16
16
  import unittest
17
17
 
18
18
  import jax.numpy as jnp
19
- import jax.stages
20
19
 
21
20
  import brainstate as bc
22
21
 
@@ -90,7 +89,6 @@ class TestJIT(unittest.TestCase):
90
89
  self.assertTrue(len(compiling) == 2)
91
90
 
92
91
  def test_jit_attribute_origin_fun(self):
93
-
94
92
  def fun1(x):
95
93
  return x
96
94
 
@@ -99,4 +97,3 @@ class TestJIT(unittest.TestCase):
99
97
  self.assertTrue(isinstance(jitted_fun.stateful_fun, bc.transform.StatefulFunction))
100
98
  self.assertTrue(callable(jitted_fun.jitted_fun))
101
99
  self.assertTrue(callable(jitted_fun.clear_cache))
102
-
@@ -54,20 +54,26 @@ function.
54
54
  from __future__ import annotations
55
55
 
56
56
  import functools
57
+ import inspect
57
58
  import operator
58
59
  from collections.abc import Hashable, Iterable, Sequence
60
+ from contextlib import ExitStack
59
61
  from typing import Any, Callable, Tuple, Union, Dict, Optional
60
62
 
61
63
  import jax
62
64
  from jax._src import source_info_util
65
+ from jax._src.linear_util import annotate
66
+ from jax._src.traceback_util import api_boundary
67
+ from jax.extend.linear_util import transformation_with_aux, wrap_init
63
68
  from jax.interpreters import partial_eval as pe
64
- from jax.util import wraps
65
69
  from jax.interpreters.xla import abstractify
70
+ from jax.util import wraps
66
71
 
67
72
  from brainstate._state import State, StateTrace
68
73
  from brainstate._utils import set_module_as
69
74
 
70
75
  PyTree = Any
76
+ AxisName = Hashable
71
77
 
72
78
  __all__ = [
73
79
  "StatefulFunction",
@@ -393,7 +399,8 @@ class StatefulFunction(object):
393
399
  if cache_key not in self._state_trace:
394
400
  try:
395
401
  # jaxpr
396
- jaxpr, (out_shapes, state_shapes) = jax.make_jaxpr(
402
+ # jaxpr, (out_shapes, state_shapes) = jax.make_jaxpr(
403
+ jaxpr, (out_shapes, state_shapes) = _make_jaxpr(
397
404
  functools.partial(self._wrapped_fun_to_eval, cache_key),
398
405
  static_argnums=self.static_argnums,
399
406
  axis_env=self.axis_env,
@@ -474,7 +481,8 @@ def make_jaxpr(
474
481
  state_returns: Union[str, Tuple[str, ...]] = ('read', 'write')
475
482
  ) -> Callable[..., (Tuple[jax.core.ClosedJaxpr, Tuple[State, ...]] |
476
483
  Tuple[jax.core.ClosedJaxpr, Tuple[State, ...], PyTree])]:
477
- """Creates a function that produces its jaxpr given example args.
484
+ """
485
+ Creates a function that produces its jaxpr given example args.
478
486
 
479
487
  Args:
480
488
  fun: The function whose ``jaxpr`` is to be computed. Its positional
@@ -571,3 +579,156 @@ def make_jaxpr(
571
579
  if hasattr(fun, "__name__"):
572
580
  make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
573
581
  return make_jaxpr_f
582
+
583
+
584
+ def _check_callable(fun):
585
+ # In Python 3.10+, the only thing stopping us from supporting staticmethods
586
+ # is that we can't take weak references to them, which the C++ JIT requires.
587
+ if isinstance(fun, staticmethod):
588
+ raise TypeError(f"staticmethod arguments are not supported, got {fun}")
589
+ if not callable(fun):
590
+ raise TypeError(f"Expected a callable value, got {fun}")
591
+ if inspect.isgeneratorfunction(fun):
592
+ raise TypeError(f"Expected a function, got a generator function: {fun}")
593
+
594
+
595
+ def _broadcast_prefix(
596
+ prefix_tree: Any,
597
+ full_tree: Any,
598
+ is_leaf: Callable[[Any], bool] | None = None
599
+ ) -> list[Any]:
600
+ # If prefix_tree is not a tree prefix of full_tree, this code can raise a
601
+ # ValueError; use prefix_errors to find disagreements and raise more precise
602
+ # error messages.
603
+ result = []
604
+ num_leaves = lambda t: jax.tree.structure(t).num_leaves
605
+ add_leaves = lambda x, subtree: result.extend([x] * num_leaves(subtree))
606
+ jax.tree.map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
607
+ return result
608
+
609
+
610
+ def _flat_axes_specs(
611
+ abstracted_axes, *args, **kwargs
612
+ ) -> list[pe.AbstractedAxesSpec]:
613
+ if kwargs:
614
+ raise NotImplementedError
615
+
616
+ def ax_leaf(l):
617
+ return (isinstance(l, dict) and jax.tree_util.all_leaves(l.values()) or
618
+ isinstance(l, tuple) and jax.tree_util.all_leaves(l, lambda x: x is None))
619
+
620
+ return _broadcast_prefix(abstracted_axes, args, ax_leaf)
621
+
622
+
623
+ @transformation_with_aux
624
+ def _flatten_fun(in_tree, *args_flat):
625
+ py_args, py_kwargs = jax.tree.unflatten(in_tree, args_flat)
626
+ ans = yield py_args, py_kwargs
627
+ yield jax.tree.flatten(ans)
628
+
629
+
630
+ def _make_jaxpr(
631
+ fun: Callable,
632
+ static_argnums: int | Iterable[int] = (),
633
+ axis_env: Sequence[tuple[AxisName, int]] | None = None,
634
+ return_shape: bool = False,
635
+ abstracted_axes: Any | None = None,
636
+ ) -> Callable[..., (jax.core.ClosedJaxpr | tuple[jax.core.ClosedJaxpr, Any])]:
637
+ """Creates a function that produces its jaxpr given example args.
638
+
639
+ Args:
640
+ fun: The function whose ``jaxpr`` is to be computed. Its positional
641
+ arguments and return value should be arrays, scalars, or standard Python
642
+ containers (tuple/list/dict) thereof.
643
+ static_argnums: See the :py:func:`jax.jit` docstring.
644
+ axis_env: Optional, a sequence of pairs where the first element is an axis
645
+ name and the second element is a positive integer representing the size of
646
+ the mapped axis with that name. This parameter is useful when lowering
647
+ functions that involve parallel communication collectives, and it
648
+ specifies the axis name/size environment that would be set up by
649
+ applications of :py:func:`jax.pmap`.
650
+ return_shape: Optional boolean, defaults to ``False``. If ``True``, the
651
+ wrapped function returns a pair where the first element is the
652
+ ``ClosedJaxpr`` representation of ``fun`` and the second element is a
653
+ pytree with the same structure as the output of ``fun`` and where the
654
+ leaves are objects with ``shape``, ``dtype``, and ``named_shape``
655
+ attributes representing the corresponding types of the output leaves.
656
+
657
+ Returns:
658
+ A wrapped version of ``fun`` that when applied to example arguments returns
659
+ a ``ClosedJaxpr`` representation of ``fun`` on those arguments. If the
660
+ argument ``return_shape`` is ``True``, then the returned function instead
661
+ returns a pair where the first element is the ``ClosedJaxpr``
662
+ representation of ``fun`` and the second element is a pytree representing
663
+ the structure, shape, dtypes, and named shapes of the output of ``fun``.
664
+
665
+ A ``jaxpr`` is JAX's intermediate representation for program traces. The
666
+ ``jaxpr`` language is based on the simply-typed first-order lambda calculus
667
+ with let-bindings. :py:func:`make_jaxpr` adapts a function to return its
668
+ ``jaxpr``, which we can inspect to understand what JAX is doing internally.
669
+ The ``jaxpr`` returned is a trace of ``fun`` abstracted to
670
+ :py:class:`ShapedArray` level. Other levels of abstraction exist internally.
671
+
672
+ We do not describe the semantics of the ``jaxpr`` language in detail here, but
673
+ instead give a few examples.
674
+
675
+ >>> import jax
676
+ >>>
677
+ >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
678
+ >>> print(f(3.0))
679
+ -0.83602
680
+ >>> _make_jaxpr(f)(3.0)
681
+ { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }
682
+ >>> _make_jaxpr(jax.grad(f))(3.0)
683
+ { lambda ; a:f32[]. let
684
+ b:f32[] = cos a
685
+ c:f32[] = sin a
686
+ _:f32[] = sin b
687
+ d:f32[] = cos b
688
+ e:f32[] = mul 1.0 d
689
+ f:f32[] = neg e
690
+ g:f32[] = mul f c
691
+ in (g,) }
692
+ """
693
+ _check_callable(fun)
694
+ static_argnums = _ensure_index_tuple(static_argnums)
695
+
696
+ def _abstractify(args, kwargs):
697
+ flat_args, in_tree = jax.tree.flatten((args, kwargs))
698
+ if abstracted_axes is None:
699
+ return map(jax.api_util.shaped_abstractify, flat_args), in_tree, [True] * len(flat_args)
700
+ else:
701
+ axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)
702
+ in_type = pe.infer_lambda_input_type(axes_specs, flat_args)
703
+ in_avals, keep_inputs = jax.util.unzip2(in_type)
704
+ return in_avals, in_tree, keep_inputs
705
+
706
+ @wraps(fun)
707
+ @api_boundary
708
+ def make_jaxpr_f(*args, **kwargs):
709
+ f = wrap_init(fun)
710
+ if static_argnums:
711
+ dyn_argnums = [i for i in range(len(args)) if i not in static_argnums]
712
+ f, args = jax.api_util.argnums_partial(f, dyn_argnums, args)
713
+ in_avals, in_tree, keep_inputs = _abstractify(args, kwargs)
714
+ in_type = tuple(jax.util.safe_zip(in_avals, keep_inputs))
715
+ f, out_tree = _flatten_fun(f, in_tree)
716
+ f = annotate(f, in_type)
717
+ debug_info = pe.debug_info(fun, in_tree, out_tree, True, 'make_jaxpr')
718
+ with ExitStack() as stack:
719
+ for axis_name, size in axis_env or []:
720
+ stack.enter_context(jax.core.extend_axis_env(axis_name, size, None))
721
+ jaxpr, out_type, consts = pe.trace_to_jaxpr_dynamic2(f, debug_info=debug_info)
722
+ closed_jaxpr = jax.core.ClosedJaxpr(jaxpr, consts)
723
+ if return_shape:
724
+ out_avals, _ = jax.util.unzip2(out_type)
725
+ out_shapes_flat = [jax.ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
726
+ return closed_jaxpr, jax.tree.unflatten(out_tree(), out_shapes_flat)
727
+ return closed_jaxpr
728
+
729
+ make_jaxpr_f.__module__ = "brainstate.transform"
730
+ if hasattr(fun, "__qualname__"):
731
+ make_jaxpr_f.__qualname__ = f"make_jaxpr({fun.__qualname__})"
732
+ if hasattr(fun, "__name__"):
733
+ make_jaxpr_f.__name__ = f"make_jaxpr({fun.__name__})"
734
+ return make_jaxpr_f
@@ -129,5 +129,3 @@ def test_return_states():
129
129
 
130
130
  with pytest.raises(ValueError):
131
131
  f()
132
-
133
-
@@ -14,13 +14,12 @@
14
14
  # ==============================================================================
15
15
 
16
16
  from __future__ import annotations
17
+
17
18
  import copy
18
19
  from typing import Optional
19
20
 
20
21
  import jax
21
22
 
22
- from brainstate import environ
23
-
24
23
  try:
25
24
  from tqdm.auto import tqdm
26
25
  except (ImportError, ModuleNotFoundError):
@@ -95,7 +94,6 @@ class ProgressBarRunner(object):
95
94
  self.tqdm_bars[0].close()
96
95
 
97
96
  def __call__(self, iter_num, *args, **kwargs):
98
-
99
97
  _ = jax.lax.cond(
100
98
  iter_num == 0,
101
99
  lambda: jax.debug.callback(self._define_tqdm),
brainstate/util.py CHANGED
@@ -26,7 +26,6 @@ from jax.lib import xla_bridge
26
26
 
27
27
  from ._utils import set_module_as
28
28
 
29
-
30
29
  __all__ = [
31
30
  'unique_name',
32
31
  'clear_buffer_memory',
@@ -1,8 +1,8 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: brainstate
3
- Version: 0.0.1
3
+ Version: 0.0.1.post20240622
4
4
  Summary: A State-based Transformation System for Brain Dynamics Programming.
5
- Home-page: https://github.com/brainpy/braincore
5
+ Home-page: https://github.com/brainpy/brainstate
6
6
  Author: BrainPy Team
7
7
  Author-email: BrainPy Team <chao.brain@qq.com>
8
8
  License: Apache-2.0 license
@@ -31,21 +31,11 @@ License-File: LICENSE
31
31
  Requires-Dist: jax
32
32
  Requires-Dist: jaxlib
33
33
  Requires-Dist: numpy
34
+ Requires-Dist: brainunit
34
35
  Provides-Extra: cpu
35
36
  Requires-Dist: jaxlib ; extra == 'cpu'
36
- Requires-Dist: brainpylib ; extra == 'cpu'
37
- Provides-Extra: cpu_mini
38
- Requires-Dist: jaxlib ; extra == 'cpu_mini'
39
- Provides-Extra: cuda11
40
- Requires-Dist: jaxlib[cuda11_pip] ; extra == 'cuda11'
41
- Requires-Dist: brainpylib ; extra == 'cuda11'
42
- Provides-Extra: cuda11_mini
43
- Requires-Dist: jaxlib[cuda11_pip] ; extra == 'cuda11_mini'
44
37
  Provides-Extra: cuda12
45
38
  Requires-Dist: jaxlib[cuda12_pip] ; extra == 'cuda12'
46
- Requires-Dist: brainpylib ; extra == 'cuda12'
47
- Provides-Extra: cuda12_mini
48
- Requires-Dist: jaxlib[cuda12_pip] ; extra == 'cuda12_mini'
49
39
  Provides-Extra: testing
50
40
  Requires-Dist: pytest ; extra == 'testing'
51
41
  Provides-Extra: tpu
@@ -90,12 +80,14 @@ The official documentation is hosted on Read the Docs: [https://brainstate.readt
90
80
 
91
81
  ## See also the BDP ecosystem
92
82
 
93
- - [``brainpy``](https://github.com/brainpy/BrainPy): The solution for the general-purpose brain dynamics programming.
83
+ - [``brainstate``](https://github.com/brainpy/brainstate): A ``State``-based transformation system for brain dynamics programming.
94
84
 
95
- - [``brainstate``](https://github.com/brainpy/brainstate): The core system for the next generation of BrainPy framework.
85
+ - [``brainunit``](https://github.com/brainpy/brainunit): The unit system for brain dynamics programming.
96
86
 
97
- - [``braintools``](https://github.com/brainpy/braintools): The tools for the brain dynamics simulation and analysis.
87
+ - [``braintaichi``](https://github.com/brainpy/braintaichi): Leveraging Taichi Lang to customize brain dynamics operators.
98
88
 
99
- - [``brainscale``](https://github.com/brainpy/brainscale): The scalable online learning for biological spiking neural networks.
89
+ - [``brainscale``](https://github.com/brainpy/brainscale): The scalable online learning framework for biological neural networks.
90
+
91
+ - [``braintools``](https://github.com/brainpy/braintools): The toolbox for the brain dynamics simulation, training and analysis.
100
92
 
101
93