brainstate 0.0.1__py2.py3-none-any.whl → 0.0.1.post20240622__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. brainstate/__init__.py +4 -5
  2. brainstate/_module.py +191 -48
  3. brainstate/_module_test.py +95 -21
  4. brainstate/_state.py +17 -0
  5. brainstate/environ.py +2 -2
  6. brainstate/functional/__init__.py +3 -2
  7. brainstate/functional/_activations.py +7 -26
  8. brainstate/functional/_normalization.py +3 -0
  9. brainstate/functional/_others.py +49 -0
  10. brainstate/functional/_spikes.py +0 -1
  11. brainstate/mixin.py +2 -2
  12. brainstate/nn/__init__.py +4 -0
  13. brainstate/nn/_base.py +10 -7
  14. brainstate/nn/_dynamics.py +20 -0
  15. brainstate/nn/_elementwise.py +5 -4
  16. brainstate/nn/_embedding.py +66 -0
  17. brainstate/nn/_misc.py +4 -3
  18. brainstate/nn/_others.py +3 -2
  19. brainstate/nn/_poolings.py +21 -20
  20. brainstate/nn/_poolings_test.py +4 -4
  21. brainstate/nn/_rate_rnns.py +17 -0
  22. brainstate/nn/_readout.py +6 -0
  23. brainstate/optim/__init__.py +0 -1
  24. brainstate/optim/_lr_scheduler_test.py +13 -0
  25. brainstate/optim/_sgd_optimizer.py +18 -17
  26. brainstate/transform/__init__.py +2 -3
  27. brainstate/transform/_autograd.py +1 -1
  28. brainstate/transform/_autograd_test.py +0 -2
  29. brainstate/transform/_jit.py +47 -21
  30. brainstate/transform/_jit_test.py +0 -3
  31. brainstate/transform/_make_jaxpr.py +164 -3
  32. brainstate/transform/_make_jaxpr_test.py +0 -2
  33. brainstate/transform/_progress_bar.py +1 -3
  34. brainstate/util.py +0 -1
  35. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/METADATA +9 -17
  36. brainstate-0.0.1.post20240622.dist-info/RECORD +64 -0
  37. brainstate/math/__init__.py +0 -21
  38. brainstate/math/_einops.py +0 -787
  39. brainstate/math/_einops_parsing.py +0 -169
  40. brainstate/math/_einops_parsing_test.py +0 -126
  41. brainstate/math/_einops_test.py +0 -346
  42. brainstate/math/_misc.py +0 -298
  43. brainstate/math/_misc_test.py +0 -58
  44. brainstate/nn/functional/__init__.py +0 -25
  45. brainstate/nn/functional/_activations.py +0 -754
  46. brainstate/nn/functional/_normalization.py +0 -69
  47. brainstate/nn/functional/_spikes.py +0 -90
  48. brainstate/nn/init/__init__.py +0 -26
  49. brainstate/nn/init/_base.py +0 -36
  50. brainstate/nn/init/_generic.py +0 -175
  51. brainstate/nn/init/_random_inits.py +0 -489
  52. brainstate/nn/init/_regular_inits.py +0 -109
  53. brainstate/nn/surrogate.py +0 -1740
  54. brainstate-0.0.1.dist-info/RECORD +0 -79
  55. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/LICENSE +0 -0
  56. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/WHEEL +0 -0
  57. {brainstate-0.0.1.dist-info → brainstate-0.0.1.post20240622.dist-info}/top_level.txt +0 -0
