brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. brainstate/__init__.py +130 -19
  2. brainstate/_compatible_import.py +201 -9
  3. brainstate/_compatible_import_test.py +681 -0
  4. brainstate/_deprecation.py +210 -0
  5. brainstate/_deprecation_test.py +2319 -0
  6. brainstate/{util/error.py → _error.py} +10 -20
  7. brainstate/_state.py +94 -47
  8. brainstate/_state_test.py +1 -1
  9. brainstate/_utils.py +1 -1
  10. brainstate/environ.py +1279 -347
  11. brainstate/environ_test.py +1187 -26
  12. brainstate/graph/__init__.py +6 -13
  13. brainstate/graph/_node.py +240 -0
  14. brainstate/graph/_node_test.py +589 -0
  15. brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
  16. brainstate/graph/_operation_test.py +1147 -0
  17. brainstate/mixin.py +1209 -141
  18. brainstate/mixin_test.py +991 -51
  19. brainstate/nn/__init__.py +74 -72
  20. brainstate/nn/_activations.py +587 -295
  21. brainstate/nn/_activations_test.py +109 -86
  22. brainstate/nn/_collective_ops.py +393 -274
  23. brainstate/nn/_collective_ops_test.py +746 -15
  24. brainstate/nn/_common.py +114 -66
  25. brainstate/nn/_common_test.py +154 -0
  26. brainstate/nn/_conv.py +1652 -143
  27. brainstate/nn/_conv_test.py +838 -227
  28. brainstate/nn/_delay.py +15 -28
  29. brainstate/nn/_delay_test.py +25 -20
  30. brainstate/nn/_dropout.py +359 -167
  31. brainstate/nn/_dropout_test.py +429 -52
  32. brainstate/nn/_dynamics.py +14 -90
  33. brainstate/nn/_dynamics_test.py +1 -12
  34. brainstate/nn/_elementwise.py +492 -313
  35. brainstate/nn/_elementwise_test.py +806 -145
  36. brainstate/nn/_embedding.py +369 -19
  37. brainstate/nn/_embedding_test.py +156 -0
  38. brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
  39. brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
  40. brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
  41. brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
  42. brainstate/nn/_exp_euler.py +200 -38
  43. brainstate/nn/_exp_euler_test.py +350 -8
  44. brainstate/nn/_linear.py +391 -71
  45. brainstate/nn/_linear_test.py +427 -59
  46. brainstate/nn/_metrics.py +1070 -0
  47. brainstate/nn/_metrics_test.py +611 -0
  48. brainstate/nn/_module.py +10 -3
  49. brainstate/nn/_module_test.py +1 -1
  50. brainstate/nn/_normalizations.py +688 -329
  51. brainstate/nn/_normalizations_test.py +663 -37
  52. brainstate/nn/_paddings.py +1020 -0
  53. brainstate/nn/_paddings_test.py +723 -0
  54. brainstate/nn/_poolings.py +1404 -342
  55. brainstate/nn/_poolings_test.py +828 -92
  56. brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
  57. brainstate/nn/_rnns_test.py +593 -0
  58. brainstate/nn/_utils.py +132 -5
  59. brainstate/nn/_utils_test.py +402 -0
  60. brainstate/{init/_random_inits.py → nn/init.py} +301 -45
  61. brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
  62. brainstate/random/__init__.py +247 -1
  63. brainstate/random/_rand_funs.py +668 -346
  64. brainstate/random/_rand_funs_test.py +74 -1
  65. brainstate/random/_rand_seed.py +541 -76
  66. brainstate/random/_rand_seed_test.py +1 -1
  67. brainstate/random/_rand_state.py +601 -393
  68. brainstate/random/_rand_state_test.py +551 -0
  69. brainstate/transform/__init__.py +59 -0
  70. brainstate/transform/_ad_checkpoint.py +176 -0
  71. brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
  72. brainstate/{augment → transform}/_autograd.py +360 -113
  73. brainstate/{augment → transform}/_autograd_test.py +2 -2
  74. brainstate/transform/_conditions.py +316 -0
  75. brainstate/{compile → transform}/_conditions_test.py +11 -11
  76. brainstate/{compile → transform}/_error_if.py +22 -20
  77. brainstate/{compile → transform}/_error_if_test.py +1 -1
  78. brainstate/transform/_eval_shape.py +145 -0
  79. brainstate/{augment → transform}/_eval_shape_test.py +1 -1
  80. brainstate/{compile → transform}/_jit.py +99 -46
  81. brainstate/{compile → transform}/_jit_test.py +3 -3
  82. brainstate/{compile → transform}/_loop_collect_return.py +219 -80
  83. brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
  84. brainstate/{compile → transform}/_loop_no_collection.py +133 -34
  85. brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
  86. brainstate/transform/_make_jaxpr.py +2016 -0
  87. brainstate/transform/_make_jaxpr_test.py +1510 -0
  88. brainstate/transform/_mapping.py +529 -0
  89. brainstate/transform/_mapping_test.py +194 -0
  90. brainstate/{compile → transform}/_progress_bar.py +78 -25
  91. brainstate/{augment → transform}/_random.py +65 -45
  92. brainstate/{compile → transform}/_unvmap.py +102 -5
  93. brainstate/transform/_util.py +286 -0
  94. brainstate/typing.py +594 -61
  95. brainstate/typing_test.py +780 -0
  96. brainstate/util/__init__.py +9 -32
  97. brainstate/util/_others.py +1025 -0
  98. brainstate/util/_others_test.py +962 -0
  99. brainstate/util/_pretty_pytree.py +1301 -0
  100. brainstate/util/_pretty_pytree_test.py +675 -0
  101. brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
  102. brainstate/util/_pretty_repr_test.py +696 -0
  103. brainstate/util/filter.py +557 -81
  104. brainstate/util/filter_test.py +912 -0
  105. brainstate/util/struct.py +769 -382
  106. brainstate/util/struct_test.py +602 -0
  107. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
  108. brainstate-0.2.0.dist-info/RECORD +111 -0
  109. brainstate/augment/__init__.py +0 -30
  110. brainstate/augment/_eval_shape.py +0 -99
  111. brainstate/augment/_mapping.py +0 -1060
  112. brainstate/augment/_mapping_test.py +0 -597
  113. brainstate/compile/__init__.py +0 -38
  114. brainstate/compile/_ad_checkpoint.py +0 -204
  115. brainstate/compile/_conditions.py +0 -256
  116. brainstate/compile/_make_jaxpr.py +0 -888
  117. brainstate/compile/_make_jaxpr_test.py +0 -156
  118. brainstate/compile/_util.py +0 -147
  119. brainstate/functional/__init__.py +0 -27
  120. brainstate/graph/_graph_node.py +0 -244
  121. brainstate/graph/_graph_node_test.py +0 -73
  122. brainstate/graph/_graph_operation_test.py +0 -563
  123. brainstate/init/__init__.py +0 -26
  124. brainstate/init/_base.py +0 -52
  125. brainstate/init/_generic.py +0 -244
  126. brainstate/init/_regular_inits.py +0 -105
  127. brainstate/init/_regular_inits_test.py +0 -50
  128. brainstate/nn/_inputs.py +0 -608
  129. brainstate/nn/_ltp.py +0 -28
  130. brainstate/nn/_neuron.py +0 -705
  131. brainstate/nn/_neuron_test.py +0 -161
  132. brainstate/nn/_others.py +0 -46
  133. brainstate/nn/_projection.py +0 -486
  134. brainstate/nn/_rate_rnns_test.py +0 -63
  135. brainstate/nn/_readout.py +0 -209
  136. brainstate/nn/_readout_test.py +0 -53
  137. brainstate/nn/_stp.py +0 -236
  138. brainstate/nn/_synapse.py +0 -505
  139. brainstate/nn/_synapse_test.py +0 -131
  140. brainstate/nn/_synaptic_projection.py +0 -423
  141. brainstate/nn/_synouts.py +0 -162
  142. brainstate/nn/_synouts_test.py +0 -57
  143. brainstate/nn/metrics.py +0 -388
  144. brainstate/optim/__init__.py +0 -38
  145. brainstate/optim/_base.py +0 -64
  146. brainstate/optim/_lr_scheduler.py +0 -448
  147. brainstate/optim/_lr_scheduler_test.py +0 -50
  148. brainstate/optim/_optax_optimizer.py +0 -152
  149. brainstate/optim/_optax_optimizer_test.py +0 -53
  150. brainstate/optim/_sgd_optimizer.py +0 -1104
  151. brainstate/random/_random_for_unit.py +0 -52
  152. brainstate/surrogate.py +0 -1957
  153. brainstate/transform.py +0 -23
  154. brainstate/util/caller.py +0 -98
  155. brainstate/util/others.py +0 -540
  156. brainstate/util/pretty_pytree.py +0 -945
  157. brainstate/util/pretty_pytree_test.py +0 -159
  158. brainstate/util/pretty_table.py +0 -2954
  159. brainstate/util/scaling.py +0 -258
  160. brainstate-0.1.10.dist-info/RECORD +0 -130
  161. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
  162. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
  163. {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,286 @@
1
+ # Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from functools import wraps
17
+ from typing import Sequence, Tuple, Hashable
18
+
19
+ from brainstate._state import StateTraceStack
20
+ from brainstate.typing import PyTree
21
+ from ._make_jaxpr import StatefulFunction
22
+
23
+
24
+ def wrap_single_fun_in_multi_branches(
25
+ stateful_fun: StatefulFunction,
26
+ merged_state_trace: StateTraceStack,
27
+ read_state_vals: Sequence[PyTree | None],
28
+ return_states: bool = True,
29
+ cache_key: Hashable = None,
30
+ ):
31
+ """
32
+ Wrap a stateful function for use in multi-branch control flow.
33
+
34
+ This function creates a wrapper that allows a stateful function to be used
35
+ in control flow operations where multiple functions share state. It manages
36
+ state values by extracting only the states needed by this specific function
37
+ from a merged state trace.
38
+
39
+ Parameters
40
+ ----------
41
+ stateful_fun : StatefulFunction
42
+ The stateful function to be wrapped.
43
+ merged_state_trace : StateTraceStack
44
+ The merged state trace containing all states from multiple functions.
45
+ read_state_vals : sequence of PyTree or None
46
+ The original read state values for all states in the merged trace.
47
+ return_states : bool, default True
48
+ Whether to return updated state values along with the function output.
49
+
50
+ Returns
51
+ -------
52
+ callable
53
+ A wrapped function that can be used in multi-branch control flow.
54
+
55
+ Examples
56
+ --------
57
+ Usage in conditional execution:
58
+
59
+ .. code-block:: python
60
+
61
+ >>> import brainstate
62
+ >>> import jax.numpy as jnp
63
+ >>>
64
+ >>> # Create states
65
+ >>> state1 = brainstate.State(jnp.array([1.0]))
66
+ >>> state2 = brainstate.State(jnp.array([2.0]))
67
+ >>>
68
+ >>> def branch_fn(x):
69
+ ... state1.value *= x
70
+ ... return state1.value + state2.value
71
+ >>>
72
+ >>> # During compilation, this wrapper allows the function
73
+ >>> # to work with merged state traces from multiple branches
74
+ >>> sf = brainstate.transform.StatefulFunction(branch_fn)
75
+ >>> # wrapped_fn = wrap_single_fun_in_multi_branches(sf, merged_trace, read_vals)
76
+ """
77
+ state_ids_belong_to_this_fun = {id(st): st for st in stateful_fun.get_states_by_cache(cache_key)}
78
+
79
+ @wraps(stateful_fun.fun)
80
+ def wrapped_branch(write_state_vals, *operands):
81
+ # "write_state_vals" should have the same length as "merged_state_trace.states"
82
+ assert len(merged_state_trace.states) == len(write_state_vals) == len(read_state_vals)
83
+
84
+ # get all state values needed for this function, which is a subset of "write_state_vals"
85
+ st_vals_for_this_fun = []
86
+ for write, st, val_w, val_r in zip(merged_state_trace.been_writen,
87
+ merged_state_trace.states,
88
+ write_state_vals,
89
+ read_state_vals):
90
+ if id(st) in state_ids_belong_to_this_fun:
91
+ st_vals_for_this_fun.append(val_w if write else val_r)
92
+
93
+ # call this function
94
+ new_state_vals, out = stateful_fun.jaxpr_call(st_vals_for_this_fun, *operands)
95
+ assert len(new_state_vals) == len(st_vals_for_this_fun)
96
+
97
+ if return_states:
98
+ # get all written state values
99
+ new_state_vals = {id(st): val for st, val in zip(stateful_fun.get_states_by_cache(cache_key), new_state_vals)}
100
+ write_state_vals = tuple([
101
+ (new_state_vals[id(st)] if id(st) in state_ids_belong_to_this_fun else w_val)
102
+ if write else None
103
+ for write, st, w_val in zip(merged_state_trace.been_writen,
104
+ merged_state_trace.states,
105
+ write_state_vals)
106
+ ])
107
+ return write_state_vals, out
108
+ return out
109
+
110
+ return wrapped_branch
111
+
112
+
113
+ def wrap_single_fun_in_multi_branches_while_loop(
114
+ stateful_fun: StatefulFunction,
115
+ merged_state_trace: StateTraceStack,
116
+ read_state_vals: Sequence[PyTree | None],
117
+ return_states: bool = True,
118
+ cache_key: Hashable = None,
119
+ ):
120
+ """
121
+ Wrap a stateful function for use in while loop control flow.
122
+
123
+ This function creates a wrapper specifically designed for while loop operations
124
+ where multiple functions share state. It manages state values by extracting only
125
+ the states needed by this specific function from a merged state trace, with
126
+ special handling for the loop's init_val structure.
127
+
128
+ Parameters
129
+ ----------
130
+ stateful_fun : StatefulFunction
131
+ The stateful function to be wrapped.
132
+ merged_state_trace : StateTraceStack
133
+ The merged state trace containing all states from multiple functions.
134
+ read_state_vals : sequence of PyTree or None
135
+ The original read state values for all states in the merged trace.
136
+ return_states : bool, default True
137
+ Whether to return updated state values along with the function output.
138
+
139
+ Returns
140
+ -------
141
+ callable
142
+ A wrapped function that can be used in while loop control flow.
143
+
144
+ Examples
145
+ --------
146
+ Usage in while loop operations:
147
+
148
+ .. code-block:: python
149
+
150
+ >>> import brainstate
151
+ >>> import jax.numpy as jnp
152
+ >>>
153
+ >>> # Create states
154
+ >>> counter = brainstate.State(jnp.array([0]))
155
+ >>> accumulator = brainstate.State(jnp.array([0.0]))
156
+ >>>
157
+ >>> def cond_fn(val):
158
+ ... return counter.value < 10
159
+ >>>
160
+ >>> def body_fn(val):
161
+ ... counter.value += 1
162
+ ... accumulator.value += val
163
+ ... return val * 2
164
+ >>>
165
+ >>> # During compilation, this wrapper allows the functions
166
+ >>> # to work with merged state traces in while loops
167
+ >>> sf_cond = brainstate.transform.StatefulFunction(cond_fn)
168
+ >>> sf_body = brainstate.transform.StatefulFunction(body_fn)
169
+ >>> # wrapped_cond = wrap_single_fun_in_multi_branches_while_loop(sf_cond, ...)
170
+ >>> # wrapped_body = wrap_single_fun_in_multi_branches_while_loop(sf_body, ...)
171
+ """
172
+ state_ids_belong_to_this_fun = {id(st): st for st in stateful_fun.get_states_by_cache(cache_key)}
173
+
174
+ @wraps(stateful_fun.fun)
175
+ def wrapped_branch(init_val):
176
+ write_state_vals, init_val = init_val
177
+ # "write_state_vals" should have the same length as "merged_state_trace.states"
178
+ assert len(merged_state_trace.states) == len(write_state_vals) == len(read_state_vals)
179
+
180
+ # get all state values needed for this function, which is a subset of "write_state_vals"
181
+ st_vals_for_this_fun = []
182
+ for write, st, val_w, val_r in zip(merged_state_trace.been_writen,
183
+ merged_state_trace.states,
184
+ write_state_vals,
185
+ read_state_vals):
186
+ if id(st) in state_ids_belong_to_this_fun:
187
+ st_vals_for_this_fun.append(val_w if write else val_r)
188
+
189
+ # call this function
190
+ new_state_vals, out = stateful_fun.jaxpr_call(st_vals_for_this_fun, init_val)
191
+ assert len(new_state_vals) == len(st_vals_for_this_fun)
192
+
193
+ if return_states:
194
+ # get all written state values
195
+ new_state_vals = {id(st): val for st, val in zip(stateful_fun.get_states_by_cache(cache_key), new_state_vals)}
196
+ write_state_vals = tuple([
197
+ (new_state_vals[id(st)] if id(st) in state_ids_belong_to_this_fun else w_val)
198
+ if write else None
199
+ for write, st, w_val in zip(merged_state_trace.been_writen,
200
+ merged_state_trace.states,
201
+ write_state_vals)
202
+ ])
203
+ return write_state_vals, out
204
+ return out
205
+
206
+ return wrapped_branch
207
+
208
+
209
+ def wrap_single_fun(
210
+ stateful_fun: StatefulFunction,
211
+ been_writen: Sequence[bool],
212
+ read_state_vals: Tuple[PyTree | None],
213
+ ):
214
+ """
215
+ Wrap a stateful function for use in scan operations.
216
+
217
+ This function creates a wrapper specifically designed for scan operations.
218
+ It manages state values by combining written and read states, calls the
219
+ stateful function, and returns only the written states along with the
220
+ carry and output values.
221
+
222
+ Parameters
223
+ ----------
224
+ stateful_fun : StatefulFunction
225
+ The stateful function to be wrapped for scan operations.
226
+ been_writen : sequence of bool
227
+ Boolean flags indicating which states have been written to.
228
+ read_state_vals : tuple of PyTree or None
229
+ The original read state values for all states.
230
+
231
+ Returns
232
+ -------
233
+ callable
234
+ A wrapped function that can be used in scan operations with proper
235
+ state management.
236
+
237
+ Examples
238
+ --------
239
+ Usage in scan operations:
240
+
241
+ .. code-block:: python
242
+
243
+ >>> import brainstate
244
+ >>> import jax.numpy as jnp
245
+ >>>
246
+ >>> # Create states
247
+ >>> state1 = brainstate.State(jnp.array([0.0]))
248
+ >>> state2 = brainstate.State(jnp.array([1.0]))
249
+ >>>
250
+ >>> def scan_fn(carry, x):
251
+ ... state1.value += x # This state will be written
252
+ ... result = carry + state1.value + state2.value # state2 is only read
253
+ ... return result, result ** 2
254
+ >>>
255
+ >>> # During compilation, this wrapper allows the function
256
+ >>> # to work properly in scan operations
257
+ >>> sf = brainstate.transform.StatefulFunction(scan_fn)
258
+ >>> # wrapped_fn = wrap_single_fun(sf, been_written_flags, read_values)
259
+ >>>
260
+ >>> # The wrapped function handles state management automatically
261
+ >>> xs = jnp.arange(5.0)
262
+ >>> init_carry = 0.0
263
+ final_carry, ys = brainstate.transform.scan(scan_fn, init_carry, xs)
264
+ """
265
+
266
+ @wraps(stateful_fun.fun)
267
+ def wrapped_fun(new_carry, inputs):
268
+ writen_state_vals, carry = new_carry
269
+ assert len(been_writen) == len(writen_state_vals) == len(read_state_vals)
270
+
271
+ # collect all written and read states
272
+ state_vals = [
273
+ written_val if written else read_val
274
+ for written, written_val, read_val in zip(been_writen, writen_state_vals, read_state_vals)
275
+ ]
276
+
277
+ # call the jaxpr
278
+ state_vals, (carry, out) = stateful_fun.jaxpr_call(state_vals, carry, inputs)
279
+
280
+ # only return the written states
281
+ writen_state_vals = tuple([val if written else None for written, val in zip(been_writen, state_vals)])
282
+
283
+ # return
284
+ return (writen_state_vals, carry), out
285
+
286
+ return wrapped_fun