brainstate 0.1.10__py2.py3-none-any.whl → 0.2.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +130 -19
- brainstate/_compatible_import.py +201 -9
- brainstate/_compatible_import_test.py +681 -0
- brainstate/_deprecation.py +210 -0
- brainstate/_deprecation_test.py +2319 -0
- brainstate/{util/error.py → _error.py} +10 -20
- brainstate/_state.py +94 -47
- brainstate/_state_test.py +1 -1
- brainstate/_utils.py +1 -1
- brainstate/environ.py +1279 -347
- brainstate/environ_test.py +1187 -26
- brainstate/graph/__init__.py +6 -13
- brainstate/graph/_node.py +240 -0
- brainstate/graph/_node_test.py +589 -0
- brainstate/graph/{_graph_operation.py → _operation.py} +632 -746
- brainstate/graph/_operation_test.py +1147 -0
- brainstate/mixin.py +1209 -141
- brainstate/mixin_test.py +991 -51
- brainstate/nn/__init__.py +74 -72
- brainstate/nn/_activations.py +587 -295
- brainstate/nn/_activations_test.py +109 -86
- brainstate/nn/_collective_ops.py +393 -274
- brainstate/nn/_collective_ops_test.py +746 -15
- brainstate/nn/_common.py +114 -66
- brainstate/nn/_common_test.py +154 -0
- brainstate/nn/_conv.py +1652 -143
- brainstate/nn/_conv_test.py +838 -227
- brainstate/nn/_delay.py +15 -28
- brainstate/nn/_delay_test.py +25 -20
- brainstate/nn/_dropout.py +359 -167
- brainstate/nn/_dropout_test.py +429 -52
- brainstate/nn/_dynamics.py +14 -90
- brainstate/nn/_dynamics_test.py +1 -12
- brainstate/nn/_elementwise.py +492 -313
- brainstate/nn/_elementwise_test.py +806 -145
- brainstate/nn/_embedding.py +369 -19
- brainstate/nn/_embedding_test.py +156 -0
- brainstate/nn/{_fixedprob.py → _event_fixedprob.py} +10 -16
- brainstate/nn/{_fixedprob_test.py → _event_fixedprob_test.py} +6 -5
- brainstate/nn/{_linear_mv.py → _event_linear.py} +2 -2
- brainstate/nn/{_linear_mv_test.py → _event_linear_test.py} +6 -5
- brainstate/nn/_exp_euler.py +200 -38
- brainstate/nn/_exp_euler_test.py +350 -8
- brainstate/nn/_linear.py +391 -71
- brainstate/nn/_linear_test.py +427 -59
- brainstate/nn/_metrics.py +1070 -0
- brainstate/nn/_metrics_test.py +611 -0
- brainstate/nn/_module.py +10 -3
- brainstate/nn/_module_test.py +1 -1
- brainstate/nn/_normalizations.py +688 -329
- brainstate/nn/_normalizations_test.py +663 -37
- brainstate/nn/_paddings.py +1020 -0
- brainstate/nn/_paddings_test.py +723 -0
- brainstate/nn/_poolings.py +1404 -342
- brainstate/nn/_poolings_test.py +828 -92
- brainstate/nn/{_rate_rnns.py → _rnns.py} +446 -54
- brainstate/nn/_rnns_test.py +593 -0
- brainstate/nn/_utils.py +132 -5
- brainstate/nn/_utils_test.py +402 -0
- brainstate/{init/_random_inits.py → nn/init.py} +301 -45
- brainstate/{init/_random_inits_test.py → nn/init_test.py} +51 -20
- brainstate/random/__init__.py +247 -1
- brainstate/random/_rand_funs.py +668 -346
- brainstate/random/_rand_funs_test.py +74 -1
- brainstate/random/_rand_seed.py +541 -76
- brainstate/random/_rand_seed_test.py +1 -1
- brainstate/random/_rand_state.py +601 -393
- brainstate/random/_rand_state_test.py +551 -0
- brainstate/transform/__init__.py +59 -0
- brainstate/transform/_ad_checkpoint.py +176 -0
- brainstate/{compile → transform}/_ad_checkpoint_test.py +1 -1
- brainstate/{augment → transform}/_autograd.py +360 -113
- brainstate/{augment → transform}/_autograd_test.py +2 -2
- brainstate/transform/_conditions.py +316 -0
- brainstate/{compile → transform}/_conditions_test.py +11 -11
- brainstate/{compile → transform}/_error_if.py +22 -20
- brainstate/{compile → transform}/_error_if_test.py +1 -1
- brainstate/transform/_eval_shape.py +145 -0
- brainstate/{augment → transform}/_eval_shape_test.py +1 -1
- brainstate/{compile → transform}/_jit.py +99 -46
- brainstate/{compile → transform}/_jit_test.py +3 -3
- brainstate/{compile → transform}/_loop_collect_return.py +219 -80
- brainstate/{compile → transform}/_loop_collect_return_test.py +1 -1
- brainstate/{compile → transform}/_loop_no_collection.py +133 -34
- brainstate/{compile → transform}/_loop_no_collection_test.py +2 -2
- brainstate/transform/_make_jaxpr.py +2016 -0
- brainstate/transform/_make_jaxpr_test.py +1510 -0
- brainstate/transform/_mapping.py +529 -0
- brainstate/transform/_mapping_test.py +194 -0
- brainstate/{compile → transform}/_progress_bar.py +78 -25
- brainstate/{augment → transform}/_random.py +65 -45
- brainstate/{compile → transform}/_unvmap.py +102 -5
- brainstate/transform/_util.py +286 -0
- brainstate/typing.py +594 -61
- brainstate/typing_test.py +780 -0
- brainstate/util/__init__.py +9 -32
- brainstate/util/_others.py +1025 -0
- brainstate/util/_others_test.py +962 -0
- brainstate/util/_pretty_pytree.py +1301 -0
- brainstate/util/_pretty_pytree_test.py +675 -0
- brainstate/util/{pretty_repr.py → _pretty_repr.py} +161 -27
- brainstate/util/_pretty_repr_test.py +696 -0
- brainstate/util/filter.py +557 -81
- brainstate/util/filter_test.py +912 -0
- brainstate/util/struct.py +769 -382
- brainstate/util/struct_test.py +602 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/METADATA +34 -17
- brainstate-0.2.0.dist-info/RECORD +111 -0
- brainstate/augment/__init__.py +0 -30
- brainstate/augment/_eval_shape.py +0 -99
- brainstate/augment/_mapping.py +0 -1060
- brainstate/augment/_mapping_test.py +0 -597
- brainstate/compile/__init__.py +0 -38
- brainstate/compile/_ad_checkpoint.py +0 -204
- brainstate/compile/_conditions.py +0 -256
- brainstate/compile/_make_jaxpr.py +0 -888
- brainstate/compile/_make_jaxpr_test.py +0 -156
- brainstate/compile/_util.py +0 -147
- brainstate/functional/__init__.py +0 -27
- brainstate/graph/_graph_node.py +0 -244
- brainstate/graph/_graph_node_test.py +0 -73
- brainstate/graph/_graph_operation_test.py +0 -563
- brainstate/init/__init__.py +0 -26
- brainstate/init/_base.py +0 -52
- brainstate/init/_generic.py +0 -244
- brainstate/init/_regular_inits.py +0 -105
- brainstate/init/_regular_inits_test.py +0 -50
- brainstate/nn/_inputs.py +0 -608
- brainstate/nn/_ltp.py +0 -28
- brainstate/nn/_neuron.py +0 -705
- brainstate/nn/_neuron_test.py +0 -161
- brainstate/nn/_others.py +0 -46
- brainstate/nn/_projection.py +0 -486
- brainstate/nn/_rate_rnns_test.py +0 -63
- brainstate/nn/_readout.py +0 -209
- brainstate/nn/_readout_test.py +0 -53
- brainstate/nn/_stp.py +0 -236
- brainstate/nn/_synapse.py +0 -505
- brainstate/nn/_synapse_test.py +0 -131
- brainstate/nn/_synaptic_projection.py +0 -423
- brainstate/nn/_synouts.py +0 -162
- brainstate/nn/_synouts_test.py +0 -57
- brainstate/nn/metrics.py +0 -388
- brainstate/optim/__init__.py +0 -38
- brainstate/optim/_base.py +0 -64
- brainstate/optim/_lr_scheduler.py +0 -448
- brainstate/optim/_lr_scheduler_test.py +0 -50
- brainstate/optim/_optax_optimizer.py +0 -152
- brainstate/optim/_optax_optimizer_test.py +0 -53
- brainstate/optim/_sgd_optimizer.py +0 -1104
- brainstate/random/_random_for_unit.py +0 -52
- brainstate/surrogate.py +0 -1957
- brainstate/transform.py +0 -23
- brainstate/util/caller.py +0 -98
- brainstate/util/others.py +0 -540
- brainstate/util/pretty_pytree.py +0 -945
- brainstate/util/pretty_pytree_test.py +0 -159
- brainstate/util/pretty_table.py +0 -2954
- brainstate/util/scaling.py +0 -258
- brainstate-0.1.10.dist-info/RECORD +0 -130
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/WHEEL +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {brainstate-0.1.10.dist-info → brainstate-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -22,7 +22,6 @@ from brainstate._utils import set_module_as
|
|
22
22
|
from ._loop_collect_return import _bounded_while_loop
|
23
23
|
from ._make_jaxpr import StatefulFunction
|
24
24
|
from ._util import wrap_single_fun_in_multi_branches_while_loop as wrap_fn
|
25
|
-
from ._util import write_back_state_values
|
26
25
|
|
27
26
|
X = TypeVar('X')
|
28
27
|
Y = TypeVar('Y')
|
@@ -35,7 +34,7 @@ __all__ = [
|
|
35
34
|
]
|
36
35
|
|
37
36
|
|
38
|
-
@set_module_as('brainstate.
|
37
|
+
@set_module_as('brainstate.transform')
|
39
38
|
def while_loop(
|
40
39
|
cond_fun: Callable[[T], BooleanNumeric],
|
41
40
|
body_fun: Callable[[T], T],
|
@@ -50,13 +49,15 @@ def while_loop(
|
|
50
49
|
|
51
50
|
while_loop :: (a -> Bool) -> (a -> a) -> a -> a
|
52
51
|
|
53
|
-
The semantics of ``while_loop`` are given by this Python implementation
|
52
|
+
The semantics of ``while_loop`` are given by this Python implementation:
|
54
53
|
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
54
|
+
.. code-block:: python
|
55
|
+
|
56
|
+
>>> def while_loop(cond_fun, body_fun, init_val):
|
57
|
+
... val = init_val
|
58
|
+
... while cond_fun(val):
|
59
|
+
... val = body_fun(val)
|
60
|
+
... return val
|
60
61
|
|
61
62
|
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
|
62
63
|
to a single WhileOp. That makes it useful for reducing compilation times
|
@@ -75,16 +76,55 @@ def while_loop(
|
|
75
76
|
``while_loop`` is not reverse-mode differentiable because XLA computations
|
76
77
|
require static bounds on memory requirements.
|
77
78
|
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
79
|
+
Parameters
|
80
|
+
----------
|
81
|
+
cond_fun : callable
|
82
|
+
Function of type ``a -> Bool``.
|
83
|
+
body_fun : callable
|
84
|
+
Function of type ``a -> a``.
|
85
|
+
init_val : T
|
86
|
+
Value of type ``a``, a type that can be a scalar, array, or any
|
82
87
|
pytree (nested Python tuple/list/dict) thereof, representing the initial
|
83
88
|
loop carry value.
|
84
89
|
|
85
|
-
Returns
|
86
|
-
|
87
|
-
|
90
|
+
Returns
|
91
|
+
-------
|
92
|
+
T
|
93
|
+
The output from the final iteration of body_fun, of type ``a``.
|
94
|
+
|
95
|
+
Examples
|
96
|
+
--------
|
97
|
+
Basic while loop operation:
|
98
|
+
|
99
|
+
.. code-block:: python
|
100
|
+
|
101
|
+
>>> import brainstate
|
102
|
+
>>> import jax.numpy as jnp
|
103
|
+
>>>
|
104
|
+
>>> def cond_fn(val):
|
105
|
+
... return val < 10
|
106
|
+
>>>
|
107
|
+
>>> def body_fn(val):
|
108
|
+
... return val + 1
|
109
|
+
>>>
|
110
|
+
>>> result = brainstate.transform.while_loop(cond_fn, body_fn, 0)
|
111
|
+
>>> # result will be 10
|
112
|
+
|
113
|
+
While loop with array state:
|
114
|
+
|
115
|
+
.. code-block:: python
|
116
|
+
|
117
|
+
>>> def cond_fn(state):
|
118
|
+
... return jnp.sum(state) < 100
|
119
|
+
>>>
|
120
|
+
>>> def body_fn(state):
|
121
|
+
... return state * 1.1
|
122
|
+
>>>
|
123
|
+
>>> init_state = jnp.array([1.0, 2.0, 3.0])
|
124
|
+
>>> final_state = brainstate.transform.while_loop(cond_fn, body_fn, init_state)
|
125
|
+
|
126
|
+
References
|
127
|
+
----------
|
88
128
|
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
|
89
129
|
"""
|
90
130
|
if not (callable(body_fun) and callable(cond_fun)):
|
@@ -103,24 +143,28 @@ def while_loop(
|
|
103
143
|
# evaluate jaxpr
|
104
144
|
stateful_cond = StatefulFunction(cond_fun, name='while:cond').make_jaxpr(init_val)
|
105
145
|
stateful_body = StatefulFunction(body_fun, name='while:body').make_jaxpr(init_val)
|
106
|
-
|
146
|
+
cond_cache_key = stateful_cond.get_arg_cache_key(init_val)
|
147
|
+
body_cache_key = stateful_body.get_arg_cache_key(init_val)
|
148
|
+
if len(stateful_cond.get_write_states_by_cache(cond_cache_key)) != 0:
|
107
149
|
raise ValueError("while_loop: cond_fun should not have any write states.")
|
108
150
|
|
109
151
|
# state trace and state values
|
110
|
-
state_trace = stateful_cond.
|
152
|
+
state_trace = (stateful_cond.get_state_trace_by_cache(cond_cache_key) +
|
153
|
+
stateful_body.get_state_trace_by_cache(body_cache_key))
|
111
154
|
read_state_vals = state_trace.get_read_state_values(True)
|
112
155
|
write_state_vals = state_trace.get_write_state_values(True)
|
113
|
-
new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False)
|
114
|
-
new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True)
|
156
|
+
new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False, cond_cache_key)
|
157
|
+
new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True, body_cache_key)
|
115
158
|
|
116
159
|
# while_loop
|
117
160
|
state_vals, final_val = jax.lax.while_loop(new_cond_fn, new_body_fn, (write_state_vals, init_val))
|
118
161
|
|
119
162
|
# write back state values or restore them
|
120
|
-
|
163
|
+
state_trace.assign_state_vals_v2(read_state_vals, state_vals)
|
121
164
|
return final_val
|
122
165
|
|
123
166
|
|
167
|
+
@set_module_as('brainstate.transform')
|
124
168
|
def bounded_while_loop(
|
125
169
|
cond_fun: Callable[[T], BooleanNumeric],
|
126
170
|
body_fun: Callable[[T], T],
|
@@ -138,18 +182,70 @@ def bounded_while_loop(
|
|
138
182
|
even if the condition function is never false. The function is implemented
|
139
183
|
using a scan operation, so it is reverse-mode differentiable.
|
140
184
|
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
185
|
+
Parameters
|
186
|
+
----------
|
187
|
+
cond_fun : callable
|
188
|
+
A function of type ``a -> Bool``.
|
189
|
+
body_fun : callable
|
190
|
+
A function of type ``a -> a``.
|
191
|
+
init_val : T
|
192
|
+
The initial value of type ``a``.
|
193
|
+
max_steps : int
|
194
|
+
A bound on the maximum number of steps, after which the loop
|
146
195
|
terminates unconditionally.
|
147
|
-
|
196
|
+
base : int, default 16
|
197
|
+
Run time will increase slightly as `base` increases. Compilation time will
|
148
198
|
decrease substantially as `math.ceil(math.log(max_steps, base))` decreases.
|
149
199
|
(Which happens as `base` increases.)
|
150
200
|
|
151
|
-
Returns
|
152
|
-
|
201
|
+
Returns
|
202
|
+
-------
|
203
|
+
T
|
204
|
+
The final value, as if computed by a `lax.while_loop`.
|
205
|
+
|
206
|
+
Examples
|
207
|
+
--------
|
208
|
+
Basic bounded while loop:
|
209
|
+
|
210
|
+
.. code-block:: python
|
211
|
+
|
212
|
+
>>> import brainstate
|
213
|
+
>>> import jax.numpy as jnp
|
214
|
+
>>>
|
215
|
+
>>> def cond_fn(val):
|
216
|
+
... return val < 1000 # This might never be false
|
217
|
+
>>>
|
218
|
+
>>> def body_fn(val):
|
219
|
+
... return val * 2
|
220
|
+
>>>
|
221
|
+
>>> # Loop will terminate after at most 10 steps
|
222
|
+
>>> result = brainstate.transform.bounded_while_loop(
|
223
|
+
... cond_fn, body_fn, 1, max_steps=10
|
224
|
+
... )
|
225
|
+
|
226
|
+
Bounded while loop with custom base:
|
227
|
+
|
228
|
+
.. code-block:: python
|
229
|
+
|
230
|
+
>>> # Use a smaller base for potentially faster compilation
|
231
|
+
>>> result = brainstate.transform.bounded_while_loop(
|
232
|
+
... cond_fn, body_fn, 1, max_steps=100, base=8
|
233
|
+
... )
|
234
|
+
|
235
|
+
Bounded while loop with array state:
|
236
|
+
|
237
|
+
.. code-block:: python
|
238
|
+
|
239
|
+
>>> def cond_fn(state):
|
240
|
+
... return jnp.max(state) < 100
|
241
|
+
>>>
|
242
|
+
>>> def body_fn(state):
|
243
|
+
... return state + jnp.array([1.0, 2.0, 3.0])
|
244
|
+
>>>
|
245
|
+
>>> init_state = jnp.array([0.0, 0.0, 0.0])
|
246
|
+
>>> final_state = brainstate.transform.bounded_while_loop(
|
247
|
+
... cond_fn, body_fn, init_state, max_steps=50
|
248
|
+
... )
|
153
249
|
"""
|
154
250
|
|
155
251
|
# checking
|
@@ -162,15 +258,18 @@ def bounded_while_loop(
|
|
162
258
|
# evaluate jaxpr
|
163
259
|
stateful_cond = StatefulFunction(cond_fun, name='bounded_while:cond').make_jaxpr(init_val)
|
164
260
|
stateful_body = StatefulFunction(body_fun, name='bounded_while:body').make_jaxpr(init_val)
|
165
|
-
|
261
|
+
cond_cache_key = stateful_cond.get_arg_cache_key(init_val)
|
262
|
+
body_cache_key = stateful_body.get_arg_cache_key(init_val)
|
263
|
+
if len(stateful_cond.get_write_states_by_cache(cond_cache_key)) != 0:
|
166
264
|
raise ValueError("while_loop: cond_fun should not have any write states.")
|
167
265
|
|
168
266
|
# state trace and state values
|
169
|
-
state_trace = stateful_cond.get_state_trace() +
|
267
|
+
state_trace = (stateful_cond.get_state_trace(init_val) +
|
268
|
+
stateful_body.get_state_trace(init_val))
|
170
269
|
read_state_vals = state_trace.get_read_state_values(True)
|
171
270
|
write_state_vals = state_trace.get_write_state_values(True)
|
172
|
-
new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False)
|
173
|
-
new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True)
|
271
|
+
new_cond_fn = wrap_fn(stateful_cond, state_trace, read_state_vals, False, cond_cache_key)
|
272
|
+
new_body_fn = wrap_fn(stateful_body, state_trace, read_state_vals, True, body_cache_key)
|
174
273
|
|
175
274
|
# initial value
|
176
275
|
init_val = (write_state_vals, init_val)
|
@@ -180,5 +279,5 @@ def bounded_while_loop(
|
|
180
279
|
state_vals, val = _bounded_while_loop(new_cond_fn, new_body_fn, init_val, rounded_max_steps, base, None)
|
181
280
|
|
182
281
|
# write back state values or restore them
|
183
|
-
|
282
|
+
state_trace.assign_state_vals_v2(read_state_vals, state_vals)
|
184
283
|
return val
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024
|
1
|
+
# Copyright 2024 BrainX Ecosystem Limited. All Rights Reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -16,7 +16,7 @@
|
|
16
16
|
|
17
17
|
from unittest import TestCase
|
18
18
|
|
19
|
-
import brainstate
|
19
|
+
import brainstate
|
20
20
|
|
21
21
|
|
22
22
|
class TestWhileLoop(TestCase):
|