brainstate 0.1.3__py2.py3-none-any.whl → 0.1.5__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainstate/__init__.py +1 -1
- brainstate/_compatible_import.py +1 -16
- brainstate/_state.py +1 -0
- brainstate/augment/_mapping.py +9 -9
- brainstate/augment/_mapping_test.py +162 -0
- brainstate/compile/_jit.py +14 -5
- brainstate/compile/_make_jaxpr.py +78 -22
- brainstate/compile/_make_jaxpr_test.py +13 -2
- brainstate/graph/_graph_node.py +1 -1
- brainstate/graph/_graph_operation.py +4 -4
- brainstate/mixin.py +31 -2
- brainstate/nn/__init__.py +8 -5
- brainstate/nn/_common.py +7 -19
- brainstate/nn/_delay.py +13 -1
- brainstate/nn/_dropout.py +5 -4
- brainstate/nn/_dynamics.py +39 -44
- brainstate/nn/_exp_euler.py +13 -16
- brainstate/nn/{_fixedprob_mv.py → _fixedprob.py} +95 -24
- brainstate/nn/_inputs.py +1 -1
- brainstate/nn/_linear_mv.py +1 -1
- brainstate/nn/_module.py +5 -5
- brainstate/nn/_projection.py +190 -98
- brainstate/nn/_synapse.py +5 -9
- brainstate/nn/_synaptic_projection.py +376 -86
- brainstate/random/_rand_state.py +13 -7
- brainstate/surrogate.py +1 -1
- brainstate/typing.py +1 -1
- brainstate/util/__init__.py +14 -14
- brainstate/util/{_pretty_pytree.py → pretty_pytree.py} +2 -2
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/METADATA +1 -1
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/RECORD +42 -42
- /brainstate/nn/{_fixedprob_mv_test.py → _fixedprob_test.py} +0 -0
- /brainstate/util/{_caller.py → caller.py} +0 -0
- /brainstate/util/{_error.py → error.py} +0 -0
- /brainstate/util/{_others.py → others.py} +0 -0
- /brainstate/util/{_pretty_repr.py → pretty_repr.py} +0 -0
- /brainstate/util/{_pretty_table.py → pretty_table.py} +0 -0
- /brainstate/util/{_scaling.py → scaling.py} +0 -0
- /brainstate/util/{_struct.py → struct.py} +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/LICENSE +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/WHEEL +0 -0
- {brainstate-0.1.3.dist-info → brainstate-0.1.5.dist-info}/top_level.txt +0 -0
@@ -88,7 +88,7 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
88
88
|
self.assertTrue(jnp.allclose(jaxpr_as_fun(jaxpr)(jnp.zeros(1), st1.value)[0],
|
89
89
|
f3(jnp.zeros(1))))
|
90
90
|
|
91
|
-
def
|
91
|
+
def test_compare_jax_make_jaxpr2(self):
|
92
92
|
st1 = brainstate.State(jnp.ones(10))
|
93
93
|
|
94
94
|
def fa(x):
|
@@ -108,7 +108,7 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
108
108
|
print(jaxpr)
|
109
109
|
print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
110
110
|
|
111
|
-
def
|
111
|
+
def test_compare_jax_make_jaxpr3(self):
|
112
112
|
def fa(x):
|
113
113
|
return 1.
|
114
114
|
|
@@ -121,6 +121,17 @@ class TestMakeJaxpr(unittest.TestCase):
|
|
121
121
|
print(jaxpr)
|
122
122
|
# print(jaxpr_as_fun(jaxpr)(jnp.zeros(1)))
|
123
123
|
|
124
|
+
def test_static_argnames(self):
|
125
|
+
def func4(a, b): # Arg is a pair
|
126
|
+
temp = a + jnp.sin(b) * 3.
|
127
|
+
c = brainstate.random.rand_like(a)
|
128
|
+
return jnp.sum(temp + c)
|
129
|
+
|
130
|
+
jaxpr, states = brainstate.compile.make_jaxpr(func4, static_argnames='b')(jnp.zeros(8), 1.)
|
131
|
+
print()
|
132
|
+
print(jaxpr)
|
133
|
+
print(states)
|
134
|
+
|
124
135
|
|
125
136
|
def test_return_states():
|
126
137
|
import jax.numpy
|
brainstate/graph/_graph_node.py
CHANGED
@@ -25,7 +25,7 @@ import numpy as np
|
|
25
25
|
|
26
26
|
from brainstate._state import State, TreefyState
|
27
27
|
from brainstate.typing import Key
|
28
|
-
from brainstate.util.
|
28
|
+
from brainstate.util.pretty_pytree import PrettyObject
|
29
29
|
from ._graph_operation import register_graph_node_type
|
30
30
|
|
31
31
|
__all__ = [
|
@@ -30,10 +30,10 @@ from typing_extensions import TypeGuard, Unpack
|
|
30
30
|
from brainstate._state import State, TreefyState
|
31
31
|
from brainstate._utils import set_module_as
|
32
32
|
from brainstate.typing import PathParts, Filter, Predicate, Key
|
33
|
-
from brainstate.util.
|
34
|
-
from brainstate.util.
|
35
|
-
from brainstate.util.
|
36
|
-
from brainstate.util.
|
33
|
+
from brainstate.util.caller import ApplyCaller, CallableProxy, DelayedAccessor
|
34
|
+
from brainstate.util.pretty_pytree import NestedDict, FlattedDict, PrettyDict
|
35
|
+
from brainstate.util.pretty_repr import PrettyRepr, PrettyType, PrettyAttr, PrettyMapping, MappingReprMixin
|
36
|
+
from brainstate.util.struct import FrozenDict
|
37
37
|
from brainstate.util.filter import to_predicate
|
38
38
|
|
39
39
|
_max_int = np.iinfo(np.int32).max
|
brainstate/mixin.py
CHANGED
@@ -41,6 +41,14 @@ __all__ = [
|
|
41
41
|
]
|
42
42
|
|
43
43
|
|
44
|
+
def hashable(x):
|
45
|
+
try:
|
46
|
+
hash(x)
|
47
|
+
return True
|
48
|
+
except TypeError:
|
49
|
+
return False
|
50
|
+
|
51
|
+
|
44
52
|
class Mixin(object):
|
45
53
|
"""Base Mixin object.
|
46
54
|
|
@@ -67,6 +75,14 @@ class ParamDesc(Mixin):
|
|
67
75
|
|
68
76
|
|
69
77
|
class HashableDict(dict):
|
78
|
+
def __init__(self, the_dict: dict):
|
79
|
+
out = dict()
|
80
|
+
for k, v in the_dict.items():
|
81
|
+
if not hashable(v):
|
82
|
+
v = str(v) # convert to string if not hashable
|
83
|
+
out[k] = v
|
84
|
+
super().__init__(out)
|
85
|
+
|
70
86
|
def __hash__(self):
|
71
87
|
return hash(tuple(sorted(self.items())))
|
72
88
|
|
@@ -132,7 +148,6 @@ class AlignPost(Mixin):
|
|
132
148
|
raise NotImplementedError
|
133
149
|
|
134
150
|
|
135
|
-
|
136
151
|
class BindCondData(Mixin):
|
137
152
|
"""Bind temporary conductance data.
|
138
153
|
|
@@ -147,12 +162,26 @@ class BindCondData(Mixin):
|
|
147
162
|
self._conductance = None
|
148
163
|
|
149
164
|
|
165
|
+
def not_implemented(func):
|
166
|
+
|
167
|
+
def wrapper(*args, **kwargs):
|
168
|
+
raise NotImplementedError(f'{func.__name__} is not implemented.')
|
169
|
+
|
170
|
+
wrapper.not_implemented = True
|
171
|
+
return wrapper
|
172
|
+
|
173
|
+
|
174
|
+
|
150
175
|
class UpdateReturn(Mixin):
|
176
|
+
@not_implemented
|
151
177
|
def update_return(self) -> PyTree:
|
152
178
|
"""
|
153
179
|
The update function return of the model.
|
154
180
|
|
155
|
-
|
181
|
+
This function requires no parameters and must return a PyTree.
|
182
|
+
|
183
|
+
It is usually used for delay initialization, for example, ``Dynamics.output_delay`` relies on this function to
|
184
|
+
initialize the output delay.
|
156
185
|
|
157
186
|
"""
|
158
187
|
raise NotImplementedError(f'Must implement the "{self.update_return.__name__}()" function.')
|
brainstate/nn/__init__.py
CHANGED
@@ -33,12 +33,14 @@ from ._embedding import *
|
|
33
33
|
from ._embedding import __all__ as embed_all
|
34
34
|
from ._exp_euler import *
|
35
35
|
from ._exp_euler import __all__ as exp_euler_all
|
36
|
-
from .
|
36
|
+
from ._fixedprob import *
|
37
|
+
from._fixedprob import __all__ as fixedprob_all
|
37
38
|
from ._inputs import *
|
38
39
|
from ._inputs import __all__ as inputs_all
|
39
40
|
from ._linear import *
|
40
41
|
from ._linear import __all__ as linear_all
|
41
|
-
from ._linear_mv import
|
42
|
+
from ._linear_mv import *
|
43
|
+
from ._linear_mv import __all__ as linear_mv_all
|
42
44
|
from ._ltp import *
|
43
45
|
from ._ltp import __all__ as ltp_all
|
44
46
|
from ._module import *
|
@@ -69,9 +71,6 @@ from ._utils import __all__ as utils_all
|
|
69
71
|
__all__ = (
|
70
72
|
[
|
71
73
|
'metrics',
|
72
|
-
'EventLinear',
|
73
|
-
'EventFixedProb',
|
74
|
-
'EventFixedNumConn',
|
75
74
|
]
|
76
75
|
+ collective_ops_all
|
77
76
|
+ common_all
|
@@ -87,6 +86,8 @@ __all__ = (
|
|
87
86
|
+ linear_all
|
88
87
|
+ normalizations_all
|
89
88
|
+ poolings_all
|
89
|
+
+ fixedprob_all
|
90
|
+
+ linear_mv_all
|
90
91
|
+ embed_all
|
91
92
|
+ dropout_all
|
92
93
|
+ elementwise_all
|
@@ -115,6 +116,8 @@ del (
|
|
115
116
|
normalizations_all,
|
116
117
|
poolings_all,
|
117
118
|
embed_all,
|
119
|
+
fixedprob_all,
|
120
|
+
linear_mv_all,
|
118
121
|
dropout_all,
|
119
122
|
elementwise_all,
|
120
123
|
dyn_neuron_all,
|
brainstate/nn/_common.py
CHANGED
@@ -118,14 +118,14 @@ class Vmap(Module):
|
|
118
118
|
This class wraps a module and applies vectorized mapping to its execution,
|
119
119
|
allowing for efficient parallel processing across specified axes.
|
120
120
|
|
121
|
-
|
121
|
+
Args:
|
122
122
|
module (Module): The module to be vmapped.
|
123
|
-
in_axes (int | None | Sequence[Any]): Specifies how to map over inputs.
|
124
|
-
out_axes (Any): Specifies how to map over outputs.
|
125
|
-
vmap_states (Filter | Dict[Filter, int]): Specifies which states to vmap and on which axes.
|
126
|
-
vmap_out_states (Filter | Dict[Filter, int]): Specifies which output states to vmap and on which axes.
|
127
|
-
axis_name (AxisName | None): Name of the axis being mapped over.
|
128
|
-
axis_size (int | None): Size of the axis being mapped over.
|
123
|
+
in_axes (int | None | Sequence[Any], optional): Specifies how to map over inputs. Defaults to 0.
|
124
|
+
out_axes (Any, optional): Specifies how to map over outputs. Defaults to 0.
|
125
|
+
vmap_states (Filter | Dict[Filter, int], optional): Specifies which states to vmap and on which axes. Defaults to None.
|
126
|
+
vmap_out_states (Filter | Dict[Filter, int], optional): Specifies which output states to vmap and on which axes. Defaults to None.
|
127
|
+
axis_name (AxisName | None, optional): Name of the axis being mapped over. Defaults to None.
|
128
|
+
axis_size (int | None, optional): Size of the axis being mapped over. Defaults to None.
|
129
129
|
"""
|
130
130
|
|
131
131
|
def __init__(
|
@@ -138,18 +138,6 @@ class Vmap(Module):
|
|
138
138
|
axis_name: AxisName | None = None,
|
139
139
|
axis_size: int | None = None,
|
140
140
|
):
|
141
|
-
"""
|
142
|
-
Initialize the Vmap instance.
|
143
|
-
|
144
|
-
Args:
|
145
|
-
module (Module): The module to be vmapped.
|
146
|
-
in_axes (int | None | Sequence[Any], optional): Specifies how to map over inputs. Defaults to 0.
|
147
|
-
out_axes (Any, optional): Specifies how to map over outputs. Defaults to 0.
|
148
|
-
vmap_states (Filter | Dict[Filter, int], optional): Specifies which states to vmap and on which axes. Defaults to None.
|
149
|
-
vmap_out_states (Filter | Dict[Filter, int], optional): Specifies which output states to vmap and on which axes. Defaults to None.
|
150
|
-
axis_name (AxisName | None, optional): Name of the axis being mapped over. Defaults to None.
|
151
|
-
axis_size (int | None, optional): Size of the axis being mapped over. Defaults to None.
|
152
|
-
"""
|
153
141
|
super().__init__()
|
154
142
|
|
155
143
|
# parameters
|
brainstate/nn/_delay.py
CHANGED
@@ -330,7 +330,14 @@ class Delay(Module):
|
|
330
330
|
indices = (delay_idx,) + indices
|
331
331
|
|
332
332
|
# the delay data
|
333
|
-
|
333
|
+
if self._unit is None:
|
334
|
+
return jax.tree.map(lambda a: a[indices], self.history.value)
|
335
|
+
else:
|
336
|
+
return jax.tree.map(
|
337
|
+
lambda hist, unit: u.maybe_decimal(hist[indices] * unit),
|
338
|
+
self.history.value,
|
339
|
+
self._unit
|
340
|
+
)
|
334
341
|
|
335
342
|
def retrieve_at_time(self, delay_time, *indices) -> PyTree:
|
336
343
|
"""
|
@@ -393,6 +400,9 @@ class Delay(Module):
|
|
393
400
|
"""
|
394
401
|
assert self.history is not None, 'The delay history is not initialized.'
|
395
402
|
|
403
|
+
if self.take_aware_unit and self._unit is None:
|
404
|
+
self._unit = jax.tree.map(lambda x: u.get_unit(x), current, is_leaf=u.math.is_quantity)
|
405
|
+
|
396
406
|
# update the delay data at the rotation index
|
397
407
|
if self.delay_method == _DELAY_ROTATE:
|
398
408
|
i = environ.get(environ.I)
|
@@ -419,6 +429,8 @@ class Delay(Module):
|
|
419
429
|
raise ValueError(f'Unknown updating method "{self.delay_method}"')
|
420
430
|
|
421
431
|
|
432
|
+
|
433
|
+
|
422
434
|
class StateWithDelay(Delay):
|
423
435
|
"""
|
424
436
|
A ``State`` type that defines the state in a differential equation.
|
brainstate/nn/_dropout.py
CHANGED
@@ -409,7 +409,8 @@ class DropoutFixed(ElementWiseBlock):
|
|
409
409
|
self.out_size = in_size
|
410
410
|
|
411
411
|
def init_state(self, batch_size=None, **kwargs):
|
412
|
-
|
412
|
+
if self.prob < 1.:
|
413
|
+
self.mask = ShortTermState(init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size))
|
413
414
|
|
414
415
|
def update(self, x):
|
415
416
|
dtype = u.math.get_dtype(x)
|
@@ -418,8 +419,8 @@ class DropoutFixed(ElementWiseBlock):
|
|
418
419
|
if self.mask.value.shape != x.shape:
|
419
420
|
raise ValueError(f"Input shape {x.shape} does not match the mask shape {self.mask.value.shape}. "
|
420
421
|
f"Please call `init_state()` method first.")
|
421
|
-
return
|
422
|
-
|
423
|
-
|
422
|
+
return u.math.where(self.mask.value,
|
423
|
+
u.math.asarray(x / self.prob, dtype=dtype),
|
424
|
+
u.math.asarray(0., dtype=dtype) * u.get_unit(x))
|
424
425
|
else:
|
425
426
|
return x
|
brainstate/nn/_dynamics.py
CHANGED
@@ -36,13 +36,14 @@ For handling the delays:
|
|
36
36
|
from typing import Any, Dict, Callable, Hashable, Optional, Union, TypeVar, TYPE_CHECKING
|
37
37
|
|
38
38
|
import brainunit as u
|
39
|
+
import jax
|
39
40
|
import numpy as np
|
40
41
|
|
41
42
|
from brainstate import environ
|
42
43
|
from brainstate._state import State
|
43
44
|
from brainstate.graph import Node
|
44
45
|
from brainstate.mixin import ParamDescriber, UpdateReturn
|
45
|
-
from brainstate.typing import Size, ArrayLike
|
46
|
+
from brainstate.typing import Size, ArrayLike, PyTree
|
46
47
|
from ._delay import StateWithDelay, Delay
|
47
48
|
from ._module import Module
|
48
49
|
|
@@ -811,18 +812,25 @@ class Dynamics(Module, UpdateReturn):
|
|
811
812
|
>>> n1.align_pre(brainstate.nn.Expon.desc(n1.varshape)) # n2 will run after n1
|
812
813
|
"""
|
813
814
|
if isinstance(dyn, Dynamics):
|
814
|
-
self._add_after_update(dyn
|
815
|
+
self._add_after_update(id(dyn), dyn)
|
815
816
|
return dyn
|
816
817
|
elif isinstance(dyn, ParamDescriber):
|
817
818
|
if not issubclass(dyn.cls, Dynamics):
|
818
819
|
raise TypeError(f'The input {dyn} should be an instance of {Dynamics}.')
|
819
820
|
if not self._has_after_update(dyn.identifier):
|
820
|
-
self._add_after_update(
|
821
|
+
self._add_after_update(
|
822
|
+
dyn.identifier,
|
823
|
+
dyn() if ('in_size' in dyn.kwargs or len(dyn.args) > 0) else dyn(in_size=self.varshape)
|
824
|
+
)
|
821
825
|
return self._get_after_update(dyn.identifier)
|
822
826
|
else:
|
823
827
|
raise TypeError(f'The input {dyn} should be an instance of {Dynamics} or a delayed initializer.')
|
824
828
|
|
825
|
-
def prefetch_delay(
|
829
|
+
def prefetch_delay(
|
830
|
+
self,
|
831
|
+
state: str,
|
832
|
+
delay: Optional[ArrayLike] = None
|
833
|
+
) -> 'PrefetchDelayAt':
|
826
834
|
"""
|
827
835
|
Create a reference to a delayed state or variable in the module.
|
828
836
|
|
@@ -840,7 +848,11 @@ class Dynamics(Module, UpdateReturn):
|
|
840
848
|
"""
|
841
849
|
return self.prefetch(state).delay.at(delay)
|
842
850
|
|
843
|
-
def output_delay(
|
851
|
+
def output_delay(
|
852
|
+
self,
|
853
|
+
delay: Optional[ArrayLike] = None,
|
854
|
+
variable_like: PyTree = None
|
855
|
+
) -> 'OutputDelayAt':
|
844
856
|
"""
|
845
857
|
Create a reference to the delayed output of the module.
|
846
858
|
|
@@ -851,6 +863,7 @@ class Dynamics(Module, UpdateReturn):
|
|
851
863
|
Args:
|
852
864
|
delay (Optional[ArrayLike]): The amount of time to delay the output access,
|
853
865
|
typically in time units (e.g., milliseconds). Defaults to None.
|
866
|
+
variable_like:
|
854
867
|
|
855
868
|
Returns:
|
856
869
|
OutputDelayAt: An object that provides access to the module's output at the specified delay time.
|
@@ -1102,8 +1115,6 @@ class OutputDelayAt(Node):
|
|
1102
1115
|
----------
|
1103
1116
|
module : Dynamics
|
1104
1117
|
The dynamics module that contains the referenced state or variable.
|
1105
|
-
item : str
|
1106
|
-
The name of the state or variable to access with delay.
|
1107
1118
|
time : ArrayLike
|
1108
1119
|
The amount of time to delay access by, typically in time units (e.g., milliseconds).
|
1109
1120
|
|
@@ -1113,56 +1124,40 @@ class OutputDelayAt(Node):
|
|
1113
1124
|
>>> import brainunit as u
|
1114
1125
|
>>> neuron = brainstate.nn.LIF(10)
|
1115
1126
|
>>> # Create a reference to voltage delayed by 5ms
|
1116
|
-
>>>
|
1127
|
+
>>> delayed_spike = OutputDelayAt(neuron, 5.0 * u.ms)
|
1117
1128
|
>>> # Get the delayed value
|
1118
|
-
>>> v_value =
|
1129
|
+
>>> v_value = delayed_spike()
|
1119
1130
|
"""
|
1120
1131
|
|
1121
1132
|
def __init__(
|
1122
1133
|
self,
|
1123
1134
|
module: Dynamics,
|
1124
1135
|
time: Optional[ArrayLike] = None,
|
1136
|
+
variable_like: Optional[PyTree] = None,
|
1125
1137
|
):
|
1126
|
-
"""
|
1127
|
-
Initialize a PrefetchDelayAt object.
|
1128
|
-
|
1129
|
-
Parameters
|
1130
|
-
----------
|
1131
|
-
module : AlignPre, Module
|
1132
|
-
The dynamics module that contains the referenced state or variable.
|
1133
|
-
time : ArrayLike
|
1134
|
-
The amount of time to delay access by, typically in time units.
|
1135
|
-
"""
|
1136
1138
|
super().__init__()
|
1137
|
-
assert isinstance(module,
|
1138
|
-
assert isinstance(module, Module), 'The module should be an instance of Module.'
|
1139
|
+
assert isinstance(module, Dynamics), 'The module should be an instance of Dynamics.'
|
1139
1140
|
self.module = module
|
1141
|
+
dt = environ.get_dt()
|
1142
|
+
if time is None:
|
1143
|
+
time = u.math.zeros_like(dt)
|
1140
1144
|
self.time = time
|
1141
|
-
|
1142
|
-
|
1143
|
-
|
1144
|
-
|
1145
|
-
|
1146
|
-
|
1147
|
-
|
1148
|
-
|
1149
|
-
|
1150
|
-
|
1151
|
-
|
1145
|
+
self.step = u.math.asarray(time / dt, dtype=environ.ditype())
|
1146
|
+
|
1147
|
+
# register the delay
|
1148
|
+
key = _get_output_delay_key()
|
1149
|
+
if not module._has_after_update(key):
|
1150
|
+
delay = Delay(
|
1151
|
+
jax.ShapeDtypeStruct(module.out_size, dtype=environ.dftype()),
|
1152
|
+
time,
|
1153
|
+
take_aware_unit=True
|
1154
|
+
)
|
1155
|
+
module._add_after_update(key, receive_update_output(delay))
|
1156
|
+
self.out_delay: Delay = module._get_after_update(key)
|
1157
|
+
self.out_delay.register_delay(time)
|
1152
1158
|
|
1153
1159
|
def __call__(self, *args, **kwargs):
|
1154
|
-
|
1155
|
-
Retrieve the value of the state at the specified delay time.
|
1156
|
-
|
1157
|
-
Returns
|
1158
|
-
-------
|
1159
|
-
Any
|
1160
|
-
The value of the state or variable at the specified delay time.
|
1161
|
-
"""
|
1162
|
-
if self.time is None:
|
1163
|
-
return self.module.update_return()
|
1164
|
-
else:
|
1165
|
-
return self.out_delay.retrieve_at_step(self.step)
|
1160
|
+
return self.out_delay.retrieve_at_step(self.step)
|
1166
1161
|
|
1167
1162
|
|
1168
1163
|
def _get_prefetch_delay_key(item) -> str:
|
brainstate/nn/_exp_euler.py
CHANGED
@@ -69,27 +69,24 @@ def exp_euler_step(
|
|
69
69
|
f'The input data type should be float64, float32, float16, or bfloat16 '
|
70
70
|
f'when using Exponential Euler method. But we got {args[0].dtype}.'
|
71
71
|
)
|
72
|
+
|
73
|
+
# drift
|
72
74
|
dt = environ.get('dt')
|
73
75
|
linear, derivative = vector_grad(fn, argnums=0, return_value=True)(*args, **kwargs)
|
74
76
|
linear = u.Quantity(u.get_mantissa(linear), u.get_unit(derivative) / u.get_unit(linear))
|
75
77
|
phi = u.math.exprel(dt * linear)
|
76
78
|
x_next = args[0] + dt * phi * derivative
|
77
79
|
|
80
|
+
# diffusion
|
78
81
|
if diffusion is not None:
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
# "Drift unit is {drift}, diffusion unit is {diffusion}, ",
|
90
|
-
# drift=drift_unit, diffusion=diffusion_unit * time_unit ** 0.5
|
91
|
-
# )
|
92
|
-
|
93
|
-
# diffusion
|
94
|
-
x_next += diffusion * u.math.sqrt(dt) * random.randn_like(args[0])
|
82
|
+
diffusion_part = diffusion(*args, **kwargs) * u.math.sqrt(dt) * random.randn_like(args[0])
|
83
|
+
if u.get_dim(x_next) != u.get_dim(diffusion_part):
|
84
|
+
drift_unit = u.get_unit(x_next)
|
85
|
+
time_unit = u.get_unit(dt)
|
86
|
+
raise ValueError(
|
87
|
+
f"Drift unit is {drift_unit}, "
|
88
|
+
f"expected diffusion unit is {drift_unit / time_unit ** 0.5}, "
|
89
|
+
f"but we got {u.get_unit(diffusion_part)}."
|
90
|
+
)
|
91
|
+
x_next += diffusion_part
|
95
92
|
return x_next
|
@@ -16,19 +16,20 @@
|
|
16
16
|
|
17
17
|
from typing import Union, Callable, Optional
|
18
18
|
|
19
|
+
import brainevent
|
19
20
|
import brainunit as u
|
20
21
|
import jax
|
21
22
|
import jax.numpy as jnp
|
22
23
|
import numpy as np
|
23
24
|
|
24
25
|
from brainstate import random, augment, environ, init
|
25
|
-
from brainstate.
|
26
|
-
from brainstate._state import ParamState
|
26
|
+
from brainstate._state import ParamState, FakeState
|
27
27
|
from brainstate.compile import for_loop
|
28
28
|
from brainstate.typing import Size, ArrayLike
|
29
29
|
from ._module import Module
|
30
30
|
|
31
31
|
__all__ = [
|
32
|
+
'FixedNumConn',
|
32
33
|
'EventFixedNumConn',
|
33
34
|
'EventFixedProb',
|
34
35
|
]
|
@@ -44,12 +45,11 @@ def init_indices_without_replace(
|
|
44
45
|
rng = random.default_rng(seed)
|
45
46
|
|
46
47
|
if method == 'vmap':
|
47
|
-
@augment.vmap
|
48
|
-
def rand_indices(
|
49
|
-
rng.set_key(key)
|
48
|
+
@augment.vmap(axis_size=n_pre)
|
49
|
+
def rand_indices():
|
50
50
|
return rng.choice(n_post, size=(conn_num,), replace=False)
|
51
51
|
|
52
|
-
return rand_indices(
|
52
|
+
return rand_indices()
|
53
53
|
|
54
54
|
elif method == 'for_loop':
|
55
55
|
return for_loop(
|
@@ -61,9 +61,9 @@ def init_indices_without_replace(
|
|
61
61
|
raise ValueError(f"Unknown method: {method}")
|
62
62
|
|
63
63
|
|
64
|
-
class
|
64
|
+
class FixedNumConn(Module):
|
65
65
|
"""
|
66
|
-
The
|
66
|
+
The ``FixedNumConn`` module implements a fixed probability connection with CSR sparse data structure.
|
67
67
|
|
68
68
|
Parameters
|
69
69
|
----------
|
@@ -77,7 +77,7 @@ class EventFixedNumConn(Module):
|
|
77
77
|
If it is an integer, representing the number of connections.
|
78
78
|
conn_weight : float or callable or jax.Array or brainunit.Quantity
|
79
79
|
Maximum synaptic conductance, i.e., synaptic weight.
|
80
|
-
|
80
|
+
efferent_target : str, optional
|
81
81
|
The target of the connection. Default is 'post', meaning that each pre-synaptic neuron connects to
|
82
82
|
a fixed number of post-synaptic neurons. The connection number is determined by the value of ``n_conn``.
|
83
83
|
|
@@ -104,7 +104,8 @@ class EventFixedNumConn(Module):
|
|
104
104
|
out_size: Size,
|
105
105
|
conn_num: Union[int, float],
|
106
106
|
conn_weight: Union[Callable, ArrayLike],
|
107
|
-
|
107
|
+
efferent_target: str = 'post', # 'pre' or 'post'
|
108
|
+
afferent_ratio: Union[int, float] = 1.,
|
108
109
|
allow_multi_conn: bool = True,
|
109
110
|
seed: Optional[int] = None,
|
110
111
|
name: Optional[str] = None,
|
@@ -116,11 +117,14 @@ class EventFixedNumConn(Module):
|
|
116
117
|
# network parameters
|
117
118
|
self.in_size = in_size
|
118
119
|
self.out_size = out_size
|
119
|
-
self.
|
120
|
-
assert
|
120
|
+
self.efferent_target = efferent_target
|
121
|
+
assert efferent_target in ('pre', 'post'), 'The target of the connection must be either "pre" or "post".'
|
122
|
+
assert 0. <= afferent_ratio <= 1., 'Afferent ratio must be in [0, 1].'
|
121
123
|
if isinstance(conn_num, float):
|
122
124
|
assert 0. <= conn_num <= 1., 'Connection probability must be in [0, 1].'
|
123
|
-
conn_num = int(self.out_size[-1] * conn_num)
|
125
|
+
conn_num = (int(self.out_size[-1] * conn_num)
|
126
|
+
if efferent_target == 'post' else
|
127
|
+
int(self.in_size[-1] * conn_num))
|
124
128
|
assert isinstance(conn_num, int), 'Connection number must be an integer.'
|
125
129
|
self.conn_num = conn_num
|
126
130
|
self.seed = seed
|
@@ -128,14 +132,13 @@ class EventFixedNumConn(Module):
|
|
128
132
|
|
129
133
|
# connections
|
130
134
|
if self.conn_num >= 1:
|
131
|
-
if self.
|
135
|
+
if self.efferent_target == 'post':
|
132
136
|
n_post = self.out_size[-1]
|
133
137
|
n_pre = self.in_size[-1]
|
134
138
|
else:
|
135
139
|
n_post = self.in_size[-1]
|
136
140
|
n_pre = self.out_size[-1]
|
137
141
|
|
138
|
-
# indices of post connected neurons
|
139
142
|
with jax.ensure_compile_time_eval():
|
140
143
|
if allow_multi_conn:
|
141
144
|
rng = np.random if seed is None else np.random.RandomState(seed)
|
@@ -143,15 +146,83 @@ class EventFixedNumConn(Module):
|
|
143
146
|
else:
|
144
147
|
indices = init_indices_without_replace(self.conn_num, n_pre, n_post, seed, conn_init)
|
145
148
|
indices = u.math.asarray(indices, dtype=environ.ditype())
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
149
|
+
|
150
|
+
if afferent_ratio == 1.:
|
151
|
+
conn_weight = u.math.asarray(init.param(conn_weight, (n_pre, self.conn_num), allow_none=False))
|
152
|
+
self.weight = param_type(conn_weight)
|
153
|
+
csr = (
|
154
|
+
brainevent.FixedPostNumConn((conn_weight, indices), shape=(n_pre, n_post))
|
155
|
+
if self.efferent_target == 'post' else
|
156
|
+
brainevent.FixedPreNumConn((conn_weight, indices), shape=(n_pre, n_post))
|
157
|
+
)
|
158
|
+
self.conn = csr
|
159
|
+
|
160
|
+
else:
|
161
|
+
self.pre_selected = np.random.random(n_pre) < afferent_ratio
|
162
|
+
indices = indices[self.pre_selected].flatten()
|
163
|
+
conn_weight = u.math.asarray(init.param(conn_weight, (indices.size,), allow_none=False))
|
164
|
+
self.weight = param_type(conn_weight)
|
165
|
+
indptr = (jnp.arange(1, n_pre + 1) * self.conn_num -
|
166
|
+
jnp.cumsum(~self.pre_selected) * self.conn_num)
|
167
|
+
indptr = jnp.insert(indptr, 0, 0) # insert 0 at the beginning
|
168
|
+
csr = (
|
169
|
+
brainevent.CSR((conn_weight, indices, indptr), shape=(n_pre, n_post))
|
170
|
+
if self.efferent_target == 'post' else
|
171
|
+
brainevent.CSC((conn_weight, indices, indptr), shape=(n_pre, n_post))
|
172
|
+
)
|
173
|
+
self.conn = csr
|
174
|
+
|
175
|
+
else:
|
176
|
+
conn_weight = u.math.asarray(init.param(conn_weight, (), allow_none=False))
|
177
|
+
self.weight = FakeState(conn_weight)
|
178
|
+
|
179
|
+
def update(self, x: jax.Array) -> Union[jax.Array, u.Quantity]:
|
180
|
+
if self.conn_num >= 1:
|
181
|
+
csr = self.conn.with_data(self.weight.value)
|
182
|
+
return x @ csr
|
183
|
+
else:
|
184
|
+
weight = self.weight.value
|
185
|
+
r = u.math.zeros(x.shape[:-1] + (self.out_size[-1],), dtype=weight.dtype)
|
186
|
+
r = u.maybe_decimal(u.Quantity(r, unit=u.get_unit(weight)))
|
187
|
+
return u.math.asarray(r, dtype=environ.dftype())
|
188
|
+
|
189
|
+
|
190
|
+
class EventFixedNumConn(FixedNumConn):
|
191
|
+
"""
|
192
|
+
The FixedProb module implements a fixed probability connection with CSR sparse data structure.
|
193
|
+
|
194
|
+
Parameters
|
195
|
+
----------
|
196
|
+
in_size : Size
|
197
|
+
Number of pre-synaptic neurons, i.e., input size.
|
198
|
+
out_size : Size
|
199
|
+
Number of post-synaptic neurons, i.e., output size.
|
200
|
+
conn_num : float, int
|
201
|
+
If it is a float, representing the probability of connection, i.e., connection probability.
|
202
|
+
|
203
|
+
If it is an integer, representing the number of connections.
|
204
|
+
conn_weight : float or callable or jax.Array or brainunit.Quantity
|
205
|
+
Maximum synaptic conductance, i.e., synaptic weight.
|
206
|
+
conn_target : str, optional
|
207
|
+
The target of the connection. Default is 'post', meaning that each pre-synaptic neuron connects to
|
208
|
+
a fixed number of post-synaptic neurons. The connection number is determined by the value of ``n_conn``.
|
209
|
+
|
210
|
+
If 'pre', each post-synaptic neuron connects to a fixed number of pre-synaptic neurons.
|
211
|
+
conn_init : str, optional
|
212
|
+
The initialization method of the connection weight. Default is 'vmap', meaning that the connection weight
|
213
|
+
is initialized by parallelized across multiple threads.
|
214
|
+
|
215
|
+
If 'for_loop', the connection weight is initialized by a for loop.
|
216
|
+
allow_multi_conn : bool, optional
|
217
|
+
Whether multiple connections are allowed from a single pre-synaptic neuron.
|
218
|
+
Default is True, meaning that a value of ``a`` can be selected multiple times.
|
219
|
+
seed: int, optional
|
220
|
+
Random seed. Default is None. If None, the default random seed will be used.
|
221
|
+
name : str, optional
|
222
|
+
Name of the module.
|
223
|
+
"""
|
224
|
+
|
225
|
+
__module__ = 'brainstate.nn'
|
155
226
|
|
156
227
|
def update(self, spk: jax.Array) -> Union[jax.Array, u.Quantity]:
|
157
228
|
if self.conn_num >= 1:
|