brainstate 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. brainstate/__init__.py +45 -0
  2. brainstate/_module.py +1466 -0
  3. brainstate/_module_test.py +133 -0
  4. brainstate/_state.py +378 -0
  5. brainstate/_state_test.py +41 -0
  6. brainstate/_utils.py +21 -0
  7. brainstate/environ.py +375 -0
  8. brainstate/functional/__init__.py +25 -0
  9. brainstate/functional/_activations.py +754 -0
  10. brainstate/functional/_normalization.py +69 -0
  11. brainstate/functional/_spikes.py +90 -0
  12. brainstate/init/__init__.py +26 -0
  13. brainstate/init/_base.py +36 -0
  14. brainstate/init/_generic.py +175 -0
  15. brainstate/init/_random_inits.py +489 -0
  16. brainstate/init/_regular_inits.py +109 -0
  17. brainstate/math/__init__.py +21 -0
  18. brainstate/math/_einops.py +787 -0
  19. brainstate/math/_einops_parsing.py +169 -0
  20. brainstate/math/_einops_parsing_test.py +126 -0
  21. brainstate/math/_einops_test.py +346 -0
  22. brainstate/math/_misc.py +298 -0
  23. brainstate/math/_misc_test.py +58 -0
  24. brainstate/mixin.py +373 -0
  25. brainstate/mixin_test.py +73 -0
  26. brainstate/nn/__init__.py +68 -0
  27. brainstate/nn/_base.py +248 -0
  28. brainstate/nn/_connections.py +686 -0
  29. brainstate/nn/_dynamics.py +406 -0
  30. brainstate/nn/_elementwise.py +1437 -0
  31. brainstate/nn/_misc.py +132 -0
  32. brainstate/nn/_normalizations.py +389 -0
  33. brainstate/nn/_others.py +100 -0
  34. brainstate/nn/_poolings.py +1228 -0
  35. brainstate/nn/_poolings_test.py +231 -0
  36. brainstate/nn/_projection/__init__.py +32 -0
  37. brainstate/nn/_projection/_align_post.py +528 -0
  38. brainstate/nn/_projection/_align_pre.py +599 -0
  39. brainstate/nn/_projection/_delta.py +241 -0
  40. brainstate/nn/_projection/_utils.py +17 -0
  41. brainstate/nn/_projection/_vanilla.py +101 -0
  42. brainstate/nn/_rate_rnns.py +393 -0
  43. brainstate/nn/_readout.py +130 -0
  44. brainstate/nn/_synouts.py +166 -0
  45. brainstate/nn/functional/__init__.py +25 -0
  46. brainstate/nn/functional/_activations.py +754 -0
  47. brainstate/nn/functional/_normalization.py +69 -0
  48. brainstate/nn/functional/_spikes.py +90 -0
  49. brainstate/nn/init/__init__.py +26 -0
  50. brainstate/nn/init/_base.py +36 -0
  51. brainstate/nn/init/_generic.py +175 -0
  52. brainstate/nn/init/_random_inits.py +489 -0
  53. brainstate/nn/init/_regular_inits.py +109 -0
  54. brainstate/nn/surrogate.py +1740 -0
  55. brainstate/optim/__init__.py +23 -0
  56. brainstate/optim/_lr_scheduler.py +486 -0
  57. brainstate/optim/_lr_scheduler_test.py +36 -0
  58. brainstate/optim/_sgd_optimizer.py +1148 -0
  59. brainstate/random.py +5148 -0
  60. brainstate/random_test.py +576 -0
  61. brainstate/surrogate.py +1740 -0
  62. brainstate/transform/__init__.py +36 -0
  63. brainstate/transform/_autograd.py +585 -0
  64. brainstate/transform/_autograd_test.py +1183 -0
  65. brainstate/transform/_control.py +665 -0
  66. brainstate/transform/_controls_test.py +220 -0
  67. brainstate/transform/_jit.py +239 -0
  68. brainstate/transform/_jit_error.py +158 -0
  69. brainstate/transform/_jit_test.py +102 -0
  70. brainstate/transform/_make_jaxpr.py +573 -0
  71. brainstate/transform/_make_jaxpr_test.py +133 -0
  72. brainstate/transform/_progress_bar.py +113 -0
  73. brainstate/typing.py +69 -0
  74. brainstate/util.py +747 -0
  75. brainstate-0.0.1.dist-info/LICENSE +202 -0
  76. brainstate-0.0.1.dist-info/METADATA +101 -0
  77. brainstate-0.0.1.dist-info/RECORD +79 -0
  78. brainstate-0.0.1.dist-info/WHEEL +6 -0
  79. brainstate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,133 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import unittest
