brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,283 +1,283 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- import math
17
- from typing import Any, Callable, TypeVar
18
-
19
- import jax
20
-
21
- from brainstate._utils import set_module_as
22
- from ._loop_collect_return import _bounded_while_loop
23
- from ._make_jaxpr import StatefulFunction
24
- from ._util import wrap_single_fun_in_multi_branches_while_loop as wrap_fn
25
-
26
- X = TypeVar('X')
27
- Y = TypeVar('Y')
28
- T = TypeVar('T')
29
- Carry = TypeVar('Carry')
30
- BooleanNumeric = Any # A bool, or a Boolean array.
31
-
32
- __all__ = [
33
- 'while_loop', 'bounded_while_loop',
34
- ]
35
-
36
-
37
- @set_module_as('brainstate.transform')
38
- def while_loop(
39
- cond_fun: Callable[[T], BooleanNumeric],
40
- body_fun: Callable[[T], T],
41
- init_val: T
42
- ) -> T:
43
- """
44
- Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
45
-
46
- The `Haskell-like type signature`_ in brief is
47
-
48
- .. code-block:: haskell
49
-
50
- while_loop :: (a -> Bool) -> (a -> a) -> a -> a
51
-
52
- The semantics of ``while_loop`` are given by this Python implementation:
53
-
54
- .. code-block:: python
55
-
56
- >>> def while_loop(cond_fun, body_fun, init_val):
57
- ... val = init_val
58
- ... while cond_fun(val):
59
- ... val = body_fun(val)
60
- ... return val
61
-
62
- Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
63
- to a single WhileOp. That makes it useful for reducing compilation times
64
- for jit-compiled functions, since native Python loop constructs in an ``@jit``
65
- function are unrolled, leading to large XLA computations.
66
-
67
- Also unlike the Python analogue, the loop-carried value ``val`` must hold a
68
- fixed shape and dtype across all iterations (and not just be consistent up to
69
- NumPy rank/shape broadcasting and dtype promotion rules, for example). In
70
- other words, the type ``a`` in the type signature above represents an array
71
- with a fixed shape and dtype (or a nested tuple/list/dict container data
72
- structure with a fixed structure and arrays with fixed shape and dtype at the
73
- leaves).
74
-
75
- Another difference from using Python-native loop constructs is that
76
- ``while_loop`` is not reverse-mode differentiable because XLA computations
77
- require static bounds on memory requirements.
78
-
79
- Parameters
80
- ----------
81
- cond_fun : callable
82
- Function of type ``a -> Bool``.
83
- body_fun : callable
84
- Function of type ``a -> a``.
85
- init_val : T
86
- Value of type ``a``, a type that can be a scalar, array, or any
87
- pytree (nested Python tuple/list/dict) thereof, representing the initial
88
- loop carry value.
89
-
90
- Returns
91
- -------
92
- T
93
- The output from the final iteration of body_fun, of type ``a``.
94
-
95
- Examples
96
- --------
97
- Basic while loop operation:
98
-
99
- .. code-block:: python
100
-
101
- >>> import brainstate
102
- >>> import jax.numpy as jnp
103
- >>>
104
- >>> def cond_fn(val):
105
- ... return val < 10
106
- >>>
107
- >>> def body_fn(val):
108
- ... return val + 1
109
- >>>
110
- >>> result = brainstate.transform.while_loop(cond_fn, body_fn, 0)
111
- >>> # result will be 10
112
-
113
- While loop with array state:
114
-
115
- .. code-block:: python
116
-
117
- >>> def cond_fn(state):
118
- ... return jnp.sum(state) < 100
119
- >>>
120
- >>> def body_fn(state):
121
- ... return state * 1.1
122
- >>>
123
- >>> init_state = jnp.array([1.0, 2.0, 3.0])
124
- >>> final_state = brainstate.transform.while_loop(cond_fn, body_fn, init_state)
125
-
126
- References
127
- ----------
128
- .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
129
- """
130
- if not (callable(body_fun) and callable(cond_fun)):
131
- raise TypeError("while_loop: body_fun and cond_fun arguments should be callable.")
132
- if jax.config.jax_disable_jit:
133
- try:
134
- val = init_val
135
- while cond_fun(val):
136
- val = body_fun(val)
137
- return val
138
- except jax.core.ConcretizationTypeError:
139
- # Can't run this while_loop in Python (e.g. because there's a vmap
140
- # transformation on it), so we fall back to the primitive version.
141
- pass
142
-
143
- # evaluate jaxpr
144
- stateful_cond = StatefulFunction(cond_fun, name='while:cond').make_jaxpr(init_val)
145
- stateful_body = StatefulFunction(body_fun, name='while:body').make_jaxpr(init_val)
146
- cond_cache_key = stateful_cond.get_arg_cache_key(init_val)
147
- body_cache_key = stateful_body.get_arg_cache_key(init_val)
148
- if len(stateful_cond.get_write_states_by_cache(cond_cache_key)) != 0:
149
- raise ValueError("while_loop: cond_fun should not have any write states.")
150
-
151
- # state trace and state values
152
- state_trace = (stateful_cond.get_state_trace_by_cache(cond_cache_key) +
153
- stateful_body.get_state_trace_by_cache(body_cache_key))
154
- read_state_vals = state_trace.get_read_state_values(True)
155
- write_state_vals = state_trace.get_write_state_values(True)
156
- new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False, cond_cache_key)
157
- new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True, body_cache_key)
158
-
159
- # while_loop
160
- state_vals, final_val = jax.lax.while_loop(new_cond_fn, new_body_fn, (write_state_vals, init_val))
161
-
162
- # write back state values or restore them
163
- state_trace.assign_state_vals_v2(read_state_vals, state_vals)
164
- return final_val
165
-
166
-
167
- @set_module_as('brainstate.transform')
168
- def bounded_while_loop(
169
- cond_fun: Callable[[T], BooleanNumeric],
170
- body_fun: Callable[[T], T],
171
- init_val: T,
172
- *,
173
- max_steps: int,
174
- base: int = 16,
175
- ):
176
- """
177
- While loop with a bound on the maximum number of steps.
178
-
179
- This function is adapted from ``while_loop`` in `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
180
-
181
- This function is useful when you want to ensure that a while loop terminates
182
- even if the condition function is never false. The function is implemented
183
- using a scan operation, so it is reverse-mode differentiable.
184
-
185
- Parameters
186
- ----------
187
- cond_fun : callable
188
- A function of type ``a -> Bool``.
189
- body_fun : callable
190
- A function of type ``a -> a``.
191
- init_val : T
192
- The initial value of type ``a``.
193
- max_steps : int
194
- A bound on the maximum number of steps, after which the loop
195
- terminates unconditionally.
196
- base : int, default 16
197
- Run time will increase slightly as `base` increases. Compilation time will
198
- decrease substantially as `math.ceil(math.log(max_steps, base))` decreases.
199
- (Which happens as `base` increases.)
200
-
201
- Returns
202
- -------
203
- T
204
- The final value, as if computed by a `lax.while_loop`.
205
-
206
- Examples
207
- --------
208
- Basic bounded while loop:
209
-
210
- .. code-block:: python
211
-
212
- >>> import brainstate
213
- >>> import jax.numpy as jnp
214
- >>>
215
- >>> def cond_fn(val):
216
- ... return val < 1000 # This might never be false
217
- >>>
218
- >>> def body_fn(val):
219
- ... return val * 2
220
- >>>
221
- >>> # Loop will terminate after at most 10 steps
222
- >>> result = brainstate.transform.bounded_while_loop(
223
- ... cond_fn, body_fn, 1, max_steps=10
224
- ... )
225
-
226
- Bounded while loop with custom base:
227
-
228
- .. code-block:: python
229
-
230
- >>> # Use a smaller base for potentially faster compilation
231
- >>> result = brainstate.transform.bounded_while_loop(
232
- ... cond_fn, body_fn, 1, max_steps=100, base=8
233
- ... )
234
-
235
- Bounded while loop with array state:
236
-
237
- .. code-block:: python
238
-
239
- >>> def cond_fn(state):
240
- ... return jnp.max(state) < 100
241
- >>>
242
- >>> def body_fn(state):
243
- ... return state + jnp.array([1.0, 2.0, 3.0])
244
- >>>
245
- >>> init_state = jnp.array([0.0, 0.0, 0.0])
246
- >>> final_state = brainstate.transform.bounded_while_loop(
247
- ... cond_fn, body_fn, init_state, max_steps=50
248
- ... )
249
- """
250
-
251
- # checking
252
- if not isinstance(max_steps, int) or max_steps < 0:
253
- raise ValueError("max_steps must be a non-negative integer")
254
- init_val = jax.tree.map(jax.numpy.asarray, init_val)
255
- if max_steps == 0:
256
- return init_val
257
-
258
- # evaluate jaxpr
259
- stateful_cond = StatefulFunction(cond_fun, name='bounded_while:cond').make_jaxpr(init_val)
260
- stateful_body = StatefulFunction(body_fun, name='bounded_while:body').make_jaxpr(init_val)
261
- cond_cache_key = stateful_cond.get_arg_cache_key(init_val)
262
- body_cache_key = stateful_body.get_arg_cache_key(init_val)
263
- if len(stateful_cond.get_write_states_by_cache(cond_cache_key)) != 0:
264
- raise ValueError("while_loop: cond_fun should not have any write states.")
265
-
266
- # state trace and state values
267
- state_trace = (stateful_cond.get_state_trace(init_val) +
268
- stateful_body.get_state_trace(init_val))
269
- read_state_vals = state_trace.get_read_state_values(True)
270
- write_state_vals = state_trace.get_write_state_values(True)
271
- new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False, cond_cache_key)
272
- new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True, body_cache_key)
273
-
274
- # initial value
275
- init_val = (write_state_vals, init_val)
276
-
277
- # while_loop
278
- rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base)))
279
- state_vals, val = _bounded_while_loop(new_cond_fn, new_body_fn, init_val, rounded_max_steps, base, None)
280
-
281
- # write back state values or restore them
282
- state_trace.assign_state_vals_v2(read_state_vals, state_vals)
283
- return val
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import math
17
+ from typing import Any, Callable, TypeVar
18
+
19
+ import jax
20
+
21
+ from brainstate._utils import set_module_as
22
+ from ._loop_collect_return import _bounded_while_loop
23
+ from ._make_jaxpr import StatefulFunction
24
+ from ._util import wrap_single_fun_in_multi_branches_while_loop as wrap_fn
25
+
26
+ X = TypeVar('X')
27
+ Y = TypeVar('Y')
28
+ T = TypeVar('T')
29
+ Carry = TypeVar('Carry')
30
+ BooleanNumeric = Any # A bool, or a Boolean array.
31
+
32
+ __all__ = [
33
+ 'while_loop', 'bounded_while_loop',
34
+ ]
35
+
36
+
37
+ @set_module_as('brainstate.transform')
38
+ def while_loop(
39
+ cond_fun: Callable[[T], BooleanNumeric],
40
+ body_fun: Callable[[T], T],
41
+ init_val: T
42
+ ) -> T:
43
+ """
44
+ Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
45
+
46
+ The `Haskell-like type signature`_ in brief is
47
+
48
+ .. code-block:: haskell
49
+
50
+ while_loop :: (a -> Bool) -> (a -> a) -> a -> a
51
+
52
+ The semantics of ``while_loop`` are given by this Python implementation:
53
+
54
+ .. code-block:: python
55
+
56
+ >>> def while_loop(cond_fun, body_fun, init_val):
57
+ ... val = init_val
58
+ ... while cond_fun(val):
59
+ ... val = body_fun(val)
60
+ ... return val
61
+
62
+ Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
63
+ to a single WhileOp. That makes it useful for reducing compilation times
64
+ for jit-compiled functions, since native Python loop constructs in an ``@jit``
65
+ function are unrolled, leading to large XLA computations.
66
+
67
+ Also unlike the Python analogue, the loop-carried value ``val`` must hold a
68
+ fixed shape and dtype across all iterations (and not just be consistent up to
69
+ NumPy rank/shape broadcasting and dtype promotion rules, for example). In
70
+ other words, the type ``a`` in the type signature above represents an array
71
+ with a fixed shape and dtype (or a nested tuple/list/dict container data
72
+ structure with a fixed structure and arrays with fixed shape and dtype at the
73
+ leaves).
74
+
75
+ Another difference from using Python-native loop constructs is that
76
+ ``while_loop`` is not reverse-mode differentiable because XLA computations
77
+ require static bounds on memory requirements.
78
+
79
+ Parameters
80
+ ----------
81
+ cond_fun : callable
82
+ Function of type ``a -> Bool``.
83
+ body_fun : callable
84
+ Function of type ``a -> a``.
85
+ init_val : T
86
+ Value of type ``a``, a type that can be a scalar, array, or any
87
+ pytree (nested Python tuple/list/dict) thereof, representing the initial
88
+ loop carry value.
89
+
90
+ Returns
91
+ -------
92
+ T
93
+ The output from the final iteration of body_fun, of type ``a``.
94
+
95
+ Examples
96
+ --------
97
+ Basic while loop operation:
98
+
99
+ .. code-block:: python
100
+
101
+ >>> import brainstate
102
+ >>> import jax.numpy as jnp
103
+ >>>
104
+ >>> def cond_fn(val):
105
+ ... return val < 10
106
+ >>>
107
+ >>> def body_fn(val):
108
+ ... return val + 1
109
+ >>>
110
+ >>> result = brainstate.transform.while_loop(cond_fn, body_fn, 0)
111
+ >>> # result will be 10
112
+
113
+ While loop with array state:
114
+
115
+ .. code-block:: python
116
+
117
+ >>> def cond_fn(state):
118
+ ... return jnp.sum(state) < 100
119
+ >>>
120
+ >>> def body_fn(state):
121
+ ... return state * 1.1
122
+ >>>
123
+ >>> init_state = jnp.array([1.0, 2.0, 3.0])
124
+ >>> final_state = brainstate.transform.while_loop(cond_fn, body_fn, init_state)
125
+
126
+ References
127
+ ----------
128
+ .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
129
+ """
130
+ if not (callable(body_fun) and callable(cond_fun)):
131
+ raise TypeError("while_loop: body_fun and cond_fun arguments should be callable.")
132
+ if jax.config.jax_disable_jit:
133
+ try:
134
+ val = init_val
135
+ while cond_fun(val):
136
+ val = body_fun(val)
137
+ return val
138
+ except jax.core.ConcretizationTypeError:
139
+ # Can't run this while_loop in Python (e.g. because there's a vmap
140
+ # transformation on it), so we fall back to the primitive version.
141
+ pass
142
+
143
+ # evaluate jaxpr
144
+ stateful_cond = StatefulFunction(cond_fun, name='while:cond').make_jaxpr(init_val)
145
+ stateful_body = StatefulFunction(body_fun, name='while:body').make_jaxpr(init_val)
146
+ cond_cache_key = stateful_cond.get_arg_cache_key(init_val)
147
+ body_cache_key = stateful_body.get_arg_cache_key(init_val)
148
+ if len(stateful_cond.get_write_states_by_cache(cond_cache_key)) != 0:
149
+ raise ValueError("while_loop: cond_fun should not have any write states.")
150
+
151
+ # state trace and state values
152
+ state_trace = (stateful_cond.get_state_trace_by_cache(cond_cache_key) +
153
+ stateful_body.get_state_trace_by_cache(body_cache_key))
154
+ read_state_vals = state_trace.get_read_state_values(True)
155
+ write_state_vals = state_trace.get_write_state_values(True)
156
+ new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False, cond_cache_key)
157
+ new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True, body_cache_key)
158
+
159
+ # while_loop
160
+ state_vals, final_val = jax.lax.while_loop(new_cond_fn, new_body_fn, (write_state_vals, init_val))
161
+
162
+ # write back state values or restore them
163
+ state_trace.assign_state_vals_v2(read_state_vals, state_vals)
164
+ return final_val
165
+
166
+
167
+ @set_module_as('brainstate.transform')
168
+ def bounded_while_loop(
169
+ cond_fun: Callable[[T], BooleanNumeric],
170
+ body_fun: Callable[[T], T],
171
+ init_val: T,
172
+ *,
173
+ max_steps: int,
174
+ base: int = 16,
175
+ ):
176
+ """
177
+ While loop with a bound on the maximum number of steps.
178
+
179
+ This function is adapted from ``while_loop`` in `equinox <https://github.com/patrick-kidger/equinox/blob/main/equinox/internal/_loop/loop.py>`_.
180
+
181
+ This function is useful when you want to ensure that a while loop terminates
182
+ even if the condition function is never false. The function is implemented
183
+ using a scan operation, so it is reverse-mode differentiable.
184
+
185
+ Parameters
186
+ ----------
187
+ cond_fun : callable
188
+ A function of type ``a -> Bool``.
189
+ body_fun : callable
190
+ A function of type ``a -> a``.
191
+ init_val : T
192
+ The initial value of type ``a``.
193
+ max_steps : int
194
+ A bound on the maximum number of steps, after which the loop
195
+ terminates unconditionally.
196
+ base : int, default 16
197
+ Run time will increase slightly as `base` increases. Compilation time will
198
+ decrease substantially as `math.ceil(math.log(max_steps, base))` decreases.
199
+ (Which happens as `base` increases.)
200
+
201
+ Returns
202
+ -------
203
+ T
204
+ The final value, as if computed by a `lax.while_loop`.
205
+
206
+ Examples
207
+ --------
208
+ Basic bounded while loop:
209
+
210
+ .. code-block:: python
211
+
212
+ >>> import brainstate
213
+ >>> import jax.numpy as jnp
214
+ >>>
215
+ >>> def cond_fn(val):
216
+ ... return val < 1000 # This might never be false
217
+ >>>
218
+ >>> def body_fn(val):
219
+ ... return val * 2
220
+ >>>
221
+ >>> # Loop will terminate after at most 10 steps
222
+ >>> result = brainstate.transform.bounded_while_loop(
223
+ ... cond_fn, body_fn, 1, max_steps=10
224
+ ... )
225
+
226
+ Bounded while loop with custom base:
227
+
228
+ .. code-block:: python
229
+
230
+ >>> # Use a smaller base for potentially faster compilation
231
+ >>> result = brainstate.transform.bounded_while_loop(
232
+ ... cond_fn, body_fn, 1, max_steps=100, base=8
233
+ ... )
234
+
235
+ Bounded while loop with array state:
236
+
237
+ .. code-block:: python
238
+
239
+ >>> def cond_fn(state):
240
+ ... return jnp.max(state) < 100
241
+ >>>
242
+ >>> def body_fn(state):
243
+ ... return state + jnp.array([1.0, 2.0, 3.0])
244
+ >>>
245
+ >>> init_state = jnp.array([0.0, 0.0, 0.0])
246
+ >>> final_state = brainstate.transform.bounded_while_loop(
247
+ ... cond_fn, body_fn, init_state, max_steps=50
248
+ ... )
249
+ """
250
+
251
+ # checking
252
+ if not isinstance(max_steps, int) or max_steps < 0:
253
+ raise ValueError("max_steps must be a non-negative integer")
254
+ init_val = jax.tree.map(jax.numpy.asarray, init_val)
255
+ if max_steps == 0:
256
+ return init_val
257
+
258
+ # evaluate jaxpr
259
+ stateful_cond = StatefulFunction(cond_fun, name='bounded_while:cond').make_jaxpr(init_val)
260
+ stateful_body = StatefulFunction(body_fun, name='bounded_while:body').make_jaxpr(init_val)
261
+ cond_cache_key = stateful_cond.get_arg_cache_key(init_val)
262
+ body_cache_key = stateful_body.get_arg_cache_key(init_val)
263
+ if len(stateful_cond.get_write_states_by_cache(cond_cache_key)) != 0:
264
+ raise ValueError("while_loop: cond_fun should not have any write states.")
265
+
266
+ # state trace and state values
267
+ state_trace = (stateful_cond.get_state_trace(init_val) +
268
+ stateful_body.get_state_trace(init_val))
269
+ read_state_vals = state_trace.get_read_state_values(True)
270
+ write_state_vals = state_trace.get_write_state_values(True)
271
+ new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False, cond_cache_key)
272
+ new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True, body_cache_key)
273
+
274
+ # initial value
275
+ init_val = (write_state_vals, init_val)
276
+
277
+ # while_loop
278
+ rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base)))
279
+ state_vals, val = _bounded_while_loop(new_cond_fn, new_body_fn, init_val, rounded_max_steps, base, None)
280
+
281
+ # write back state values or restore them
282
+ state_trace.assign_state_vals_v2(read_state_vals, state_vals)
283
+ return val
@@ -1,50 +1,50 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from unittest import TestCase
18
-
19
- import brainstate
20
-
21
-
22
- class TestWhileLoop(TestCase):
23
- def test1(self):
24
- a = brainstate.State(1.)
25
- b = brainstate.State(20.)
26
-
27
- def cond(_):
28
- return a.value < b.value
29
-
30
- def body(_):
31
- a.value += 1.
32
-
33
- brainstate.compile.while_loop(cond, body, None)
34
-
35
- print(a.value, b.value)
36
-
37
- def test2(self):
38
- a = brainstate.State(1.)
39
- b = brainstate.State(20.)
40
-
41
- def cond(x):
42
- return a.value < b.value
43
-
44
- def body(x):
45
- a.value += x
46
- return x
47
-
48
- r = brainstate.compile.while_loop(cond, body, 1.)
49
-
50
- print(a.value, b.value, r)
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from unittest import TestCase
18
+
19
+ import brainstate
20
+
21
+
22
+ class TestWhileLoop(TestCase):
23
+ def test1(self):
24
+ a = brainstate.State(1.)
25
+ b = brainstate.State(20.)
26
+
27
+ def cond(_):
28
+ return a.value < b.value
29
+
30
+ def body(_):
31
+ a.value += 1.
32
+
33
+ brainstate.compile.while_loop(cond, body, None)
34
+
35
+ print(a.value, b.value)
36
+
37
+ def test2(self):
38
+ a = brainstate.State(1.)
39
+ b = brainstate.State(20.)
40
+
41
+ def cond(x):
42
+ return a.value < b.value
43
+
44
+ def body(x):
45
+ a.value += x
46
+ return x
47
+
48
+ r = brainstate.compile.while_loop(cond, body, 1.)
49
+
50
+ print(a.value, b.value, r)