vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl → 0.2__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vbi/feature_extraction/features.json +4 -1
- vbi/feature_extraction/features.py +10 -4
- vbi/inference.py +50 -22
- vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/jr_sde.hpp +5 -6
- vbi/models/cpp/_src/jr_sde_wrap.cxx +28 -28
- vbi/models/cpp/jansen_rit.py +2 -9
- vbi/models/cupy/bold.py +117 -0
- vbi/models/cupy/jansen_rit.py +1 -1
- vbi/models/cupy/km.py +62 -34
- vbi/models/cupy/mpr.py +24 -4
- vbi/models/cupy/utils.py +163 -2
- vbi/models/cupy/wilson_cowan.py +317 -0
- vbi/models/cupy/ww.py +342 -0
- vbi/models/numba/__init__.py +4 -0
- vbi/models/numba/jansen_rit.py +532 -0
- vbi/models/numba/mpr.py +8 -0
- vbi/models/numba/wilson_cowan.py +443 -0
- vbi/models/numba/ww.py +564 -0
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/METADATA +30 -11
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/RECORD +30 -26
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/WHEEL +1 -1
- vbi/models/numba/_ww_EI.py +0 -444
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/licenses/LICENSE +0 -0
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,443 @@
|
|
1
|
+
|
2
|
+
import warnings
|
3
|
+
import numpy as np
|
4
|
+
from numba import njit, jit
|
5
|
+
from numba.experimental import jitclass
|
6
|
+
from numba.extending import register_jitable
|
7
|
+
from numba import float64, boolean, int64, types
|
8
|
+
from numba.core.errors import NumbaPerformanceWarning
|
9
|
+
|
10
|
+
warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
|
11
|
+
|
12
|
+
|
13
|
+
# ---------- utilities ----------
|
14
|
+
|
15
|
+
def _to_1d_array(x):
|
16
|
+
x = np.array(x, dtype=np.float64)
|
17
|
+
if x.ndim == 0:
|
18
|
+
x = x.reshape(1)
|
19
|
+
return x
|
20
|
+
|
21
|
+
def _to_2d_array(x):
|
22
|
+
x = np.array(x, dtype=np.float64)
|
23
|
+
if x.ndim == 1:
|
24
|
+
# try to guess a square matrix if possible
|
25
|
+
n = int(np.sqrt(x.size))
|
26
|
+
if n * n == x.size:
|
27
|
+
x = x.reshape(n, n)
|
28
|
+
else:
|
29
|
+
raise ValueError("weights must be square (nxn).")
|
30
|
+
return x
|
31
|
+
|
32
|
+
def check_vec_size(x, nn):
|
33
|
+
"""Return a length-nn vector from scalar/1-vector or already-length-nn input."""
|
34
|
+
arr = np.array(x, dtype=np.float64)
|
35
|
+
if arr.ndim == 0:
|
36
|
+
return np.ones(nn, dtype=np.float64) * float(arr)
|
37
|
+
if arr.size == 1:
|
38
|
+
return np.ones(nn, dtype=np.float64) * float(arr[0])
|
39
|
+
if arr.size != nn:
|
40
|
+
raise ValueError(f"Vector parameter has size {arr.size} but nn={nn}.")
|
41
|
+
return arr.astype(np.float64)
|
42
|
+
|
43
|
+
|
44
|
+
@register_jitable
|
45
|
+
def set_seed_compat(x):
|
46
|
+
np.random.seed(x)
|
47
|
+
|
48
|
+
|
49
|
+
# ---------- core model (Numba) ----------
|
50
|
+
|
51
|
+
wc_spec = [
|
52
|
+
("c_ee", float64[:]),
|
53
|
+
("c_ei", float64[:]),
|
54
|
+
("c_ie", float64[:]),
|
55
|
+
("c_ii", float64[:]),
|
56
|
+
("tau_e", float64[:]),
|
57
|
+
("tau_i", float64[:]),
|
58
|
+
("a_e", float64),
|
59
|
+
("a_i", float64),
|
60
|
+
("b_e", float64),
|
61
|
+
("b_i", float64),
|
62
|
+
("c_e", float64),
|
63
|
+
("c_i", float64),
|
64
|
+
("theta_e", float64),
|
65
|
+
("theta_i", float64),
|
66
|
+
("r_e", float64),
|
67
|
+
("r_i", float64),
|
68
|
+
("k_e", float64),
|
69
|
+
("k_i", float64),
|
70
|
+
("alpha_e", float64),
|
71
|
+
("alpha_i", float64),
|
72
|
+
("P", float64[:]),
|
73
|
+
("Q", float64[:]),
|
74
|
+
("g_e", float64),
|
75
|
+
("g_i", float64),
|
76
|
+
("dt", float64),
|
77
|
+
("t_end", float64),
|
78
|
+
("t_cut", float64),
|
79
|
+
("nn", int64),
|
80
|
+
("weights", float64[:, :]),
|
81
|
+
("seed", int64),
|
82
|
+
("noise_amp", float64),
|
83
|
+
("decimate", int64),
|
84
|
+
("RECORD_EI", types.string),
|
85
|
+
("initial_state", float64[:]),
|
86
|
+
("shift_sigmoid", boolean),
|
87
|
+
]
|
88
|
+
|
89
|
+
|
90
|
+
@jitclass(wc_spec)
|
91
|
+
class ParWC:
|
92
|
+
def __init__(
|
93
|
+
self,
|
94
|
+
c_ee=np.array([16.0]),
|
95
|
+
c_ei=np.array([12.0]),
|
96
|
+
c_ie=np.array([15.0]),
|
97
|
+
c_ii=np.array([3.0]),
|
98
|
+
tau_e=np.array([8.0]),
|
99
|
+
tau_i=np.array([8.0]),
|
100
|
+
a_e=1.3,
|
101
|
+
a_i=2.0,
|
102
|
+
b_e=4.0,
|
103
|
+
b_i=3.7,
|
104
|
+
c_e=1.0,
|
105
|
+
c_i=1.0,
|
106
|
+
theta_e=0.0,
|
107
|
+
theta_i=0.0,
|
108
|
+
r_e=1.0,
|
109
|
+
r_i=1.0,
|
110
|
+
k_e=0.994,
|
111
|
+
k_i=0.999,
|
112
|
+
alpha_e=1.0,
|
113
|
+
alpha_i=1.0,
|
114
|
+
P=np.array([0.0]),
|
115
|
+
Q=np.array([0.0]),
|
116
|
+
g_e=0.0,
|
117
|
+
g_i=0.0,
|
118
|
+
dt=0.01,
|
119
|
+
t_end=300.0,
|
120
|
+
t_cut=0.0,
|
121
|
+
weights=np.empty((0, 0), dtype=np.float64),
|
122
|
+
seed=-1,
|
123
|
+
noise_amp=0.0,
|
124
|
+
decimate=1,
|
125
|
+
RECORD_EI="E",
|
126
|
+
initial_state=np.empty(0, dtype=np.float64),
|
127
|
+
shift_sigmoid=False,
|
128
|
+
):
|
129
|
+
self.c_ee = c_ee
|
130
|
+
self.c_ei = c_ei
|
131
|
+
self.c_ie = c_ie
|
132
|
+
self.c_ii = c_ii
|
133
|
+
self.tau_e = tau_e
|
134
|
+
self.tau_i = tau_i
|
135
|
+
self.a_e = a_e
|
136
|
+
self.a_i = a_i
|
137
|
+
self.b_e = b_e
|
138
|
+
self.b_i = b_i
|
139
|
+
self.c_e = c_e
|
140
|
+
self.c_i = c_i
|
141
|
+
self.theta_e = theta_e
|
142
|
+
self.theta_i = theta_i
|
143
|
+
self.r_e = r_e
|
144
|
+
self.r_i = r_i
|
145
|
+
self.k_e = k_e
|
146
|
+
self.k_i = k_i
|
147
|
+
self.alpha_e = alpha_e
|
148
|
+
self.alpha_i = alpha_i
|
149
|
+
self.P = P
|
150
|
+
self.Q = Q
|
151
|
+
self.g_e = g_e
|
152
|
+
self.g_i = g_i
|
153
|
+
self.dt = dt
|
154
|
+
self.t_end = t_end
|
155
|
+
self.t_cut = t_cut
|
156
|
+
self.nn = len(weights)
|
157
|
+
self.weights = weights
|
158
|
+
self.seed = seed
|
159
|
+
self.noise_amp = noise_amp
|
160
|
+
self.decimate = decimate
|
161
|
+
self.RECORD_EI = RECORD_EI
|
162
|
+
self.initial_state = initial_state
|
163
|
+
self.shift_sigmoid = shift_sigmoid
|
164
|
+
|
165
|
+
|
166
|
+
@njit
|
167
|
+
def sigmoid_vec(x, a, b, c, shift_sigmoid):
|
168
|
+
y = np.empty_like(x)
|
169
|
+
if shift_sigmoid:
|
170
|
+
# c * (sigmoid(a(x-b)) - sigmoid(-ab))
|
171
|
+
base = 1.0 / (1.0 + np.exp(-a * (-b)))
|
172
|
+
for i in range(x.size):
|
173
|
+
y[i] = c * (1.0 / (1.0 + np.exp(-a * (x[i] - b))) - base)
|
174
|
+
else:
|
175
|
+
for i in range(x.size):
|
176
|
+
y[i] = c / (1.0 + np.exp(-a * (x[i] - b)))
|
177
|
+
return y
|
178
|
+
|
179
|
+
|
180
|
+
@njit
|
181
|
+
def f_wc(x, t, P):
|
182
|
+
"""
|
183
|
+
Wilson-Cowan ODE right-hand side (per-node, single simulation).
|
184
|
+
x: shape (2*nn,)
|
185
|
+
"""
|
186
|
+
nn = P.nn
|
187
|
+
dxdt = np.zeros_like(x)
|
188
|
+
|
189
|
+
E = x[:nn]
|
190
|
+
I = x[nn:]
|
191
|
+
|
192
|
+
# Linear coupling (weights @ state)
|
193
|
+
lc_e = P.g_e * np.dot(P.weights, E) if P.g_e != 0.0 else np.zeros(nn)
|
194
|
+
lc_i = P.g_i * np.dot(P.weights, I) if P.g_i != 0.0 else np.zeros(nn)
|
195
|
+
|
196
|
+
# Inputs to sigmoids
|
197
|
+
x_e = P.alpha_e * (P.c_ee * E - P.c_ei * I + P.P - P.theta_e + lc_e)
|
198
|
+
x_i = P.alpha_i * (P.c_ie * E - P.c_ii * I + P.Q - P.theta_i + lc_i)
|
199
|
+
|
200
|
+
s_e = sigmoid_vec(x_e, P.a_e, P.b_e, P.c_e, P.shift_sigmoid)
|
201
|
+
s_i = sigmoid_vec(x_i, P.a_i, P.b_i, P.c_i, P.shift_sigmoid)
|
202
|
+
|
203
|
+
# Time constants (vectorized)
|
204
|
+
inv_tau_e = 1.0 / P.tau_e
|
205
|
+
inv_tau_i = 1.0 / P.tau_i
|
206
|
+
|
207
|
+
# dE/dt
|
208
|
+
for i in range(nn):
|
209
|
+
dxdt[i] = inv_tau_e[i] * (-E[i] + (P.k_e - P.r_e * E[i]) * s_e[i])
|
210
|
+
# dI/dt
|
211
|
+
for i in range(nn):
|
212
|
+
dxdt[nn + i] = inv_tau_i[i] * (-I[i] + (P.k_i - P.r_i * I[i]) * s_i[i])
|
213
|
+
|
214
|
+
return dxdt
|
215
|
+
|
216
|
+
|
217
|
+
@njit
|
218
|
+
def heun_sde(x, t, P):
|
219
|
+
dt = P.dt
|
220
|
+
coeff = P.noise_amp * np.sqrt(dt)
|
221
|
+
dW = coeff * np.random.randn(x.size)
|
222
|
+
|
223
|
+
k1 = f_wc(x, t, P)
|
224
|
+
x1 = x + dt * k1 + dW
|
225
|
+
k2 = f_wc(x1, t + dt, P)
|
226
|
+
x_out = x + 0.5 * dt * (k1 + k2) + dW
|
227
|
+
return x_out
|
228
|
+
|
229
|
+
|
230
|
+
@njit
|
231
|
+
def set_initial_state(nn, seed=-1):
|
232
|
+
if seed >= 0:
|
233
|
+
set_seed_compat(seed)
|
234
|
+
y0 = np.random.rand(2 * nn)
|
235
|
+
return y0
|
236
|
+
|
237
|
+
|
238
|
+
# ---------- high-level API (Python) ----------
|
239
|
+
|
240
|
+
class WC_sde_numba:
|
241
|
+
"""
|
242
|
+
Numba implementation of the Wilson-Cowan SDE, modeled after mpr.py and
|
243
|
+
translated from the CuPy/Numpy reference.
|
244
|
+
"""
|
245
|
+
|
246
|
+
def __init__(self, par: dict = {}):
|
247
|
+
# Prepare raw dict and build jitclass
|
248
|
+
self.P = self._get_par_wc(par)
|
249
|
+
|
250
|
+
# Seed
|
251
|
+
if self.P.seed >= 0:
|
252
|
+
np.random.seed(self.P.seed)
|
253
|
+
|
254
|
+
def __call__(self):
|
255
|
+
return self.P
|
256
|
+
|
257
|
+
def __str__(self) -> str:
|
258
|
+
params = [
|
259
|
+
"nn", "dt", "t_end", "t_cut", "decimate", "noise_amp",
|
260
|
+
"g_e", "g_i", "a_e", "a_i", "b_e", "b_i", "k_e", "k_i",
|
261
|
+
]
|
262
|
+
s = ["Wilson-Cowan (Numba) parameters:"]
|
263
|
+
for k in params:
|
264
|
+
s.append(f"{k} = {getattr(self.P, k)}")
|
265
|
+
return "\n".join(s)
|
266
|
+
|
267
|
+
# ----- builders & checks -----
|
268
|
+
def _get_par_wc(self, par: dict):
|
269
|
+
par = dict(par) # shallow copy
|
270
|
+
|
271
|
+
# weights first (to infer nn)
|
272
|
+
if "weights" not in par:
|
273
|
+
raise ValueError("weights (nxn) must be provided.")
|
274
|
+
W = _to_2d_array(par["weights"])
|
275
|
+
nn = W.shape[0]
|
276
|
+
|
277
|
+
# convert possibly-scalar/vector params to length-nn arrays
|
278
|
+
vec_keys = ["c_ee","c_ei","c_ie","c_ii","tau_e","tau_i","P","Q"]
|
279
|
+
for k in vec_keys:
|
280
|
+
if k in par:
|
281
|
+
par[k] = check_vec_size(par[k], nn)
|
282
|
+
|
283
|
+
# defaults for any missing vector keys
|
284
|
+
defaults = {
|
285
|
+
"c_ee": 16.0, "c_ei": 12.0, "c_ie": 15.0, "c_ii": 3.0,
|
286
|
+
"tau_e": 8.0, "tau_i": 8.0, "P": 0.0, "Q": 0.0
|
287
|
+
}
|
288
|
+
for k, v in defaults.items():
|
289
|
+
if k not in par:
|
290
|
+
par[k] = np.ones(nn) * v
|
291
|
+
|
292
|
+
# set weights and nn
|
293
|
+
par["weights"] = W
|
294
|
+
|
295
|
+
# initial_state (optional)
|
296
|
+
if "initial_state" in par:
|
297
|
+
arr = np.array(par["initial_state"], dtype=np.float64)
|
298
|
+
if arr.size != 0 and arr.size != 2 * nn:
|
299
|
+
raise ValueError(f"initial_state must have length {2*nn}.")
|
300
|
+
par["initial_state"] = arr
|
301
|
+
else:
|
302
|
+
par["initial_state"] = np.empty(0, dtype=np.float64)
|
303
|
+
|
304
|
+
# strings/flags
|
305
|
+
if "RECORD_EI" not in par:
|
306
|
+
par["RECORD_EI"] = "E"
|
307
|
+
if "decimate" not in par:
|
308
|
+
par["decimate"] = 1
|
309
|
+
if "noise_amp" not in par:
|
310
|
+
par["noise_amp"] = 0.0
|
311
|
+
|
312
|
+
# build jitclass
|
313
|
+
P = ParWC(**par)
|
314
|
+
return P
|
315
|
+
|
316
|
+
def set_initial_state(self):
|
317
|
+
self.P.initial_state = set_initial_state(self.P.nn, self.P.seed)
|
318
|
+
|
319
|
+
def check_input(self):
|
320
|
+
P = self.P
|
321
|
+
assert P.weights.shape[0] == P.weights.shape[1], "weights must be square"
|
322
|
+
assert P.nn == P.weights.shape[0], "nn must match weights shape"
|
323
|
+
if P.initial_state.size == 0:
|
324
|
+
self.set_initial_state()
|
325
|
+
assert P.initial_state.size == 2 * P.nn, "initial_state length mismatch"
|
326
|
+
assert P.t_cut < P.t_end, "t_cut must be less than t_end"
|
327
|
+
|
328
|
+
# ensure vector parameters are length-nn (already enforced in builder)
|
329
|
+
# but re-check shapes at runtime for safety
|
330
|
+
for k in ["c_ee","c_ei","c_ie","c_ii","tau_e","tau_i","P","Q"]:
|
331
|
+
v = getattr(P, k)
|
332
|
+
assert v.size == P.nn, f"{k} must be length nn"
|
333
|
+
|
334
|
+
def run(self, par: dict = None, x0=None, verbose: bool = True):
|
335
|
+
# update parameters if provided
|
336
|
+
if par:
|
337
|
+
# (rebuild jitclass when structure-changing params come in)
|
338
|
+
merged = {**self._par_to_dict(), **par}
|
339
|
+
self.P = self._get_par_wc(merged)
|
340
|
+
|
341
|
+
# set external initial state if provided
|
342
|
+
if x0 is not None:
|
343
|
+
x0 = np.array(x0, dtype=np.float64)
|
344
|
+
if x0.size != 2 * self.P.nn:
|
345
|
+
raise ValueError(f"x0 must be length {2*self.P.nn}")
|
346
|
+
self.P.initial_state = x0
|
347
|
+
|
348
|
+
# checks
|
349
|
+
self.check_input()
|
350
|
+
|
351
|
+
return integrate(self.P, verbose=verbose)
|
352
|
+
|
353
|
+
def _par_to_dict(self):
|
354
|
+
P = self.P
|
355
|
+
d = {
|
356
|
+
"c_ee": np.array(P.c_ee),
|
357
|
+
"c_ei": np.array(P.c_ei),
|
358
|
+
"c_ie": np.array(P.c_ie),
|
359
|
+
"c_ii": np.array(P.c_ii),
|
360
|
+
"tau_e": np.array(P.tau_e),
|
361
|
+
"tau_i": np.array(P.tau_i),
|
362
|
+
"a_e": P.a_e,
|
363
|
+
"a_i": P.a_i,
|
364
|
+
"b_e": P.b_e,
|
365
|
+
"b_i": P.b_i,
|
366
|
+
"c_e": P.c_e,
|
367
|
+
"c_i": P.c_i,
|
368
|
+
"theta_e": P.theta_e,
|
369
|
+
"theta_i": P.theta_i,
|
370
|
+
"r_e": P.r_e,
|
371
|
+
"r_i": P.r_i,
|
372
|
+
"k_e": P.k_e,
|
373
|
+
"k_i": P.k_i,
|
374
|
+
"alpha_e": P.alpha_e,
|
375
|
+
"alpha_i": P.alpha_i,
|
376
|
+
"P": np.array(P.P),
|
377
|
+
"Q": np.array(P.Q),
|
378
|
+
"g_e": P.g_e,
|
379
|
+
"g_i": P.g_i,
|
380
|
+
"dt": P.dt,
|
381
|
+
"t_end": P.t_end,
|
382
|
+
"t_cut": P.t_cut,
|
383
|
+
"weights": np.array(P.weights),
|
384
|
+
"seed": P.seed,
|
385
|
+
"noise_amp": P.noise_amp,
|
386
|
+
"decimate": P.decimate,
|
387
|
+
"RECORD_EI": P.RECORD_EI,
|
388
|
+
"initial_state": np.array(P.initial_state),
|
389
|
+
"shift_sigmoid": P.shift_sigmoid,
|
390
|
+
}
|
391
|
+
return d
|
392
|
+
|
393
|
+
|
394
|
+
def integrate(P: ParWC, verbose=True):
|
395
|
+
"""
|
396
|
+
Pure-Python driver (Numba-accelerated inner steps).
|
397
|
+
Returns dict with t, E, I (float32).
|
398
|
+
"""
|
399
|
+
nn = P.nn
|
400
|
+
dt = P.dt
|
401
|
+
nt = int(P.t_end / dt)
|
402
|
+
dec = max(1, int(P.decimate))
|
403
|
+
|
404
|
+
# buffers sized after decimation & cut
|
405
|
+
# we'll first allocate full decimated length, then trim by t_cut
|
406
|
+
nbuf = nt // dec
|
407
|
+
record_e = "e" in P.RECORD_EI.lower()
|
408
|
+
record_i = "i" in P.RECORD_EI.lower()
|
409
|
+
|
410
|
+
t_buf = np.zeros(nbuf, dtype=np.float32)
|
411
|
+
E_buf = np.zeros((nbuf, nn), dtype=np.float32) if record_e else None
|
412
|
+
I_buf = np.zeros((nbuf, nn), dtype=np.float32) if record_i else None
|
413
|
+
|
414
|
+
x = P.initial_state.copy()
|
415
|
+
buf_idx = 0
|
416
|
+
|
417
|
+
for i in range(nt):
|
418
|
+
t_curr = i * dt
|
419
|
+
x = heun_sde(x, t_curr, P)
|
420
|
+
|
421
|
+
if (i % dec) == 0 and buf_idx < nbuf:
|
422
|
+
t_buf[buf_idx] = t_curr
|
423
|
+
if record_e:
|
424
|
+
E_buf[buf_idx] = x[:nn].astype(np.float32)
|
425
|
+
if record_i:
|
426
|
+
I_buf[buf_idx] = x[nn:].astype(np.float32)
|
427
|
+
buf_idx += 1
|
428
|
+
|
429
|
+
# trim to actual filled length
|
430
|
+
t_buf = t_buf[:buf_idx]
|
431
|
+
if record_e: E_buf = E_buf[:buf_idx]
|
432
|
+
if record_i: I_buf = I_buf[:buf_idx]
|
433
|
+
|
434
|
+
# apply t_cut
|
435
|
+
keep = t_buf >= P.t_cut
|
436
|
+
t_out = t_buf[keep]
|
437
|
+
E_out = E_buf[keep] if record_e else None
|
438
|
+
I_out = I_buf[keep] if record_i else None
|
439
|
+
|
440
|
+
return {"t": t_out, "E": E_out, "I": I_out}
|
441
|
+
|
442
|
+
|
443
|
+
WC_sde = WC_sde_numba # alias
|