17
+
18
+ import jax
19
+ import jax.numpy as jnp
20
+ import pytest
21
+
22
+ import brainstate as bc
23
+
24
+
25
+ class TestMakeJaxpr(unittest.TestCase):
26
+
27
+ def test_compar_jax_make_jaxpr(self):
28
+ def func4(arg): # Arg is a pair
29
+ temp = arg[0] + jnp.sin(arg[1]) * 3.
30
+ c = bc.random.rand_like(arg[0])
31
+ return jnp.sum(temp + c)
32
+
33
+ key = bc.random.DEFAULT.value
34
+ jaxpr = jax.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
35
+ print(jaxpr)
36
+ self.assertTrue(len(jaxpr.in_avals) == 2)
37
+ self.assertTrue(len(jaxpr.consts) == 1)
38
+ self.assertTrue(len(jaxpr.out_avals) == 1)
39
+ self.assertTrue(jnp.allclose(jaxpr.consts[0], key))
40
+
41
+ jaxpr2, states = bc.transform.make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))
42
+ print(jaxpr2)
43
+ self.assertTrue(len(jaxpr2.in_avals) == 3)
44
+ self.assertTrue(len(jaxpr2.out_avals) == 2)
45
+ self.assertTrue(len(jaxpr2.consts) == 0)
46
+
47
+ def test_StatefulFunction_1(self):
48
+ def func4(arg): # Arg is a pair
49
+ temp = arg[0] + jnp.sin(arg[1]) * 3.
50
+ c = bc.random.rand_like(arg[0])
51
+ return jnp.sum(temp + c)
52
+
53
+ fun = bc.transform.StatefulFunction(func4).make_jaxpr((jnp.zeros(8), jnp.ones(8)))
54
+ print(fun.get_states())
55
+ print(fun.get_jaxpr())
56
+
57
+ def test_StatefulFunction_2(self):
58
+ st1 = bc.State(jnp.ones(10))
59
+
60
+ def f1(x):
61
+ st1.value = x + st1.value
62
+
63
+ def f2(x):
64
+ jaxpr = bc.transform.make_jaxpr(f1)(x)
65
+ c = 1. + x
66
+ return c
67
+
68
+ def f3(x):
69
+ jaxpr = bc.transform.make_jaxpr(f1)(x)
70
+ c = 1.
71
+ return c
72
+
73
+ print()
74
+ jaxpr = bc.transform.make_jaxpr(f1)(jnp.zeros(1))
75
+ print(jaxpr)
76
+ jaxpr = jax.make_jaxpr(f2)(jnp.zeros(1))
77
+ print(jaxpr)
78
+ jaxpr = jax.make_jaxpr(f3)(jnp.zeros(1))
79
+ print(jaxpr)
80
+ jaxpr, _ = bc.transform.make_jaxpr(f3)(jnp.zeros(1))
81
+ print(jaxpr)
82
+ self.assertTrue(jnp.allclose(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
83
+ f3(jnp.zeros(1))))
84
+
85
+ def test_compar_jax_make_jaxpr2(self):
86
+ st1 = bc.State(jnp.ones(10))
87
+
88
+ def fa(x):
89
+ st1.value = x + st1.value
90
+
91
+ def ffa(x):
92
+ jaxpr, states = bc.transform.make_jaxpr(fa)(x)
93
+ c = 1. + x
94
+ return c
95
+
96
+ jaxpr, states = bc.transform.make_jaxpr(ffa)(jnp.zeros(1))
97
+ print()
98
+ print(jaxpr)
99
+ print(states)
100
+ print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value))
101
+ jaxpr = jax.make_jaxpr(ffa)(jnp.zeros(1))
102
+ print(jaxpr)
103
+ print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
104
+
105
+ def test_compar_jax_make_jaxpr3(self):
106
+ def fa(x):
107
+ return 1.
108
+
109
+ jaxpr, states = bc.transform.make_jaxpr(fa)(jnp.zeros(1))
110
+ print()
111
+ print(jaxpr)
112
+ print(states)
113
+ # print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
114
+ jaxpr = jax.make_jaxpr(fa)(jnp.zeros(1))
115
+ print(jaxpr)
116
+ # print(jax.core.jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
117
+
118
+
119
+ def test_return_states():
120
+ import jax.numpy
121
+
122
+ import brainstate as bc
123
+
124
+ a = bc.State(jax.numpy.ones(3))
125
+
126
+ @bc.transform.jit
127
+ def f():
128
+ return a
129
+
130
+ with pytest.raises(ValueError):
131
+ f()
132
+
133
+
@@ -0,0 +1,113 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+ import copy
18
+ from typing import Optional
19
+
20
+ import jax
21
+
22
+ from brainstate import environ
23
+
24
+ try:
25
+ from tqdm.auto import tqdm
26
+ except (ImportError, ModuleNotFoundError):
27
+ tqdm = None
28
+
29
+ __all__ = [
30
+ 'ProgressBar',
31
+ ]
32
+
33
+
34
+ class ProgressBar(object):
35
+ __module__ = "brainstate.transform"
36
+
37
+ def __init__(self, freq: Optional[int] = None, count: Optional[int] = None, **kwargs):
38
+ self.print_freq = freq
39
+ self.print_count = count
40
+ if self.print_freq is not None and self.print_count is not None:
41
+ raise ValueError("Cannot specify both count and freq.")
42
+ for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
43
+ kwargs.pop(kwarg, None)
44
+ self.kwargs = kwargs
45
+ if tqdm is None:
46
+ raise ImportError("tqdm is not installed.")
47
+
48
+ def init(self, n: int):
49
+ kwargs = copy.copy(self.kwargs)
50
+ freq = self.print_freq
51
+ count = self.print_count
52
+ if count is not None:
53
+ freq, remainder = divmod(n, count)
54
+ if freq == 0:
55
+ raise ValueError(f"Count {count} is too large for n {n}.")
56
+ elif freq is None:
57
+ if n > 20:
58
+ freq = int(n / 20)
59
+ else:
60
+ freq = 1
61
+ remainder = n % freq
62
+ else:
63
+ if freq < 1:
64
+ raise ValueError(f"Print rate should be > 0 got {freq}")
65
+ elif freq > n:
66
+ raise ValueError("Print rate should be less than the "
67
+ f"number of steps {n}, got {freq}")
68
+ remainder = n % freq
69
+ desc = kwargs.pop("desc", f"Running for {n:,} iterations")
70
+ message = kwargs.pop("message", desc)
71
+ return ProgressBarRunner(n, message, freq, remainder, **kwargs)
72
+
73
+
74
+ class ProgressBarRunner(object):
75
+ __module__ = "brainstate.transform"
76
+
77
+ def __init__(self, n: int, message, print_freq: int, remainder: int, **kwargs):
78
+ self.tqdm_bars = {}
79
+ self.kwargs = kwargs
80
+ self.n = n
81
+ self.print_freq = print_freq
82
+ self.remainder = remainder
83
+ self.message = message
84
+
85
+ def _define_tqdm(self):
86
+ self.tqdm_bars[0] = tqdm(range(self.n), **self.kwargs)
87
+ self.tqdm_bars[0].set_description(self.message, refresh=False)
88
+
89
+ def _update_tqdm(self):
90
+ self.tqdm_bars[0].update(self.print_freq)
91
+
92
+ def _close_tqdm(self):
93
+ if self.remainder > 0:
94
+ self.tqdm_bars[0].update(self.remainder)
95
+ self.tqdm_bars[0].close()
96
+
97
+ def __call__(self, iter_num, *args, **kwargs):
98
+
99
+ _ = jax.lax.cond(
100
+ iter_num == 0,
101
+ lambda: jax.debug.callback(self._define_tqdm),
102
+ lambda: None,
103
+ )
104
+ _ = jax.lax.cond(
105
+ (iter_num + 1) % self.print_freq == 0,
106
+ lambda: jax.debug.callback(self._update_tqdm),
107
+ lambda: None,
108
+ )
109
+ _ = jax.lax.cond(
110
+ iter_num == self.n - 1,
111
+ lambda: jax.debug.callback(self._close_tqdm),
112
+ lambda: None,
113
+ )
brainstate/typing.py ADDED
@@ -0,0 +1,69 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from typing import Any, Sequence, Protocol, Union
18
+
19
+ import jax
20
+ import numpy as np
21
+
22
+ __all__ = [
23
+ 'Size',
24
+ 'Axes',
25
+ 'SeedOrKey',
26
+ 'ArrayLike',
27
+ 'DType',
28
+ 'DTypeLike',
29
+ ]
30
+
31
+ Size = Union[int, Sequence[int]]
32
+ Axes = Union[int, Sequence[int]]
33
+ SeedOrKey = Union[int, jax.Array, np.ndarray]
34
+
35
+ # --- Array --- #
36
+
37
+ # ArrayLike is a Union of all objects that can be implicitly converted to a
38
+ # standard JAX array (i.e. not including future non-standard array types like
39
+ # KeyArray and BInt). It's different than np.typing.ArrayLike in that it doesn't
40
+ # accept arbitrary sequences, nor does it accept string data.
41
+ ArrayLike = Union[
42
+ jax.Array, # JAX array type
43
+ np.ndarray, # NumPy array type
44
+ np.bool_, np.number, # NumPy scalar types
45
+ bool, int, float, complex, # Python scalar types
46
+ ]
47
+
48
+ # --- Dtype --- #
49
+
50
+
51
+ DType = np.dtype
52
+
53
+
54
+ class SupportsDType(Protocol):
55
+ @property
56
+ def dtype(self) -> DType: ...
57
+
58
+
59
+ # DTypeLike is meant to annotate inputs to np.dtype that return
60
+ # a valid JAX dtype. It's different than numpy.typing.DTypeLike
61
+ # because JAX doesn't support objects or structured dtypes.
62
+ # Unlike np.typing.DTypeLike, we exclude None, and instead require
63
+ # explicit annotations when None is acceptable.
64
+ DTypeLike = Union[
65
+ str, # like 'float32', 'int32'
66
+ type[Any], # like np.float32, np.int32, float, int
67
+ np.dtype, # like np.dtype('float32'), np.dtype('int32')
68
+ SupportsDType, # like jnp.float32, jnp.int32
69
+ ]