brainstate 0.2.0__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. brainstate/__init__.py +2 -4
  2. brainstate/_deprecation_test.py +2 -24
  3. brainstate/_state.py +540 -35
  4. brainstate/_state_test.py +1085 -8
  5. brainstate/graph/_operation.py +1 -5
  6. brainstate/mixin.py +14 -0
  7. brainstate/nn/__init__.py +42 -33
  8. brainstate/nn/_collective_ops.py +2 -0
  9. brainstate/nn/_common_test.py +0 -20
  10. brainstate/nn/_delay.py +1 -1
  11. brainstate/nn/_dropout_test.py +9 -6
  12. brainstate/nn/_dynamics.py +67 -464
  13. brainstate/nn/_dynamics_test.py +0 -14
  14. brainstate/nn/_embedding.py +7 -7
  15. brainstate/nn/_exp_euler.py +9 -9
  16. brainstate/nn/_linear.py +21 -21
  17. brainstate/nn/_module.py +25 -18
  18. brainstate/nn/_normalizations.py +27 -27
  19. brainstate/random/__init__.py +6 -6
  20. brainstate/random/{_rand_funs.py → _fun.py} +1 -1
  21. brainstate/random/{_rand_funs_test.py → _fun_test.py} +0 -2
  22. brainstate/random/_impl.py +672 -0
  23. brainstate/random/{_rand_seed.py → _seed.py} +1 -1
  24. brainstate/random/{_rand_state.py → _state.py} +121 -418
  25. brainstate/random/{_rand_state_test.py → _state_test.py} +7 -7
  26. brainstate/transform/__init__.py +6 -9
  27. brainstate/transform/_conditions.py +2 -2
  28. brainstate/transform/_find_state.py +200 -0
  29. brainstate/transform/_find_state_test.py +84 -0
  30. brainstate/transform/_make_jaxpr.py +221 -61
  31. brainstate/transform/_make_jaxpr_test.py +125 -1
  32. brainstate/transform/_mapping.py +287 -209
  33. brainstate/transform/_mapping_test.py +94 -184
  34. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/METADATA +1 -1
  35. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/RECORD +39 -39
  36. brainstate/transform/_eval_shape.py +0 -145
  37. brainstate/transform/_eval_shape_test.py +0 -38
  38. brainstate/transform/_random.py +0 -171
  39. /brainstate/random/{_rand_seed_test.py → _seed_test.py} +0 -0
  40. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  41. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +0 -0
  42. {brainstate-0.2.0.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ import jax.numpy as jnp
19
19
  import jax.random as jr
20
20
  import numpy as np
21
21
 
22
- from brainstate.random._rand_state import RandomState, DEFAULT, _formalize_key, _size2shape, _check_py_seq
22
+ from brainstate.random._state import RandomState, DEFAULT, formalize_key, _size2shape, _check_py_seq
23
23
 
24
24
 
25
25
  class TestRandomStateInitialization(unittest.TestCase):
@@ -437,32 +437,32 @@ class TestUtilityFunctions(unittest.TestCase):
437
437
 
438
438
  def test_formalize_key_with_int(self):
439
439
  """Test _formalize_key with integer."""
440
- key = _formalize_key(42)
440
+ key = formalize_key(42)
441
441
  expected = jr.PRNGKey(42)
442
442
  np.testing.assert_array_equal(key, expected)
443
443
 
444
444
  def test_formalize_key_with_array(self):
445
445
  """Test _formalize_key with array."""
446
446
  input_key = jr.PRNGKey(123)
447
- key = _formalize_key(input_key)
447
+ key = formalize_key(input_key, True)
448
448
  np.testing.assert_array_equal(key, input_key)
449
449
 
450
450
  def test_formalize_key_with_uint32_array(self):
451
451
  """Test _formalize_key with uint32 array."""
452
452
  input_array = np.array([123, 456], dtype=np.uint32)
453
- key = _formalize_key(input_array)
453
+ key = formalize_key(input_array)
454
454
  np.testing.assert_array_equal(key, input_array)
455
455
 
456
456
  def test_formalize_key_invalid_input(self):
457
457
  """Test _formalize_key with invalid input."""
458
458
  with self.assertRaises(TypeError):
459
- _formalize_key("invalid")
459
+ formalize_key("invalid")
460
460
 
461
461
  with self.assertRaises(TypeError):
462
- _formalize_key(np.array([1, 2, 3], dtype=np.uint32)) # Wrong size
462
+ formalize_key(np.array([1, 2, 3], dtype=np.uint32)) # Wrong size
463
463
 
464
464
  with self.assertRaises(TypeError):
465
- _formalize_key(np.array([1, 2], dtype=np.int32)) # Wrong dtype
465
+ formalize_key(np.array([1, 2], dtype=np.int32)) # Wrong dtype
466
466
 
467
467
  def test_size2shape(self):
468
468
  """Test _size2shape function."""
@@ -22,8 +22,8 @@ from ._conditions import *
22
22
  from ._conditions import __all__ as _conditions_all
23
23
  from ._error_if import *
24
24
  from ._error_if import __all__ as _error_if_all
25
- from ._eval_shape import *
26
- from ._eval_shape import __all__ as _eval_shape_all
25
+ from ._find_state import *
26
+ from ._find_state import __all__ as _find_all
27
27
  from ._jit import *
28
28
  from ._jit import __all__ as _jit_all
29
29
  from ._loop_collect_return import *
@@ -36,24 +36,21 @@ from ._mapping import *
36
36
  from ._mapping import __all__ as _mapping_all
37
37
  from ._progress_bar import *
38
38
  from ._progress_bar import __all__ as _progress_bar_all
39
- from ._random import *
40
- from ._random import __all__ as _random_all
41
39
  from ._unvmap import *
42
40
  from ._unvmap import __all__ as _unvmap_all
43
41
 
44
- __all__ = _ad_checkpoint_all + _autograd_all + _conditions_all + _error_if_all
45
- __all__ += _eval_shape_all + _jit_all + _loop_collect_return_all + _loop_no_collection_all
46
- __all__ += _make_jaxpr_all + _mapping_all + _progress_bar_all + _random_all + _unvmap_all
42
+ __all__ = _ad_checkpoint_all + _autograd_all + _conditions_all + _error_if_all + _find_all
43
+ __all__ += _jit_all + _loop_collect_return_all + _loop_no_collection_all
44
+ __all__ += _make_jaxpr_all + _mapping_all + _progress_bar_all + _unvmap_all
45
+ del _find_all
47
46
  del _ad_checkpoint_all
48
47
  del _autograd_all
49
48
  del _conditions_all
50
49
  del _error_if_all
51
- del _eval_shape_all
52
50
  del _jit_all
53
51
  del _loop_collect_return_all
54
52
  del _loop_no_collection_all
55
53
  del _make_jaxpr_all
56
54
  del _mapping_all
57
55
  del _progress_bar_all
58
- del _random_all
59
56
  del _unvmap_all
@@ -218,7 +218,7 @@ def switch(index, branches: Sequence[Callable], *operands):
218
218
  index = jax.lax.clamp(lo, index, hi)
219
219
 
220
220
  # not jit
221
- if jax.config.jax_disable_jit and isinstance(jax.core.core.get_aval(index), jax.core.ConcreteArray):
221
+ if jax.config.jax_disable_jit and not isinstance(to_concrete_aval(index), Tracer):
222
222
  return branches[int(index)](*operands)
223
223
 
224
224
  # evaluate jaxpr
@@ -311,6 +311,6 @@ def ifelse(conditions, branches, *operands, check_cond: bool = True):
311
311
  # format index
312
312
  conditions = jnp.asarray(conditions, np.int32)
313
313
  if check_cond:
314
- jit_error_if(jnp.sum(conditions) != 1, "Only one condition can be True. But got {}.", err_arg=conditions)
314
+ jit_error_if(jnp.sum(conditions) != 1, "Only one condition can be True. But got {c}.", c=conditions)
315
315
  index = jnp.where(conditions, size=1, fill_value=len(conditions) - 1)[0][0]
316
316
  return switch(index, branches, *operands)
@@ -0,0 +1,200 @@
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Callable, Dict, Hashable, Literal, Sequence, Any
17
+
18
+ from brainstate._state import State
19
+ from brainstate.util.filter import Filter, to_predicate
20
+ from ._make_jaxpr import StatefulFunction
21
+
22
+ __all__ = [
23
+ 'StateFinder'
24
+ ]
25
+
26
+
27
+ class StateFinder:
28
+ """
29
+ Discover :class:`~brainstate.State` instances touched by a callable.
30
+
31
+ ``StateFinder`` wraps a function in :class:`~brainstate.transform.StatefulFunction`
32
+ and exposes the collection of states the function reads or writes. The finder
33
+ can filter states by predicates, request only read or write states, and return
34
+ the result in several convenient formats.
35
+
36
+ Parameters
37
+ ----------
38
+ fn : callable
39
+ Function whose state usage should be inspected.
40
+ filter : Filter, optional
41
+ Predicate (see :mod:`brainstate.util.filter`) used to select states.
42
+ usage : {'all', 'read', 'write', 'both'}, default 'all'
43
+ Portion of the state trace to return. ``'both'`` returns a mapping with
44
+ separate read and write entries.
45
+ return_type : {'dict', 'list', 'tuple'}, default 'dict'
46
+ Controls the container type returned for the selected states. When
47
+ ``usage='both'``, the same container type is used for the ``'read'`` and
48
+ ``'write'`` entries.
49
+ key_fn : callable, optional
50
+ Callable ``key_fn(index, state)`` that produces dictionary keys when
51
+ ``return_type='dict'``. Defaults to using the positional index so existing
52
+ code continues to receive integer keys.
53
+
54
+ Examples
55
+ --------
56
+ .. code-block:: python
57
+
58
+ >>> import brainstate
59
+ >>> import jax.numpy as jnp
60
+ >>>
61
+ >>> param = brainstate.ParamState(jnp.ones(()), name='weights')
62
+ >>> bias = brainstate.ParamState(jnp.zeros(()), name='bias')
63
+ >>>
64
+ >>> def forward(x):
65
+ ... _ = bias.value # read-only
66
+ ... param.value = param.value * x # read + write
67
+ ... return param.value + bias.value
68
+ >>>
69
+ >>> finder = brainstate.transform.StateFinder(
70
+ ... forward,
71
+ ... filter=brainstate.ParamState,
72
+ ... usage='both',
73
+ ... key_fn=lambda i, st: st.name or i,
74
+ ... )
75
+ >>> finder(2.0)['write'] # doctest: +ELLIPSIS
76
+ {'weights': ParamState(...}
77
+
78
+ Notes
79
+ -----
80
+ The underlying :class:`StatefulFunction` is cached, so subsequent calls with
81
+ compatible arguments will reuse the compiled trace.
82
+ """
83
+
84
+ _VALID_USAGE: tuple[str, ...] = ('all', 'read', 'write', 'both')
85
+ _VALID_RETURN_TYPE: tuple[str, ...] = ('dict', 'list', 'tuple')
86
+
87
+ def __init__(
88
+ self,
89
+ fn: Callable,
90
+ filter: Filter = None,
91
+ *,
92
+ usage: Literal['all', 'read', 'write', 'both'] = 'all',
93
+ return_type: Literal['dict', 'list', 'tuple'] = 'dict',
94
+ key_fn: Callable[[int, State], Hashable] | None = None,
95
+ ) -> None:
96
+ if usage not in self._VALID_USAGE:
97
+ raise ValueError(f"Invalid usage '{usage}'. Expected one of {self._VALID_USAGE}.")
98
+ if return_type not in self._VALID_RETURN_TYPE:
99
+ raise ValueError(
100
+ f"Invalid return_type '{return_type}'. Expected one of {self._VALID_RETURN_TYPE}."
101
+ )
102
+
103
+ self.fn = fn
104
+ self._usage = usage
105
+ self._return_type = return_type
106
+ self._key_fn = key_fn if key_fn is not None else self._default_key_fn
107
+ self._filter = to_predicate(filter) if filter is not None else None
108
+ self.stateful_fn = StatefulFunction(self.fn)
109
+
110
+ def __call__(self, *args, **kwargs):
111
+ """
112
+ Invoke :meth:`find` to retrieve states touched by ``fn``.
113
+ """
114
+ return self.find(*args, **kwargs)
115
+
116
+ def find(self, *args, **kwargs):
117
+ """
118
+ Execute the wrapped function symbolically and return the selected states.
119
+
120
+ Parameters
121
+ ----------
122
+ *args, **kwargs
123
+ Arguments forwarded to ``fn`` to determine the state trace.
124
+
125
+ Returns
126
+ -------
127
+ Any
128
+ Container holding the requested states as configured by ``usage`` and
129
+ ``return_type``.
130
+ """
131
+ if self._usage == 'both':
132
+ read_states = self._collect_states('read', *args, **kwargs)
133
+ write_states = self._collect_states('write', *args, **kwargs)
134
+ return {
135
+ 'read': self._format_states(read_states),
136
+ 'write': self._format_states(write_states),
137
+ }
138
+
139
+ states = self._collect_states(self._usage, *args, **kwargs)
140
+ return self._format_states(states)
141
+
142
+ def _collect_states(self, usage: str, *args, **kwargs) -> Sequence[State]:
143
+ usage_map = {
144
+ 'all': self.stateful_fn.get_states,
145
+ 'read': self.stateful_fn.get_read_states,
146
+ 'write': self.stateful_fn.get_write_states,
147
+ }
148
+ collector = usage_map.get(usage)
149
+ if collector is None:
150
+ raise ValueError(f"Unsupported usage '{usage}'.")
151
+ states = list(collector(*args, **kwargs))
152
+ if self._filter is not None:
153
+ states = [st for st in states if self._filter(tuple(), st)]
154
+ return states
155
+
156
+ def _format_states(self, states: Sequence[State]):
157
+ if self._return_type == 'list':
158
+ return list(states)
159
+ if self._return_type == 'tuple':
160
+ return tuple(states)
161
+ return self._states_to_dict(states)
162
+
163
+ def _states_to_dict(self, states: Sequence[State]) -> Dict[Hashable, State]:
164
+ result: Dict[Hashable, State] = {}
165
+ used_keys: set[Hashable] = set()
166
+ for idx, state in enumerate(states):
167
+ key = self._key_fn(idx, state)
168
+ key = self._ensure_hashable(key)
169
+ key = self._ensure_unique_key(key, idx, state, used_keys)
170
+ result[key] = state
171
+ used_keys.add(key)
172
+ return result
173
+
174
+ @staticmethod
175
+ def _default_key_fn(idx: int, state: State) -> Hashable:
176
+ return idx
177
+
178
+ @staticmethod
179
+ def _ensure_hashable(key: Any) -> Hashable:
180
+ if key is None:
181
+ return None
182
+ try:
183
+ hash(key)
184
+ except TypeError:
185
+ return str(key)
186
+ return key
187
+
188
+ @staticmethod
189
+ def _ensure_unique_key(key: Hashable, idx: int, state: State, used: set[Hashable]) -> Hashable:
190
+ if key is None or key in used:
191
+ base_name = getattr(state, 'name', None)
192
+ base = base_name if base_name not in (None, '') else f"state_{idx}"
193
+ candidate = base
194
+ suffix = 1
195
+ while candidate in used:
196
+ candidate = f"{base}_{suffix}"
197
+ suffix += 1
198
+ return candidate
199
+ return key
200
+
@@ -0,0 +1,84 @@
1
+ import unittest
2
+
3
+ import jax.numpy as jnp
4
+
5
+ import brainstate as bst
6
+ from brainstate.transform import StateFinder
7
+
8
+
9
+ class TestStateFinder(unittest.TestCase):
10
+ def test_default_dictionary_output(self):
11
+ read_state = bst.State(jnp.array(0.0), name='read_state')
12
+ param_state = bst.ParamState(jnp.array(1.0), name='param_state')
13
+
14
+ def fn(scale):
15
+ _ = read_state.value
16
+ param_state.value = param_state.value * scale
17
+ return param_state.value + _
18
+
19
+ finder = StateFinder(fn)
20
+ result = finder(2.0)
21
+ self.assertEqual(len(result), 2)
22
+ self.assertEqual(set(result.values()), {read_state, param_state})
23
+
24
+ def test_filter_and_usage_read(self):
25
+ buffer_state = bst.State(jnp.array(1.0), name='buffer')
26
+ param_state = bst.ParamState(jnp.array(3.0), name='param')
27
+
28
+ def fn(offset):
29
+ _ = buffer_state.value
30
+ param_state.value = param_state.value + offset
31
+ return param_state.value
32
+
33
+ read_finder = StateFinder(fn, usage='read', return_type='list')
34
+ reads = read_finder(1.0)
35
+ self.assertEqual(reads, [buffer_state])
36
+
37
+ param_finder = StateFinder(fn, filter=bst.ParamState, usage='all')
38
+ param_states = param_finder(1.0)
39
+ self.assertEqual(list(param_states.values()), [param_state])
40
+
41
+ def test_usage_write_with_custom_key(self):
42
+ param_state = bst.ParamState(jnp.array(5.0), name='param')
43
+
44
+ def fn(scale):
45
+ param_state.value = param_state.value * scale
46
+ return param_state.value
47
+
48
+ finder = StateFinder(fn, usage='write', return_type='dict', key_fn=lambda idx, st: f"w_{idx}")
49
+ write_states = finder(2.0)
50
+ self.assertIn('w_0', write_states)
51
+ self.assertIs(write_states['w_0'], param_state)
52
+
53
+ def test_usage_both_returns_separated_collections(self):
54
+ read_state = bst.State(jnp.array(2.0), name='read')
55
+ write_state = bst.ParamState(jnp.array(4.0), name='write')
56
+
57
+ def fn(delta):
58
+ _ = read_state.value
59
+ write_state.value = write_state.value + delta
60
+ return write_state.value
61
+
62
+ finder = StateFinder(fn, usage='both', return_type='tuple')
63
+ result = finder(1.5)
64
+ self.assertEqual(set(result.keys()), {'read', 'write'})
65
+ self.assertEqual(result['read'], (read_state,))
66
+ self.assertEqual(result['write'], (write_state,))
67
+
68
+ def test_duplicate_names_are_disambiguated(self):
69
+ first = bst.State(jnp.array(0.0), name='dup')
70
+ second = bst.State(jnp.array(1.0), name='dup')
71
+
72
+ def fn():
73
+ _ = first.value
74
+ _ = second.value
75
+ return None
76
+
77
+ finder = StateFinder(fn)
78
+ states = finder()
79
+ self.assertEqual(len(states), 2)
80
+ self.assertEqual(set(states.values()), {first, second})
81
+
82
+
83
+ if __name__ == "__main__":
84
+ unittest.main()