vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vbi/__init__.py +37 -0
- vbi/_version.py +17 -0
- vbi/dataset/__init__.py +0 -0
- vbi/dataset/connectivity_84/centers.txt +84 -0
- vbi/dataset/connectivity_84/centres.txt +84 -0
- vbi/dataset/connectivity_84/cortical.txt +84 -0
- vbi/dataset/connectivity_84/tract_lengths.txt +84 -0
- vbi/dataset/connectivity_84/weights.txt +84 -0
- vbi/dataset/connectivity_88/Aud_88.txt +88 -0
- vbi/dataset/connectivity_88/Bold.npz +0 -0
- vbi/dataset/connectivity_88/Labels.txt +17 -0
- vbi/dataset/connectivity_88/Region_labels.txt +88 -0
- vbi/dataset/connectivity_88/tract_lengths.txt +88 -0
- vbi/dataset/connectivity_88/weights.txt +88 -0
- vbi/feature_extraction/__init__.py +1 -0
- vbi/feature_extraction/calc_features.py +293 -0
- vbi/feature_extraction/features.json +535 -0
- vbi/feature_extraction/features.py +2124 -0
- vbi/feature_extraction/features_settings.py +374 -0
- vbi/feature_extraction/features_utils.py +1357 -0
- vbi/feature_extraction/infodynamics.jar +0 -0
- vbi/feature_extraction/utility.py +507 -0
- vbi/inference.py +98 -0
- vbi/models/__init__.py +0 -0
- vbi/models/cpp/__init__.py +0 -0
- vbi/models/cpp/_src/__init__.py +0 -0
- vbi/models/cpp/_src/__pycache__/mpr_sde.cpython-310.pyc +0 -0
- vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/bold.hpp +303 -0
- vbi/models/cpp/_src/do.hpp +167 -0
- vbi/models/cpp/_src/do.i +17 -0
- vbi/models/cpp/_src/do.py +467 -0
- vbi/models/cpp/_src/do_wrap.cxx +12811 -0
- vbi/models/cpp/_src/jr_sdde.hpp +352 -0
- vbi/models/cpp/_src/jr_sdde.i +19 -0
- vbi/models/cpp/_src/jr_sdde.py +688 -0
- vbi/models/cpp/_src/jr_sdde_wrap.cxx +18718 -0
- vbi/models/cpp/_src/jr_sde.hpp +264 -0
- vbi/models/cpp/_src/jr_sde.i +17 -0
- vbi/models/cpp/_src/jr_sde.py +470 -0
- vbi/models/cpp/_src/jr_sde_wrap.cxx +13406 -0
- vbi/models/cpp/_src/km_sde.hpp +158 -0
- vbi/models/cpp/_src/km_sde.i +19 -0
- vbi/models/cpp/_src/km_sde.py +671 -0
- vbi/models/cpp/_src/km_sde_wrap.cxx +17367 -0
- vbi/models/cpp/_src/makefile +52 -0
- vbi/models/cpp/_src/mpr_sde.hpp +327 -0
- vbi/models/cpp/_src/mpr_sde.i +19 -0
- vbi/models/cpp/_src/mpr_sde.py +711 -0
- vbi/models/cpp/_src/mpr_sde_wrap.cxx +18618 -0
- vbi/models/cpp/_src/utility.hpp +307 -0
- vbi/models/cpp/_src/vep.hpp +171 -0
- vbi/models/cpp/_src/vep.i +16 -0
- vbi/models/cpp/_src/vep.py +464 -0
- vbi/models/cpp/_src/vep_wrap.cxx +12968 -0
- vbi/models/cpp/_src/wc_ode.hpp +294 -0
- vbi/models/cpp/_src/wc_ode.i +19 -0
- vbi/models/cpp/_src/wc_ode.py +686 -0
- vbi/models/cpp/_src/wc_ode_wrap.cxx +24263 -0
- vbi/models/cpp/damp_oscillator.py +143 -0
- vbi/models/cpp/jansen_rit.py +543 -0
- vbi/models/cpp/km.py +187 -0
- vbi/models/cpp/mpr.py +289 -0
- vbi/models/cpp/vep.py +150 -0
- vbi/models/cpp/wc.py +216 -0
- vbi/models/cupy/__init__.py +0 -0
- vbi/models/cupy/bold.py +111 -0
- vbi/models/cupy/ghb.py +284 -0
- vbi/models/cupy/jansen_rit.py +473 -0
- vbi/models/cupy/km.py +224 -0
- vbi/models/cupy/mpr.py +475 -0
- vbi/models/cupy/mpr_modified_bold.py +12 -0
- vbi/models/cupy/utils.py +184 -0
- vbi/models/numba/__init__.py +0 -0
- vbi/models/numba/_ww_EI.py +444 -0
- vbi/models/numba/damp_oscillator.py +162 -0
- vbi/models/numba/ghb.py +208 -0
- vbi/models/numba/mpr.py +383 -0
- vbi/models/pytorch/__init__.py +0 -0
- vbi/models/pytorch/data/default_parameters.npz +0 -0
- vbi/models/pytorch/data/input/ROI_sim.mat +0 -0
- vbi/models/pytorch/data/input/fc_test.csv +68 -0
- vbi/models/pytorch/data/input/fc_train.csv +68 -0
- vbi/models/pytorch/data/input/fc_vali.csv +68 -0
- vbi/models/pytorch/data/input/fcd_test.mat +0 -0
- vbi/models/pytorch/data/input/fcd_test_high_window.mat +0 -0
- vbi/models/pytorch/data/input/fcd_test_low_window.mat +0 -0
- vbi/models/pytorch/data/input/fcd_train.mat +0 -0
- vbi/models/pytorch/data/input/fcd_vali.mat +0 -0
- vbi/models/pytorch/data/input/myelin.csv +68 -0
- vbi/models/pytorch/data/input/rsfc_gradient.csv +68 -0
- vbi/models/pytorch/data/input/run_label_testset.mat +0 -0
- vbi/models/pytorch/data/input/sc_test.csv +68 -0
- vbi/models/pytorch/data/input/sc_train.csv +68 -0
- vbi/models/pytorch/data/input/sc_vali.csv +68 -0
- vbi/models/pytorch/data/obs_kong0.npz +0 -0
- vbi/models/pytorch/ww_sde_kong.py +570 -0
- vbi/models/tvbk/__init__.py +9 -0
- vbi/models/tvbk/tvbk_wrapper.py +166 -0
- vbi/models/tvbk/utils.py +72 -0
- vbi/papers/__init__.py +0 -0
- vbi/papers/pavlides_pcb_2015/pavlides.py +211 -0
- vbi/tests/__init__.py +0 -0
- vbi/tests/_test_mpr_nb.py +36 -0
- vbi/tests/test_features.py +355 -0
- vbi/tests/test_ghb_cupy.py +90 -0
- vbi/tests/test_mpr_cupy.py +49 -0
- vbi/tests/test_mpr_numba.py +84 -0
- vbi/tests/test_suite.py +19 -0
- vbi/utils.py +402 -0
- vbi-0.1.3.dist-info/METADATA +166 -0
- vbi-0.1.3.dist-info/RECORD +121 -0
- vbi-0.1.3.dist-info/WHEEL +5 -0
- vbi-0.1.3.dist-info/licenses/LICENSE +201 -0
- vbi-0.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,444 @@
|
|
1
|
+
import warnings
|
2
|
+
import numpy as np
|
3
|
+
from numba import njit
|
4
|
+
from numba.experimental import jitclass
|
5
|
+
from numba.core.errors import NumbaPerformanceWarning
|
6
|
+
from numba import float64, boolean, int64, types
|
7
|
+
|
8
|
+
warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
|
9
|
+
|
10
|
+
|
11
|
+
w_spec = [
|
12
|
+
("G", float64),
|
13
|
+
|
14
|
+
("a_I", float64),
|
15
|
+
("b_I", float64),
|
16
|
+
("d_I", float64),
|
17
|
+
("tau_I", float64),
|
18
|
+
|
19
|
+
("a_E", float64),
|
20
|
+
("b_E", float64),
|
21
|
+
("d_E", float64),
|
22
|
+
("tau_E", float64),
|
23
|
+
|
24
|
+
("w_II", float64),
|
25
|
+
("w_EE", float64),
|
26
|
+
("w_IE", float64),
|
27
|
+
("w_EI", float64),
|
28
|
+
|
29
|
+
("W_E", float64),
|
30
|
+
("W_I", float64),
|
31
|
+
|
32
|
+
("gamma", float64),
|
33
|
+
("dt", float64),
|
34
|
+
("J_NMDA", float64),
|
35
|
+
("J_I", float64),
|
36
|
+
|
37
|
+
("I_I", float64[:]),
|
38
|
+
("I_E", float64[:]),
|
39
|
+
|
40
|
+
# ("sigma_I", float64),
|
41
|
+
# ("sigma_E", float64),
|
42
|
+
|
43
|
+
("initial_state", float64[:]),
|
44
|
+
("weights", float64[:, :]),
|
45
|
+
("seed", int64),
|
46
|
+
("method", types.string),
|
47
|
+
("t_end", float64),
|
48
|
+
("t_cut", float64),
|
49
|
+
("nn", int64),
|
50
|
+
("ts_decimate", int64),
|
51
|
+
("fmri_decimate", int64),
|
52
|
+
("RECORD_TS", boolean),
|
53
|
+
("RECORD_FMRI", boolean),
|
54
|
+
]
|
55
|
+
|
56
|
+
b_spec = [
|
57
|
+
("eps", float64),
|
58
|
+
("E0", float64),
|
59
|
+
("V0", float64),
|
60
|
+
("alpha", float64),
|
61
|
+
("inv_alpha", float64),
|
62
|
+
("K1", float64),
|
63
|
+
("K2", float64),
|
64
|
+
("K3", float64),
|
65
|
+
("taus", float64),
|
66
|
+
("tauo", float64),
|
67
|
+
("tauf", float64),
|
68
|
+
("inv_tauo", float64),
|
69
|
+
("inv_taus", float64),
|
70
|
+
("inv_tauf", float64),
|
71
|
+
("nn", int64),
|
72
|
+
("dt_bold", float64),
|
73
|
+
]
|
74
|
+
|
75
|
+
|
76
|
+
@jitclass(w_spec)
|
77
|
+
class ParWW:
|
78
|
+
def __init__(
|
79
|
+
self,
|
80
|
+
G=0.0,
|
81
|
+
a_I=615.0,
|
82
|
+
b_I=177.0,
|
83
|
+
d_I=0.087,
|
84
|
+
tau_I=0.01,
|
85
|
+
|
86
|
+
a_E=310.0,
|
87
|
+
b_E=125.0,
|
88
|
+
d_E=0.16,
|
89
|
+
tau_E=0.1,
|
90
|
+
|
91
|
+
gamma=0.641,
|
92
|
+
|
93
|
+
w_II=1.0,
|
94
|
+
w_IE=1.4,
|
95
|
+
w_EI=1.0,
|
96
|
+
w_EE=1.0,
|
97
|
+
dt=0.01,
|
98
|
+
|
99
|
+
W_E=1.0,
|
100
|
+
W_I =0.7,
|
101
|
+
|
102
|
+
I0 = 0.382,
|
103
|
+
J_NMDA=0.15,
|
104
|
+
|
105
|
+
|
106
|
+
I_I=np.array([0.296]), # 0.296
|
107
|
+
I_E=np.array([0.377]), # 0.377
|
108
|
+
sigma_I=0.001,
|
109
|
+
sigma_E=0.001,
|
110
|
+
initial_state=np.array([]),
|
111
|
+
weights=np.array([[], []]),
|
112
|
+
seed=-1,
|
113
|
+
method="heun",
|
114
|
+
t_end=300.0,
|
115
|
+
t_cut=0.0,
|
116
|
+
ts_decimate=10,
|
117
|
+
fmri_decimate=10,
|
118
|
+
RECORD_TS=True,
|
119
|
+
RECORD_FMRI=True,
|
120
|
+
):
|
121
|
+
self.G = G
|
122
|
+
self.a_I = a_I
|
123
|
+
self.b_I = b_I
|
124
|
+
self.d_I = d_I
|
125
|
+
self.tau_I = tau_I
|
126
|
+
|
127
|
+
self.a_E = a_E
|
128
|
+
self.b_E = b_E
|
129
|
+
self.d_E = d_E
|
130
|
+
self.tau_E = tau_E
|
131
|
+
|
132
|
+
self.w_II = w_II
|
133
|
+
self.w_IE = w_IE
|
134
|
+
self.w_EI = w_EI
|
135
|
+
self.w_EE = w_EE
|
136
|
+
self.gamma = gamma
|
137
|
+
|
138
|
+
self.dt = dt
|
139
|
+
|
140
|
+
self.W_E = W_E
|
141
|
+
self.W_I = W_I
|
142
|
+
|
143
|
+
self.I0 = I0
|
144
|
+
self.I_E = I_E
|
145
|
+
self.I_I = I_I
|
146
|
+
self.J_NMDA = J_NMDA
|
147
|
+
|
148
|
+
self.sigma_I = sigma_I
|
149
|
+
self.sigma_E = sigma_E
|
150
|
+
|
151
|
+
self.initial_state = initial_state
|
152
|
+
self.weights = weights
|
153
|
+
self.seed = seed
|
154
|
+
self.method = method
|
155
|
+
self.t_end = t_end
|
156
|
+
self.t_cut = t_cut
|
157
|
+
self.ts_decimate = ts_decimate
|
158
|
+
self.fmri_decimate = fmri_decimate
|
159
|
+
self.RECORD_TS = RECORD_TS
|
160
|
+
self.RECORD_FMRI = RECORD_FMRI
|
161
|
+
if len(initial_state) > 0:
|
162
|
+
self.nn = len(initial_state)
|
163
|
+
else:
|
164
|
+
self.nn = -1
|
165
|
+
|
166
|
+
|
167
|
+
@jitclass(b_spec)
|
168
|
+
class ParBaloon:
|
169
|
+
def __init__(
|
170
|
+
self, eps=0.5, E0=0.4, V0=4.0,
|
171
|
+
alpha=0.32, taus=1.54, tauo=0.98, tauf=1.44
|
172
|
+
):
|
173
|
+
self.eps = eps
|
174
|
+
self.E0 = E0
|
175
|
+
self.V0 = V0
|
176
|
+
self.alpha = alpha
|
177
|
+
self.inv_alpha = 1.0 / alpha
|
178
|
+
self.K1 = 7.0 * E0
|
179
|
+
self.K2 = 2 * E0
|
180
|
+
self.K3 = 1 - eps
|
181
|
+
self.taus = taus
|
182
|
+
self.tauo = tauo
|
183
|
+
self.tauf = tauf
|
184
|
+
self.inv_tauo = 1.0 / tauo
|
185
|
+
self.inv_taus = 1.0 / taus
|
186
|
+
self.inv_tauf = 1.0 / tauf
|
187
|
+
self.dt_bold = 0.01
|
188
|
+
|
189
|
+
|
190
|
+
@njit
|
191
|
+
def f_ww(S, P):
|
192
|
+
"""
|
193
|
+
system function for Wong-Wang model.
|
194
|
+
"""
|
195
|
+
coupling = np.dot(P.weights, S)
|
196
|
+
x = P.w * P.J_N * S + P.I_o + P.G * P.J_N * coupling
|
197
|
+
H = (P.a * x - P.b) / (1 - np.exp(-P.d * (P.a * x - P.b)))
|
198
|
+
dS = -(S / P.tau_s) + (1 - S) * H * P.gamma
|
199
|
+
return dS
|
200
|
+
|
201
|
+
|
202
|
+
@njit
|
203
|
+
def f_fmri(xin, x, t, B):
|
204
|
+
"""
|
205
|
+
system function for Balloon model.
|
206
|
+
"""
|
207
|
+
E0 = B.E0
|
208
|
+
nn = B.nn
|
209
|
+
inv_tauf = B.inv_tauf
|
210
|
+
inv_tauo = B.inv_tauo
|
211
|
+
inv_taus = B.inv_taus
|
212
|
+
inv_alpha = B.inv_alpha
|
213
|
+
|
214
|
+
dxdt = np.zeros(4 * nn)
|
215
|
+
s = x[:nn]
|
216
|
+
f = x[nn : 2 * nn]
|
217
|
+
v = x[2 * nn : 3 * nn]
|
218
|
+
q = x[3 * nn :]
|
219
|
+
|
220
|
+
dxdt[:nn] = xin - inv_taus * s - inv_tauf * (f - 1.0)
|
221
|
+
dxdt[nn : (2 * nn)] = s
|
222
|
+
dxdt[(2 * nn) : (3 * nn)] = inv_tauo * (f - v ** (inv_alpha))
|
223
|
+
dxdt[3 * nn :] = (inv_tauo) * (
|
224
|
+
(f * (1.0 - (1.0 - E0) ** (1.0 / f)) / E0) - (v ** (inv_alpha)) * (q / v)
|
225
|
+
)
|
226
|
+
return dxdt
|
227
|
+
|
228
|
+
|
229
|
+
@njit
|
230
|
+
def euler_sde_step(S, P):
|
231
|
+
dW = np.sqrt(P.dt) * P.sigma_noise * np.random.randn(P.nn)
|
232
|
+
return S + P.dt * f_ww(S, P) + dW
|
233
|
+
|
234
|
+
|
235
|
+
@njit
|
236
|
+
def heun_sde_step(S, P):
|
237
|
+
dW = np.sqrt(P.dt) * P.sigma_noise * np.random.randn(P.nn)
|
238
|
+
k0 = f_ww(S, P)
|
239
|
+
S1 = S + P.dt * k0 + dW
|
240
|
+
k1 = f_ww(S1, P)
|
241
|
+
return S + 0.5 * P.dt * (k0 + k1) + dW
|
242
|
+
|
243
|
+
|
244
|
+
@njit
|
245
|
+
def heun_ode_step(yin, y, t, B):
|
246
|
+
"""Heun scheme."""
|
247
|
+
|
248
|
+
dt = B.dt_bold
|
249
|
+
k1 = f_fmri(yin, y, t, B)
|
250
|
+
tmp = y + dt * k1
|
251
|
+
k2 = f_fmri(yin, tmp, t + dt, B)
|
252
|
+
y += 0.5 * dt * (k1 + k2)
|
253
|
+
return y
|
254
|
+
|
255
|
+
|
256
|
+
@njit
|
257
|
+
def integrate_fmri(yin, y, t, B):
|
258
|
+
"""
|
259
|
+
Integrate Balloon model
|
260
|
+
|
261
|
+
Parameters
|
262
|
+
----------
|
263
|
+
yin : array [nn]
|
264
|
+
r and v time series, r is used as input
|
265
|
+
y : array [4*nn]
|
266
|
+
state, update in place
|
267
|
+
t : float
|
268
|
+
time
|
269
|
+
|
270
|
+
Returns
|
271
|
+
-------
|
272
|
+
yb : array [nn]
|
273
|
+
BOLD signal
|
274
|
+
|
275
|
+
"""
|
276
|
+
|
277
|
+
V0 = B.V0
|
278
|
+
K1 = B.K1
|
279
|
+
K2 = B.K2
|
280
|
+
K3 = B.K3
|
281
|
+
|
282
|
+
nn = yin.shape[0]
|
283
|
+
y = heun_ode_step(yin, y, t, B)
|
284
|
+
yb = V0 * (
|
285
|
+
K1 * (1.0 - y[(3 * nn) :])
|
286
|
+
+ K2 * (1.0 - y[(3 * nn) :] / y[(2 * nn) : (3 * nn)])
|
287
|
+
+ K3 * (1.0 - y[(2 * nn) : (3 * nn)])
|
288
|
+
)
|
289
|
+
return y, yb
|
290
|
+
|
291
|
+
|
292
|
+
@njit
|
293
|
+
def integrate(P, B, intg=heun_sde_step):
|
294
|
+
"""
|
295
|
+
integrate Wong-Wang model and Balloon model.
|
296
|
+
"""
|
297
|
+
t = np.arange(0, P.t_end, P.dt)
|
298
|
+
nt = len(t)
|
299
|
+
nn = P.nn
|
300
|
+
|
301
|
+
if P.RECORD_TS:
|
302
|
+
T = np.empty(int(np.ceil(nt / P.ts_decimate)))
|
303
|
+
S = np.empty((int(np.ceil(nt / P.ts_decimate)), nn))
|
304
|
+
else:
|
305
|
+
T = np.empty(0)
|
306
|
+
S = np.empty((0, 1))
|
307
|
+
|
308
|
+
if P.RECORD_FMRI:
|
309
|
+
t_fmri = np.empty(int(np.ceil(nt / P.fmri_decimate)))
|
310
|
+
d_fmri = np.empty((int(np.ceil(nt / P.fmri_decimate)), nn))
|
311
|
+
else:
|
312
|
+
t_fmri = np.empty(0)
|
313
|
+
d_fmri = np.empty((0, 1))
|
314
|
+
|
315
|
+
x0 = P.initial_state
|
316
|
+
y0 = np.zeros((4 * nn))
|
317
|
+
y0[nn:] = 1.0
|
318
|
+
|
319
|
+
jj = 0
|
320
|
+
ii = 0
|
321
|
+
for i in range(1, nt):
|
322
|
+
|
323
|
+
t = i * P.dt
|
324
|
+
t_bold = i * B.dt_bold
|
325
|
+
x0 = intg(x0, P)
|
326
|
+
|
327
|
+
if P.RECORD_TS:
|
328
|
+
if i % P.ts_decimate == 0:
|
329
|
+
S[ii, :] = x0
|
330
|
+
T[ii] = t
|
331
|
+
ii += 1
|
332
|
+
if P.RECORD_FMRI:
|
333
|
+
y0, fmri_i = integrate_fmri(x0, y0, t, B)
|
334
|
+
if i % P.fmri_decimate == 0:
|
335
|
+
d_fmri[jj, :] = fmri_i
|
336
|
+
# t_fmri[jj] = t[i]
|
337
|
+
t_fmri[jj] = t_bold
|
338
|
+
jj += 1
|
339
|
+
S = S[T >= P.t_cut, :]
|
340
|
+
T = T[T >= P.t_cut]
|
341
|
+
d_fmri = d_fmri[t_fmri >= P.t_cut, :]
|
342
|
+
t_fmri = t_fmri[t_fmri >= P.t_cut]
|
343
|
+
|
344
|
+
return T, S, t_fmri, d_fmri
|
345
|
+
|
346
|
+
|
347
|
+
class WW_sde(object):
|
348
|
+
r"""
|
349
|
+
Wong-Wang model.
|
350
|
+
|
351
|
+
.. math::
|
352
|
+
x_k &= w\,J_N \, S_k + I_o + G\,J_N \sum_j \, C_{kj} \,Sj \\
|
353
|
+
H(x_k) &= \dfrac{ax_k - b}{1 - \exp(-d(ax_k -b))}\\
|
354
|
+
\dot{S}_k &= -\dfrac{S_k}{\tau_s} + (1 - S_k) \, H(x_k) \, \gamma + \sigma \, \Xi_k
|
355
|
+
|
356
|
+
- Kong-Fatt Wong and Xiao-Jing Wang, A Recurrent Network Mechanism of Time Integration in Perceptual Decisions. Journal of Neuroscience 26(4), 1314-1328, 2006.
|
357
|
+
- Deco Gustavo, Ponce Alvarez Adrian, Dante Mantini, Gian Luca Romani, Patric Hagmann and Maurizio Corbetta. Resting-State Functional Connectivity Emerges from Structurally and Dynamically Shaped Slow Linear Fluctuations. The Journal of Neuroscience 32(27), 11239-11252, 2013. Equations taken from DPA 2013 , page 11242.
|
358
|
+
"""
|
359
|
+
|
360
|
+
def __init__(self, par: dict = {}, parB: dict = {}) -> None:
|
361
|
+
|
362
|
+
self.valid_parW = [w_spec[i][0] for i in range(len(w_spec))]
|
363
|
+
self.valid_parB = [b_spec[i][0] for i in range(len(b_spec))]
|
364
|
+
self.valid_par = self.valid_parW + self.valid_parB
|
365
|
+
|
366
|
+
self.check_parameters(par)
|
367
|
+
self.P = self.get_par_WW_obj(par)
|
368
|
+
self.B = self.get_par_Baloon_obj(parB)
|
369
|
+
|
370
|
+
def __str__(self):
|
371
|
+
print("Wong-Wang model of neural population dynamics")
|
372
|
+
print("Parameters:----------------------------")
|
373
|
+
for key in self.valid_parW:
|
374
|
+
print(key, ": ", getattr(self.P, key))
|
375
|
+
print("---------------------------------------")
|
376
|
+
for key in self.valid_parB:
|
377
|
+
print(key, ": ", getattr(self.B, key))
|
378
|
+
return ""
|
379
|
+
|
380
|
+
def get_par_WW_obj(self, par: dict = {}) -> ParWW:
|
381
|
+
"""
|
382
|
+
return default parameters for Wong-Wang model.
|
383
|
+
"""
|
384
|
+
if "initial_state" in par.keys():
|
385
|
+
par["initial_state"] = np.array(par["initial_state"])
|
386
|
+
if "weights" in par.keys():
|
387
|
+
assert par["weights"] is not None
|
388
|
+
par["weights"] = np.array(par["weights"])
|
389
|
+
assert par["weights"].shape[0] == par["weights"].shape[1]
|
390
|
+
parobj = ParWW(**par)
|
391
|
+
|
392
|
+
return parobj
|
393
|
+
|
394
|
+
def get_par_Baloon_obj(self, par: dict = {}) -> ParBaloon:
|
395
|
+
"""
|
396
|
+
return default parameters for Balloon model.
|
397
|
+
"""
|
398
|
+
parobj = ParBaloon(**par)
|
399
|
+
return parobj
|
400
|
+
|
401
|
+
def check_parameters(self, par: dict) -> None:
|
402
|
+
for key in par.keys():
|
403
|
+
if key not in self.valid_par:
|
404
|
+
raise ValueError(f"Invalid parameter {key}")
|
405
|
+
|
406
|
+
def set_initial_state(self, seed=None):
|
407
|
+
if seed is not None:
|
408
|
+
np.random.seed(seed)
|
409
|
+
self.P.nn = self.P.weights.shape[0]
|
410
|
+
self.initial_state = np.random.rand(self.P.nn)
|
411
|
+
self.B.nn = self.P.nn
|
412
|
+
|
413
|
+
def check_input(self):
|
414
|
+
|
415
|
+
assert self.P.weights is not None
|
416
|
+
assert self.P.weights.shape[0] == self.P.weights.shape[1]
|
417
|
+
assert self.P.initial_state is not None
|
418
|
+
assert len(self.P.initial_state) == self.P.weights.shape[0]
|
419
|
+
self.B.nn = self.P.nn
|
420
|
+
# self.B.dt = self.P.dt
|
421
|
+
|
422
|
+
def run(self, par={}, parB={}, x0=None, verbose=True):
|
423
|
+
|
424
|
+
if x0 is None:
|
425
|
+
self.seed = self.P.seed if self.P.seed > 0 else None
|
426
|
+
self.set_initial_state(self.seed)
|
427
|
+
self.P.initial_state = self.initial_state
|
428
|
+
else:
|
429
|
+
self.P.initial_state = x0
|
430
|
+
self.P.nn = len(x0)
|
431
|
+
|
432
|
+
if par:
|
433
|
+
self.check_parameters(par)
|
434
|
+
for key in par.keys():
|
435
|
+
setattr(self.P, key, par[key])
|
436
|
+
if parB:
|
437
|
+
self.check_parameters(parB)
|
438
|
+
for key in parB.keys():
|
439
|
+
setattr(self.B, key, parB[key])
|
440
|
+
self.check_input()
|
441
|
+
|
442
|
+
T, S, t_fmri, d_fmri = integrate(self.P, self.B)
|
443
|
+
|
444
|
+
return {"t": T, "s": S, "t_fmri": t_fmri, "d_fmri": d_fmri}
|
@@ -0,0 +1,162 @@
|
|
1
|
+
import os
|
2
|
+
import numpy as np
|
3
|
+
from numba.experimental import jitclass
|
4
|
+
from numba import float64, types
|
5
|
+
from numba import njit
|
6
|
+
|
7
|
+
jit_spec = [('a', float64),
|
8
|
+
('b', float64),
|
9
|
+
('dt', float64),
|
10
|
+
('t_start', float64),
|
11
|
+
('t_end', float64),
|
12
|
+
('t_cut', float64),
|
13
|
+
('output', types.string),
|
14
|
+
('initial_state', float64[:]),
|
15
|
+
('method', types.string)
|
16
|
+
]
|
17
|
+
|
18
|
+
@jitclass(jit_spec)
|
19
|
+
class Param:
|
20
|
+
def __init__(self,
|
21
|
+
a=0.1,
|
22
|
+
b=0.05,
|
23
|
+
dt=0.01,
|
24
|
+
t_start=0,
|
25
|
+
t_end=100.0,
|
26
|
+
t_cut=20,
|
27
|
+
output="output",
|
28
|
+
method="euler",
|
29
|
+
initial_state=np.array([0.5, 1.0])
|
30
|
+
):
|
31
|
+
self.a = a
|
32
|
+
self.b = b
|
33
|
+
self.dt = dt
|
34
|
+
self.t_start = t_start
|
35
|
+
self.t_end = t_end
|
36
|
+
self.t_cut = t_cut
|
37
|
+
self.output = output
|
38
|
+
self.method = method
|
39
|
+
self.initial_state = initial_state
|
40
|
+
|
41
|
+
|
42
|
+
@njit
|
43
|
+
def _f_sys(x, P):
|
44
|
+
'''
|
45
|
+
system function for damp oscillator model.
|
46
|
+
'''
|
47
|
+
a = P.a
|
48
|
+
b = P.b
|
49
|
+
return np.array([x[0] - x[0]*x[1] - a * x[0] * x[0],
|
50
|
+
x[0]*x[1] - x[1] - b * x[1] * x[1]])
|
51
|
+
|
52
|
+
|
53
|
+
@njit
|
54
|
+
def euler(x, P):
|
55
|
+
'''
|
56
|
+
euler integration for damp oscillator model.
|
57
|
+
'''
|
58
|
+
return x + P.dt * _f_sys(x, P)
|
59
|
+
|
60
|
+
@njit
|
61
|
+
def heun(x, P):
|
62
|
+
'''
|
63
|
+
heun integration for damp oscillator model.
|
64
|
+
'''
|
65
|
+
k0 = _f_sys(x, P)
|
66
|
+
x1 = x + P.dt * k0
|
67
|
+
k1 = _f_sys(x1, P)
|
68
|
+
return x + 0.5 * P.dt * (k0 + k1)
|
69
|
+
|
70
|
+
@njit
|
71
|
+
def rk4(x, P):
|
72
|
+
'''
|
73
|
+
runge-kutta integration for damp oscillator model.
|
74
|
+
'''
|
75
|
+
k1 = _f_sys(x, P)
|
76
|
+
k2 = _f_sys(x + 0.5 * P.dt * k1, P)
|
77
|
+
k3 = _f_sys(x + 0.5 * P.dt * k2, P)
|
78
|
+
k4 = _f_sys(x + P.dt * k3, P)
|
79
|
+
return x + P.dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6.0
|
80
|
+
|
81
|
+
|
82
|
+
@njit
|
83
|
+
def _integrate(x, P, intg=euler):
|
84
|
+
|
85
|
+
t0 = np.arange(P.t_start, P.t_cut, P.dt)
|
86
|
+
|
87
|
+
for i in range(len(t0)):
|
88
|
+
x = intg(x, P)
|
89
|
+
|
90
|
+
t = np.arange(P.t_cut, P.t_end, P.dt)
|
91
|
+
x_out = np.zeros((len(t), len(x)))
|
92
|
+
|
93
|
+
for i in range(len(t)):
|
94
|
+
x = intg(x, P)
|
95
|
+
x_out[i, :] = x
|
96
|
+
return t, x_out
|
97
|
+
|
98
|
+
|
99
|
+
class DO_nb:
|
100
|
+
'''
|
101
|
+
Damper Oscillator model class.
|
102
|
+
'''
|
103
|
+
|
104
|
+
def __init__(self, par={}):
|
105
|
+
|
106
|
+
self.valid_params = [jit_spec[i][0] for i in range(len(jit_spec))]
|
107
|
+
self.check_parameters(par)
|
108
|
+
self.P = self.get_parobj(par)
|
109
|
+
|
110
|
+
self.P.output = "output" if self.P.output is None else self.P.output
|
111
|
+
os.makedirs(self.P.output, exist_ok=True)
|
112
|
+
|
113
|
+
def __str__(self) -> str:
|
114
|
+
print("Damp Oscillator model")
|
115
|
+
print("----------------")
|
116
|
+
for key in self.valid_params:
|
117
|
+
print(key, ": ", getattr(self.P, key))
|
118
|
+
return ""
|
119
|
+
|
120
|
+
def check_parameters(self, par):
|
121
|
+
'''
|
122
|
+
check if the parameters are valid.
|
123
|
+
'''
|
124
|
+
for key in par.keys():
|
125
|
+
if key not in self.valid_params:
|
126
|
+
raise ValueError("Invalid parameter: " + key)
|
127
|
+
|
128
|
+
def get_parobj(self, par={}):
|
129
|
+
'''
|
130
|
+
return default parameters for damp oscillator model.
|
131
|
+
'''
|
132
|
+
if "initial_state" in par.keys():
|
133
|
+
par["initial_state"] = np.array(par["initial_state"])
|
134
|
+
|
135
|
+
parobj = Param(**par)
|
136
|
+
|
137
|
+
return parobj
|
138
|
+
|
139
|
+
def update_par(self, par={}):
|
140
|
+
|
141
|
+
if par:
|
142
|
+
self.check_parameters(par)
|
143
|
+
for key in par.keys():
|
144
|
+
setattr(self.P, key, par[key])
|
145
|
+
|
146
|
+
def run(self, par={}, x0=None):
|
147
|
+
|
148
|
+
self.update_par(par)
|
149
|
+
if x0 is not None:
|
150
|
+
assert len(x0) == 2, "Invalid initial state"
|
151
|
+
self.P.initial_state = x0
|
152
|
+
|
153
|
+
method = self.P.method
|
154
|
+
if method == "euler":
|
155
|
+
intg = euler
|
156
|
+
elif method == "heun":
|
157
|
+
intg = heun
|
158
|
+
elif method == "rk4":
|
159
|
+
intg = rk4
|
160
|
+
|
161
|
+
return _integrate(self.P.initial_state, self.P, intg=intg)
|
162
|
+
|