brainstate 0.1.0.post20250212__py2.py3-none-any.whl → 0.1.0.post20250216__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/_state.py +853 -90
- brainstate/_state_test.py +1 -3
- brainstate/augment/__init__.py +2 -2
- brainstate/augment/_autograd.py +257 -115
- brainstate/augment/_autograd_test.py +2 -3
- brainstate/augment/_eval_shape.py +3 -4
- brainstate/augment/_mapping.py +582 -62
- brainstate/augment/_mapping_test.py +114 -30
- brainstate/augment/_random.py +61 -7
- brainstate/compile/_ad_checkpoint.py +2 -3
- brainstate/compile/_conditions.py +4 -5
- brainstate/compile/_conditions_test.py +1 -2
- brainstate/compile/_error_if.py +1 -2
- brainstate/compile/_error_if_test.py +1 -2
- brainstate/compile/_jit.py +23 -16
- brainstate/compile/_jit_test.py +1 -2
- brainstate/compile/_loop_collect_return.py +18 -10
- brainstate/compile/_loop_collect_return_test.py +1 -1
- brainstate/compile/_loop_no_collection.py +5 -5
- brainstate/compile/_make_jaxpr.py +23 -21
- brainstate/compile/_make_jaxpr_test.py +1 -2
- brainstate/compile/_progress_bar.py +1 -2
- brainstate/compile/_unvmap.py +1 -0
- brainstate/compile/_util.py +4 -2
- brainstate/environ.py +4 -4
- brainstate/environ_test.py +1 -2
- brainstate/functional/_activations.py +1 -2
- brainstate/functional/_activations_test.py +1 -1
- brainstate/functional/_normalization.py +1 -2
- brainstate/functional/_others.py +1 -2
- brainstate/functional/_spikes.py +136 -20
- brainstate/graph/_graph_node.py +2 -43
- brainstate/graph/_graph_operation.py +4 -20
- brainstate/graph/_graph_operation_test.py +3 -4
- brainstate/init/_base.py +1 -2
- brainstate/init/_generic.py +1 -2
- brainstate/nn/__init__.py +4 -0
- brainstate/nn/_collective_ops.py +351 -48
- brainstate/nn/_collective_ops_test.py +36 -0
- brainstate/nn/_common.py +194 -0
- brainstate/nn/_dyn_impl/_dynamics_neuron.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_neuron_test.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse.py +1 -2
- brainstate/nn/_dyn_impl/_dynamics_synapse_test.py +1 -2
- brainstate/nn/_dyn_impl/_inputs.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns.py +1 -2
- brainstate/nn/_dyn_impl/_rate_rnns_test.py +1 -2
- brainstate/nn/_dyn_impl/_readout.py +2 -3
- brainstate/nn/_dyn_impl/_readout_test.py +1 -2
- brainstate/nn/_dynamics/_dynamics_base.py +2 -3
- brainstate/nn/_dynamics/_dynamics_base_test.py +1 -2
- brainstate/nn/_dynamics/_state_delay.py +3 -3
- brainstate/nn/_dynamics/_synouts_test.py +1 -2
- brainstate/nn/_elementwise/_dropout.py +6 -7
- brainstate/nn/_elementwise/_dropout_test.py +1 -2
- brainstate/nn/_elementwise/_elementwise.py +1 -2
- brainstate/nn/_exp_euler.py +1 -2
- brainstate/nn/_exp_euler_test.py +1 -2
- brainstate/nn/_interaction/_conv.py +1 -2
- brainstate/nn/_interaction/_conv_test.py +1 -0
- brainstate/nn/_interaction/_linear.py +1 -2
- brainstate/nn/_interaction/_linear_test.py +1 -2
- brainstate/nn/_interaction/_normalizations.py +1 -2
- brainstate/nn/_interaction/_poolings.py +3 -4
- brainstate/nn/_module.py +63 -19
- brainstate/nn/_module_test.py +1 -2
- brainstate/nn/metrics.py +3 -4
- brainstate/optim/_lr_scheduler.py +1 -2
- brainstate/optim/_lr_scheduler_test.py +2 -3
- brainstate/optim/_optax_optimizer_test.py +1 -2
- brainstate/optim/_sgd_optimizer.py +2 -3
- brainstate/random/_rand_funs.py +1 -2
- brainstate/random/_rand_funs_test.py +2 -3
- brainstate/random/_rand_seed.py +2 -3
- brainstate/random/_rand_seed_test.py +1 -2
- brainstate/random/_rand_state.py +3 -4
- brainstate/surrogate.py +5 -2
- brainstate/transform.py +0 -3
- brainstate/typing.py +28 -25
- brainstate/util/__init__.py +9 -7
- brainstate/util/_caller.py +1 -2
- brainstate/util/_error.py +27 -0
- brainstate/util/_others.py +60 -15
- brainstate/util/{_dict.py → _pretty_pytree.py} +2 -2
- brainstate/util/{_dict_test.py → _pretty_pytree_test.py} +1 -2
- brainstate/util/_pretty_repr.py +1 -2
- brainstate/util/_pretty_table.py +2900 -0
- brainstate/util/_struct.py +11 -11
- brainstate/util/filter.py +472 -0
- {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/METADATA +2 -2
- brainstate-0.1.0.post20250216.dist-info/RECORD +127 -0
- brainstate/util/_filter.py +0 -178
- brainstate-0.1.0.post20250212.dist-info/RECORD +0 -124
- {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/LICENSE +0 -0
- {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/WHEEL +0 -0
- {brainstate-0.1.0.post20250212.dist-info → brainstate-0.1.0.post20250216.dist-info}/top_level.txt +0 -0
@@ -55,12 +55,10 @@ from __future__ import annotations
|
|
55
55
|
|
56
56
|
import functools
|
57
57
|
import inspect
|
58
|
+
import jax
|
58
59
|
import operator
|
59
60
|
from collections.abc import Hashable, Iterable, Sequence
|
60
61
|
from contextlib import ExitStack
|
61
|
-
from typing import Any, Callable, Tuple, Union, Dict, Optional
|
62
|
-
|
63
|
-
import jax
|
64
62
|
from jax._src import source_info_util
|
65
63
|
from jax._src.linear_util import annotate
|
66
64
|
from jax._src.traceback_util import api_boundary
|
@@ -68,18 +66,18 @@ from jax.api_util import shaped_abstractify
|
|
68
66
|
from jax.extend.linear_util import transformation_with_aux, wrap_init
|
69
67
|
from jax.interpreters import partial_eval as pe
|
70
68
|
from jax.util import wraps
|
69
|
+
from typing import Any, Callable, Tuple, Union, Dict, Optional
|
71
70
|
|
72
71
|
from brainstate._state import State, StateTraceStack
|
73
72
|
from brainstate._utils import set_module_as
|
74
73
|
from brainstate.typing import PyTree
|
75
|
-
|
74
|
+
from brainstate.util import PrettyObject
|
76
75
|
|
77
76
|
if jax.__version_info__ < (0, 4, 38):
|
78
77
|
from jax.core import ClosedJaxpr
|
79
78
|
else:
|
80
79
|
from jax.extend.core import ClosedJaxpr
|
81
80
|
|
82
|
-
|
83
81
|
AxisName = Hashable
|
84
82
|
|
85
83
|
__all__ = [
|
@@ -125,8 +123,8 @@ def _new_jax_trace():
|
|
125
123
|
return frame, trace
|
126
124
|
|
127
125
|
|
128
|
-
def _init_state_trace_stack() -> StateTraceStack:
|
129
|
-
state_trace: StateTraceStack = StateTraceStack()
|
126
|
+
def _init_state_trace_stack(name) -> StateTraceStack:
|
127
|
+
state_trace: StateTraceStack = StateTraceStack(name=name)
|
130
128
|
|
131
129
|
if jax.__version_info__ < (0, 4, 36):
|
132
130
|
# Should be within the calling of ``jax.make_jaxpr()``
|
@@ -141,7 +139,7 @@ def _init_state_trace_stack() -> StateTraceStack:
|
|
141
139
|
return state_trace
|
142
140
|
|
143
141
|
|
144
|
-
class StatefulFunction(
|
142
|
+
class StatefulFunction(PrettyObject):
|
145
143
|
"""
|
146
144
|
A wrapper class for a function that collects the states that are read and written by the function. The states are
|
147
145
|
collected by the function and returned as a StateDictManager instance. The StateDictManager instance can be used to
|
@@ -189,6 +187,7 @@ class StatefulFunction(object):
|
|
189
187
|
abstracted_axes: Optional[Any] = None,
|
190
188
|
state_returns: Union[str, Tuple[str, ...]] = ('read', 'write'),
|
191
189
|
cache_type: Optional[str] = None,
|
190
|
+
name: Optional[str] = None,
|
192
191
|
):
|
193
192
|
# explicit parameters
|
194
193
|
self.fun = fun
|
@@ -197,6 +196,7 @@ class StatefulFunction(object):
|
|
197
196
|
self.abstracted_axes = abstracted_axes
|
198
197
|
self.state_returns = tuple(state_returns) if isinstance(state_returns, (tuple, list)) else (state_returns,)
|
199
198
|
assert cache_type in [None, 'jit']
|
199
|
+
self.name = name
|
200
200
|
|
201
201
|
# implicit parameters
|
202
202
|
self.cache_type = cache_type
|
@@ -205,12 +205,10 @@ class StatefulFunction(object):
|
|
205
205
|
self._cached_jaxpr_out_tree: Dict[Any, PyTree] = dict()
|
206
206
|
self._cached_state_trace: Dict[Any, StateTraceStack] = dict()
|
207
207
|
|
208
|
-
def
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
f"abstracted_axes={self.abstracted_axes}, "
|
213
|
-
f"state_returns={self.state_returns})")
|
208
|
+
def __pretty_repr_item__(self, k, v):
|
209
|
+
if k.startswith('_'):
|
210
|
+
return None
|
211
|
+
return k, v
|
214
212
|
|
215
213
|
def get_jaxpr(self, cache_key: Hashable = ()) -> jax.core.ClosedJaxpr:
|
216
214
|
"""
|
@@ -388,7 +386,7 @@ class StatefulFunction(object):
|
|
388
386
|
A tuple of the states that are read and written by the function and the output of the function.
|
389
387
|
"""
|
390
388
|
# state trace
|
391
|
-
state_trace = _init_state_trace_stack()
|
389
|
+
state_trace = _init_state_trace_stack(self.name)
|
392
390
|
self._cached_state_trace[cache_key] = state_trace
|
393
391
|
with state_trace:
|
394
392
|
out = self.fun(*args, **kwargs)
|
@@ -497,11 +495,7 @@ class StatefulFunction(object):
|
|
497
495
|
"""
|
498
496
|
state_trace = self.get_state_trace(self.get_arg_cache_key(*args, **kwargs))
|
499
497
|
state_vals, out = self.jaxpr_call([st.value for st in state_trace.states], *args, **kwargs)
|
500
|
-
|
501
|
-
if written:
|
502
|
-
st.value = val
|
503
|
-
else:
|
504
|
-
st.restore_value(val)
|
498
|
+
state_trace.assign_state_vals(state_vals)
|
505
499
|
return out
|
506
500
|
|
507
501
|
|
@@ -592,7 +586,15 @@ def make_jaxpr(
|
|
592
586
|
in (g,) }
|
593
587
|
"""
|
594
588
|
|
595
|
-
stateful_fun = StatefulFunction(
|
589
|
+
stateful_fun = StatefulFunction(
|
590
|
+
fun,
|
591
|
+
static_argnums=static_argnums,
|
592
|
+
axis_env=axis_env,
|
593
|
+
abstracted_axes=abstracted_axes,
|
594
|
+
state_returns=state_returns,
|
595
|
+
name='make_jaxpr'
|
596
|
+
|
597
|
+
)
|
596
598
|
|
597
599
|
@wraps(fun)
|
598
600
|
def make_jaxpr_f(*args, **kwargs):
|
@@ -17,9 +17,8 @@ from __future__ import annotations
|
|
17
17
|
|
18
18
|
import copy
|
19
19
|
import importlib.util
|
20
|
-
from typing import Optional, Callable, Any, Tuple, Dict
|
21
|
-
|
22
20
|
import jax
|
21
|
+
from typing import Optional, Callable, Any, Tuple, Dict
|
23
22
|
|
24
23
|
tqdm_installed = importlib.util.find_spec('tqdm') is not None
|
25
24
|
|
brainstate/compile/_unvmap.py
CHANGED
brainstate/compile/_util.py
CHANGED
@@ -132,8 +132,10 @@ def wrap_single_fun(
|
|
132
132
|
assert len(been_writen) == len(writen_state_vals) == len(read_state_vals)
|
133
133
|
|
134
134
|
# collect all written and read states
|
135
|
-
state_vals = [
|
136
|
-
|
135
|
+
state_vals = [
|
136
|
+
written_val if written else read_val
|
137
|
+
for written, written_val, read_val in zip(been_writen, writen_state_vals, read_state_vals)
|
138
|
+
]
|
137
139
|
|
138
140
|
# call the jaxpr
|
139
141
|
state_vals, (carry, out) = stateful_fun.jaxpr_call(state_vals, carry, inputs)
|
brainstate/environ.py
CHANGED
@@ -17,18 +17,18 @@
|
|
17
17
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
|
+
from collections import defaultdict
|
21
|
+
|
20
22
|
import contextlib
|
21
23
|
import dataclasses
|
22
24
|
import functools
|
25
|
+
import numpy as np
|
23
26
|
import os
|
24
27
|
import re
|
25
28
|
import threading
|
26
|
-
from collections import defaultdict
|
27
|
-
from typing import Any, Callable, Dict, Hashable
|
28
|
-
|
29
|
-
import numpy as np
|
30
29
|
from jax import config, devices, numpy as jnp
|
31
30
|
from jax.typing import DTypeLike
|
31
|
+
from typing import Any, Callable, Dict, Hashable
|
32
32
|
|
33
33
|
from .mixin import Mode
|
34
34
|
|
brainstate/environ_test.py
CHANGED
@@ -20,11 +20,10 @@ Shared neural network activations and other functions.
|
|
20
20
|
|
21
21
|
from __future__ import annotations
|
22
22
|
|
23
|
-
from typing import Any, Union, Sequence
|
24
|
-
|
25
23
|
import brainunit as u
|
26
24
|
import jax
|
27
25
|
from jax.scipy.special import logsumexp
|
26
|
+
from typing import Any, Union, Sequence
|
28
27
|
|
29
28
|
from brainstate import random
|
30
29
|
from brainstate.typing import ArrayLike
|
@@ -16,12 +16,12 @@
|
|
16
16
|
"""Tests for nn module."""
|
17
17
|
|
18
18
|
import itertools
|
19
|
-
from functools import partial
|
20
19
|
|
21
20
|
import jax
|
22
21
|
import jax.numpy as jnp
|
23
22
|
import scipy.stats
|
24
23
|
from absl.testing import parameterized
|
24
|
+
from functools import partial
|
25
25
|
from jax._src import test_util as jtu
|
26
26
|
from jax.test_util import check_grads
|
27
27
|
|
@@ -15,10 +15,9 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
from typing import Optional, Union
|
19
|
-
|
20
18
|
import brainunit as u
|
21
19
|
import jax
|
20
|
+
from typing import Optional, Union
|
22
21
|
|
23
22
|
from brainstate._utils import set_module_as
|
24
23
|
from brainstate.typing import ArrayLike
|
brainstate/functional/_others.py
CHANGED
brainstate/functional/_spikes.py
CHANGED
@@ -27,53 +27,169 @@ __all__ = [
|
|
27
27
|
|
28
28
|
|
29
29
|
def spike_bitwise_or(x, y):
|
30
|
-
"""
|
30
|
+
"""
|
31
|
+
Perform a bitwise OR operation on spike tensors.
|
32
|
+
|
33
|
+
This function computes the OR operation between two spike tensors.
|
34
|
+
The OR operation is implemented using the formula: x + y - x * y,
|
35
|
+
which is equivalent to the OR operation for binary values.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
x (Tensor): The first input spike tensor.
|
39
|
+
y (Tensor): The second input spike tensor.
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
Tensor: The result of the bitwise OR operation applied to the input tensors.
|
43
|
+
The output tensor has the same shape as the input tensors.
|
44
|
+
|
45
|
+
Note:
|
46
|
+
This operation assumes that the input tensors contain binary (0 or 1) values.
|
47
|
+
For non-binary inputs, the behavior may not correspond to a true bitwise OR.
|
48
|
+
"""
|
31
49
|
return x + y - x * y
|
32
50
|
|
33
51
|
|
34
52
|
def spike_bitwise_and(x, y):
|
35
|
-
"""
|
53
|
+
"""
|
54
|
+
Perform a bitwise AND operation on spike tensors.
|
55
|
+
|
56
|
+
This function computes the AND operation between two spike tensors.
|
57
|
+
The AND operation is equivalent to element-wise multiplication for binary values.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
x (Tensor): The first input spike tensor.
|
61
|
+
y (Tensor): The second input spike tensor.
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
Tensor: The result of the bitwise AND operation applied to the input tensors.
|
65
|
+
The output tensor has the same shape as the input tensors.
|
66
|
+
|
67
|
+
Note:
|
68
|
+
This operation is implemented using element-wise multiplication (x * y),
|
69
|
+
which is equivalent to the AND operation for binary values.
|
70
|
+
"""
|
36
71
|
return x * y
|
37
72
|
|
38
73
|
|
39
74
|
def spike_bitwise_iand(x, y):
|
40
|
-
"""
|
75
|
+
"""
|
76
|
+
Perform a bitwise IAND (Inverse AND) operation on spike tensors.
|
77
|
+
|
78
|
+
This function computes the Inverse AND (IAND) operation between two spike tensors.
|
79
|
+
IAND is defined as (NOT x) AND y.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
x (Tensor): The first input spike tensor.
|
83
|
+
y (Tensor): The second input spike tensor.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
Tensor: The result of the bitwise IAND operation applied to the input tensors.
|
87
|
+
The output tensor has the same shape as the input tensors.
|
88
|
+
|
89
|
+
Note:
|
90
|
+
This operation is implemented using the formula: (1 - x) * y,
|
91
|
+
which is equivalent to the IAND operation for binary values.
|
92
|
+
"""
|
41
93
|
return (1 - x) * y
|
42
94
|
|
43
95
|
|
44
96
|
def spike_bitwise_not(x):
|
45
|
-
"""
|
97
|
+
"""
|
98
|
+
Perform a bitwise NOT operation on spike tensors.
|
99
|
+
|
100
|
+
This function computes the NOT operation on a spike tensor.
|
101
|
+
The NOT operation inverts the binary values in the tensor.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
x (Tensor): The input spike tensor.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
Tensor: The result of the bitwise NOT operation applied to the input tensor.
|
108
|
+
The output tensor has the same shape as the input tensor.
|
109
|
+
|
110
|
+
Note:
|
111
|
+
This operation is implemented using the formula: 1 - x,
|
112
|
+
which is equivalent to the NOT operation for binary values.
|
113
|
+
"""
|
46
114
|
return 1 - x
|
47
115
|
|
48
116
|
|
49
117
|
def spike_bitwise_xor(x, y):
|
50
|
-
"""
|
118
|
+
"""
|
119
|
+
Perform a bitwise XOR operation on spike tensors.
|
120
|
+
|
121
|
+
This function computes the XOR operation between two spike tensors.
|
122
|
+
XOR is defined as (x OR y) AND NOT (x AND y).
|
123
|
+
|
124
|
+
Args:
|
125
|
+
x (Tensor): The first input spike tensor.
|
126
|
+
y (Tensor): The second input spike tensor.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
Tensor: The result of the bitwise XOR operation applied to the input tensors.
|
130
|
+
The output tensor has the same shape as the input tensors.
|
131
|
+
|
132
|
+
Note:
|
133
|
+
This operation is implemented using the formula: x + y - 2 * x * y,
|
134
|
+
which is equivalent to the XOR operation for binary values.
|
135
|
+
"""
|
51
136
|
return x + y - 2 * x * y
|
52
137
|
|
53
138
|
|
54
139
|
def spike_bitwise_ixor(x, y):
|
55
|
-
"""
|
140
|
+
"""
|
141
|
+
Perform a bitwise IXOR (Inverse XOR) operation on spike tensors.
|
142
|
+
|
143
|
+
This function computes the Inverse XOR (IXOR) operation between two spike tensors.
|
144
|
+
IXOR is defined as (x AND NOT y) OR (NOT x AND y).
|
145
|
+
|
146
|
+
Args:
|
147
|
+
x (Tensor): The first input spike tensor.
|
148
|
+
y (Tensor): The second input spike tensor.
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
Tensor: The result of the bitwise IXOR operation applied to the input tensors.
|
152
|
+
The output tensor has the same shape as the input tensors.
|
153
|
+
|
154
|
+
Note:
|
155
|
+
This operation is implemented using the formula: x * (1 - y) + (1 - x) * y,
|
156
|
+
which is equivalent to the IXOR operation for binary values.
|
157
|
+
"""
|
56
158
|
return x * (1 - y) + (1 - x) * y
|
57
159
|
|
58
160
|
|
59
161
|
def spike_bitwise(x, y, op: str):
|
60
|
-
|
61
|
-
|
62
|
-
.. math::
|
162
|
+
"""
|
163
|
+
Perform bitwise operations on spike tensors.
|
63
164
|
|
64
|
-
|
65
|
-
|
66
|
-
\hline \text { ADD } & x+y & x+y \\
|
67
|
-
\text { AND } & x \cap y & x \cdot y \\
|
68
|
-
\text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
|
69
|
-
\text { OR } & x \cup y & (x+y)-(x \cdot y) \\
|
70
|
-
\hline
|
71
|
-
\end{array}
|
165
|
+
This function applies various bitwise operations on spike tensors based on the specified operation.
|
166
|
+
It supports 'or', 'and', 'iand', 'xor', and 'ixor' operations.
|
72
167
|
|
73
168
|
Args:
|
74
|
-
|
75
|
-
|
76
|
-
|
169
|
+
x (Tensor): The first input spike tensor.
|
170
|
+
y (Tensor): The second input spike tensor.
|
171
|
+
op (str): A string indicating the bitwise operation to perform.
|
172
|
+
Supported operations are 'or', 'and', 'iand', 'xor', and 'ixor'.
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
Tensor: The result of the bitwise operation applied to the input tensors.
|
176
|
+
|
177
|
+
Raises:
|
178
|
+
NotImplementedError: If an unsupported bitwise operation is specified.
|
179
|
+
|
180
|
+
Note:
|
181
|
+
The function uses the following mathematical expressions for different operations:
|
182
|
+
|
183
|
+
.. math::
|
184
|
+
|
185
|
+
\begin{array}{ccc}
|
186
|
+
\hline \text { Mode } & \text { Expression for } \mathrm{g}(\mathrm{x}, \mathrm{y}) & \text { Code for } \mathrm{g}(\mathrm{x}, \mathrm{y}) \\
|
187
|
+
\hline \text { ADD } & x+y & x+y \\
|
188
|
+
\text { AND } & x \cap y & x \cdot y \\
|
189
|
+
\text { IAND } & (\neg x) \cap y & (1-x) \cdot y \\
|
190
|
+
\text { OR } & x \cup y & (x+y)-(x \cdot y) \\
|
191
|
+
\hline
|
192
|
+
\end{array}
|
77
193
|
"""
|
78
194
|
if op == 'or':
|
79
195
|
return spike_bitwise_or(x, y)
|
brainstate/graph/_graph_node.py
CHANGED
@@ -27,7 +27,7 @@ import numpy as np
|
|
27
27
|
|
28
28
|
from brainstate._state import State, TreefyState
|
29
29
|
from brainstate.typing import Key
|
30
|
-
from brainstate.util.
|
30
|
+
from brainstate.util._pretty_pytree import PrettyObject
|
31
31
|
from ._graph_operation import register_graph_node_type
|
32
32
|
|
33
33
|
__all__ = [
|
@@ -46,7 +46,7 @@ class GraphNodeMeta(ABCMeta):
|
|
46
46
|
return node
|
47
47
|
|
48
48
|
|
49
|
-
class Node(
|
49
|
+
class Node(PrettyObject, metaclass=GraphNodeMeta):
|
50
50
|
"""
|
51
51
|
Base class for all graph nodes.
|
52
52
|
|
@@ -84,47 +84,6 @@ class Node(PrettyRepr, metaclass=GraphNodeMeta):
|
|
84
84
|
state = deepcopy(state)
|
85
85
|
return treefy_merge(graphdef, state)
|
86
86
|
|
87
|
-
def __pretty_repr__(self):
|
88
|
-
"""
|
89
|
-
Pretty repr for the object.
|
90
|
-
"""
|
91
|
-
yield from yield_unique_pretty_repr_items(self, _default_repr_object, _default_repr_attr)
|
92
|
-
|
93
|
-
def __treescope_repr__(self, path, subtree_renderer):
|
94
|
-
"""
|
95
|
-
Treescope repr for the object.
|
96
|
-
"""
|
97
|
-
children = {}
|
98
|
-
for name, value in vars(self).items():
|
99
|
-
name, value = self.__leaf_fn__(name, value)
|
100
|
-
if name.startswith('_'):
|
101
|
-
continue
|
102
|
-
children[name] = value
|
103
|
-
import treescope # type: ignore[import-not-found,import-untyped]
|
104
|
-
return treescope.repr_lib.render_object_constructor(
|
105
|
-
object_type=type(self),
|
106
|
-
attributes=children,
|
107
|
-
path=path,
|
108
|
-
subtree_renderer=subtree_renderer,
|
109
|
-
color=treescope.formatting_util.color_from_string(type(self).__qualname__)
|
110
|
-
)
|
111
|
-
|
112
|
-
def __leaf_fn__(self, leaf, value):
|
113
|
-
return leaf, value
|
114
|
-
|
115
|
-
|
116
|
-
def _default_repr_object(node: Node):
|
117
|
-
yield PrettyType(type=type(node))
|
118
|
-
|
119
|
-
|
120
|
-
def _default_repr_attr(node: Node):
|
121
|
-
for name, value in vars(node).items():
|
122
|
-
name, value = node.__leaf_fn__(name, value)
|
123
|
-
if name.startswith('_'):
|
124
|
-
continue
|
125
|
-
# value = jax.tree.map(_to_shape_dtype, value, is_leaf=lambda x: isinstance(x, u.Quantity))
|
126
|
-
yield PrettyAttr(name, repr(value))
|
127
|
-
|
128
87
|
|
129
88
|
class String:
|
130
89
|
def __init__(self, msg):
|
@@ -18,21 +18,20 @@
|
|
18
18
|
from __future__ import annotations
|
19
19
|
|
20
20
|
import dataclasses
|
21
|
-
from typing import (Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
|
22
|
-
Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional, overload)
|
23
|
-
|
24
21
|
import jax
|
25
22
|
import numpy as np
|
23
|
+
from typing import (Any, Callable, Generic, Iterable, Iterator, Mapping, MutableMapping,
|
24
|
+
Sequence, Type, TypeVar, Union, Hashable, Tuple, Dict, Optional, overload)
|
26
25
|
from typing_extensions import TypeGuard, Unpack
|
27
26
|
|
28
27
|
from brainstate._state import State, TreefyState
|
29
28
|
from brainstate._utils import set_module_as
|
30
29
|
from brainstate.typing import PathParts, Filter, Predicate, Key
|
31
30
|
from brainstate.util._caller import ApplyCaller, CallableProxy, DelayedAccessor
|
32
|
-
from brainstate.util.
|
33
|
-
from brainstate.util._filter import to_predicate
|
31
|
+
from brainstate.util._pretty_pytree import NestedDict, FlattedDict, PrettyDict
|
34
32
|
from brainstate.util._pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
|
35
33
|
from brainstate.util._struct import FrozenDict
|
34
|
+
from brainstate.util.filter import to_predicate
|
36
35
|
|
37
36
|
_max_int = np.iinfo(np.int32).max
|
38
37
|
|
@@ -347,21 +346,6 @@ class NodeDef(GraphDef[Node], PrettyRepr):
|
|
347
346
|
yield PrettyAttr('metadata', self.metadata)
|
348
347
|
yield PrettyAttr('index_mapping', PrettyMapping(self.index_mapping) if self.index_mapping is not None else None)
|
349
348
|
|
350
|
-
def __treescope_repr__(self, path, subtree_renderer):
|
351
|
-
import treescope # type: ignore[import-not-found,import-untyped]
|
352
|
-
return treescope.repr_lib.render_object_constructor(
|
353
|
-
object_type=type(self),
|
354
|
-
attributes={'type': self.type,
|
355
|
-
'index': self.index,
|
356
|
-
'attributes': self.attributes,
|
357
|
-
'subgraphs': dict(self.subgraphs),
|
358
|
-
'static_fields': dict(self.static_fields),
|
359
|
-
'leaves': dict(self.leaves),
|
360
|
-
'metadata': self.metadata, },
|
361
|
-
path=path,
|
362
|
-
subtree_renderer=subtree_renderer,
|
363
|
-
)
|
364
|
-
|
365
349
|
def apply(
|
366
350
|
self,
|
367
351
|
state_map: GraphStateMapping,
|
@@ -15,13 +15,12 @@
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
-
import unittest
|
19
|
-
from collections.abc import Callable
|
20
|
-
from threading import Thread
|
21
|
-
|
22
18
|
import jax
|
23
19
|
import jax.numpy as jnp
|
20
|
+
import unittest
|
24
21
|
from absl.testing import absltest, parameterized
|
22
|
+
from collections.abc import Callable
|
23
|
+
from threading import Thread
|
25
24
|
|
26
25
|
import brainstate as bst
|
27
26
|
|
brainstate/init/_base.py
CHANGED
brainstate/init/_generic.py
CHANGED
@@ -16,11 +16,10 @@
|
|
16
16
|
# -*- coding: utf-8 -*-
|
17
17
|
from __future__ import annotations
|
18
18
|
|
19
|
-
from typing import Union, Callable, Optional, Sequence
|
20
|
-
|
21
19
|
import brainunit as bu
|
22
20
|
import jax
|
23
21
|
import numpy as np
|
22
|
+
from typing import Union, Callable, Optional, Sequence
|
24
23
|
|
25
24
|
from brainstate._state import State
|
26
25
|
from brainstate._utils import set_module_as
|
brainstate/nn/__init__.py
CHANGED
@@ -17,6 +17,8 @@
|
|
17
17
|
from . import metrics
|
18
18
|
from ._collective_ops import *
|
19
19
|
from ._collective_ops import __all__ as collective_ops_all
|
20
|
+
from ._common import *
|
21
|
+
from ._common import __all__ as common_all
|
20
22
|
from ._dyn_impl import *
|
21
23
|
from ._dyn_impl import __all__ as dyn_impl_all
|
22
24
|
from ._dynamics import *
|
@@ -33,6 +35,7 @@ from ._module import __all__ as module_all
|
|
33
35
|
__all__ = (
|
34
36
|
['metrics']
|
35
37
|
+ collective_ops_all
|
38
|
+
+ common_all
|
36
39
|
+ dyn_impl_all
|
37
40
|
+ dynamics_all
|
38
41
|
+ elementwise_all
|
@@ -43,6 +46,7 @@ __all__ = (
|
|
43
46
|
|
44
47
|
del (
|
45
48
|
collective_ops_all,
|
49
|
+
common_all,
|
46
50
|
dyn_impl_all,
|
47
51
|
dynamics_all,
|
48
52
|
elementwise_all,
|