vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl → 0.2__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vbi/feature_extraction/features.json +4 -1
- vbi/feature_extraction/features.py +10 -4
- vbi/inference.py +50 -22
- vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/jr_sde.hpp +5 -6
- vbi/models/cpp/_src/jr_sde_wrap.cxx +28 -28
- vbi/models/cpp/jansen_rit.py +2 -9
- vbi/models/cupy/bold.py +117 -0
- vbi/models/cupy/jansen_rit.py +1 -1
- vbi/models/cupy/km.py +62 -34
- vbi/models/cupy/mpr.py +24 -4
- vbi/models/cupy/utils.py +163 -2
- vbi/models/cupy/wilson_cowan.py +317 -0
- vbi/models/cupy/ww.py +342 -0
- vbi/models/numba/__init__.py +4 -0
- vbi/models/numba/jansen_rit.py +532 -0
- vbi/models/numba/mpr.py +8 -0
- vbi/models/numba/wilson_cowan.py +443 -0
- vbi/models/numba/ww.py +564 -0
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/METADATA +30 -11
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/RECORD +30 -26
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/WHEEL +1 -1
- vbi/models/numba/_ww_EI.py +0 -444
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/licenses/LICENSE +0 -0
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,317 @@
|
|
1
|
+
import os
|
2
|
+
import tqdm
|
3
|
+
import logging
|
4
|
+
import numpy as np
|
5
|
+
from copy import copy
|
6
|
+
from vbi.models.cupy.utils import *
|
7
|
+
|
8
|
+
try:
|
9
|
+
import cupy as cp
|
10
|
+
except ImportError:
|
11
|
+
logging.warning("Cupy is not installed. Using Numpy instead.")
|
12
|
+
|
13
|
+
class WC_sde:
|
14
|
+
r"""
|
15
|
+
Wilson-Cowan model of neural population dynamics.
|
16
|
+
|
17
|
+
**References**:
|
18
|
+
|
19
|
+
.. [WC_1972] Wilson, H.R. and Cowan, J.D. *Excitatory and inhibitory
|
20
|
+
interactions in localized populations of model neurons*, Biophysical
|
21
|
+
journal, 12: 1-24, 1972.
|
22
|
+
.. [WC_1973] Wilson, H.R. and Cowan, J.D *A Mathematical Theory of the
|
23
|
+
Functional Dynamics of Cortical and Thalamic Nervous Tissue*
|
24
|
+
|
25
|
+
.. [D_2011] Daffertshofer, A. and van Wijk, B. *On the influence of
|
26
|
+
amplitude on the connectivity between phases*
|
27
|
+
Frontiers in Neuroinformatics, July, 2011
|
28
|
+
|
29
|
+
"""
|
30
|
+
|
31
|
+
def __init__(self, par: dict = {}) -> None:
|
32
|
+
|
33
|
+
self._par = self.get_default_parameters()
|
34
|
+
self.valid_parameters = list(self._par.keys())
|
35
|
+
self.check_parameters(par)
|
36
|
+
self._par.update(par)
|
37
|
+
|
38
|
+
for item in self._par.items():
|
39
|
+
name = item[0]
|
40
|
+
value = item[1]
|
41
|
+
setattr(self, name, value)
|
42
|
+
|
43
|
+
self.xp = get_module(self.engine)
|
44
|
+
if self.seed is not None:
|
45
|
+
self.xp.random.seed(self.seed)
|
46
|
+
|
47
|
+
os.makedirs(self.output, exist_ok=True)
|
48
|
+
self.update_dependent_parameters()
|
49
|
+
self.PREPARE_INPUT = False
|
50
|
+
|
51
|
+
def __call__(self):
|
52
|
+
print("Wilson-Cowan model of neural population dynamics")
|
53
|
+
return self._par
|
54
|
+
|
55
|
+
def __str__(self) -> str:
|
56
|
+
print("Wilson-Cowan model of neural population dynamics")
|
57
|
+
print("----------------")
|
58
|
+
for item in self._par.items():
|
59
|
+
name = item[0]
|
60
|
+
value = item[1]
|
61
|
+
print(f"{name} = {value}")
|
62
|
+
return ""
|
63
|
+
|
64
|
+
def get_default_parameters(self):
|
65
|
+
|
66
|
+
par = {
|
67
|
+
'c_ee': 16.0,
|
68
|
+
'c_ei': 12.0,
|
69
|
+
'c_ie': 15.0,
|
70
|
+
'c_ii': 3.0,
|
71
|
+
'tau_e': 8.0,
|
72
|
+
'tau_i': 8.0,
|
73
|
+
'a_e': 1.3,
|
74
|
+
'a_i': 2.0,
|
75
|
+
'b_e': 4.0,
|
76
|
+
'b_i': 3.7,
|
77
|
+
'c_e': 1.0,
|
78
|
+
'c_i': 1.0,
|
79
|
+
'theta_e': 0.0,
|
80
|
+
'theta_i': 0.0,
|
81
|
+
'r_e': 1.0,
|
82
|
+
'r_i': 1.0,
|
83
|
+
'k_e': 0.994,
|
84
|
+
'k_i': 0.999,
|
85
|
+
'alpha_e': 1.0,
|
86
|
+
'alpha_i': 1.0,
|
87
|
+
'P': 0.0, # external input to excitatory population
|
88
|
+
'Q': 0.0, # external input to inhibitory population
|
89
|
+
'g_e': 0.0, # coupling excitatory
|
90
|
+
'g_i': 0.0, # coupling inhibitory
|
91
|
+
"method": "heun", # integration method
|
92
|
+
"weights": None, # connectivity matrix
|
93
|
+
'seed': None, # random seed
|
94
|
+
"t_end": 300.0, # end time
|
95
|
+
"t_cut": 0.0, # cut time
|
96
|
+
"dt": 0.01, # time step
|
97
|
+
"noise_amp": 0.0, # noise
|
98
|
+
"output": "output", # output directory
|
99
|
+
"num_sim": 1,
|
100
|
+
"engine": "cpu",
|
101
|
+
"same_initial_state": False,
|
102
|
+
"dtype": "float",
|
103
|
+
"RECORD_EI": "E",
|
104
|
+
"initial_state": None,
|
105
|
+
"decimate": 1,
|
106
|
+
"shift_sigmoid": False,
|
107
|
+
}
|
108
|
+
|
109
|
+
return par
|
110
|
+
|
111
|
+
|
112
|
+
def update_dependent_parameters(self):
|
113
|
+
|
114
|
+
self.inv_tau_e = 1.0 / self.tau_e
|
115
|
+
self.inv_tau_i = 1.0 / self.tau_i
|
116
|
+
|
117
|
+
def check_parameters(self, par: dict) -> None:
|
118
|
+
for key in par.keys():
|
119
|
+
if key not in self.valid_parameters:
|
120
|
+
raise ValueError(f"Invalid parameter: {key} provided.")
|
121
|
+
|
122
|
+
|
123
|
+
def set_initial_state(self):
|
124
|
+
self.x0 = set_initial_state(
|
125
|
+
self.nn,
|
126
|
+
self.num_sim,
|
127
|
+
self.engine,
|
128
|
+
self.seed,
|
129
|
+
self.same_initial_state,
|
130
|
+
self.dtype,
|
131
|
+
)
|
132
|
+
|
133
|
+
def prepare_input(self):
|
134
|
+
'''
|
135
|
+
Prepare input parameters, check dimensions and convert to cupy array if needed. Some parameters
|
136
|
+
can be scalars, vectors or 2D arrays. 2D arrays parameters are heterogeneous accross nodes and
|
137
|
+
simulations, but 1D arrays parameters are homogeneous accross nodes and heterogeneous accross
|
138
|
+
simulations.
|
139
|
+
|
140
|
+
vector parameters: ns
|
141
|
+
scalar parameters: 1
|
142
|
+
matrix parameters: nn x ns
|
143
|
+
'''
|
144
|
+
assert self.weights is not None, "weights must be provided"
|
145
|
+
self.g_e = self.xp.array(self.g_e)
|
146
|
+
self.g_i = self.xp.array(self.g_i)
|
147
|
+
|
148
|
+
for i in ["P", "Q", "c_ee", "c_ei", "c_ie", "c_ii", "tau_e", "tau_i"]:
|
149
|
+
setattr(self, i, prepare_vec(getattr(self, i), self.num_sim, self.engine, self.dtype))
|
150
|
+
|
151
|
+
self.weights = self.xp.array(self.weights)
|
152
|
+
self.weights = move_data(self.weights, self.engine)
|
153
|
+
self.num_nodes = self.nn = self.weights.shape[0]
|
154
|
+
self.PREPARE_INPUT = True
|
155
|
+
|
156
|
+
def derivative(self, x, t):
|
157
|
+
"""
|
158
|
+
Derivative of the Wilson-Cowan model
|
159
|
+
"""
|
160
|
+
|
161
|
+
nn = self.nn
|
162
|
+
E = x[:nn, :]
|
163
|
+
I = x[nn:, :]
|
164
|
+
dxdt = self.xp.zeros((2*nn, self.num_sim), dtype=self.dtype)
|
165
|
+
lc_e = lc_i = 0.0
|
166
|
+
|
167
|
+
if (self.g_e > 0.0).any():
|
168
|
+
lc_e = self.g_e * (self.weights @ E)
|
169
|
+
if (self.g_i > 0.0).any():
|
170
|
+
lc_i = self.g_i * (self.weights @ I)
|
171
|
+
|
172
|
+
x_e = self.alpha_e * (self.c_ee * E - self.c_ei * I + self.P - self.theta_e + lc_e)
|
173
|
+
x_i = self.alpha_i * (self.c_ie * E - self.c_ii * I + self.Q - self.theta_i + lc_i)
|
174
|
+
s_e = self.sigmoid(x_e, self.a_e, self.b_e, self.c_e)
|
175
|
+
s_i = self.sigmoid(x_i, self.a_i, self.b_i, self.c_i)
|
176
|
+
dxdt[:nn, :] = self.inv_tau_e * (-E + (self.k_e - self.r_e * E) * s_e)
|
177
|
+
dxdt[nn:, :] = self.inv_tau_i * (-I + (self.k_i - self.r_i * I) * s_i)
|
178
|
+
|
179
|
+
return dxdt
|
180
|
+
|
181
|
+
def sigmoid(self, x, a, b, c):
|
182
|
+
'''
|
183
|
+
Sigmoid function
|
184
|
+
'''
|
185
|
+
|
186
|
+
if self.shift_sigmoid:
|
187
|
+
return c * (1.0 / (1.0 + self.xp.exp(-a * (x - b))) - 1.0 / (1.0 + self.xp.exp(-a * -b)))
|
188
|
+
else:
|
189
|
+
return c / (1.0 + self.xp.exp(-a * (x - b)))
|
190
|
+
|
191
|
+
def euler_maruyama(self, x, t):
|
192
|
+
'''
|
193
|
+
Euler-Maruyama method
|
194
|
+
'''
|
195
|
+
|
196
|
+
dw = self.xp.random.normal(size=x.shape)
|
197
|
+
coeff = self.noise_amp * self.xp.sqrt(self.dt)
|
198
|
+
|
199
|
+
return x + self.dt * self.derivative(x, t) + coeff * dw
|
200
|
+
|
201
|
+
def heunStochastic(self, x, t):
|
202
|
+
'''
|
203
|
+
Heun method
|
204
|
+
'''
|
205
|
+
|
206
|
+
coeff = self.noise_amp * self.xp.sqrt(self.dt)
|
207
|
+
dw = self.xp.random.normal(size=x.shape)
|
208
|
+
|
209
|
+
k1 = self.derivative(x, t)
|
210
|
+
x_predictor = x + self.dt * k1 + coeff * dw
|
211
|
+
k2 = self.derivative(x_predictor, t + self.dt)
|
212
|
+
|
213
|
+
return x + self.dt * (k1 + k2) / 2.0 + coeff * dw
|
214
|
+
|
215
|
+
def run(self, x0=None, tspan=None, verbose=True):
|
216
|
+
'''
|
217
|
+
Run the Wilson-Cowan model
|
218
|
+
#TODO: optimize memory usage
|
219
|
+
'''
|
220
|
+
|
221
|
+
self.prepare_input()
|
222
|
+
|
223
|
+
if x0 is None:
|
224
|
+
self.set_initial_state()
|
225
|
+
else:
|
226
|
+
self.x0 = x0
|
227
|
+
|
228
|
+
if tspan is None:
|
229
|
+
t = np.arange(0.0, self.t_end, self.dt)
|
230
|
+
else:
|
231
|
+
t = tspan
|
232
|
+
|
233
|
+
nn = self.nn
|
234
|
+
t_cut = self.t_cut
|
235
|
+
decimate = self.decimate
|
236
|
+
RECORD_EI = self.RECORD_EI.lower()
|
237
|
+
|
238
|
+
valid_points = np.sum(t > t_cut)
|
239
|
+
buffer_size = valid_points // decimate
|
240
|
+
t_buffer = np.zeros((buffer_size), dtype=np.float32)
|
241
|
+
E = I = None
|
242
|
+
|
243
|
+
if "e" in RECORD_EI:
|
244
|
+
E = np.zeros((buffer_size, self.nn, self.num_sim), dtype=np.float32)
|
245
|
+
|
246
|
+
if "i" in RECORD_EI:
|
247
|
+
I = np.zeros((buffer_size, self.nn, self.num_sim), dtype=np.float32)
|
248
|
+
|
249
|
+
|
250
|
+
buffer_idx = 0
|
251
|
+
for i in tqdm.trange(len(t), disable=not verbose, desc="Integrating"):
|
252
|
+
t_curr = i * self.dt
|
253
|
+
|
254
|
+
self.x0 = self.heunStochastic(self.x0, t_curr)
|
255
|
+
|
256
|
+
if (t_curr > t_cut) and (i % decimate == 0):
|
257
|
+
if buffer_idx < buffer_size:
|
258
|
+
t_buffer[buffer_idx] = t_curr
|
259
|
+
|
260
|
+
if "e" in RECORD_EI:
|
261
|
+
E[buffer_idx] = get_(self.x0[:nn, :], self.engine, "f")
|
262
|
+
|
263
|
+
if "i" in RECORD_EI:
|
264
|
+
I[buffer_idx] = get_(self.x0[nn:, :], self.engine, "f")
|
265
|
+
|
266
|
+
buffer_idx += 1
|
267
|
+
|
268
|
+
return {"t": t_buffer, "E": E, "I": I}
|
269
|
+
|
270
|
+
|
271
|
+
def do_step_EI(self, x, t, method="heunStochastic"):
|
272
|
+
'''
|
273
|
+
Do a single step of the Wilson-Cowan model
|
274
|
+
'''
|
275
|
+
if not self.PREPARE_INPUT:
|
276
|
+
self.prepare_input()
|
277
|
+
|
278
|
+
if method == "heunStochastic":
|
279
|
+
x = self.heunStochastic(x, t)
|
280
|
+
elif method == "euler_maruyama":
|
281
|
+
x = self.euler_maruyama(x, t)
|
282
|
+
else:
|
283
|
+
raise ValueError(f"Invalid method: {method}")
|
284
|
+
|
285
|
+
return x
|
286
|
+
|
287
|
+
|
288
|
+
def set_initial_state(nn, ns, engine, seed=None, same_initial_state=False, dtype=float):
|
289
|
+
"""
|
290
|
+
Set initial state for the Wilson-Cowan model
|
291
|
+
|
292
|
+
Parameters
|
293
|
+
----------
|
294
|
+
nn : int
|
295
|
+
number of nodes
|
296
|
+
ns : int
|
297
|
+
number of simulations
|
298
|
+
engine : str
|
299
|
+
cpu or gpu
|
300
|
+
seed : int
|
301
|
+
random seed
|
302
|
+
dtype : str
|
303
|
+
float: float64
|
304
|
+
f : float32
|
305
|
+
"""
|
306
|
+
|
307
|
+
if seed is not None:
|
308
|
+
np.random.seed(seed)
|
309
|
+
|
310
|
+
if same_initial_state:
|
311
|
+
y0 = np.random.rand(2*nn)
|
312
|
+
y0 = repmat_vec(y0, ns, engine)
|
313
|
+
else:
|
314
|
+
y0 = np.random.rand(2*nn, ns)
|
315
|
+
y0 = move_data(y0, engine)
|
316
|
+
|
317
|
+
return y0.astype(dtype)
|
vbi/models/cupy/ww.py
ADDED
@@ -0,0 +1,342 @@
|
|
1
|
+
import os
|
2
|
+
import tqdm
|
3
|
+
import logging
|
4
|
+
import numpy as np
|
5
|
+
from copy import copy
|
6
|
+
from vbi.models.cupy.utils import *
|
7
|
+
from vbi.models.cupy.bold import Bold
|
8
|
+
from typing import List, Dict
|
9
|
+
|
10
|
+
try:
|
11
|
+
import cupy as cp
|
12
|
+
except ImportError:
|
13
|
+
logging.warning("Cupy is not installed. Using Numpy instead.")
|
14
|
+
|
15
|
+
|
16
|
+
class WW_sde:
|
17
|
+
"""
|
18
|
+
Wong-Wang neural mass including Excitatory and Inhibitory populations.
|
19
|
+
|
20
|
+
|
21
|
+
Main reference:
|
22
|
+
[original] Wong, K. F., & Wang, X. J. (2006). A recurrent network mechanism
|
23
|
+
of time integration in perceptual decisions. Journal of Neuroscience, 26(4),
|
24
|
+
1314-1328.
|
25
|
+
|
26
|
+
Additional references:
|
27
|
+
[reduced] Deco, G., Ponce-Alvarez, A., Mantini, D., Romani, G. L., Hagmann,
|
28
|
+
P., & Corbetta, M. (2013). Resting-state functional connectivity emerges
|
29
|
+
from structurally and dynamically shaped slow linear fluctuations. Journal
|
30
|
+
of Neuroscience, 33(27), 11239-11252.
|
31
|
+
|
32
|
+
[original] Deco, G., Ponce-Alvarez, A., Hagmann, P., Romani, G. L., Mantini,
|
33
|
+
D., & Corbetta, M. (2014). How local excitation-inhibition ratio impacts the
|
34
|
+
whole brain dynamics. Journal of Neuroscience, 34(23), 7886-7898.
|
35
|
+
|
36
|
+
Parameters
|
37
|
+
----------
|
38
|
+
G: float
|
39
|
+
Global coupling strength.
|
40
|
+
dt: float
|
41
|
+
Time step for integration.
|
42
|
+
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(self, par: Dict = {}, Bpar: Dict = {}) -> None:
|
46
|
+
|
47
|
+
self._par = self.get_default_parameters()
|
48
|
+
self.valid_parameters = list(self._par.keys())
|
49
|
+
self.check_parameters(par)
|
50
|
+
self._par.update(par)
|
51
|
+
|
52
|
+
for item in self._par.items():
|
53
|
+
setattr(self, item[0], item[1])
|
54
|
+
|
55
|
+
self.B = Bold(Bpar)
|
56
|
+
|
57
|
+
self.xp = get_module(self.engine)
|
58
|
+
if self.seed is not None:
|
59
|
+
self.xp.random.seed(self.seed)
|
60
|
+
|
61
|
+
os.makedirs(self.output, exist_ok=True)
|
62
|
+
|
63
|
+
def __call__(self):
|
64
|
+
print("Wong-Wang model.")
|
65
|
+
return self._par
|
66
|
+
|
67
|
+
def __str__(self) -> str:
|
68
|
+
header = "Wong-Wang Model Parameters"
|
69
|
+
header = header.center(50, "=")
|
70
|
+
params = "\n".join([f"{key:>20}: {value}" for key, value in self._par.items()])
|
71
|
+
return f"{header}\n{params}"
|
72
|
+
|
73
|
+
def set_initial_state(self):
|
74
|
+
return set_initial_state(
|
75
|
+
self.nn,
|
76
|
+
self.num_sim,
|
77
|
+
self.engine,
|
78
|
+
self.seed,
|
79
|
+
self.same_initial_state,
|
80
|
+
self.dtype,
|
81
|
+
)
|
82
|
+
|
83
|
+
def get_default_parameters(self) -> Dict:
|
84
|
+
"""Get default parameters for the Wong-Wang full model."""
|
85
|
+
|
86
|
+
par = {
|
87
|
+
# Excitatory parameters
|
88
|
+
"a_exc": 310, # n/C
|
89
|
+
"a_inh": 0.615, # nC^-1
|
90
|
+
"b_exc": 125, # Hz
|
91
|
+
"b_inh": 177, # Hz
|
92
|
+
"d_exc": 0.16, # s
|
93
|
+
"d_inh": 0.087, # ms
|
94
|
+
"tau_exc": 100.0, # ms
|
95
|
+
"tau_inh": 10.0, # ms
|
96
|
+
"gamma_exc": 0.641 / 1000.0,
|
97
|
+
"gamma_inh": 1.0 / 1000.0, # ms
|
98
|
+
"W_exc": 1.0,
|
99
|
+
"W_inh": 0.7,
|
100
|
+
"ext_current": 0.382, # nA external current
|
101
|
+
"J_NMDA": 0.15, # nA
|
102
|
+
"J_I": 1.0, # nA
|
103
|
+
"w_plus": 1.4,
|
104
|
+
"lambda_inh_exc": 0.0, # logn-range feedforward inhibition is considered =1, otherwise =0
|
105
|
+
# other parameters
|
106
|
+
"t_end": 1000.0, # end time of simulation in ms
|
107
|
+
"t_cut": 0.0, # time to cut off initial transient in ms
|
108
|
+
"dt": 0.1, # time step for integration in ms
|
109
|
+
"G_exc": 0.0, # global excitatory coupling strength
|
110
|
+
"G_inh": 0.0, # global inhibitory coupling strength
|
111
|
+
"weights": None, # connectivity matrix (nn x nn)
|
112
|
+
"tr": 300.0, # repetition time in ms for BOLD
|
113
|
+
"s_decimate": 1, # decimation factor for recording gating variables S
|
114
|
+
"same_noise_per_sim": False, # if True, same noise is used for all simulations
|
115
|
+
"sigma": 0.0, # noise strength
|
116
|
+
"num_sim": 1, # number of simulations
|
117
|
+
"nn": 1, # number of nodes
|
118
|
+
"engine": "cpu", # computation engine (cpu or gpu)
|
119
|
+
"seed": None, # random seed
|
120
|
+
"output": "output", # output directory
|
121
|
+
"dtype": "float32", # data type (float or float32)
|
122
|
+
"initial_state": None, # initial state
|
123
|
+
"same_initial_state": False, # if True, same initial state for all simulations
|
124
|
+
"RECORD_S": False, # if True, record gating variables S
|
125
|
+
"RECORD_BOLD": True, # if True, record BOLD signal
|
126
|
+
}
|
127
|
+
return par
|
128
|
+
|
129
|
+
def check_parameters(self, par):
|
130
|
+
for key in par.keys():
|
131
|
+
if key not in self.valid_parameters:
|
132
|
+
raise ValueError(f"Invalid parameter {key:s} provided.")
|
133
|
+
|
134
|
+
def prepare_input(self):
|
135
|
+
self.G_exc = self.xp.array(self.G_exc, dtype=self.dtype)
|
136
|
+
|
137
|
+
self.ext_current = prepare_vec_2d(
|
138
|
+
self.ext_current, self.nn, self.num_sim, self.engine, self.dtype
|
139
|
+
)
|
140
|
+
self.sigma = self.xp.array(self.sigma, dtype=self.dtype)
|
141
|
+
assert self.weights is not None, "Weights must be provided."
|
142
|
+
self.weights = self.xp.array(self.weights, dtype=self.dtype)
|
143
|
+
self.nn = self.num_nodes = self.weights.shape[0]
|
144
|
+
|
145
|
+
if self.initial_state is None:
|
146
|
+
self.set_initial_state()
|
147
|
+
|
148
|
+
def get_firing_rate(self, current: float, is_exc: bool = True):
|
149
|
+
"""Calculate firing rate based on input current"""
|
150
|
+
if is_exc:
|
151
|
+
a, b, d = self.a_exc, self.b_exc, self.d_exc
|
152
|
+
else:
|
153
|
+
a, b, d = self.a_inh, self.b_inh, self.d_inh
|
154
|
+
|
155
|
+
return (a * current - b) / (1.0 - np.exp(-d * (a * current - b)))
|
156
|
+
|
157
|
+
def f_ww(self, S, t=None):
|
158
|
+
"""Wong-Wang neural mass model equations."""
|
159
|
+
|
160
|
+
xp = self.xp
|
161
|
+
nn = self.nn
|
162
|
+
ns = self.num_sim
|
163
|
+
weights = self.weights
|
164
|
+
S_exc, S_inh = S[: self.nn, :], S[self.nn :, :]
|
165
|
+
|
166
|
+
network_exc_exc = weights @ S_exc
|
167
|
+
if self.lambda_inh_exc > 0:
|
168
|
+
network_inh_exc = weights @ S_inh
|
169
|
+
else:
|
170
|
+
network_inh_exc = 0.0
|
171
|
+
|
172
|
+
current_exc = (
|
173
|
+
self.W_exc * self.ext_current
|
174
|
+
+ self.w_plus * self.J_NMDA * S_exc
|
175
|
+
+ self.G_exc * self.J_NMDA * network_exc_exc
|
176
|
+
- self.J_I * S_inh
|
177
|
+
)
|
178
|
+
|
179
|
+
current_inh = (
|
180
|
+
self.W_inh * self.ext_current
|
181
|
+
+ self.J_NMDA * S_inh
|
182
|
+
- S_inh
|
183
|
+
+ self.G_inh * self.J_NMDA * network_inh_exc
|
184
|
+
)
|
185
|
+
|
186
|
+
r_exc = self.get_firing_rate(current_exc, is_exc=True)
|
187
|
+
r_inh = self.get_firing_rate(current_inh, is_exc=False)
|
188
|
+
dSdt = xp.zeros((2 * nn, ns)).astype(self.dtype)
|
189
|
+
|
190
|
+
# exc
|
191
|
+
dSdt[:nn, :] = (-S_exc / self.tau_exc) + (1.0 - S_exc) * self.gamma_exc * r_exc
|
192
|
+
# inh
|
193
|
+
dSdt[nn:, :] = (-S_inh / self.tau_inh) + self.gamma_inh * r_inh
|
194
|
+
|
195
|
+
return dSdt
|
196
|
+
|
197
|
+
def heunStochastic(self, y, t, dt):
|
198
|
+
|
199
|
+
xp = self.xp
|
200
|
+
nn = self.nn
|
201
|
+
ns = self.num_sim
|
202
|
+
|
203
|
+
if not self.same_noise_per_sim:
|
204
|
+
dW = self.sigma * xp.random.randn(2 * nn, ns) * xp.sqrt(dt)
|
205
|
+
else:
|
206
|
+
dW = self.sigma * xp.random.randn(2 * nn, 1) * xp.sqrt(dt)
|
207
|
+
k1 = self.f_ww(y, t)
|
208
|
+
y_ = y + dt * k1 + dW
|
209
|
+
k2 = self.f_ww(y_, t + dt)
|
210
|
+
y = y + dt * 0.5 * (k1 + k2) + dW
|
211
|
+
|
212
|
+
return y
|
213
|
+
|
214
|
+
def do_step(self, S, t, dt):
|
215
|
+
"""run one step of the model"""
|
216
|
+
S = self.heunStochastic(S, t, dt)
|
217
|
+
return S
|
218
|
+
|
219
|
+
def do_bold_step(self, r_in, s, f, ftilde, vtilde, qtilde, v, q, dt, P):
|
220
|
+
"""
|
221
|
+
Step the BOLD model forward in time.
|
222
|
+
"""
|
223
|
+
return self.Bold.do_bold_step(r_in, s, f, ftilde, vtilde, qtilde, v, q, dt, P)
|
224
|
+
|
225
|
+
def run(self, x0=None, tspan: np.ndarray = None, verbose=True):
|
226
|
+
"""Run the Wong-Wang model simulation."""
|
227
|
+
|
228
|
+
self.prepare_input()
|
229
|
+
if x0 is None:
|
230
|
+
x0 = copy(self.set_initial_state())
|
231
|
+
else:
|
232
|
+
x0 = copy(self.x0)
|
233
|
+
|
234
|
+
if tspan is None:
|
235
|
+
t = np.arange(0.0, self.t_end, self.dt)
|
236
|
+
else:
|
237
|
+
t = tspan
|
238
|
+
|
239
|
+
dt = self.dt
|
240
|
+
t_cut = self.t_cut
|
241
|
+
dt_bold = dt / 1000.0 # BOLD time step in seconds
|
242
|
+
|
243
|
+
tr = self.tr
|
244
|
+
xp = self.xp
|
245
|
+
nn = self.nn
|
246
|
+
ns = self.num_sim
|
247
|
+
engine = self.engine
|
248
|
+
s_decimate = self.s_decimate
|
249
|
+
bold_decimate = int(np.round(tr / dt))
|
250
|
+
s_curr = copy(x0)
|
251
|
+
valid_points = np.sum(t > t_cut)
|
252
|
+
s_buffer_size = valid_points // s_decimate
|
253
|
+
# b_buffer_size = int(np.ceil(len(t)/ bold_decimate))
|
254
|
+
t_buffer = np.zeros((s_buffer_size), dtype=np.float32)
|
255
|
+
n_steps = len(t)
|
256
|
+
|
257
|
+
B = self.B
|
258
|
+
B.allocate_memory(xp, nn, ns, n_steps, bold_decimate, self.dtype)
|
259
|
+
S_exc = np.array([])
|
260
|
+
|
261
|
+
if self.RECORD_S:
|
262
|
+
S_exc = np.zeros((s_buffer_size, nn, ns), dtype=np.float32)
|
263
|
+
|
264
|
+
buffer_idx = 0
|
265
|
+
for i in tqdm.trange(len(t), disable=not verbose, desc="Integrating"):
|
266
|
+
t_curr = i * dt
|
267
|
+
|
268
|
+
s_curr = self.do_step(s_curr, t_curr, dt)
|
269
|
+
|
270
|
+
if (t_curr > t_cut) and (i % s_decimate == 0):
|
271
|
+
|
272
|
+
if buffer_idx < s_buffer_size:
|
273
|
+
t_buffer[buffer_idx] = t_curr
|
274
|
+
|
275
|
+
if self.RECORD_S:
|
276
|
+
S_exc[buffer_idx] = get_(s_curr[:nn, :], engine, "f")
|
277
|
+
|
278
|
+
buffer_idx += 1
|
279
|
+
|
280
|
+
if self.RECORD_BOLD:
|
281
|
+
B.do_bold_step(s_curr[:nn, :], dt_bold)
|
282
|
+
|
283
|
+
if (i % bold_decimate == 0) and ((i // bold_decimate) < B.vv.shape[0]):
|
284
|
+
B.vv[i // bold_decimate] = get_(B.v[1], engine, "f")
|
285
|
+
B.qq[i // bold_decimate] = get_(B.q[1], engine, "f")
|
286
|
+
|
287
|
+
if self.RECORD_BOLD:
|
288
|
+
# Calculate indices for t_cut
|
289
|
+
bold_t = np.linspace(0, self.t_end - dt * bold_decimate, len(B.vv))
|
290
|
+
valid_indices = np.where(bold_t > self.t_cut)[0]
|
291
|
+
if len(valid_indices) > 0:
|
292
|
+
start_idx = valid_indices[0]
|
293
|
+
bold_d = B.vo * (
|
294
|
+
B.k1 * (1 - B.qq[start_idx:])
|
295
|
+
+ B.k2 * (1 - B.qq[start_idx:] / B.vv[start_idx:])
|
296
|
+
+ B.k3 * (1 - B.vv[start_idx:])
|
297
|
+
)
|
298
|
+
bold_t = bold_t[start_idx:]
|
299
|
+
else:
|
300
|
+
bold_d = np.array([])
|
301
|
+
bold_t = np.array([])
|
302
|
+
|
303
|
+
return {
|
304
|
+
"S": S_exc,
|
305
|
+
"t": t_buffer,
|
306
|
+
"bold_t": bold_t,
|
307
|
+
"bold_d": bold_d,
|
308
|
+
}
|
309
|
+
|
310
|
+
|
311
|
+
def set_initial_state(nn, ns, engine, seed=None, same_initial_state=False, dtype=float):
|
312
|
+
"""
|
313
|
+
Set initial state
|
314
|
+
|
315
|
+
Parameters
|
316
|
+
----------
|
317
|
+
nn : int
|
318
|
+
number of nodes
|
319
|
+
ns : int
|
320
|
+
number of simulations
|
321
|
+
engine : str
|
322
|
+
cpu or gpu
|
323
|
+
same_initial_condition : bool
|
324
|
+
same initial condition for all simulations
|
325
|
+
seed : int
|
326
|
+
random seed
|
327
|
+
dtype : str
|
328
|
+
float: float64
|
329
|
+
f : float32
|
330
|
+
"""
|
331
|
+
|
332
|
+
if seed is not None:
|
333
|
+
np.random.seed(seed)
|
334
|
+
|
335
|
+
if same_initial_state:
|
336
|
+
y0 = np.random.rand(2 * nn) * 0.1
|
337
|
+
y0 = repmat_vec(y0, ns, engine)
|
338
|
+
else:
|
339
|
+
y0 = np.random.rand(2 * nn, ns) * 0.1
|
340
|
+
y0 = move_data(y0, engine)
|
341
|
+
|
342
|
+
return y0.astype(dtype)
|
vbi/models/numba/__init__.py
CHANGED