vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vbi/__init__.py +37 -0
- vbi/_version.py +17 -0
- vbi/dataset/__init__.py +0 -0
- vbi/dataset/connectivity_84/centers.txt +84 -0
- vbi/dataset/connectivity_84/centres.txt +84 -0
- vbi/dataset/connectivity_84/cortical.txt +84 -0
- vbi/dataset/connectivity_84/tract_lengths.txt +84 -0
- vbi/dataset/connectivity_84/weights.txt +84 -0
- vbi/dataset/connectivity_88/Aud_88.txt +88 -0
- vbi/dataset/connectivity_88/Bold.npz +0 -0
- vbi/dataset/connectivity_88/Labels.txt +17 -0
- vbi/dataset/connectivity_88/Region_labels.txt +88 -0
- vbi/dataset/connectivity_88/tract_lengths.txt +88 -0
- vbi/dataset/connectivity_88/weights.txt +88 -0
- vbi/feature_extraction/__init__.py +1 -0
- vbi/feature_extraction/calc_features.py +293 -0
- vbi/feature_extraction/features.json +535 -0
- vbi/feature_extraction/features.py +2124 -0
- vbi/feature_extraction/features_settings.py +374 -0
- vbi/feature_extraction/features_utils.py +1357 -0
- vbi/feature_extraction/infodynamics.jar +0 -0
- vbi/feature_extraction/utility.py +507 -0
- vbi/inference.py +98 -0
- vbi/models/__init__.py +0 -0
- vbi/models/cpp/__init__.py +0 -0
- vbi/models/cpp/_src/__init__.py +0 -0
- vbi/models/cpp/_src/__pycache__/mpr_sde.cpython-310.pyc +0 -0
- vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/bold.hpp +303 -0
- vbi/models/cpp/_src/do.hpp +167 -0
- vbi/models/cpp/_src/do.i +17 -0
- vbi/models/cpp/_src/do.py +467 -0
- vbi/models/cpp/_src/do_wrap.cxx +12811 -0
- vbi/models/cpp/_src/jr_sdde.hpp +352 -0
- vbi/models/cpp/_src/jr_sdde.i +19 -0
- vbi/models/cpp/_src/jr_sdde.py +688 -0
- vbi/models/cpp/_src/jr_sdde_wrap.cxx +18718 -0
- vbi/models/cpp/_src/jr_sde.hpp +264 -0
- vbi/models/cpp/_src/jr_sde.i +17 -0
- vbi/models/cpp/_src/jr_sde.py +470 -0
- vbi/models/cpp/_src/jr_sde_wrap.cxx +13406 -0
- vbi/models/cpp/_src/km_sde.hpp +158 -0
- vbi/models/cpp/_src/km_sde.i +19 -0
- vbi/models/cpp/_src/km_sde.py +671 -0
- vbi/models/cpp/_src/km_sde_wrap.cxx +17367 -0
- vbi/models/cpp/_src/makefile +52 -0
- vbi/models/cpp/_src/mpr_sde.hpp +327 -0
- vbi/models/cpp/_src/mpr_sde.i +19 -0
- vbi/models/cpp/_src/mpr_sde.py +711 -0
- vbi/models/cpp/_src/mpr_sde_wrap.cxx +18618 -0
- vbi/models/cpp/_src/utility.hpp +307 -0
- vbi/models/cpp/_src/vep.hpp +171 -0
- vbi/models/cpp/_src/vep.i +16 -0
- vbi/models/cpp/_src/vep.py +464 -0
- vbi/models/cpp/_src/vep_wrap.cxx +12968 -0
- vbi/models/cpp/_src/wc_ode.hpp +294 -0
- vbi/models/cpp/_src/wc_ode.i +19 -0
- vbi/models/cpp/_src/wc_ode.py +686 -0
- vbi/models/cpp/_src/wc_ode_wrap.cxx +24263 -0
- vbi/models/cpp/damp_oscillator.py +143 -0
- vbi/models/cpp/jansen_rit.py +543 -0
- vbi/models/cpp/km.py +187 -0
- vbi/models/cpp/mpr.py +289 -0
- vbi/models/cpp/vep.py +150 -0
- vbi/models/cpp/wc.py +216 -0
- vbi/models/cupy/__init__.py +0 -0
- vbi/models/cupy/bold.py +111 -0
- vbi/models/cupy/ghb.py +284 -0
- vbi/models/cupy/jansen_rit.py +473 -0
- vbi/models/cupy/km.py +224 -0
- vbi/models/cupy/mpr.py +475 -0
- vbi/models/cupy/mpr_modified_bold.py +12 -0
- vbi/models/cupy/utils.py +184 -0
- vbi/models/numba/__init__.py +0 -0
- vbi/models/numba/_ww_EI.py +444 -0
- vbi/models/numba/damp_oscillator.py +162 -0
- vbi/models/numba/ghb.py +208 -0
- vbi/models/numba/mpr.py +383 -0
- vbi/models/pytorch/__init__.py +0 -0
- vbi/models/pytorch/data/default_parameters.npz +0 -0
- vbi/models/pytorch/data/input/ROI_sim.mat +0 -0
- vbi/models/pytorch/data/input/fc_test.csv +68 -0
- vbi/models/pytorch/data/input/fc_train.csv +68 -0
- vbi/models/pytorch/data/input/fc_vali.csv +68 -0
- vbi/models/pytorch/data/input/fcd_test.mat +0 -0
- vbi/models/pytorch/data/input/fcd_test_high_window.mat +0 -0
- vbi/models/pytorch/data/input/fcd_test_low_window.mat +0 -0
- vbi/models/pytorch/data/input/fcd_train.mat +0 -0
- vbi/models/pytorch/data/input/fcd_vali.mat +0 -0
- vbi/models/pytorch/data/input/myelin.csv +68 -0
- vbi/models/pytorch/data/input/rsfc_gradient.csv +68 -0
- vbi/models/pytorch/data/input/run_label_testset.mat +0 -0
- vbi/models/pytorch/data/input/sc_test.csv +68 -0
- vbi/models/pytorch/data/input/sc_train.csv +68 -0
- vbi/models/pytorch/data/input/sc_vali.csv +68 -0
- vbi/models/pytorch/data/obs_kong0.npz +0 -0
- vbi/models/pytorch/ww_sde_kong.py +570 -0
- vbi/models/tvbk/__init__.py +9 -0
- vbi/models/tvbk/tvbk_wrapper.py +166 -0
- vbi/models/tvbk/utils.py +72 -0
- vbi/papers/__init__.py +0 -0
- vbi/papers/pavlides_pcb_2015/pavlides.py +211 -0
- vbi/tests/__init__.py +0 -0
- vbi/tests/_test_mpr_nb.py +36 -0
- vbi/tests/test_features.py +355 -0
- vbi/tests/test_ghb_cupy.py +90 -0
- vbi/tests/test_mpr_cupy.py +49 -0
- vbi/tests/test_mpr_numba.py +84 -0
- vbi/tests/test_suite.py +19 -0
- vbi/utils.py +402 -0
- vbi-0.1.3.dist-info/METADATA +166 -0
- vbi-0.1.3.dist-info/RECORD +121 -0
- vbi-0.1.3.dist-info/WHEEL +5 -0
- vbi-0.1.3.dist-info/licenses/LICENSE +201 -0
- vbi-0.1.3.dist-info/top_level.txt +1 -0
vbi/models/numba/ghb.py
ADDED
@@ -0,0 +1,208 @@
|
|
1
|
+
import warnings
|
2
|
+
import numpy as np
|
3
|
+
from numba import njit
|
4
|
+
from numba.experimental import jitclass
|
5
|
+
from numba.core.errors import NumbaPerformanceWarning
|
6
|
+
from numba import float64, int64
|
7
|
+
|
8
|
+
warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
|
9
|
+
|
10
|
+
|
11
|
+
@njit(nogil=True)
|
12
|
+
def run(P, times):
|
13
|
+
|
14
|
+
G = P.G
|
15
|
+
dt = P.dt
|
16
|
+
eta = P.eta
|
17
|
+
SC = P.weights
|
18
|
+
sigma = P.sigma
|
19
|
+
omega = P.omega
|
20
|
+
tcut = P.tcut
|
21
|
+
decimate = P.decimate
|
22
|
+
init_state = P.init_state
|
23
|
+
|
24
|
+
epsilon = 0.5
|
25
|
+
itaus = 1.25
|
26
|
+
itauf = 2.5
|
27
|
+
itauo = 1.02040816327
|
28
|
+
ialpha = 5.0
|
29
|
+
Eo = 0.4
|
30
|
+
V0 = 4.0
|
31
|
+
k1 = 2.77264
|
32
|
+
k2 = 0.572
|
33
|
+
k3 = -0.43
|
34
|
+
|
35
|
+
nt = times.shape[0]
|
36
|
+
nn = SC.shape[0]
|
37
|
+
|
38
|
+
# state variables
|
39
|
+
n_buffer = np.int64(np.floor(nt / decimate) + 1)
|
40
|
+
x = np.zeros(nn)
|
41
|
+
bold = np.zeros((nn, n_buffer))
|
42
|
+
y = np.zeros(nn)
|
43
|
+
z = np.array([0.0] * nn + [1.0] * 3 * nn)
|
44
|
+
# act = np.zeros((nn, nt))
|
45
|
+
|
46
|
+
# initial conditions (similar value for all regions)
|
47
|
+
x_init, y_init = init_state[:nn], init_state[nn:]
|
48
|
+
x[:] = x_init
|
49
|
+
y[:] = y_init
|
50
|
+
|
51
|
+
for i in range(nn):
|
52
|
+
bold[i, 0] = V0 * (
|
53
|
+
k1
|
54
|
+
- k1 * z[3 * nn + i]
|
55
|
+
+ k2
|
56
|
+
- k2 * (z[3 * nn + i] / z[2 * nn + i])
|
57
|
+
+ k3
|
58
|
+
- k3 * z[2 * nn + i]
|
59
|
+
)
|
60
|
+
|
61
|
+
ii = 0 # counter for decimation
|
62
|
+
for it in range(nt - 1):
|
63
|
+
for i in range(nn):
|
64
|
+
gx, gy = 0.0, 0.0
|
65
|
+
for j in range(nn):
|
66
|
+
gx = gx + SC[i, j] * (x[j] - x[i])
|
67
|
+
gy = gy + SC[i, j] * (y[j] - y[i])
|
68
|
+
dx = (
|
69
|
+
(x[i] * (eta[i] - (x[i] * x[i]) - (y[i] * y[i])))
|
70
|
+
- (omega[i] * y[i])
|
71
|
+
+ (G * gx)
|
72
|
+
)
|
73
|
+
dy = (
|
74
|
+
(y[i] * (eta[i] - (x[i] * x[i]) - (y[i] * y[i])))
|
75
|
+
+ (omega[i] * x[i])
|
76
|
+
+ (G * gy)
|
77
|
+
)
|
78
|
+
dz0 = epsilon * x[i] - itaus * z[i] - itauf * (z[nn + i] - 1)
|
79
|
+
dz1 = z[i]
|
80
|
+
dz2 = itauo * (z[nn + i] - z[2 * nn + i] ** ialpha)
|
81
|
+
dz3 = itauo * (
|
82
|
+
z[nn + i] * (1 - (1 - Eo) ** (1 / z[nn + i])) / Eo
|
83
|
+
- (z[2 * nn + i] ** ialpha) * z[3 * nn + i] / z[2 * nn + i]
|
84
|
+
)
|
85
|
+
|
86
|
+
x[i] = x[i] + dt * dx + np.sqrt(dt) * sigma * np.random.randn()
|
87
|
+
y[i] = y[i] + dt * dy + np.sqrt(dt) * sigma * np.random.randn()
|
88
|
+
|
89
|
+
z[i] = z[i] + dt * dz0
|
90
|
+
z[nn + i] = z[nn + i] + dt * dz1
|
91
|
+
z[2 * nn + i] = z[2 * nn + i] + dt * dz2
|
92
|
+
z[3 * nn + i] = z[3 * nn + i] + dt * dz3
|
93
|
+
if (it%decimate == 0):
|
94
|
+
bold[i, ii + 1] = V0 * (
|
95
|
+
k1
|
96
|
+
- k1 * z[3 * nn + i]
|
97
|
+
+ k2
|
98
|
+
- k2 * (z[3 * nn + i] / z[2 * nn + i])
|
99
|
+
+ k3
|
100
|
+
- k3 * z[2 * nn + i]
|
101
|
+
)
|
102
|
+
if (it%decimate == 0):
|
103
|
+
ii += 1
|
104
|
+
bold = bold[:, times[::decimate]>tcut]
|
105
|
+
t_bold = times[times[::decimate]>tcut]
|
106
|
+
return t_bold, bold
|
107
|
+
|
108
|
+
|
109
|
+
class GHB_sde(object):
|
110
|
+
def __init__(self, par: dict = {}) -> None:
|
111
|
+
self.valid_par = [par_spec[i][0] for i in range(len(par_spec))]
|
112
|
+
self.check_parameters(par)
|
113
|
+
self.P = self.get_par_obj(par)
|
114
|
+
|
115
|
+
def get_par_obj(self, par: dict):
|
116
|
+
if "init_state" in par.keys():
|
117
|
+
par["init_state"] = np.array(par["init_state"])
|
118
|
+
if "weights" in par.keys():
|
119
|
+
par["weights"] = np.array(par["weights"])
|
120
|
+
return ParGHB(**par)
|
121
|
+
|
122
|
+
def __str__(self) -> str:
|
123
|
+
print("GHB model")
|
124
|
+
for key in self.valid_par:
|
125
|
+
print(f"{key}: {getattr(self.P, key)}")
|
126
|
+
return ""
|
127
|
+
|
128
|
+
def check_parameters(self, par: dict) -> None:
|
129
|
+
for key in par.keys():
|
130
|
+
if key not in self.valid_par:
|
131
|
+
raise ValueError(f"Invalid parameter: {key}")
|
132
|
+
|
133
|
+
def set_initial_state(self, seed=None):
|
134
|
+
if seed is not None:
|
135
|
+
np.random.seed(seed)
|
136
|
+
assert self.P.weights is not None
|
137
|
+
return np.random.uniform(0, 1, 2 * self.P.weights.shape[0])
|
138
|
+
|
139
|
+
def check_input(self):
|
140
|
+
assert self.P.weights is not None
|
141
|
+
assert self.P.weights.shape[0] == self.P.weights.shape[1]
|
142
|
+
assert self.P.eta is not None
|
143
|
+
assert self.P.omega is not None
|
144
|
+
assert self.P.weights.shape[0] == self.P.eta.shape[0]
|
145
|
+
|
146
|
+
def run(self, par={}, tspan=None, x0=None, verbose=True):
|
147
|
+
if x0 is None:
|
148
|
+
self.seed = self.P.seed if self.P.seed > 0 else None
|
149
|
+
self.P.init_state = self.set_initial_state(seed=self.seed)
|
150
|
+
else:
|
151
|
+
self.P.init_state = x0
|
152
|
+
|
153
|
+
if tspan is None:
|
154
|
+
times = np.arange(0, self.P.tend, self.P.dt)
|
155
|
+
else:
|
156
|
+
times = np.arange(tspan[0], tspan[1], self.P.dt)
|
157
|
+
|
158
|
+
if par:
|
159
|
+
self.check_parameters(par)
|
160
|
+
for key in par.keys():
|
161
|
+
setattr(self.P, key, par[key])
|
162
|
+
|
163
|
+
self.check_input()
|
164
|
+
t, b = run(self.P, times)
|
165
|
+
return {'t': t, 'bold': b}
|
166
|
+
|
167
|
+
|
168
|
+
par_spec = [
|
169
|
+
("G", float64),
|
170
|
+
("dt", float64),
|
171
|
+
("seed", int64),
|
172
|
+
("tend", float64),
|
173
|
+
("tcut", float64),
|
174
|
+
("sigma", float64),
|
175
|
+
("eta", float64[:]),
|
176
|
+
("decimate", int64),
|
177
|
+
("omega", float64[:]),
|
178
|
+
("weights", float64[:, :]),
|
179
|
+
("init_state", float64[:]),
|
180
|
+
]
|
181
|
+
|
182
|
+
|
183
|
+
@jitclass(par_spec)
|
184
|
+
class ParGHB:
|
185
|
+
def __init__(
|
186
|
+
self,
|
187
|
+
G=1.0,
|
188
|
+
dt=0.001,
|
189
|
+
sigma=0.1,
|
190
|
+
tend=10.0,
|
191
|
+
tcut=0.0,
|
192
|
+
eta=np.array([]),
|
193
|
+
init_state=np.array([]),
|
194
|
+
omega=np.array([]),
|
195
|
+
weights=np.array([[], []]),
|
196
|
+
decimate=1,
|
197
|
+
):
|
198
|
+
self.G = G
|
199
|
+
self.dt = dt
|
200
|
+
self.seed = -1
|
201
|
+
self.eta = eta
|
202
|
+
self.tend = tend
|
203
|
+
self.tcut = tcut
|
204
|
+
self.sigma = sigma
|
205
|
+
self.omega = omega
|
206
|
+
self.weights = weights
|
207
|
+
self.decimate = decimate
|
208
|
+
self.init_state = init_state
|
vbi/models/numba/mpr.py
ADDED
@@ -0,0 +1,383 @@
|
|
1
|
+
import warnings
|
2
|
+
import numpy as np
|
3
|
+
from copy import copy
|
4
|
+
from numba import njit, jit
|
5
|
+
from numba.experimental import jitclass
|
6
|
+
from numba.extending import register_jitable
|
7
|
+
from numba import float64, boolean, int64, types
|
8
|
+
from numba.core.errors import NumbaPerformanceWarning
|
9
|
+
|
10
|
+
warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
|
11
|
+
np.random.seed(42)
|
12
|
+
|
13
|
+
|
14
|
+
@njit
|
15
|
+
def f_mpr(x, t, P):
|
16
|
+
"""
|
17
|
+
MPR model
|
18
|
+
"""
|
19
|
+
|
20
|
+
dxdt = np.zeros_like(x)
|
21
|
+
nn = P.nn
|
22
|
+
x0 = x[:nn]
|
23
|
+
x1 = x[nn:]
|
24
|
+
delta_over_tau_pi = P.delta / (P.tau * np.pi)
|
25
|
+
J_tau = P.J * P.tau
|
26
|
+
pi2 = np.pi * np.pi
|
27
|
+
tau2 = P.tau * P.tau
|
28
|
+
rtau = 1.0 / P.tau
|
29
|
+
|
30
|
+
coupling = np.dot(P.weights, x0)
|
31
|
+
dxdt[:nn] = rtau * (delta_over_tau_pi + 2 * x0 * x1)
|
32
|
+
dxdt[nn:] = rtau * (
|
33
|
+
x1 * x1 + P.eta + P.iapp + J_tau * x0 - (pi2 * tau2 * x0 * x0) + P.G * coupling
|
34
|
+
)
|
35
|
+
return dxdt
|
36
|
+
|
37
|
+
|
38
|
+
@njit
|
39
|
+
def heun_sde(x, t, P):
|
40
|
+
nn = P.nn
|
41
|
+
dt = P.dt
|
42
|
+
dW_r = P.sigma_r * np.random.randn(nn)
|
43
|
+
dW_v = P.sigma_v * np.random.randn(nn)
|
44
|
+
k1 = f_mpr(x, t, P)
|
45
|
+
x1 = x + dt * k1
|
46
|
+
x1[:nn] += dW_r
|
47
|
+
x1[nn:] += dW_v
|
48
|
+
|
49
|
+
k2 = f_mpr(x1, t + dt, P)
|
50
|
+
x += 0.5 * dt * (k1 + k2)
|
51
|
+
x[:nn] += dW_r
|
52
|
+
x[:nn] = (x[:nn] > 0.0) * x[:nn]
|
53
|
+
x[nn:] += dW_v
|
54
|
+
return x
|
55
|
+
|
56
|
+
|
57
|
+
@njit
|
58
|
+
def do_bold_step(r_in, s, f, ftilde, vtilde, qtilde, v, q, dtt, P):
|
59
|
+
kappa = P.kappa
|
60
|
+
gamma = P.gamma
|
61
|
+
ialpha = 1 / P.alpha
|
62
|
+
tau = P.tau
|
63
|
+
Eo = P.Eo
|
64
|
+
|
65
|
+
s[1] = s[0] + dtt * (r_in - kappa * s[0] - gamma * (f[0] - 1))
|
66
|
+
f[0] = np.clip(f[0], 1, None)
|
67
|
+
ftilde[1] = ftilde[0] + dtt * (s[0] / f[0])
|
68
|
+
fv = v[0] ** ialpha # outflow
|
69
|
+
vtilde[1] = vtilde[0] + dtt * ((f[0] - fv) / (tau * v[0]))
|
70
|
+
q[0] = np.clip(q[0], 0.01, None)
|
71
|
+
ff = (1 - (1 - Eo) ** (1 / f[0])) / Eo # oxygen extraction
|
72
|
+
qtilde[1] = qtilde[0] + dtt * ((f[0] * ff - fv * q[0] / v[0]) / (tau * q[0]))
|
73
|
+
|
74
|
+
f[1] = np.exp(ftilde[1])
|
75
|
+
v[1] = np.exp(vtilde[1])
|
76
|
+
q[1] = np.exp(qtilde[1])
|
77
|
+
|
78
|
+
f[0] = f[1]
|
79
|
+
s[0] = s[1]
|
80
|
+
ftilde[0] = ftilde[1]
|
81
|
+
vtilde[0] = vtilde[1]
|
82
|
+
qtilde[0] = qtilde[1]
|
83
|
+
v[0] = v[1]
|
84
|
+
q[0] = q[1]
|
85
|
+
|
86
|
+
|
87
|
+
def integrate(P, B):
|
88
|
+
|
89
|
+
nn = P.nn
|
90
|
+
tr = P.tr
|
91
|
+
dt = P.dt
|
92
|
+
dt = P.dt
|
93
|
+
rv_decimate = P.rv_decimate
|
94
|
+
r_period = P.dt * 10 # extenting time
|
95
|
+
bold_decimate = int(np.round(tr / r_period))
|
96
|
+
|
97
|
+
dtt = r_period / 1000.0 # in seconds
|
98
|
+
k1 = 4.3 * B.theta0 * B.Eo * B.TE
|
99
|
+
k2 = B.epsilon * B.r0 * B.Eo * B.TE
|
100
|
+
k3 = 1 - B.epsilon
|
101
|
+
vo = B.vo
|
102
|
+
|
103
|
+
nt = int(P.t_end / P.dt)
|
104
|
+
rv_current = P.initial_state
|
105
|
+
RECORD_RV = P.RECORD_RV
|
106
|
+
RECORD_BOLD = P.RECORD_BOLD
|
107
|
+
|
108
|
+
rv_d = np.array([])
|
109
|
+
rv_t = np.zeros([])
|
110
|
+
|
111
|
+
bold_d = np.array([])
|
112
|
+
bold_t = np.array([])
|
113
|
+
|
114
|
+
if P.RECORD_RV:
|
115
|
+
rv_d = np.zeros((nt // rv_decimate, 2 * nn), dtype=np.float32)
|
116
|
+
rv_t = np.zeros((nt // rv_decimate), dtype=np.float32)
|
117
|
+
|
118
|
+
def compute():
|
119
|
+
|
120
|
+
bold_d = np.array([])
|
121
|
+
bold_t = np.array([])
|
122
|
+
s = np.zeros((2, nn))
|
123
|
+
f = np.zeros((2, nn))
|
124
|
+
ftilde = np.zeros((2, nn))
|
125
|
+
vtilde = np.zeros((2, nn))
|
126
|
+
qtilde = np.zeros((2, nn))
|
127
|
+
v = np.zeros((2, nn))
|
128
|
+
q = np.zeros((2, nn))
|
129
|
+
vv = np.zeros((nt // bold_decimate, nn))
|
130
|
+
qq = np.zeros((nt // bold_decimate, nn))
|
131
|
+
s[0] = 1
|
132
|
+
f[0] = 1
|
133
|
+
v[0] = 1
|
134
|
+
q[0] = 1
|
135
|
+
ftilde[0] = 0
|
136
|
+
vtilde[0] = 0
|
137
|
+
qtilde[0] = 0
|
138
|
+
|
139
|
+
for i in range(nt - 1):
|
140
|
+
t_current = i * dt
|
141
|
+
heun_sde(rv_current, t_current, P)
|
142
|
+
|
143
|
+
if RECORD_RV:
|
144
|
+
if ((i % rv_decimate) == 0) and ((i // rv_decimate) < rv_d.shape[0]):
|
145
|
+
rv_d[i // rv_decimate, :] = rv_current
|
146
|
+
rv_t[i // rv_decimate] = t_current
|
147
|
+
|
148
|
+
if RECORD_BOLD:
|
149
|
+
do_bold_step(
|
150
|
+
rv_current[:nn], s, f, ftilde, vtilde, qtilde, v, q, dtt, B
|
151
|
+
)
|
152
|
+
if (i % bold_decimate == 0) and ((i // bold_decimate) < vv.shape[0]):
|
153
|
+
vv[i // bold_decimate] = v[1]
|
154
|
+
qq[i // bold_decimate] = q[1]
|
155
|
+
|
156
|
+
if RECORD_BOLD:
|
157
|
+
bold_d = vo * (k1 * (1 - qq) + k2 * (1 - qq / vv) + k3 * (1 - vv))
|
158
|
+
bold_t = np.linspace(0, P.t_end - dt * bold_decimate, len(bold_d))
|
159
|
+
|
160
|
+
return rv_t, rv_d, bold_t, bold_d
|
161
|
+
|
162
|
+
rv_t, rv_d, bold_t, bold_d = compute()
|
163
|
+
|
164
|
+
return {
|
165
|
+
"rv_t": rv_t * 10,
|
166
|
+
"rv_d": rv_d,
|
167
|
+
"bold_t": bold_t.astype("f") * 10,
|
168
|
+
"bold_d": bold_d.astype("f"),
|
169
|
+
}
|
170
|
+
|
171
|
+
|
172
|
+
class MPR_sde:
|
173
|
+
def __init__(self, par_mpr: dict = {}) -> None:
|
174
|
+
self.valid_par = [mpr_spec[i][0] for i in range(len(mpr_spec))]
|
175
|
+
self.check_parameters(par_mpr)
|
176
|
+
self.P = self.get_par_mpr(par_mpr)
|
177
|
+
self.B = ParBold()
|
178
|
+
|
179
|
+
self.seed = self.P.seed
|
180
|
+
if self.seed > 0:
|
181
|
+
np.random.seed(self.seed)
|
182
|
+
|
183
|
+
def __str__(self) -> str:
|
184
|
+
print("MPR model")
|
185
|
+
print("Parameters: --------------------------------")
|
186
|
+
for key in self.valid_par:
|
187
|
+
print(f"{key} = {getattr(self.P, key)}")
|
188
|
+
print("--------------------------------------------")
|
189
|
+
return ""
|
190
|
+
|
191
|
+
def check_parameters(self, par: dict) -> None:
|
192
|
+
for key in par.keys():
|
193
|
+
if key not in self.valid_par:
|
194
|
+
raise ValueError(f"Invalid parameter: {key}")
|
195
|
+
|
196
|
+
def get_par_mpr(self, par: dict):
|
197
|
+
"""
|
198
|
+
return default parameters of MPR model and update with user defined parameters.
|
199
|
+
"""
|
200
|
+
if "initial_state" in par.keys():
|
201
|
+
par["initial_state"] = np.array(par["initial_state"])
|
202
|
+
if "weights" in par.keys():
|
203
|
+
assert par["weights"] is not None
|
204
|
+
par["weights"] = np.array(par["weights"])
|
205
|
+
assert par["weights"].shape[0] == par["weights"].shape[1]
|
206
|
+
parP = ParMPR(**par)
|
207
|
+
return parP
|
208
|
+
|
209
|
+
def set_initial_state(self):
|
210
|
+
self.initial_state = set_initial_state(self.P.nn, self.seed)
|
211
|
+
self.INITIAL_STATE_SET = True
|
212
|
+
|
213
|
+
def check_input(self):
|
214
|
+
assert self.P.weights is not None
|
215
|
+
assert self.P.weights.shape[0] == self.P.weights.shape[1]
|
216
|
+
assert self.P.initial_state is not None
|
217
|
+
assert len(self.P.initial_state) == 2 * self.P.weights.shape[0]
|
218
|
+
self.P.eta = check_vec_size(self.P.eta, self.P.nn)
|
219
|
+
self.P.t_end /= 10
|
220
|
+
self.P.t_cut /= 10
|
221
|
+
|
222
|
+
def run(self, par={}, x0=None, verbose=True):
|
223
|
+
|
224
|
+
if x0 is None:
|
225
|
+
self.seed = self.P.seed if self.P.seed > 0 else None
|
226
|
+
self.set_initial_state()
|
227
|
+
self.P.initial_state = self.initial_state
|
228
|
+
else:
|
229
|
+
self.P.initial_state = x0
|
230
|
+
# self.P.nn = len(x0) // 2 # is it necessary?
|
231
|
+
if par:
|
232
|
+
self.check_parameters(par)
|
233
|
+
for key in par.keys():
|
234
|
+
setattr(self.P, key, par[key])
|
235
|
+
|
236
|
+
self.check_input()
|
237
|
+
|
238
|
+
return integrate(self.P, self.B)
|
239
|
+
|
240
|
+
|
241
|
+
@njit
|
242
|
+
def set_initial_state(nn, seed=None):
|
243
|
+
|
244
|
+
if seed is not None:
|
245
|
+
set_seed_compat(seed)
|
246
|
+
|
247
|
+
y0 = np.random.rand(2 * nn)
|
248
|
+
y0[:nn] = y0[:nn] * 1.5
|
249
|
+
y0[nn:] = y0[nn:] * 4 - 2
|
250
|
+
return y0
|
251
|
+
|
252
|
+
|
253
|
+
mpr_spec = [
|
254
|
+
("G", float64),
|
255
|
+
("dt", float64),
|
256
|
+
("J", float64),
|
257
|
+
("eta", float64[:]),
|
258
|
+
("tau", float64),
|
259
|
+
("weights", float64[:, :]),
|
260
|
+
("delta", float64),
|
261
|
+
("t_init", float64),
|
262
|
+
("t_cut", float64),
|
263
|
+
("t_end", float64),
|
264
|
+
("nn", int64),
|
265
|
+
("method", types.string),
|
266
|
+
("seed", int64),
|
267
|
+
("initial_state", float64[:]),
|
268
|
+
("noise_amp", float64),
|
269
|
+
("sigma_r", float64),
|
270
|
+
("sigma_v", float64),
|
271
|
+
("iapp", float64),
|
272
|
+
("output", types.string),
|
273
|
+
("RECORD_RV", boolean),
|
274
|
+
("RECORD_BOLD", boolean),
|
275
|
+
("rv_decimate", int64),
|
276
|
+
("tr", float64),
|
277
|
+
]
|
278
|
+
|
279
|
+
|
280
|
+
@jitclass(mpr_spec)
|
281
|
+
class ParMPR:
|
282
|
+
def __init__(
|
283
|
+
self,
|
284
|
+
G=0.5,
|
285
|
+
dt=0.01,
|
286
|
+
J=14.5,
|
287
|
+
eta=np.array([-4.6]),
|
288
|
+
tau=1.0,
|
289
|
+
delta=0.7,
|
290
|
+
rv_decimate=1.0,
|
291
|
+
noise_amp=0.037,
|
292
|
+
weights=np.array([[], []]),
|
293
|
+
t_init=0.0,
|
294
|
+
t_cut=0.0,
|
295
|
+
t_end=1000.0,
|
296
|
+
iapp=0.0,
|
297
|
+
seed=-1,
|
298
|
+
output="output",
|
299
|
+
RECORD_RV=True,
|
300
|
+
RECORD_BOLD=True,
|
301
|
+
tr=500.0, # TR in milliseconds
|
302
|
+
):
|
303
|
+
|
304
|
+
self.G = G
|
305
|
+
self.dt = dt
|
306
|
+
self.J = J
|
307
|
+
self.eta = eta
|
308
|
+
self.tau = tau
|
309
|
+
self.delta = delta
|
310
|
+
self.rv_decimate = rv_decimate
|
311
|
+
self.noise_amp = noise_amp
|
312
|
+
self.t_init = t_init
|
313
|
+
self.t_cut = t_cut
|
314
|
+
self.t_end = t_end
|
315
|
+
self.iapp = iapp
|
316
|
+
self.nn = len(weights)
|
317
|
+
self.seed = seed
|
318
|
+
self.output = output
|
319
|
+
self.weights = weights
|
320
|
+
self.RECORD_RV = RECORD_RV
|
321
|
+
self.RECORD_BOLD = RECORD_BOLD
|
322
|
+
self.sigma_r = np.sqrt(dt) * np.sqrt(2 * noise_amp)
|
323
|
+
self.sigma_v = np.sqrt(dt) * np.sqrt(4 * noise_amp)
|
324
|
+
self.tr = tr
|
325
|
+
|
326
|
+
|
327
|
+
bold_spec = [
|
328
|
+
("kappa", float64),
|
329
|
+
("gamma", float64),
|
330
|
+
("tau", float64),
|
331
|
+
("alpha", float64),
|
332
|
+
("epsilon", float64),
|
333
|
+
("Eo", float64),
|
334
|
+
("TE", float64),
|
335
|
+
("vo", float64),
|
336
|
+
("r0", float64),
|
337
|
+
("theta0", float64),
|
338
|
+
("t_min", float64),
|
339
|
+
("rtol", float64),
|
340
|
+
("atol", float64),
|
341
|
+
]
|
342
|
+
|
343
|
+
|
344
|
+
@jitclass(bold_spec)
|
345
|
+
class ParBold:
|
346
|
+
def __init__(
|
347
|
+
self,
|
348
|
+
kappa=0.65,
|
349
|
+
gamma=0.41,
|
350
|
+
tau=0.98,
|
351
|
+
alpha=0.32,
|
352
|
+
epsilon=0.34,
|
353
|
+
Eo=0.4,
|
354
|
+
TE=0.04,
|
355
|
+
vo=0.08,
|
356
|
+
r0=25.0,
|
357
|
+
theta0=40.3,
|
358
|
+
t_min=0.0,
|
359
|
+
rtol=1e-5,
|
360
|
+
atol=1e-8,
|
361
|
+
):
|
362
|
+
self.kappa = kappa
|
363
|
+
self.gamma = gamma
|
364
|
+
self.tau = tau
|
365
|
+
self.alpha = alpha
|
366
|
+
self.epsilon = epsilon
|
367
|
+
self.Eo = Eo
|
368
|
+
self.TE = TE
|
369
|
+
self.vo = vo
|
370
|
+
self.r0 = r0
|
371
|
+
self.theta0 = theta0
|
372
|
+
self.t_min = t_min
|
373
|
+
self.rtol = rtol
|
374
|
+
self.atol = atol
|
375
|
+
|
376
|
+
|
377
|
+
def check_vec_size(x, nn):
|
378
|
+
return np.ones(nn) * x if len(x) != nn else np.array(x)
|
379
|
+
|
380
|
+
|
381
|
+
@register_jitable
|
382
|
+
def set_seed_compat(x):
|
383
|
+
np.random.seed(x)
|
File without changes
|
Binary file
|
Binary file
|