brainstate 0.1.0.post20250104__py2.py3-none-any.whl → 0.1.0.post20250120__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +77 -44
- brainstate/_state_test.py +0 -17
- brainstate/augment/_eval_shape.py +9 -10
- brainstate/augment/_eval_shape_test.py +1 -1
- brainstate/augment/_mapping.py +265 -277
- brainstate/augment/_mapping_test.py +147 -175
- brainstate/compile/_ad_checkpoint.py +6 -4
- brainstate/compile/_error_if_test.py +1 -0
- brainstate/compile/_jit.py +37 -28
- brainstate/compile/_loop_collect_return.py +8 -5
- brainstate/compile/_loop_no_collection.py +2 -0
- brainstate/compile/_make_jaxpr.py +7 -3
- brainstate/compile/_make_jaxpr_test.py +2 -1
- brainstate/compile/_progress_bar.py +68 -40
- brainstate/compile/_unvmap.py +6 -2
- brainstate/environ.py +28 -18
- brainstate/environ_test.py +4 -0
- brainstate/event/__init__.py +0 -2
- brainstate/event/_csr.py +266 -23
- brainstate/event/_csr_test.py +187 -0
- brainstate/event/_fixedprob_mv.py +4 -2
- brainstate/event/_fixedprob_mv_test.py +2 -1
- brainstate/event/_xla_custom_op.py +16 -5
- brainstate/graph/__init__.py +8 -12
- brainstate/graph/_graph_node.py +1 -23
- brainstate/graph/_graph_operation.py +1 -1
- brainstate/graph/_graph_operation_test.py +0 -159
- brainstate/nn/_dyn_impl/_inputs.py +124 -39
- brainstate/nn/_interaction/_conv.py +4 -2
- brainstate/nn/_interaction/_linear.py +84 -10
- brainstate/random/_rand_funs.py +9 -2
- brainstate/random/_rand_seed.py +12 -2
- brainstate/random/_rand_state.py +50 -179
- brainstate/surrogate.py +5 -1
- brainstate/util/__init__.py +0 -4
- brainstate/util/_caller.py +1 -1
- brainstate/util/_dict.py +4 -1
- brainstate/util/_filter.py +1 -1
- brainstate/util/_pretty_repr.py +1 -1
- brainstate/util/_struct.py +1 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/METADATA +2 -1
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/RECORD +46 -52
- brainstate/event/_csr_mv_test.py +0 -118
- brainstate/graph/_graph_context.py +0 -443
- brainstate/graph/_graph_context_test.py +0 -65
- brainstate/graph/_graph_convert.py +0 -246
- brainstate/util/_tracers.py +0 -68
- brainstate/util/_visualization.py +0 -47
- /brainstate/event/{_csr_mv_benchmark.py → _csr_benchmark.py} +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250104.dist-info → brainstate-0.1.0.post20250120.dist-info}/top_level.txt +0 -0
brainstate/_state.py
CHANGED
@@ -28,13 +28,14 @@ from jax.api_util import shaped_abstractify
|
|
28
28
|
from jax.extend import source_info_util
|
29
29
|
|
30
30
|
from brainstate.typing import ArrayLike, PyTree, Missing
|
31
|
-
from brainstate.util import DictManager, PrettyRepr, PrettyType, PrettyAttr
|
32
|
-
from brainstate.util._tracers import StateJaxTracer
|
31
|
+
from brainstate.util import DictManager, PrettyRepr, PrettyType, PrettyAttr
|
33
32
|
|
34
33
|
__all__ = [
|
35
34
|
'State', 'ShortTermState', 'LongTermState', 'HiddenState', 'ParamState', 'TreefyState',
|
35
|
+
'FakedState',
|
36
36
|
|
37
37
|
'StateDictManager', 'StateTraceStack', 'check_state_value_tree', 'check_state_jax_tracer', 'catch_new_states',
|
38
|
+
'maybe_state',
|
38
39
|
]
|
39
40
|
|
40
41
|
A = TypeVar('A')
|
@@ -102,6 +103,7 @@ class Catcher:
|
|
102
103
|
"""
|
103
104
|
The catcher to catch the new states.
|
104
105
|
"""
|
106
|
+
|
105
107
|
def __init__(self, tag: str):
|
106
108
|
self.tag = tag
|
107
109
|
self.state_ids = set()
|
@@ -114,6 +116,13 @@ class Catcher:
|
|
114
116
|
state.tag = self.tag
|
115
117
|
|
116
118
|
|
119
|
+
def maybe_state(val: Any):
|
120
|
+
if isinstance(val, State):
|
121
|
+
return val.value
|
122
|
+
else:
|
123
|
+
return val
|
124
|
+
|
125
|
+
|
117
126
|
@contextlib.contextmanager
|
118
127
|
def check_state_jax_tracer(val: bool = True) -> None:
|
119
128
|
"""
|
@@ -197,7 +206,6 @@ class State(Generic[A], PrettyRepr):
|
|
197
206
|
value: PyTree. It can be anything as a pyTree.
|
198
207
|
"""
|
199
208
|
__module__ = 'brainstate'
|
200
|
-
_trace_state: StateJaxTracer
|
201
209
|
_level: int
|
202
210
|
_source_info: source_info_util.SourceInfo
|
203
211
|
_name: Optional[str]
|
@@ -213,9 +221,6 @@ class State(Generic[A], PrettyRepr):
|
|
213
221
|
):
|
214
222
|
tag = metadata.pop('tag', None)
|
215
223
|
|
216
|
-
# avoid using self._setattr to avoid the check
|
217
|
-
vars(self)['_trace_state'] = StateJaxTracer()
|
218
|
-
|
219
224
|
# set the value and metadata
|
220
225
|
if isinstance(value, StateMetadata):
|
221
226
|
metadata.update(dict(value.metadata))
|
@@ -237,31 +242,6 @@ class State(Generic[A], PrettyRepr):
|
|
237
242
|
# record the state initialization
|
238
243
|
record_state_init(self)
|
239
244
|
|
240
|
-
if not TYPE_CHECKING:
|
241
|
-
def __setattr__(self, name: str, value: Any) -> None:
|
242
|
-
return self._setattr(name, value)
|
243
|
-
|
244
|
-
def _setattr(self, name: str, value: Any):
|
245
|
-
"""
|
246
|
-
Check if the state is valid to mutate.
|
247
|
-
"""
|
248
|
-
if TRACE_CONTEXT.jax_tracer_check[-1]:
|
249
|
-
self.check_valid_trace(lambda: f'Cannot mutate {type(self).__name__} from a different trace level')
|
250
|
-
object.__setattr__(self, name, value)
|
251
|
-
|
252
|
-
def _setattr_no_check(self, name: str, value: Any):
|
253
|
-
"""
|
254
|
-
Set the attribute without checking the trace level.
|
255
|
-
"""
|
256
|
-
vars(self)[name] = value
|
257
|
-
|
258
|
-
def check_valid_trace(self, error_msg: Callable[[], str]):
|
259
|
-
"""
|
260
|
-
Check if the state is valid to trace.
|
261
|
-
"""
|
262
|
-
if not self._trace_state.is_valid():
|
263
|
-
raise TraceContextError(error_msg())
|
264
|
-
|
265
245
|
@property
|
266
246
|
def name(self) -> Optional[str]:
|
267
247
|
"""
|
@@ -274,7 +254,7 @@ class State(Generic[A], PrettyRepr):
|
|
274
254
|
"""
|
275
255
|
Set the name of the state.
|
276
256
|
"""
|
277
|
-
self.
|
257
|
+
self._name = name
|
278
258
|
|
279
259
|
@property
|
280
260
|
def value(self) -> PyTree[ArrayLike]:
|
@@ -295,6 +275,26 @@ class State(Generic[A], PrettyRepr):
|
|
295
275
|
"""
|
296
276
|
self.write_value(v)
|
297
277
|
|
278
|
+
@property
|
279
|
+
def stack_level(self):
|
280
|
+
"""
|
281
|
+
The stack level of the state.
|
282
|
+
|
283
|
+
Returns:
|
284
|
+
The stack level.
|
285
|
+
"""
|
286
|
+
return self._level
|
287
|
+
|
288
|
+
@stack_level.setter
|
289
|
+
def stack_level(self, level: int):
|
290
|
+
"""
|
291
|
+
Set the stack level of the state.
|
292
|
+
|
293
|
+
Args:
|
294
|
+
level: The stack level.
|
295
|
+
"""
|
296
|
+
self._level = level
|
297
|
+
|
298
298
|
def write_value(self, v) -> None:
|
299
299
|
# value checking
|
300
300
|
if isinstance(v, State):
|
@@ -338,11 +338,11 @@ class State(Generic[A], PrettyRepr):
|
|
338
338
|
in_tree = jax.tree.structure(v)
|
339
339
|
self_tree = jax.tree.structure(self._value)
|
340
340
|
if in_tree != self_tree:
|
341
|
-
self.
|
341
|
+
self.raise_error_with_source_info(
|
342
342
|
ValueError(f'The given value {in_tree} does not match with the origin tree structure {self_tree}.')
|
343
343
|
)
|
344
344
|
|
345
|
-
def
|
345
|
+
def raise_error_with_source_info(self, error: Exception):
|
346
346
|
"""
|
347
347
|
Raise an error with the source information for easy debugging.
|
348
348
|
"""
|
@@ -415,7 +415,6 @@ class State(Generic[A], PrettyRepr):
|
|
415
415
|
obj = object.__new__(type(self))
|
416
416
|
attributes = vars(self).copy()
|
417
417
|
# keep its own trace state and stack level
|
418
|
-
attributes['_trace_state'] = StateJaxTracer()
|
419
418
|
attributes['_level'] = _get_trace_stack_level()
|
420
419
|
attributes['_source_info'] = source_info_util.current()
|
421
420
|
attributes.pop('_been_writen', None)
|
@@ -426,8 +425,6 @@ class State(Generic[A], PrettyRepr):
|
|
426
425
|
def to_state_ref(self: State[A]) -> TreefyState[A]:
|
427
426
|
metadata = vars(self).copy()
|
428
427
|
del metadata['_value']
|
429
|
-
del metadata['_trace_state']
|
430
|
-
del metadata['_level']
|
431
428
|
return TreefyState(type(self), self._value, **metadata)
|
432
429
|
|
433
430
|
def __pretty_repr__(self):
|
@@ -442,7 +439,7 @@ class State(Generic[A], PrettyRepr):
|
|
442
439
|
name = 'name'
|
443
440
|
if name == 'tag' and value is None:
|
444
441
|
continue
|
445
|
-
if name in ['
|
442
|
+
if name in ['_level', '_source_info', '_been_writen']:
|
446
443
|
continue
|
447
444
|
yield PrettyAttr(name, repr(value))
|
448
445
|
|
@@ -458,7 +455,7 @@ class State(Generic[A], PrettyRepr):
|
|
458
455
|
name = 'name'
|
459
456
|
if name == 'tag' and value is None:
|
460
457
|
continue
|
461
|
-
if name in ['
|
458
|
+
if name in ['_level', '_source_info', '_been_writen']:
|
462
459
|
continue
|
463
460
|
children[name] = value
|
464
461
|
|
@@ -473,6 +470,12 @@ class State(Generic[A], PrettyRepr):
|
|
473
470
|
def __eq__(self, other: object) -> bool:
|
474
471
|
return type(self) is type(other) and vars(other) == vars(self)
|
475
472
|
|
473
|
+
def __hash__(self):
|
474
|
+
"""
|
475
|
+
Make the state hashable.
|
476
|
+
"""
|
477
|
+
return hash(id(self))
|
478
|
+
|
476
479
|
|
477
480
|
def record_state_init(st: State[A]):
|
478
481
|
trace: Catcher
|
@@ -482,13 +485,13 @@ def record_state_init(st: State[A]):
|
|
482
485
|
|
483
486
|
def record_state_value_read(st: State[A]):
|
484
487
|
trace: StateTraceStack
|
485
|
-
for trace in TRACE_CONTEXT.state_stack[st.
|
488
|
+
for trace in TRACE_CONTEXT.state_stack[st.stack_level:]:
|
486
489
|
trace.read_its_value(st)
|
487
490
|
|
488
491
|
|
489
492
|
def record_state_value_write(st: State[A]):
|
490
493
|
trace: StateTraceStack
|
491
|
-
for trace in TRACE_CONTEXT.state_stack[st.
|
494
|
+
for trace in TRACE_CONTEXT.state_stack[st.stack_level:]:
|
492
495
|
trace.write_its_value(st)
|
493
496
|
|
494
497
|
|
@@ -524,7 +527,6 @@ class BatchState(LongTermState):
|
|
524
527
|
__module__ = 'brainstate'
|
525
528
|
|
526
529
|
|
527
|
-
|
528
530
|
class HiddenState(ShortTermState):
|
529
531
|
"""
|
530
532
|
The hidden state, which is used to store the hidden data in a dynamic model.
|
@@ -541,6 +543,37 @@ class ParamState(LongTermState):
|
|
541
543
|
__module__ = 'brainstate'
|
542
544
|
|
543
545
|
|
546
|
+
class FakedState:
|
547
|
+
"""
|
548
|
+
The faked state, which is used to store the faked data in the program.
|
549
|
+
"""
|
550
|
+
|
551
|
+
__module__ = 'brainstate'
|
552
|
+
|
553
|
+
def __init__(self, value: Any, name: Optional[str] = None):
|
554
|
+
self._value = value
|
555
|
+
self._name = name
|
556
|
+
|
557
|
+
@property
|
558
|
+
def value(self) -> Any:
|
559
|
+
return self._value
|
560
|
+
|
561
|
+
@value.setter
|
562
|
+
def value(self, v) -> None:
|
563
|
+
self._value = v
|
564
|
+
|
565
|
+
def __repr__(self) -> str:
|
566
|
+
return f'FakedState(value={self._value})'
|
567
|
+
|
568
|
+
@property
|
569
|
+
def name(self) -> Optional[str]:
|
570
|
+
return self._name
|
571
|
+
|
572
|
+
@name.setter
|
573
|
+
def name(self, name: str) -> None:
|
574
|
+
self._name = name
|
575
|
+
|
576
|
+
|
544
577
|
class StateDictManager(DictManager):
|
545
578
|
"""
|
546
579
|
State stack, for collecting all :py:class:`~.State` used in the program.
|
@@ -802,7 +835,7 @@ class TreefyState(Generic[A], PrettyRepr):
|
|
802
835
|
continue
|
803
836
|
else:
|
804
837
|
name = 'name'
|
805
|
-
if name in ['
|
838
|
+
if name in ['_level', '_source_info', 'type']:
|
806
839
|
continue
|
807
840
|
yield PrettyAttr(name, repr(value))
|
808
841
|
|
@@ -835,7 +868,7 @@ class TreefyState(Generic[A], PrettyRepr):
|
|
835
868
|
# __init__ logic which should not be called twice
|
836
869
|
metadata = self.get_metadata()
|
837
870
|
state = object.__new__(self.type)
|
838
|
-
vars(state).update(metadata, _value=self.value,
|
871
|
+
vars(state).update(metadata, _value=self.value, _level=_get_trace_stack_level())
|
839
872
|
return state
|
840
873
|
|
841
874
|
def copy(self: TreefyState[A]) -> TreefyState[A]:
|
brainstate/_state_test.py
CHANGED
@@ -37,23 +37,6 @@ class TestStateSourceInfo(unittest.TestCase):
|
|
37
37
|
with self.assertRaises(ValueError):
|
38
38
|
state.value = (jnp.zeros((2, 3)), jnp.zeros((2, 3)))
|
39
39
|
|
40
|
-
def test_check_jax_tracer(self):
|
41
|
-
a = bst.ShortTermState(jnp.zeros((2, 3)))
|
42
|
-
|
43
|
-
@jax.jit
|
44
|
-
def run_state(b):
|
45
|
-
a.value = b
|
46
|
-
return a.value
|
47
|
-
|
48
|
-
# The following code will not raise an error, since the state is valid to trace.
|
49
|
-
run_state(jnp.ones((2, 3)))
|
50
|
-
|
51
|
-
with bst.check_state_jax_tracer():
|
52
|
-
# The line below will not raise an error.
|
53
|
-
with self.assertRaises(bst.util.TraceContextError):
|
54
|
-
# recompile the function
|
55
|
-
run_state(jnp.ones((2, 4)))
|
56
|
-
|
57
40
|
|
58
41
|
class TestStateRepr(unittest.TestCase):
|
59
42
|
|
@@ -20,18 +20,18 @@ from typing import Any, TypeVar, Callable, Sequence, Union
|
|
20
20
|
|
21
21
|
import jax
|
22
22
|
|
23
|
-
from brainstate.graph import
|
23
|
+
from brainstate.graph import Node, flatten, unflatten
|
24
24
|
from brainstate.random import DEFAULT, RandomState
|
25
25
|
from ._random import restore_rngs
|
26
26
|
|
27
27
|
__all__ = [
|
28
|
-
'
|
28
|
+
'abstract_init',
|
29
29
|
]
|
30
30
|
|
31
31
|
A = TypeVar('A')
|
32
32
|
|
33
33
|
|
34
|
-
def
|
34
|
+
def abstract_init(
|
35
35
|
fn: Callable[..., A],
|
36
36
|
*args: Any,
|
37
37
|
rngs: Union[RandomState, Sequence[RandomState]] = DEFAULT,
|
@@ -48,7 +48,7 @@ def eval_shape(
|
|
48
48
|
... self.dense1 = bst.nn.Linear(n_in, n_mid)
|
49
49
|
... self.dense2 = bst.nn.Linear(n_mid, n_out)
|
50
50
|
|
51
|
-
>>> r = bst.augment.
|
51
|
+
>>> r = bst.augment.abstract_init(lambda: MLP(1, 2, 3))
|
52
52
|
>>> r
|
53
53
|
MLP(
|
54
54
|
dense1=Linear(
|
@@ -87,16 +87,15 @@ def eval_shape(
|
|
87
87
|
Returns:
|
88
88
|
out: a nested PyTree containing :class:`jax.ShapeDtypeStruct` objects as leaves.
|
89
89
|
|
90
|
-
|
91
90
|
"""
|
92
91
|
|
93
92
|
@functools.wraps(fn)
|
94
93
|
@restore_rngs(rngs=rngs)
|
95
94
|
def _eval_shape_fn(*args_, **kwargs_):
|
96
|
-
args_, kwargs_ = tree_to_graph((args_, kwargs_))
|
97
95
|
out = fn(*args_, **kwargs_)
|
98
|
-
|
96
|
+
assert isinstance(out, Node), 'The output of the function must be Node'
|
97
|
+
graph_def, treefy_states = flatten(out)
|
98
|
+
return graph_def, treefy_states
|
99
99
|
|
100
|
-
|
101
|
-
|
102
|
-
return tree_to_graph(out)
|
100
|
+
graph_def_, treefy_states_ = jax.eval_shape(_eval_shape_fn, *args, **kwargs)
|
101
|
+
return unflatten(graph_def_, treefy_states_)
|