brainmass 0.0.1__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- brainmass/__init__.py +29 -0
- brainmass/bold.py +143 -0
- brainmass/coupling.py +171 -0
- brainmass/integration.py +253 -0
- brainmass/noise.py +77 -0
- brainmass/noise_test.py +35 -0
- brainmass/wilson_cowan.py +141 -0
- brainmass-0.0.1.dist-info/LICENSE +202 -0
- brainmass-0.0.1.dist-info/METADATA +255 -0
- brainmass-0.0.1.dist-info/RECORD +69 -0
- brainmass-0.0.1.dist-info/WHEEL +6 -0
- brainmass-0.0.1.dist-info/top_level.txt +2 -0
- examples/datasets/README.md +4 -0
- examples/datasets/__init__.py +21 -0
- examples/datasets/gw/NAP_001/functional/BOLD_rsfMRI.mat +0 -0
- examples/datasets/gw/NAP_001/structural/DTI_CM.mat +0 -0
- examples/datasets/gw/NAP_001/structural/DTI_LEN.mat +0 -0
- examples/datasets/gw/NAP_002/functional/BOLD_rsfMRI.mat +0 -0
- examples/datasets/gw/NAP_002/structural/DTI_CM.mat +0 -0
- examples/datasets/gw/NAP_002/structural/DTI_LEN.mat +0 -0
- examples/datasets/gw/NAP_007/functional/BOLD_rsfMRI.mat +0 -0
- examples/datasets/gw/NAP_007/structural/DTI_CM.mat +0 -0
- examples/datasets/gw/NAP_007/structural/DTI_LEN.mat +0 -0
- examples/datasets/gw/NAP_009/functional/BOLD_rsfMRI.mat +0 -0
- examples/datasets/gw/NAP_009/structural/DTI_CM.mat +0 -0
- examples/datasets/gw/NAP_009/structural/DTI_LEN.mat +0 -0
- examples/datasets/gw/NAP_013/functional/BOLD_rsfMRI.mat +0 -0
- examples/datasets/gw/NAP_013/structural/DTI_CM.mat +0 -0
- examples/datasets/gw/NAP_013/structural/DTI_LEN.mat +0 -0
- examples/datasets/hcp/101309/functional/TC_rsfMRI_REST1_LR.mat +0 -0
- examples/datasets/hcp/101309/structural/DTI_CM.mat +0 -0
- examples/datasets/hcp/101309/structural/DTI_LEN.mat +0 -0
- examples/datasets/hcp/101309/structural/nvoxel.txt +94 -0
- examples/datasets/hcp/101309/structural/waytotal.txt +94 -0
- examples/datasets/hcp/102311/functional/TC_rsfMRI_REST1_LR.mat +0 -0
- examples/datasets/hcp/102311/structural/DTI_CM.mat +0 -0
- examples/datasets/hcp/102311/structural/DTI_LEN.mat +0 -0
- examples/datasets/hcp/102311/structural/nvoxel.txt +94 -0
- examples/datasets/hcp/102311/structural/waytotal.txt +94 -0
- examples/datasets/hcp/102816/functional/TC_rsfMRI_REST1_LR.mat +0 -0
- examples/datasets/hcp/102816/structural/DTI_CM.mat +0 -0
- examples/datasets/hcp/102816/structural/DTI_LEN.mat +0 -0
- examples/datasets/hcp/102816/structural/nvoxel.txt +94 -0
- examples/datasets/hcp/102816/structural/waytotal.txt +94 -0
- examples/datasets/hcp/131217/functional/TC_rsfMRI_REST1_LR.mat +0 -0
- examples/datasets/hcp/131217/structural/DTI_CM.mat +0 -0
- examples/datasets/hcp/131217/structural/DTI_LEN.mat +0 -0
- examples/datasets/hcp/131217/structural/nvoxel.txt +94 -0
- examples/datasets/hcp/131217/structural/waytotal.txt +94 -0
- examples/datasets/hcp/211619/functional/TC_rsfMRI_REST1_LR.mat +0 -0
- examples/datasets/hcp/211619/structural/DTI_CM.mat +0 -0
- examples/datasets/hcp/211619/structural/DTI_LEN.mat +0 -0
- examples/datasets/hcp/211619/structural/nvoxel.txt +94 -0
- examples/datasets/hcp/211619/structural/waytotal.txt +94 -0
- examples/datasets/hcp/213522/functional/TC_rsfMRI_REST1_LR.mat +0 -0
- examples/datasets/hcp/213522/structural/DTI_CM.mat +0 -0
- examples/datasets/hcp/213522/structural/DTI_LEN.mat +0 -0
- examples/datasets/hcp/213522/structural/nvoxel.txt +94 -0
- examples/datasets/hcp/213522/structural/waytotal.txt +94 -0
- examples/datasets/hcp/377451/functional/TC_rsfMRI_REST1_LR.mat +0 -0
- examples/datasets/hcp/377451/structural/DTI_CM.mat +0 -0
- examples/datasets/hcp/377451/structural/DTI_LEN.mat +0 -0
- examples/datasets/hcp/377451/structural/nvoxel.txt +94 -0
- examples/datasets/hcp/377451/structural/waytotal.txt +94 -0
- examples/datasets/load_data.py +318 -0
- examples/parameter-exploration.ipynb +301 -0
- examples/rww_pytorch_model.py +1204 -0
- examples/the_model.py +89 -0
- examples/wilsonwowan-osillator.ipynb +374 -0
brainmass/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
__version__ = "0.0.1"
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
'DiffusiveCoupling',
|
|
20
|
+
'AdditiveCoupling',
|
|
21
|
+
'WilsonCowanModel',
|
|
22
|
+
'OUProcess',
|
|
23
|
+
'BOLDSignal',
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
from .bold import *
|
|
27
|
+
from .coupling import *
|
|
28
|
+
from .noise import *
|
|
29
|
+
from .wilson_cowan import *
|
brainmass/bold.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Union, Callable
|
|
17
|
+
|
|
18
|
+
import brainstate
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
|
|
21
|
+
from .integration import ode_rk2_step
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
'BOLDSignal',
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BOLDSignal(brainstate.nn.Dynamics):
|
|
29
|
+
r"""
|
|
30
|
+
Balloon-Windkessel hemodynamic model of Friston et al. (2003) [1]_.
|
|
31
|
+
|
|
32
|
+
The Balloon-Windkessel model describes the coupling of perfusion to BOLD signal, with a
|
|
33
|
+
dynamical model of the transduction of neuronal activity into perfusion changes. The
|
|
34
|
+
model assumes that the BOLD signal is a static nonlinear function of the normalized total
|
|
35
|
+
deoxyhemoglobin voxel content, normalized venous volume, resting net oxygen extraction
|
|
36
|
+
fraction by the capillary bed, and resting blood volume fraction. The BOLD-signal estimation
|
|
37
|
+
for each brain area is computed by the level of synaptic activity in that particular cortical
|
|
38
|
+
area, noted $z_i$ for a given cortical are $i$.
|
|
39
|
+
|
|
40
|
+
For the i-th region, synaptic activity $z_i$ causes an increase in a vasodilatory signal $x_i$
|
|
41
|
+
that is subject to autoregulatory feedback. Inflow $f_i$ responds in proportion to this signal
|
|
42
|
+
with concomitant changes in blood volume $v_i$ and deoxyhemoglobin content $q_i$. The equations
|
|
43
|
+
relating these biophysical processes are as follows:
|
|
44
|
+
|
|
45
|
+
$$
|
|
46
|
+
\begin{gathered}
|
|
47
|
+
\dot{x}_i=z_i-k_i x_i-\gamma_i\left(f_i-1\right) \\
|
|
48
|
+
\dot{f}_i=x_i \\
|
|
49
|
+
\tau_i \dot{v}_i=f_i-v_i^{1 / \alpha} \\
|
|
50
|
+
\tau_i \dot{q}_i=\frac{f_i}{\rho}\left[1-(1-\rho)^{1 / f_i}\right]-q_i v_i^{1 / \alpha-1},
|
|
51
|
+
\end{gathered}
|
|
52
|
+
$$
|
|
53
|
+
|
|
54
|
+
where $\rho$ is the resting oxygen extraction fraction. The BOLD signal is given by the following:
|
|
55
|
+
|
|
56
|
+
$$
|
|
57
|
+
\mathrm{BOLD}_i=V_0\left[k_1\left(1-q_i\right)+k_2\left(1-q_i / v_i\right)+k_3\left(1-v_i\right)\right],
|
|
58
|
+
$$
|
|
59
|
+
|
|
60
|
+
where $V_0 = 0.02, k1 = 7\rho, k2 = 2$, and $k3 = 2\rho − 0.2$. All biophysical parameters were taken
|
|
61
|
+
as in Friston et al. (2003) [1]_. The BOLD model converts the local synaptic activity of a given cortical
|
|
62
|
+
area into an observable BOLD signal and does not actively couple the signals from other cortical areas.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
in_size : int
|
|
67
|
+
Size of the input vector (number of brain regions).
|
|
68
|
+
gamma : float or callable, optional
|
|
69
|
+
Rate of signal decay (default is 0.41).
|
|
70
|
+
k : float or callable, optional
|
|
71
|
+
Rate of flow-dependent elimination (default is 0.65).
|
|
72
|
+
alpha : float or callable, optional
|
|
73
|
+
Grubb's exponent (default is 0.32).
|
|
74
|
+
tau : float or callable, optional
|
|
75
|
+
Hemodynamic transit time (default is 0.98).
|
|
76
|
+
rho : float or callable, optional
|
|
77
|
+
Resting oxygen extraction fraction (default is 0.34).
|
|
78
|
+
V0 : float, optional
|
|
79
|
+
Resting blood volume fraction (default is 0.02).
|
|
80
|
+
|
|
81
|
+
References
|
|
82
|
+
----------
|
|
83
|
+
.. [1] Friston KJ, Harrison L, Penny W (2003) Dynamic causal modelling. Neuroimage 19:1273–1302,
|
|
84
|
+
doi:10.1016/S1053-8119(03)00202-7
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
in_size,
|
|
90
|
+
gamma: Union[brainstate.typing.ArrayLike, Callable] = 0.41,
|
|
91
|
+
k: Union[brainstate.typing.ArrayLike, Callable] = 0.65,
|
|
92
|
+
alpha: Union[brainstate.typing.ArrayLike, Callable] = 0.32,
|
|
93
|
+
tau: Union[brainstate.typing.ArrayLike, Callable] = 0.98,
|
|
94
|
+
rho: Union[brainstate.typing.ArrayLike, Callable] = 0.34,
|
|
95
|
+
V0: float = 0.02,
|
|
96
|
+
):
|
|
97
|
+
super().__init__(in_size)
|
|
98
|
+
|
|
99
|
+
self.gamma = brainstate.init.param(gamma, self.varshape)
|
|
100
|
+
self.k = brainstate.init.param(k, self.varshape)
|
|
101
|
+
self.alpha = brainstate.init.param(alpha, self.varshape)
|
|
102
|
+
self.tau = brainstate.init.param(tau, self.varshape)
|
|
103
|
+
self.rho = brainstate.init.param(rho, self.varshape)
|
|
104
|
+
|
|
105
|
+
self.V0 = V0
|
|
106
|
+
self.k1 = 7 * self.rho
|
|
107
|
+
self.k2 = 2.
|
|
108
|
+
self.k3 = 2 * self.rho - 0.2
|
|
109
|
+
|
|
110
|
+
self.init = brainstate.init.Constant(1.)
|
|
111
|
+
|
|
112
|
+
def init_state(self, batch_size=None, **kwargs):
|
|
113
|
+
self.x = brainstate.HiddenState(brainstate.init.param(self.init, self.varshape, batch_size))
|
|
114
|
+
self.f = brainstate.HiddenState(brainstate.init.param(self.init, self.varshape, batch_size))
|
|
115
|
+
self.v = brainstate.HiddenState(brainstate.init.param(self.init, self.varshape, batch_size))
|
|
116
|
+
self.q = brainstate.HiddenState(brainstate.init.param(self.init, self.varshape, batch_size))
|
|
117
|
+
|
|
118
|
+
def reset_state(self, batch_size=None, **kwargs):
|
|
119
|
+
self.x.value = brainstate.init.param(self.init, self.varshape, batch_size)
|
|
120
|
+
self.f.value = brainstate.init.param(self.init, self.varshape, batch_size)
|
|
121
|
+
self.v.value = brainstate.init.param(self.init, self.varshape, batch_size)
|
|
122
|
+
self.q.value = brainstate.init.param(self.init, self.varshape, batch_size)
|
|
123
|
+
|
|
124
|
+
def derivative(self, y, t, z):
|
|
125
|
+
x, f, v, q = y
|
|
126
|
+
dx = z - self.k * x - self.gamma * (f - 1)
|
|
127
|
+
df = x
|
|
128
|
+
dv = (f - jnp.power(v, 1 / self.alpha)) / self.tau
|
|
129
|
+
E = 1 - jnp.power(1 - self.rho, 1 / f)
|
|
130
|
+
dq = (f * E / self.rho - jnp.power(v, 1 / self.alpha) * q / v) / self.tau
|
|
131
|
+
return dx, df, dv, dq
|
|
132
|
+
|
|
133
|
+
def update(self, z):
|
|
134
|
+
x, f, v, q = ode_rk2_step(self.derivative, (self.x.value, self.f.value, self.v.value, self.q.value), 0., z)
|
|
135
|
+
self.x.value = x
|
|
136
|
+
self.f.value = f
|
|
137
|
+
self.v.value = v
|
|
138
|
+
self.q.value = q
|
|
139
|
+
|
|
140
|
+
def bold(self):
|
|
141
|
+
return self.V0 * (self.k1 * (1 - self.q.value) +
|
|
142
|
+
self.k2 * (1 - self.q.value / self.rho) +
|
|
143
|
+
self.k3 * (1 - self.v.value))
|
brainmass/coupling.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
from typing import Union
|
|
18
|
+
|
|
19
|
+
import brainstate
|
|
20
|
+
import brainunit as u
|
|
21
|
+
from brainstate.nn._dynamics import maybe_init_prefetch
|
|
22
|
+
|
|
23
|
+
Prefetch = Union[brainstate.nn.PrefetchDelayAt, brainstate.nn.PrefetchDelay, brainstate.nn.Prefetch]
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
'DiffusiveCoupling',
|
|
27
|
+
'AdditiveCoupling',
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class DiffusiveCoupling(brainstate.nn.Module):
|
|
32
|
+
r"""
|
|
33
|
+
Diffusive coupling.
|
|
34
|
+
|
|
35
|
+
This class implements a diffusive coupling mechanism for neural network modules.
|
|
36
|
+
It simulates the following model:
|
|
37
|
+
|
|
38
|
+
$$
|
|
39
|
+
\mathrm{current}_i = k * \sum_j (g_{ij} * x_{D_{ij}} - y_i)
|
|
40
|
+
$$
|
|
41
|
+
|
|
42
|
+
where:
|
|
43
|
+
- $\mathrm{current}_i$: the output current for neuron $i$
|
|
44
|
+
- $g_{ij}$: the connection strength between neuron $i$ and neuron $j$
|
|
45
|
+
- $x_{D_{ij}}$: the delayed state variable for neuron $j$, as seen by neuron $i$
|
|
46
|
+
- $y_i$: the state variable for neuron i
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
x : Prefetch
|
|
51
|
+
The delayed state variable for the source units.
|
|
52
|
+
y : Prefetch
|
|
53
|
+
The delayed state variable for the target units.
|
|
54
|
+
conn : brainstate.typing.Array
|
|
55
|
+
The connection matrix (1D or 2D array) specifying the coupling strengths between units.
|
|
56
|
+
k: float
|
|
57
|
+
The global coupling strength. Default is 1.0.
|
|
58
|
+
|
|
59
|
+
Attributes
|
|
60
|
+
----------
|
|
61
|
+
x : Prefetch
|
|
62
|
+
The delayed state variable for the source units.
|
|
63
|
+
y : Prefetch
|
|
64
|
+
The delayed state variable for the target units.
|
|
65
|
+
conn : Array
|
|
66
|
+
The connection matrix.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
x: Prefetch,
|
|
72
|
+
y: Prefetch,
|
|
73
|
+
conn: brainstate.typing.Array,
|
|
74
|
+
k: float = 1.0
|
|
75
|
+
):
|
|
76
|
+
super().__init__()
|
|
77
|
+
assert isinstance(x, Prefetch), f'The first element must be a Prefetch. But got {type(x)}.'
|
|
78
|
+
assert isinstance(y, Prefetch), f'The second element must be a Prefetch. But got {type(y)}.'
|
|
79
|
+
self.x = x
|
|
80
|
+
self.y = y
|
|
81
|
+
self.k = k
|
|
82
|
+
|
|
83
|
+
# Connection matrix
|
|
84
|
+
self.conn = u.math.asarray(conn)
|
|
85
|
+
assert self.conn.ndim in (1, 2), f'Only support 1d, 2d connection matrix. But we got {self.conn.ndim}d.'
|
|
86
|
+
|
|
87
|
+
@brainstate.nn.call_order(2)
|
|
88
|
+
def init_state(self, *args, **kwargs):
|
|
89
|
+
maybe_init_prefetch(self.x)
|
|
90
|
+
maybe_init_prefetch(self.y)
|
|
91
|
+
|
|
92
|
+
def update(self):
|
|
93
|
+
delayed_x = self.x()
|
|
94
|
+
y = u.math.expand_dims(self.y(), axis=1) # (..., 1)
|
|
95
|
+
if self.conn.ndim == 1:
|
|
96
|
+
assert self.conn.size == delayed_x.shape[-1], (
|
|
97
|
+
f'Connection matrix size {self.conn.size} does not '
|
|
98
|
+
f'match the variable size {delayed_x.shape[-1]}.'
|
|
99
|
+
)
|
|
100
|
+
diffusive = (self.conn * delayed_x).reshape(y.shape[0], -1) - y
|
|
101
|
+
elif self.conn.ndim == 2:
|
|
102
|
+
delayed_x = delayed_x.reshape(y.shape[0], -1)
|
|
103
|
+
assert self.conn.shape == delayed_x.shape, (f'Connection matrix shape {self.conn.shape} does not '
|
|
104
|
+
f'match the variable shape {delayed_x.shape}.')
|
|
105
|
+
diffusive = (self.conn * delayed_x) - y
|
|
106
|
+
else:
|
|
107
|
+
raise NotImplementedError(f'Only support 1d, 2d connection matrix. But we got {self.conn.ndim}d.')
|
|
108
|
+
return self.k * diffusive.sum(axis=1)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class AdditiveCoupling(brainstate.nn.Module):
|
|
112
|
+
r"""
|
|
113
|
+
Additive coupling.
|
|
114
|
+
|
|
115
|
+
This class implements an additive coupling mechanism for neural network modules.
|
|
116
|
+
It simulates the following model:
|
|
117
|
+
|
|
118
|
+
$$
|
|
119
|
+
\mathrm{current}_i = k * \sum_j g_{ij} * x_{D_{ij}}
|
|
120
|
+
$$
|
|
121
|
+
|
|
122
|
+
where:
|
|
123
|
+
- $\mathrm{current}_i$: the output current for neuron $i$
|
|
124
|
+
- $g_{ij}$: the connection strength between neuron $i$ and neuron $j$
|
|
125
|
+
- $x_{D_{ij}}$: the delayed state variable for neuron $j$, as seen by neuron $i$
|
|
126
|
+
|
|
127
|
+
Parameters
|
|
128
|
+
----------
|
|
129
|
+
x : Prefetch
|
|
130
|
+
The delayed state variable for the source units.
|
|
131
|
+
conn : brainstate.typing.Array
|
|
132
|
+
The connection matrix (1D or 2D array) specifying the coupling strengths between units.
|
|
133
|
+
k: float
|
|
134
|
+
The global coupling strength. Default is 1.0.
|
|
135
|
+
|
|
136
|
+
Attributes
|
|
137
|
+
----------
|
|
138
|
+
x : Prefetch
|
|
139
|
+
The delayed state variable for the source units.
|
|
140
|
+
conn : Array
|
|
141
|
+
The connection matrix.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
def __init__(
|
|
145
|
+
self,
|
|
146
|
+
x: Prefetch,
|
|
147
|
+
conn: brainstate.typing.Array,
|
|
148
|
+
k: float = 1.0
|
|
149
|
+
):
|
|
150
|
+
super().__init__()
|
|
151
|
+
assert isinstance(x, Prefetch), f'The first element must be a Prefetch. But got {type(x)}.'
|
|
152
|
+
self.x = x
|
|
153
|
+
self.k = k
|
|
154
|
+
|
|
155
|
+
# Connection matrix
|
|
156
|
+
self.conn = u.math.asarray(conn)
|
|
157
|
+
assert self.conn.ndim == 2, f'Only support 2d connection matrix. But we got {self.conn.ndim}d.'
|
|
158
|
+
|
|
159
|
+
@brainstate.nn.call_order(2)
|
|
160
|
+
def init_state(self, *args, **kwargs):
|
|
161
|
+
maybe_init_prefetch(self.x)
|
|
162
|
+
|
|
163
|
+
def update(self):
|
|
164
|
+
delayed_x = self.x()
|
|
165
|
+
assert self.conn.size == delayed_x.size, (
|
|
166
|
+
f'Connection matrix size {self.conn.size} does not '
|
|
167
|
+
f'match the variable size {delayed_x.size}.'
|
|
168
|
+
)
|
|
169
|
+
delayed_x = delayed_x.reshape(self.conn.shape)
|
|
170
|
+
diffusive = self.conn * delayed_x
|
|
171
|
+
return self.k * diffusive.sum(axis=1)
|
brainmass/integration.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Callable
|
|
17
|
+
|
|
18
|
+
import brainstate
|
|
19
|
+
import jax
|
|
20
|
+
import jax.numpy as jnp
|
|
21
|
+
from brainstate.typing import PyTree
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
'ode_euler_step',
|
|
25
|
+
'ode_rk2_step',
|
|
26
|
+
'ode_rk3_step',
|
|
27
|
+
'ode_rk4_step',
|
|
28
|
+
'sde_euler_step',
|
|
29
|
+
'sde_milstein_step',
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
ODE = Callable[[PyTree, float, ...], PyTree]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def ode_euler_step(f: ODE, y: PyTree, t, *args):
|
|
36
|
+
"""
|
|
37
|
+
Euler method for solving ordinary differential equations.
|
|
38
|
+
|
|
39
|
+
The Euler method is the simplest numerical method for solving ODEs of the form:
|
|
40
|
+
dy/dt = f(y, t)
|
|
41
|
+
|
|
42
|
+
The method approximates the solution using:
|
|
43
|
+
y_{n+1} = y_n + dt * f(y_n, t_n)
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
f (ODE): The differential equation function dy/dt = f(y, t, *args)
|
|
47
|
+
y (PyTree): Current state vector
|
|
48
|
+
t (float): Current time
|
|
49
|
+
*args: Additional arguments passed to function f
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
PyTree: Updated state vector y_{n+1}
|
|
53
|
+
|
|
54
|
+
Note:
|
|
55
|
+
This is a first-order method with O(dt) local truncation error.
|
|
56
|
+
It's the least accurate but most computationally efficient method.
|
|
57
|
+
"""
|
|
58
|
+
dt = brainstate.environ.get_dt()
|
|
59
|
+
k1 = f(y, t, *args)
|
|
60
|
+
return jax.tree.map(lambda x, _k1: x + dt * _k1, y, k1)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def ode_rk2_step(f: ODE, y: PyTree, t, *args):
|
|
64
|
+
"""
|
|
65
|
+
Second-order Runge-Kutta method (RK2) for solving ODEs.
|
|
66
|
+
|
|
67
|
+
Also known as the midpoint method or Heun's method, this method provides
|
|
68
|
+
better accuracy than Euler by using two function evaluations:
|
|
69
|
+
|
|
70
|
+
k1 = f(y_n, t_n)
|
|
71
|
+
k2 = f(y_n + dt*k1, t_n + dt)
|
|
72
|
+
y_{n+1} = y_n + dt/2 * (k1 + k2)
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
f (ODE): The differential equation function dy/dt = f(y, t, *args)
|
|
76
|
+
y (PyTree): Current state vector
|
|
77
|
+
t (float): Current time
|
|
78
|
+
*args: Additional arguments passed to function f
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
PyTree: Updated state vector y_{n+1}
|
|
82
|
+
|
|
83
|
+
Note:
|
|
84
|
+
This is a second-order method with O(dt²) local truncation error.
|
|
85
|
+
More accurate than Euler with only one additional function evaluation.
|
|
86
|
+
"""
|
|
87
|
+
dt = brainstate.environ.get_dt()
|
|
88
|
+
k1 = f(y, t, *args)
|
|
89
|
+
k2 = f(jax.tree.map(lambda x, k: x + dt * k, y, k1), t + dt, *args)
|
|
90
|
+
return jax.tree.map(lambda x, _k1, _k2: x + dt / 2 * (_k1 + _k2), y, k1, k2)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def ode_rk3_step(f: ODE, y: PyTree, t, *args):
|
|
94
|
+
"""
|
|
95
|
+
Third-order Runge-Kutta method (RK3) for solving ODEs.
|
|
96
|
+
|
|
97
|
+
This method uses three function evaluations to achieve third-order accuracy:
|
|
98
|
+
|
|
99
|
+
k1 = f(y_n, t_n)
|
|
100
|
+
k2 = f(y_n + dt/2*k1, t_n + dt/2)
|
|
101
|
+
k3 = f(y_n - dt*k1 + 2*dt*k2, t_n + dt)
|
|
102
|
+
y_{n+1} = y_n + dt/6 * (k1 + 4*k2 + k3)
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
f (ODE): The differential equation function dy/dt = f(y, t, *args)
|
|
106
|
+
y (PyTree): Current state vector
|
|
107
|
+
t (float): Current time
|
|
108
|
+
*args: Additional arguments passed to function f
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
PyTree: Updated state vector y_{n+1}
|
|
112
|
+
|
|
113
|
+
Note:
|
|
114
|
+
This is a third-order method with O(dt³) local truncation error.
|
|
115
|
+
More accurate than RK2 but requires one additional function evaluation.
|
|
116
|
+
"""
|
|
117
|
+
dt = brainstate.environ.get_dt()
|
|
118
|
+
k1 = f(y, t, *args)
|
|
119
|
+
k2 = f(jax.tree.map(lambda x, k: x + dt / 2 * k, y, k1), t + dt / 2, *args)
|
|
120
|
+
k3 = f(jax.tree.map(lambda x, k1_val, k2_val: x - dt * k1_val + 2 * dt * k2_val, y, k1, k2), t + dt, *args)
|
|
121
|
+
return jax.tree.map(lambda x, _k1, _k2, _k3: x + dt / 6 * (_k1 + 4 * _k2 + _k3), y, k1, k2, k3)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def ode_rk4_step(f: ODE, y: PyTree, t, *args):
|
|
125
|
+
"""
|
|
126
|
+
Fourth-order Runge-Kutta method (RK4) for solving ODEs.
|
|
127
|
+
|
|
128
|
+
The classic RK4 method uses four function evaluations to achieve fourth-order accuracy:
|
|
129
|
+
|
|
130
|
+
k1 = f(y_n, t_n)
|
|
131
|
+
k2 = f(y_n + dt/2*k1, t_n + dt/2)
|
|
132
|
+
k3 = f(y_n + dt/2*k2, t_n + dt/2)
|
|
133
|
+
k4 = f(y_n + dt*k3, t_n + dt)
|
|
134
|
+
y_{n+1} = y_n + dt/6 * (k1 + 2*k2 + 2*k3 + k4)
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
f (ODE): The differential equation function dy/dt = f(y, t, *args)
|
|
138
|
+
y (PyTree): Current state vector
|
|
139
|
+
t (float): Current time
|
|
140
|
+
*args: Additional arguments passed to function f
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
PyTree: Updated state vector y_{n+1}
|
|
144
|
+
|
|
145
|
+
Note:
|
|
146
|
+
This is a fourth-order method with O(dt⁴) local truncation error.
|
|
147
|
+
The most commonly used method due to excellent accuracy/cost trade-off.
|
|
148
|
+
"""
|
|
149
|
+
dt = brainstate.environ.get_dt()
|
|
150
|
+
k1 = f(y, t, *args)
|
|
151
|
+
k2 = f(jax.tree.map(lambda x, k: x + dt / 2 * k, y, k1), t + dt / 2, *args)
|
|
152
|
+
k3 = f(jax.tree.map(lambda x, k: x + dt / 2 * k, y, k2), t + dt / 2, *args)
|
|
153
|
+
k4 = f(jax.tree.map(lambda x, k: x + dt * k, y, k3), t + dt, *args)
|
|
154
|
+
return jax.tree.map(
|
|
155
|
+
lambda x, _k1, _k2, _k3, _k4: x + dt / 6 * (_k1 + 2 * _k2 + 2 * _k3 + _k4),
|
|
156
|
+
y, k1, k2, k3, k4
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def sde_euler_step(df, dg, y, t, sde_type='ito', **kwargs):
|
|
161
|
+
"""
|
|
162
|
+
Euler-Maruyama method for solving stochastic differential equations (SDEs).
|
|
163
|
+
|
|
164
|
+
Solves SDEs of the form:
|
|
165
|
+
dy = f(y, t)dt + g(y, t)dW
|
|
166
|
+
|
|
167
|
+
where f is the drift term, g is the diffusion term, and dW is Wiener noise.
|
|
168
|
+
|
|
169
|
+
The Euler-Maruyama scheme approximates:
|
|
170
|
+
y_{n+1} = y_n + f(y_n, t_n)*dt + g(y_n, t_n)*ΔW_n
|
|
171
|
+
|
|
172
|
+
where ΔW_n ~ N(0, dt) is the Wiener increment.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
df (Callable): Drift function f(y, t, **kwargs) -> PyTree
|
|
176
|
+
dg (Callable): Diffusion function g(y, t, **kwargs) -> PyTree
|
|
177
|
+
y (PyTree): Current state vector
|
|
178
|
+
t (float): Current time
|
|
179
|
+
sde_type (str): Type of SDE interpretation ('ito' only supported)
|
|
180
|
+
**kwargs: Additional arguments passed to df and dg
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
PyTree: Updated state vector y_{n+1}
|
|
184
|
+
|
|
185
|
+
Note:
|
|
186
|
+
This method has strong convergence order 0.5 and weak convergence order 1.0.
|
|
187
|
+
Only Itô interpretation is currently supported.
|
|
188
|
+
"""
|
|
189
|
+
assert sde_type in ['ito', ]
|
|
190
|
+
|
|
191
|
+
dt = brainstate.environ.get_dt()
|
|
192
|
+
dt_sqrt = jnp.sqrt(dt)
|
|
193
|
+
y_bars = jax.tree.map(
|
|
194
|
+
lambda y0, drift, diffusion: y0 + drift * dt + diffusion * brainstate.random.randn_like(y0) * dt_sqrt,
|
|
195
|
+
y, df(y, t, **kwargs), dg(y, t, **kwargs)
|
|
196
|
+
)
|
|
197
|
+
return y_bars
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def sde_milstein_step(df, dg, y, t, sde_type='ito', **kwargs):
|
|
201
|
+
"""
|
|
202
|
+
Milstein method for solving stochastic differential equations (SDEs).
|
|
203
|
+
|
|
204
|
+
Solves SDEs of the form:
|
|
205
|
+
dy = f(y, t)dt + g(y, t)dW
|
|
206
|
+
|
|
207
|
+
The Milstein scheme includes an additional correction term for higher accuracy:
|
|
208
|
+
y_{n+1} = y_n + f(y_n, t_n)*dt + g(y_n, t_n)*ΔW_n +
|
|
209
|
+
(1/2)*g(y_n, t_n)*∂g/∂y(y_n, t_n)*((ΔW_n)² - dt)
|
|
210
|
+
|
|
211
|
+
This method approximates the derivative ∂g/∂y using finite differences:
|
|
212
|
+
∂g/∂y ≈ (g(y + g*√dt) - g(y)) / √dt
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
df (Callable): Drift function f(y, t, **kwargs) -> PyTree
|
|
216
|
+
dg (Callable): Diffusion function g(y, t, **kwargs) -> PyTree
|
|
217
|
+
y (PyTree): Current state vector
|
|
218
|
+
t (float): Current time
|
|
219
|
+
sde_type (str): SDE interpretation ('ito' or 'stra' for Stratonovich)
|
|
220
|
+
**kwargs: Additional arguments passed to df and dg
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
PyTree: Updated state vector y_{n+1}
|
|
224
|
+
|
|
225
|
+
Note:
|
|
226
|
+
This method has strong convergence order 1.0, better than Euler-Maruyama.
|
|
227
|
+
Supports both Itô and Stratonovich interpretations.
|
|
228
|
+
The finite difference approximation is used for the derivative term.
|
|
229
|
+
"""
|
|
230
|
+
assert sde_type in ['ito', 'stra']
|
|
231
|
+
|
|
232
|
+
dt = brainstate.environ.get_dt()
|
|
233
|
+
dt_sqrt = jnp.sqrt(dt)
|
|
234
|
+
|
|
235
|
+
# drift values
|
|
236
|
+
drifts = df(y, t, **kwargs)
|
|
237
|
+
|
|
238
|
+
# diffusion values
|
|
239
|
+
diffusions = dg(y, t, **kwargs)
|
|
240
|
+
|
|
241
|
+
# intermediate results
|
|
242
|
+
y_bars = jax.tree.map(lambda y0, drift, diffusion: y0 + drift * dt + diffusion * dt_sqrt, y, drifts, diffusions)
|
|
243
|
+
diffusion_bars = dg(y_bars, t, **kwargs)
|
|
244
|
+
|
|
245
|
+
# integral results
|
|
246
|
+
def f_integral(y0, drift, diffusion, diffusion_bar):
|
|
247
|
+
noise = brainstate.random.randn_like(y0) * dt_sqrt
|
|
248
|
+
noise_p2 = (noise ** 2 - dt) if sde_type == 'ito' else noise ** 2
|
|
249
|
+
minus = (diffusion_bar - diffusion) / 2 / dt_sqrt
|
|
250
|
+
return y0 + drift * dt + diffusion * noise + minus * noise_p2
|
|
251
|
+
|
|
252
|
+
integrals = jax.tree.map(f_integral, y, drifts, diffusions, diffusion_bars)
|
|
253
|
+
return integrals
|
brainmass/noise.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
import brainstate
|
|
18
|
+
import brainunit as u
|
|
19
|
+
|
|
20
|
+
__all__ = [
|
|
21
|
+
'OUProcess',
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OUProcess(brainstate.nn.Dynamics):
|
|
26
|
+
r"""
|
|
27
|
+
The Ornstein–Uhlenbeck process.
|
|
28
|
+
|
|
29
|
+
The Ornstein–Uhlenbeck process :math:`x_{t}` is defined by the following
|
|
30
|
+
stochastic differential equation:
|
|
31
|
+
|
|
32
|
+
.. math::
|
|
33
|
+
|
|
34
|
+
\tau dx_{t}=-\theta \,x_{t}\,dt+\sigma \,dW_{t}
|
|
35
|
+
|
|
36
|
+
where :math:`\theta >0` and :math:`\sigma >0` are parameters and :math:`W_{t}`
|
|
37
|
+
denotes the Wiener process.
|
|
38
|
+
|
|
39
|
+
Parameters
|
|
40
|
+
==========
|
|
41
|
+
in_size: int, sequence of int
|
|
42
|
+
The model size.
|
|
43
|
+
mean: ArrayLike
|
|
44
|
+
The noise mean value.
|
|
45
|
+
sigma: ArrayLike
|
|
46
|
+
The noise amplitude.
|
|
47
|
+
tau: ArrayLike
|
|
48
|
+
The decay time constant.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
in_size: brainstate.typing.Size,
|
|
54
|
+
mean: brainstate.typing.ArrayLike = 0., # noise mean value
|
|
55
|
+
sigma: brainstate.typing.ArrayLike = 1., # noise amplitude
|
|
56
|
+
tau: brainstate.typing.ArrayLike = 10., # time constant
|
|
57
|
+
):
|
|
58
|
+
super().__init__(in_size=in_size)
|
|
59
|
+
|
|
60
|
+
# parameters
|
|
61
|
+
self.mean = mean
|
|
62
|
+
self.sigma = sigma
|
|
63
|
+
self.tau = tau
|
|
64
|
+
|
|
65
|
+
def init_state(self, batch_size=None, **kwargs):
|
|
66
|
+
size = self.in_size if batch_size is None else (batch_size, *self.in_size)
|
|
67
|
+
self.x = brainstate.HiddenState(u.math.zeros(size, unit=u.get_unit(self.mean)))
|
|
68
|
+
|
|
69
|
+
def reset_state(self, batch_size=None, **kwargs):
|
|
70
|
+
size = self.in_size if batch_size is None else (batch_size, *self.in_size)
|
|
71
|
+
self.x.value = u.math.zeros(size, unit=u.get_unit(self.mean))
|
|
72
|
+
|
|
73
|
+
def update(self):
|
|
74
|
+
df = lambda x: (self.mean - x) / self.tau
|
|
75
|
+
dg = lambda x: self.sigma / u.math.sqrt(self.tau)
|
|
76
|
+
self.x.value = brainstate.nn.exp_euler_step(df, dg, self.x.value)
|
|
77
|
+
return self.x.value
|