brainmass 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. brainmass/__init__.py +29 -0
  2. brainmass/bold.py +143 -0
  3. brainmass/coupling.py +171 -0
  4. brainmass/integration.py +253 -0
  5. brainmass/noise.py +77 -0
  6. brainmass/noise_test.py +35 -0
  7. brainmass/wilson_cowan.py +141 -0
  8. brainmass-0.0.1.dist-info/LICENSE +202 -0
  9. brainmass-0.0.1.dist-info/METADATA +255 -0
  10. brainmass-0.0.1.dist-info/RECORD +69 -0
  11. brainmass-0.0.1.dist-info/WHEEL +6 -0
  12. brainmass-0.0.1.dist-info/top_level.txt +2 -0
  13. examples/datasets/README.md +4 -0
  14. examples/datasets/__init__.py +21 -0
  15. examples/datasets/gw/NAP_001/functional/BOLD_rsfMRI.mat +0 -0
  16. examples/datasets/gw/NAP_001/structural/DTI_CM.mat +0 -0
  17. examples/datasets/gw/NAP_001/structural/DTI_LEN.mat +0 -0
  18. examples/datasets/gw/NAP_002/functional/BOLD_rsfMRI.mat +0 -0
  19. examples/datasets/gw/NAP_002/structural/DTI_CM.mat +0 -0
  20. examples/datasets/gw/NAP_002/structural/DTI_LEN.mat +0 -0
  21. examples/datasets/gw/NAP_007/functional/BOLD_rsfMRI.mat +0 -0
  22. examples/datasets/gw/NAP_007/structural/DTI_CM.mat +0 -0
  23. examples/datasets/gw/NAP_007/structural/DTI_LEN.mat +0 -0
  24. examples/datasets/gw/NAP_009/functional/BOLD_rsfMRI.mat +0 -0
  25. examples/datasets/gw/NAP_009/structural/DTI_CM.mat +0 -0
  26. examples/datasets/gw/NAP_009/structural/DTI_LEN.mat +0 -0
  27. examples/datasets/gw/NAP_013/functional/BOLD_rsfMRI.mat +0 -0
  28. examples/datasets/gw/NAP_013/structural/DTI_CM.mat +0 -0
  29. examples/datasets/gw/NAP_013/structural/DTI_LEN.mat +0 -0
  30. examples/datasets/hcp/101309/functional/TC_rsfMRI_REST1_LR.mat +0 -0
  31. examples/datasets/hcp/101309/structural/DTI_CM.mat +0 -0
  32. examples/datasets/hcp/101309/structural/DTI_LEN.mat +0 -0
  33. examples/datasets/hcp/101309/structural/nvoxel.txt +94 -0
  34. examples/datasets/hcp/101309/structural/waytotal.txt +94 -0
  35. examples/datasets/hcp/102311/functional/TC_rsfMRI_REST1_LR.mat +0 -0
  36. examples/datasets/hcp/102311/structural/DTI_CM.mat +0 -0
  37. examples/datasets/hcp/102311/structural/DTI_LEN.mat +0 -0
  38. examples/datasets/hcp/102311/structural/nvoxel.txt +94 -0
  39. examples/datasets/hcp/102311/structural/waytotal.txt +94 -0
  40. examples/datasets/hcp/102816/functional/TC_rsfMRI_REST1_LR.mat +0 -0
  41. examples/datasets/hcp/102816/structural/DTI_CM.mat +0 -0
  42. examples/datasets/hcp/102816/structural/DTI_LEN.mat +0 -0
  43. examples/datasets/hcp/102816/structural/nvoxel.txt +94 -0
  44. examples/datasets/hcp/102816/structural/waytotal.txt +94 -0
  45. examples/datasets/hcp/131217/functional/TC_rsfMRI_REST1_LR.mat +0 -0
  46. examples/datasets/hcp/131217/structural/DTI_CM.mat +0 -0
  47. examples/datasets/hcp/131217/structural/DTI_LEN.mat +0 -0
  48. examples/datasets/hcp/131217/structural/nvoxel.txt +94 -0
  49. examples/datasets/hcp/131217/structural/waytotal.txt +94 -0
  50. examples/datasets/hcp/211619/functional/TC_rsfMRI_REST1_LR.mat +0 -0
  51. examples/datasets/hcp/211619/structural/DTI_CM.mat +0 -0
  52. examples/datasets/hcp/211619/structural/DTI_LEN.mat +0 -0
  53. examples/datasets/hcp/211619/structural/nvoxel.txt +94 -0
  54. examples/datasets/hcp/211619/structural/waytotal.txt +94 -0
  55. examples/datasets/hcp/213522/functional/TC_rsfMRI_REST1_LR.mat +0 -0
  56. examples/datasets/hcp/213522/structural/DTI_CM.mat +0 -0
  57. examples/datasets/hcp/213522/structural/DTI_LEN.mat +0 -0
  58. examples/datasets/hcp/213522/structural/nvoxel.txt +94 -0
  59. examples/datasets/hcp/213522/structural/waytotal.txt +94 -0
  60. examples/datasets/hcp/377451/functional/TC_rsfMRI_REST1_LR.mat +0 -0
  61. examples/datasets/hcp/377451/structural/DTI_CM.mat +0 -0
  62. examples/datasets/hcp/377451/structural/DTI_LEN.mat +0 -0
  63. examples/datasets/hcp/377451/structural/nvoxel.txt +94 -0
  64. examples/datasets/hcp/377451/structural/waytotal.txt +94 -0
  65. examples/datasets/load_data.py +318 -0
  66. examples/parameter-exploration.ipynb +301 -0
  67. examples/rww_pytorch_model.py +1204 -0
  68. examples/the_model.py +89 -0
  69. examples/wilsonwowan-osillator.ipynb +374 -0
