vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. vbi/__init__.py +37 -0
  2. vbi/_version.py +17 -0
  3. vbi/dataset/__init__.py +0 -0
  4. vbi/dataset/connectivity_84/centers.txt +84 -0
  5. vbi/dataset/connectivity_84/centres.txt +84 -0
  6. vbi/dataset/connectivity_84/cortical.txt +84 -0
  7. vbi/dataset/connectivity_84/tract_lengths.txt +84 -0
  8. vbi/dataset/connectivity_84/weights.txt +84 -0
  9. vbi/dataset/connectivity_88/Aud_88.txt +88 -0
  10. vbi/dataset/connectivity_88/Bold.npz +0 -0
  11. vbi/dataset/connectivity_88/Labels.txt +17 -0
  12. vbi/dataset/connectivity_88/Region_labels.txt +88 -0
  13. vbi/dataset/connectivity_88/tract_lengths.txt +88 -0
  14. vbi/dataset/connectivity_88/weights.txt +88 -0
  15. vbi/feature_extraction/__init__.py +1 -0
  16. vbi/feature_extraction/calc_features.py +293 -0
  17. vbi/feature_extraction/features.json +535 -0
  18. vbi/feature_extraction/features.py +2124 -0
  19. vbi/feature_extraction/features_settings.py +374 -0
  20. vbi/feature_extraction/features_utils.py +1357 -0
  21. vbi/feature_extraction/infodynamics.jar +0 -0
  22. vbi/feature_extraction/utility.py +507 -0
  23. vbi/inference.py +98 -0
  24. vbi/models/__init__.py +0 -0
  25. vbi/models/cpp/__init__.py +0 -0
  26. vbi/models/cpp/_src/__init__.py +0 -0
  27. vbi/models/cpp/_src/__pycache__/mpr_sde.cpython-310.pyc +0 -0
  28. vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
  29. vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
  30. vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  31. vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  32. vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  33. vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
  34. vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
  35. vbi/models/cpp/_src/bold.hpp +303 -0
  36. vbi/models/cpp/_src/do.hpp +167 -0
  37. vbi/models/cpp/_src/do.i +17 -0
  38. vbi/models/cpp/_src/do.py +467 -0
  39. vbi/models/cpp/_src/do_wrap.cxx +12811 -0
  40. vbi/models/cpp/_src/jr_sdde.hpp +352 -0
  41. vbi/models/cpp/_src/jr_sdde.i +19 -0
  42. vbi/models/cpp/_src/jr_sdde.py +688 -0
  43. vbi/models/cpp/_src/jr_sdde_wrap.cxx +18718 -0
  44. vbi/models/cpp/_src/jr_sde.hpp +264 -0
  45. vbi/models/cpp/_src/jr_sde.i +17 -0
  46. vbi/models/cpp/_src/jr_sde.py +470 -0
  47. vbi/models/cpp/_src/jr_sde_wrap.cxx +13406 -0
  48. vbi/models/cpp/_src/km_sde.hpp +158 -0
  49. vbi/models/cpp/_src/km_sde.i +19 -0
  50. vbi/models/cpp/_src/km_sde.py +671 -0
  51. vbi/models/cpp/_src/km_sde_wrap.cxx +17367 -0
  52. vbi/models/cpp/_src/makefile +52 -0
  53. vbi/models/cpp/_src/mpr_sde.hpp +327 -0
  54. vbi/models/cpp/_src/mpr_sde.i +19 -0
  55. vbi/models/cpp/_src/mpr_sde.py +711 -0
  56. vbi/models/cpp/_src/mpr_sde_wrap.cxx +18618 -0
  57. vbi/models/cpp/_src/utility.hpp +307 -0
  58. vbi/models/cpp/_src/vep.hpp +171 -0
  59. vbi/models/cpp/_src/vep.i +16 -0
  60. vbi/models/cpp/_src/vep.py +464 -0
  61. vbi/models/cpp/_src/vep_wrap.cxx +12968 -0
  62. vbi/models/cpp/_src/wc_ode.hpp +294 -0
  63. vbi/models/cpp/_src/wc_ode.i +19 -0
  64. vbi/models/cpp/_src/wc_ode.py +686 -0
  65. vbi/models/cpp/_src/wc_ode_wrap.cxx +24263 -0
  66. vbi/models/cpp/damp_oscillator.py +143 -0
  67. vbi/models/cpp/jansen_rit.py +543 -0
  68. vbi/models/cpp/km.py +187 -0
  69. vbi/models/cpp/mpr.py +289 -0
  70. vbi/models/cpp/vep.py +150 -0
  71. vbi/models/cpp/wc.py +216 -0
  72. vbi/models/cupy/__init__.py +0 -0
  73. vbi/models/cupy/bold.py +111 -0
  74. vbi/models/cupy/ghb.py +284 -0
  75. vbi/models/cupy/jansen_rit.py +473 -0
  76. vbi/models/cupy/km.py +224 -0
  77. vbi/models/cupy/mpr.py +475 -0
  78. vbi/models/cupy/mpr_modified_bold.py +12 -0
  79. vbi/models/cupy/utils.py +184 -0
  80. vbi/models/numba/__init__.py +0 -0
  81. vbi/models/numba/_ww_EI.py +444 -0
  82. vbi/models/numba/damp_oscillator.py +162 -0
  83. vbi/models/numba/ghb.py +208 -0
  84. vbi/models/numba/mpr.py +383 -0
  85. vbi/models/pytorch/__init__.py +0 -0
  86. vbi/models/pytorch/data/default_parameters.npz +0 -0
  87. vbi/models/pytorch/data/input/ROI_sim.mat +0 -0
  88. vbi/models/pytorch/data/input/fc_test.csv +68 -0
  89. vbi/models/pytorch/data/input/fc_train.csv +68 -0
  90. vbi/models/pytorch/data/input/fc_vali.csv +68 -0
  91. vbi/models/pytorch/data/input/fcd_test.mat +0 -0
  92. vbi/models/pytorch/data/input/fcd_test_high_window.mat +0 -0
  93. vbi/models/pytorch/data/input/fcd_test_low_window.mat +0 -0
  94. vbi/models/pytorch/data/input/fcd_train.mat +0 -0
  95. vbi/models/pytorch/data/input/fcd_vali.mat +0 -0
  96. vbi/models/pytorch/data/input/myelin.csv +68 -0
  97. vbi/models/pytorch/data/input/rsfc_gradient.csv +68 -0
  98. vbi/models/pytorch/data/input/run_label_testset.mat +0 -0
  99. vbi/models/pytorch/data/input/sc_test.csv +68 -0
  100. vbi/models/pytorch/data/input/sc_train.csv +68 -0
  101. vbi/models/pytorch/data/input/sc_vali.csv +68 -0
  102. vbi/models/pytorch/data/obs_kong0.npz +0 -0
  103. vbi/models/pytorch/ww_sde_kong.py +570 -0
  104. vbi/models/tvbk/__init__.py +9 -0
  105. vbi/models/tvbk/tvbk_wrapper.py +166 -0
  106. vbi/models/tvbk/utils.py +72 -0
  107. vbi/papers/__init__.py +0 -0
  108. vbi/papers/pavlides_pcb_2015/pavlides.py +211 -0
  109. vbi/tests/__init__.py +0 -0
  110. vbi/tests/_test_mpr_nb.py +36 -0
  111. vbi/tests/test_features.py +355 -0
  112. vbi/tests/test_ghb_cupy.py +90 -0
  113. vbi/tests/test_mpr_cupy.py +49 -0
  114. vbi/tests/test_mpr_numba.py +84 -0
  115. vbi/tests/test_suite.py +19 -0
  116. vbi/utils.py +402 -0
  117. vbi-0.1.3.dist-info/METADATA +166 -0
  118. vbi-0.1.3.dist-info/RECORD +121 -0
  119. vbi-0.1.3.dist-info/WHEEL +5 -0
  120. vbi-0.1.3.dist-info/licenses/LICENSE +201 -0
  121. vbi-0.1.3.dist-info/top_level.txt +1 -0
