brainstate 0.1.10__py2.py3-none-any.whl → 0.2.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +169 -58
  2. brainstate/_compatible_import.py +340 -148
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +45 -55
  7. brainstate/_state.py +1652 -1605
  8. brainstate/_state_test.py +52 -52
  9. brainstate/_utils.py +47 -47
  10. brainstate/environ.py +1495 -563
  11. brainstate/environ_test.py +1223 -62
  12. brainstate/graph/__init__.py +22 -29
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +1624 -1738
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1433 -365
  18. brainstate/mixin_test.py +1017 -77
  19. brainstate/nn/__init__.py +137 -135
  20. brainstate/nn/_activations.py +1100 -808
  21. brainstate/nn/_activations_test.py +354 -331
  22. brainstate/nn/_collective_ops.py +633 -514
  23. brainstate/nn/_collective_ops_test.py +774 -43
  24. brainstate/nn/_common.py +226 -178
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +2010 -501
  27. brainstate/nn/_conv_test.py +849 -238
  28. brainstate/nn/_delay.py +575 -588
  29. brainstate/nn/_delay_test.py +243 -238
  30. brainstate/nn/_dropout.py +618 -426
  31. brainstate/nn/_dropout_test.py +477 -100
  32. brainstate/nn/_dynamics.py +1267 -1343
  33. brainstate/nn/_dynamics_test.py +67 -78
  34. brainstate/nn/_elementwise.py +1298 -1119
  35. brainstate/nn/_elementwise_test.py +830 -169
  36. brainstate/nn/_embedding.py +408 -58
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +233 -239
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +115 -114
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +83 -83
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +121 -120
  42. brainstate/nn/_exp_euler.py +254 -92
  43. brainstate/nn/_exp_euler_test.py +377 -35
  44. brainstate/nn/_linear.py +744 -424
  45. brainstate/nn/_linear_test.py +475 -107
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +384 -377
  49. brainstate/nn/_module_test.py +40 -40
  50. brainstate/nn/_normalizations.py +1334 -975
  51. brainstate/nn/_normalizations_test.py +699 -73
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +2239 -1177
  55. brainstate/nn/_poolings_test.py +953 -217
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +946 -554
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +216 -89
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +809 -553
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +180 -149
  62. brainstate/random/__init__.py +270 -24
  63. brainstate/random/_rand_funs.py +3938 -3616
  64. brainstate/random/_rand_funs_test.py +640 -567
  65. brainstate/random/_rand_seed.py +675 -210
  66. brainstate/random/_rand_seed_test.py +48 -48
  67. brainstate/random/_rand_state.py +1617 -1409
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +49 -49
  72. brainstate/{augment → transform}/_autograd.py +1025 -778
  73. brainstate/{augment → transform}/_autograd_test.py +1289 -1289
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +220 -220
  76. brainstate/{compile → transform}/_error_if.py +94 -92
  77. brainstate/{compile → transform}/_error_if_test.py +52 -52
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +38 -38
  80. brainstate/{compile → transform}/_jit.py +399 -346
  81. brainstate/{compile → transform}/_jit_test.py +143 -143
  82. brainstate/{compile → transform}/_loop_collect_return.py +675 -536
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +58 -58
  84. brainstate/{compile → transform}/_loop_no_collection.py +283 -184
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +50 -50
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +255 -202
  91. brainstate/{augment → transform}/_random.py +171 -151
  92. brainstate/{compile → transform}/_unvmap.py +256 -159
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +837 -304
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +27 -50
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +462 -328
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +945 -469
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +910 -523
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/METADATA +108 -91
  108. brainstate-0.2.1.dist-info/RECORD +111 -0
  109. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/licenses/LICENSE +202 -202
  110. brainstate/augment/__init__.py +0 -30
  111. brainstate/augment/_eval_shape.py +0 -99
  112. brainstate/augment/_mapping.py +0 -1060
  113. brainstate/augment/_mapping_test.py +0 -597
  114. brainstate/compile/__init__.py +0 -38
  115. brainstate/compile/_ad_checkpoint.py +0 -204
  116. brainstate/compile/_conditions.py +0 -256
  117. brainstate/compile/_make_jaxpr.py +0 -888
  118. brainstate/compile/_make_jaxpr_test.py +0 -156
  119. brainstate/compile/_util.py +0 -147
  120. brainstate/functional/__init__.py +0 -27
  121. brainstate/graph/_graph_node.py +0 -244
  122. brainstate/graph/_graph_node_test.py +0 -73
  123. brainstate/graph/_graph_operation_test.py +0 -563
  124. brainstate/init/__init__.py +0 -26
  125. brainstate/init/_base.py +0 -52
  126. brainstate/init/_generic.py +0 -244
  127. brainstate/init/_regular_inits.py +0 -105
  128. brainstate/init/_regular_inits_test.py +0 -50
  129. brainstate/nn/_inputs.py +0 -608
  130. brainstate/nn/_ltp.py +0 -28
  131. brainstate/nn/_neuron.py +0 -705
  132. brainstate/nn/_neuron_test.py +0 -161
  133. brainstate/nn/_others.py +0 -46
  134. brainstate/nn/_projection.py +0 -486
  135. brainstate/nn/_rate_rnns_test.py +0 -63
  136. brainstate/nn/_readout.py +0 -209
  137. brainstate/nn/_readout_test.py +0 -53
  138. brainstate/nn/_stp.py +0 -236
  139. brainstate/nn/_synapse.py +0 -505
  140. brainstate/nn/_synapse_test.py +0 -131
  141. brainstate/nn/_synaptic_projection.py +0 -423
  142. brainstate/nn/_synouts.py +0 -162
  143. brainstate/nn/_synouts_test.py +0 -57
  144. brainstate/nn/metrics.py +0 -388
  145. brainstate/optim/__init__.py +0 -38
  146. brainstate/optim/_base.py +0 -64
  147. brainstate/optim/_lr_scheduler.py +0 -448
  148. brainstate/optim/_lr_scheduler_test.py +0 -50
  149. brainstate/optim/_optax_optimizer.py +0 -152
  150. brainstate/optim/_optax_optimizer_test.py +0 -53
  151. brainstate/optim/_sgd_optimizer.py +0 -1104
  152. brainstate/random/_random_for_unit.py +0 -52
  153. brainstate/surrogate.py +0 -1957
  154. brainstate/transform.py +0 -23
  155. brainstate/util/caller.py +0 -98
  156. brainstate/util/others.py +0 -540
  157. brainstate/util/pretty_pytree.py +0 -945
  158. brainstate/util/pretty_pytree_test.py +0 -159
  159. brainstate/util/pretty_table.py +0 -2954
  160. brainstate/util/scaling.py +0 -258
  161. brainstate-0.1.10.dist-info/RECORD +0 -130
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/WHEEL +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,286 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from functools import wraps
17
+ from typing import Sequence, Tuple, Hashable
18
+
19
+ from brainstate._state import StateTraceStack
20
+ from brainstate.typing import PyTree
21
+ from ._make_jaxpr import StatefulFunction
22
+
23
+
24
+ def wrap_single_fun_in_multi_branches(
25
+ stateful_fun: StatefulFunction,
26
+ merged_state_trace: StateTraceStack,
27
+ read_state_vals: Sequence[PyTree | None],
28
+ return_states: bool = True,
29
+ cache_key: Hashable = None,
30
+ ):
31
+ """
32
+ Wrap a stateful function for use in multi-branch control flow.
33
+
34
+ This function creates a wrapper that allows a stateful function to be used
35
+ in control flow operations where multiple functions share state. It manages
36
+ state values by extracting only the states needed by this specific function
37
+ from a merged state trace.
38
+
39
+ Parameters
40
+ ----------
41
+ stateful_fun : StatefulFunction
42
+ The stateful function to be wrapped.
43
+ merged_state_trace : StateTraceStack
44
+ The merged state trace containing all states from multiple functions.
45
+ read_state_vals : sequence of PyTree or None
46
+ The original read state values for all states in the merged trace.
47
+ return_states : bool, default True
48
+ Whether to return updated state values along with the function output.
49
+
50
+ Returns
51
+ -------
52
+ callable
53
+ A wrapped function that can be used in multi-branch control flow.
54
+
55
+ Examples
56
+ --------
57
+ Usage in conditional execution:
58
+
59
+ .. code-block:: python
60
+
61
+ >>> import brainstate
62
+ >>> import jax.numpy as jnp
63
+ >>>
64
+ >>> # Create states
65
+ >>> state1 = brainstate.State(jnp.array([1.0]))
66
+ >>> state2 = brainstate.State(jnp.array([2.0]))
67
+ >>>
68
+ >>> def branch_fn(x):
69
+ ... state1.value *= x
70
+ ... return state1.value + state2.value
71
+ >>>
72
+ >>> # During compilation, this wrapper allows the function
73
+ >>> # to work with merged state traces from multiple branches
74
+ >>> sf = brainstate.transform.StatefulFunction(branch_fn)
75
+ >>> # wrapped_fn = wrap_single_fun_in_multi_branches(sf, merged_trace, read_vals)
76
+ """
77
+ state_ids_belong_to_this_fun = {id(st): st for st in stateful_fun.get_states_by_cache(cache_key)}
78
+
79
+ @wraps(stateful_fun.fun)
80
+ def wrapped_branch(write_state_vals, *operands):
81
+ # "write_state_vals" should have the same length as "merged_state_trace.states"
82
+ assert len(merged_state_trace.states) == len(write_state_vals) == len(read_state_vals)
83
+
84
+ # get all state values needed for this function, which is a subset of "write_state_vals"
85
+ st_vals_for_this_fun = []
86
+ for write, st, val_w, val_r in zip(merged_state_trace.been_writen,
87
+ merged_state_trace.states,
88
+ write_state_vals,
89
+ read_state_vals):
90
+ if id(st) in state_ids_belong_to_this_fun:
91
+ st_vals_for_this_fun.append(val_w if write else val_r)
92
+
93
+ # call this function
94
+ new_state_vals, out = stateful_fun.jaxpr_call(st_vals_for_this_fun, *operands)
95
+ assert len(new_state_vals) == len(st_vals_for_this_fun)
96
+
97
+ if return_states:
98
+ # get all written state values
99
+ new_state_vals = {id(st): val for st, val in zip(stateful_fun.get_states_by_cache(cache_key), new_state_vals)}
100
+ write_state_vals = tuple([
101
+ (new_state_vals[id(st)] if id(st) in state_ids_belong_to_this_fun else w_val)
102
+ if write else None
103
+ for write, st, w_val in zip(merged_state_trace.been_writen,
104
+ merged_state_trace.states,
105
+ write_state_vals)
106
+ ])
107
+ return write_state_vals, out
108
+ return out
109
+
110
+ return wrapped_branch
111
+
112
+
113
+ def wrap_single_fun_in_multi_branches_while_loop(
114
+ stateful_fun: StatefulFunction,
115
+ merged_state_trace: StateTraceStack,
116
+ read_state_vals: Sequence[PyTree | None],
117
+ return_states: bool = True,
118
+ cache_key: Hashable = None,
119
+ ):
120
+ """
121
+ Wrap a stateful function for use in while loop control flow.
122
+
123
+ This function creates a wrapper specifically designed for while loop operations
124
+ where multiple functions share state. It manages state values by extracting only
125
+ the states needed by this specific function from a merged state trace, with
126
+ special handling for the loop's init_val structure.
127
+
128
+ Parameters
129
+ ----------
130
+ stateful_fun : StatefulFunction
131
+ The stateful function to be wrapped.
132
+ merged_state_trace : StateTraceStack
133
+ The merged state trace containing all states from multiple functions.
134
+ read_state_vals : sequence of PyTree or None
135
+ The original read state values for all states in the merged trace.
136
+ return_states : bool, default True
137
+ Whether to return updated state values along with the function output.
138
+
139
+ Returns
140
+ -------
141
+ callable
142
+ A wrapped function that can be used in while loop control flow.
143
+
144
+ Examples
145
+ --------
146
+ Usage in while loop operations:
147
+
148
+ .. code-block:: python
149
+
150
+ >>> import brainstate
151
+ >>> import jax.numpy as jnp
152
+ >>>
153
+ >>> # Create states
154
+ >>> counter = brainstate.State(jnp.array([0]))
155
+ >>> accumulator = brainstate.State(jnp.array([0.0]))
156
+ >>>
157
+ >>> def cond_fn(val):
158
+ ... return counter.value < 10
159
+ >>>
160
+ >>> def body_fn(val):
161
+ ... counter.value += 1
162
+ ... accumulator.value += val
163
+ ... return val * 2
164
+ >>>
165
+ >>> # During compilation, this wrapper allows the functions
166
+ >>> # to work with merged state traces in while loops
167
+ >>> sf_cond = brainstate.transform.StatefulFunction(cond_fn)
168
+ >>> sf_body = brainstate.transform.StatefulFunction(body_fn)
169
+ >>> # wrapped_cond = wrap_single_fun_in_multi_branches_while_loop(sf_cond, ...)
170
+ >>> # wrapped_body = wrap_single_fun_in_multi_branches_while_loop(sf_body, ...)
171
+ """
172
+ state_ids_belong_to_this_fun = {id(st): st for st in stateful_fun.get_states_by_cache(cache_key)}
173
+
174
+ @wraps(stateful_fun.fun)
175
+ def wrapped_branch(init_val):
176
+ write_state_vals, init_val = init_val
177
+ # "write_state_vals" should have the same length as "merged_state_trace.states"
178
+ assert len(merged_state_trace.states) == len(write_state_vals) == len(read_state_vals)
179
+
180
+ # get all state values needed for this function, which is a subset of "write_state_vals"
181
+ st_vals_for_this_fun = []
182
+ for write, st, val_w, val_r in zip(merged_state_trace.been_writen,
183
+ merged_state_trace.states,
184
+ write_state_vals,
185
+ read_state_vals):
186
+ if id(st) in state_ids_belong_to_this_fun:
187
+ st_vals_for_this_fun.append(val_w if write else val_r)
188
+
189
+ # call this function
190
+ new_state_vals, out = stateful_fun.jaxpr_call(st_vals_for_this_fun, init_val)
191
+ assert len(new_state_vals) == len(st_vals_for_this_fun)
192
+
193
+ if return_states:
194
+ # get all written state values
195
+ new_state_vals = {id(st): val for st, val in zip(stateful_fun.get_states_by_cache(cache_key), new_state_vals)}
196
+ write_state_vals = tuple([
197
+ (new_state_vals[id(st)] if id(st) in state_ids_belong_to_this_fun else w_val)
198
+ if write else None
199
+ for write, st, w_val in zip(merged_state_trace.been_writen,
200
+ merged_state_trace.states,
201
+ write_state_vals)
202
+ ])
203
+ return write_state_vals, out
204
+ return out
205
+
206
+ return wrapped_branch
207
+
208
+
209
+ def wrap_single_fun(
210
+ stateful_fun: StatefulFunction,
211
+ been_writen: Sequence[bool],
212
+ read_state_vals: Tuple[PyTree | None],
213
+ ):
214
+ """
215
+ Wrap a stateful function for use in scan operations.
216
+
217
+ This function creates a wrapper specifically designed for scan operations.
218
+ It manages state values by combining written and read states, calls the
219
+ stateful function, and returns only the written states along with the
220
+ carry and output values.
221
+
222
+ Parameters
223
+ ----------
224
+ stateful_fun : StatefulFunction
225
+ The stateful function to be wrapped for scan operations.
226
+ been_writen : sequence of bool
227
+ Boolean flags indicating which states have been written to.
228
+ read_state_vals : tuple of PyTree or None
229
+ The original read state values for all states.
230
+
231
+ Returns
232
+ -------
233
+ callable
234
+ A wrapped function that can be used in scan operations with proper
235
+ state management.
236
+
237
+ Examples
238
+ --------
239
+ Usage in scan operations:
240
+
241
+ .. code-block:: python
242
+
243
+ >>> import brainstate
244
+ >>> import jax.numpy as jnp
245
+ >>>
246
+ >>> # Create states
247
+ >>> state1 = brainstate.State(jnp.array([0.0]))
248
+ >>> state2 = brainstate.State(jnp.array([1.0]))
249
+ >>>
250
+ >>> def scan_fn(carry, x):
251
+ ... state1.value += x # This state will be written
252
+ ... result = carry + state1.value + state2.value # state2 is only read
253
+ ... return result, result ** 2
254
+ >>>
255
+ >>> # During compilation, this wrapper allows the function
256
+ >>> # to work properly in scan operations
257
+ >>> sf = brainstate.transform.StatefulFunction(scan_fn)
258
+ >>> # wrapped_fn = wrap_single_fun(sf, been_written_flags, read_values)
259
+ >>>
260
+ >>> # The wrapped function handles state management automatically
261
+ >>> xs = jnp.arange(5.0)
262
+ >>> init_carry = 0.0
263
+ final_carry, ys = brainstate.transform.scan(scan_fn, init_carry, xs)
264
+ """
265
+
266
+ @wraps(stateful_fun.fun)
267
+ def wrapped_fun(new_carry, inputs):
268
+ writen_state_vals, carry = new_carry
269
+ assert len(been_writen) == len(writen_state_vals) == len(read_state_vals)
270
+
271
+ # collect all written and read states
272
+ state_vals = [
273
+ written_val if written else read_val
274
+ for written, written_val, read_val in zip(been_writen, writen_state_vals, read_state_vals)
275
+ ]
276
+
277
+ # call the jaxpr
278
+ state_vals, (carry, out) = stateful_fun.jaxpr_call(state_vals, carry, inputs)
279
+
280
+ # only return the written states
281
+ writen_state_vals = tuple([val if written else None for written, val in zip(been_writen, state_vals)])
282
+
283
+ # return
284
+ return (writen_state_vals, carry), out
285
+
286
+ return wrapped_fun