brainstate 0.0.2.post20241010__py2.py3-none-any.whl → 0.1.0.post20241122__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmark/COBA_2005.py +125 -0
- benchmark/CUBA_2005.py +149 -0
- brainstate/__init__.py +31 -11
- brainstate/_state.py +760 -316
- brainstate/_state_test.py +41 -12
- brainstate/_utils.py +31 -4
- brainstate/augment/__init__.py +40 -0
- brainstate/augment/_autograd.py +611 -0
- brainstate/augment/_autograd_test.py +1193 -0
- brainstate/augment/_eval_shape.py +102 -0
- brainstate/augment/_eval_shape_test.py +40 -0
- brainstate/augment/_mapping.py +525 -0
- brainstate/augment/_mapping_test.py +210 -0
- brainstate/augment/_random.py +99 -0
- brainstate/{transform → compile}/__init__.py +25 -13
- brainstate/compile/_ad_checkpoint.py +204 -0
- brainstate/compile/_ad_checkpoint_test.py +51 -0
- brainstate/compile/_conditions.py +259 -0
- brainstate/compile/_conditions_test.py +221 -0
- brainstate/compile/_error_if.py +94 -0
- brainstate/compile/_error_if_test.py +54 -0
- brainstate/compile/_jit.py +314 -0
- brainstate/compile/_jit_test.py +143 -0
- brainstate/compile/_loop_collect_return.py +516 -0
- brainstate/compile/_loop_collect_return_test.py +59 -0
- brainstate/compile/_loop_no_collection.py +185 -0
- brainstate/compile/_loop_no_collection_test.py +51 -0
- brainstate/compile/_make_jaxpr.py +756 -0
- brainstate/compile/_make_jaxpr_test.py +134 -0
- brainstate/compile/_progress_bar.py +111 -0
- brainstate/compile/_unvmap.py +159 -0
- brainstate/compile/_util.py +147 -0
- brainstate/environ.py +408 -381
- brainstate/environ_test.py +34 -32
- brainstate/event/__init__.py +27 -0
- brainstate/event/_csr.py +316 -0
- brainstate/event/_csr_benchmark.py +14 -0
- brainstate/event/_csr_test.py +118 -0
- brainstate/event/_fixed_probability.py +708 -0
- brainstate/event/_fixed_probability_benchmark.py +128 -0
- brainstate/event/_fixed_probability_test.py +131 -0
- brainstate/event/_linear.py +359 -0
- brainstate/event/_linear_benckmark.py +82 -0
- brainstate/event/_linear_test.py +117 -0
- brainstate/{nn/event → event}/_misc.py +7 -7
- brainstate/event/_xla_custom_op.py +312 -0
- brainstate/event/_xla_custom_op_test.py +55 -0
- brainstate/functional/_activations.py +521 -511
- brainstate/functional/_activations_test.py +300 -300
- brainstate/functional/_normalization.py +43 -43
- brainstate/functional/_others.py +15 -15
- brainstate/functional/_spikes.py +49 -49
- brainstate/graph/__init__.py +33 -0
- brainstate/graph/_graph_context.py +443 -0
- brainstate/graph/_graph_context_test.py +65 -0
- brainstate/graph/_graph_convert.py +246 -0
- brainstate/graph/_graph_node.py +300 -0
- brainstate/graph/_graph_node_test.py +75 -0
- brainstate/graph/_graph_operation.py +1746 -0
- brainstate/graph/_graph_operation_test.py +724 -0
- brainstate/init/_base.py +28 -10
- brainstate/init/_generic.py +175 -172
- brainstate/init/_random_inits.py +470 -415
- brainstate/init/_random_inits_test.py +150 -0
- brainstate/init/_regular_inits.py +66 -69
- brainstate/init/_regular_inits_test.py +51 -0
- brainstate/mixin.py +236 -244
- brainstate/mixin_test.py +44 -46
- brainstate/nn/__init__.py +26 -51
- brainstate/nn/_collective_ops.py +199 -0
- brainstate/nn/_dyn_impl/__init__.py +46 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +290 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +162 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +315 -0
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +132 -0
- brainstate/nn/_dyn_impl/_inputs.py +154 -0
- brainstate/nn/{event/__init__.py → _dyn_impl/_projection_alignpost.py} +8 -8
- brainstate/nn/_dyn_impl/_rate_rnns.py +400 -0
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +64 -0
- brainstate/nn/_dyn_impl/_readout.py +128 -0
- brainstate/nn/_dyn_impl/_readout_test.py +54 -0
- brainstate/nn/_dynamics/__init__.py +37 -0
- brainstate/nn/_dynamics/_dynamics_base.py +631 -0
- brainstate/nn/_dynamics/_dynamics_base_test.py +79 -0
- brainstate/nn/_dynamics/_projection_base.py +346 -0
- brainstate/nn/_dynamics/_state_delay.py +453 -0
- brainstate/nn/_dynamics/_synouts.py +161 -0
- brainstate/nn/_dynamics/_synouts_test.py +58 -0
- brainstate/nn/_elementwise/__init__.py +22 -0
- brainstate/nn/_elementwise/_dropout.py +418 -0
- brainstate/nn/_elementwise/_dropout_test.py +100 -0
- brainstate/nn/_elementwise/_elementwise.py +1122 -0
- brainstate/nn/_elementwise/_elementwise_test.py +171 -0
- brainstate/nn/_exp_euler.py +97 -0
- brainstate/nn/_exp_euler_test.py +36 -0
- brainstate/nn/_interaction/__init__.py +41 -0
- brainstate/nn/_interaction/_conv.py +499 -0
- brainstate/nn/_interaction/_conv_test.py +239 -0
- brainstate/nn/_interaction/_embedding.py +59 -0
- brainstate/nn/_interaction/_linear.py +582 -0
- brainstate/nn/_interaction/_linear_test.py +42 -0
- brainstate/nn/_interaction/_normalizations.py +388 -0
- brainstate/nn/_interaction/_normalizations_test.py +75 -0
- brainstate/nn/_interaction/_poolings.py +1179 -0
- brainstate/nn/_interaction/_poolings_test.py +219 -0
- brainstate/nn/_module.py +328 -0
- brainstate/nn/_module_test.py +211 -0
- brainstate/nn/metrics.py +309 -309
- brainstate/optim/__init__.py +14 -2
- brainstate/optim/_base.py +66 -0
- brainstate/optim/_lr_scheduler.py +363 -400
- brainstate/optim/_lr_scheduler_test.py +25 -24
- brainstate/optim/_optax_optimizer.py +121 -176
- brainstate/optim/_optax_optimizer_test.py +41 -1
- brainstate/optim/_sgd_optimizer.py +950 -1025
- brainstate/random/_rand_funs.py +3269 -3268
- brainstate/random/_rand_funs_test.py +568 -0
- brainstate/random/_rand_seed.py +149 -117
- brainstate/random/_rand_seed_test.py +50 -0
- brainstate/random/_rand_state.py +1356 -1321
- brainstate/random/_random_for_unit.py +13 -13
- brainstate/surrogate.py +1262 -1243
- brainstate/{nn/_projection/_utils.py → transform.py} +1 -2
- brainstate/typing.py +157 -130
- brainstate/util/__init__.py +52 -0
- brainstate/util/_caller.py +100 -0
- brainstate/util/_dict.py +734 -0
- brainstate/util/_dict_test.py +160 -0
- brainstate/{nn/_projection/__init__.py → util/_error.py} +9 -13
- brainstate/util/_filter.py +178 -0
- brainstate/util/_others.py +497 -0
- brainstate/util/_pretty_repr.py +208 -0
- brainstate/util/_scaling.py +260 -0
- brainstate/util/_struct.py +524 -0
- brainstate/util/_tracers.py +75 -0
- brainstate/{_visualization.py → util/_visualization.py} +16 -16
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/METADATA +11 -11
- brainstate-0.1.0.post20241122.dist-info/RECORD +144 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/top_level.txt +1 -0
- brainstate/_module.py +0 -1637
- brainstate/_module_test.py +0 -207
- brainstate/nn/_base.py +0 -251
- brainstate/nn/_connections.py +0 -686
- brainstate/nn/_dynamics.py +0 -426
- brainstate/nn/_elementwise.py +0 -1438
- brainstate/nn/_embedding.py +0 -66
- brainstate/nn/_misc.py +0 -133
- brainstate/nn/_normalizations.py +0 -389
- brainstate/nn/_others.py +0 -101
- brainstate/nn/_poolings.py +0 -1229
- brainstate/nn/_poolings_test.py +0 -231
- brainstate/nn/_projection/_align_post.py +0 -546
- brainstate/nn/_projection/_align_pre.py +0 -599
- brainstate/nn/_projection/_delta.py +0 -241
- brainstate/nn/_projection/_vanilla.py +0 -101
- brainstate/nn/_rate_rnns.py +0 -410
- brainstate/nn/_readout.py +0 -136
- brainstate/nn/_synouts.py +0 -166
- brainstate/nn/event/csr.py +0 -312
- brainstate/nn/event/csr_test.py +0 -118
- brainstate/nn/event/fixed_probability.py +0 -276
- brainstate/nn/event/fixed_probability_test.py +0 -127
- brainstate/nn/event/linear.py +0 -220
- brainstate/nn/event/linear_test.py +0 -111
- brainstate/random/random_test.py +0 -593
- brainstate/transform/_autograd.py +0 -585
- brainstate/transform/_autograd_test.py +0 -1181
- brainstate/transform/_conditions.py +0 -334
- brainstate/transform/_conditions_test.py +0 -220
- brainstate/transform/_error_if.py +0 -94
- brainstate/transform/_error_if_test.py +0 -55
- brainstate/transform/_jit.py +0 -265
- brainstate/transform/_jit_test.py +0 -118
- brainstate/transform/_loop_collect_return.py +0 -502
- brainstate/transform/_loop_no_collection.py +0 -170
- brainstate/transform/_make_jaxpr.py +0 -739
- brainstate/transform/_make_jaxpr_test.py +0 -131
- brainstate/transform/_mapping.py +0 -109
- brainstate/transform/_progress_bar.py +0 -111
- brainstate/transform/_unvmap.py +0 -143
- brainstate/util.py +0 -746
- brainstate-0.0.2.post20241010.dist-info/RECORD +0 -87
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/LICENSE +0 -0
- {brainstate-0.0.2.post20241010.dist-info → brainstate-0.1.0.post20241122.dist-info}/WHEEL +0 -0
@@ -24,7 +24,7 @@ from brainstate._utils import set_module_as
|
|
24
24
|
from brainstate.typing import ArrayLike
|
25
25
|
|
26
26
|
__all__ = [
|
27
|
-
|
27
|
+
'weight_standardization',
|
28
28
|
]
|
29
29
|
|
30
30
|
|
@@ -35,49 +35,49 @@ def weight_standardization(
|
|
35
35
|
gain: Optional[jax.Array] = None,
|
36
36
|
out_axis: int = -1,
|
37
37
|
) -> Union[jax.Array, u.Quantity]:
|
38
|
-
|
39
|
-
|
40
|
-
|
38
|
+
"""
|
39
|
+
Scaled Weight Standardization,
|
40
|
+
see `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization <https://paperswithcode.com/paper/weight-standardization>`_.
|
41
41
|
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
42
|
+
Parameters
|
43
|
+
----------
|
44
|
+
w : ArrayLike
|
45
|
+
The weight tensor.
|
46
|
+
eps : float
|
47
|
+
A small value to avoid division by zero.
|
48
|
+
gain : Array
|
49
|
+
The gain function, by default None.
|
50
|
+
out_axis : int
|
51
|
+
The output axis, by default -1.
|
52
52
|
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
53
|
+
Returns
|
54
|
+
-------
|
55
|
+
ArrayLike
|
56
|
+
The scaled weight tensor.
|
57
|
+
"""
|
58
|
+
if out_axis < 0:
|
59
|
+
out_axis = w.ndim + out_axis
|
60
|
+
fan_in = 1 # get the fan-in of the weight tensor
|
61
|
+
axes = [] # get the axes of the weight tensor
|
62
|
+
for i in range(w.ndim):
|
63
|
+
if i != out_axis:
|
64
|
+
fan_in *= w.shape[i]
|
65
|
+
axes.append(i)
|
66
|
+
# normalize the weight
|
67
|
+
mean = u.math.mean(w, axis=axes, keepdims=True)
|
68
|
+
var = u.math.var(w, axis=axes, keepdims=True)
|
69
69
|
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
70
|
+
temp = u.math.maximum(var * fan_in, eps)
|
71
|
+
if isinstance(temp, u.Quantity):
|
72
|
+
unit = temp.unit
|
73
|
+
temp = temp.mantissa
|
74
|
+
if unit.is_unitless:
|
75
|
+
scale = jax.lax.rsqrt(temp)
|
76
|
+
else:
|
77
|
+
scale = u.Quantity(jax.lax.rsqrt(temp), unit=1 / unit ** 0.5)
|
76
78
|
else:
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
shift = mean * scale
|
83
|
-
return w * scale - shift
|
79
|
+
scale = jax.lax.rsqrt(temp)
|
80
|
+
if gain is not None:
|
81
|
+
scale = gain * scale
|
82
|
+
shift = mean * scale
|
83
|
+
return w * scale - shift
|
brainstate/functional/_others.py
CHANGED
@@ -23,7 +23,7 @@ import jax.numpy as jnp
|
|
23
23
|
from brainstate.typing import PyTree
|
24
24
|
|
25
25
|
__all__ = [
|
26
|
-
|
26
|
+
'clip_grad_norm',
|
27
27
|
]
|
28
28
|
|
29
29
|
|
@@ -32,17 +32,17 @@ def clip_grad_norm(
|
|
32
32
|
max_norm: float | jax.Array,
|
33
33
|
norm_type: int | str | None = None
|
34
34
|
):
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
35
|
+
"""
|
36
|
+
Clips gradient norm of an iterable of parameters.
|
37
|
+
|
38
|
+
The norm is computed over all gradients together, as if they were
|
39
|
+
concatenated into a single vector. Gradients are modified in-place.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
grad (PyTree): an iterable of Tensors or a single Tensor that will have gradients normalized
|
43
|
+
max_norm (float): max norm of the gradients.
|
44
|
+
norm_type (int, str, None): type of the used p-norm. Can be ``'inf'`` for infinity norm.
|
45
|
+
"""
|
46
|
+
norm_fn = partial(jnp.linalg.norm, ord=norm_type)
|
47
|
+
norm = norm_fn(jnp.asarray(jax.tree.leaves(jax.tree.map(norm_fn, grad))))
|
48
|
+
return jax.tree.map(lambda x: jnp.where(norm < max_norm, x, x * max_norm / (norm + 1e-6)), grad)
|
brainstate/functional/_spikes.py
CHANGED
@@ -16,74 +16,74 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
__all__ = [
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
19
|
+
'spike_bitwise_or',
|
20
|
+
'spike_bitwise_and',
|
21
|
+
'spike_bitwise_iand',
|
22
|
+
'spike_bitwise_not',
|
23
|
+
'spike_bitwise_xor',
|
24
|
+
'spike_bitwise_ixor',
|
25
|
+
'spike_bitwise',
|
26
26
|
]
|
27
27
|
|
28
28
|
|
29
29
|
def spike_bitwise_or(x, y):
|
30
|
-
|
31
|
-
|
30
|
+
"""Bitwise OR operation for spike tensors."""
|
31
|
+
return x + y - x * y
|
32
32
|
|
33
33
|
|
34
34
|
def spike_bitwise_and(x, y):
|
35
|
-
|
36
|
-
|
35
|
+
"""Bitwise AND operation for spike tensors."""
|
36
|
+
return x * y
|
37
37
|
|
38
38
|
|
39
39
|
def spike_bitwise_iand(x, y):
|
40
|
-
|
41
|
-
|
40
|
+
"""Bitwise IAND operation for spike tensors."""
|
41
|
+
return (1 - x) * y
|
42
42
|
|
43
43
|
|
44
44
|
def spike_bitwise_not(x):
|
45
|
-
|
46
|
-
|
45
|
+
"""Bitwise NOT operation for spike tensors."""
|
46
|
+
return 1 - x
|
47
47
|
|
48
48
|
|
49
49
|
def spike_bitwise_xor(x, y):
|
50
|
-
|
51
|
-
|
50
|
+
"""Bitwise XOR operation for spike tensors."""
|
51
|
+
return x + y - 2 * x * y
|
52
52
|
|
53
53
|
|
54
54
|
def spike_bitwise_ixor(x, y):
|
55
|
-
|
56
|
-
|
55
|
+
"""Bitwise IXOR operation for spike tensors."""
|
56
|
+
return x * (1 - y) + (1 - x) * y
|
57
57
|
|
58
58
|
|
59
59
|
def spike_bitwise(x, y, op: str):
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
60
|
+
r"""Bitwise operation for spike tensors.
|
61
|
+
|
62
|
+
.. math::
|
63
|
+
|
64
|
+
\begin{array}{ccc}
|
65
|
+
\hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
|
66
|
+
\hline \text { ADD } & x+y & x+y \\
|
67
|
+
\text { AND } & x \cap y & x \cdot y \\
|
68
|
+
\text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
|
69
|
+
\text { OR } & x \cup y & (x+y)-(x \cdot y) \\
|
70
|
+
\hline
|
71
|
+
\end{array}
|
72
|
+
|
73
|
+
Args:
|
74
|
+
x: A spike tensor.
|
75
|
+
y: A spike tensor.
|
76
|
+
op: A string indicating the bitwise operation to perform.
|
77
|
+
"""
|
78
|
+
if op == 'or':
|
79
|
+
return spike_bitwise_or(x, y)
|
80
|
+
elif op == 'and':
|
81
|
+
return spike_bitwise_and(x, y)
|
82
|
+
elif op == 'iand':
|
83
|
+
return spike_bitwise_iand(x, y)
|
84
|
+
elif op == 'xor':
|
85
|
+
return spike_bitwise_xor(x, y)
|
86
|
+
elif op == 'ixor':
|
87
|
+
return spike_bitwise_ixor(x, y)
|
88
|
+
else:
|
89
|
+
raise NotImplementedError(f"Unsupported bitwise operation: {op}.")
|
@@ -0,0 +1,33 @@
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
|
17
|
+
from ._graph_context import *
|
18
|
+
from ._graph_context import __all__ as _graph_context__all__
|
19
|
+
from ._graph_convert import *
|
20
|
+
from ._graph_convert import __all__ as _graph_convert__all__
|
21
|
+
from ._graph_node import *
|
22
|
+
from ._graph_node import __all__ as _graph_node__all__
|
23
|
+
from ._graph_operation import *
|
24
|
+
from ._graph_operation import __all__ as _graph_operation__all__
|
25
|
+
|
26
|
+
__all__ = (_graph_context__all__ +
|
27
|
+
_graph_convert__all__ +
|
28
|
+
_graph_node__all__ +
|
29
|
+
_graph_operation__all__)
|
30
|
+
del (_graph_context__all__,
|
31
|
+
_graph_convert__all__,
|
32
|
+
_graph_node__all__,
|
33
|
+
_graph_operation__all__)
|