brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,52 +1,52 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import unittest
17
-
18
- import jax
19
- import jax.numpy as jnp
20
-
21
- import brainstate
22
-
23
-
24
- class TestJitError(unittest.TestCase):
25
- def test1(self):
26
- with self.assertRaises(Exception):
27
- brainstate.compile.jit_error_if(True, 'error')
28
-
29
- def err_f(x):
30
- raise ValueError(f'error: {x}')
31
-
32
- brainstate.compile.jit_error_if(False, err_f, 1.)
33
- with self.assertRaises(Exception):
34
- brainstate.compile.jit_error_if(True, err_f, 1.)
35
-
36
- def test_vmap(self):
37
- def f(x):
38
- brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
39
-
40
- jax.vmap(f)(jnp.array([False, False, False]))
41
- with self.assertRaises(Exception):
42
- jax.vmap(f)(jnp.array([True, False, False]))
43
-
44
- def test_vmap_vmap(self):
45
- def f(x):
46
- brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
47
-
48
- jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
49
- [False, False, False]]))
50
- with self.assertRaises(Exception):
51
- jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
52
- [True, False, False]]))
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+
21
+ import brainstate
22
+
23
+
24
+ class TestJitError(unittest.TestCase):
25
+ def test1(self):
26
+ with self.assertRaises(Exception):
27
+ brainstate.compile.jit_error_if(True, 'error')
28
+
29
+ def err_f(x):
30
+ raise ValueError(f'error: {x}')
31
+
32
+ brainstate.compile.jit_error_if(False, err_f, 1.)
33
+ with self.assertRaises(Exception):
34
+ brainstate.compile.jit_error_if(True, err_f, 1.)
35
+
36
+ def test_vmap(self):
37
+ def f(x):
38
+ brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
39
+
40
+ jax.vmap(f)(jnp.array([False, False, False]))
41
+ with self.assertRaises(Exception):
42
+ jax.vmap(f)(jnp.array([True, False, False]))
43
+
44
+ def test_vmap_vmap(self):
45
+ def f(x):
46
+ brainstate.compile.jit_error_if(x, 'error: {x}', x=x)
47
+
48
+ jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
49
+ [False, False, False]]))
50
+ with self.assertRaises(Exception):
51
+ jax.vmap(jax.vmap(f))(jnp.array([[False, False, False],
52
+ [True, False, False]]))
@@ -0,0 +1,200 @@
1
+ # Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Callable, Dict, Hashable, Literal, Sequence, Any
17
+
18
+ from brainstate._state import State
19
+ from brainstate.util.filter import Filter, to_predicate
20
+ from ._make_jaxpr import StatefulFunction
21
+
22
+ __all__ = [
23
+ 'StateFinder'
24
+ ]
25
+
26
+
27
+ class StateFinder:
28
+ """
29
+ Discover :class:`~brainstate.State` instances touched by a callable.
30
+
31
+ ``StateFinder`` wraps a function in :class:`~brainstate.transform.StatefulFunction`
32
+ and exposes the collection of states the function reads or writes. The finder
33
+ can filter states by predicates, request only read or write states, and return
34
+ the result in several convenient formats.
35
+
36
+ Parameters
37
+ ----------
38
+ fn : callable
39
+ Function whose state usage should be inspected.
40
+ filter : Filter, optional
41
+ Predicate (see :mod:`brainstate.util.filter`) used to select states.
42
+ usage : {'all', 'read', 'write', 'both'}, default 'all'
43
+ Portion of the state trace to return. ``'both'`` returns a mapping with
44
+ separate read and write entries.
45
+ return_type : {'dict', 'list', 'tuple'}, default 'dict'
46
+ Controls the container type returned for the selected states. When
47
+ ``usage='both'``, the same container type is used for the ``'read'`` and
48
+ ``'write'`` entries.
49
+ key_fn : callable, optional
50
+ Callable ``key_fn(index, state)`` that produces dictionary keys when
51
+ ``return_type='dict'``. Defaults to using the positional index so existing
52
+ code continues to receive integer keys.
53
+
54
+ Examples
55
+ --------
56
+ .. code-block:: python
57
+
58
+ >>> import brainstate
59
+ >>> import jax.numpy as jnp
60
+ >>>
61
+ >>> param = brainstate.ParamState(jnp.ones(()), name='weights')
62
+ >>> bias = brainstate.ParamState(jnp.zeros(()), name='bias')
63
+ >>>
64
+ >>> def forward(x):
65
+ ... _ = bias.value # read-only
66
+ ... param.value = param.value * x # read + write
67
+ ... return param.value + bias.value
68
+ >>>
69
+ >>> finder = brainstate.transform.StateFinder(
70
+ ... forward,
71
+ ... filter=brainstate.ParamState,
72
+ ... usage='both',
73
+ ... key_fn=lambda i, st: st.name or i,
74
+ ... )
75
+ >>> finder(2.0)['write'] # doctest: +ELLIPSIS
76
+ {'weights': ParamState(...}
77
+
78
+ Notes
79
+ -----
80
+ The underlying :class:`StatefulFunction` is cached, so subsequent calls with
81
+ compatible arguments will reuse the compiled trace.
82
+ """
83
+
84
+ _VALID_USAGE: tuple[str, ...] = ('all', 'read', 'write', 'both')
85
+ _VALID_RETURN_TYPE: tuple[str, ...] = ('dict', 'list', 'tuple')
86
+
87
+ def __init__(
88
+ self,
89
+ fn: Callable,
90
+ filter: Filter = None,
91
+ *,
92
+ usage: Literal['all', 'read', 'write', 'both'] = 'all',
93
+ return_type: Literal['dict', 'list', 'tuple'] = 'dict',
94
+ key_fn: Callable[[int, State], Hashable] | None = None,
95
+ ) -> None:
96
+ if usage not in self._VALID_USAGE:
97
+ raise ValueError(f"Invalid usage '{usage}'. Expected one of {self._VALID_USAGE}.")
98
+ if return_type not in self._VALID_RETURN_TYPE:
99
+ raise ValueError(
100
+ f"Invalid return_type '{return_type}'. Expected one of {self._VALID_RETURN_TYPE}."
101
+ )
102
+
103
+ self.fn = fn
104
+ self._usage = usage
105
+ self._return_type = return_type
106
+ self._key_fn = key_fn if key_fn is not None else self._default_key_fn
107
+ self._filter = to_predicate(filter) if filter is not None else None
108
+ self.stateful_fn = StatefulFunction(self.fn)
109
+
110
+ def __call__(self, *args, **kwargs):
111
+ """
112
+ Invoke :meth:`find` to retrieve states touched by ``fn``.
113
+ """
114
+ return self.find(*args, **kwargs)
115
+
116
+ def find(self, *args, **kwargs):
117
+ """
118
+ Execute the wrapped function symbolically and return the selected states.
119
+
120
+ Parameters
121
+ ----------
122
+ *args, **kwargs
123
+ Arguments forwarded to ``fn`` to determine the state trace.
124
+
125
+ Returns
126
+ -------
127
+ Any
128
+ Container holding the requested states as configured by ``usage`` and
129
+ ``return_type``.
130
+ """
131
+ if self._usage == 'both':
132
+ read_states = self._collect_states('read', *args, **kwargs)
133
+ write_states = self._collect_states('write', *args, **kwargs)
134
+ return {
135
+ 'read': self._format_states(read_states),
136
+ 'write': self._format_states(write_states),
137
+ }
138
+
139
+ states = self._collect_states(self._usage, *args, **kwargs)
140
+ return self._format_states(states)
141
+
142
+ def _collect_states(self, usage: str, *args, **kwargs) -> Sequence[State]:
143
+ usage_map = {
144
+ 'all': self.stateful_fn.get_states,
145
+ 'read': self.stateful_fn.get_read_states,
146
+ 'write': self.stateful_fn.get_write_states,
147
+ }
148
+ collector = usage_map.get(usage)
149
+ if collector is None:
150
+ raise ValueError(f"Unsupported usage '{usage}'.")
151
+ states = list(collector(*args, **kwargs))
152
+ if self._filter is not None:
153
+ states = [st for st in states if self._filter(tuple(), st)]
154
+ return states
155
+
156
+ def _format_states(self, states: Sequence[State]):
157
+ if self._return_type == 'list':
158
+ return list(states)
159
+ if self._return_type == 'tuple':
160
+ return tuple(states)
161
+ return self._states_to_dict(states)
162
+
163
+ def _states_to_dict(self, states: Sequence[State]) -> Dict[Hashable, State]:
164
+ result: Dict[Hashable, State] = {}
165
+ used_keys: set[Hashable] = set()
166
+ for idx, state in enumerate(states):
167
+ key = self._key_fn(idx, state)
168
+ key = self._ensure_hashable(key)
169
+ key = self._ensure_unique_key(key, idx, state, used_keys)
170
+ result[key] = state
171
+ used_keys.add(key)
172
+ return result
173
+
174
+ @staticmethod
175
+ def _default_key_fn(idx: int, state: State) -> Hashable:
176
+ return idx
177
+
178
+ @staticmethod
179
+ def _ensure_hashable(key: Any) -> Hashable:
180
+ if key is None:
181
+ return None
182
+ try:
183
+ hash(key)
184
+ except TypeError:
185
+ return str(key)
186
+ return key
187
+
188
+ @staticmethod
189
+ def _ensure_unique_key(key: Hashable, idx: int, state: State, used: set[Hashable]) -> Hashable:
190
+ if key is None or key in used:
191
+ base_name = getattr(state, 'name', None)
192
+ base = base_name if base_name not in (None, '') else f"state_{idx}"
193
+ candidate = base
194
+ suffix = 1
195
+ while candidate in used:
196
+ candidate = f"{base}_{suffix}"
197
+ suffix += 1
198
+ return candidate
199
+ return key
200
+
@@ -0,0 +1,84 @@
1
+ import unittest
2
+
3
+ import jax.numpy as jnp
4
+
5
+ import brainstate as bst
6
+ from brainstate.transform import StateFinder
7
+
8
+
9
+ class TestStateFinder(unittest.TestCase):
10
+ def test_default_dictionary_output(self):
11
+ read_state = bst.State(jnp.array(0.0), name='read_state')
12
+ param_state = bst.ParamState(jnp.array(1.0), name='param_state')
13
+
14
+ def fn(scale):
15
+ _ = read_state.value
16
+ param_state.value = param_state.value * scale
17
+ return param_state.value + _
18
+
19
+ finder = StateFinder(fn)
20
+ result = finder(2.0)
21
+ self.assertEqual(len(result), 2)
22
+ self.assertEqual(set(result.values()), {read_state, param_state})
23
+
24
+ def test_filter_and_usage_read(self):
25
+ buffer_state = bst.State(jnp.array(1.0), name='buffer')
26
+ param_state = bst.ParamState(jnp.array(3.0), name='param')
27
+
28
+ def fn(offset):
29
+ _ = buffer_state.value
30
+ param_state.value = param_state.value + offset
31
+ return param_state.value
32
+
33
+ read_finder = StateFinder(fn, usage='read', return_type='list')
34
+ reads = read_finder(1.0)
35
+ self.assertEqual(reads, [buffer_state])
36
+
37
+ param_finder = StateFinder(fn, filter=bst.ParamState, usage='all')
38
+ param_states = param_finder(1.0)
39
+ self.assertEqual(list(param_states.values()), [param_state])
40
+
41
+ def test_usage_write_with_custom_key(self):
42
+ param_state = bst.ParamState(jnp.array(5.0), name='param')
43
+
44
+ def fn(scale):
45
+ param_state.value = param_state.value * scale
46
+ return param_state.value
47
+
48
+ finder = StateFinder(fn, usage='write', return_type='dict', key_fn=lambda idx, st: f"w_{idx}")
49
+ write_states = finder(2.0)
50
+ self.assertIn('w_0', write_states)
51
+ self.assertIs(write_states['w_0'], param_state)
52
+
53
+ def test_usage_both_returns_separated_collections(self):
54
+ read_state = bst.State(jnp.array(2.0), name='read')
55
+ write_state = bst.ParamState(jnp.array(4.0), name='write')
56
+
57
+ def fn(delta):
58
+ _ = read_state.value
59
+ write_state.value = write_state.value + delta
60
+ return write_state.value
61
+
62
+ finder = StateFinder(fn, usage='both', return_type='tuple')
63
+ result = finder(1.5)
64
+ self.assertEqual(set(result.keys()), {'read', 'write'})
65
+ self.assertEqual(result['read'], (read_state,))
66
+ self.assertEqual(result['write'], (write_state,))
67
+
68
+ def test_duplicate_names_are_disambiguated(self):
69
+ first = bst.State(jnp.array(0.0), name='dup')
70
+ second = bst.State(jnp.array(1.0), name='dup')
71
+
72
+ def fn():
73
+ _ = first.value
74
+ _ = second.value
75
+ return None
76
+
77
+ finder = StateFinder(fn)
78
+ states = finder()
79
+ self.assertEqual(len(states), 2)
80
+ self.assertEqual(set(states.values()), {first, second})
81
+
82
+
83
+ if __name__ == "__main__":
84
+ unittest.main()