brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +608 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/{nn/event → event}/__init__.py +6 -6
- brainstate/event/_csr.py +308 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +271 -0
- brainstate/event/_fixed_probability_test.py +128 -0
- brainstate/event/_linear.py +219 -0
- brainstate/event/_linear_test.py +112 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +320 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{_projection/__init__.py → _dyn_impl/_projection_alignpost.py} +6 -13
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +32 -0
- brainstate/nn/_interaction/_connections.py +726 -0
- brainstate/nn/_interaction/_connections_test.py +254 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +103 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/util/_error.py +28 -0
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/METADATA +11 -11
- brainstate-0.1.0.dist-info/RECORD +135 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/WHEEL +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,131 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
import unittest
|
17
|
-
|
18
|
-
import jax
|
19
|
-
import jax.numpy as jnp
|
20
|
-
import pytest
|
21
|
-
|
22
|
-
import brainstate as bc
|
23
|
-
|
24
|
-
|
25
|
-
class TestMakeJaxpr(unittest.TestCase):
|
26
|
-
|
27
|
-
def test_compar_jax_make_jaxpr(self):
|
28
|
-
def func4(arg): # Arg is a pair
|
29
|
-
temp = arg[0] + jnp.sin(arg[1]) * 3.
|
30
|
-
c = bc.random.rand_like(arg[0])
|
31
|
-
return jnp.sum(temp + c)
|
32
|
-
|
33
|
-
key = bc.random.DEFAULT.value
|
34
|
-
jaxpr = jax.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
|
35
|
-
print(jaxpr)
|
36
|
-
self.assertTrue(len(jaxpr.in_avals) == 2)
|
37
|
-
self.assertTrue(len(jaxpr.consts) == 1)
|
38
|
-
self.assertTrue(len(jaxpr.out_avals) == 1)
|
39
|
-
self.assertTrue(jnp.allclose(jaxpr.consts[0], key))
|
40
|
-
|
41
|
-
jaxpr2, states = bc.transform.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
|
42
|
-
print(jaxpr2)
|
43
|
-
self.assertTrue(len(jaxpr2.in_avals) == 3)
|
44
|
-
self.assertTrue(len(jaxpr2.out_avals) == 2)
|
45
|
-
self.assertTrue(len(jaxpr2.consts) == 0)
|
46
|
-
|
47
|
-
def test_StatefulFunction_1(self):
|
48
|
-
def func4(arg): # Arg is a pair
|
49
|
-
temp = arg[0] + jnp.sin(arg[1]) * 3.
|
50
|
-
c = bc.random.rand_like(arg[0])
|
51
|
-
return jnp.sum(temp + c)
|
52
|
-
|
53
|
-
fun = bc.transform.StatefulFunction(func4).make_jaxpr((jnp.zeros(8), jnp.ones(8)))
|
54
|
-
print(fun.get_states())
|
55
|
-
print(fun.get_jaxpr())
|
56
|
-
|
57
|
-
def test_StatefulFunction_2(self):
|
58
|
-
st1 = bc.State(jnp.ones(10))
|
59
|
-
|
60
|
-
def f1(x):
|
61
|
-
st1.value = x + st1.value
|
62
|
-
|
63
|
-
def f2(x):
|
64
|
-
jaxpr = bc.transform.make_jaxpr(f1)(x)
|
65
|
-
c = 1. + x
|
66
|
-
return c
|
67
|
-
|
68
|
-
def f3(x):
|
69
|
-
jaxpr = bc.transform.make_jaxpr(f1)(x)
|
70
|
-
c = 1.
|
71
|
-
return c
|
72
|
-
|
73
|
-
print()
|
74
|
-
jaxpr = bc.transform.make_jaxpr(f1)(jnp.zeros(1))
|
75
|
-
print(jaxpr)
|
76
|
-
jaxpr = jax.make_jaxpr(f2)(jnp.zeros(1))
|
77
|
-
print(jaxpr)
|
78
|
-
jaxpr = jax.make_jaxpr(f3)(jnp.zeros(1))
|
79
|
-
print(jaxpr)
|
80
|
-
jaxpr, _ = bc.transform.make_jaxpr(f3)(jnp.zeros(1))
|
81
|
-
print(jaxpr)
|
82
|
-
self.assertTrue(jnp.allclose(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
|
83
|
-
f3(jnp.zeros(1))))
|
84
|
-
|
85
|
-
def test_compar_jax_make_jaxpr2(self):
|
86
|
-
st1 = bc.State(jnp.ones(10))
|
87
|
-
|
88
|
-
def fa(x):
|
89
|
-
st1.value = x + st1.value
|
90
|
-
|
91
|
-
def ffa(x):
|
92
|
-
jaxpr, states = bc.transform.make_jaxpr(fa)(x)
|
93
|
-
c = 1. + x
|
94
|
-
return c
|
95
|
-
|
96
|
-
jaxpr, states = bc.transform.make_jaxpr(ffa)(jnp.zeros(1))
|
97
|
-
print()
|
98
|
-
print(jaxpr)
|
99
|
-
print(states)
|
100
|
-
print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value))
|
101
|
-
jaxpr = jax.make_jaxpr(ffa)(jnp.zeros(1))
|
102
|
-
print(jaxpr)
|
103
|
-
print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
104
|
-
|
105
|
-
def test_compar_jax_make_jaxpr3(self):
|
106
|
-
def fa(x):
|
107
|
-
return 1.
|
108
|
-
|
109
|
-
jaxpr, states = bc.transform.make_jaxpr(fa)(jnp.zeros(1))
|
110
|
-
print()
|
111
|
-
print(jaxpr)
|
112
|
-
print(states)
|
113
|
-
# print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
114
|
-
jaxpr = jax.make_jaxpr(fa)(jnp.zeros(1))
|
115
|
-
print(jaxpr)
|
116
|
-
# print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
117
|
-
|
118
|
-
|
119
|
-
def test_return_states():
|
120
|
-
import jax.numpy
|
121
|
-
|
122
|
-
import brainstate as bc
|
123
|
-
|
124
|
-
a = bc.State(jax.numpy.ones(3))
|
125
|
-
|
126
|
-
@bc.transform.jit
|
127
|
-
def f():
|
128
|
-
return a
|
129
|
-
|
130
|
-
with pytest.raises(ValueError):
|
131
|
-
f()
|
brainstate/transform/_mapping.py
DELETED
@@ -1,109 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
|
-
import jax
|
19
|
-
|
20
|
-
from ._loop_collect_return import scan
|
21
|
-
|
22
|
-
__all__ = [
|
23
|
-
'map',
|
24
|
-
]
|
25
|
-
|
26
|
-
|
27
|
-
def _batch_and_remainder(x, batch_size: int):
|
28
|
-
leaves, treedef = jax.tree.flatten(x)
|
29
|
-
|
30
|
-
scan_leaves = []
|
31
|
-
remainder_leaves = []
|
32
|
-
|
33
|
-
for leaf in leaves:
|
34
|
-
num_batches, _ = divmod(leaf.shape[0], batch_size)
|
35
|
-
total_batch_elems = num_batches * batch_size
|
36
|
-
scan_leaves.append(leaf[:total_batch_elems].reshape(num_batches, batch_size, *leaf.shape[1:]))
|
37
|
-
remainder_leaves.append(leaf[total_batch_elems:])
|
38
|
-
|
39
|
-
scan_tree = treedef.unflatten(scan_leaves)
|
40
|
-
remainder_tree = treedef.unflatten(remainder_leaves)
|
41
|
-
return scan_tree, remainder_tree
|
42
|
-
|
43
|
-
|
44
|
-
def map(
|
45
|
-
f,
|
46
|
-
xs,
|
47
|
-
*,
|
48
|
-
batch_size: int | None = None,
|
49
|
-
):
|
50
|
-
"""Map a function over leading array axes.
|
51
|
-
|
52
|
-
Like Python's builtin map, except inputs and outputs are in the form of
|
53
|
-
stacked arrays. Consider using the :func:`~jax.vmap` transform instead, unless you
|
54
|
-
need to apply a function element by element for reduced memory usage or
|
55
|
-
heterogeneous computation with other control flow primitives.
|
56
|
-
|
57
|
-
When ``xs`` is an array type, the semantics of :func:`~map` are given by this
|
58
|
-
Python implementation::
|
59
|
-
|
60
|
-
def map(f, xs):
|
61
|
-
return np.stack([f(x) for x in xs])
|
62
|
-
|
63
|
-
Like :func:`~scan`, :func:`~map` is implemented in terms of JAX primitives so
|
64
|
-
many of the same advantages over a Python loop apply: ``xs`` may be an
|
65
|
-
arbitrary nested pytree type, and the mapped computation is compiled only
|
66
|
-
once.
|
67
|
-
|
68
|
-
If ``batch_size`` is provided, the computation is executed in batches of that size
|
69
|
-
and parallelized using :func:`~jax.vmap`. This can be used as either a more performant
|
70
|
-
version of ``map`` or as a memory-efficient version of ``vmap``. If the axis is not
|
71
|
-
divisible by the batch size, the remainder is processed in a separate ``vmap`` and
|
72
|
-
concatenated to the result.
|
73
|
-
|
74
|
-
>>> x = jax.numpy.ones((10, 3, 4))
|
75
|
-
>>> def f(x):
|
76
|
-
... print('inner shape:', x.shape)
|
77
|
-
... return x + 1
|
78
|
-
>>> y = map(f, x, batch_size=3)
|
79
|
-
inner shape: (3, 4)
|
80
|
-
inner shape: (3, 4)
|
81
|
-
>>> y.shape
|
82
|
-
(10, 3, 4)
|
83
|
-
|
84
|
-
In the example above, "inner shape" is printed twice, once while tracing the batched
|
85
|
-
computation and once while tracing the remainder computation.
|
86
|
-
|
87
|
-
Args:
|
88
|
-
f: a Python function to apply element-wise over the first axis or axes of
|
89
|
-
``xs``.
|
90
|
-
xs: values over which to map along the leading axis.
|
91
|
-
batch_size: (optional) integer specifying the size of the batch for each step to execute
|
92
|
-
in parallel.
|
93
|
-
|
94
|
-
Returns:
|
95
|
-
Mapped values.
|
96
|
-
"""
|
97
|
-
if batch_size is not None:
|
98
|
-
scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size)
|
99
|
-
g = lambda _, x: ((), jax.vmap(f)(x))
|
100
|
-
_, scan_ys = scan(g, (), scan_xs)
|
101
|
-
remainder_ys = jax.vmap(f)(remainder_xs)
|
102
|
-
flatten = lambda x: x.reshape(-1, *x.shape[2:])
|
103
|
-
ys = jax.tree.map(
|
104
|
-
lambda x, y: jax.numpy.concatenate([flatten(x), y], axis=0), scan_ys, remainder_ys,
|
105
|
-
)
|
106
|
-
else:
|
107
|
-
g = lambda _, x: ((), f(x))
|
108
|
-
_, ys = scan(g, (), xs)
|
109
|
-
return ys
|
@@ -1,111 +0,0 @@
|
|
1
|
-
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
|
16
|
-
from __future__ import annotations
|
17
|
-
|
18
|
-
import copy
|
19
|
-
from typing import Optional
|
20
|
-
|
21
|
-
import jax
|
22
|
-
|
23
|
-
try:
|
24
|
-
from tqdm.auto import tqdm
|
25
|
-
except (ImportError, ModuleNotFoundError):
|
26
|
-
tqdm = None
|
27
|
-
|
28
|
-
__all__ = [
|
29
|
-
'ProgressBar',
|
30
|
-
]
|
31
|
-
|
32
|
-
|
33
|
-
class ProgressBar(object):
|
34
|
-
__module__ = "brainstate.transform"
|
35
|
-
|
36
|
-
def __init__(self, freq: Optional[int] = None, count: Optional[int] = None, **kwargs):
|
37
|
-
self.print_freq = freq
|
38
|
-
self.print_count = count
|
39
|
-
if self.print_freq is not None and self.print_count is not None:
|
40
|
-
raise ValueError("Cannot specify both count and freq.")
|
41
|
-
for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
|
42
|
-
kwargs.pop(kwarg, None)
|
43
|
-
self.kwargs = kwargs
|
44
|
-
if tqdm is None:
|
45
|
-
raise ImportError("tqdm is not installed.")
|
46
|
-
|
47
|
-
def init(self, n: int):
|
48
|
-
kwargs = copy.copy(self.kwargs)
|
49
|
-
freq = self.print_freq
|
50
|
-
count = self.print_count
|
51
|
-
if count is not None:
|
52
|
-
freq, remainder = divmod(n, count)
|
53
|
-
if freq == 0:
|
54
|
-
raise ValueError(f"Count {count} is too large for n {n}.")
|
55
|
-
elif freq is None:
|
56
|
-
if n > 20:
|
57
|
-
freq = int(n / 20)
|
58
|
-
else:
|
59
|
-
freq = 1
|
60
|
-
remainder = n % freq
|
61
|
-
else:
|
62
|
-
if freq < 1:
|
63
|
-
raise ValueError(f"Print rate should be > 0 got {freq}")
|
64
|
-
elif freq > n:
|
65
|
-
raise ValueError("Print rate should be less than the "
|
66
|
-
f"number of steps {n}, got {freq}")
|
67
|
-
remainder = n % freq
|
68
|
-
desc = kwargs.pop("desc", f"Running for {n:,} iterations")
|
69
|
-
message = kwargs.pop("message", desc)
|
70
|
-
return ProgressBarRunner(n, message, freq, remainder, **kwargs)
|
71
|
-
|
72
|
-
|
73
|
-
class ProgressBarRunner(object):
|
74
|
-
__module__ = "brainstate.transform"
|
75
|
-
|
76
|
-
def __init__(self, n: int, message, print_freq: int, remainder: int, **kwargs):
|
77
|
-
self.tqdm_bars = {}
|
78
|
-
self.kwargs = kwargs
|
79
|
-
self.n = n
|
80
|
-
self.print_freq = print_freq
|
81
|
-
self.remainder = remainder
|
82
|
-
self.message = message
|
83
|
-
|
84
|
-
def _define_tqdm(self):
|
85
|
-
self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
|
86
|
-
self.tqdm_bars[0].set_description(self.message, refresh=False)
|
87
|
-
|
88
|
-
def _update_tqdm(self):
|
89
|
-
self.tqdm_bars[0].update(self.print_freq)
|
90
|
-
|
91
|
-
def _close_tqdm(self):
|
92
|
-
if self.remainder > 0:
|
93
|
-
self.tqdm_bars[0].update(self.remainder)
|
94
|
-
self.tqdm_bars[0].close()
|
95
|
-
|
96
|
-
def __call__(self, iter_num, *args, **kwargs):
|
97
|
-
_ = jax.lax.cond(
|
98
|
-
iter_num == 0,
|
99
|
-
lambda: jax.debug.callback(self._define_tqdm),
|
100
|
-
lambda: None,
|
101
|
-
)
|
102
|
-
_ = jax.lax.cond(
|
103
|
-
(iter_num + 1) % self.print_freq == 0,
|
104
|
-
lambda: jax.debug.callback(self._update_tqdm),
|
105
|
-
lambda: None,
|
106
|
-
)
|
107
|
-
_ = jax.lax.cond(
|
108
|
-
iter_num == self.n - 1,
|
109
|
-
lambda: jax.debug.callback(self._close_tqdm),
|
110
|
-
lambda: None,
|
111
|
-
)
|
brainstate/transform/_unvmap.py
DELETED
@@ -1,143 +0,0 @@
|
|
1
|
-
import jax
|
2
|
-
import jax.core
|
3
|
-
import jax.interpreters.batching as batching
|
4
|
-
import jax.interpreters.mlir as mlir
|
5
|
-
import jax.numpy as jnp
|
6
|
-
|
7
|
-
from brainstate._utils import set_module_as
|
8
|
-
|
9
|
-
__all__ = [
|
10
|
-
"unvmap",
|
11
|
-
]
|
12
|
-
|
13
|
-
|
14
|
-
@set_module_as('brainstate.transform')
|
15
|
-
def unvmap(x, op: str = 'any'):
|
16
|
-
if op == 'all':
|
17
|
-
return unvmap_all(x)
|
18
|
-
elif op == 'any':
|
19
|
-
return unvmap_any(x)
|
20
|
-
elif op == 'none':
|
21
|
-
return _without_vmap(x)
|
22
|
-
elif op == 'max':
|
23
|
-
return unvmap_max(x)
|
24
|
-
else:
|
25
|
-
raise ValueError(f'Do not support type: {op}')
|
26
|
-
|
27
|
-
|
28
|
-
# unvmap_all
|
29
|
-
|
30
|
-
unvmap_all_p = jax.core.Primitive("unvmap_all")
|
31
|
-
|
32
|
-
|
33
|
-
def unvmap_all(x):
|
34
|
-
"""As `jnp.all`, but ignores batch dimensions."""
|
35
|
-
return unvmap_all_p.bind(x)
|
36
|
-
|
37
|
-
|
38
|
-
def _unvmap_all_impl(x):
|
39
|
-
return jnp.all(x)
|
40
|
-
|
41
|
-
|
42
|
-
def _unvmap_all_abstract_eval(x):
|
43
|
-
return jax.core.ShapedArray(shape=(), dtype=jax.numpy.bool_.dtype) # pyright: ignore
|
44
|
-
|
45
|
-
|
46
|
-
def _unvmap_all_batch(x, batch_axes):
|
47
|
-
(x,) = x
|
48
|
-
return unvmap_all(x), batching.not_mapped
|
49
|
-
|
50
|
-
|
51
|
-
unvmap_all_p.def_impl(_unvmap_all_impl)
|
52
|
-
unvmap_all_p.def_abstract_eval(_unvmap_all_abstract_eval)
|
53
|
-
batching.primitive_batchers[unvmap_all_p] = _unvmap_all_batch # pyright: ignore
|
54
|
-
mlir.register_lowering(
|
55
|
-
unvmap_all_p,
|
56
|
-
mlir.lower_fun(_unvmap_all_impl, multiple_results=False),
|
57
|
-
)
|
58
|
-
|
59
|
-
# unvmap_any
|
60
|
-
|
61
|
-
unvmap_any_p = jax.core.Primitive("unvmap_any")
|
62
|
-
|
63
|
-
|
64
|
-
def unvmap_any(x):
|
65
|
-
"""As `jnp.any`, but ignores batch dimensions."""
|
66
|
-
return unvmap_any_p.bind(x)
|
67
|
-
|
68
|
-
|
69
|
-
def _unvmap_any_impl(x):
|
70
|
-
return jnp.any(x)
|
71
|
-
|
72
|
-
|
73
|
-
def _unvmap_any_abstract_eval(x):
|
74
|
-
return jax.core.ShapedArray(shape=(), dtype=jax.numpy.bool_.dtype) # pyright: ignore
|
75
|
-
|
76
|
-
|
77
|
-
def _unvmap_any_batch(x, batch_axes):
|
78
|
-
(x,) = x
|
79
|
-
return unvmap_any(x), batching.not_mapped
|
80
|
-
|
81
|
-
|
82
|
-
unvmap_any_p.def_impl(_unvmap_any_impl)
|
83
|
-
unvmap_any_p.def_abstract_eval(_unvmap_any_abstract_eval)
|
84
|
-
batching.primitive_batchers[unvmap_any_p] = _unvmap_any_batch # pyright: ignore
|
85
|
-
mlir.register_lowering(
|
86
|
-
unvmap_any_p,
|
87
|
-
mlir.lower_fun(_unvmap_any_impl, multiple_results=False),
|
88
|
-
)
|
89
|
-
|
90
|
-
# unvmap_max
|
91
|
-
|
92
|
-
unvmap_max_p = jax.core.Primitive("unvmap_max")
|
93
|
-
|
94
|
-
|
95
|
-
def unvmap_max(x):
|
96
|
-
"""As `jnp.max`, but ignores batch dimensions."""
|
97
|
-
return unvmap_max_p.bind(x)
|
98
|
-
|
99
|
-
|
100
|
-
def _unvmap_max_impl(x):
|
101
|
-
return jnp.max(x)
|
102
|
-
|
103
|
-
|
104
|
-
def _unvmap_max_abstract_eval(x):
|
105
|
-
return jax.core.ShapedArray(shape=(), dtype=x.dtype)
|
106
|
-
|
107
|
-
|
108
|
-
def _unvmap_max_batch(x, batch_axes):
|
109
|
-
(x,) = x
|
110
|
-
return unvmap_max(x), batching.not_mapped
|
111
|
-
|
112
|
-
|
113
|
-
unvmap_max_p.def_impl(_unvmap_max_impl)
|
114
|
-
unvmap_max_p.def_abstract_eval(_unvmap_max_abstract_eval)
|
115
|
-
batching.primitive_batchers[unvmap_max_p] = _unvmap_max_batch # pyright: ignore
|
116
|
-
mlir.register_lowering(
|
117
|
-
unvmap_max_p,
|
118
|
-
mlir.lower_fun(_unvmap_max_impl, multiple_results=False),
|
119
|
-
)
|
120
|
-
|
121
|
-
|
122
|
-
def _without_vmap(x):
|
123
|
-
return _no_vmap_prim.bind(x)
|
124
|
-
|
125
|
-
|
126
|
-
def _without_vmap_imp(x):
|
127
|
-
return x
|
128
|
-
|
129
|
-
|
130
|
-
def _without_vmap_abs(x):
|
131
|
-
return x
|
132
|
-
|
133
|
-
|
134
|
-
def _without_vmap_batch(x, batch_axes):
|
135
|
-
(x,) = x
|
136
|
-
return _without_vmap(x), batching.not_mapped
|
137
|
-
|
138
|
-
|
139
|
-
_no_vmap_prim = jax.core.Primitive('no_vmap')
|
140
|
-
_no_vmap_prim.def_impl(_without_vmap_imp)
|
141
|
-
_no_vmap_prim.def_abstract_eval(_without_vmap_abs)
|
142
|
-
batching.primitive_batchers[_no_vmap_prim] = _without_vmap_batch
|
143
|
-
mlir.register_lowering(_no_vmap_prim, mlir.lower_fun(_without_vmap_imp, multiple_results=False))
|