vbi/models/cpp/km.py ADDED
@@ -0,0 +1,187 @@
1
+ import numpy as np
2
+
3
+ try:
4
+ from vbi.models.cpp._src.km_sde import KM_sde as _KM_sde
5
+ except ImportError as e:
6
+ print(f"Could not import modules: {e}, probably C++ code is not compiled or properly linked.")
7
+
8
+
9
+ class KM_sde:
10
+ '''
11
+ Kuramoto model with noise (sde), C++ implementation.
12
+
13
+ Parameters
14
+ ----------
15
+ par : dict
16
+ Dictionary of parameters.
17
+
18
+ '''
19
+
20
+ valid_parameters = [
21
+ "G", # global coupling strength
22
+ "dt", # time step
23
+ "noise_amp", # noise amplitude
24
+ "omega", # natural angular frequency
25
+ "weights", # weighted connection matrix
26
+ "noise_seed", # fix random seed for noise in Cpp code
27
+ "seed",
28
+ "alpha", # frustration matrix
29
+ "t_initial", # initial time
30
+ "t_transition", # transition time
31
+ "t_end", # end time
32
+ "output", # output directory
33
+ "num_threads", # number of threads using openmp
34
+ "initial_state",
35
+ "type" # output times series data type
36
+ ]
37
+
38
+ def __init__(self, par) -> None:
39
+
40
+ self.check_parameters(par)
41
+ self._par = self.get_default_parameters()
42
+ self._par.update(par)
43
+
44
+ for item in self._par.items():
45
+ name = item[0]
46
+ value = item[1]
47
+ setattr(self, name, value)
48
+
49
+ assert (self.omega is not None)
50
+
51
+ if self.seed is not None:
52
+ np.random.seed(self.seed)
53
+
54
+ self.num_nodes = len(self.omega)
55
+
56
+ if self.initial_state is None:
57
+ self.INITIAL_STATE_SET = False
58
+
59
+ def set_initial_state(self):
60
+ self.INITIAL_STATE_SET = True
61
+ self.initial_state = set_initial_state(self.num_nodes, self.seed)
62
+
63
+ def __str__(self) -> str:
64
+ print("Kuramoto model with noise (sde), C++ implementation.")
65
+ print("----------------")
66
+ for item in self._par.items():
67
+ name = item[0]
68
+ value = item[1]
69
+ print(f"{name} = {value}")
70
+ return ""
71
+
72
+ def __call__(self):
73
+ return self._par
74
+
75
+ def get_default_parameters(self):
76
+ return {
77
+ "G": 1.0, # global coupling strength
78
+ "dt": 0.01, # time step
79
+ "noise_amp": 0.1, # noise amplitude
80
+ "weights": None, # weighted connection matrix
81
+ "alpha": None, # frustration matrix
82
+ "omega": None, # natural angular frequency
83
+ "noise_seed": 0, # fix random seed for noise in Cpp code
84
+ "seed": None, # fix random seed for initial state
85
+ "t_initial": 0.0, # initial time
86
+ "t_transition": 0.0, # transition time
87
+ "t_end": 100.0, # end time
88
+ "num_threads": 1, # number of threads using openmp
89
+ "output": "output", # output directory
90
+ "initial_state": None, # initial state
91
+ "type": np.float32
92
+ }
93
+
94
+ def check_parameters(self, par):
95
+ for key in par.keys():
96
+ if key not in self.valid_parameters:
97
+ raise ValueError(f"Invalid parameter: {key}")
98
+
99
+ def prepare_input(self):
100
+
101
+ nn = self.num_nodes
102
+ if self.weights is None:
103
+ raise ValueError("Missing weights.")
104
+ if self.omega is None:
105
+ raise ValueError("Missing omega.")
106
+ if not self.INITIAL_STATE_SET:
107
+ self.set_initial_state()
108
+
109
+ self.weights = np.array(self.weights, dtype=np.float64)
110
+ self.omega = np.array(self.omega, dtype=np.float64)
111
+ self.initial_state = np.array(self.initial_state, dtype=np.float64)
112
+ self.G = float(self.G)
113
+ self.dt = float(self.dt)
114
+ self.noise_amp = float(self.noise_amp)
115
+ self.t_initial = float(self.t_initial)
116
+ self.t_transition = float(self.t_transition)
117
+ self.t_end = float(self.t_end)
118
+ self.noise_seed = int(self.noise_seed)
119
+ if self.alpha is None:
120
+ self.alpha = np.zeros_like(self.weights, dtype=np.float64)
121
+ else:
122
+ self.alpha = np.array(self.alpha, dtype=np.float64)
123
+ assert (self.alpha.shape == (nn, nn))
124
+
125
+ def run(self, par={}, x0=None, verbose=False):
126
+ '''
127
+ Simulate the model.
128
+
129
+ Parameters
130
+ ----------
131
+ par : dict
132
+ Dictionary of parameters.
133
+ x0 : array
134
+ Initial state.
135
+ verbose : bool
136
+ Print simulation progress.
137
+
138
+ Returns
139
+ -------
140
+ dict
141
+ t : array
142
+ Time points.
143
+ x : array
144
+ State time series.
145
+ '''
146
+
147
+ if x0 is None:
148
+ if not self.INITIAL_STATE_SET:
149
+ self.set_initial_state()
150
+ if verbose:
151
+ print("initial state set by default")
152
+ else:
153
+ assert (len(x0) == self.num_nodes)
154
+ self.initial_state = x0
155
+ self.INITIAL_STATE_SET = True
156
+
157
+ for key in par.keys():
158
+ if key not in self.valid_parameters:
159
+ raise ValueError(f"Invalid parameter {key:s} provided.")
160
+ else:
161
+ setattr(self, key, par[key]['value'])
162
+ self.prepare_input()
163
+
164
+ obj = _KM_sde(self.dt,
165
+ self.t_initial,
166
+ self.t_transition,
167
+ self.t_end,
168
+ self.G,
169
+ self.noise_amp,
170
+ self.initial_state,
171
+ self.omega,
172
+ self.alpha,
173
+ self.weights,
174
+ self.noise_seed,
175
+ self.num_threads
176
+ )
177
+ obj.IntegrateHeun()
178
+ t = np.asarray(obj.get_times())
179
+ x = np.asarray(obj.get_theta()).T.astype(self.type)
180
+
181
+ return {"t": t, "x": x}
182
+
183
+
184
+ def set_initial_state(num_nodes, seed=None):
185
+ if seed is not None:
186
+ np.random.seed(seed)
187
+ return np.random.uniform(0, 2*np.pi, num_nodes)
vbi/models/cpp/mpr.py ADDED
@@ -0,0 +1,289 @@
1
+ import numpy as np
2
+ from typing import Union
3
+ from copy import deepcopy
4
+
5
+ try:
6
+ from vbi.models.cpp._src.mpr_sde import MPR_sde as _MPR_sde
7
+ from vbi.models.cpp._src.mpr_sde import BoldParams as _BoldParams
8
+ except ImportError as e:
9
+ print(f"Could not import modules: {e}, probably C++ code is not compiled or properly linked.")
10
+
11
+
12
+ class MPR_sde:
13
+ """
14
+ MPR model
15
+ """
16
+
17
+ def __init__(self, par: dict = {}, parbold={}) -> None:
18
+
19
+ par = deepcopy(par)
20
+ self._par = self.get_default_parameters()
21
+ self.valid_parameters = list(self._par.keys())
22
+ self.check_parameters(par)
23
+ self._par.update(par)
24
+
25
+ for item in self._par.items():
26
+ name = item[0]
27
+ value = item[1]
28
+ setattr(self, name, value)
29
+
30
+ if self.seed is not None:
31
+ np.random.seed(self.seed)
32
+
33
+ if self.initial_state is None:
34
+ self.INITIAL_STATE_SET = False
35
+
36
+ self.BP = BoldParams(parbold)
37
+
38
+ def set_initial_state(self):
39
+ self.num_nodes = self.weights.shape[0]
40
+ self.initial_state = set_initial_state(self.num_nodes, self.seed)
41
+ self.INITIAL_STATE_SET = True
42
+
43
+ # -------------------------------------------------------------------------
44
+
45
+ def __str__(self) -> str:
46
+ print("MPR sde model.")
47
+ print("----------------")
48
+ for item in self._par.items():
49
+ name = item[0]
50
+ value = item[1]
51
+ print(f"{name} = {value}")
52
+ return ""
53
+
54
+ # -------------------------------------------------------------------------
55
+
56
+ def __call__(self):
57
+ return self._par
58
+
59
+ # -------------------------------------------------------------------------
60
+
61
+ def check_parameters(self, par: dict):
62
+ for key in par.keys():
63
+ if key not in self.valid_parameters:
64
+ raise ValueError(f"Invalid parameter {key:s} provided.")
65
+
66
+ def get_default_parameters(self):
67
+
68
+ params = {
69
+ "G": 0.733, # global coupling strength
70
+ "dt": 0.01, # for mpr model [ms]
71
+ "dt_bold": 0.001, # for Balloon model [s]
72
+ "J": 14.5, # model parameter
73
+ "eta": -4.6, # model parameter
74
+ "tau": 1.0, # model parameter
75
+ "delta": 0.7, # model parameter
76
+ "tr": 500.0, # sampling from mpr time series
77
+ "rv_decimate": 10, # sampling from activity time series
78
+ "noise_amp": 0.037, # amplitude of noise
79
+ "noise_seed": 0, # fix seed for noise
80
+ "iapp": 0.0, # constant applyed current
81
+ "seed": None,
82
+ "initial_state": None, # initial condition of the system
83
+ "t_cut": 0.0, # transition time [ms]
84
+ "t_end": 5 * 60 * 1000.0, # end time [ms]
85
+ "weights": None, # weighted connection matrix
86
+ "output": "output", # output directory
87
+ "RECORD_RV": 0, # true to store large time series in file
88
+ "RECORD_BOLD": 1,
89
+ }
90
+
91
+ return params
92
+
93
+ def prepare_input(self):
94
+ """
95
+ Prepare input parameters for passing to C++ engine.
96
+ """
97
+
98
+ self.dt = float(self.dt)
99
+ self.dt_bold = float(self.dt_bold)
100
+ self.tr = float(self.tr)
101
+ self.initial_state = np.asarray(self.initial_state).astype(np.float64)
102
+ self.weights = np.asarray(self.weights).astype(np.float64)
103
+ self.num_nodes = self.weights.shape[0]
104
+ self.G = float(self.G)
105
+ self.eta = check_sequence(self.eta, self.num_nodes)
106
+ self.eta = np.asarray(self.eta).astype(np.float64)
107
+
108
+ self.J = check_sequence(self.J, self.num_nodes)
109
+ self.tau = check_sequence(self.tau, self.num_nodes)
110
+ self.delta = check_sequence(self.delta, self.num_nodes)
111
+ self.iapp = check_sequence(self.iapp, self.num_nodes)
112
+ self.noise_amp = float(self.noise_amp)
113
+ self.rv_decimate = int(self.rv_decimate)
114
+ self.t_cut = float(self.t_cut) / 10.0
115
+ self.t_end = float(self.t_end) / 10.0
116
+ self.RECORD_RV = int(self.RECORD_RV)
117
+ self.RECORD_BOLD = int(self.RECORD_BOLD)
118
+ self.noise_seed = int(self.noise_seed)
119
+
120
+ def run(self, par: dict = {}, x0: np.ndarray = None, verbose: bool = False):
121
+ """
122
+ Integrate the MPR model with the given parameters.
123
+
124
+ Parameters
125
+ ----------
126
+ par : dict
127
+ Dictionary of parameters.
128
+ x0 : array_like
129
+ Initial condition of the system.
130
+ verbose : bool
131
+ If True, print the progress of the simulation.
132
+
133
+ Returns
134
+ -------
135
+ bold : array_like
136
+ Simulated BOLD signal.
137
+ """
138
+
139
+ if x0 is None:
140
+ if not self.INITIAL_STATE_SET:
141
+ self.set_initial_state()
142
+ if verbose:
143
+ print("initial state set by default")
144
+ else:
145
+ assert len(x0) == self.num_nodes * 2
146
+ self.initial_state = x0
147
+ self.INITIAL_STATE_SET = True
148
+
149
+ for key in par.keys():
150
+ if key not in self.valid_parameters:
151
+ raise ValueError(f"Invalid parameter {key:s} provided.")
152
+ setattr(self, key, par[key])
153
+
154
+ self.prepare_input()
155
+
156
+ obj = _MPR_sde(
157
+ self.dt,
158
+ self.dt_bold,
159
+ self.rv_decimate,
160
+ self.weights,
161
+ self.initial_state,
162
+ self.delta,
163
+ self.tau,
164
+ self.eta,
165
+ self.J,
166
+ self.iapp,
167
+ self.noise_amp,
168
+ self.G,
169
+ self.t_end,
170
+ self.t_cut,
171
+ self.tr,
172
+ self.RECORD_RV,
173
+ self.RECORD_BOLD,
174
+ self.noise_seed,
175
+ self.BP.get_params()
176
+ )
177
+
178
+ obj.integrate()
179
+
180
+ bold_d = np.array([])
181
+ bold_t = np.array([])
182
+ r_d = np.array([])
183
+ r_t = np.array([])
184
+
185
+
186
+ if self.RECORD_BOLD:
187
+ bold_d = np.asarray(obj.get_bold_d()).astype(np.float32)
188
+ bold_t = np.asarray(obj.get_bold_t())
189
+
190
+ if bold_d.ndim == 2:
191
+ bold_d = bold_d[bold_t > self.t_cut, :]
192
+ bold_t = bold_t[bold_t > self.t_cut] * 10.0
193
+
194
+ if self.RECORD_RV:
195
+ r_d = np.asarray(obj.get_r_d()).astype(np.float32)
196
+ r_t = np.asarray(obj.get_r_t())
197
+ if r_d.ndim == 2:
198
+ r_d = r_d[r_t > self.t_cut, :]
199
+ r_t = r_t[r_t > self.t_cut] * 10.0
200
+
201
+ return {
202
+ "rv_t": r_t,
203
+ "rv_d": r_d,
204
+ "bold_t": bold_t,
205
+ "bold_d": bold_d,
206
+ }
207
+
208
+
209
+ class BoldParams:
210
+
211
+ def __init__(self, par={}):
212
+
213
+ self._par = self.get_default_parameters()
214
+ self.valid_parameters = list(self._par.keys())
215
+ self.check_parameters(par)
216
+ self._par.update(par)
217
+
218
+ for item in self._par.items():
219
+ name = item[0]
220
+ value = item[1]
221
+ setattr(self, name, value)
222
+
223
+ def check_parameters(self, par):
224
+ for key in par.keys():
225
+ if key not in self.valid_parameters:
226
+ raise ValueError(f"Invalid parameter {key:s} provided.")
227
+
228
+ def get_default_parameters(self):
229
+ return {
230
+ "kappa": 0.7,
231
+ "gamma": 0.5,
232
+ "tau": 1.0,
233
+ "alpha": 0.35,
234
+ "epsilon": 0.36,
235
+ "Eo": 0.42,
236
+ "TE": 0.05,
237
+ "vo": 0.09,
238
+ "r0": 26.0,
239
+ "theta0": 41.0,
240
+ "rtol": 1e-6,
241
+ "atol": 1e-9,
242
+ }
243
+
244
+ def get_params(self):
245
+ bp = _BoldParams()
246
+ bp.kappa = self.kappa
247
+ bp.gamma = self.gamma
248
+ bp.tau = self.tau
249
+ bp.alpha = self.alpha
250
+ bp.epsilon = self.epsilon
251
+ bp.Eo = self.Eo
252
+ bp.TE = self.TE
253
+ bp.vo = self.vo
254
+ bp.r0 = self.r0
255
+ bp.theta0 = self.theta0
256
+ bp.rtol = self.rtol
257
+ bp.atol = self.atol
258
+ return bp
259
+
260
+
261
+ def check_sequence(x: Union[int, float, np.ndarray], n: int):
262
+ """
263
+ check if x is a scalar or a sequence of length n
264
+
265
+ parameters
266
+ ----------
267
+ x: scalar or sequence of length n
268
+ n: number of nodes
269
+
270
+ returns
271
+ -------
272
+ x: sequence of length n
273
+ """
274
+ if isinstance(x, (np.ndarray, list, tuple)):
275
+ assert len(x) == n, f" variable must be a sequence of length {n}"
276
+ return x
277
+ else:
278
+ return x * np.ones(n)
279
+
280
+
281
+ def set_initial_state(nn, seed=None):
282
+
283
+ if seed is not None:
284
+ np.random.seed(seed)
285
+
286
+ y0 = np.random.rand(2 * nn)
287
+ y0[:nn] = y0[:nn] * 1.5
288
+ y0[nn:] = y0[nn:] * 4 - 2
289
+ return y0
vbi/models/cpp/vep.py ADDED
@@ -0,0 +1,150 @@
1
+ import os
2
+ import numpy as np
3
+ from copy import deepcopy
4
+ from os.path import join
5
+ from typing import Union
6
+
7
+ try:
8
+ from vbi.models.cpp._src.vep import VEP as _VEP
9
+ except ImportError as e:
10
+ print(f"Could not import modules: {e}, probably C++ code is not compiled or properly linked.")
11
+
12
+
13
+ class VEP:
14
+ """
15
+ Virtual Epileptic Patient (VEP) model
16
+ """
17
+
18
+ def __init__(self, par: dict = {}):
19
+
20
+ par = deepcopy(par)
21
+ self._par = self.get_default_parameters()
22
+ self.valid_params = list(self._par.keys())
23
+ self.check_parameters(par)
24
+ self._par.update(par)
25
+
26
+ for item in self._par.items():
27
+ setattr(self, item[0], item[1])
28
+
29
+ if self.seed is not None:
30
+ np.random.seed(self.seed)
31
+
32
+ self.INITIAL_STATE_SET = False
33
+ if self.initial_state is not None:
34
+ self.INITIAL_STATE_SET = True
35
+
36
+
37
+ def set_initial_state(self):
38
+ self.nn = self.weights.shape[0]
39
+ self.initial_state = set_initial_state(self.nn, self.seed)
40
+ self.INITIAL_STATE_SET = True
41
+
42
+ def __str__(self) -> str:
43
+ print("VEP model")
44
+ print("---------")
45
+ for item in self._par.items():
46
+ print(f"{item[0]} = {item[1]}")
47
+ return ""
48
+
49
+ def __call__(self):
50
+ return self._par
51
+
52
+ def check_parameters(self, par: dict):
53
+ for key in par.keys():
54
+ if key not in self.valid_params:
55
+ raise ValueError(f"Invalid parameter: {key}")
56
+
57
+ def prepare_input(self):
58
+ self.nn = self.weights.shape[0]
59
+ self.iext = check_sequence(self.iext, self.nn)
60
+ self.tau = float(self.tau)
61
+ self.eta = check_sequence(self.eta, self.nn)
62
+ self.sigma = float(self.noise_sigma)
63
+ self.dt = float(self.dt)
64
+ self.tend = float(self.tend)
65
+ self.tcut = float(self.tcut)
66
+ self.noise_seed = int(self.noise_seed)
67
+ self.record_step = int(self.record_step)
68
+ self.method = str(self.method)
69
+
70
+ def get_default_parameters(self):
71
+ params = {
72
+ "G": 1.0,
73
+ "seed": None,
74
+ "initial_state": None,
75
+ "weights": None,
76
+ "tau": 10.0,
77
+ "eta": -1.5,
78
+ "noise_sigma": 0.1,
79
+ "iext": 0.0,
80
+ "dt": 0.01,
81
+ "tend": 100.0,
82
+ "tcut": 0.0,
83
+ "noise_seed": 0,
84
+ "record_step": 1,
85
+ "method": "euler",
86
+ "output": "output",
87
+ }
88
+ return params
89
+
90
+ def run(self, par: dict = {}, x0: np.ndarray = None, verbose: bool = False):
91
+
92
+ if x0 is None:
93
+ if not self.INITIAL_STATE_SET:
94
+ self.set_initial_state()
95
+ else:
96
+ self.initial_state = x0
97
+ self.INITIAL_STATE_SET = True
98
+ for key in par.keys():
99
+ if key not in self.valid_params:
100
+ raise ValueError(f"Invalid parameter: {key}")
101
+ setattr(self, key, par[key])
102
+ self.prepare_input()
103
+
104
+ obj = _VEP(
105
+ self.G,
106
+ self.iext,
107
+ self.eta,
108
+ self.dt,
109
+ self.tcut,
110
+ self.tend,
111
+ self.tau,
112
+ self.noise_sigma,
113
+ self.initial_state,
114
+ self.weights,
115
+ self.noise_seed,
116
+ self.method,
117
+ )
118
+ obj.integrate()
119
+ states = np.asarray(obj.get_states(), dtype=np.float32).T
120
+ t = np.asarray(obj.get_times())
121
+ return {"t": t, "x": states}
122
+
123
+
124
+ def set_initial_state(nn: int, seed: int = None):
125
+ if seed is not None:
126
+ np.random.seed(seed)
127
+ x0 = np.zeros(2 * nn)
128
+ x0[:nn] = np.random.uniform(-3.0, -2.0, nn)
129
+ x0[nn:] = np.random.uniform(0.0, 3.5, nn)
130
+ return x0
131
+
132
+
133
+ def check_sequence(x: Union[int, float, np.ndarray], n: int):
134
+ """
135
+ check if x is a scalar or a sequence of length n
136
+
137
+ parameters
138
+ ----------
139
+ x: scalar or sequence
140
+ n: number of elements
141
+
142
+ returns
143
+ -------
144
+ x: sequence of length n
145
+ """
146
+ if isinstance(x, (np.ndarray, list, tuple)):
147
+ assert len(x) == n, f" variable must be a sequence of length {n}"
148
+ return x
149
+ else:
150
+ return x * np.ones(n)