brainstate 0.1.0.post20250105__py2.py3-none-any.whl → 0.1.0.post20250126__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. brainstate/__init__.py +1 -2
  2. brainstate/_state.py +77 -44
  3. brainstate/_state_test.py +0 -17
  4. brainstate/augment/__init__.py +10 -20
  5. brainstate/augment/_eval_shape.py +9 -10
  6. brainstate/augment/_eval_shape_test.py +1 -1
  7. brainstate/augment/_mapping.py +265 -277
  8. brainstate/augment/_mapping_test.py +147 -175
  9. brainstate/compile/__init__.py +18 -37
  10. brainstate/compile/_ad_checkpoint.py +6 -4
  11. brainstate/compile/_jit.py +37 -28
  12. brainstate/compile/_loop_collect_return.py +6 -3
  13. brainstate/compile/_loop_no_collection.py +2 -0
  14. brainstate/compile/_make_jaxpr.py +15 -4
  15. brainstate/compile/_make_jaxpr_test.py +10 -6
  16. brainstate/compile/_progress_bar.py +68 -40
  17. brainstate/compile/_unvmap.py +9 -6
  18. brainstate/graph/__init__.py +12 -16
  19. brainstate/graph/_graph_node.py +1 -23
  20. brainstate/graph/_graph_operation.py +1 -1
  21. brainstate/graph/_graph_operation_test.py +0 -159
  22. brainstate/nn/_dyn_impl/_inputs.py +124 -39
  23. brainstate/nn/_elementwise/_dropout_test.py +1 -1
  24. brainstate/nn/_interaction/_conv.py +4 -2
  25. brainstate/nn/_interaction/_linear.py +84 -10
  26. brainstate/random/_rand_funs.py +9 -2
  27. brainstate/random/_rand_seed.py +12 -2
  28. brainstate/random/_rand_state.py +50 -179
  29. brainstate/surrogate.py +5 -1
  30. brainstate/util/__init__.py +0 -4
  31. brainstate/util/_caller.py +1 -1
  32. brainstate/util/_dict.py +4 -1
  33. brainstate/util/_filter.py +1 -1
  34. brainstate/util/_pretty_repr.py +1 -1
  35. brainstate/util/_struct.py +1 -1
  36. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/METADATA +2 -1
  37. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/RECORD +40 -60
  38. brainstate/event/__init__.py +0 -29
  39. brainstate/event/_csr.py +0 -906
  40. brainstate/event/_csr_mv.py +0 -303
  41. brainstate/event/_csr_mv_benchmark.py +0 -14
  42. brainstate/event/_csr_mv_test.py +0 -118
  43. brainstate/event/_csr_test.py +0 -90
  44. brainstate/event/_fixedprob_mv.py +0 -730
  45. brainstate/event/_fixedprob_mv_benchmark.py +0 -128
  46. brainstate/event/_fixedprob_mv_test.py +0 -132
  47. brainstate/event/_linear_mv.py +0 -359
  48. brainstate/event/_linear_mv_benckmark.py +0 -82
  49. brainstate/event/_linear_mv_test.py +0 -117
  50. brainstate/event/_misc.py +0 -34
  51. brainstate/event/_xla_custom_op.py +0 -313
  52. brainstate/event/_xla_custom_op_test.py +0 -55
  53. brainstate/graph/_graph_context.py +0 -443
  54. brainstate/graph/_graph_context_test.py +0 -65
  55. brainstate/graph/_graph_convert.py +0 -246
  56. brainstate/util/_tracers.py +0 -68
  57. brainstate/util/_visualization.py +0 -47
  58. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/LICENSE +0 -0
  59. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/WHEEL +0 -0
  60. {brainstate-0.1.0.post20250105.dist-info → brainstate-0.1.0.post20250126.dist-info}/top_level.txt +0 -0
brainstate/__init__.py CHANGED
@@ -22,7 +22,6 @@ __version__ = "0.1.0"
22
22
  from . import augment
23
23
  from . import compile
24
24
  from . import environ
25
- from . import event
26
25
  from . import functional
