vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vbi/__init__.py +37 -0
- vbi/_version.py +17 -0
- vbi/dataset/__init__.py +0 -0
- vbi/dataset/connectivity_84/centers.txt +84 -0
- vbi/dataset/connectivity_84/centres.txt +84 -0
- vbi/dataset/connectivity_84/cortical.txt +84 -0
- vbi/dataset/connectivity_84/tract_lengths.txt +84 -0
- vbi/dataset/connectivity_84/weights.txt +84 -0
- vbi/dataset/connectivity_88/Aud_88.txt +88 -0
- vbi/dataset/connectivity_88/Bold.npz +0 -0
- vbi/dataset/connectivity_88/Labels.txt +17 -0
- vbi/dataset/connectivity_88/Region_labels.txt +88 -0
- vbi/dataset/connectivity_88/tract_lengths.txt +88 -0
- vbi/dataset/connectivity_88/weights.txt +88 -0
- vbi/feature_extraction/__init__.py +1 -0
- vbi/feature_extraction/calc_features.py +293 -0
- vbi/feature_extraction/features.json +535 -0
- vbi/feature_extraction/features.py +2124 -0
- vbi/feature_extraction/features_settings.py +374 -0
- vbi/feature_extraction/features_utils.py +1357 -0
- vbi/feature_extraction/infodynamics.jar +0 -0
- vbi/feature_extraction/utility.py +507 -0
- vbi/inference.py +98 -0
- vbi/models/__init__.py +0 -0
- vbi/models/cpp/__init__.py +0 -0
- vbi/models/cpp/_src/__init__.py +0 -0
- vbi/models/cpp/_src/__pycache__/mpr_sde.cpython-310.pyc +0 -0
- vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/bold.hpp +303 -0
- vbi/models/cpp/_src/do.hpp +167 -0
- vbi/models/cpp/_src/do.i +17 -0
- vbi/models/cpp/_src/do.py +467 -0
- vbi/models/cpp/_src/do_wrap.cxx +12811 -0
- vbi/models/cpp/_src/jr_sdde.hpp +352 -0
- vbi/models/cpp/_src/jr_sdde.i +19 -0
- vbi/models/cpp/_src/jr_sdde.py +688 -0
- vbi/models/cpp/_src/jr_sdde_wrap.cxx +18718 -0
- vbi/models/cpp/_src/jr_sde.hpp +264 -0
- vbi/models/cpp/_src/jr_sde.i +17 -0
- vbi/models/cpp/_src/jr_sde.py +470 -0
- vbi/models/cpp/_src/jr_sde_wrap.cxx +13406 -0
- vbi/models/cpp/_src/km_sde.hpp +158 -0
- vbi/models/cpp/_src/km_sde.i +19 -0
- vbi/models/cpp/_src/km_sde.py +671 -0
- vbi/models/cpp/_src/km_sde_wrap.cxx +17367 -0
- vbi/models/cpp/_src/makefile +52 -0
- vbi/models/cpp/_src/mpr_sde.hpp +327 -0
- vbi/models/cpp/_src/mpr_sde.i +19 -0
- vbi/models/cpp/_src/mpr_sde.py +711 -0
- vbi/models/cpp/_src/mpr_sde_wrap.cxx +18618 -0
- vbi/models/cpp/_src/utility.hpp +307 -0
- vbi/models/cpp/_src/vep.hpp +171 -0
- vbi/models/cpp/_src/vep.i +16 -0
- vbi/models/cpp/_src/vep.py +464 -0
- vbi/models/cpp/_src/vep_wrap.cxx +12968 -0
- vbi/models/cpp/_src/wc_ode.hpp +294 -0
- vbi/models/cpp/_src/wc_ode.i +19 -0
- vbi/models/cpp/_src/wc_ode.py +686 -0
- vbi/models/cpp/_src/wc_ode_wrap.cxx +24263 -0
- vbi/models/cpp/damp_oscillator.py +143 -0
- vbi/models/cpp/jansen_rit.py +543 -0
- vbi/models/cpp/km.py +187 -0
- vbi/models/cpp/mpr.py +289 -0
- vbi/models/cpp/vep.py +150 -0
- vbi/models/cpp/wc.py +216 -0
- vbi/models/cupy/__init__.py +0 -0
- vbi/models/cupy/bold.py +111 -0
- vbi/models/cupy/ghb.py +284 -0
- vbi/models/cupy/jansen_rit.py +473 -0
- vbi/models/cupy/km.py +224 -0
- vbi/models/cupy/mpr.py +475 -0
- vbi/models/cupy/mpr_modified_bold.py +12 -0
- vbi/models/cupy/utils.py +184 -0
- vbi/models/numba/__init__.py +0 -0
- vbi/models/numba/_ww_EI.py +444 -0
- vbi/models/numba/damp_oscillator.py +162 -0
- vbi/models/numba/ghb.py +208 -0
- vbi/models/numba/mpr.py +383 -0
- vbi/models/pytorch/__init__.py +0 -0
- vbi/models/pytorch/data/default_parameters.npz +0 -0
- vbi/models/pytorch/data/input/ROI_sim.mat +0 -0
- vbi/models/pytorch/data/input/fc_test.csv +68 -0
- vbi/models/pytorch/data/input/fc_train.csv +68 -0
- vbi/models/pytorch/data/input/fc_vali.csv +68 -0
- vbi/models/pytorch/data/input/fcd_test.mat +0 -0
- vbi/models/pytorch/data/input/fcd_test_high_window.mat +0 -0
- vbi/models/pytorch/data/input/fcd_test_low_window.mat +0 -0
- vbi/models/pytorch/data/input/fcd_train.mat +0 -0
- vbi/models/pytorch/data/input/fcd_vali.mat +0 -0
- vbi/models/pytorch/data/input/myelin.csv +68 -0
- vbi/models/pytorch/data/input/rsfc_gradient.csv +68 -0
- vbi/models/pytorch/data/input/run_label_testset.mat +0 -0
- vbi/models/pytorch/data/input/sc_test.csv +68 -0
- vbi/models/pytorch/data/input/sc_train.csv +68 -0
- vbi/models/pytorch/data/input/sc_vali.csv +68 -0
- vbi/models/pytorch/data/obs_kong0.npz +0 -0
- vbi/models/pytorch/ww_sde_kong.py +570 -0
- vbi/models/tvbk/__init__.py +9 -0
- vbi/models/tvbk/tvbk_wrapper.py +166 -0
- vbi/models/tvbk/utils.py +72 -0
- vbi/papers/__init__.py +0 -0
- vbi/papers/pavlides_pcb_2015/pavlides.py +211 -0
- vbi/tests/__init__.py +0 -0
- vbi/tests/_test_mpr_nb.py +36 -0
- vbi/tests/test_features.py +355 -0
- vbi/tests/test_ghb_cupy.py +90 -0
- vbi/tests/test_mpr_cupy.py +49 -0
- vbi/tests/test_mpr_numba.py +84 -0
- vbi/tests/test_suite.py +19 -0
- vbi/utils.py +402 -0
- vbi-0.1.3.dist-info/METADATA +166 -0
- vbi-0.1.3.dist-info/RECORD +121 -0
- vbi-0.1.3.dist-info/WHEEL +5 -0
- vbi-0.1.3.dist-info/licenses/LICENSE +201 -0
- vbi-0.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,143 @@
|
|
1
|
+
import os
|
2
|
+
from typing import Any
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
try:
|
6
|
+
from vbi.models.cpp._src.do import DO as _DO
|
7
|
+
except ImportError as e:
|
8
|
+
print(f"Could not import modules: {e}, probably C++ code is not compiled or properly linked.")
|
9
|
+
|
10
|
+
class DO:
|
11
|
+
|
12
|
+
'''
|
13
|
+
Damp Oscillator model class.
|
14
|
+
'''
|
15
|
+
|
16
|
+
valid_params = ["a", "b", "dt", "t_start", "t_end", "t_transition",
|
17
|
+
"initial_state", "method", "output"]
|
18
|
+
|
19
|
+
# ---------------------------------------------------------------
|
20
|
+
def __init__(self, par={}):
|
21
|
+
'''
|
22
|
+
Parameters
|
23
|
+
----------
|
24
|
+
par : dictionary
|
25
|
+
parameters which includes the following:
|
26
|
+
- **dt** [double] time step.
|
27
|
+
- **t_start** [double] initial time for simulation.
|
28
|
+
- **t_end** [double] final time for simulation.
|
29
|
+
- **initial_state** [list] initial state of the system.
|
30
|
+
|
31
|
+
'''
|
32
|
+
self.check_parameters(par)
|
33
|
+
self._par = self.get_default_parameters()
|
34
|
+
self._par.update(par)
|
35
|
+
|
36
|
+
for item in self._par.items():
|
37
|
+
name = item[0]
|
38
|
+
value = item[1]
|
39
|
+
setattr(self, name, value)
|
40
|
+
|
41
|
+
def __str__(self) -> str:
|
42
|
+
print("Damp Oscillator model")
|
43
|
+
print("----------------")
|
44
|
+
for item in self._par.items():
|
45
|
+
name = item[0]
|
46
|
+
value = item[1]
|
47
|
+
print(f"{name} = {value}")
|
48
|
+
return ""
|
49
|
+
|
50
|
+
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
51
|
+
print("Damp Oscillator model")
|
52
|
+
return self._par
|
53
|
+
|
54
|
+
def check_parameters(self, par):
|
55
|
+
'''
|
56
|
+
check if the parameters are valid.
|
57
|
+
'''
|
58
|
+
for key in par.keys():
|
59
|
+
if key not in self.valid_params:
|
60
|
+
raise ValueError("Invalid parameter: " + key)
|
61
|
+
|
62
|
+
def get_default_parameters(self):
|
63
|
+
'''
|
64
|
+
return default parameters for damp oscillator model.
|
65
|
+
'''
|
66
|
+
|
67
|
+
params = {
|
68
|
+
"a": 0.1,
|
69
|
+
"b": 0.05,
|
70
|
+
"dt": 0.01,
|
71
|
+
"t_start": 0,
|
72
|
+
"method": "rk4",
|
73
|
+
"t_end": 100.0,
|
74
|
+
"t_transition": 20,
|
75
|
+
"output": "output",
|
76
|
+
"initial_state": [0.5, 1.0],
|
77
|
+
}
|
78
|
+
|
79
|
+
return params
|
80
|
+
|
81
|
+
def prepare_input(self):
|
82
|
+
'''
|
83
|
+
prepare input for cpp model.
|
84
|
+
'''
|
85
|
+
self.t_start = float(self.t_start)
|
86
|
+
self.t_end = float(self.t_end)
|
87
|
+
self.dt = float(self.dt)
|
88
|
+
self.a = float(self.a)
|
89
|
+
self.b = float(self.b)
|
90
|
+
|
91
|
+
if self.output is None:
|
92
|
+
self.output = "output"
|
93
|
+
if not os.path.exists(self.output):
|
94
|
+
os.makedirs(self.output)
|
95
|
+
|
96
|
+
if self.initial_state is None:
|
97
|
+
self.initial_state = [0.5, 1.0]
|
98
|
+
self.initial_state = np.asarray(self.initial_state, dtype=np.float64)
|
99
|
+
|
100
|
+
# ---------------------------------------------------------------
|
101
|
+
def run(self, par={}, x0=None, verbose=False):
|
102
|
+
'''
|
103
|
+
Integrate the damp oscillator system of equations
|
104
|
+
|
105
|
+
Parameters
|
106
|
+
----------
|
107
|
+
par : dictionary
|
108
|
+
parameters to control the model parameters.
|
109
|
+
|
110
|
+
'''
|
111
|
+
|
112
|
+
if x0 is not None:
|
113
|
+
assert(len(x0) == 2)
|
114
|
+
self.initial_state = x0
|
115
|
+
|
116
|
+
self.check_parameters(par)
|
117
|
+
for key in par.keys():
|
118
|
+
setattr(self, key, par[key])
|
119
|
+
|
120
|
+
self.prepare_input()
|
121
|
+
|
122
|
+
obj = _DO(self.dt,
|
123
|
+
self.a,
|
124
|
+
self.b,
|
125
|
+
self.t_start,
|
126
|
+
self.t_end,
|
127
|
+
self.initial_state)
|
128
|
+
|
129
|
+
if self.method.lower() == 'euler':
|
130
|
+
obj.eulerIntegrate()
|
131
|
+
elif self.method.lower() == 'heun':
|
132
|
+
obj.heunIntegrate()
|
133
|
+
elif self.method.lower() == 'rk4':
|
134
|
+
obj.rk4Integrate()
|
135
|
+
else:
|
136
|
+
print("unkown integratiom method")
|
137
|
+
exit(0)
|
138
|
+
|
139
|
+
sol = np.asarray(obj.get_coordinates())
|
140
|
+
times = np.asarray(obj.get_times())
|
141
|
+
del obj
|
142
|
+
|
143
|
+
return {"t": times, "x": sol}
|
@@ -0,0 +1,543 @@
|
|
1
|
+
import os
|
2
|
+
import numpy as np
|
3
|
+
from os.path import join
|
4
|
+
|
5
|
+
try:
|
6
|
+
from vbi.models.cpp._src.jr_sde import JR_sde as _JR_sde
|
7
|
+
from vbi.models.cpp._src.jr_sdde import JR_sdde as _JR_sdde
|
8
|
+
except ImportError as e:
|
9
|
+
print(f"Could not import modules: {e}, probably C++ code is not compiled or properly linked.")
|
10
|
+
|
11
|
+
|
12
|
+
class JR_sde:
|
13
|
+
'''
|
14
|
+
Jansen-Rit model C++ implementation.
|
15
|
+
|
16
|
+
Parameters
|
17
|
+
----------
|
18
|
+
|
19
|
+
par: dict
|
20
|
+
Including the following:
|
21
|
+
- **A** : [mV] determine the maximum amplitude of the excitatory PSP (EPSP)
|
22
|
+
- **B** : [mV] determine the maximum amplitude of the inhibitory PSP (IPSP)
|
23
|
+
- **a** : [Hz] 1/tau_e, :math:`\sum` of the reciprocal of the time constant of passive membrane and all other spatially distributed delays in the dendritic network
|
24
|
+
- **b** : [Hz] 1/tau_i
|
25
|
+
- **r** [mV] the steepness of the sigmoidal transformation.
|
26
|
+
- **v0** parameter of nonlinear sigmoid function
|
27
|
+
- **vmax** parameter of nonlinear sigmoid function
|
28
|
+
- **C_i** [list or np.array] average number of synaptic contacts in th inhibitory and excitatory feedback loops
|
29
|
+
- **noise_amp**
|
30
|
+
- **noise_std**
|
31
|
+
|
32
|
+
- **dt** [second] integration time step
|
33
|
+
- **t_initial** [s] initial time
|
34
|
+
- **t_end** [s] final time
|
35
|
+
- **method** [str] method of integration
|
36
|
+
- **t_transition** [s] time to reach steady state
|
37
|
+
- **dim** [int] dimention of the system
|
38
|
+
|
39
|
+
'''
|
40
|
+
valid_params = [
|
41
|
+
"noise_seed", "seed", "G", "weights", "A", "B", "a", "b",
|
42
|
+
"noise_mu", "noise_std", "vmax", "v0", "r",
|
43
|
+
"C0", "C1", "C2", "C3", "dt", "method", "t_transition",
|
44
|
+
"t_end", "control", "output", "RECORD_AVG",
|
45
|
+
"initial_state"
|
46
|
+
]
|
47
|
+
|
48
|
+
def __init__(self, par={}):
|
49
|
+
|
50
|
+
self.check_parameters(par)
|
51
|
+
self._par = self.get_default_parameters()
|
52
|
+
self._par.update(par)
|
53
|
+
|
54
|
+
for item in self._par.items():
|
55
|
+
name = item[0]
|
56
|
+
value = item[1]
|
57
|
+
setattr(self, name, value)
|
58
|
+
|
59
|
+
if self.seed is not None:
|
60
|
+
np.random.seed(self.seed)
|
61
|
+
|
62
|
+
self.N = self.num_nodes = np.asarray(self.weights).shape[0]
|
63
|
+
|
64
|
+
if self.initial_state is None:
|
65
|
+
self.INITIAL_STATE_SET = False
|
66
|
+
|
67
|
+
# self.C0 = self.C0 * np.ones(self.N)
|
68
|
+
# self.C1 = self.C1 * np.ones(self.N)
|
69
|
+
# self.C2 = self.C2 * np.ones(self.N)
|
70
|
+
# self.C3 = self.C3 * np.ones(self.N)
|
71
|
+
self.noise_seed = 1 if self.noise_seed else 0
|
72
|
+
os.makedirs(join(self.output), exist_ok=True)
|
73
|
+
|
74
|
+
def __str__(self) -> str:
|
75
|
+
print("Jansen-Rit sde model")
|
76
|
+
print("----------------")
|
77
|
+
for item in self._par.items():
|
78
|
+
name = item[0]
|
79
|
+
value = item[1]
|
80
|
+
print(f"{name} = {value}")
|
81
|
+
return ""
|
82
|
+
|
83
|
+
def __call__(self):
|
84
|
+
print("Jansen-Rit sde model")
|
85
|
+
return self._par
|
86
|
+
|
87
|
+
def check_parameters(self, par):
|
88
|
+
'''
|
89
|
+
Check if the parameters are valid.
|
90
|
+
'''
|
91
|
+
for key in par.keys():
|
92
|
+
if key not in self.valid_params:
|
93
|
+
raise ValueError("Invalid parameter: " + key)
|
94
|
+
|
95
|
+
def get_default_parameters(self):
|
96
|
+
'''
|
97
|
+
return default parameters for the Jansen-Rit sde model.
|
98
|
+
'''
|
99
|
+
|
100
|
+
par = {
|
101
|
+
'G': 0.5, # global coupling strength
|
102
|
+
"A": 3.25, # mV
|
103
|
+
"B": 22.0, # mV
|
104
|
+
"a": 0.1, # 1/ms
|
105
|
+
"b": 0.05, # 1/ms
|
106
|
+
"noise_mu": 0.24,
|
107
|
+
"noise_std": 0.3,
|
108
|
+
"vmax": 0.005,
|
109
|
+
"v0": 6, # mV
|
110
|
+
"r": 0.56, # mV
|
111
|
+
"initial_state": None,
|
112
|
+
|
113
|
+
'weights': None,
|
114
|
+
"C0": 135.0 * 1.0,
|
115
|
+
"C1": 135.0 * 0.8,
|
116
|
+
"C2": 135.0 * 0.25,
|
117
|
+
"C3": 135.0 * 0.25,
|
118
|
+
|
119
|
+
"noise_seed": 0,
|
120
|
+
"seed": None,
|
121
|
+
|
122
|
+
"dt": 0.05, # ms
|
123
|
+
"dim": 6,
|
124
|
+
"method": "heun",
|
125
|
+
"t_transition": 500.0, # ms
|
126
|
+
"t_end": 2501.0, # ms
|
127
|
+
"output": "output", # output directory
|
128
|
+
"RECORD_AVG": False # true to store large time series in file
|
129
|
+
}
|
130
|
+
return par
|
131
|
+
|
132
|
+
# ---------------------------------------------------------------
|
133
|
+
def set_initial_state(self):
|
134
|
+
'''
|
135
|
+
Set initial state for the system of JR equations with N nodes.
|
136
|
+
'''
|
137
|
+
|
138
|
+
self.initial_state = set_initial_state(self.num_nodes, self.seed)
|
139
|
+
self.INITIAL_STATE_SET = True
|
140
|
+
|
141
|
+
# -------------------------------------------------------------------------
|
142
|
+
|
143
|
+
# def set_C(self, label, val_dict):
|
144
|
+
# '''
|
145
|
+
# set the value of C0, C1, C2, C3.
|
146
|
+
|
147
|
+
# Parameters
|
148
|
+
# ----------
|
149
|
+
# label: str
|
150
|
+
# C0, C1, C2, C3
|
151
|
+
# val_dict: dict
|
152
|
+
# {'indices': [list or multiple list seperated with comma],
|
153
|
+
# 'value': [list or multiple list seperated with comma]}
|
154
|
+
|
155
|
+
# '''
|
156
|
+
# indices = val_dict['indices']
|
157
|
+
|
158
|
+
# if indices is None:
|
159
|
+
# indices = [list(range(self.N))]
|
160
|
+
|
161
|
+
# values = val_dict['value']
|
162
|
+
# if isinstance(values, np.ndarray):
|
163
|
+
# values = values.tolist()
|
164
|
+
# if not isinstance(values, list):
|
165
|
+
# values = [values]
|
166
|
+
|
167
|
+
# assert (len(indices) == len(values))
|
168
|
+
# C = getattr(self, label)
|
169
|
+
|
170
|
+
# for i in range(len(values)):
|
171
|
+
# C[indices[i]] = values[i]
|
172
|
+
|
173
|
+
def prepare_input(self):
|
174
|
+
'''
|
175
|
+
prepare input parameters for passing to C++ engine.
|
176
|
+
'''
|
177
|
+
|
178
|
+
self.N = int(self.N)
|
179
|
+
self.weights = np.asarray(self.weights)
|
180
|
+
self.dt = float(self.dt)
|
181
|
+
self.t_transition = float(self.t_transition)
|
182
|
+
self.t_end = float(self.t_end)
|
183
|
+
self.G = float(self.G)
|
184
|
+
self.A = float(self.A)
|
185
|
+
self.B = float(self.B)
|
186
|
+
self.a = float(self.a)
|
187
|
+
self.b = float(self.b)
|
188
|
+
self.r = float(self.r)
|
189
|
+
self.v0 = float(self.v0)
|
190
|
+
self.vmax = float(self.vmax)
|
191
|
+
# self.C0 = np.asarray(self.C0)
|
192
|
+
# self.C1 = np.asarray(self.C1)
|
193
|
+
# self.C2 = np.asarray(self.C2)
|
194
|
+
# self.C3 = np.asarray(self.C3)
|
195
|
+
self.C0 = check_sequence(self.C0, self.N)
|
196
|
+
self.C1 = check_sequence(self.C1, self.N)
|
197
|
+
self.C2 = check_sequence(self.C2, self.N)
|
198
|
+
self.C3 = check_sequence(self.C3, self.N)
|
199
|
+
self.noise_mu = float(self.noise_mu)
|
200
|
+
self.noise_std = float(self.noise_std)
|
201
|
+
self.noise_seed = int(self.noise_seed)
|
202
|
+
self.initial_state = np.asarray(self.initial_state)
|
203
|
+
|
204
|
+
# -------------------------------------------------------------------------
|
205
|
+
def run(self, par={}, x0=None, verbose=False):
|
206
|
+
'''
|
207
|
+
Integrate the system of equations for Jansen-Rit sde model.
|
208
|
+
|
209
|
+
Parameters
|
210
|
+
----------
|
211
|
+
|
212
|
+
par: dict
|
213
|
+
parameters to control the Jansen-Rit sde model.
|
214
|
+
x0: np.array
|
215
|
+
initial state
|
216
|
+
verbose: bool
|
217
|
+
print the message if True
|
218
|
+
|
219
|
+
Returns
|
220
|
+
-------
|
221
|
+
dict
|
222
|
+
- **t** : time series
|
223
|
+
- **x** : state variables
|
224
|
+
|
225
|
+
'''
|
226
|
+
|
227
|
+
if x0 is None:
|
228
|
+
if not self.INITIAL_STATE_SET:
|
229
|
+
self.set_initial_state()
|
230
|
+
if verbose:
|
231
|
+
print("initial state set by default")
|
232
|
+
else:
|
233
|
+
self.INITIAL_STATE_SET = True
|
234
|
+
self.initial_state = x0
|
235
|
+
|
236
|
+
for key in par.keys():
|
237
|
+
if key not in self.valid_params:
|
238
|
+
raise ValueError("Invalid parameter: " + key)
|
239
|
+
# if key in ["C0", "C1", "C2", "C3"]:
|
240
|
+
# self.set_C(key, par[key])
|
241
|
+
# else:
|
242
|
+
setattr(self, key, par[key])
|
243
|
+
|
244
|
+
self.prepare_input()
|
245
|
+
|
246
|
+
obj = _JR_sde(self.N,
|
247
|
+
self.dt,
|
248
|
+
self.t_transition,
|
249
|
+
self.t_end,
|
250
|
+
self.G,
|
251
|
+
self.weights,
|
252
|
+
self.initial_state,
|
253
|
+
self.A,
|
254
|
+
self.B,
|
255
|
+
self.a,
|
256
|
+
self.b,
|
257
|
+
self.r,
|
258
|
+
self.v0,
|
259
|
+
self.vmax,
|
260
|
+
self.C0,
|
261
|
+
self.C1,
|
262
|
+
self.C2,
|
263
|
+
self.C3,
|
264
|
+
self.noise_mu,
|
265
|
+
self.noise_std,
|
266
|
+
self.noise_seed)
|
267
|
+
|
268
|
+
if self.method == 'euler':
|
269
|
+
obj.eulerIntegrate()
|
270
|
+
elif self.method == 'heun':
|
271
|
+
obj.heunIntegrate()
|
272
|
+
else:
|
273
|
+
print("unkown integratiom method")
|
274
|
+
exit(0)
|
275
|
+
|
276
|
+
sol = np.asarray(obj.get_coordinates()).T
|
277
|
+
times = np.asarray(obj.get_times())
|
278
|
+
|
279
|
+
del obj
|
280
|
+
|
281
|
+
return {"t": times, "x": sol}
|
282
|
+
|
283
|
+
|
284
|
+
############################ Jansen-Rit sdde ##################################
|
285
|
+
|
286
|
+
class JR_sdde:
|
287
|
+
pass
|
288
|
+
|
289
|
+
valid_params = ["weights", "delays", "dt", "t_end", "G", "A", "a", "B", "b", "mu",
|
290
|
+
"nstart", "t_end", "t_transition", "sigma", "C", "record_step",
|
291
|
+
"C0", "C1", "C2", "C3", "vmax", "r", "v0", "output",
|
292
|
+
'sti_ti', 'sti_duration', 'sti_amplitude', 'sti_gain',
|
293
|
+
"noise_seed", "seed", "method"]
|
294
|
+
# -------------------------------------------------------------------------
|
295
|
+
|
296
|
+
def __init__(self, par={}) -> None:
|
297
|
+
|
298
|
+
self.check_parameters(par)
|
299
|
+
_par = self.get_default_parameters()
|
300
|
+
_par.update(par)
|
301
|
+
|
302
|
+
for item in _par.items():
|
303
|
+
setattr(self, item[0], item[1])
|
304
|
+
|
305
|
+
if self.seed is not None:
|
306
|
+
np.random.seed(self.seed)
|
307
|
+
|
308
|
+
self.noise_seed = 1 if self.noise_seed else 0
|
309
|
+
assert (self.weights is not None), "weights must be provided"
|
310
|
+
assert (self.delays is not None), "delays must be provided"
|
311
|
+
self.N = self.num_nodes = len(self.weights)
|
312
|
+
|
313
|
+
self.C0 = check_sequence(self.C0, self.N)
|
314
|
+
self.C1 = check_sequence(self.C1, self.N)
|
315
|
+
self.C2 = check_sequence(self.C2, self.N)
|
316
|
+
self.C3 = check_sequence(self.C3, self.N)
|
317
|
+
self.sti_amplitude = check_sequence(self.sti_amplitude, self.N)
|
318
|
+
|
319
|
+
if self.initial_state is None:
|
320
|
+
self.INITIAL_STATE_SET = False
|
321
|
+
os.makedirs(join(self.output), exist_ok=True)
|
322
|
+
|
323
|
+
def check_parameters(self, par):
|
324
|
+
'''
|
325
|
+
check if the parameters are valid
|
326
|
+
'''
|
327
|
+
for key in par.keys():
|
328
|
+
if key not in self.valid_params:
|
329
|
+
raise ValueError("Invalid parameter: " + key)
|
330
|
+
# -------------------------------------------------------------------------
|
331
|
+
|
332
|
+
def get_default_parameters(self):
|
333
|
+
'''
|
334
|
+
get default parameters for the system of JR equations.
|
335
|
+
'''
|
336
|
+
|
337
|
+
param = {
|
338
|
+
"dt": 0.01,
|
339
|
+
"G": 0.01,
|
340
|
+
"mu": 0.22,
|
341
|
+
"sigma": 0.005,
|
342
|
+
"dim": 6,
|
343
|
+
"A": 3.25,
|
344
|
+
"a": 0.1,
|
345
|
+
"B": 22.0,
|
346
|
+
"b": 0.05,
|
347
|
+
"v0": 6.0,
|
348
|
+
"vmax": 0.005,
|
349
|
+
"r": 0.56,
|
350
|
+
"C0": 135.0 * 1.0,
|
351
|
+
"C1": 135.0 * 0.8,
|
352
|
+
"C2": 135.0 * 0.25,
|
353
|
+
"C3": 135.0 * 0.25,
|
354
|
+
'sti_ti': 0.0,
|
355
|
+
'sti_duration': 0.0,
|
356
|
+
'sti_amplitude': 0.0, # scalar or sequence of length N
|
357
|
+
'sti_gain': 0.0,
|
358
|
+
"noise_seed": False,
|
359
|
+
"seed": None,
|
360
|
+
"initial_state": None,
|
361
|
+
"method": "heun",
|
362
|
+
"output": "output",
|
363
|
+
"t_end": 2000.0,
|
364
|
+
"t_transition": 1000.0
|
365
|
+
}
|
366
|
+
|
367
|
+
return param
|
368
|
+
# -------------------------------------------------------------------------
|
369
|
+
|
370
|
+
def prepare_stimulus(self, sti_gain, sti_ti):
|
371
|
+
'''
|
372
|
+
prepare stimulation parameteres
|
373
|
+
'''
|
374
|
+
if np.abs(sti_gain) > 0.0:
|
375
|
+
assert (
|
376
|
+
sti_ti >= self.t_transition), "stimulation must start after transition"
|
377
|
+
# -------------------------------------------------------------------------
|
378
|
+
|
379
|
+
def set_initial_state(self):
|
380
|
+
'''
|
381
|
+
set initial state for the system of JR equations with N nodes.
|
382
|
+
'''
|
383
|
+
self.initial_state = set_initial_state(self.num_nodes, self.seed)
|
384
|
+
self.INITIAL_STATE_SET = True
|
385
|
+
# -------------------------------------------------------------------------
|
386
|
+
|
387
|
+
# def set_C(self, label, val_dict):
|
388
|
+
# indices = val_dict['indices']
|
389
|
+
|
390
|
+
# if indices is None:
|
391
|
+
# indices = [list(range(self.N))]
|
392
|
+
|
393
|
+
# values = val_dict['value']
|
394
|
+
# if isinstance(values, np.ndarray):
|
395
|
+
# values = values.tolist()
|
396
|
+
# if not isinstance(values, list):
|
397
|
+
# values = [values]
|
398
|
+
|
399
|
+
# assert (len(indices) == len(values))
|
400
|
+
# C = getattr(self, label)
|
401
|
+
|
402
|
+
# for i in range(len(values)):
|
403
|
+
# C[indices[i]] = values[i]
|
404
|
+
# -------------------------------------------------------------------------
|
405
|
+
|
406
|
+
def prepare_input(self):
|
407
|
+
'''
|
408
|
+
prepare input parameters for C++ code.
|
409
|
+
'''
|
410
|
+
|
411
|
+
self.dt = float(self.dt)
|
412
|
+
self.t_transition = float(self.t_transition)
|
413
|
+
self.t_end = float(self.t_end)
|
414
|
+
self.G = float(self.G)
|
415
|
+
self.A = float(self.A)
|
416
|
+
self.B = float(self.B)
|
417
|
+
self.a = float(self.a)
|
418
|
+
self.b = float(self.b)
|
419
|
+
self.r = float(self.r)
|
420
|
+
self.v0 = float(self.v0)
|
421
|
+
self.vmax = float(self.vmax)
|
422
|
+
self.C0 = np.asarray(self.C0)
|
423
|
+
self.C1 = np.asarray(self.C1)
|
424
|
+
self.C2 = np.asarray(self.C2)
|
425
|
+
self.C3 = np.asarray(self.C3)
|
426
|
+
self.sti_amplitude = np.asarray(self.sti_amplitude)
|
427
|
+
self.sti_gain = float(self.sti_gain)
|
428
|
+
self.sti_ti = float(self.sti_ti)
|
429
|
+
self.sti_duration = float(self.sti_duration)
|
430
|
+
self.mu = float(self.mu)
|
431
|
+
self.sigma = float(self.sigma)
|
432
|
+
self.noise_seed = int(self.noise_seed)
|
433
|
+
self.initial_state = np.asarray(self.initial_state)
|
434
|
+
self.weights = np.asarray(self.weights)
|
435
|
+
self.delays = np.asarray(self.delays)
|
436
|
+
# -------------------------------------------------------------------------
|
437
|
+
|
438
|
+
def run(self, par={}, x0=None, verbose=False):
|
439
|
+
'''
|
440
|
+
Integrate the system of equations for Jansen-Rit model.
|
441
|
+
'''
|
442
|
+
|
443
|
+
if x0 is None:
|
444
|
+
if not self.INITIAL_STATE_SET:
|
445
|
+
self.set_initial_state()
|
446
|
+
if verbose:
|
447
|
+
print("initial state set by default")
|
448
|
+
else:
|
449
|
+
assert (len(x0) == self.num_nodes * self.dim)
|
450
|
+
self.initial_state = x0
|
451
|
+
self.INITIAL_STATE_SET = True
|
452
|
+
|
453
|
+
for key in par.keys():
|
454
|
+
if key not in self.valid_params:
|
455
|
+
raise ValueError("Invalid parameter: " + key)
|
456
|
+
# if key in ["C0", "C1", "C2", "C3"]:
|
457
|
+
# self.set_C(key, par[key])
|
458
|
+
# else:
|
459
|
+
setattr(self, key, par[key])
|
460
|
+
|
461
|
+
self.prepare_input()
|
462
|
+
obj = _JR_sdde(self.dt,
|
463
|
+
self.initial_state,
|
464
|
+
self.weights,
|
465
|
+
self.delays,
|
466
|
+
self.G,
|
467
|
+
self.dim,
|
468
|
+
self.A,
|
469
|
+
self.B,
|
470
|
+
self.a,
|
471
|
+
self.b,
|
472
|
+
self.r,
|
473
|
+
self.v0,
|
474
|
+
self.vmax,
|
475
|
+
self.C0,
|
476
|
+
self.C1,
|
477
|
+
self.C2,
|
478
|
+
self.C3,
|
479
|
+
self.sti_amplitude,
|
480
|
+
self.sti_gain,
|
481
|
+
self.sti_ti,
|
482
|
+
self.sti_duration,
|
483
|
+
self.mu,
|
484
|
+
self.sigma,
|
485
|
+
self.t_transition,
|
486
|
+
self.t_end,
|
487
|
+
self.noise_seed)
|
488
|
+
obj.integrate(self.method)
|
489
|
+
nstart = int((np.max(self.delays)) / self.dt) + 1
|
490
|
+
t = np.asarray(obj.get_t())[:-nstart]
|
491
|
+
y = np.asarray(obj.get_y())[:, :-nstart]
|
492
|
+
sti_vector = np.asarray(obj.get_sti_vector())[:-nstart]
|
493
|
+
|
494
|
+
return {"t": t, "x": y, "sti": sti_vector}
|
495
|
+
|
496
|
+
############################# helper functions ################################
|
497
|
+
|
498
|
+
|
499
|
+
def check_sequence(x, n):
|
500
|
+
'''
|
501
|
+
check if x is a scalar or a sequence of length n
|
502
|
+
|
503
|
+
parameters
|
504
|
+
----------
|
505
|
+
x: scalar or sequence of length n
|
506
|
+
n: number of nodes
|
507
|
+
|
508
|
+
returns
|
509
|
+
-------
|
510
|
+
x: sequence of length n
|
511
|
+
'''
|
512
|
+
if isinstance(x, (np.ndarray, list, tuple)):
|
513
|
+
assert (len(x) == n), f" variable must be a sequence of length {n}"
|
514
|
+
return x
|
515
|
+
else:
|
516
|
+
return x * np.ones(n)
|
517
|
+
|
518
|
+
|
519
|
+
def set_initial_state(nn, seed=None):
|
520
|
+
'''
|
521
|
+
set initial state for the system of JR equations with N nodes.
|
522
|
+
|
523
|
+
parameters
|
524
|
+
----------
|
525
|
+
nn: number of nodes
|
526
|
+
seed: random seed
|
527
|
+
|
528
|
+
returns
|
529
|
+
-------
|
530
|
+
y: initial state of length 6N
|
531
|
+
|
532
|
+
'''
|
533
|
+
if seed is not None:
|
534
|
+
np.random.seed(seed)
|
535
|
+
|
536
|
+
y0 = np.random.uniform(-1, 1, nn)
|
537
|
+
y1 = np.random.uniform(-500, 500, nn)
|
538
|
+
y2 = np.random.uniform(-50, 50, nn)
|
539
|
+
y3 = np.random.uniform(-6, 6, nn)
|
540
|
+
y4 = np.random.uniform(-20, 20, nn)
|
541
|
+
y5 = np.random.uniform(-500, 500, nn)
|
542
|
+
|
543
|
+
return np.hstack((y0, y1, y2, y3, y4, y5))
|