brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +2 -4
- brainstate/_deprecation_test.py +2 -24
- brainstate/_state.py +540 -35
- brainstate/_state_test.py +1085 -8
- brainstate/graph/_operation.py +1 -5
- brainstate/mixin.py +14 -0
- brainstate/nn/__init__.py +42 -33
- brainstate/nn/_collective_ops.py +2 -0
- brainstate/nn/_common_test.py +0 -20
- brainstate/nn/_delay.py +1 -1
- brainstate/nn/_dropout_test.py +9 -6
- brainstate/nn/_dynamics.py +67 -464
- brainstate/nn/_dynamics_test.py +0 -14
- brainstate/nn/_embedding.py +7 -7
- brainstate/nn/_exp_euler.py +9 -9
- brainstate/nn/_linear.py +21 -21
- brainstate/nn/_module.py +25 -18
- brainstate/nn/_normalizations.py +27 -27
- brainstate/random/__init__.py +6 -6
- brainstate/random/{_rand_funs.py → _fun.py} +1 -1
- brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
- brainstate/random/_impl.py +672 -0
- brainstate/random/{_rand_seed.py → _seed.py} +1 -1
- brainstate/random/{_rand_state.py → _state.py} +121 -418
- brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
- brainstate/transform/__init__.py +6 -9
- brainstate/transform/_conditions.py +2 -2
- brainstate/transform/_find_state.py +200 -0
- brainstate/transform/_find_state_test.py +84 -0
- brainstate/transform/_make_jaxpr.py +221 -61
- brainstate/transform/_make_jaxpr_test.py +125 -1
- brainstate/transform/_mapping.py +287 -209
- brainstate/transform/_mapping_test.py +94 -184
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
- brainstate/transform/_eval_shape.py +0 -145
- brainstate/transform/_eval_shape_test.py +0 -38
- brainstate/transform/_random.py +0 -171
- /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ import jax.numpy as jnp
|
|
19
19
|
import jax.random as jr
|
20
20
|
import numpy as np
|
21
21
|
|
22
|
-
from brainstate.random.
|
22
|
+
from brainstate.random._state import RandomState, DEFAULT, formalize_key, _size2shape, _check_py_seq
|
23
23
|
|
24
24
|
|
25
25
|
class TestRandomStateInitialization(unittest.TestCase):
|
@@ -437,32 +437,32 @@ class TestUtilityFunctions(unittest.TestCase):
|
|
437
437
|
|
438
438
|
def test_formalize_key_with_int(self):
|
439
439
|
"""Test _formalize_key with integer."""
|
440
|
-
key =
|
440
|
+
key = formalize_key(42)
|
441
441
|
expected = jr.PRNGKey(42)
|
442
442
|
np.testing.assert_array_equal(key, expected)
|
443
443
|
|
444
444
|
def test_formalize_key_with_array(self):
|
445
445
|
"""Test _formalize_key with array."""
|
446
446
|
input_key = jr.PRNGKey(123)
|
447
|
-
key =
|
447
|
+
key = formalize_key(input_key, True)
|
448
448
|
np.testing.assert_array_equal(key, input_key)
|
449
449
|
|
450
450
|
def test_formalize_key_with_uint32_array(self):
|
451
451
|
"""Test _formalize_key with uint32 array."""
|
452
452
|
input_array = np.array([123, 456], dtype=np.uint32)
|
453
|
-
key =
|
453
|
+
key = formalize_key(input_array)
|
454
454
|
np.testing.assert_array_equal(key, input_array)
|
455
455
|
|
456
456
|
def test_formalize_key_invalid_input(self):
|
457
457
|
"""Test _formalize_key with invalid input."""
|
458
458
|
with self.assertRaises(TypeError):
|
459
|
-
|
459
|
+
formalize_key("invalid")
|
460
460
|
|
461
461
|
with self.assertRaises(TypeError):
|
462
|
-
|
462
|
+
formalize_key(np.array([1, 2, 3], dtype=np.uint32)) # Wrong size
|
463
463
|
|
464
464
|
with self.assertRaises(TypeError):
|
465
|
-
|
465
|
+
formalize_key(np.array([1, 2], dtype=np.int32)) # Wrong dtype
|
466
466
|
|
467
467
|
def test_size2shape(self):
|
468
468
|
"""Test _size2shape function."""
|
brainstate/transform/__init__.py
CHANGED
@@ -22,8 +22,8 @@ from ._conditions import *
|
|
22
22
|
from ._conditions import __all__ as _conditions_all
|
23
23
|
from ._error_if import *
|
24
24
|
from ._error_if import __all__ as _error_if_all
|
25
|
-
from .
|
26
|
-
from .
|
25
|
+
from ._find_state import *
|
26
|
+
from ._find_state import __all__ as _find_all
|
27
27
|
from ._jit import *
|
28
28
|
from ._jit import __all__ as _jit_all
|
29
29
|
from ._loop_collect_return import *
|
@@ -36,24 +36,21 @@ from ._mapping import *
|
|
36
36
|
from ._mapping import __all__ as _mapping_all
|
37
37
|
from ._progress_bar import *
|
38
38
|
from ._progress_bar import __all__ as _progress_bar_all
|
39
|
-
from ._random import *
|
40
|
-
from ._random import __all__ as _random_all
|
41
39
|
from ._unvmap import *
|
42
40
|
from ._unvmap import __all__ as _unvmap_all
|
43
41
|
|
44
|
-
__all__ = _ad_checkpoint_all + _autograd_all + _conditions_all + _error_if_all
|
45
|
-
__all__ +=
|
46
|
-
__all__ += _make_jaxpr_all + _mapping_all + _progress_bar_all +
|
42
|
+
__all__ = _ad_checkpoint_all + _autograd_all + _conditions_all + _error_if_all + _find_all
|
43
|
+
__all__ += _jit_all + _loop_collect_return_all + _loop_no_collection_all
|
44
|
+
__all__ += _make_jaxpr_all + _mapping_all + _progress_bar_all + _unvmap_all
|
45
|
+
del _find_all
|
47
46
|
del _ad_checkpoint_all
|
48
47
|
del _autograd_all
|
49
48
|
del _conditions_all
|
50
49
|
del _error_if_all
|
51
|
-
del _eval_shape_all
|
52
50
|
del _jit_all
|
53
51
|
del _loop_collect_return_all
|
54
52
|
del _loop_no_collection_all
|
55
53
|
del _make_jaxpr_all
|
56
54
|
del _mapping_all
|
57
55
|
del _progress_bar_all
|
58
|
-
del _random_all
|
59
56
|
del _unvmap_all
|
@@ -218,7 +218,7 @@ def switch(index, branches: Sequence[Callable], *operands):
|
|
218
218
|
index = jax.lax.clamp(lo, index, hi)
|
219
219
|
|
220
220
|
# not jit
|
221
|
-
if jax.config.jax_disable_jit and isinstance(
|
221
|
+
if jax.config.jax_disable_jit and not isinstance(to_concrete_aval(index), Tracer):
|
222
222
|
return branches[int(index)](*operands)
|
223
223
|
|
224
224
|
# evaluate jaxpr
|
@@ -311,6 +311,6 @@ def ifelse(conditions, branches, *operands, check_cond: bool = True):
|
|
311
311
|
# format index
|
312
312
|
conditions = jnp.asarray(conditions, np.int32)
|
313
313
|
if check_cond:
|
314
|
-
jit_error_if(jnp.sum(conditions) != 1, "Only one condition can be True. But got {}.",
|
314
|
+
jit_error_if(jnp.sum(conditions) != 1, "Only one condition can be True. But got {c}.", c=conditions)
|
315
315
|
index = jnp.where(conditions, size=1, fill_value=len(conditions) - 1)[0][0]
|
316
316
|
return switch(index, branches, *operands)
|
@@ -0,0 +1,200 @@
|
|
1
|
+
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from typing import Callable, Dict, Hashable, Literal, Sequence, Any
|
17
|
+
|
18
|
+
from brainstate._state import State
|
19
|
+
from brainstate.util.filter import Filter, to_predicate
|
20
|
+
from ._make_jaxpr import StatefulFunction
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
'StateFinder'
|
24
|
+
]
|
25
|
+
|
26
|
+
|
27
|
+
class StateFinder:
|
28
|
+
"""
|
29
|
+
Discover :class:`~brainstate.State` instances touched by a callable.
|
30
|
+
|
31
|
+
``StateFinder`` wraps a function in :class:`~brainstate.transform.StatefulFunction`
|
32
|
+
and exposes the collection of states the function reads or writes. The finder
|
33
|
+
can filter states by predicates, request only read or write states, and return
|
34
|
+
the result in several convenient formats.
|
35
|
+
|
36
|
+
Parameters
|
37
|
+
----------
|
38
|
+
fn : callable
|
39
|
+
Function whose state usage should be inspected.
|
40
|
+
filter : Filter, optional
|
41
|
+
Predicate (see :mod:`brainstate.util.filter`) used to select states.
|
42
|
+
usage : {'all', 'read', 'write', 'both'}, default 'all'
|
43
|
+
Portion of the state trace to return. ``'both'`` returns a mapping with
|
44
|
+
separate read and write entries.
|
45
|
+
return_type : {'dict', 'list', 'tuple'}, default 'dict'
|
46
|
+
Controls the container type returned for the selected states. When
|
47
|
+
``usage='both'``, the same container type is used for the ``'read'`` and
|
48
|
+
``'write'`` entries.
|
49
|
+
key_fn : callable, optional
|
50
|
+
Callable ``key_fn(index, state)`` that produces dictionary keys when
|
51
|
+
``return_type='dict'``. Defaults to using the positional index so existing
|
52
|
+
code continues to receive integer keys.
|
53
|
+
|
54
|
+
Examples
|
55
|
+
--------
|
56
|
+
.. code-block:: python
|
57
|
+
|
58
|
+
>>> import brainstate
|
59
|
+
>>> import jax.numpy as jnp
|
60
|
+
>>>
|
61
|
+
>>> param = brainstate.ParamState(jnp.ones(()), name='weights')
|
62
|
+
>>> bias = brainstate.ParamState(jnp.zeros(()), name='bias')
|
63
|
+
>>>
|
64
|
+
>>> def forward(x):
|
65
|
+
... _ = bias.value # read-only
|
66
|
+
... param.value = param.value * x # read + write
|
67
|
+
... return param.value + bias.value
|
68
|
+
>>>
|
69
|
+
>>> finder = brainstate.transform.StateFinder(
|
70
|
+
... forward,
|
71
|
+
... filter=brainstate.ParamState,
|
72
|
+
... usage='both',
|
73
|
+
... key_fn=lambda i, st: st.name or i,
|
74
|
+
... )
|
75
|
+
>>> finder(2.0)['write'] # doctest: +ELLIPSIS
|
76
|
+
{'weights': ParamState(...}
|
77
|
+
|
78
|
+
Notes
|
79
|
+
-----
|
80
|
+
The underlying :class:`StatefulFunction` is cached, so subsequent calls with
|
81
|
+
compatible arguments will reuse the compiled trace.
|
82
|
+
"""
|
83
|
+
|
84
|
+
_VALID_USAGE: tuple[str, ...] = ('all', 'read', 'write', 'both')
|
85
|
+
_VALID_RETURN_TYPE: tuple[str, ...] = ('dict', 'list', 'tuple')
|
86
|
+
|
87
|
+
def __init__(
|
88
|
+
self,
|
89
|
+
fn: Callable,
|
90
|
+
filter: Filter = None,
|
91
|
+
*,
|
92
|
+
usage: Literal['all', 'read', 'write', 'both'] = 'all',
|
93
|
+
return_type: Literal['dict', 'list', 'tuple'] = 'dict',
|
94
|
+
key_fn: Callable[[int, State], Hashable] | None = None,
|
95
|
+
) -> None:
|
96
|
+
if usage not in self._VALID_USAGE:
|
97
|
+
raise ValueError(f"Invalid usage '{usage}'. Expected one of {self._VALID_USAGE}.")
|
98
|
+
if return_type not in self._VALID_RETURN_TYPE:
|
99
|
+
raise ValueError(
|
100
|
+
f"Invalid return_type '{return_type}'. Expected one of {self._VALID_RETURN_TYPE}."
|
101
|
+
)
|
102
|
+
|
103
|
+
self.fn = fn
|
104
|
+
self._usage = usage
|
105
|
+
self._return_type = return_type
|
106
|
+
self._key_fn = key_fn if key_fn is not None else self._default_key_fn
|
107
|
+
self._filter = to_predicate(filter) if filter is not None else None
|
108
|
+
self.stateful_fn = StatefulFunction(self.fn)
|
109
|
+
|
110
|
+
def __call__(self, *args, **kwargs):
|
111
|
+
"""
|
112
|
+
Invoke :meth:`find` to retrieve states touched by ``fn``.
|
113
|
+
"""
|
114
|
+
return self.find(*args, **kwargs)
|
115
|
+
|
116
|
+
def find(self, *args, **kwargs):
|
117
|
+
"""
|
118
|
+
Execute the wrapped function symbolically and return the selected states.
|
119
|
+
|
120
|
+
Parameters
|
121
|
+
----------
|
122
|
+
*args, **kwargs
|
123
|
+
Arguments forwarded to ``fn`` to determine the state trace.
|
124
|
+
|
125
|
+
Returns
|
126
|
+
-------
|
127
|
+
Any
|
128
|
+
Container holding the requested states as configured by ``usage`` and
|
129
|
+
``return_type``.
|
130
|
+
"""
|
131
|
+
if self._usage == 'both':
|
132
|
+
read_states = self._collect_states('read', *args, **kwargs)
|
133
|
+
write_states = self._collect_states('write', *args, **kwargs)
|
134
|
+
return {
|
135
|
+
'read': self._format_states(read_states),
|
136
|
+
'write': self._format_states(write_states),
|
137
|
+
}
|
138
|
+
|
139
|
+
states = self._collect_states(self._usage, *args, **kwargs)
|
140
|
+
return self._format_states(states)
|
141
|
+
|
142
|
+
def _collect_states(self, usage: str, *args, **kwargs) -> Sequence[State]:
|
143
|
+
usage_map = {
|
144
|
+
'all': self.stateful_fn.get_states,
|
145
|
+
'read': self.stateful_fn.get_read_states,
|
146
|
+
'write': self.stateful_fn.get_write_states,
|
147
|
+
}
|
148
|
+
collector = usage_map.get(usage)
|
149
|
+
if collector is None:
|
150
|
+
raise ValueError(f"Unsupported usage '{usage}'.")
|
151
|
+
states = list(collector(*args, **kwargs))
|
152
|
+
if self._filter is not None:
|
153
|
+
states = [st for st in states if self._filter(tuple(), st)]
|
154
|
+
return states
|
155
|
+
|
156
|
+
def _format_states(self, states: Sequence[State]):
|
157
|
+
if self._return_type == 'list':
|
158
|
+
return list(states)
|
159
|
+
if self._return_type == 'tuple':
|
160
|
+
return tuple(states)
|
161
|
+
return self._states_to_dict(states)
|
162
|
+
|
163
|
+
def _states_to_dict(self, states: Sequence[State]) -> Dict[Hashable, State]:
|
164
|
+
result: Dict[Hashable, State] = {}
|
165
|
+
used_keys: set[Hashable] = set()
|
166
|
+
for idx, state in enumerate(states):
|
167
|
+
key = self._key_fn(idx, state)
|
168
|
+
key = self._ensure_hashable(key)
|
169
|
+
key = self._ensure_unique_key(key, idx, state, used_keys)
|
170
|
+
result[key] = state
|
171
|
+
used_keys.add(key)
|
172
|
+
return result
|
173
|
+
|
174
|
+
@staticmethod
|
175
|
+
def _default_key_fn(idx: int, state: State) -> Hashable:
|
176
|
+
return idx
|
177
|
+
|
178
|
+
@staticmethod
|
179
|
+
def _ensure_hashable(key: Any) -> Hashable:
|
180
|
+
if key is None:
|
181
|
+
return None
|
182
|
+
try:
|
183
|
+
hash(key)
|
184
|
+
except TypeError:
|
185
|
+
return str(key)
|
186
|
+
return key
|
187
|
+
|
188
|
+
@staticmethod
|
189
|
+
def _ensure_unique_key(key: Hashable, idx: int, state: State, used: set[Hashable]) -> Hashable:
|
190
|
+
if key is None or key in used:
|
191
|
+
base_name = getattr(state, 'name', None)
|
192
|
+
base = base_name if base_name not in (None, '') else f"state_{idx}"
|
193
|
+
candidate = base
|
194
|
+
suffix = 1
|
195
|
+
while candidate in used:
|
196
|
+
candidate = f"{base}_{suffix}"
|
197
|
+
suffix += 1
|
198
|
+
return candidate
|
199
|
+
return key
|
200
|
+
|
@@ -0,0 +1,84 @@
|
|
1
|
+
import unittest
|
2
|
+
|
3
|
+
import jax.numpy as jnp
|
4
|
+
|
5
|
+
import brainstate as bst
|
6
|
+
from brainstate.transform import StateFinder
|
7
|
+
|
8
|
+
|
9
|
+
class TestStateFinder(unittest.TestCase):
|
10
|
+
def test_default_dictionary_output(self):
|
11
|
+
read_state = bst.State(jnp.array(0.0), name='read_state')
|
12
|
+
param_state = bst.ParamState(jnp.array(1.0), name='param_state')
|
13
|
+
|
14
|
+
def fn(scale):
|
15
|
+
_ = read_state.value
|
16
|
+
param_state.value = param_state.value * scale
|
17
|
+
return param_state.value + _
|
18
|
+
|
19
|
+
finder = StateFinder(fn)
|
20
|
+
result = finder(2.0)
|
21
|
+
self.assertEqual(len(result), 2)
|
22
|
+
self.assertEqual(set(result.values()), {read_state, param_state})
|
23
|
+
|
24
|
+
def test_filter_and_usage_read(self):
|
25
|
+
buffer_state = bst.State(jnp.array(1.0), name='buffer')
|
26
|
+
param_state = bst.ParamState(jnp.array(3.0), name='param')
|
27
|
+
|
28
|
+
def fn(offset):
|
29
|
+
_ = buffer_state.value
|
30
|
+
param_state.value = param_state.value + offset
|
31
|
+
return param_state.value
|
32
|
+
|
33
|
+
read_finder = StateFinder(fn, usage='read', return_type='list')
|
34
|
+
reads = read_finder(1.0)
|
35
|
+
self.assertEqual(reads, [buffer_state])
|
36
|
+
|
37
|
+
param_finder = StateFinder(fn, filter=bst.ParamState, usage='all')
|
38
|
+
param_states = param_finder(1.0)
|
39
|
+
self.assertEqual(list(param_states.values()), [param_state])
|
40
|
+
|
41
|
+
def test_usage_write_with_custom_key(self):
|
42
|
+
param_state = bst.ParamState(jnp.array(5.0), name='param')
|
43
|
+
|
44
|
+
def fn(scale):
|
45
|
+
param_state.value = param_state.value * scale
|
46
|
+
return param_state.value
|
47
|
+
|
48
|
+
finder = StateFinder(fn, usage='write', return_type='dict', key_fn=lambda idx, st: f"w_{idx}")
|
49
|
+
write_states = finder(2.0)
|
50
|
+
self.assertIn('w_0', write_states)
|
51
|
+
self.assertIs(write_states['w_0'], param_state)
|
52
|
+
|
53
|
+
def test_usage_both_returns_separated_collections(self):
|
54
|
+
read_state = bst.State(jnp.array(2.0), name='read')
|
55
|
+
write_state = bst.ParamState(jnp.array(4.0), name='write')
|
56
|
+
|
57
|
+
def fn(delta):
|
58
|
+
_ = read_state.value
|
59
|
+
write_state.value = write_state.value + delta
|
60
|
+
return write_state.value
|
61
|
+
|
62
|
+
finder = StateFinder(fn, usage='both', return_type='tuple')
|
63
|
+
result = finder(1.5)
|
64
|
+
self.assertEqual(set(result.keys()), {'read', 'write'})
|
65
|
+
self.assertEqual(result['read'], (read_state,))
|
66
|
+
self.assertEqual(result['write'], (write_state,))
|
67
|
+
|
68
|
+
def test_duplicate_names_are_disambiguated(self):
|
69
|
+
first = bst.State(jnp.array(0.0), name='dup')
|
70
|
+
second = bst.State(jnp.array(1.0), name='dup')
|
71
|
+
|
72
|
+
def fn():
|
73
|
+
_ = first.value
|
74
|
+
_ = second.value
|
75
|
+
return None
|
76
|
+
|
77
|
+
finder = StateFinder(fn)
|
78
|
+
states = finder()
|
79
|
+
self.assertEqual(len(states), 2)
|
80
|
+
self.assertEqual(set(states.values()), {first, second})
|
81
|
+
|
82
|
+
|
83
|
+
if __name__ == "__main__":
|
84
|
+
unittest.main()
|