27
26
  from . import graph
28
27
  from . import init
@@ -39,7 +38,7 @@ from ._state import __all__ as _state_all
39
38
 
40
39
  __all__ = (
41
40
  [
42
- 'augment', 'compile', 'environ', 'event', 'functional',
41
+ 'augment', 'compile', 'environ', 'functional',
43
42
  'graph', 'init', 'mixin', 'nn', 'optim', 'random',
44
43
  'surrogate', 'typing', 'util',
45
44
  # deprecated
brainstate/_state.py CHANGED
@@ -28,13 +28,14 @@ from jax.api_util import shaped_abstractify
28
28
  from jax.extend import source_info_util
29
29
 
30
30
  from brainstate.typing import ArrayLike, PyTree, Missing
31
- from brainstate.util import DictManager, PrettyRepr, PrettyType, PrettyAttr, TraceContextError
32
- from brainstate.util._tracers import StateJaxTracer
31
+ from brainstate.util import DictManager, PrettyRepr, PrettyType, PrettyAttr
33
32
 
34
33
  __all__ = [
35
34
  'State', 'ShortTermState', 'LongTermState', 'HiddenState', 'ParamState', 'TreefyState',
35
+ 'FakedState',
36
36
 
37
37
  'StateDictManager', 'StateTraceStack', 'check_state_value_tree', 'check_state_jax_tracer', 'catch_new_states',
38
+ 'maybe_state',
38
39
  ]
39
40
 
40
41
  A = TypeVar('A')
@@ -102,6 +103,7 @@ class Catcher:
102
103
  """
103
104
  The catcher to catch the new states.
104
105
  """
106
+
105
107
  def __init__(self, tag: str):
106
108
  self.tag = tag
107
109
  self.state_ids = set()
@@ -114,6 +116,13 @@ class Catcher:
114
116
  state.tag = self.tag
115
117
 
116
118
 
119
+ def maybe_state(val: Any):
120
+ if isinstance(val, State):
121
+ return val.value
122
+ else:
123
+ return val
124
+
125
+
117
126
  @contextlib.contextmanager
118
127
  def check_state_jax_tracer(val: bool = True) -> None:
119
128
  """
@@ -197,7 +206,6 @@ class State(Generic[A], PrettyRepr):
197
206
  value: PyTree. It can be anything as a pyTree.
198
207
  """
199
208
  __module__ = 'brainstate'
200
- _trace_state: StateJaxTracer
201
209
  _level: int
202
210
  _source_info: source_info_util.SourceInfo
203
211
  _name: Optional[str]
@@ -213,9 +221,6 @@ class State(Generic[A], PrettyRepr):
213
221
  ):
214
222
  tag = metadata.pop('tag', None)
215
223
 
216
- # avoid using self._setattr to avoid the check
217
- vars(self)['_trace_state'] = StateJaxTracer()
218
-
219
224
  # set the value and metadata
220
225
  if isinstance(value, StateMetadata):
221
226
  metadata.update(dict(value.metadata))
@@ -237,31 +242,6 @@ class State(Generic[A], PrettyRepr):
237
242
  # record the state initialization
238
243
  record_state_init(self)
239
244
 
240
- if not TYPE_CHECKING:
241
- def __setattr__(self, name: str, value: Any) -> None:
242
- return self._setattr(name, value)
243
-
244
- def _setattr(self, name: str, value: Any):
245
- """
246
- Check if the state is valid to mutate.
247
- """
248
- if TRACE_CONTEXT.jax_tracer_check[-1]:
249
- self.check_valid_trace(lambda: f'Cannot mutate {type(self).__name__} from a different trace level')
250
- object.__setattr__(self, name, value)
251
-
252
- def _setattr_no_check(self, name: str, value: Any):
253
- """
254
- Set the attribute without checking the trace level.
255
- """
256
- vars(self)[name] = value
257
-
258
- def check_valid_trace(self, error_msg: Callable[[], str]):
259
- """
260
- Check if the state is valid to trace.
261
- """
262
- if not self._trace_state.is_valid():
263
- raise TraceContextError(error_msg())
264
-
265
245
  @property
266
246
  def name(self) -> Optional[str]:
267
247
  """
@@ -274,7 +254,7 @@ class State(Generic[A], PrettyRepr):
274
254
  """
275
255
  Set the name of the state.
276
256
  """
277
- self._setattr_no_check('_name', name)
257
+ self._name = name
278
258
 
279
259
  @property
280
260
  def value(self) -> PyTree[ArrayLike]:
@@ -295,6 +275,26 @@ class State(Generic[A], PrettyRepr):
295
275
  """
296
276
  self.write_value(v)
297
277
 
278
+ @property
279
+ def stack_level(self):
280
+ """
281
+ The stack level of the state.
282
+
283
+ Returns:
284
+ The stack level.
285
+ """
286
+ return self._level
287
+
288
+ @stack_level.setter
289
+ def stack_level(self, level: int):
290
+ """
291
+ Set the stack level of the state.
292
+
293
+ Args:
294
+ level: The stack level.
295
+ """
296
+ self._level = level
297
+
298
298
  def write_value(self, v) -> None:
299
299
  # value checking
300
300
  if isinstance(v, State):
@@ -338,11 +338,11 @@ class State(Generic[A], PrettyRepr):
338
338
  in_tree = jax.tree.structure(v)
339
339
  self_tree = jax.tree.structure(self._value)
340
340
  if in_tree != self_tree:
341
- self._raise_error_with_source_info(
341
+ self.raise_error_with_source_info(
342
342
  ValueError(f'The given value {in_tree} does not match with the origin tree structure {self_tree}.')
343
343
  )
344
344
 
345
- def _raise_error_with_source_info(self, error: Exception):
345
+ def raise_error_with_source_info(self, error: Exception):
346
346
  """
347
347
  Raise an error with the source information for easy debugging.
348
348
  """
@@ -415,7 +415,6 @@ class State(Generic[A], PrettyRepr):
415
415
  obj = object.__new__(type(self))
416
416
  attributes = vars(self).copy()
417
417
  # keep its own trace state and stack level
418
- attributes['_trace_state'] = StateJaxTracer()
419
418
  attributes['_level'] = _get_trace_stack_level()
420
419
  attributes['_source_info'] = source_info_util.current()
421
420
  attributes.pop('_been_writen', None)
@@ -426,8 +425,6 @@ class State(Generic[A], PrettyRepr):
426
425
  def to_state_ref(self: State[A]) -> TreefyState[A]:
427
426
  metadata = vars(self).copy()
428
427
  del metadata['_value']
429
- del metadata['_trace_state']
430
- del metadata['_level']
431
428
  return TreefyState(type(self), self._value, **metadata)
432
429
 
433
430
  def __pretty_repr__(self):
@@ -442,7 +439,7 @@ class State(Generic[A], PrettyRepr):
442
439
  name = 'name'
443
440
  if name == 'tag' and value is None:
444
441
  continue
445
- if name in ['_trace_state', '_level', '_source_info', '_been_writen']:
442
+ if name in ['_level', '_source_info', '_been_writen']:
446
443
  continue
447
444
  yield PrettyAttr(name, repr(value))
448
445
 
@@ -458,7 +455,7 @@ class State(Generic[A], PrettyRepr):
458
455
  name = 'name'
459
456
  if name == 'tag' and value is None:
460
457
  continue
461
- if name in ['_trace_state', '_level', '_source_info', '_been_writen']:
458
+ if name in ['_level', '_source_info', '_been_writen']:
462
459
  continue
463
460
  children[name] = value
464
461
 
@@ -473,6 +470,12 @@ class State(Generic[A], PrettyRepr):
473
470
  def __eq__(self, other: object) -> bool:
474
471
  return type(self) is type(other) and vars(other) == vars(self)
475
472
 
473
+ def __hash__(self):
474
+ """
475
+ Make the state hashable.
476
+ """
477
+ return hash(id(self))
478
+
476
479
 
477
480
  def record_state_init(st: State[A]):
478
481
  trace: Catcher
@@ -482,13 +485,13 @@ def record_state_init(st: State[A]):
482
485
 
483
486
  def record_state_value_read(st: State[A]):
484
487
  trace: StateTraceStack
485
- for trace in TRACE_CONTEXT.state_stack[st._level:]:
488
+ for trace in TRACE_CONTEXT.state_stack[st.stack_level:]:
486
489
  trace.read_its_value(st)
487
490
 
488
491
 
489
492
  def record_state_value_write(st: State[A]):
490
493
  trace: StateTraceStack
491
- for trace in TRACE_CONTEXT.state_stack[st._level:]:
494
+ for trace in TRACE_CONTEXT.state_stack[st.stack_level:]:
492
495
  trace.write_its_value(st)
493
496
 
494
497
 
@@ -524,7 +527,6 @@ class BatchState(LongTermState):
524
527
  __module__ = 'brainstate'
525
528
 
526
529
 
527
-
528
530
  class HiddenState(ShortTermState):
529
531
  """
530
532
  The hidden state, which is used to store the hidden data in a dynamic model.
@@ -541,6 +543,37 @@ class ParamState(LongTermState):
541
543
  __module__ = 'brainstate'
542
544
 
543
545
 
546
+ class FakedState:
547
+ """
548
+ The faked state, which is used to store the faked data in the program.
549
+ """
550
+
551
+ __module__ = 'brainstate'
552
+
553
+ def __init__(self, value: Any, name: Optional[str] = None):
554
+ self._value = value
555
+ self._name = name
556
+
557
+ @property
558
+ def value(self) -> Any:
559
+ return self._value
560
+
561
+ @value.setter
562
+ def value(self, v) -> None:
563
+ self._value = v
564
+
565
+ def __repr__(self) -> str:
566
+ return f'FakedState(value={self._value})'
567
+
568
+ @property
569
+ def name(self) -> Optional[str]:
570
+ return self._name
571
+
572
+ @name.setter
573
+ def name(self, name: str) -> None:
574
+ self._name = name
575
+
576
+
544
577
  class StateDictManager(DictManager):
545
578
  """
546
579
  State stack, for collecting all :py:class:`~.State` used in the program.
@@ -802,7 +835,7 @@ class TreefyState(Generic[A], PrettyRepr):
802
835
  continue
803
836
  else:
804
837
  name = 'name'
805
- if name in ['_trace_state', '_level', '_source_info', 'type']:
838
+ if name in ['_level', '_source_info', 'type']:
806
839
  continue
807
840
  yield PrettyAttr(name, repr(value))
808
841
 
@@ -835,7 +868,7 @@ class TreefyState(Generic[A], PrettyRepr):
835
868
  # __init__ logic which should not be called twice
836
869
  metadata = self.get_metadata()
837
870
  state = object.__new__(self.type)
838
- vars(state).update(metadata, _value=self.value, _trace_state=StateJaxTracer(), _level=_get_trace_stack_level())
871
+ vars(state).update(metadata, _value=self.value, _level=_get_trace_stack_level())
839
872
  return state
840
873
 
841
874
  def copy(self: TreefyState[A]) -> TreefyState[A]:
brainstate/_state_test.py CHANGED
@@ -37,23 +37,6 @@ class TestStateSourceInfo(unittest.TestCase):
37
37
  with self.assertRaises(ValueError):
38
38
  state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
39
39
 
40
- def test_check_jax_tracer(self):
41
- a = bst.ShortTermState(jnp.zeros((2, 3)))
42
-
43
- @jax.jit
44
- def run_state(b):
45
- a.value = b
46
- return a.value
47
-
48
- # The following code will not raise an error, since the state is valid to trace.
49
- run_state(jnp.ones((2, 3)))
50
-
51
- with bst.check_state_jax_tracer():
52
- # The line below will not raise an error.
53
- with self.assertRaises(bst.util.TraceContextError):
54
- # recompile the function
55
- run_state(jnp.ones((2, 4)))
56
-
57
40
 
58
41
  class TestStateRepr(unittest.TestCase):
59
42
 
@@ -17,24 +17,14 @@
17
17
  This module includes transformations for augmenting the functionalities of JAX code.
18
18
  """
19
19
 
20
- from ._autograd import *
21
- from ._autograd import __all__ as _autograd_all
22
- from ._eval_shape import *
23
- from ._eval_shape import __all__ as _eval_shape_all
24
- from ._mapping import *
25
- from ._mapping import __all__ as _mapping_all
26
- from ._random import *
27
- from ._random import __all__ as _random_all
20
+ from ._autograd import GradientTransform, grad, vector_grad, hessian, jacobian, jacrev, jacfwd
21
+ from ._eval_shape import abstract_init
22
+ from ._mapping import vmap, pmap, map
23
+ from ._random import restore_rngs
28
24
 
29
- __all__ = (
30
- _eval_shape_all
31
- + _autograd_all
32
- + _mapping_all
33
- + _random_all
34
- )
35
- del (
36
- _eval_shape_all,
37
- _autograd_all,
38
- _mapping_all,
39
- _random_all
40
- )
25
+ __all__ = [
26
+ 'GradientTransform', 'grad', 'vector_grad', 'hessian', 'jacobian', 'jacrev', 'jacfwd',
27
+ 'abstract_init',
28
+ 'vmap', 'pmap', 'map',
29
+ 'restore_rngs',
30
+ ]
@@ -20,18 +20,18 @@ from typing import Any, TypeVar, Callable, Sequence, Union
20
20
 
21
21
  import jax
22
22
 
23
- from brainstate.graph import graph_to_tree, tree_to_graph
23
+ from brainstate.graph import Node, flatten, unflatten
24
24
  from brainstate.random import DEFAULT, RandomState
25
25
  from ._random import restore_rngs
26
26
 
27
27
  __all__ = [
28
- 'eval_shape',
28
+ 'abstract_init',
29
29
  ]
30
30
 
31
31
  A = TypeVar('A')
32
32
 
33
33
 
34
- def eval_shape(
34
+ def abstract_init(
35
35
  fn: Callable[..., A],
36
36
  *args: Any,
37
37
  rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
@@ -48,7 +48,7 @@ def eval_shape(
48
48
  ... self.dense1 = bst.nn.Linear(n_in, n_mid)
49
49
  ... self.dense2 = bst.nn.Linear(n_mid, n_out)
50
50
 
51
- >>> r = bst.augment.eval_shape(lambda: MLP(1, 2, 3))
51
+ >>> r = bst.augment.abstract_init(lambda: MLP(1, 2, 3))
52
52
  >>> r
53
53
  MLP(
54
54
  dense1=Linear(
@@ -87,16 +87,15 @@ def eval_shape(
87
87
  Returns:
88
88
  out: a nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves.
89
89
 
90
-
91
90
  """
92
91
 
93
92
  @functools.wraps(fn)
94
93
  @restore_rngs(rngs=rngs)
95
94
  def _eval_shape_fn(*args_, **kwargs_):
96
- args_, kwargs_ = tree_to_graph((args_, kwargs_))
97
95
  out = fn(*args_, **kwargs_)
98
- return graph_to_tree(out)
96
+ assert isinstance(out, Node), 'The output of the function must be Node'
97
+ graph_def, treefy_states = flatten(out)
98
+ return graph_def, treefy_states
99
99
 
100
- args, kwargs = graph_to_tree((args, kwargs))
101
- out = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
102
- return tree_to_graph(out)
100
+ graph_def_, treefy_states_ = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
101
+ return unflatten(graph_def_, treefy_states_)
@@ -35,6 +35,6 @@ class TestEvalShape(unittest.TestCase):
35
35
  x = self.dense2(x)
36
36
  return x
37
37
 
38
- r = bst.augment.eval_shape(lambda: MLP(1, 2, 3))
38
+ r = bst.augment.abstract_init(lambda: MLP(1, 2, 3))
39
39
  print(r)
40
40
  print(bst.random.DEFAULT)