brainstate 0.2.1__py2.py3-none-any.whl → 0.2.2__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. brainstate/__init__.py +167 -169
  2. brainstate/_compatible_import.py +340 -340
  3. brainstate/_compatible_import_test.py +681 -681
  4. brainstate/_deprecation.py +210 -210
  5. brainstate/_deprecation_test.py +2297 -2319
  6. brainstate/_error.py +45 -45
  7. brainstate/_state.py +2157 -1652
  8. brainstate/_state_test.py +1129 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -1495
  11. brainstate/environ_test.py +1223 -1223
  12. brainstate/graph/__init__.py +22 -22
  13. brainstate/graph/_node.py +240 -240
  14. brainstate/graph/_node_test.py +589 -589
  15. brainstate/graph/_operation.py +1620 -1624
  16. brainstate/graph/_operation_test.py +1147 -1147
  17. brainstate/mixin.py +1447 -1433
  18. brainstate/mixin_test.py +1017 -1017
  19. brainstate/nn/__init__.py +146 -137
  20. brainstate/nn/_activations.py +1100 -1100
  21. brainstate/nn/_activations_test.py +354 -354
  22. brainstate/nn/_collective_ops.py +635 -633
  23. brainstate/nn/_collective_ops_test.py +774 -774
  24. brainstate/nn/_common.py +226 -226
  25. brainstate/nn/_common_test.py +134 -154
  26. brainstate/nn/_conv.py +2010 -2010
  27. brainstate/nn/_conv_test.py +849 -849
  28. brainstate/nn/_delay.py +575 -575
  29. brainstate/nn/_delay_test.py +243 -243
  30. brainstate/nn/_dropout.py +618 -618
  31. brainstate/nn/_dropout_test.py +480 -477
  32. brainstate/nn/_dynamics.py +870 -1267
  33. brainstate/nn/_dynamics_test.py +53 -67
  34. brainstate/nn/_elementwise.py +1298 -1298
  35. brainstate/nn/_elementwise_test.py +829 -829
  36. brainstate/nn/_embedding.py +408 -408
  37. brainstate/nn/_embedding_test.py +156 -156
  38. brainstate/nn/_event_fixedprob.py +233 -233
  39. brainstate/nn/_event_fixedprob_test.py +115 -115
  40. brainstate/nn/_event_linear.py +83 -83
  41. brainstate/nn/_event_linear_test.py +121 -121
  42. brainstate/nn/_exp_euler.py +254 -254
  43. brainstate/nn/_exp_euler_test.py +377 -377
  44. brainstate/nn/_linear.py +744 -744
  45. brainstate/nn/_linear_test.py +475 -475
  46. brainstate/nn/_metrics.py +1070 -1070
  47. brainstate/nn/_metrics_test.py +611 -611
  48. brainstate/nn/_module.py +391 -384
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -1334
  51. brainstate/nn/_normalizations_test.py +699 -699
  52. brainstate/nn/_paddings.py +1020 -1020
  53. brainstate/nn/_paddings_test.py +722 -722
  54. brainstate/nn/_poolings.py +2239 -2239
  55. brainstate/nn/_poolings_test.py +952 -952
  56. brainstate/nn/_rnns.py +946 -946
  57. brainstate/nn/_rnns_test.py +592 -592
  58. brainstate/nn/_utils.py +216 -216
  59. brainstate/nn/_utils_test.py +401 -401
  60. brainstate/nn/init.py +809 -809
  61. brainstate/nn/init_test.py +180 -180
  62. brainstate/random/__init__.py +270 -270
  63. brainstate/random/{_rand_funs.py → _fun.py} +3938 -3938
  64. brainstate/random/{_rand_funs_test.py → _fun_test.py} +638 -640
  65. brainstate/random/_impl.py +672 -0
  66. brainstate/random/{_rand_seed.py → _seed.py} +675 -675
  67. brainstate/random/{_rand_seed_test.py → _seed_test.py} +48 -48
  68. brainstate/random/{_rand_state.py → _state.py} +1320 -1617
  69. brainstate/random/{_rand_state_test.py → _state_test.py} +551 -551
  70. brainstate/transform/__init__.py +56 -59
  71. brainstate/transform/_ad_checkpoint.py +176 -176
  72. brainstate/transform/_ad_checkpoint_test.py +49 -49
  73. brainstate/transform/_autograd.py +1025 -1025
  74. brainstate/transform/_autograd_test.py +1289 -1289
  75. brainstate/transform/_conditions.py +316 -316
  76. brainstate/transform/_conditions_test.py +220 -220
  77. brainstate/transform/_error_if.py +94 -94
  78. brainstate/transform/_error_if_test.py +52 -52
  79. brainstate/transform/_find_state.py +200 -0
  80. brainstate/transform/_find_state_test.py +84 -0
  81. brainstate/transform/_jit.py +399 -399
  82. brainstate/transform/_jit_test.py +143 -143
  83. brainstate/transform/_loop_collect_return.py +675 -675
  84. brainstate/transform/_loop_collect_return_test.py +58 -58
  85. brainstate/transform/_loop_no_collection.py +283 -283
  86. brainstate/transform/_loop_no_collection_test.py +50 -50
  87. brainstate/transform/_make_jaxpr.py +2176 -2016
  88. brainstate/transform/_make_jaxpr_test.py +1634 -1510
  89. brainstate/transform/_mapping.py +607 -529
  90. brainstate/transform/_mapping_test.py +104 -194
  91. brainstate/transform/_progress_bar.py +255 -255
  92. brainstate/transform/_unvmap.py +256 -256
  93. brainstate/transform/_util.py +286 -286
  94. brainstate/typing.py +837 -837
  95. brainstate/typing_test.py +780 -780
  96. brainstate/util/__init__.py +27 -27
  97. brainstate/util/_others.py +1024 -1024
  98. brainstate/util/_others_test.py +962 -962
  99. brainstate/util/_pretty_pytree.py +1301 -1301
  100. brainstate/util/_pretty_pytree_test.py +675 -675
  101. brainstate/util/_pretty_repr.py +462 -462
  102. brainstate/util/_pretty_repr_test.py +696 -696
  103. brainstate/util/filter.py +945 -945
  104. brainstate/util/filter_test.py +911 -911
  105. brainstate/util/struct.py +910 -910
  106. brainstate/util/struct_test.py +602 -602
  107. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/METADATA +108 -108
  108. brainstate-0.2.2.dist-info/RECORD +111 -0
  109. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/transform/_eval_shape.py +0 -145
  111. brainstate/transform/_eval_shape_test.py +0 -38
  112. brainstate/transform/_random.py +0 -171
  113. brainstate-0.2.1.dist-info/RECORD +0 -111
  114. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/WHEEL +0 -0
  115. {brainstate-0.2.1.dist-info → brainstate-0.2.2.dist-info}/top_level.txt +0 -0