brainmass/__init__.py ADDED
@@ -0,0 +1,29 @@
1
+ # Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ __version__ = "0.0.1"
17
+
18
+ __all__ = [
19
+ 'DiffusiveCoupling',
20
+ 'AdditiveCoupling',
21
+ 'WilsonCowanModel',
22
+ 'OUProcess',
23
+ 'BOLDSignal',
24
+ ]
25
+
26
+ from .bold import *
27
+ from .coupling import *
28
+ from .noise import *
29
+ from .wilson_cowan import *
brainmass/bold.py ADDED
@@ -0,0 +1,143 @@
1
+ # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Union, Callable
17
+
18
+ import brainstate
19
+ import jax.numpy as jnp
20
+
21
+ from .integration import ode_rk2_step
22
+
23
+ __all__ = [
24
+ 'BOLDSignal',
25
+ ]
26
+
27
+
28
+ class BOLDSignal(brainstate.nn.Dynamics):
29
+ r"""
30
+ Balloon-Windkessel hemodynamic model of Friston et al. (2003) [1]_.
31
+
32
+ The Balloon-Windkessel model describes the coupling of perfusion to BOLD signal, with a
33
+ dynamical model of the transduction of neuronal activity into perfusion changes. The
34
+ model assumes that the BOLD signal is a static nonlinear function of the normalized total
35
+ deoxyhemoglobin voxel content, normalized venous volume, resting net oxygen extraction
36
+ fraction by the capillary bed, and resting blood volume fraction. The BOLD-signal estimation
37
+ for each brain area is computed by the level of synaptic activity in that particular cortical
38
+ area, noted $z_i$ for a given cortical are $i$.
39
+
40
+ For the i-th region, synaptic activity $z_i$ causes an increase in a vasodilatory signal $x_i$
41
+ that is subject to autoregulatory feedback. Inflow $f_i$ responds in proportion to this signal
42
+ with concomitant changes in blood volume $v_i$ and deoxyhemoglobin content $q_i$. The equations
43
+ relating these biophysical processes are as follows:
44
+
45
+ $$
46
+ \begin{gathered}
47
+ \dot{x}_i=z_i-k_i x_i-\gamma_i\left(f_i-1\right) \\
48
+ \dot{f}_i=x_i \\
49
+ \tau_i \dot{v}_i=f_i-v_i^{1 / \alpha} \\
50
+ \tau_i \dot{q}_i=\frac{f_i}{\rho}\left[1-(1-\rho)^{1 / f_i}\right]-q_i v_i^{1 / \alpha-1},
51
+ \end{gathered}
52
+ $$
53
+
54
+ where $\rho$ is the resting oxygen extraction fraction. The BOLD signal is given by the following:
55
+
56
+ $$
57
+ \mathrm{BOLD}_i=V_0\left[k_1\left(1-q_i\right)+k_2\left(1-q_i / v_i\right)+k_3\left(1-v_i\right)\right],
58
+ $$
59
+
60
+ where $V_0 = 0.02, k1 = 7\rho, k2 = 2$, and $k3 = 2\rho − 0.2$. All biophysical parameters were taken
61
+ as in Friston et al. (2003) [1]_. The BOLD model converts the local synaptic activity of a given cortical
62
+ area into an observable BOLD signal and does not actively couple the signals from other cortical areas.
63
+
64
+ Parameters
65
+ ----------
66
+ in_size : int
67
+ Size of the input vector (number of brain regions).
68
+ gamma : float or callable, optional
69
+ Rate of signal decay (default is 0.41).
70
+ k : float or callable, optional
71
+ Rate of flow-dependent elimination (default is 0.65).
72
+ alpha : float or callable, optional
73
+ Grubb's exponent (default is 0.32).
74
+ tau : float or callable, optional
75
+ Hemodynamic transit time (default is 0.98).
76
+ rho : float or callable, optional
77
+ Resting oxygen extraction fraction (default is 0.34).
78
+ V0 : float, optional
79
+ Resting blood volume fraction (default is 0.02).
80
+
81
+ References
82
+ ----------
83
+ .. [1] Friston KJ, Harrison L, Penny W (2003) Dynamic causal modelling. Neuroimage 19:1273–1302,
84
+ doi:10.1016/S1053-8119(03)00202-7
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ in_size,
90
+ gamma: Union[brainstate.typing.ArrayLike, Callable] = 0.41,
91
+ k: Union[brainstate.typing.ArrayLike, Callable] = 0.65,
92
+ alpha: Union[brainstate.typing.ArrayLike, Callable] = 0.32,
93
+ tau: Union[brainstate.typing.ArrayLike, Callable] = 0.98,
94
+ rho: Union[brainstate.typing.ArrayLike, Callable] = 0.34,
95
+ V0: float = 0.02,
96
+ ):
97
+ super().__init__(in_size)
98
+
99
+ self.gamma = brainstate.init.param(gamma, self.varshape)
100
+ self.k = brainstate.init.param(k, self.varshape)
101
+ self.alpha = brainstate.init.param(alpha, self.varshape)
102
+ self.tau = brainstate.init.param(tau, self.varshape)
103
+ self.rho = brainstate.init.param(rho, self.varshape)
104
+
105
+ self.V0 = V0
106
+ self.k1 = 7 * self.rho
107
+ self.k2 = 2.
108
+ self.k3 = 2 * self.rho - 0.2
109
+
110
+ self.init = brainstate.init.Constant(1.)
111
+
112
+ def init_state(self, batch_size=None, **kwargs):
113
+ self.x = brainstate.HiddenState(brainstate.init.param(self.init, self.varshape, batch_size))
114
+ self.f = brainstate.HiddenState(brainstate.init.param(self.init, self.varshape, batch_size))
115
+ self.v = brainstate.HiddenState(brainstate.init.param(self.init, self.varshape, batch_size))
116
+ self.q = brainstate.HiddenState(brainstate.init.param(self.init, self.varshape, batch_size))
117
+
118
+ def reset_state(self, batch_size=None, **kwargs):
119
+ self.x.value = brainstate.init.param(self.init, self.varshape, batch_size)
120
+ self.f.value = brainstate.init.param(self.init, self.varshape, batch_size)
121
+ self.v.value = brainstate.init.param(self.init, self.varshape, batch_size)
122
+ self.q.value = brainstate.init.param(self.init, self.varshape, batch_size)
123
+
124
+ def derivative(self, y, t, z):
125
+ x, f, v, q = y
126
+ dx = z - self.k * x - self.gamma * (f - 1)
127
+ df = x
128
+ dv = (f - jnp.power(v, 1 / self.alpha)) / self.tau
129
+ E = 1 - jnp.power(1 - self.rho, 1 / f)
130
+ dq = (f * E / self.rho - jnp.power(v, 1 / self.alpha) * q / v) / self.tau
131
+ return dx, df, dv, dq
132
+
133
+ def update(self, z):
134
+ x, f, v, q = ode_rk2_step(self.derivative, (self.x.value, self.f.value, self.v.value, self.q.value), 0., z)
135
+ self.x.value = x
136
+ self.f.value = f
137
+ self.v.value = v
138
+ self.q.value = q
139
+
140
+ def bold(self):
141
+ return self.V0 * (self.k1 * (1 - self.q.value) +
142
+ self.k2 * (1 - self.q.value / self.rho) +
143
+ self.k3 * (1 - self.v.value))
brainmass/coupling.py ADDED
@@ -0,0 +1,171 @@
1
+ # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ from typing import Union
18
+
19
+ import brainstate
20
+ import brainunit as u
21
+ from brainstate.nn._dynamics import maybe_init_prefetch
22
+
23
+ Prefetch = Union[brainstate.nn.PrefetchDelayAt, brainstate.nn.PrefetchDelay, brainstate.nn.Prefetch]
24
+
25
+ __all__ = [
26
+ 'DiffusiveCoupling',
27
+ 'AdditiveCoupling',
28
+ ]
29
+
30
+
31
+ class DiffusiveCoupling(brainstate.nn.Module):
32
+ r"""
33
+ Diffusive coupling.
34
+
35
+ This class implements a diffusive coupling mechanism for neural network modules.
36
+ It simulates the following model:
37
+
38
+ $$
39
+ \mathrm{current}_i = k * \sum_j (g_{ij} * x_{D_{ij}} - y_i)
40
+ $$
41
+
42
+ where:
43
+ - $\mathrm{current}_i$: the output current for neuron $i$
44
+ - $g_{ij}$: the connection strength between neuron $i$ and neuron $j$
45
+ - $x_{D_{ij}}$: the delayed state variable for neuron $j$, as seen by neuron $i$
46
+ - $y_i$: the state variable for neuron i
47
+
48
+ Parameters
49
+ ----------
50
+ x : Prefetch
51
+ The delayed state variable for the source units.
52
+ y : Prefetch
53
+ The delayed state variable for the target units.
54
+ conn : brainstate.typing.Array
55
+ The connection matrix (1D or 2D array) specifying the coupling strengths between units.
56
+ k: float
57
+ The global coupling strength. Default is 1.0.
58
+
59
+ Attributes
60
+ ----------
61
+ x : Prefetch
62
+ The delayed state variable for the source units.
63
+ y : Prefetch
64
+ The delayed state variable for the target units.
65
+ conn : Array
66
+ The connection matrix.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ x: Prefetch,
72
+ y: Prefetch,
73
+ conn: brainstate.typing.Array,
74
+ k: float = 1.0
75
+ ):
76
+ super().__init__()
77
+ assert isinstance(x, Prefetch), f'The first element must be a Prefetch. But got {type(x)}.'
78
+ assert isinstance(y, Prefetch), f'The second element must be a Prefetch. But got {type(y)}.'
79
+ self.x = x
80
+ self.y = y
81
+ self.k = k
82
+
83
+ # Connection matrix
84
+ self.conn = u.math.asarray(conn)
85
+ assert self.conn.ndim in (1, 2), f'Only support 1d, 2d connection matrix. But we got {self.conn.ndim}d.'
86
+
87
+ @brainstate.nn.call_order(2)
88
+ def init_state(self, *args, **kwargs):
89
+ maybe_init_prefetch(self.x)
90
+ maybe_init_prefetch(self.y)
91
+
92
+ def update(self):
93
+ delayed_x = self.x()
94
+ y = u.math.expand_dims(self.y(), axis=1) # (..., 1)
95
+ if self.conn.ndim == 1:
96
+ assert self.conn.size == delayed_x.shape[-1], (
97
+ f'Connection matrix size {self.conn.size} does not '
98
+ f'match the variable size {delayed_x.shape[-1]}.'
99
+ )
100
+ diffusive = (self.conn * delayed_x).reshape(y.shape[0], -1) - y
101
+ elif self.conn.ndim == 2:
102
+ delayed_x = delayed_x.reshape(y.shape[0], -1)
103
+ assert self.conn.shape == delayed_x.shape, (f'Connection matrix shape {self.conn.shape} does not '
104
+ f'match the variable shape {delayed_x.shape}.')
105
+ diffusive = (self.conn * delayed_x) - y
106
+ else:
107
+ raise NotImplementedError(f'Only support 1d, 2d connection matrix. But we got {self.conn.ndim}d.')
108
+ return self.k * diffusive.sum(axis=1)
109
+
110
+
111
+ class AdditiveCoupling(brainstate.nn.Module):
112
+ r"""
113
+ Additive coupling.
114
+
115
+ This class implements an additive coupling mechanism for neural network modules.
116
+ It simulates the following model:
117
+
118
+ $$
119
+ \mathrm{current}_i = k * \sum_j g_{ij} * x_{D_{ij}}
120
+ $$
121
+
122
+ where:
123
+ - $\mathrm{current}_i$: the output current for neuron $i$
124
+ - $g_{ij}$: the connection strength between neuron $i$ and neuron $j$
125
+ - $x_{D_{ij}}$: the delayed state variable for neuron $j$, as seen by neuron $i$
126
+
127
+ Parameters
128
+ ----------
129
+ x : Prefetch
130
+ The delayed state variable for the source units.
131
+ conn : brainstate.typing.Array
132
+ The connection matrix (1D or 2D array) specifying the coupling strengths between units.
133
+ k: float
134
+ The global coupling strength. Default is 1.0.
135
+
136
+ Attributes
137
+ ----------
138
+ x : Prefetch
139
+ The delayed state variable for the source units.
140
+ conn : Array
141
+ The connection matrix.
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ x: Prefetch,
147
+ conn: brainstate.typing.Array,
148
+ k: float = 1.0
149
+ ):
150
+ super().__init__()
151
+ assert isinstance(x, Prefetch), f'The first element must be a Prefetch. But got {type(x)}.'
152
+ self.x = x
153
+ self.k = k
154
+
155
+ # Connection matrix
156
+ self.conn = u.math.asarray(conn)
157
+ assert self.conn.ndim == 2, f'Only support 2d connection matrix. But we got {self.conn.ndim}d.'
158
+
159
+ @brainstate.nn.call_order(2)
160
+ def init_state(self, *args, **kwargs):
161
+ maybe_init_prefetch(self.x)
162
+
163
+ def update(self):
164
+ delayed_x = self.x()
165
+ assert self.conn.size == delayed_x.size, (
166
+ f'Connection matrix size {self.conn.size} does not '
167
+ f'match the variable size {delayed_x.size}.'
168
+ )
169
+ delayed_x = delayed_x.reshape(self.conn.shape)
170
+ diffusive = self.conn * delayed_x
171
+ return self.k * diffusive.sum(axis=1)
@@ -0,0 +1,253 @@
1
+ # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Callable
17
+
18
+ import brainstate
19
+ import jax
20
+ import jax.numpy as jnp
21
+ from brainstate.typing import PyTree
22
+
23
+ __all__ = [
24
+ 'ode_euler_step',
25
+ 'ode_rk2_step',
26
+ 'ode_rk3_step',
27
+ 'ode_rk4_step',
28
+ 'sde_euler_step',
29
+ 'sde_milstein_step',
30
+ ]
31
+
32
+ ODE = Callable[[PyTree, float, ...], PyTree]
33
+
34
+
35
+ def ode_euler_step(f: ODE, y: PyTree, t, *args):
36
+ """
37
+ Euler method for solving ordinary differential equations.
38
+
39
+ The Euler method is the simplest numerical method for solving ODEs of the form:
40
+ dy/dt = f(y, t)
41
+
42
+ The method approximates the solution using:
43
+ y_{n+1} = y_n + dt * f(y_n, t_n)
44
+
45
+ Args:
46
+ f (ODE): The differential equation function dy/dt = f(y, t, *args)
47
+ y (PyTree): Current state vector
48
+ t (float): Current time
49
+ *args: Additional arguments passed to function f
50
+
51
+ Returns:
52
+ PyTree: Updated state vector y_{n+1}
53
+
54
+ Note:
55
+ This is a first-order method with O(dt) local truncation error.
56
+ It's the least accurate but most computationally efficient method.
57
+ """
58
+ dt = brainstate.environ.get_dt()
59
+ k1 = f(y, t, *args)
60
+ return jax.tree.map(lambda x, _k1: x + dt * _k1, y, k1)
61
+
62
+
63
+ def ode_rk2_step(f: ODE, y: PyTree, t, *args):
64
+ """
65
+ Second-order Runge-Kutta method (RK2) for solving ODEs.
66
+
67
+ Also known as the midpoint method or Heun's method, this method provides
68
+ better accuracy than Euler by using two function evaluations:
69
+
70
+ k1 = f(y_n, t_n)
71
+ k2 = f(y_n + dt*k1, t_n + dt)
72
+ y_{n+1} = y_n + dt/2 * (k1 + k2)
73
+
74
+ Args:
75
+ f (ODE): The differential equation function dy/dt = f(y, t, *args)
76
+ y (PyTree): Current state vector
77
+ t (float): Current time
78
+ *args: Additional arguments passed to function f
79
+
80
+ Returns:
81
+ PyTree: Updated state vector y_{n+1}
82
+
83
+ Note:
84
+ This is a second-order method with O(dt²) local truncation error.
85
+ More accurate than Euler with only one additional function evaluation.
86
+ """
87
+ dt = brainstate.environ.get_dt()
88
+ k1 = f(y, t, *args)
89
+ k2 = f(jax.tree.map(lambda x, k: x + dt * k, y, k1), t + dt, *args)
90
+ return jax.tree.map(lambda x, _k1, _k2: x + dt / 2 * (_k1 + _k2), y, k1, k2)
91
+
92
+
93
+ def ode_rk3_step(f: ODE, y: PyTree, t, *args):
94
+ """
95
+ Third-order Runge-Kutta method (RK3) for solving ODEs.
96
+
97
+ This method uses three function evaluations to achieve third-order accuracy:
98
+
99
+ k1 = f(y_n, t_n)
100
+ k2 = f(y_n + dt/2*k1, t_n + dt/2)
101
+ k3 = f(y_n - dt*k1 + 2*dt*k2, t_n + dt)
102
+ y_{n+1} = y_n + dt/6 * (k1 + 4*k2 + k3)
103
+
104
+ Args:
105
+ f (ODE): The differential equation function dy/dt = f(y, t, *args)
106
+ y (PyTree): Current state vector
107
+ t (float): Current time
108
+ *args: Additional arguments passed to function f
109
+
110
+ Returns:
111
+ PyTree: Updated state vector y_{n+1}
112
+
113
+ Note:
114
+ This is a third-order method with O(dt³) local truncation error.
115
+ More accurate than RK2 but requires one additional function evaluation.
116
+ """
117
+ dt = brainstate.environ.get_dt()
118
+ k1 = f(y, t, *args)
119
+ k2 = f(jax.tree.map(lambda x, k: x + dt / 2 * k, y, k1), t + dt / 2, *args)
120
+ k3 = f(jax.tree.map(lambda x, k1_val, k2_val: x - dt * k1_val + 2 * dt * k2_val, y, k1, k2), t + dt, *args)
121
+ return jax.tree.map(lambda x, _k1, _k2, _k3: x + dt / 6 * (_k1 + 4 * _k2 + _k3), y, k1, k2, k3)
122
+
123
+
124
+ def ode_rk4_step(f: ODE, y: PyTree, t, *args):
125
+ """
126
+ Fourth-order Runge-Kutta method (RK4) for solving ODEs.
127
+
128
+ The classic RK4 method uses four function evaluations to achieve fourth-order accuracy:
129
+
130
+ k1 = f(y_n, t_n)
131
+ k2 = f(y_n + dt/2*k1, t_n + dt/2)
132
+ k3 = f(y_n + dt/2*k2, t_n + dt/2)
133
+ k4 = f(y_n + dt*k3, t_n + dt)
134
+ y_{n+1} = y_n + dt/6 * (k1 + 2*k2 + 2*k3 + k4)
135
+
136
+ Args:
137
+ f (ODE): The differential equation function dy/dt = f(y, t, *args)
138
+ y (PyTree): Current state vector
139
+ t (float): Current time
140
+ *args: Additional arguments passed to function f
141
+
142
+ Returns:
143
+ PyTree: Updated state vector y_{n+1}
144
+
145
+ Note:
146
+ This is a fourth-order method with O(dt⁴) local truncation error.
147
+ The most commonly used method due to excellent accuracy/cost trade-off.
148
+ """
149
+ dt = brainstate.environ.get_dt()
150
+ k1 = f(y, t, *args)
151
+ k2 = f(jax.tree.map(lambda x, k: x + dt / 2 * k, y, k1), t + dt / 2, *args)
152
+ k3 = f(jax.tree.map(lambda x, k: x + dt / 2 * k, y, k2), t + dt / 2, *args)
153
+ k4 = f(jax.tree.map(lambda x, k: x + dt * k, y, k3), t + dt, *args)
154
+ return jax.tree.map(
155
+ lambda x, _k1, _k2, _k3, _k4: x + dt / 6 * (_k1 + 2 * _k2 + 2 * _k3 + _k4),
156
+ y, k1, k2, k3, k4
157
+ )
158
+
159
+
160
+ def sde_euler_step(df, dg, y, t, sde_type='ito', **kwargs):
161
+ """
162
+ Euler-Maruyama method for solving stochastic differential equations (SDEs).
163
+
164
+ Solves SDEs of the form:
165
+ dy = f(y, t)dt + g(y, t)dW
166
+
167
+ where f is the drift term, g is the diffusion term, and dW is Wiener noise.
168
+
169
+ The Euler-Maruyama scheme approximates:
170
+ y_{n+1} = y_n + f(y_n, t_n)*dt + g(y_n, t_n)*ΔW_n
171
+
172
+ where ΔW_n ~ N(0, dt) is the Wiener increment.
173
+
174
+ Args:
175
+ df (Callable): Drift function f(y, t, **kwargs) -> PyTree
176
+ dg (Callable): Diffusion function g(y, t, **kwargs) -> PyTree
177
+ y (PyTree): Current state vector
178
+ t (float): Current time
179
+ sde_type (str): Type of SDE interpretation ('ito' only supported)
180
+ **kwargs: Additional arguments passed to df and dg
181
+
182
+ Returns:
183
+ PyTree: Updated state vector y_{n+1}
184
+
185
+ Note:
186
+ This method has strong convergence order 0.5 and weak convergence order 1.0.
187
+ Only Itô interpretation is currently supported.
188
+ """
189
+ assert sde_type in ['ito', ]
190
+
191
+ dt = brainstate.environ.get_dt()
192
+ dt_sqrt = jnp.sqrt(dt)
193
+ y_bars = jax.tree.map(
194
+ lambda y0, drift, diffusion: y0 + drift * dt + diffusion * brainstate.random.randn_like(y0) * dt_sqrt,
195
+ y, df(y, t, **kwargs), dg(y, t, **kwargs)
196
+ )
197
+ return y_bars
198
+
199
+
200
+ def sde_milstein_step(df, dg, y, t, sde_type='ito', **kwargs):
201
+ """
202
+ Milstein method for solving stochastic differential equations (SDEs).
203
+
204
+ Solves SDEs of the form:
205
+ dy = f(y, t)dt + g(y, t)dW
206
+
207
+ The Milstein scheme includes an additional correction term for higher accuracy:
208
+ y_{n+1} = y_n + f(y_n, t_n)*dt + g(y_n, t_n)*ΔW_n +
209
+ (1/2)*g(y_n, t_n)*∂g/∂y(y_n, t_n)*((ΔW_n)² - dt)
210
+
211
+ This method approximates the derivative ∂g/∂y using finite differences:
212
+ ∂g/∂y ≈ (g(y + g*√dt) - g(y)) / √dt
213
+
214
+ Args:
215
+ df (Callable): Drift function f(y, t, **kwargs) -> PyTree
216
+ dg (Callable): Diffusion function g(y, t, **kwargs) -> PyTree
217
+ y (PyTree): Current state vector
218
+ t (float): Current time
219
+ sde_type (str): SDE interpretation ('ito' or 'stra' for Stratonovich)
220
+ **kwargs: Additional arguments passed to df and dg
221
+
222
+ Returns:
223
+ PyTree: Updated state vector y_{n+1}
224
+
225
+ Note:
226
+ This method has strong convergence order 1.0, better than Euler-Maruyama.
227
+ Supports both Itô and Stratonovich interpretations.
228
+ The finite difference approximation is used for the derivative term.
229
+ """
230
+ assert sde_type in ['ito', 'stra']
231
+
232
+ dt = brainstate.environ.get_dt()
233
+ dt_sqrt = jnp.sqrt(dt)
234
+
235
+ # drift values
236
+ drifts = df(y, t, **kwargs)
237
+
238
+ # diffusion values
239
+ diffusions = dg(y, t, **kwargs)
240
+
241
+ # intermediate results
242
+ y_bars = jax.tree.map(lambda y0, drift, diffusion: y0 + drift * dt + diffusion * dt_sqrt, y, drifts, diffusions)
243
+ diffusion_bars = dg(y_bars, t, **kwargs)
244
+
245
+ # integral results
246
+ def f_integral(y0, drift, diffusion, diffusion_bar):
247
+ noise = brainstate.random.randn_like(y0) * dt_sqrt
248
+ noise_p2 = (noise ** 2 - dt) if sde_type == 'ito' else noise ** 2
249
+ minus = (diffusion_bar - diffusion) / 2 / dt_sqrt
250
+ return y0 + drift * dt + diffusion * noise + minus * noise_p2
251
+
252
+ integrals = jax.tree.map(f_integral, y, drifts, diffusions, diffusion_bars)
253
+ return integrals
brainmass/noise.py ADDED
@@ -0,0 +1,77 @@
1
+ # Copyright 2025 BDP Ecosystem Limited. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+
17
+ import brainstate
18
+ import brainunit as u
19
+
20
+ __all__ = [
21
+ 'OUProcess',
22
+ ]
23
+
24
+
25
+ class OUProcess(brainstate.nn.Dynamics):
26
+ r"""
27
+ The Ornstein–Uhlenbeck process.
28
+
29
+ The Ornstein–Uhlenbeck process :math:`x_{t}` is defined by the following
30
+ stochastic differential equation:
31
+
32
+ .. math::
33
+
34
+ \tau dx_{t}=-\theta \,x_{t}\,dt+\sigma \,dW_{t}
35
+
36
+ where :math:`\theta >0` and :math:`\sigma >0` are parameters and :math:`W_{t}`
37
+ denotes the Wiener process.
38
+
39
+ Parameters
40
+ ==========
41
+ in_size: int, sequence of int
42
+ The model size.
43
+ mean: ArrayLike
44
+ The noise mean value.
45
+ sigma: ArrayLike
46
+ The noise amplitude.
47
+ tau: ArrayLike
48
+ The decay time constant.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ in_size: brainstate.typing.Size,
54
+ mean: brainstate.typing.ArrayLike = 0., # noise mean value
55
+ sigma: brainstate.typing.ArrayLike = 1., # noise amplitude
56
+ tau: brainstate.typing.ArrayLike = 10., # time constant
57
+ ):
58
+ super().__init__(in_size=in_size)
59
+
60
+ # parameters
61
+ self.mean = mean
62
+ self.sigma = sigma
63
+ self.tau = tau
64
+
65
+ def init_state(self, batch_size=None, **kwargs):
66
+ size = self.in_size if batch_size is None else (batch_size, *self.in_size)
67
+ self.x = brainstate.HiddenState(u.math.zeros(size, unit=u.get_unit(self.mean)))
68
+
69
+ def reset_state(self, batch_size=None, **kwargs):
70
+ size = self.in_size if batch_size is None else (batch_size, *self.in_size)
71
+ self.x.value = u.math.zeros(size, unit=u.get_unit(self.mean))
72
+
73
+ def update(self):
74
+ df = lambda x: (self.mean - x) / self.tau
75
+ dg = lambda x: self.sigma / u.math.sqrt(self.tau)
76
+ self.x.value = brainstate.nn.exp_euler_step(df, dg, self.x.value)
77
+ return self.x.value