vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vbi/__init__.py +37 -0
- vbi/_version.py +17 -0
- vbi/dataset/__init__.py +0 -0
- vbi/dataset/connectivity_84/centers.txt +84 -0
- vbi/dataset/connectivity_84/centres.txt +84 -0
- vbi/dataset/connectivity_84/cortical.txt +84 -0
- vbi/dataset/connectivity_84/tract_lengths.txt +84 -0
- vbi/dataset/connectivity_84/weights.txt +84 -0
- vbi/dataset/connectivity_88/Aud_88.txt +88 -0
- vbi/dataset/connectivity_88/Bold.npz +0 -0
- vbi/dataset/connectivity_88/Labels.txt +17 -0
- vbi/dataset/connectivity_88/Region_labels.txt +88 -0
- vbi/dataset/connectivity_88/tract_lengths.txt +88 -0
- vbi/dataset/connectivity_88/weights.txt +88 -0
- vbi/feature_extraction/__init__.py +1 -0
- vbi/feature_extraction/calc_features.py +293 -0
- vbi/feature_extraction/features.json +535 -0
- vbi/feature_extraction/features.py +2124 -0
- vbi/feature_extraction/features_settings.py +374 -0
- vbi/feature_extraction/features_utils.py +1357 -0
- vbi/feature_extraction/infodynamics.jar +0 -0
- vbi/feature_extraction/utility.py +507 -0
- vbi/inference.py +98 -0
- vbi/models/__init__.py +0 -0
- vbi/models/cpp/__init__.py +0 -0
- vbi/models/cpp/_src/__init__.py +0 -0
- vbi/models/cpp/_src/__pycache__/mpr_sde.cpython-310.pyc +0 -0
- vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/bold.hpp +303 -0
- vbi/models/cpp/_src/do.hpp +167 -0
- vbi/models/cpp/_src/do.i +17 -0
- vbi/models/cpp/_src/do.py +467 -0
- vbi/models/cpp/_src/do_wrap.cxx +12811 -0
- vbi/models/cpp/_src/jr_sdde.hpp +352 -0
- vbi/models/cpp/_src/jr_sdde.i +19 -0
- vbi/models/cpp/_src/jr_sdde.py +688 -0
- vbi/models/cpp/_src/jr_sdde_wrap.cxx +18718 -0
- vbi/models/cpp/_src/jr_sde.hpp +264 -0
- vbi/models/cpp/_src/jr_sde.i +17 -0
- vbi/models/cpp/_src/jr_sde.py +470 -0
- vbi/models/cpp/_src/jr_sde_wrap.cxx +13406 -0
- vbi/models/cpp/_src/km_sde.hpp +158 -0
- vbi/models/cpp/_src/km_sde.i +19 -0
- vbi/models/cpp/_src/km_sde.py +671 -0
- vbi/models/cpp/_src/km_sde_wrap.cxx +17367 -0
- vbi/models/cpp/_src/makefile +52 -0
- vbi/models/cpp/_src/mpr_sde.hpp +327 -0
- vbi/models/cpp/_src/mpr_sde.i +19 -0
- vbi/models/cpp/_src/mpr_sde.py +711 -0
- vbi/models/cpp/_src/mpr_sde_wrap.cxx +18618 -0
- vbi/models/cpp/_src/utility.hpp +307 -0
- vbi/models/cpp/_src/vep.hpp +171 -0
- vbi/models/cpp/_src/vep.i +16 -0
- vbi/models/cpp/_src/vep.py +464 -0
- vbi/models/cpp/_src/vep_wrap.cxx +12968 -0
- vbi/models/cpp/_src/wc_ode.hpp +294 -0
- vbi/models/cpp/_src/wc_ode.i +19 -0
- vbi/models/cpp/_src/wc_ode.py +686 -0
- vbi/models/cpp/_src/wc_ode_wrap.cxx +24263 -0
- vbi/models/cpp/damp_oscillator.py +143 -0
- vbi/models/cpp/jansen_rit.py +543 -0
- vbi/models/cpp/km.py +187 -0
- vbi/models/cpp/mpr.py +289 -0
- vbi/models/cpp/vep.py +150 -0
- vbi/models/cpp/wc.py +216 -0
- vbi/models/cupy/__init__.py +0 -0
- vbi/models/cupy/bold.py +111 -0
- vbi/models/cupy/ghb.py +284 -0
- vbi/models/cupy/jansen_rit.py +473 -0
- vbi/models/cupy/km.py +224 -0
- vbi/models/cupy/mpr.py +475 -0
- vbi/models/cupy/mpr_modified_bold.py +12 -0
- vbi/models/cupy/utils.py +184 -0
- vbi/models/numba/__init__.py +0 -0
- vbi/models/numba/_ww_EI.py +444 -0
- vbi/models/numba/damp_oscillator.py +162 -0
- vbi/models/numba/ghb.py +208 -0
- vbi/models/numba/mpr.py +383 -0
- vbi/models/pytorch/__init__.py +0 -0
- vbi/models/pytorch/data/default_parameters.npz +0 -0
- vbi/models/pytorch/data/input/ROI_sim.mat +0 -0
- vbi/models/pytorch/data/input/fc_test.csv +68 -0
- vbi/models/pytorch/data/input/fc_train.csv +68 -0
- vbi/models/pytorch/data/input/fc_vali.csv +68 -0
- vbi/models/pytorch/data/input/fcd_test.mat +0 -0
- vbi/models/pytorch/data/input/fcd_test_high_window.mat +0 -0
- vbi/models/pytorch/data/input/fcd_test_low_window.mat +0 -0
- vbi/models/pytorch/data/input/fcd_train.mat +0 -0
- vbi/models/pytorch/data/input/fcd_vali.mat +0 -0
- vbi/models/pytorch/data/input/myelin.csv +68 -0
- vbi/models/pytorch/data/input/rsfc_gradient.csv +68 -0
- vbi/models/pytorch/data/input/run_label_testset.mat +0 -0
- vbi/models/pytorch/data/input/sc_test.csv +68 -0
- vbi/models/pytorch/data/input/sc_train.csv +68 -0
- vbi/models/pytorch/data/input/sc_vali.csv +68 -0
- vbi/models/pytorch/data/obs_kong0.npz +0 -0
- vbi/models/pytorch/ww_sde_kong.py +570 -0
- vbi/models/tvbk/__init__.py +9 -0
- vbi/models/tvbk/tvbk_wrapper.py +166 -0
- vbi/models/tvbk/utils.py +72 -0
- vbi/papers/__init__.py +0 -0
- vbi/papers/pavlides_pcb_2015/pavlides.py +211 -0
- vbi/tests/__init__.py +0 -0
- vbi/tests/_test_mpr_nb.py +36 -0
- vbi/tests/test_features.py +355 -0
- vbi/tests/test_ghb_cupy.py +90 -0
- vbi/tests/test_mpr_cupy.py +49 -0
- vbi/tests/test_mpr_numba.py +84 -0
- vbi/tests/test_suite.py +19 -0
- vbi/utils.py +402 -0
- vbi-0.1.3.dist-info/METADATA +166 -0
- vbi-0.1.3.dist-info/RECORD +121 -0
- vbi-0.1.3.dist-info/WHEEL +5 -0
- vbi-0.1.3.dist-info/licenses/LICENSE +201 -0
- vbi-0.1.3.dist-info/top_level.txt +1 -0
vbi/models/cpp/km.py
ADDED
@@ -0,0 +1,187 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
try:
|
4
|
+
from vbi.models.cpp._src.km_sde import KM_sde as _KM_sde
|
5
|
+
except ImportError as e:
|
6
|
+
print(f"Could not import modules: {e}, probably C++ code is not compiled or properly linked.")
|
7
|
+
|
8
|
+
|
9
|
+
class KM_sde:
|
10
|
+
'''
|
11
|
+
Kuramoto model with noise (sde), C++ implementation.
|
12
|
+
|
13
|
+
Parameters
|
14
|
+
----------
|
15
|
+
par : dict
|
16
|
+
Dictionary of parameters.
|
17
|
+
|
18
|
+
'''
|
19
|
+
|
20
|
+
valid_parameters = [
|
21
|
+
"G", # global coupling strength
|
22
|
+
"dt", # time step
|
23
|
+
"noise_amp", # noise amplitude
|
24
|
+
"omega", # natural angular frequency
|
25
|
+
"weights", # weighted connection matrix
|
26
|
+
"noise_seed", # fix random seed for noise in Cpp code
|
27
|
+
"seed",
|
28
|
+
"alpha", # frustration matrix
|
29
|
+
"t_initial", # initial time
|
30
|
+
"t_transition", # transition time
|
31
|
+
"t_end", # end time
|
32
|
+
"output", # output directory
|
33
|
+
"num_threads", # number of threads using openmp
|
34
|
+
"initial_state",
|
35
|
+
"type" # output times series data type
|
36
|
+
]
|
37
|
+
|
38
|
+
def __init__(self, par) -> None:
|
39
|
+
|
40
|
+
self.check_parameters(par)
|
41
|
+
self._par = self.get_default_parameters()
|
42
|
+
self._par.update(par)
|
43
|
+
|
44
|
+
for item in self._par.items():
|
45
|
+
name = item[0]
|
46
|
+
value = item[1]
|
47
|
+
setattr(self, name, value)
|
48
|
+
|
49
|
+
assert (self.omega is not None)
|
50
|
+
|
51
|
+
if self.seed is not None:
|
52
|
+
np.random.seed(self.seed)
|
53
|
+
|
54
|
+
self.num_nodes = len(self.omega)
|
55
|
+
|
56
|
+
if self.initial_state is None:
|
57
|
+
self.INITIAL_STATE_SET = False
|
58
|
+
|
59
|
+
def set_initial_state(self):
|
60
|
+
self.INITIAL_STATE_SET = True
|
61
|
+
self.initial_state = set_initial_state(self.num_nodes, self.seed)
|
62
|
+
|
63
|
+
def __str__(self) -> str:
|
64
|
+
print("Kuramoto model with noise (sde), C++ implementation.")
|
65
|
+
print("----------------")
|
66
|
+
for item in self._par.items():
|
67
|
+
name = item[0]
|
68
|
+
value = item[1]
|
69
|
+
print(f"{name} = {value}")
|
70
|
+
return ""
|
71
|
+
|
72
|
+
def __call__(self):
|
73
|
+
return self._par
|
74
|
+
|
75
|
+
def get_default_parameters(self):
|
76
|
+
return {
|
77
|
+
"G": 1.0, # global coupling strength
|
78
|
+
"dt": 0.01, # time step
|
79
|
+
"noise_amp": 0.1, # noise amplitude
|
80
|
+
"weights": None, # weighted connection matrix
|
81
|
+
"alpha": None, # frustration matrix
|
82
|
+
"omega": None, # natural angular frequency
|
83
|
+
"noise_seed": 0, # fix random seed for noise in Cpp code
|
84
|
+
"seed": None, # fix random seed for initial state
|
85
|
+
"t_initial": 0.0, # initial time
|
86
|
+
"t_transition": 0.0, # transition time
|
87
|
+
"t_end": 100.0, # end time
|
88
|
+
"num_threads": 1, # number of threads using openmp
|
89
|
+
"output": "output", # output directory
|
90
|
+
"initial_state": None, # initial state
|
91
|
+
"type": np.float32
|
92
|
+
}
|
93
|
+
|
94
|
+
def check_parameters(self, par):
|
95
|
+
for key in par.keys():
|
96
|
+
if key not in self.valid_parameters:
|
97
|
+
raise ValueError(f"Invalid parameter: {key}")
|
98
|
+
|
99
|
+
def prepare_input(self):
|
100
|
+
|
101
|
+
nn = self.num_nodes
|
102
|
+
if self.weights is None:
|
103
|
+
raise ValueError("Missing weights.")
|
104
|
+
if self.omega is None:
|
105
|
+
raise ValueError("Missing omega.")
|
106
|
+
if not self.INITIAL_STATE_SET:
|
107
|
+
self.set_initial_state()
|
108
|
+
|
109
|
+
self.weights = np.array(self.weights, dtype=np.float64)
|
110
|
+
self.omega = np.array(self.omega, dtype=np.float64)
|
111
|
+
self.initial_state = np.array(self.initial_state, dtype=np.float64)
|
112
|
+
self.G = float(self.G)
|
113
|
+
self.dt = float(self.dt)
|
114
|
+
self.noise_amp = float(self.noise_amp)
|
115
|
+
self.t_initial = float(self.t_initial)
|
116
|
+
self.t_transition = float(self.t_transition)
|
117
|
+
self.t_end = float(self.t_end)
|
118
|
+
self.noise_seed = int(self.noise_seed)
|
119
|
+
if self.alpha is None:
|
120
|
+
self.alpha = np.zeros_like(self.weights, dtype=np.float64)
|
121
|
+
else:
|
122
|
+
self.alpha = np.array(self.alpha, dtype=np.float64)
|
123
|
+
assert (self.alpha.shape == (nn, nn))
|
124
|
+
|
125
|
+
def run(self, par={}, x0=None, verbose=False):
|
126
|
+
'''
|
127
|
+
Simulate the model.
|
128
|
+
|
129
|
+
Parameters
|
130
|
+
----------
|
131
|
+
par : dict
|
132
|
+
Dictionary of parameters.
|
133
|
+
x0 : array
|
134
|
+
Initial state.
|
135
|
+
verbose : bool
|
136
|
+
Print simulation progress.
|
137
|
+
|
138
|
+
Returns
|
139
|
+
-------
|
140
|
+
dict
|
141
|
+
t : array
|
142
|
+
Time points.
|
143
|
+
x : array
|
144
|
+
State time series.
|
145
|
+
'''
|
146
|
+
|
147
|
+
if x0 is None:
|
148
|
+
if not self.INITIAL_STATE_SET:
|
149
|
+
self.set_initial_state()
|
150
|
+
if verbose:
|
151
|
+
print("initial state set by default")
|
152
|
+
else:
|
153
|
+
assert (len(x0) == self.num_nodes)
|
154
|
+
self.initial_state = x0
|
155
|
+
self.INITIAL_STATE_SET = True
|
156
|
+
|
157
|
+
for key in par.keys():
|
158
|
+
if key not in self.valid_parameters:
|
159
|
+
raise ValueError(f"Invalid parameter {key:s} provided.")
|
160
|
+
else:
|
161
|
+
setattr(self, key, par[key]['value'])
|
162
|
+
self.prepare_input()
|
163
|
+
|
164
|
+
obj = _KM_sde(self.dt,
|
165
|
+
self.t_initial,
|
166
|
+
self.t_transition,
|
167
|
+
self.t_end,
|
168
|
+
self.G,
|
169
|
+
self.noise_amp,
|
170
|
+
self.initial_state,
|
171
|
+
self.omega,
|
172
|
+
self.alpha,
|
173
|
+
self.weights,
|
174
|
+
self.noise_seed,
|
175
|
+
self.num_threads
|
176
|
+
)
|
177
|
+
obj.IntegrateHeun()
|
178
|
+
t = np.asarray(obj.get_times())
|
179
|
+
x = np.asarray(obj.get_theta()).T.astype(self.type)
|
180
|
+
|
181
|
+
return {"t": t, "x": x}
|
182
|
+
|
183
|
+
|
184
|
+
def set_initial_state(num_nodes, seed=None):
|
185
|
+
if seed is not None:
|
186
|
+
np.random.seed(seed)
|
187
|
+
return np.random.uniform(0, 2*np.pi, num_nodes)
|
vbi/models/cpp/mpr.py
ADDED
@@ -0,0 +1,289 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from typing import Union
|
3
|
+
from copy import deepcopy
|
4
|
+
|
5
|
+
try:
|
6
|
+
from vbi.models.cpp._src.mpr_sde import MPR_sde as _MPR_sde
|
7
|
+
from vbi.models.cpp._src.mpr_sde import BoldParams as _BoldParams
|
8
|
+
except ImportError as e:
|
9
|
+
print(f"Could not import modules: {e}, probably C++ code is not compiled or properly linked.")
|
10
|
+
|
11
|
+
|
12
|
+
class MPR_sde:
|
13
|
+
"""
|
14
|
+
MPR model
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, par: dict = {}, parbold={}) -> None:
|
18
|
+
|
19
|
+
par = deepcopy(par)
|
20
|
+
self._par = self.get_default_parameters()
|
21
|
+
self.valid_parameters = list(self._par.keys())
|
22
|
+
self.check_parameters(par)
|
23
|
+
self._par.update(par)
|
24
|
+
|
25
|
+
for item in self._par.items():
|
26
|
+
name = item[0]
|
27
|
+
value = item[1]
|
28
|
+
setattr(self, name, value)
|
29
|
+
|
30
|
+
if self.seed is not None:
|
31
|
+
np.random.seed(self.seed)
|
32
|
+
|
33
|
+
if self.initial_state is None:
|
34
|
+
self.INITIAL_STATE_SET = False
|
35
|
+
|
36
|
+
self.BP = BoldParams(parbold)
|
37
|
+
|
38
|
+
def set_initial_state(self):
|
39
|
+
self.num_nodes = self.weights.shape[0]
|
40
|
+
self.initial_state = set_initial_state(self.num_nodes, self.seed)
|
41
|
+
self.INITIAL_STATE_SET = True
|
42
|
+
|
43
|
+
# -------------------------------------------------------------------------
|
44
|
+
|
45
|
+
def __str__(self) -> str:
|
46
|
+
print("MPR sde model.")
|
47
|
+
print("----------------")
|
48
|
+
for item in self._par.items():
|
49
|
+
name = item[0]
|
50
|
+
value = item[1]
|
51
|
+
print(f"{name} = {value}")
|
52
|
+
return ""
|
53
|
+
|
54
|
+
# -------------------------------------------------------------------------
|
55
|
+
|
56
|
+
def __call__(self):
|
57
|
+
return self._par
|
58
|
+
|
59
|
+
# -------------------------------------------------------------------------
|
60
|
+
|
61
|
+
def check_parameters(self, par: dict):
|
62
|
+
for key in par.keys():
|
63
|
+
if key not in self.valid_parameters:
|
64
|
+
raise ValueError(f"Invalid parameter {key:s} provided.")
|
65
|
+
|
66
|
+
def get_default_parameters(self):
|
67
|
+
|
68
|
+
params = {
|
69
|
+
"G": 0.733, # global coupling strength
|
70
|
+
"dt": 0.01, # for mpr model [ms]
|
71
|
+
"dt_bold": 0.001, # for Balloon model [s]
|
72
|
+
"J": 14.5, # model parameter
|
73
|
+
"eta": -4.6, # model parameter
|
74
|
+
"tau": 1.0, # model parameter
|
75
|
+
"delta": 0.7, # model parameter
|
76
|
+
"tr": 500.0, # sampling from mpr time series
|
77
|
+
"rv_decimate": 10, # sampling from activity time series
|
78
|
+
"noise_amp": 0.037, # amplitude of noise
|
79
|
+
"noise_seed": 0, # fix seed for noise
|
80
|
+
"iapp": 0.0, # constant applyed current
|
81
|
+
"seed": None,
|
82
|
+
"initial_state": None, # initial condition of the system
|
83
|
+
"t_cut": 0.0, # transition time [ms]
|
84
|
+
"t_end": 5 * 60 * 1000.0, # end time [ms]
|
85
|
+
"weights": None, # weighted connection matrix
|
86
|
+
"output": "output", # output directory
|
87
|
+
"RECORD_RV": 0, # true to store large time series in file
|
88
|
+
"RECORD_BOLD": 1,
|
89
|
+
}
|
90
|
+
|
91
|
+
return params
|
92
|
+
|
93
|
+
def prepare_input(self):
|
94
|
+
"""
|
95
|
+
Prepare input parameters for passing to C++ engine.
|
96
|
+
"""
|
97
|
+
|
98
|
+
self.dt = float(self.dt)
|
99
|
+
self.dt_bold = float(self.dt_bold)
|
100
|
+
self.tr = float(self.tr)
|
101
|
+
self.initial_state = np.asarray(self.initial_state).astype(np.float64)
|
102
|
+
self.weights = np.asarray(self.weights).astype(np.float64)
|
103
|
+
self.num_nodes = self.weights.shape[0]
|
104
|
+
self.G = float(self.G)
|
105
|
+
self.eta = check_sequence(self.eta, self.num_nodes)
|
106
|
+
self.eta = np.asarray(self.eta).astype(np.float64)
|
107
|
+
|
108
|
+
self.J = check_sequence(self.J, self.num_nodes)
|
109
|
+
self.tau = check_sequence(self.tau, self.num_nodes)
|
110
|
+
self.delta = check_sequence(self.delta, self.num_nodes)
|
111
|
+
self.iapp = check_sequence(self.iapp, self.num_nodes)
|
112
|
+
self.noise_amp = float(self.noise_amp)
|
113
|
+
self.rv_decimate = int(self.rv_decimate)
|
114
|
+
self.t_cut = float(self.t_cut) / 10.0
|
115
|
+
self.t_end = float(self.t_end) / 10.0
|
116
|
+
self.RECORD_RV = int(self.RECORD_RV)
|
117
|
+
self.RECORD_BOLD = int(self.RECORD_BOLD)
|
118
|
+
self.noise_seed = int(self.noise_seed)
|
119
|
+
|
120
|
+
def run(self, par: dict = {}, x0: np.ndarray = None, verbose: bool = False):
|
121
|
+
"""
|
122
|
+
Integrate the MPR model with the given parameters.
|
123
|
+
|
124
|
+
Parameters
|
125
|
+
----------
|
126
|
+
par : dict
|
127
|
+
Dictionary of parameters.
|
128
|
+
x0 : array_like
|
129
|
+
Initial condition of the system.
|
130
|
+
verbose : bool
|
131
|
+
If True, print the progress of the simulation.
|
132
|
+
|
133
|
+
Returns
|
134
|
+
-------
|
135
|
+
bold : array_like
|
136
|
+
Simulated BOLD signal.
|
137
|
+
"""
|
138
|
+
|
139
|
+
if x0 is None:
|
140
|
+
if not self.INITIAL_STATE_SET:
|
141
|
+
self.set_initial_state()
|
142
|
+
if verbose:
|
143
|
+
print("initial state set by default")
|
144
|
+
else:
|
145
|
+
assert len(x0) == self.num_nodes * 2
|
146
|
+
self.initial_state = x0
|
147
|
+
self.INITIAL_STATE_SET = True
|
148
|
+
|
149
|
+
for key in par.keys():
|
150
|
+
if key not in self.valid_parameters:
|
151
|
+
raise ValueError(f"Invalid parameter {key:s} provided.")
|
152
|
+
setattr(self, key, par[key])
|
153
|
+
|
154
|
+
self.prepare_input()
|
155
|
+
|
156
|
+
obj = _MPR_sde(
|
157
|
+
self.dt,
|
158
|
+
self.dt_bold,
|
159
|
+
self.rv_decimate,
|
160
|
+
self.weights,
|
161
|
+
self.initial_state,
|
162
|
+
self.delta,
|
163
|
+
self.tau,
|
164
|
+
self.eta,
|
165
|
+
self.J,
|
166
|
+
self.iapp,
|
167
|
+
self.noise_amp,
|
168
|
+
self.G,
|
169
|
+
self.t_end,
|
170
|
+
self.t_cut,
|
171
|
+
self.tr,
|
172
|
+
self.RECORD_RV,
|
173
|
+
self.RECORD_BOLD,
|
174
|
+
self.noise_seed,
|
175
|
+
self.BP.get_params()
|
176
|
+
)
|
177
|
+
|
178
|
+
obj.integrate()
|
179
|
+
|
180
|
+
bold_d = np.array([])
|
181
|
+
bold_t = np.array([])
|
182
|
+
r_d = np.array([])
|
183
|
+
r_t = np.array([])
|
184
|
+
|
185
|
+
|
186
|
+
if self.RECORD_BOLD:
|
187
|
+
bold_d = np.asarray(obj.get_bold_d()).astype(np.float32)
|
188
|
+
bold_t = np.asarray(obj.get_bold_t())
|
189
|
+
|
190
|
+
if bold_d.ndim == 2:
|
191
|
+
bold_d = bold_d[bold_t > self.t_cut, :]
|
192
|
+
bold_t = bold_t[bold_t > self.t_cut] * 10.0
|
193
|
+
|
194
|
+
if self.RECORD_RV:
|
195
|
+
r_d = np.asarray(obj.get_r_d()).astype(np.float32)
|
196
|
+
r_t = np.asarray(obj.get_r_t())
|
197
|
+
if r_d.ndim == 2:
|
198
|
+
r_d = r_d[r_t > self.t_cut, :]
|
199
|
+
r_t = r_t[r_t > self.t_cut] * 10.0
|
200
|
+
|
201
|
+
return {
|
202
|
+
"rv_t": r_t,
|
203
|
+
"rv_d": r_d,
|
204
|
+
"bold_t": bold_t,
|
205
|
+
"bold_d": bold_d,
|
206
|
+
}
|
207
|
+
|
208
|
+
|
209
|
+
class BoldParams:
|
210
|
+
|
211
|
+
def __init__(self, par={}):
|
212
|
+
|
213
|
+
self._par = self.get_default_parameters()
|
214
|
+
self.valid_parameters = list(self._par.keys())
|
215
|
+
self.check_parameters(par)
|
216
|
+
self._par.update(par)
|
217
|
+
|
218
|
+
for item in self._par.items():
|
219
|
+
name = item[0]
|
220
|
+
value = item[1]
|
221
|
+
setattr(self, name, value)
|
222
|
+
|
223
|
+
def check_parameters(self, par):
|
224
|
+
for key in par.keys():
|
225
|
+
if key not in self.valid_parameters:
|
226
|
+
raise ValueError(f"Invalid parameter {key:s} provided.")
|
227
|
+
|
228
|
+
def get_default_parameters(self):
|
229
|
+
return {
|
230
|
+
"kappa": 0.7,
|
231
|
+
"gamma": 0.5,
|
232
|
+
"tau": 1.0,
|
233
|
+
"alpha": 0.35,
|
234
|
+
"epsilon": 0.36,
|
235
|
+
"Eo": 0.42,
|
236
|
+
"TE": 0.05,
|
237
|
+
"vo": 0.09,
|
238
|
+
"r0": 26.0,
|
239
|
+
"theta0": 41.0,
|
240
|
+
"rtol": 1e-6,
|
241
|
+
"atol": 1e-9,
|
242
|
+
}
|
243
|
+
|
244
|
+
def get_params(self):
|
245
|
+
bp = _BoldParams()
|
246
|
+
bp.kappa = self.kappa
|
247
|
+
bp.gamma = self.gamma
|
248
|
+
bp.tau = self.tau
|
249
|
+
bp.alpha = self.alpha
|
250
|
+
bp.epsilon = self.epsilon
|
251
|
+
bp.Eo = self.Eo
|
252
|
+
bp.TE = self.TE
|
253
|
+
bp.vo = self.vo
|
254
|
+
bp.r0 = self.r0
|
255
|
+
bp.theta0 = self.theta0
|
256
|
+
bp.rtol = self.rtol
|
257
|
+
bp.atol = self.atol
|
258
|
+
return bp
|
259
|
+
|
260
|
+
|
261
|
+
def check_sequence(x: Union[int, float, np.ndarray], n: int):
|
262
|
+
"""
|
263
|
+
check if x is a scalar or a sequence of length n
|
264
|
+
|
265
|
+
parameters
|
266
|
+
----------
|
267
|
+
x: scalar or sequence of length n
|
268
|
+
n: number of nodes
|
269
|
+
|
270
|
+
returns
|
271
|
+
-------
|
272
|
+
x: sequence of length n
|
273
|
+
"""
|
274
|
+
if isinstance(x, (np.ndarray, list, tuple)):
|
275
|
+
assert len(x) == n, f" variable must be a sequence of length {n}"
|
276
|
+
return x
|
277
|
+
else:
|
278
|
+
return x * np.ones(n)
|
279
|
+
|
280
|
+
|
281
|
+
def set_initial_state(nn, seed=None):
|
282
|
+
|
283
|
+
if seed is not None:
|
284
|
+
np.random.seed(seed)
|
285
|
+
|
286
|
+
y0 = np.random.rand(2 * nn)
|
287
|
+
y0[:nn] = y0[:nn] * 1.5
|
288
|
+
y0[nn:] = y0[nn:] * 4 - 2
|
289
|
+
return y0
|
vbi/models/cpp/vep.py
ADDED
@@ -0,0 +1,150 @@
|
|
1
|
+
import os
|
2
|
+
import numpy as np
|
3
|
+
from copy import deepcopy
|
4
|
+
from os.path import join
|
5
|
+
from typing import Union
|
6
|
+
|
7
|
+
try:
|
8
|
+
from vbi.models.cpp._src.vep import VEP as _VEP
|
9
|
+
except ImportError as e:
|
10
|
+
print(f"Could not import modules: {e}, probably C++ code is not compiled or properly linked.")
|
11
|
+
|
12
|
+
|
13
|
+
class VEP:
|
14
|
+
"""
|
15
|
+
Virtual Epileptic Patient (VEP) model
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, par: dict = {}):
|
19
|
+
|
20
|
+
par = deepcopy(par)
|
21
|
+
self._par = self.get_default_parameters()
|
22
|
+
self.valid_params = list(self._par.keys())
|
23
|
+
self.check_parameters(par)
|
24
|
+
self._par.update(par)
|
25
|
+
|
26
|
+
for item in self._par.items():
|
27
|
+
setattr(self, item[0], item[1])
|
28
|
+
|
29
|
+
if self.seed is not None:
|
30
|
+
np.random.seed(self.seed)
|
31
|
+
|
32
|
+
self.INITIAL_STATE_SET = False
|
33
|
+
if self.initial_state is not None:
|
34
|
+
self.INITIAL_STATE_SET = True
|
35
|
+
|
36
|
+
|
37
|
+
def set_initial_state(self):
|
38
|
+
self.nn = self.weights.shape[0]
|
39
|
+
self.initial_state = set_initial_state(self.nn, self.seed)
|
40
|
+
self.INITIAL_STATE_SET = True
|
41
|
+
|
42
|
+
def __str__(self) -> str:
|
43
|
+
print("VEP model")
|
44
|
+
print("---------")
|
45
|
+
for item in self._par.items():
|
46
|
+
print(f"{item[0]} = {item[1]}")
|
47
|
+
return ""
|
48
|
+
|
49
|
+
def __call__(self):
|
50
|
+
return self._par
|
51
|
+
|
52
|
+
def check_parameters(self, par: dict):
|
53
|
+
for key in par.keys():
|
54
|
+
if key not in self.valid_params:
|
55
|
+
raise ValueError(f"Invalid parameter: {key}")
|
56
|
+
|
57
|
+
def prepare_input(self):
|
58
|
+
self.nn = self.weights.shape[0]
|
59
|
+
self.iext = check_sequence(self.iext, self.nn)
|
60
|
+
self.tau = float(self.tau)
|
61
|
+
self.eta = check_sequence(self.eta, self.nn)
|
62
|
+
self.sigma = float(self.noise_sigma)
|
63
|
+
self.dt = float(self.dt)
|
64
|
+
self.tend = float(self.tend)
|
65
|
+
self.tcut = float(self.tcut)
|
66
|
+
self.noise_seed = int(self.noise_seed)
|
67
|
+
self.record_step = int(self.record_step)
|
68
|
+
self.method = str(self.method)
|
69
|
+
|
70
|
+
def get_default_parameters(self):
|
71
|
+
params = {
|
72
|
+
"G": 1.0,
|
73
|
+
"seed": None,
|
74
|
+
"initial_state": None,
|
75
|
+
"weights": None,
|
76
|
+
"tau": 10.0,
|
77
|
+
"eta": -1.5,
|
78
|
+
"noise_sigma": 0.1,
|
79
|
+
"iext": 0.0,
|
80
|
+
"dt": 0.01,
|
81
|
+
"tend": 100.0,
|
82
|
+
"tcut": 0.0,
|
83
|
+
"noise_seed": 0,
|
84
|
+
"record_step": 1,
|
85
|
+
"method": "euler",
|
86
|
+
"output": "output",
|
87
|
+
}
|
88
|
+
return params
|
89
|
+
|
90
|
+
def run(self, par: dict = {}, x0: np.ndarray = None, verbose: bool = False):
|
91
|
+
|
92
|
+
if x0 is None:
|
93
|
+
if not self.INITIAL_STATE_SET:
|
94
|
+
self.set_initial_state()
|
95
|
+
else:
|
96
|
+
self.initial_state = x0
|
97
|
+
self.INITIAL_STATE_SET = True
|
98
|
+
for key in par.keys():
|
99
|
+
if key not in self.valid_params:
|
100
|
+
raise ValueError(f"Invalid parameter: {key}")
|
101
|
+
setattr(self, key, par[key])
|
102
|
+
self.prepare_input()
|
103
|
+
|
104
|
+
obj = _VEP(
|
105
|
+
self.G,
|
106
|
+
self.iext,
|
107
|
+
self.eta,
|
108
|
+
self.dt,
|
109
|
+
self.tcut,
|
110
|
+
self.tend,
|
111
|
+
self.tau,
|
112
|
+
self.noise_sigma,
|
113
|
+
self.initial_state,
|
114
|
+
self.weights,
|
115
|
+
self.noise_seed,
|
116
|
+
self.method,
|
117
|
+
)
|
118
|
+
obj.integrate()
|
119
|
+
states = np.asarray(obj.get_states(), dtype=np.float32).T
|
120
|
+
t = np.asarray(obj.get_times())
|
121
|
+
return {"t": t, "x": states}
|
122
|
+
|
123
|
+
|
124
|
+
def set_initial_state(nn: int, seed: int = None):
|
125
|
+
if seed is not None:
|
126
|
+
np.random.seed(seed)
|
127
|
+
x0 = np.zeros(2 * nn)
|
128
|
+
x0[:nn] = np.random.uniform(-3.0, -2.0, nn)
|
129
|
+
x0[nn:] = np.random.uniform(0.0, 3.5, nn)
|
130
|
+
return x0
|
131
|
+
|
132
|
+
|
133
|
+
def check_sequence(x: Union[int, float, np.ndarray], n: int):
|
134
|
+
"""
|
135
|
+
check if x is a scalar or a sequence of length n
|
136
|
+
|
137
|
+
parameters
|
138
|
+
----------
|
139
|
+
x: scalar or sequence
|
140
|
+
n: number of elements
|
141
|
+
|
142
|
+
returns
|
143
|
+
-------
|
144
|
+
x: sequence of length n
|
145
|
+
"""
|
146
|
+
if isinstance(x, (np.ndarray, list, tuple)):
|
147
|
+
assert len(x) == n, f" variable must be a sequence of length {n}"
|
148
|
+
return x
|
149
|
+
else:
|
150
|
+
return x * np.ones(n)
|