@@ -1,286 +1,286 @@
1
- # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from functools import wraps
17
- from typing import Sequence, Tuple, Hashable
18
-
19
- from brainstate._state import StateTraceStack
20
- from brainstate.typing import PyTree
21
- from ._make_jaxpr import StatefulFunction
22
-
23
-
24
- def wrap_single_fun_in_multi_branches(
25
- stateful_fun: StatefulFunction,
26
- merged_state_trace: StateTraceStack,
27
- read_state_vals: Sequence[PyTree | None],
28
- return_states: bool = True,
29
- cache_key: Hashable = None,
30
- ):
31
- """
32
- Wrap a stateful function for use in multi-branch control flow.
33
-
34
- This function creates a wrapper that allows a stateful function to be used
35
- in control flow operations where multiple functions share state. It manages
36
- state values by extracting only the states needed by this specific function
37
- from a merged state trace.
38
-
39
- Parameters
40
- ----------
41
- stateful_fun : StatefulFunction
42
- The stateful function to be wrapped.
43
- merged_state_trace : StateTraceStack
44
- The merged state trace containing all states from multiple functions.
45
- read_state_vals : sequence of PyTree or None
46
- The original read state values for all states in the merged trace.
47
- return_states : bool, default True
48
- Whether to return updated state values along with the function output.
49
-
50
- Returns
51
- -------
52
- callable
53
- A wrapped function that can be used in multi-branch control flow.
54
-
55
- Examples
56
- --------
57
- Usage in conditional execution:
58
-
59
- .. code-block:: python
60
-
61
- >>> import brainstate
62
- >>> import jax.numpy as jnp
63
- >>>
64
- >>> # Create states
65
- >>> state1 = brainstate.State(jnp.array([1.0]))
66
- >>> state2 = brainstate.State(jnp.array([2.0]))
67
- >>>
68
- >>> def branch_fn(x):
69
- ... state1.value *= x
70
- ... return state1.value + state2.value
71
- >>>
72
- >>> # During compilation, this wrapper allows the function
73
- >>> # to work with merged state traces from multiple branches
74
- >>> sf = brainstate.transform.StatefulFunction(branch_fn)
75
- >>> # wrapped_fn = wrap_single_fun_in_multi_branches(sf, merged_trace, read_vals)
76
- """
77
- state_ids_belong_to_this_fun = {id(st): st for st in stateful_fun.get_states_by_cache(cache_key)}
78
-
79
- @wraps(stateful_fun.fun)
80
- def wrapped_branch(write_state_vals, *operands):
81
- # "write_state_vals" should have the same length as "merged_state_trace.states"
82
- assert len(merged_state_trace.states) == len(write_state_vals) == len(read_state_vals)
83
-
84
- # get all state values needed for this function, which is a subset of "write_state_vals"
85
- st_vals_for_this_fun = []
86
- for write, st, val_w, val_r in zip(merged_state_trace.been_writen,
87
- merged_state_trace.states,
88
- write_state_vals,
89
- read_state_vals):
90
- if id(st) in state_ids_belong_to_this_fun:
91
- st_vals_for_this_fun.append(val_w if write else val_r)
92
-
93
- # call this function
94
- new_state_vals, out = stateful_fun.jaxpr_call(st_vals_for_this_fun, *operands)
95
- assert len(new_state_vals) == len(st_vals_for_this_fun)
96
-
97
- if return_states:
98
- # get all written state values
99
- new_state_vals = {id(st): val for st, val in zip(stateful_fun.get_states_by_cache(cache_key), new_state_vals)}
100
- write_state_vals = tuple([
101
- (new_state_vals[id(st)] if id(st) in state_ids_belong_to_this_fun else w_val)
102
- if write else None
103
- for write, st, w_val in zip(merged_state_trace.been_writen,
104
- merged_state_trace.states,
105
- write_state_vals)
106
- ])
107
- return write_state_vals, out
108
- return out
109
-
110
- return wrapped_branch
111
-
112
-
113
- def wrap_single_fun_in_multi_branches_while_loop(
114
- stateful_fun: StatefulFunction,
115
- merged_state_trace: StateTraceStack,
116
- read_state_vals: Sequence[PyTree | None],
117
- return_states: bool = True,
118
- cache_key: Hashable = None,
119
- ):
120
- """
121
- Wrap a stateful function for use in while loop control flow.
122
-
123
- This function creates a wrapper specifically designed for while loop operations
124
- where multiple functions share state. It manages state values by extracting only
125
- the states needed by this specific function from a merged state trace, with
126
- special handling for the loop's init_val structure.
127
-
128
- Parameters
129
- ----------
130
- stateful_fun : StatefulFunction
131
- The stateful function to be wrapped.
132
- merged_state_trace : StateTraceStack
133
- The merged state trace containing all states from multiple functions.
134
- read_state_vals : sequence of PyTree or None
135
- The original read state values for all states in the merged trace.
136
- return_states : bool, default True
137
- Whether to return updated state values along with the function output.
138
-
139
- Returns
140
- -------
141
- callable
142
- A wrapped function that can be used in while loop control flow.
143
-
144
- Examples
145
- --------
146
- Usage in while loop operations:
147
-
148
- .. code-block:: python
149
-
150
- >>> import brainstate
151
- >>> import jax.numpy as jnp
152
- >>>
153
- >>> # Create states
154
- >>> counter = brainstate.State(jnp.array([0]))
155
- >>> accumulator = brainstate.State(jnp.array([0.0]))
156
- >>>
157
- >>> def cond_fn(val):
158
- ... return counter.value < 10
159
- >>>
160
- >>> def body_fn(val):
161
- ... counter.value += 1
162
- ... accumulator.value += val
163
- ... return val * 2
164
- >>>
165
- >>> # During compilation, this wrapper allows the functions
166
- >>> # to work with merged state traces in while loops
167
- >>> sf_cond = brainstate.transform.StatefulFunction(cond_fn)
168
- >>> sf_body = brainstate.transform.StatefulFunction(body_fn)
169
- >>> # wrapped_cond = wrap_single_fun_in_multi_branches_while_loop(sf_cond, ...)
170
- >>> # wrapped_body = wrap_single_fun_in_multi_branches_while_loop(sf_body, ...)
171
- """
172
- state_ids_belong_to_this_fun = {id(st): st for st in stateful_fun.get_states_by_cache(cache_key)}
173
-
174
- @wraps(stateful_fun.fun)
175
- def wrapped_branch(init_val):
176
- write_state_vals, init_val = init_val
177
- # "write_state_vals" should have the same length as "merged_state_trace.states"
178
- assert len(merged_state_trace.states) == len(write_state_vals) == len(read_state_vals)
179
-
180
- # get all state values needed for this function, which is a subset of "write_state_vals"
181
- st_vals_for_this_fun = []
182
- for write, st, val_w, val_r in zip(merged_state_trace.been_writen,
183
- merged_state_trace.states,
184
- write_state_vals,
185
- read_state_vals):
186
- if id(st) in state_ids_belong_to_this_fun:
187
- st_vals_for_this_fun.append(val_w if write else val_r)
188
-
189
- # call this function
190
- new_state_vals, out = stateful_fun.jaxpr_call(st_vals_for_this_fun, init_val)
191
- assert len(new_state_vals) == len(st_vals_for_this_fun)
192
-
193
- if return_states:
194
- # get all written state values
195
- new_state_vals = {id(st): val for st, val in zip(stateful_fun.get_states_by_cache(cache_key), new_state_vals)}
196
- write_state_vals = tuple([
197
- (new_state_vals[id(st)] if id(st) in state_ids_belong_to_this_fun else w_val)
198
- if write else None
199
- for write, st, w_val in zip(merged_state_trace.been_writen,
200
- merged_state_trace.states,
201
- write_state_vals)
202
- ])
203
- return write_state_vals, out
204
- return out
205
-
206
- return wrapped_branch
207
-
208
-
209
- def wrap_single_fun(
210
- stateful_fun: StatefulFunction,
211
- been_writen: Sequence[bool],
212
- read_state_vals: Tuple[PyTree | None],
213
- ):
214
- """
215
- Wrap a stateful function for use in scan operations.
216
-
217
- This function creates a wrapper specifically designed for scan operations.
218
- It manages state values by combining written and read states, calls the
219
- stateful function, and returns only the written states along with the
220
- carry and output values.
221
-
222
- Parameters
223
- ----------
224
- stateful_fun : StatefulFunction
225
- The stateful function to be wrapped for scan operations.
226
- been_writen : sequence of bool
227
- Boolean flags indicating which states have been written to.
228
- read_state_vals : tuple of PyTree or None
229
- The original read state values for all states.
230
-
231
- Returns
232
- -------
233
- callable
234
- A wrapped function that can be used in scan operations with proper
235
- state management.
236
-
237
- Examples
238
- --------
239
- Usage in scan operations:
240
-
241
- .. code-block:: python
242
-
243
- >>> import brainstate
244
- >>> import jax.numpy as jnp
245
- >>>
246
- >>> # Create states
247
- >>> state1 = brainstate.State(jnp.array([0.0]))
248
- >>> state2 = brainstate.State(jnp.array([1.0]))
249
- >>>
250
- >>> def scan_fn(carry, x):
251
- ... state1.value += x # This state will be written
252
- ... result = carry + state1.value + state2.value # state2 is only read
253
- ... return result, result ** 2
254
- >>>
255
- >>> # During compilation, this wrapper allows the function
256
- >>> # to work properly in scan operations
257
- >>> sf = brainstate.transform.StatefulFunction(scan_fn)
258
- >>> # wrapped_fn = wrap_single_fun(sf, been_written_flags, read_values)
259
- >>>
260
- >>> # The wrapped function handles state management automatically
261
- >>> xs = jnp.arange(5.0)
262
- >>> init_carry = 0.0
263
- final_carry, ys = brainstate.transform.scan(scan_fn, init_carry, xs)
264
- """
265
-
266
- @wraps(stateful_fun.fun)
267
- def wrapped_fun(new_carry, inputs):
268
- writen_state_vals, carry = new_carry
269
- assert len(been_writen) == len(writen_state_vals) == len(read_state_vals)
270
-
271
- # collect all written and read states
272
- state_vals = [
273
- written_val if written else read_val
274
- for written, written_val, read_val in zip(been_writen, writen_state_vals, read_state_vals)
275
- ]
276
-
277
- # call the jaxpr
278
- state_vals, (carry, out) = stateful_fun.jaxpr_call(state_vals, carry, inputs)
279
-
280
- # only return the written states
281
- writen_state_vals = tuple([val if written else None for written, val in zip(been_writen, state_vals)])
282
-
283
- # return
284
- return (writen_state_vals, carry), out
285
-
286
- return wrapped_fun
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from functools import wraps
17
+ from typing import Sequence, Tuple, Hashable
18
+
19
+ from brainstate._state import StateTraceStack
20
+ from brainstate.typing import PyTree
21
+ from ._make_jaxpr import StatefulFunction
22
+
23
+
24
+ def wrap_single_fun_in_multi_branches(
25
+ stateful_fun: StatefulFunction,
26
+ merged_state_trace: StateTraceStack,
27
+ read_state_vals: Sequence[PyTree | None],
28
+ return_states: bool = True,
29
+ cache_key: Hashable = None,
30
+ ):
31
+ """
32
+ Wrap a stateful function for use in multi-branch control flow.
33
+
34
+ This function creates a wrapper that allows a stateful function to be used
35
+ in control flow operations where multiple functions share state. It manages
36
+ state values by extracting only the states needed by this specific function
37
+ from a merged state trace.
38
+
39
+ Parameters
40
+ ----------
41
+ stateful_fun : StatefulFunction
42
+ The stateful function to be wrapped.
43
+ merged_state_trace : StateTraceStack
44
+ The merged state trace containing all states from multiple functions.
45
+ read_state_vals : sequence of PyTree or None
46
+ The original read state values for all states in the merged trace.
47
+ return_states : bool, default True
48
+ Whether to return updated state values along with the function output.
49
+
50
+ Returns
51
+ -------
52
+ callable
53
+ A wrapped function that can be used in multi-branch control flow.
54
+
55
+ Examples
56
+ --------
57
+ Usage in conditional execution:
58
+
59
+ .. code-block:: python
60
+
61
+ >>> import brainstate
62
+ >>> import jax.numpy as jnp
63
+ >>>
64
+ >>> # Create states
65
+ >>> state1 = brainstate.State(jnp.array([1.0]))
66
+ >>> state2 = brainstate.State(jnp.array([2.0]))
67
+ >>>
68
+ >>> def branch_fn(x):
69
+ ... state1.value *= x
70
+ ... return state1.value + state2.value
71
+ >>>
72
+ >>> # During compilation, this wrapper allows the function
73
+ >>> # to work with merged state traces from multiple branches
74
+ >>> sf = brainstate.transform.StatefulFunction(branch_fn)
75
+ >>> # wrapped_fn = wrap_single_fun_in_multi_branches(sf, merged_trace, read_vals)
76
+ """
77
+ state_ids_belong_to_this_fun = {id(st): st for st in stateful_fun.get_states_by_cache(cache_key)}
78
+
79
+ @wraps(stateful_fun.fun)
80
+ def wrapped_branch(write_state_vals, *operands):
81
+ # "write_state_vals" should have the same length as "merged_state_trace.states"
82
+ assert len(merged_state_trace.states) == len(write_state_vals) == len(read_state_vals)
83
+
84
+ # get all state values needed for this function, which is a subset of "write_state_vals"
85
+ st_vals_for_this_fun = []
86
+ for write, st, val_w, val_r in zip(merged_state_trace.been_writen,
87
+ merged_state_trace.states,
88
+ write_state_vals,
89
+ read_state_vals):
90
+ if id(st) in state_ids_belong_to_this_fun:
91
+ st_vals_for_this_fun.append(val_w if write else val_r)
92
+
93
+ # call this function
94
+ new_state_vals, out = stateful_fun.jaxpr_call(st_vals_for_this_fun, *operands)
95
+ assert len(new_state_vals) == len(st_vals_for_this_fun)
96
+
97
+ if return_states:
98
+ # get all written state values
99
+ new_state_vals = {id(st): val for st, val in zip(stateful_fun.get_states_by_cache(cache_key), new_state_vals)}
100
+ write_state_vals = tuple([
101
+ (new_state_vals[id(st)] if id(st) in state_ids_belong_to_this_fun else w_val)
102
+ if write else None
103
+ for write, st, w_val in zip(merged_state_trace.been_writen,
104
+ merged_state_trace.states,
105
+ write_state_vals)
106
+ ])
107
+ return write_state_vals, out
108
+ return out
109
+
110
+ return wrapped_branch
111
+
112
+
113
+ def wrap_single_fun_in_multi_branches_while_loop(
114
+ stateful_fun: StatefulFunction,
115
+ merged_state_trace: StateTraceStack,
116
+ read_state_vals: Sequence[PyTree | None],
117
+ return_states: bool = True,
118
+ cache_key: Hashable = None,
119
+ ):
120
+ """
121
+ Wrap a stateful function for use in while loop control flow.
122
+
123
+ This function creates a wrapper specifically designed for while loop operations
124
+ where multiple functions share state. It manages state values by extracting only
125
+ the states needed by this specific function from a merged state trace, with
126
+ special handling for the loop's init_val structure.
127
+
128
+ Parameters
129
+ ----------
130
+ stateful_fun : StatefulFunction
131
+ The stateful function to be wrapped.
132
+ merged_state_trace : StateTraceStack
133
+ The merged state trace containing all states from multiple functions.
134
+ read_state_vals : sequence of PyTree or None
135
+ The original read state values for all states in the merged trace.
136
+ return_states : bool, default True
137
+ Whether to return updated state values along with the function output.
138
+
139
+ Returns
140
+ -------
141
+ callable
142
+ A wrapped function that can be used in while loop control flow.
143
+
144
+ Examples
145
+ --------
146
+ Usage in while loop operations:
147
+
148
+ .. code-block:: python
149
+
150
+ >>> import brainstate
151
+ >>> import jax.numpy as jnp
152
+ >>>
153
+ >>> # Create states
154
+ >>> counter = brainstate.State(jnp.array([0]))
155
+ >>> accumulator = brainstate.State(jnp.array([0.0]))
156
+ >>>
157
+ >>> def cond_fn(val):
158
+ ... return counter.value < 10
159
+ >>>
160
+ >>> def body_fn(val):
161
+ ... counter.value += 1
162
+ ... accumulator.value += val
163
+ ... return val * 2
164
+ >>>
165
+ >>> # During compilation, this wrapper allows the functions
166
+ >>> # to work with merged state traces in while loops
167
+ >>> sf_cond = brainstate.transform.StatefulFunction(cond_fn)
168
+ >>> sf_body = brainstate.transform.StatefulFunction(body_fn)
169
+ >>> # wrapped_cond = wrap_single_fun_in_multi_branches_while_loop(sf_cond, ...)
170
+ >>> # wrapped_body = wrap_single_fun_in_multi_branches_while_loop(sf_body, ...)
171
+ """
172
+ state_ids_belong_to_this_fun = {id(st): st for st in stateful_fun.get_states_by_cache(cache_key)}
173
+
174
+ @wraps(stateful_fun.fun)
175
+ def wrapped_branch(init_val):
176
+ write_state_vals, init_val = init_val
177
+ # "write_state_vals" should have the same length as "merged_state_trace.states"
178
+ assert len(merged_state_trace.states) == len(write_state_vals) == len(read_state_vals)
179
+
180
+ # get all state values needed for this function, which is a subset of "write_state_vals"
181
+ st_vals_for_this_fun = []
182
+ for write, st, val_w, val_r in zip(merged_state_trace.been_writen,
183
+ merged_state_trace.states,
184
+ write_state_vals,
185
+ read_state_vals):
186
+ if id(st) in state_ids_belong_to_this_fun:
187
+ st_vals_for_this_fun.append(val_w if write else val_r)
188
+
189
+ # call this function
190
+ new_state_vals, out = stateful_fun.jaxpr_call(st_vals_for_this_fun, init_val)
191
+ assert len(new_state_vals) == len(st_vals_for_this_fun)
192
+
193
+ if return_states:
194
+ # get all written state values
195
+ new_state_vals = {id(st): val for st, val in zip(stateful_fun.get_states_by_cache(cache_key), new_state_vals)}
196
+ write_state_vals = tuple([
197
+ (new_state_vals[id(st)] if id(st) in state_ids_belong_to_this_fun else w_val)
198
+ if write else None
199
+ for write, st, w_val in zip(merged_state_trace.been_writen,
200
+ merged_state_trace.states,
201
+ write_state_vals)
202
+ ])
203
+ return write_state_vals, out
204
+ return out
205
+
206
+ return wrapped_branch
207
+
208
+
209
+ def wrap_single_fun(
210
+ stateful_fun: StatefulFunction,
211
+ been_writen: Sequence[bool],
212
+ read_state_vals: Tuple[PyTree | None],
213
+ ):
214
+ """
215
+ Wrap a stateful function for use in scan operations.
216
+
217
+ This function creates a wrapper specifically designed for scan operations.
218
+ It manages state values by combining written and read states, calls the
219
+ stateful function, and returns only the written states along with the
220
+ carry and output values.
221
+
222
+ Parameters
223
+ ----------
224
+ stateful_fun : StatefulFunction
225
+ The stateful function to be wrapped for scan operations.
226
+ been_writen : sequence of bool
227
+ Boolean flags indicating which states have been written to.
228
+ read_state_vals : tuple of PyTree or None
229
+ The original read state values for all states.
230
+
231
+ Returns
232
+ -------
233
+ callable
234
+ A wrapped function that can be used in scan operations with proper
235
+ state management.
236
+
237
+ Examples
238
+ --------
239
+ Usage in scan operations:
240
+
241
+ .. code-block:: python
242
+
243
+ >>> import brainstate
244
+ >>> import jax.numpy as jnp
245
+ >>>
246
+ >>> # Create states
247
+ >>> state1 = brainstate.State(jnp.array([0.0]))
248
+ >>> state2 = brainstate.State(jnp.array([1.0]))
249
+ >>>
250
+ >>> def scan_fn(carry, x):
251
+ ... state1.value += x # This state will be written
252
+ ... result = carry + state1.value + state2.value # state2 is only read
253
+ ... return result, result ** 2
254
+ >>>
255
+ >>> # During compilation, this wrapper allows the function
256
+ >>> # to work properly in scan operations
257
+ >>> sf = brainstate.transform.StatefulFunction(scan_fn)
258
+ >>> # wrapped_fn = wrap_single_fun(sf, been_written_flags, read_values)
259
+ >>>
260
+ >>> # The wrapped function handles state management automatically
261
+ >>> xs = jnp.arange(5.0)
262
+ >>> init_carry = 0.0
263
+ final_carry, ys = brainstate.transform.scan(scan_fn, init_carry, xs)
264
+ """
265
+
266
+ @wraps(stateful_fun.fun)
267
+ def wrapped_fun(new_carry, inputs):
268
+ writen_state_vals, carry = new_carry
269
+ assert len(been_writen) == len(writen_state_vals) == len(read_state_vals)
270
+
271
+ # collect all written and read states
272
+ state_vals = [
273
+ written_val if written else read_val
274
+ for written, written_val, read_val in zip(been_writen, writen_state_vals, read_state_vals)
275
+ ]
276
+
277
+ # call the jaxpr
278
+ state_vals, (carry, out) = stateful_fun.jaxpr_call(state_vals, carry, inputs)
279
+
280
+ # only return the written states
281
+ writen_state_vals = tuple([val if written else None for written, val in zip(been_writen, state_vals)])
282
+
283
+ # return
284
+ return (writen_state_vals, carry), out
285
+
286
+ return wrapped_fun