brainstate 0.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +45 -0
- brainstate/_module.py +1466 -0
- brainstate/_module_test.py +133 -0
- brainstate/_state.py +378 -0
- brainstate/_state_test.py +41 -0
- brainstate/_utils.py +21 -0
- brainstate/environ.py +375 -0
- brainstate/functional/__init__.py +25 -0
- brainstate/functional/_activations.py +754 -0
- brainstate/functional/_normalization.py +69 -0
- brainstate/functional/_spikes.py +90 -0
- brainstate/init/__init__.py +26 -0
- brainstate/init/_base.py +36 -0
- brainstate/init/_generic.py +175 -0
- brainstate/init/_random_inits.py +489 -0
- brainstate/init/_regular_inits.py +109 -0
- brainstate/math/__init__.py +21 -0
- brainstate/math/_einops.py +787 -0
- brainstate/math/_einops_parsing.py +169 -0
- brainstate/math/_einops_parsing_test.py +126 -0
- brainstate/math/_einops_test.py +346 -0
- brainstate/math/_misc.py +298 -0
- brainstate/math/_misc_test.py +58 -0
- brainstate/mixin.py +373 -0
- brainstate/mixin_test.py +73 -0
- brainstate/nn/__init__.py +68 -0
- brainstate/nn/_base.py +248 -0
- brainstate/nn/_connections.py +686 -0
- brainstate/nn/_dynamics.py +406 -0
- brainstate/nn/_elementwise.py +1437 -0
- brainstate/nn/_misc.py +132 -0
- brainstate/nn/_normalizations.py +389 -0
- brainstate/nn/_others.py +100 -0
- brainstate/nn/_poolings.py +1228 -0
- brainstate/nn/_poolings_test.py +231 -0
- brainstate/nn/_projection/__init__.py +32 -0
- brainstate/nn/_projection/_align_post.py +528 -0
- brainstate/nn/_projection/_align_pre.py +599 -0
- brainstate/nn/_projection/_delta.py +241 -0
- brainstate/nn/_projection/_utils.py +17 -0
- brainstate/nn/_projection/_vanilla.py +101 -0
- brainstate/nn/_rate_rnns.py +393 -0
- brainstate/nn/_readout.py +130 -0
- brainstate/nn/_synouts.py +166 -0
- brainstate/nn/functional/__init__.py +25 -0
- brainstate/nn/functional/_activations.py +754 -0
- brainstate/nn/functional/_normalization.py +69 -0
- brainstate/nn/functional/_spikes.py +90 -0
- brainstate/nn/init/__init__.py +26 -0
- brainstate/nn/init/_base.py +36 -0
- brainstate/nn/init/_generic.py +175 -0
- brainstate/nn/init/_random_inits.py +489 -0
- brainstate/nn/init/_regular_inits.py +109 -0
- brainstate/nn/surrogate.py +1740 -0
- brainstate/optim/__init__.py +23 -0
- brainstate/optim/_lr_scheduler.py +486 -0
- brainstate/optim/_lr_scheduler_test.py +36 -0
- brainstate/optim/_sgd_optimizer.py +1148 -0
- brainstate/random.py +5148 -0
- brainstate/random_test.py +576 -0
- brainstate/surrogate.py +1740 -0
- brainstate/transform/__init__.py +36 -0
- brainstate/transform/_autograd.py +585 -0
- brainstate/transform/_autograd_test.py +1183 -0
- brainstate/transform/_control.py +665 -0
- brainstate/transform/_controls_test.py +220 -0
- brainstate/transform/_jit.py +239 -0
- brainstate/transform/_jit_error.py +158 -0
- brainstate/transform/_jit_test.py +102 -0
- brainstate/transform/_make_jaxpr.py +573 -0
- brainstate/transform/_make_jaxpr_test.py +133 -0
- brainstate/transform/_progress_bar.py +113 -0
- brainstate/typing.py +69 -0
- brainstate/util.py +747 -0
- brainstate-0.0.1.dist-info/LICENSE +202 -0
- brainstate-0.0.1.dist-info/METADATA +101 -0
- brainstate-0.0.1.dist-info/RECORD +79 -0
- brainstate-0.0.1.dist-info/WHEEL +6 -0
- brainstate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,133 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
import unittest
|
17
|
+
|
18
|
+
import jax
|
19
|
+
import jax.numpy as jnp
|
20
|
+
import pytest
|
21
|
+
|
22
|
+
import brainstate as bc
|
23
|
+
|
24
|
+
|
25
|
+
class TestMakeJaxpr(unittest.TestCase):
|
26
|
+
|
27
|
+
def test_compar_jax_make_jaxpr(self):
|
28
|
+
def func4(arg): # Arg is a pair
|
29
|
+
temp = arg[0] + jnp.sin(arg[1]) * 3.
|
30
|
+
c = bc.random.rand_like(arg[0])
|
31
|
+
return jnp.sum(temp + c)
|
32
|
+
|
33
|
+
key = bc.random.DEFAULT.value
|
34
|
+
jaxpr = jax.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
|
35
|
+
print(jaxpr)
|
36
|
+
self.assertTrue(len(jaxpr.in_avals) == 2)
|
37
|
+
self.assertTrue(len(jaxpr.consts) == 1)
|
38
|
+
self.assertTrue(len(jaxpr.out_avals) == 1)
|
39
|
+
self.assertTrue(jnp.allclose(jaxpr.consts[0], key))
|
40
|
+
|
41
|
+
jaxpr2, states = bc.transform.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
|
42
|
+
print(jaxpr2)
|
43
|
+
self.assertTrue(len(jaxpr2.in_avals) == 3)
|
44
|
+
self.assertTrue(len(jaxpr2.out_avals) == 2)
|
45
|
+
self.assertTrue(len(jaxpr2.consts) == 0)
|
46
|
+
|
47
|
+
def test_StatefulFunction_1(self):
|
48
|
+
def func4(arg): # Arg is a pair
|
49
|
+
temp = arg[0] + jnp.sin(arg[1]) * 3.
|
50
|
+
c = bc.random.rand_like(arg[0])
|
51
|
+
return jnp.sum(temp + c)
|
52
|
+
|
53
|
+
fun = bc.transform.StatefulFunction(func4).make_jaxpr((jnp.zeros(8), jnp.ones(8)))
|
54
|
+
print(fun.get_states())
|
55
|
+
print(fun.get_jaxpr())
|
56
|
+
|
57
|
+
def test_StatefulFunction_2(self):
|
58
|
+
st1 = bc.State(jnp.ones(10))
|
59
|
+
|
60
|
+
def f1(x):
|
61
|
+
st1.value = x + st1.value
|
62
|
+
|
63
|
+
def f2(x):
|
64
|
+
jaxpr = bc.transform.make_jaxpr(f1)(x)
|
65
|
+
c = 1. + x
|
66
|
+
return c
|
67
|
+
|
68
|
+
def f3(x):
|
69
|
+
jaxpr = bc.transform.make_jaxpr(f1)(x)
|
70
|
+
c = 1.
|
71
|
+
return c
|
72
|
+
|
73
|
+
print()
|
74
|
+
jaxpr = bc.transform.make_jaxpr(f1)(jnp.zeros(1))
|
75
|
+
print(jaxpr)
|
76
|
+
jaxpr = jax.make_jaxpr(f2)(jnp.zeros(1))
|
77
|
+
print(jaxpr)
|
78
|
+
jaxpr = jax.make_jaxpr(f3)(jnp.zeros(1))
|
79
|
+
print(jaxpr)
|
80
|
+
jaxpr, _ = bc.transform.make_jaxpr(f3)(jnp.zeros(1))
|
81
|
+
print(jaxpr)
|
82
|
+
self.assertTrue(jnp.allclose(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
|
83
|
+
f3(jnp.zeros(1))))
|
84
|
+
|
85
|
+
def test_compar_jax_make_jaxpr2(self):
|
86
|
+
st1 = bc.State(jnp.ones(10))
|
87
|
+
|
88
|
+
def fa(x):
|
89
|
+
st1.value = x + st1.value
|
90
|
+
|
91
|
+
def ffa(x):
|
92
|
+
jaxpr, states = bc.transform.make_jaxpr(fa)(x)
|
93
|
+
c = 1. + x
|
94
|
+
return c
|
95
|
+
|
96
|
+
jaxpr, states = bc.transform.make_jaxpr(ffa)(jnp.zeros(1))
|
97
|
+
print()
|
98
|
+
print(jaxpr)
|
99
|
+
print(states)
|
100
|
+
print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value))
|
101
|
+
jaxpr = jax.make_jaxpr(ffa)(jnp.zeros(1))
|
102
|
+
print(jaxpr)
|
103
|
+
print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
104
|
+
|
105
|
+
def test_compar_jax_make_jaxpr3(self):
|
106
|
+
def fa(x):
|
107
|
+
return 1.
|
108
|
+
|
109
|
+
jaxpr, states = bc.transform.make_jaxpr(fa)(jnp.zeros(1))
|
110
|
+
print()
|
111
|
+
print(jaxpr)
|
112
|
+
print(states)
|
113
|
+
# print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
114
|
+
jaxpr = jax.make_jaxpr(fa)(jnp.zeros(1))
|
115
|
+
print(jaxpr)
|
116
|
+
# print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
117
|
+
|
118
|
+
|
119
|
+
def test_return_states():
|
120
|
+
import jax.numpy
|
121
|
+
|
122
|
+
import brainstate as bc
|
123
|
+
|
124
|
+
a = bc.State(jax.numpy.ones(3))
|
125
|
+
|
126
|
+
@bc.transform.jit
|
127
|
+
def f():
|
128
|
+
return a
|
129
|
+
|
130
|
+
with pytest.raises(ValueError):
|
131
|
+
f()
|
132
|
+
|
133
|
+
|
@@ -0,0 +1,113 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
import copy
|
18
|
+
from typing import Optional
|
19
|
+
|
20
|
+
import jax
|
21
|
+
|
22
|
+
from brainstate import environ
|
23
|
+
|
24
|
+
try:
|
25
|
+
from tqdm.auto import tqdm
|
26
|
+
except (ImportError, ModuleNotFoundError):
|
27
|
+
tqdm = None
|
28
|
+
|
29
|
+
__all__ = [
|
30
|
+
'ProgressBar',
|
31
|
+
]
|
32
|
+
|
33
|
+
|
34
|
+
class ProgressBar(object):
|
35
|
+
__module__ = "brainstate.transform"
|
36
|
+
|
37
|
+
def __init__(self, freq: Optional[int] = None, count: Optional[int] = None, **kwargs):
|
38
|
+
self.print_freq = freq
|
39
|
+
self.print_count = count
|
40
|
+
if self.print_freq is not None and self.print_count is not None:
|
41
|
+
raise ValueError("Cannot specify both count and freq.")
|
42
|
+
for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
|
43
|
+
kwargs.pop(kwarg, None)
|
44
|
+
self.kwargs = kwargs
|
45
|
+
if tqdm is None:
|
46
|
+
raise ImportError("tqdm is not installed.")
|
47
|
+
|
48
|
+
def init(self, n: int):
|
49
|
+
kwargs = copy.copy(self.kwargs)
|
50
|
+
freq = self.print_freq
|
51
|
+
count = self.print_count
|
52
|
+
if count is not None:
|
53
|
+
freq, remainder = divmod(n, count)
|
54
|
+
if freq == 0:
|
55
|
+
raise ValueError(f"Count {count} is too large for n {n}.")
|
56
|
+
elif freq is None:
|
57
|
+
if n > 20:
|
58
|
+
freq = int(n / 20)
|
59
|
+
else:
|
60
|
+
freq = 1
|
61
|
+
remainder = n % freq
|
62
|
+
else:
|
63
|
+
if freq < 1:
|
64
|
+
raise ValueError(f"Print rate should be > 0 got {freq}")
|
65
|
+
elif freq > n:
|
66
|
+
raise ValueError("Print rate should be less than the "
|
67
|
+
f"number of steps {n}, got {freq}")
|
68
|
+
remainder = n % freq
|
69
|
+
desc = kwargs.pop("desc", f"Running for {n:,} iterations")
|
70
|
+
message = kwargs.pop("message", desc)
|
71
|
+
return ProgressBarRunner(n, message, freq, remainder, **kwargs)
|
72
|
+
|
73
|
+
|
74
|
+
class ProgressBarRunner(object):
|
75
|
+
__module__ = "brainstate.transform"
|
76
|
+
|
77
|
+
def __init__(self, n: int, message, print_freq: int, remainder: int, **kwargs):
|
78
|
+
self.tqdm_bars = {}
|
79
|
+
self.kwargs = kwargs
|
80
|
+
self.n = n
|
81
|
+
self.print_freq = print_freq
|
82
|
+
self.remainder = remainder
|
83
|
+
self.message = message
|
84
|
+
|
85
|
+
def _define_tqdm(self):
|
86
|
+
self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
|
87
|
+
self.tqdm_bars[0].set_description(self.message, refresh=False)
|
88
|
+
|
89
|
+
def _update_tqdm(self):
|
90
|
+
self.tqdm_bars[0].update(self.print_freq)
|
91
|
+
|
92
|
+
def _close_tqdm(self):
|
93
|
+
if self.remainder > 0:
|
94
|
+
self.tqdm_bars[0].update(self.remainder)
|
95
|
+
self.tqdm_bars[0].close()
|
96
|
+
|
97
|
+
def __call__(self, iter_num, *args, **kwargs):
|
98
|
+
|
99
|
+
_ = jax.lax.cond(
|
100
|
+
iter_num == 0,
|
101
|
+
lambda: jax.debug.callback(self._define_tqdm),
|
102
|
+
lambda: None,
|
103
|
+
)
|
104
|
+
_ = jax.lax.cond(
|
105
|
+
(iter_num + 1) % self.print_freq == 0,
|
106
|
+
lambda: jax.debug.callback(self._update_tqdm),
|
107
|
+
lambda: None,
|
108
|
+
)
|
109
|
+
_ = jax.lax.cond(
|
110
|
+
iter_num == self.n - 1,
|
111
|
+
lambda: jax.debug.callback(self._close_tqdm),
|
112
|
+
lambda: None,
|
113
|
+
)
|
brainstate/typing.py
ADDED
@@ -0,0 +1,69 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from typing import Any, Sequence, Protocol, Union
|
18
|
+
|
19
|
+
import jax
|
20
|
+
import numpy as np
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
'Size',
|
24
|
+
'Axes',
|
25
|
+
'SeedOrKey',
|
26
|
+
'ArrayLike',
|
27
|
+
'DType',
|
28
|
+
'DTypeLike',
|
29
|
+
]
|
30
|
+
|
31
|
+
Size = Union[int, Sequence[int]]
|
32
|
+
Axes = Union[int, Sequence[int]]
|
33
|
+
SeedOrKey = Union[int, jax.Array, np.ndarray]
|
34
|
+
|
35
|
+
# --- Array --- #
|
36
|
+
|
37
|
+
# ArrayLike is a Union of all objects that can be implicitly converted to a
|
38
|
+
# standard JAX array (i.e. not including future non-standard array types like
|
39
|
+
# KeyArray and BInt). It's different than np.typing.ArrayLike in that it doesn't
|
40
|
+
# accept arbitrary sequences, nor does it accept string data.
|
41
|
+
ArrayLike = Union[
|
42
|
+
jax.Array, # JAX array type
|
43
|
+
np.ndarray, # NumPy array type
|
44
|
+
np.bool_, np.number, # NumPy scalar types
|
45
|
+
bool, int, float, complex, # Python scalar types
|
46
|
+
]
|
47
|
+
|
48
|
+
# --- Dtype --- #
|
49
|
+
|
50
|
+
|
51
|
+
DType = np.dtype
|
52
|
+
|
53
|
+
|
54
|
+
class SupportsDType(Protocol):
|
55
|
+
@property
|
56
|
+
def dtype(self) -> DType: ...
|
57
|
+
|
58
|
+
|
59
|
+
# DTypeLike is meant to annotate inputs to np.dtype that return
|
60
|
+
# a valid JAX dtype. It's different than numpy.typing.DTypeLike
|
61
|
+
# because JAX doesn't support objects or structured dtypes.
|
62
|
+
# Unlike np.typing.DTypeLike, we exclude None, and instead require
|
63
|
+
# explicit annotations when None is acceptable.
|
64
|
+
DTypeLike = Union[
|
65
|
+
str, # like 'float32', 'int32'
|
66
|
+
type[Any], # like np.float32, np.int32, float, int
|
67
|
+
np.dtype, # like np.dtype('float32'), np.dtype('int32')
|
68
|
+
SupportsDType, # like jnp.float32, jnp.int32
|
69
|
+
]
|