brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,286 @@
|
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from functools import wraps
|
17
|
+
from typing import Sequence, Tuple, Hashable
|
18
|
+
|
19
|
+
from brainstate._state import StateTraceStack
|
20
|
+
from brainstate.typing import PyTree
|
21
|
+
from ._make_jaxpr import StatefulFunction
|
22
|
+
|
23
|
+
|
24
|
+
def wrap_single_fun_in_multi_branches(
|
25
|
+
stateful_fun: StatefulFunction,
|
26
|
+
merged_state_trace: StateTraceStack,
|
27
|
+
read_state_vals: Sequence[PyTree | None],
|
28
|
+
return_states: bool = True,
|
29
|
+
cache_key: Hashable = None,
|
30
|
+
):
|
31
|
+
"""
|
32
|
+
Wrap a stateful function for use in multi-branch control flow.
|
33
|
+
|
34
|
+
This function creates a wrapper that allows a stateful function to be used
|
35
|
+
in control flow operations where multiple functions share state. It manages
|
36
|
+
state values by extracting only the states needed by this specific function
|
37
|
+
from a merged state trace.
|
38
|
+
|
39
|
+
Parameters
|
40
|
+
----------
|
41
|
+
stateful_fun : StatefulFunction
|
42
|
+
The stateful function to be wrapped.
|
43
|
+
merged_state_trace : StateTraceStack
|
44
|
+
The merged state trace containing all states from multiple functions.
|
45
|
+
read_state_vals : sequence of PyTree or None
|
46
|
+
The original read state values for all states in the merged trace.
|
47
|
+
return_states : bool, default True
|
48
|
+
Whether to return updated state values along with the function output.
|
49
|
+
|
50
|
+
Returns
|
51
|
+
-------
|
52
|
+
callable
|
53
|
+
A wrapped function that can be used in multi-branch control flow.
|
54
|
+
|
55
|
+
Examples
|
56
|
+
--------
|
57
|
+
Usage in conditional execution:
|
58
|
+
|
59
|
+
.. code-block:: python
|
60
|
+
|
61
|
+
>>> import brainstate
|
62
|
+
>>> import jax.numpy as jnp
|
63
|
+
>>>
|
64
|
+
>>> # Create states
|
65
|
+
>>> state1 = brainstate.State(jnp.array([1.0]))
|
66
|
+
>>> state2 = brainstate.State(jnp.array([2.0]))
|
67
|
+
>>>
|
68
|
+
>>> def branch_fn(x):
|
69
|
+
... state1.value *= x
|
70
|
+
... return state1.value + state2.value
|
71
|
+
>>>
|
72
|
+
>>> # During compilation, this wrapper allows the function
|
73
|
+
>>> # to work with merged state traces from multiple branches
|
74
|
+
>>> sf = brainstate.transform.StatefulFunction(branch_fn)
|
75
|
+
>>> # wrapped_fn = wrap_single_fun_in_multi_branches(sf, merged_trace, read_vals)
|
76
|
+
"""
|
77
|
+
state_ids_belong_to_this_fun = {id(st): st for st in stateful_fun.get_states_by_cache(cache_key)}
|
78
|
+
|
79
|
+
@wraps(stateful_fun.fun)
|
80
|
+
def wrapped_branch(write_state_vals, *operands):
|
81
|
+
# "write_state_vals" should have the same length as "merged_state_trace.states"
|
82
|
+
assert len(merged_state_trace.states) == len(write_state_vals) == len(read_state_vals)
|
83
|
+
|
84
|
+
# get all state values needed for this function, which is a subset of "write_state_vals"
|
85
|
+
st_vals_for_this_fun = []
|
86
|
+
for write, st, val_w, val_r in zip(merged_state_trace.been_writen,
|
87
|
+
merged_state_trace.states,
|
88
|
+
write_state_vals,
|
89
|
+
read_state_vals):
|
90
|
+
if id(st) in state_ids_belong_to_this_fun:
|
91
|
+
st_vals_for_this_fun.append(val_w if write else val_r)
|
92
|
+
|
93
|
+
# call this function
|
94
|
+
new_state_vals, out = stateful_fun.jaxpr_call(st_vals_for_this_fun, *operands)
|
95
|
+
assert len(new_state_vals) == len(st_vals_for_this_fun)
|
96
|
+
|
97
|
+
if return_states:
|
98
|
+
# get all written state values
|
99
|
+
new_state_vals = {id(st): val for st, val in zip(stateful_fun.get_states_by_cache(cache_key), new_state_vals)}
|
100
|
+
write_state_vals = tuple([
|
101
|
+
(new_state_vals[id(st)] if id(st) in state_ids_belong_to_this_fun else w_val)
|
102
|
+
if write else None
|
103
|
+
for write, st, w_val in zip(merged_state_trace.been_writen,
|
104
|
+
merged_state_trace.states,
|
105
|
+
write_state_vals)
|
106
|
+
])
|
107
|
+
return write_state_vals, out
|
108
|
+
return out
|
109
|
+
|
110
|
+
return wrapped_branch
|
111
|
+
|
112
|
+
|
113
|
+
def wrap_single_fun_in_multi_branches_while_loop(
|
114
|
+
stateful_fun: StatefulFunction,
|
115
|
+
merged_state_trace: StateTraceStack,
|
116
|
+
read_state_vals: Sequence[PyTree | None],
|
117
|
+
return_states: bool = True,
|
118
|
+
cache_key: Hashable = None,
|
119
|
+
):
|
120
|
+
"""
|
121
|
+
Wrap a stateful function for use in while loop control flow.
|
122
|
+
|
123
|
+
This function creates a wrapper specifically designed for while loop operations
|
124
|
+
where multiple functions share state. It manages state values by extracting only
|
125
|
+
the states needed by this specific function from a merged state trace, with
|
126
|
+
special handling for the loop's init_val structure.
|
127
|
+
|
128
|
+
Parameters
|
129
|
+
----------
|
130
|
+
stateful_fun : StatefulFunction
|
131
|
+
The stateful function to be wrapped.
|
132
|
+
merged_state_trace : StateTraceStack
|
133
|
+
The merged state trace containing all states from multiple functions.
|
134
|
+
read_state_vals : sequence of PyTree or None
|
135
|
+
The original read state values for all states in the merged trace.
|
136
|
+
return_states : bool, default True
|
137
|
+
Whether to return updated state values along with the function output.
|
138
|
+
|
139
|
+
Returns
|
140
|
+
-------
|
141
|
+
callable
|
142
|
+
A wrapped function that can be used in while loop control flow.
|
143
|
+
|
144
|
+
Examples
|
145
|
+
--------
|
146
|
+
Usage in while loop operations:
|
147
|
+
|
148
|
+
.. code-block:: python
|
149
|
+
|
150
|
+
>>> import brainstate
|
151
|
+
>>> import jax.numpy as jnp
|
152
|
+
>>>
|
153
|
+
>>> # Create states
|
154
|
+
>>> counter = brainstate.State(jnp.array([0]))
|
155
|
+
>>> accumulator = brainstate.State(jnp.array([0.0]))
|
156
|
+
>>>
|
157
|
+
>>> def cond_fn(val):
|
158
|
+
... return counter.value < 10
|
159
|
+
>>>
|
160
|
+
>>> def body_fn(val):
|
161
|
+
... counter.value += 1
|
162
|
+
... accumulator.value += val
|
163
|
+
... return val * 2
|
164
|
+
>>>
|
165
|
+
>>> # During compilation, this wrapper allows the functions
|
166
|
+
>>> # to work with merged state traces in while loops
|
167
|
+
>>> sf_cond = brainstate.transform.StatefulFunction(cond_fn)
|
168
|
+
>>> sf_body = brainstate.transform.StatefulFunction(body_fn)
|
169
|
+
>>> # wrapped_cond = wrap_single_fun_in_multi_branches_while_loop(sf_cond, ...)
|
170
|
+
>>> # wrapped_body = wrap_single_fun_in_multi_branches_while_loop(sf_body, ...)
|
171
|
+
"""
|
172
|
+
state_ids_belong_to_this_fun = {id(st): st for st in stateful_fun.get_states_by_cache(cache_key)}
|
173
|
+
|
174
|
+
@wraps(stateful_fun.fun)
|
175
|
+
def wrapped_branch(init_val):
|
176
|
+
write_state_vals, init_val = init_val
|
177
|
+
# "write_state_vals" should have the same length as "merged_state_trace.states"
|
178
|
+
assert len(merged_state_trace.states) == len(write_state_vals) == len(read_state_vals)
|
179
|
+
|
180
|
+
# get all state values needed for this function, which is a subset of "write_state_vals"
|
181
|
+
st_vals_for_this_fun = []
|
182
|
+
for write, st, val_w, val_r in zip(merged_state_trace.been_writen,
|
183
|
+
merged_state_trace.states,
|
184
|
+
write_state_vals,
|
185
|
+
read_state_vals):
|
186
|
+
if id(st) in state_ids_belong_to_this_fun:
|
187
|
+
st_vals_for_this_fun.append(val_w if write else val_r)
|
188
|
+
|
189
|
+
# call this function
|
190
|
+
new_state_vals, out = stateful_fun.jaxpr_call(st_vals_for_this_fun, init_val)
|
191
|
+
assert len(new_state_vals) == len(st_vals_for_this_fun)
|
192
|
+
|
193
|
+
if return_states:
|
194
|
+
# get all written state values
|
195
|
+
new_state_vals = {id(st): val for st, val in zip(stateful_fun.get_states_by_cache(cache_key), new_state_vals)}
|
196
|
+
write_state_vals = tuple([
|
197
|
+
(new_state_vals[id(st)] if id(st) in state_ids_belong_to_this_fun else w_val)
|
198
|
+
if write else None
|
199
|
+
for write, st, w_val in zip(merged_state_trace.been_writen,
|
200
|
+
merged_state_trace.states,
|
201
|
+
write_state_vals)
|
202
|
+
])
|
203
|
+
return write_state_vals, out
|
204
|
+
return out
|
205
|
+
|
206
|
+
return wrapped_branch
|
207
|
+
|
208
|
+
|
209
|
+
def wrap_single_fun(
|
210
|
+
stateful_fun: StatefulFunction,
|
211
|
+
been_writen: Sequence[bool],
|
212
|
+
read_state_vals: Tuple[PyTree | None],
|
213
|
+
):
|
214
|
+
"""
|
215
|
+
Wrap a stateful function for use in scan operations.
|
216
|
+
|
217
|
+
This function creates a wrapper specifically designed for scan operations.
|
218
|
+
It manages state values by combining written and read states, calls the
|
219
|
+
stateful function, and returns only the written states along with the
|
220
|
+
carry and output values.
|
221
|
+
|
222
|
+
Parameters
|
223
|
+
----------
|
224
|
+
stateful_fun : StatefulFunction
|
225
|
+
The stateful function to be wrapped for scan operations.
|
226
|
+
been_writen : sequence of bool
|
227
|
+
Boolean flags indicating which states have been written to.
|
228
|
+
read_state_vals : tuple of PyTree or None
|
229
|
+
The original read state values for all states.
|
230
|
+
|
231
|
+
Returns
|
232
|
+
-------
|
233
|
+
callable
|
234
|
+
A wrapped function that can be used in scan operations with proper
|
235
|
+
state management.
|
236
|
+
|
237
|
+
Examples
|
238
|
+
--------
|
239
|
+
Usage in scan operations:
|
240
|
+
|
241
|
+
.. code-block:: python
|
242
|
+
|
243
|
+
>>> import brainstate
|
244
|
+
>>> import jax.numpy as jnp
|
245
|
+
>>>
|
246
|
+
>>> # Create states
|
247
|
+
>>> state1 = brainstate.State(jnp.array([0.0]))
|
248
|
+
>>> state2 = brainstate.State(jnp.array([1.0]))
|
249
|
+
>>>
|
250
|
+
>>> def scan_fn(carry, x):
|
251
|
+
... state1.value += x # This state will be written
|
252
|
+
... result = carry + state1.value + state2.value # state2 is only read
|
253
|
+
... return result, result ** 2
|
254
|
+
>>>
|
255
|
+
>>> # During compilation, this wrapper allows the function
|
256
|
+
>>> # to work properly in scan operations
|
257
|
+
>>> sf = brainstate.transform.StatefulFunction(scan_fn)
|
258
|
+
>>> # wrapped_fn = wrap_single_fun(sf, been_written_flags, read_values)
|
259
|
+
>>>
|
260
|
+
>>> # The wrapped function handles state management automatically
|
261
|
+
>>> xs = jnp.arange(5.0)
|
262
|
+
>>> init_carry = 0.0
|
263
|
+
final_carry, ys = brainstate.transform.scan(scan_fn, init_carry, xs)
|
264
|
+
"""
|
265
|
+
|
266
|
+
@wraps(stateful_fun.fun)
|
267
|
+
def wrapped_fun(new_carry, inputs):
|
268
|
+
writen_state_vals, carry = new_carry
|
269
|
+
assert len(been_writen) == len(writen_state_vals) == len(read_state_vals)
|
270
|
+
|
271
|
+
# collect all written and read states
|
272
|
+
state_vals = [
|
273
|
+
written_val if written else read_val
|
274
|
+
for written, written_val, read_val in zip(been_writen, writen_state_vals, read_state_vals)
|
275
|
+
]
|
276
|
+
|
277
|
+
# call the jaxpr
|
278
|
+
state_vals, (carry, out) = stateful_fun.jaxpr_call(state_vals, carry, inputs)
|
279
|
+
|
280
|
+
# only return the written states
|
281
|
+
writen_state_vals = tuple([val if written else None for written, val in zip(been_writen, state_vals)])
|
282
|
+
|
283
|
+
# return
|
284
|
+
return (writen_state_vals, carry), out
|
285
|
+
|
286
|
+
return wrapped_fun
|