vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl → 0.2__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vbi/feature_extraction/features.json +4 -1
- vbi/feature_extraction/features.py +10 -4
- vbi/inference.py +50 -22
- vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/jr_sde.hpp +5 -6
- vbi/models/cpp/_src/jr_sde_wrap.cxx +28 -28
- vbi/models/cpp/jansen_rit.py +2 -9
- vbi/models/cupy/bold.py +117 -0
- vbi/models/cupy/jansen_rit.py +1 -1
- vbi/models/cupy/km.py +62 -34
- vbi/models/cupy/mpr.py +24 -4
- vbi/models/cupy/utils.py +163 -2
- vbi/models/cupy/wilson_cowan.py +317 -0
- vbi/models/cupy/ww.py +342 -0
- vbi/models/numba/__init__.py +4 -0
- vbi/models/numba/jansen_rit.py +532 -0
- vbi/models/numba/mpr.py +8 -0
- vbi/models/numba/wilson_cowan.py +443 -0
- vbi/models/numba/ww.py +564 -0
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/METADATA +30 -11
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/RECORD +30 -26
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/WHEEL +1 -1
- vbi/models/numba/_ww_EI.py +0 -444
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/licenses/LICENSE +0 -0
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/top_level.txt +0 -0
vbi/models/numba/ww.py
ADDED
@@ -0,0 +1,564 @@
|
|
1
|
+
# gpt5
|
2
|
+
|
3
|
+
import warnings
|
4
|
+
import numpy as np
|
5
|
+
from copy import copy
|
6
|
+
from numba import njit, jit
|
7
|
+
from numba.experimental import jitclass
|
8
|
+
from numba.extending import register_jitable
|
9
|
+
from numba import float64, boolean, int64, types
|
10
|
+
from numba.core.errors import NumbaPerformanceWarning
|
11
|
+
|
12
|
+
warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
|
13
|
+
|
14
|
+
|
15
|
+
# -----------------------------
|
16
|
+
# Helper utilities
|
17
|
+
# -----------------------------
|
18
|
+
|
19
|
+
# @register_jitable
|
20
|
+
# def set_seed_compat(x):
|
21
|
+
# np.random.seed(x)
|
22
|
+
|
23
|
+
@jit(nopython=True)
|
24
|
+
def initialize_random_state(seed):
|
25
|
+
"""Call this once to set the seed in Numba context"""
|
26
|
+
np.random.seed(seed)
|
27
|
+
|
28
|
+
|
29
|
+
|
30
|
+
def check_vec_size_1d(x, nn):
|
31
|
+
"""Return a 1D vector of size nn, broadcasting scalar if needed (no numba)."""
|
32
|
+
x = np.array(x, dtype=np.float64) if np.ndim(x) > 0 else np.array([x], dtype=np.float64)
|
33
|
+
return np.ones(nn, dtype=np.float64) * x if x.size != nn else x.astype(np.float64)
|
34
|
+
|
35
|
+
|
36
|
+
# -----------------------------
|
37
|
+
# BOLD model parameters (same structure as in mpr.py)
|
38
|
+
# -----------------------------
|
39
|
+
|
40
|
+
bold_spec = [
|
41
|
+
("kappa", float64),
|
42
|
+
("gamma", float64),
|
43
|
+
("tau", float64),
|
44
|
+
("alpha", float64),
|
45
|
+
("epsilon", float64),
|
46
|
+
("Eo", float64),
|
47
|
+
("TE", float64),
|
48
|
+
("vo", float64),
|
49
|
+
("r0", float64),
|
50
|
+
("theta0", float64),
|
51
|
+
("t_min", float64),
|
52
|
+
("rtol", float64),
|
53
|
+
("atol", float64),
|
54
|
+
]
|
55
|
+
|
56
|
+
|
57
|
+
@jitclass(bold_spec)
|
58
|
+
class ParBold:
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
kappa=0.65,
|
62
|
+
gamma=0.41,
|
63
|
+
tau=0.98,
|
64
|
+
alpha=0.32,
|
65
|
+
epsilon=0.34,
|
66
|
+
Eo=0.4,
|
67
|
+
TE=0.04,
|
68
|
+
vo=0.08,
|
69
|
+
r0=25.0,
|
70
|
+
theta0=40.3,
|
71
|
+
t_min=0.0,
|
72
|
+
rtol=1e-5,
|
73
|
+
atol=1e-8,
|
74
|
+
):
|
75
|
+
self.kappa = kappa
|
76
|
+
self.gamma = gamma
|
77
|
+
self.tau = tau
|
78
|
+
self.alpha = alpha
|
79
|
+
self.epsilon = epsilon
|
80
|
+
self.Eo = Eo
|
81
|
+
self.TE = TE
|
82
|
+
self.vo = vo
|
83
|
+
self.r0 = r0
|
84
|
+
self.theta0 = theta0
|
85
|
+
self.t_min = t_min
|
86
|
+
self.rtol = rtol
|
87
|
+
self.atol = atol
|
88
|
+
|
89
|
+
|
90
|
+
@jit(nopython=True)
|
91
|
+
def do_bold_step(r_in, s, f, ftilde, vtilde, qtilde, v, q, dtt, P):
|
92
|
+
"""
|
93
|
+
One BOLD step for all nodes (vectorized over nn). Same as mpr.py.
|
94
|
+
r_in should be non-negative neural drive per node (here we use S_exc).
|
95
|
+
"""
|
96
|
+
kappa = P.kappa
|
97
|
+
gamma = P.gamma
|
98
|
+
ialpha = 1.0 / P.alpha
|
99
|
+
tau = P.tau
|
100
|
+
Eo = P.Eo
|
101
|
+
|
102
|
+
s[1] = s[0] + dtt * (r_in - kappa * s[0] - gamma * (f[0] - 1.0))
|
103
|
+
# keep f[0] >= 1 to avoid log issues
|
104
|
+
f[0] = np.maximum(f[0], 1.0)
|
105
|
+
ftilde[1] = ftilde[0] + dtt * (s[0] / f[0])
|
106
|
+
fv = v[0] ** ialpha # outflow
|
107
|
+
vtilde[1] = vtilde[0] + dtt * ((f[0] - fv) / (tau * v[0]))
|
108
|
+
q[0] = np.maximum(q[0], 0.01)
|
109
|
+
ff = (1.0 - (1.0 - Eo) ** (1.0 / f[0])) / Eo # oxygen extraction
|
110
|
+
qtilde[1] = qtilde[0] + dtt * ((f[0] * ff - fv * q[0] / v[0]) / (tau * q[0]))
|
111
|
+
|
112
|
+
# exponentiate back
|
113
|
+
f[1] = np.exp(ftilde[1])
|
114
|
+
v[1] = np.exp(vtilde[1])
|
115
|
+
q[1] = np.exp(qtilde[1])
|
116
|
+
|
117
|
+
# roll state
|
118
|
+
f[0] = f[1]
|
119
|
+
s[0] = s[1]
|
120
|
+
ftilde[0] = ftilde[1]
|
121
|
+
vtilde[0] = vtilde[1]
|
122
|
+
qtilde[0] = qtilde[1]
|
123
|
+
v[0] = v[1]
|
124
|
+
q[0] = q[1]
|
125
|
+
|
126
|
+
|
127
|
+
# -----------------------------
|
128
|
+
# Wong–Wang model params (Numba jitclass)
|
129
|
+
# -----------------------------
|
130
|
+
|
131
|
+
ww_spec = [
|
132
|
+
# local population parameters
|
133
|
+
("a_exc", float64),
|
134
|
+
("a_inh", float64),
|
135
|
+
("b_exc", float64),
|
136
|
+
("b_inh", float64),
|
137
|
+
("d_exc", float64),
|
138
|
+
("d_inh", float64),
|
139
|
+
("tau_exc", float64),
|
140
|
+
("tau_inh", float64),
|
141
|
+
("gamma_exc", float64),
|
142
|
+
("gamma_inh", float64),
|
143
|
+
("W_exc", float64),
|
144
|
+
("W_inh", float64),
|
145
|
+
("ext_current", float64[:]),
|
146
|
+
("J_NMDA", float64),
|
147
|
+
("J_I", float64),
|
148
|
+
("w_plus", float64),
|
149
|
+
("lambda_inh_exc", float64),
|
150
|
+
# global / simulation parameters
|
151
|
+
("t_end", float64),
|
152
|
+
("t_cut", float64),
|
153
|
+
("dt", float64),
|
154
|
+
("G_exc", float64),
|
155
|
+
("G_inh", float64),
|
156
|
+
("weights", float64[:, :]),
|
157
|
+
("tr", float64),
|
158
|
+
("s_decimate", int64),
|
159
|
+
("sigma", float64),
|
160
|
+
("nn", int64),
|
161
|
+
("seed", int64),
|
162
|
+
("output", types.string),
|
163
|
+
("dtype", types.string),
|
164
|
+
("initial_state", float64[:]),
|
165
|
+
("RECORD_S", boolean),
|
166
|
+
("RECORD_BOLD", boolean),
|
167
|
+
]
|
168
|
+
|
169
|
+
|
170
|
+
@jitclass(ww_spec)
|
171
|
+
class ParWW:
|
172
|
+
def __init__(
|
173
|
+
self,
|
174
|
+
# exc/inh params (Wong & Wang 2006 / Deco et al.)
|
175
|
+
a_exc=310.0,
|
176
|
+
a_inh=0.615,
|
177
|
+
b_exc=125.0,
|
178
|
+
b_inh=177.0,
|
179
|
+
d_exc=0.16,
|
180
|
+
d_inh=0.087,
|
181
|
+
tau_exc=100.0, # ms
|
182
|
+
tau_inh=10.0, # ms
|
183
|
+
gamma_exc=0.641 / 1000.0,
|
184
|
+
gamma_inh=1.0 / 1000.0,
|
185
|
+
W_exc=1.0,
|
186
|
+
W_inh=0.7,
|
187
|
+
ext_current=np.array([0.382]), # nA
|
188
|
+
J_NMDA=0.15,
|
189
|
+
J_I=1.0,
|
190
|
+
w_plus=1.4,
|
191
|
+
lambda_inh_exc=0.0,
|
192
|
+
# simulation
|
193
|
+
t_end=1000.0,
|
194
|
+
t_cut=0.0,
|
195
|
+
dt=0.1,
|
196
|
+
G_exc=0.0,
|
197
|
+
G_inh=0.0,
|
198
|
+
weights=np.array([[], []]),
|
199
|
+
tr=300.0, # ms
|
200
|
+
s_decimate=1,
|
201
|
+
sigma=0.0,
|
202
|
+
nn=1,
|
203
|
+
seed=-1,
|
204
|
+
output="output",
|
205
|
+
dtype="f",
|
206
|
+
initial_state=np.array([0.0]),
|
207
|
+
RECORD_S=False,
|
208
|
+
RECORD_BOLD=True,
|
209
|
+
):
|
210
|
+
# assign
|
211
|
+
self.a_exc = a_exc
|
212
|
+
self.a_inh = a_inh
|
213
|
+
self.b_exc = b_exc
|
214
|
+
self.b_inh = b_inh
|
215
|
+
self.d_exc = d_exc
|
216
|
+
self.d_inh = d_inh
|
217
|
+
self.tau_exc = tau_exc
|
218
|
+
self.tau_inh = tau_inh
|
219
|
+
self.gamma_exc = gamma_exc
|
220
|
+
self.gamma_inh = gamma_inh
|
221
|
+
self.W_exc = W_exc
|
222
|
+
self.W_inh = W_inh
|
223
|
+
self.ext_current = ext_current
|
224
|
+
self.J_NMDA = J_NMDA
|
225
|
+
self.J_I = J_I
|
226
|
+
self.w_plus = w_plus
|
227
|
+
self.lambda_inh_exc = lambda_inh_exc
|
228
|
+
|
229
|
+
self.t_end = t_end
|
230
|
+
self.t_cut = t_cut
|
231
|
+
self.dt = dt
|
232
|
+
self.G_exc = G_exc
|
233
|
+
self.G_inh = G_inh
|
234
|
+
self.weights = weights
|
235
|
+
self.tr = tr
|
236
|
+
self.s_decimate = s_decimate
|
237
|
+
self.sigma = sigma
|
238
|
+
self.nn = nn
|
239
|
+
self.seed = seed
|
240
|
+
self.output = output
|
241
|
+
self.dtype = dtype
|
242
|
+
self.initial_state = initial_state
|
243
|
+
self.RECORD_S = RECORD_S
|
244
|
+
self.RECORD_BOLD = RECORD_BOLD
|
245
|
+
|
246
|
+
|
247
|
+
# -----------------------------
|
248
|
+
# Wong–Wang dynamics (Numba)
|
249
|
+
# -----------------------------
|
250
|
+
|
251
|
+
@njit
|
252
|
+
def firing_rate(current, a, b, d):
|
253
|
+
"""
|
254
|
+
r(I) = (a I - b) / (1 - exp(-d (a I - b)))
|
255
|
+
Safe for vector inputs.
|
256
|
+
"""
|
257
|
+
u = a * current - b
|
258
|
+
den = 1.0 - np.exp(-d * u)
|
259
|
+
# avoid division by ~0; if u ~ 0 => limit is a/d
|
260
|
+
out = np.zeros_like(current)
|
261
|
+
for i in range(current.shape[0]):
|
262
|
+
if np.abs(den[i]) < 1e-12:
|
263
|
+
out[i] = a * u[i] * 0.5 # very small; fallback (won't really occur)
|
264
|
+
else:
|
265
|
+
out[i] = u[i] / den[i]
|
266
|
+
return out
|
267
|
+
|
268
|
+
|
269
|
+
@njit
|
270
|
+
def f_ww(S, t, P):
|
271
|
+
"""
|
272
|
+
Right-hand side for Wong–Wang model.
|
273
|
+
S: length 2*nn vector [S_exc, S_inh]
|
274
|
+
returns dS/dt shape (2*nn,)
|
275
|
+
"""
|
276
|
+
nn = P.nn
|
277
|
+
S_exc = S[:nn]
|
278
|
+
S_inh = S[nn:]
|
279
|
+
|
280
|
+
# network couplings
|
281
|
+
network_exc_exc = P.weights.dot(S_exc)
|
282
|
+
if P.lambda_inh_exc > 0.0:
|
283
|
+
network_inh_exc = P.weights.dot(S_inh)
|
284
|
+
else:
|
285
|
+
network_inh_exc = np.zeros_like(S_exc)
|
286
|
+
|
287
|
+
# currents
|
288
|
+
current_exc = (
|
289
|
+
P.W_exc * P.ext_current
|
290
|
+
+ P.w_plus * P.J_NMDA * S_exc
|
291
|
+
+ P.G_exc * P.J_NMDA * network_exc_exc
|
292
|
+
- P.J_I * S_inh
|
293
|
+
)
|
294
|
+
|
295
|
+
current_inh = (
|
296
|
+
P.W_inh * P.ext_current
|
297
|
+
+ P.J_NMDA * S_inh
|
298
|
+
- S_inh
|
299
|
+
+ P.G_inh * P.J_NMDA * network_inh_exc
|
300
|
+
)
|
301
|
+
|
302
|
+
# firing rates
|
303
|
+
r_exc = firing_rate(current_exc, P.a_exc, P.b_exc, P.d_exc)
|
304
|
+
r_inh = firing_rate(current_inh, P.a_inh, P.b_inh, P.d_inh)
|
305
|
+
|
306
|
+
dSdt = np.zeros(2 * nn)
|
307
|
+
|
308
|
+
# exc
|
309
|
+
dSdt[:nn] = (-S_exc / P.tau_exc) + (1.0 - S_exc) * P.gamma_exc * r_exc
|
310
|
+
# inh
|
311
|
+
dSdt[nn:] = (-S_inh / P.tau_inh) + P.gamma_inh * r_inh
|
312
|
+
|
313
|
+
return dSdt
|
314
|
+
|
315
|
+
|
316
|
+
@jit(nopython=True)
|
317
|
+
def heun_sde(S, t, P):
|
318
|
+
"""
|
319
|
+
One Heun stochastic step for S (2*nn vector).
|
320
|
+
"""
|
321
|
+
dt = P.dt
|
322
|
+
nn = P.nn
|
323
|
+
|
324
|
+
dW = P.sigma * np.sqrt(dt) * np.random.randn(2 * nn)
|
325
|
+
|
326
|
+
k1 = f_ww(S, t, P)
|
327
|
+
y_ = S + dt * k1 + dW
|
328
|
+
k2 = f_ww(y_, t + dt, P)
|
329
|
+
S = S + 0.5 * dt * (k1 + k2) + dW
|
330
|
+
|
331
|
+
return S
|
332
|
+
|
333
|
+
|
334
|
+
# -----------------------------
|
335
|
+
# Public-facing class (mirror mpr.py style)
|
336
|
+
# -----------------------------
|
337
|
+
|
338
|
+
class WW_sde:
|
339
|
+
def __init__(self, par: dict = None, Bpar: dict = None) -> None:
|
340
|
+
if par is None:
|
341
|
+
par = {}
|
342
|
+
if Bpar is None:
|
343
|
+
Bpar = {}
|
344
|
+
|
345
|
+
# sanity & defaults
|
346
|
+
nn = par.get("nn", None)
|
347
|
+
weights = par.get("weights", None)
|
348
|
+
if weights is None:
|
349
|
+
# default single node
|
350
|
+
weights = np.zeros((1, 1), dtype=np.float64)
|
351
|
+
weights = np.array(weights, dtype=np.float64)
|
352
|
+
if nn is None:
|
353
|
+
nn = weights.shape[0]
|
354
|
+
par.setdefault("nn", nn)
|
355
|
+
|
356
|
+
# broadcast scalars to vectors where necessary
|
357
|
+
par.setdefault("ext_current", 0.382)
|
358
|
+
par["ext_current"] = check_vec_size_1d(par["ext_current"], nn)
|
359
|
+
|
360
|
+
# dt-based noise scalars are computed inside heun_sde (uses sigma directly)
|
361
|
+
|
362
|
+
# initial state
|
363
|
+
if "initial_state" in par:
|
364
|
+
par["initial_state"] = np.array(par["initial_state"], dtype=np.float64)
|
365
|
+
else:
|
366
|
+
par["initial_state"] = set_initial_state(nn, par.get("seed", -1))
|
367
|
+
|
368
|
+
par.setdefault("dtype", "f") # kept for compatibility
|
369
|
+
par.setdefault("output", "output")
|
370
|
+
|
371
|
+
# create numba jitclass param holders
|
372
|
+
self.P = ParWW(
|
373
|
+
a_exc=par.get("a_exc", 310.0),
|
374
|
+
a_inh=par.get("a_inh", 0.615),
|
375
|
+
b_exc=par.get("b_exc", 125.0),
|
376
|
+
b_inh=par.get("b_inh", 177.0),
|
377
|
+
d_exc=par.get("d_exc", 0.16),
|
378
|
+
d_inh=par.get("d_inh", 0.087),
|
379
|
+
tau_exc=par.get("tau_exc", 100.0),
|
380
|
+
tau_inh=par.get("tau_inh", 10.0),
|
381
|
+
gamma_exc=par.get("gamma_exc", 0.641 / 1000.0),
|
382
|
+
gamma_inh=par.get("gamma_inh", 1.0 / 1000.0),
|
383
|
+
W_exc=par.get("W_exc", 1.0),
|
384
|
+
W_inh=par.get("W_inh", 0.7),
|
385
|
+
ext_current=par["ext_current"].astype(np.float64),
|
386
|
+
J_NMDA=par.get("J_NMDA", 0.15),
|
387
|
+
J_I=par.get("J_I", 1.0),
|
388
|
+
w_plus=par.get("w_plus", 1.4),
|
389
|
+
lambda_inh_exc=par.get("lambda_inh_exc", 0.0),
|
390
|
+
t_end=par.get("t_end", 1000.0),
|
391
|
+
t_cut=par.get("t_cut", 0.0),
|
392
|
+
dt=par.get("dt", 0.1),
|
393
|
+
G_exc=par.get("G_exc", 0.0),
|
394
|
+
G_inh=par.get("G_inh", 0.0),
|
395
|
+
weights=weights,
|
396
|
+
tr=par.get("tr", 300.0),
|
397
|
+
s_decimate=int(par.get("s_decimate", 1)),
|
398
|
+
sigma=par.get("sigma", 0.0),
|
399
|
+
nn=nn,
|
400
|
+
seed=int(par.get("seed", -1)),
|
401
|
+
output=par.get("output", "output"),
|
402
|
+
dtype=par.get("dtype", "f"),
|
403
|
+
initial_state=par["initial_state"],
|
404
|
+
RECORD_S=bool(par.get("RECORD_S", False)),
|
405
|
+
RECORD_BOLD=bool(par.get("RECORD_BOLD", True)),
|
406
|
+
)
|
407
|
+
|
408
|
+
# Bold parameters
|
409
|
+
self.B = ParBold(
|
410
|
+
kappa=Bpar.get("kappa", 0.65),
|
411
|
+
gamma=Bpar.get("gamma", 0.41),
|
412
|
+
tau=Bpar.get("tau", 0.98),
|
413
|
+
alpha=Bpar.get("alpha", 0.32),
|
414
|
+
epsilon=Bpar.get("epsilon", 0.34),
|
415
|
+
Eo=Bpar.get("Eo", 0.4),
|
416
|
+
TE=Bpar.get("TE", 0.04),
|
417
|
+
vo=Bpar.get("vo", 0.08),
|
418
|
+
r0=Bpar.get("r0", 25.0),
|
419
|
+
theta0=Bpar.get("theta0", 40.3),
|
420
|
+
t_min=Bpar.get("t_min", 0.0),
|
421
|
+
rtol=Bpar.get("rtol", 1e-5),
|
422
|
+
atol=Bpar.get("atol", 1e-8),
|
423
|
+
)
|
424
|
+
|
425
|
+
# seeding
|
426
|
+
if self.P.seed >= 0:
|
427
|
+
# set_seed_compat(self.P.seed)
|
428
|
+
initialize_random_state(self.P.seed)
|
429
|
+
# print(f"WW_sde: setting random seed to {self.P.seed}")
|
430
|
+
|
431
|
+
def __str__(self) -> str:
|
432
|
+
lines = [
|
433
|
+
"Wong-Wang (Numba) model",
|
434
|
+
"Parameters: --------------------------------",
|
435
|
+
]
|
436
|
+
for name in [
|
437
|
+
"nn", "dt", "t_end", "t_cut", "G_exc", "G_inh", "sigma", "tr",
|
438
|
+
"a_exc", "b_exc", "d_exc", "tau_exc", "gamma_exc",
|
439
|
+
"a_inh", "b_inh", "d_inh", "tau_inh", "gamma_inh",
|
440
|
+
"W_exc", "W_inh", "w_plus", "J_NMDA", "J_I",
|
441
|
+
]:
|
442
|
+
lines.append(f"{name} = {getattr(self.P, name)}")
|
443
|
+
lines.append("--------------------------------------------")
|
444
|
+
return "\n".join(lines)
|
445
|
+
|
446
|
+
# -----------------------------
|
447
|
+
# Simulation
|
448
|
+
# -----------------------------
|
449
|
+
def run(self, par: dict = None, x0=None, verbose=True):
|
450
|
+
"""
|
451
|
+
Run simulation and return dict with:
|
452
|
+
- 'S': recorded S_exc if RECORD_S (shape [T, nn])
|
453
|
+
- 't': times for S (ms)
|
454
|
+
- 'bold_t': times for BOLD (ms)
|
455
|
+
- 'bold_d': BOLD signal [T_bold, nn]
|
456
|
+
"""
|
457
|
+
# update runtime parameters if provided
|
458
|
+
if par:
|
459
|
+
for key, val in par.items():
|
460
|
+
if key == "ext_current":
|
461
|
+
val = check_vec_size_1d(val, self.P.nn).astype(np.float64)
|
462
|
+
setattr(self.P, key, val)
|
463
|
+
|
464
|
+
# initial state
|
465
|
+
if x0 is None:
|
466
|
+
S = copy(self.P.initial_state)
|
467
|
+
else:
|
468
|
+
S = np.array(x0, dtype=np.float64)
|
469
|
+
|
470
|
+
# sanity
|
471
|
+
assert self.P.weights is not None
|
472
|
+
assert self.P.weights.shape[0] == self.P.weights.shape[1]
|
473
|
+
assert len(S) == 2 * self.P.nn, "x0 must be length 2*nn"
|
474
|
+
assert self.P.t_cut < self.P.t_end
|
475
|
+
|
476
|
+
# time grid
|
477
|
+
nt = int(np.floor(self.P.t_end / self.P.dt))
|
478
|
+
t = np.arange(nt) * self.P.dt
|
479
|
+
valid_mask = t > self.P.t_cut
|
480
|
+
s_buffer_len = int(np.sum(valid_mask) // max(1, self.P.s_decimate))
|
481
|
+
|
482
|
+
# buffers
|
483
|
+
t_buf = np.zeros((s_buffer_len,), dtype=np.float32)
|
484
|
+
S_rec = np.zeros((s_buffer_len, self.P.nn), dtype=np.float32) if self.P.RECORD_S else np.array([])
|
485
|
+
|
486
|
+
# BOLD buffers
|
487
|
+
tr = self.P.tr
|
488
|
+
bold_decimate = int(np.round(tr / self.P.dt))
|
489
|
+
dtt = self.P.dt / 1000.0 # seconds
|
490
|
+
s = np.zeros((2, self.P.nn))
|
491
|
+
f = np.zeros((2, self.P.nn))
|
492
|
+
ftilde = np.zeros((2, self.P.nn))
|
493
|
+
vtilde = np.zeros((2, self.P.nn))
|
494
|
+
qtilde = np.zeros((2, self.P.nn))
|
495
|
+
v = np.zeros((2, self.P.nn))
|
496
|
+
q = np.zeros((2, self.P.nn))
|
497
|
+
vv = np.zeros((nt // max(1, bold_decimate), self.P.nn), dtype=np.float64)
|
498
|
+
qq = np.zeros_like(vv)
|
499
|
+
|
500
|
+
# init BOLD states
|
501
|
+
s[0] = 1.0
|
502
|
+
f[0] = 1.0
|
503
|
+
v[0] = 1.0
|
504
|
+
q[0] = 1.0
|
505
|
+
ftilde[0] = 0.0
|
506
|
+
vtilde[0] = 0.0
|
507
|
+
qtilde[0] = 0.0
|
508
|
+
|
509
|
+
# main loop
|
510
|
+
s_idx = 0
|
511
|
+
for i in range(nt):
|
512
|
+
t_curr = i * self.P.dt
|
513
|
+
S = heun_sde(S, t_curr, self.P)
|
514
|
+
|
515
|
+
if (t_curr > self.P.t_cut) and (i % max(1, self.P.s_decimate) == 0):
|
516
|
+
if s_idx < s_buffer_len:
|
517
|
+
t_buf[s_idx] = t_curr
|
518
|
+
if self.P.RECORD_S:
|
519
|
+
S_rec[s_idx] = S[: self.P.nn].astype(np.float32)
|
520
|
+
s_idx += 1
|
521
|
+
|
522
|
+
if self.P.RECORD_BOLD:
|
523
|
+
do_bold_step(S[: self.P.nn], s, f, ftilde, vtilde, qtilde, v, q, dtt, self.B)
|
524
|
+
if (i % max(1, bold_decimate) == 0) and ((i // max(1, bold_decimate)) < vv.shape[0]):
|
525
|
+
vv[i // max(1, bold_decimate)] = v[1]
|
526
|
+
qq[i // max(1, bold_decimate)] = q[1]
|
527
|
+
|
528
|
+
# finalize BOLD
|
529
|
+
bold_t = np.linspace(0.0, self.P.t_end - self.P.dt * max(1, bold_decimate), vv.shape[0])
|
530
|
+
if self.P.RECORD_BOLD:
|
531
|
+
# cut off t <= t_cut
|
532
|
+
valid = bold_t > self.P.t_cut
|
533
|
+
bold_t = bold_t[valid]
|
534
|
+
if bold_t.size > 0:
|
535
|
+
vv = vv[valid]
|
536
|
+
qq = qq[valid]
|
537
|
+
k1 = 4.3 * self.B.theta0 * self.B.Eo * self.B.TE
|
538
|
+
k2 = self.B.epsilon * self.B.r0 * self.B.Eo * self.B.TE
|
539
|
+
k3 = 1.0 - self.B.epsilon
|
540
|
+
bold_d = self.B.vo * (k1 * (1.0 - qq) + k2 * (1.0 - qq / vv) + k3 * (1.0 - vv))
|
541
|
+
else:
|
542
|
+
bold_d = np.array([])
|
543
|
+
else:
|
544
|
+
bold_d = np.array([])
|
545
|
+
bold_t = np.array([])
|
546
|
+
|
547
|
+
return {
|
548
|
+
"S": S_rec,
|
549
|
+
"t": t_buf,
|
550
|
+
"bold_t": bold_t.astype(np.float32),
|
551
|
+
"bold_d": bold_d.astype(np.float32),
|
552
|
+
}
|
553
|
+
|
554
|
+
|
555
|
+
# -----------------------------
|
556
|
+
# API helpers
|
557
|
+
# -----------------------------
|
558
|
+
|
559
|
+
def set_initial_state(nn, seed=-1):
|
560
|
+
if seed is not None and seed >= 0:
|
561
|
+
np.random.seed(seed)
|
562
|
+
# initialize_random_state(seed)
|
563
|
+
y0 = np.random.rand(2 * nn) * 0.1 # small positive
|
564
|
+
return y0.astype(np.float64)
|
@@ -1,14 +1,13 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: vbi
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2
|
4
4
|
Summary: Virtual brain inference.
|
5
5
|
Author-email: Abolfazl Ziaeemehr <a.ziaeemehr@gmail.com>, Meysam Hashemi <meysam.hashemi@gmail.com>, Marmaduke Woodman <marmaduke.woodman@gmail.com>
|
6
|
-
License:
|
6
|
+
License-Expression: Apache-2.0
|
7
7
|
Project-URL: homepage, https://ziaeemehr.github.io/vbi_paper/
|
8
8
|
Project-URL: repository, https://github.com/Ziaeemehr/vbi_paper
|
9
9
|
Classifier: Programming Language :: Python :: 3
|
10
10
|
Classifier: Topic :: Scientific/Engineering :: Information Analysis
|
11
|
-
Classifier: License :: OSI Approved :: Apache Software License
|
12
11
|
Classifier: Operating System :: OS Independent
|
13
12
|
Requires-Python: >=3.8
|
14
13
|
Description-Content-Type: text/markdown
|
@@ -75,11 +74,17 @@ Dynamic: license-file
|
|
75
74
|
```bash
|
76
75
|
conda env create --name vbi python=3.10
|
77
76
|
conda activate vbi
|
77
|
+
# from pip: Recommended
|
78
|
+
pip install vbi
|
79
|
+
# from source: More recent update
|
78
80
|
git clone https://github.com/ins-amu/vbi.git
|
79
81
|
cd vbi
|
80
82
|
pip install .
|
81
83
|
|
82
84
|
# pip install -e .[all,dev,docs]
|
85
|
+
|
86
|
+
# To skip C++ compilation, use the following environment variable and install from source:
|
87
|
+
SKIP_CPP=1 pip install -e .
|
83
88
|
```
|
84
89
|
|
85
90
|
## Using Docker
|
@@ -143,14 +148,28 @@ We welcome contributions to the VBI project! If you have suggestions, bug report
|
|
143
148
|
## Citation
|
144
149
|
|
145
150
|
```bibtex
|
146
|
-
@article{VBI,
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
151
|
+
@article{VBI,
|
152
|
+
title={Virtual Brain Inference (VBI): A flexible and integrative toolkit for efficient probabilistic inference on virtual brain models},
|
153
|
+
author={Ziaeemehr, Abolfazl and Woodman, Marmaduke and Domide, Lia and Petkoski, Spase and Jirsa, Viktor and Hashemi, Meysam},
|
154
|
+
DOI={10.7554/elife.106194.1},
|
155
|
+
url={http://dx.doi.org/10.7554/eLife.106194.1},
|
156
|
+
publisher={eLife Sciences Publications, Ltd},
|
157
|
+
year={2025},
|
158
|
+
abstract = {Network neuroscience has proven essential for understanding the principles and mechanisms
|
159
|
+
underlying complex brain (dys)function and cognition. In this context, whole-brain network modeling–
|
160
|
+
also known as virtual brain modeling–combines computational models of brain dynamics (placed at each network node)
|
161
|
+
with individual brain imaging data (to coordinate and connect the nodes), advancing our understanding of
|
162
|
+
the complex dynamics of the brain and its neurobiological underpinnings. However, there remains a critical
|
163
|
+
need for automated model inversion tools to estimate control (bifurcation) parameters at large scales
|
164
|
+
associated with neuroimaging modalities, given their varying spatio-temporal resolutions.
|
165
|
+
This study aims to address this gap by introducing a flexible and integrative toolkit for efficient Bayesian inference
|
166
|
+
on virtual brain models, called Virtual Brain Inference (VBI). This open-source toolkit provides fast simulations,
|
167
|
+
taxonomy of feature extraction, efficient data storage and loading, and probabilistic machine learning algorithms,
|
168
|
+
enabling biophysically interpretable inference from non-invasive and invasive recordings.
|
169
|
+
Through in-silico testing, we demonstrate the accuracy and reliability of inference for commonly used
|
170
|
+
whole-brain network models and their associated neuroimaging data. VBI shows potential to improve hypothesis
|
171
|
+
evaluation in network neuroscience through uncertainty quantification, and contribute to advances in precision
|
172
|
+
medicine by enhancing the predictive power of virtual brain models.}
|
154
173
|
}
|
155
174
|
```
|
156
175
|
|