@@ -1,69 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
-
18
- from typing import Optional
19
-
20
- import jax
21
- import jax.numpy as jnp
22
-
23
- __all__ = [
24
- 'weight_standardization',
25
- ]
26
-
27
-
28
- def weight_standardization(
29
- w: jax.typing.ArrayLike,
30
- eps: float = 1e-4,
31
- gain: Optional[jax.Array] = None,
32
- out_axis: int = -1,
33
- ):
34
- """
35
- Scaled Weight Standardization,
36
- see `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization <https://paperswithcode.com/paper/weight-standardization>`_.
37
-
38
- Parameters
39
- ----------
40
- w : jax.typing.ArrayLike
41
- The weight tensor.
42
- eps : float
43
- A small value to avoid division by zero.
44
- gain : Array
45
- The gain function, by default None.
46
- out_axis : int
47
- The output axis, by default -1.
48
-
49
- Returns
50
- -------
51
- jax.typing.ArrayLike
52
- The scaled weight tensor.
53
- """
54
- if out_axis < 0:
55
- out_axis = w.ndim + out_axis
56
- fan_in = 1 # get the fan-in of the weight tensor
57
- axes = [] # get the axes of the weight tensor
58
- for i in range(w.ndim):
59
- if i != out_axis:
60
- fan_in *= w.shape[i]
61
- axes.append(i)
62
- # normalize the weight
63
- mean = jnp.mean(w, axis=axes, keepdims=True)
64
- var = jnp.var(w, axis=axes, keepdims=True)
65
- scale = jax.lax.rsqrt(jnp.maximum(var * fan_in, eps))
66
- if gain is not None:
67
- scale = gain * scale
68
- shift = mean * scale
69
- return w * scale - shift
@@ -1,90 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- from __future__ import annotations
17
-
18
- __all__ = [
19
- 'spike_bitwise_or',
20
- 'spike_bitwise_and',
21
- 'spike_bitwise_iand',
22
- 'spike_bitwise_not',
23
- 'spike_bitwise_xor',
24
- 'spike_bitwise_ixor',
25
- 'spike_bitwise',
26
- ]
27
-
28
-
29
- def spike_bitwise_or(x, y):
30
- """Bitwise OR operation for spike tensors."""
31
- return x + y - x * y
32
-
33
-
34
- def spike_bitwise_and(x, y):
35
- """Bitwise AND operation for spike tensors."""
36
- return x * y
37
-
38
-
39
- def spike_bitwise_iand(x, y):
40
- """Bitwise IAND operation for spike tensors."""
41
- return (1 - x) * y
42
-
43
-
44
- def spike_bitwise_not(x):
45
- """Bitwise NOT operation for spike tensors."""
46
- return 1 - x
47
-
48
-
49
- def spike_bitwise_xor(x, y):
50
- """Bitwise XOR operation for spike tensors."""
51
- return x + y - 2 * x * y
52
-
53
-
54
- def spike_bitwise_ixor(x, y):
55
- """Bitwise IXOR operation for spike tensors."""
56
- return x * (1 - y) + (1 - x) * y
57
-
58
-
59
- def spike_bitwise(x, y, op: str):
60
- r"""Bitwise operation for spike tensors.
61
-
62
- .. math::
63
-
64
- \begin{array}{ccc}
65
- \hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
66
- \hline \text { ADD } & x+y & x+y \\
67
- \text { AND } & x \cap y & x \cdot y \\
68
- \text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
69
- \text { OR } & x \cup y & (x+y)-(x \cdot y) \\
70
- \hline
71
- \end{array}
72
-
73
- Args:
74
- x: A spike tensor.
75
- y: A spike tensor.
76
- op: A string indicating the bitwise operation to perform.
77
- """
78
- if op == 'or':
79
- return spike_bitwise_or(x, y)
80
- elif op == 'and':
81
- return spike_bitwise_and(x, y)
82
- elif op == 'iand':
83
- return spike_bitwise_iand(x, y)
84
- elif op == 'xor':
85
- return spike_bitwise_xor(x, y)
86
- elif op == 'ixor':
87
- return spike_bitwise_ixor(x, y)
88
- else:
89
- raise NotImplementedError(f"Unsupported bitwise operation: {op}.")
90
-
@@ -1,26 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from ._base import *
18
- from ._base import __all__ as _base_all
19
- from ._generic import *
20
- from ._generic import __all__ as _generic_all
21
- from ._random_inits import *
22
- from ._random_inits import __all__ as _random_inits_all
23
- from ._regular_inits import *
24
- from ._regular_inits import __all__ as _regular_inits_all
25
-
26
- __all__ = _generic_all + _base_all + _regular_inits_all + _random_inits_all
@@ -1,36 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
-
17
- from typing import Optional, Tuple
18
-
19
- import numpy as np
20
-
21
- __all__ = ['Initializer', 'to_size']
22
-
23
-
24
- class Initializer(object):
25
- def __call__(self, *args, **kwargs):
26
- raise NotImplementedError
27
-
28
-
29
- def to_size(x) -> Optional[Tuple[int]]:
30
- if isinstance(x, (tuple, list)):
31
- return tuple(x)
32
- if isinstance(x, (int, np.integer)):
33
- return (x,)
34
- if x is None:
35
- return x
36
- raise ValueError(f'Cannot make a size for {x}')
@@ -1,175 +0,0 @@
1
- # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- # -*- coding: utf-8 -*-
17
-
18
- import numbers
19
- from typing import Union, Callable, Optional, Sequence
20
-
21
- import jax
22
- import jax.numpy as jnp
23
- import numpy as np
24
-
25
- from brainstate._state import State
26
- from ._base import to_size
27
-
28
- __all__ = [
29
- 'param',
30
- 'state',
31
- 'noise',
32
- ]
33
-
34
-
35
- def _is_scalar(x):
36
- return isinstance(x, numbers.Number)
37
-
38
-
39
- def param(
40
- param: Union[Callable, np.ndarray, jax.Array, float, int, bool],
41
- sizes: Union[int, Sequence[int]],
42
- batch_size: Optional[int] = None,
43
- allow_none: bool = True,
44
- allow_scalar: bool = True,
45
- ):
46
- """Initialize parameters.
47
-
48
- Parameters
49
- ----------
50
- param: callable, Initializer, bm.ndarray, jnp.ndarray, onp.ndarray, float, int, bool
51
- The initialization of the parameter.
52
- - If it is None, the created parameter will be None.
53
- - If it is a callable function :math:`f`, the ``f(size)`` will be returned.
54
- - If it is an instance of :py:class:`init.Initializer``, the ``f(size)`` will be returned.
55
- - If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``.
56
- sizes: int, sequence of int
57
- The shape of the parameter.
58
- batch_size: int
59
- The batch size.
60
- allow_none: bool
61
- Whether allow the parameter is None.
62
- allow_scalar: bool
63
- Whether allow the parameter is a scalar value.
64
-
65
- Returns
66
- -------
67
- param: ArrayType, float, int, bool, None
68
- The initialized parameter.
69
-
70
- See Also
71
- --------
72
- noise, state
73
- """
74
- if param is None:
75
- if allow_none:
76
- return None
77
- else:
78
- raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or '
79
- f'Callable function, but we got None. ')
80
- sizes = list(to_size(sizes))
81
- if allow_scalar and _is_scalar(param):
82
- return param
83
-
84
- if batch_size is not None:
85
- sizes.insert(0, batch_size)
86
-
87
- if callable(param):
88
- return param(sizes)
89
- elif isinstance(param, (np.ndarray, jax.Array)):
90
- param = jnp.asarray(param)
91
- if batch_size is not None:
92
- param = jnp.repeat(jnp.expand_dims(param, axis=0), batch_size, axis=0)
93
- elif isinstance(param, State):
94
- param = param
95
- if batch_size is not None:
96
- param = type(param)(jnp.repeat(jnp.expand_dims(param.value, axis=batch_axis), batch_size, axis=batch_axis))
97
- else:
98
- raise ValueError(f'Unknown parameter type: {type(param)}')
99
-
100
- if allow_scalar:
101
- if param.shape == () or param.shape == (1,):
102
- return param
103
- if param.shape != tuple(sizes):
104
- raise ValueError(f'The shape of the parameters should be {sizes}, but we got {param.shape}')
105
- return param
106
-
107
-
108
- def state(
109
- init: Union[Callable, np.ndarray, jax.Array],
110
- sizes: Union[int, Sequence[int]] = None,
111
- batch_size: Optional[int] = None,
112
- ):
113
- """
114
- Initialize a :math:`~.State` from a callable function or a data.
115
- """
116
- sizes = to_size(sizes)
117
- if callable(init):
118
- if sizes is None:
119
- raise ValueError('"varshape" cannot be None when data is a callable function.')
120
- sizes = list(sizes)
121
- if isinstance(batch_size, int):
122
- sizes.insert(0, batch_size)
123
- return State(init(sizes))
124
-
125
- else:
126
- if sizes is not None:
127
- if jnp.shape(init) != sizes:
128
- raise ValueError(f'The shape of "data" {jnp.shape(init)} does not match with "var_shape" {sizes}')
129
- if isinstance(batch_size, int):
130
- batch_size = batch_size
131
- data = State(jnp.repeat(jnp.expand_dims(init, axis=0), batch_size, axis=0))
132
- else:
133
- data = State(init)
134
- return data
135
-
136
-
137
- def noise(
138
- noises: Optional[Union[int, float, np.ndarray, jax.Array, Callable]],
139
- size: Union[int, Sequence[int]],
140
- num_vars: int = 1,
141
- noise_idx: int = 0,
142
- ) -> Optional[Callable]:
143
- """Initialize a noise function.
144
-
145
- Parameters
146
- ----------
147
- noises: Any
148
- size: Shape
149
- The size of the noise.
150
- num_vars: int
151
- The number of variables.
152
- noise_idx: int
153
- The index of the current noise among all noise variables.
154
-
155
- Returns
156
- -------
157
- noise_func: function, None
158
- The noise function.
159
-
160
- See Also
161
- --------
162
- variable_, parameter, delay
163
-
164
- """
165
- if callable(noises):
166
- return noises
167
- elif noises is None:
168
- return None
169
- else:
170
- noises = param(noises, size, allow_none=False)
171
- if num_vars > 1:
172
- noises_ = [None] * num_vars
173
- noises_[noise_idx] = noises
174
- noises = tuple(noises_)
175
- return lambda *args, **kwargs: noises