vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. vbi/__init__.py +37 -0
  2. vbi/_version.py +17 -0
  3. vbi/dataset/__init__.py +0 -0
  4. vbi/dataset/connectivity_84/centers.txt +84 -0
  5. vbi/dataset/connectivity_84/centres.txt +84 -0
  6. vbi/dataset/connectivity_84/cortical.txt +84 -0
  7. vbi/dataset/connectivity_84/tract_lengths.txt +84 -0
  8. vbi/dataset/connectivity_84/weights.txt +84 -0
  9. vbi/dataset/connectivity_88/Aud_88.txt +88 -0
  10. vbi/dataset/connectivity_88/Bold.npz +0 -0
  11. vbi/dataset/connectivity_88/Labels.txt +17 -0
  12. vbi/dataset/connectivity_88/Region_labels.txt +88 -0
  13. vbi/dataset/connectivity_88/tract_lengths.txt +88 -0
  14. vbi/dataset/connectivity_88/weights.txt +88 -0
  15. vbi/feature_extraction/__init__.py +1 -0
  16. vbi/feature_extraction/calc_features.py +293 -0
  17. vbi/feature_extraction/features.json +535 -0
  18. vbi/feature_extraction/features.py +2124 -0
  19. vbi/feature_extraction/features_settings.py +374 -0
  20. vbi/feature_extraction/features_utils.py +1357 -0
  21. vbi/feature_extraction/infodynamics.jar +0 -0
  22. vbi/feature_extraction/utility.py +507 -0
  23. vbi/inference.py +98 -0
  24. vbi/models/__init__.py +0 -0
  25. vbi/models/cpp/__init__.py +0 -0
  26. vbi/models/cpp/_src/__init__.py +0 -0
  27. vbi/models/cpp/_src/__pycache__/mpr_sde.cpython-310.pyc +0 -0
  28. vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
  29. vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
  30. vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  31. vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  32. vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  33. vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
  34. vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
  35. vbi/models/cpp/_src/bold.hpp +303 -0
  36. vbi/models/cpp/_src/do.hpp +167 -0
  37. vbi/models/cpp/_src/do.i +17 -0
  38. vbi/models/cpp/_src/do.py +467 -0
  39. vbi/models/cpp/_src/do_wrap.cxx +12811 -0
  40. vbi/models/cpp/_src/jr_sdde.hpp +352 -0
  41. vbi/models/cpp/_src/jr_sdde.i +19 -0
  42. vbi/models/cpp/_src/jr_sdde.py +688 -0
  43. vbi/models/cpp/_src/jr_sdde_wrap.cxx +18718 -0
  44. vbi/models/cpp/_src/jr_sde.hpp +264 -0
  45. vbi/models/cpp/_src/jr_sde.i +17 -0
  46. vbi/models/cpp/_src/jr_sde.py +470 -0
  47. vbi/models/cpp/_src/jr_sde_wrap.cxx +13406 -0
  48. vbi/models/cpp/_src/km_sde.hpp +158 -0
  49. vbi/models/cpp/_src/km_sde.i +19 -0
  50. vbi/models/cpp/_src/km_sde.py +671 -0
  51. vbi/models/cpp/_src/km_sde_wrap.cxx +17367 -0
  52. vbi/models/cpp/_src/makefile +52 -0
  53. vbi/models/cpp/_src/mpr_sde.hpp +327 -0
  54. vbi/models/cpp/_src/mpr_sde.i +19 -0
  55. vbi/models/cpp/_src/mpr_sde.py +711 -0
  56. vbi/models/cpp/_src/mpr_sde_wrap.cxx +18618 -0
  57. vbi/models/cpp/_src/utility.hpp +307 -0
  58. vbi/models/cpp/_src/vep.hpp +171 -0
  59. vbi/models/cpp/_src/vep.i +16 -0
  60. vbi/models/cpp/_src/vep.py +464 -0
  61. vbi/models/cpp/_src/vep_wrap.cxx +12968 -0
  62. vbi/models/cpp/_src/wc_ode.hpp +294 -0
  63. vbi/models/cpp/_src/wc_ode.i +19 -0
  64. vbi/models/cpp/_src/wc_ode.py +686 -0
  65. vbi/models/cpp/_src/wc_ode_wrap.cxx +24263 -0
  66. vbi/models/cpp/damp_oscillator.py +143 -0
  67. vbi/models/cpp/jansen_rit.py +543 -0
  68. vbi/models/cpp/km.py +187 -0
  69. vbi/models/cpp/mpr.py +289 -0
  70. vbi/models/cpp/vep.py +150 -0
  71. vbi/models/cpp/wc.py +216 -0
  72. vbi/models/cupy/__init__.py +0 -0
  73. vbi/models/cupy/bold.py +111 -0
  74. vbi/models/cupy/ghb.py +284 -0
  75. vbi/models/cupy/jansen_rit.py +473 -0
  76. vbi/models/cupy/km.py +224 -0
  77. vbi/models/cupy/mpr.py +475 -0
  78. vbi/models/cupy/mpr_modified_bold.py +12 -0
  79. vbi/models/cupy/utils.py +184 -0
  80. vbi/models/numba/__init__.py +0 -0
  81. vbi/models/numba/_ww_EI.py +444 -0
  82. vbi/models/numba/damp_oscillator.py +162 -0
  83. vbi/models/numba/ghb.py +208 -0
  84. vbi/models/numba/mpr.py +383 -0
  85. vbi/models/pytorch/__init__.py +0 -0
  86. vbi/models/pytorch/data/default_parameters.npz +0 -0
  87. vbi/models/pytorch/data/input/ROI_sim.mat +0 -0
  88. vbi/models/pytorch/data/input/fc_test.csv +68 -0
  89. vbi/models/pytorch/data/input/fc_train.csv +68 -0
  90. vbi/models/pytorch/data/input/fc_vali.csv +68 -0
  91. vbi/models/pytorch/data/input/fcd_test.mat +0 -0
  92. vbi/models/pytorch/data/input/fcd_test_high_window.mat +0 -0
  93. vbi/models/pytorch/data/input/fcd_test_low_window.mat +0 -0
  94. vbi/models/pytorch/data/input/fcd_train.mat +0 -0
  95. vbi/models/pytorch/data/input/fcd_vali.mat +0 -0
  96. vbi/models/pytorch/data/input/myelin.csv +68 -0
  97. vbi/models/pytorch/data/input/rsfc_gradient.csv +68 -0
  98. vbi/models/pytorch/data/input/run_label_testset.mat +0 -0
  99. vbi/models/pytorch/data/input/sc_test.csv +68 -0
  100. vbi/models/pytorch/data/input/sc_train.csv +68 -0
  101. vbi/models/pytorch/data/input/sc_vali.csv +68 -0
  102. vbi/models/pytorch/data/obs_kong0.npz +0 -0
  103. vbi/models/pytorch/ww_sde_kong.py +570 -0
  104. vbi/models/tvbk/__init__.py +9 -0
  105. vbi/models/tvbk/tvbk_wrapper.py +166 -0
  106. vbi/models/tvbk/utils.py +72 -0
  107. vbi/papers/__init__.py +0 -0
  108. vbi/papers/pavlides_pcb_2015/pavlides.py +211 -0
  109. vbi/tests/__init__.py +0 -0
  110. vbi/tests/_test_mpr_nb.py +36 -0
  111. vbi/tests/test_features.py +355 -0
  112. vbi/tests/test_ghb_cupy.py +90 -0
  113. vbi/tests/test_mpr_cupy.py +49 -0
  114. vbi/tests/test_mpr_numba.py +84 -0
  115. vbi/tests/test_suite.py +19 -0
  116. vbi/utils.py +402 -0
  117. vbi-0.1.3.dist-info/METADATA +166 -0
  118. vbi-0.1.3.dist-info/RECORD +121 -0
  119. vbi-0.1.3.dist-info/WHEEL +5 -0
  120. vbi-0.1.3.dist-info/licenses/LICENSE +201 -0
  121. vbi-0.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,143 @@
1
+ import os
2
+ from typing import Any
3
+ import numpy as np
4
+
5
+ try:
6
+ from vbi.models.cpp._src.do import DO as _DO
7
+ except ImportError as e:
8
+ print(f"Could not import modules: {e}, probably C++ code is not compiled or properly linked.")
9
+
10
+ class DO:
11
+
12
+ '''
13
+ Damp Oscillator model class.
14
+ '''
15
+
16
+ valid_params = ["a", "b", "dt", "t_start", "t_end", "t_transition",
17
+ "initial_state", "method", "output"]
18
+
19
+ # ---------------------------------------------------------------
20
+ def __init__(self, par={}):
21
+ '''
22
+ Parameters
23
+ ----------
24
+ par : dictionary
25
+ parameters which includes the following:
26
+ - **dt** [double] time step.
27
+ - **t_start** [double] initial time for simulation.
28
+ - **t_end** [double] final time for simulation.
29
+ - **initial_state** [list] initial state of the system.
30
+
31
+ '''
32
+ self.check_parameters(par)
33
+ self._par = self.get_default_parameters()
34
+ self._par.update(par)
35
+
36
+ for item in self._par.items():
37
+ name = item[0]
38
+ value = item[1]
39
+ setattr(self, name, value)
40
+
41
+ def __str__(self) -> str:
42
+ print("Damp Oscillator model")
43
+ print("----------------")
44
+ for item in self._par.items():
45
+ name = item[0]
46
+ value = item[1]
47
+ print(f"{name} = {value}")
48
+ return ""
49
+
50
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
51
+ print("Damp Oscillator model")
52
+ return self._par
53
+
54
+ def check_parameters(self, par):
55
+ '''
56
+ check if the parameters are valid.
57
+ '''
58
+ for key in par.keys():
59
+ if key not in self.valid_params:
60
+ raise ValueError("Invalid parameter: " + key)
61
+
62
+ def get_default_parameters(self):
63
+ '''
64
+ return default parameters for damp oscillator model.
65
+ '''
66
+
67
+ params = {
68
+ "a": 0.1,
69
+ "b": 0.05,
70
+ "dt": 0.01,
71
+ "t_start": 0,
72
+ "method": "rk4",
73
+ "t_end": 100.0,
74
+ "t_transition": 20,
75
+ "output": "output",
76
+ "initial_state": [0.5, 1.0],
77
+ }
78
+
79
+ return params
80
+
81
+ def prepare_input(self):
82
+ '''
83
+ prepare input for cpp model.
84
+ '''
85
+ self.t_start = float(self.t_start)
86
+ self.t_end = float(self.t_end)
87
+ self.dt = float(self.dt)
88
+ self.a = float(self.a)
89
+ self.b = float(self.b)
90
+
91
+ if self.output is None:
92
+ self.output = "output"
93
+ if not os.path.exists(self.output):
94
+ os.makedirs(self.output)
95
+
96
+ if self.initial_state is None:
97
+ self.initial_state = [0.5, 1.0]
98
+ self.initial_state = np.asarray(self.initial_state, dtype=np.float64)
99
+
100
+ # ---------------------------------------------------------------
101
+ def run(self, par={}, x0=None, verbose=False):
102
+ '''
103
+ Integrate the damp oscillator system of equations
104
+
105
+ Parameters
106
+ ----------
107
+ par : dictionary
108
+ parameters to control the model parameters.
109
+
110
+ '''
111
+
112
+ if x0 is not None:
113
+ assert(len(x0) == 2)
114
+ self.initial_state = x0
115
+
116
+ self.check_parameters(par)
117
+ for key in par.keys():
118
+ setattr(self, key, par[key])
119
+
120
+ self.prepare_input()
121
+
122
+ obj = _DO(self.dt,
123
+ self.a,
124
+ self.b,
125
+ self.t_start,
126
+ self.t_end,
127
+ self.initial_state)
128
+
129
+ if self.method.lower() == 'euler':
130
+ obj.eulerIntegrate()
131
+ elif self.method.lower() == 'heun':
132
+ obj.heunIntegrate()
133
+ elif self.method.lower() == 'rk4':
134
+ obj.rk4Integrate()
135
+ else:
136
+ print("unkown integratiom method")
137
+ exit(0)
138
+
139
+ sol = np.asarray(obj.get_coordinates())
140
+ times = np.asarray(obj.get_times())
141
+ del obj
142
+
143
+ return {"t": times, "x": sol}
@@ -0,0 +1,543 @@
1
+ import os
2
+ import numpy as np
3
+ from os.path import join
4
+
5
+ try:
6
+ from vbi.models.cpp._src.jr_sde import JR_sde as _JR_sde
7
+ from vbi.models.cpp._src.jr_sdde import JR_sdde as _JR_sdde
8
+ except ImportError as e:
9
+ print(f"Could not import modules: {e}, probably C++ code is not compiled or properly linked.")
10
+
11
+
12
+ class JR_sde:
13
+ '''
14
+ Jansen-Rit model C++ implementation.
15
+
16
+ Parameters
17
+ ----------
18
+
19
+ par: dict
20
+ Including the following:
21
+ - **A** : [mV] determine the maximum amplitude of the excitatory PSP (EPSP)
22
+ - **B** : [mV] determine the maximum amplitude of the inhibitory PSP (IPSP)
23
+ - **a** : [Hz] 1/tau_e, :math:`\sum` of the reciprocal of the time constant of passive membrane and all other spatially distributed delays in the dendritic network
24
+ - **b** : [Hz] 1/tau_i
25
+ - **r** [mV] the steepness of the sigmoidal transformation.
26
+ - **v0** parameter of nonlinear sigmoid function
27
+ - **vmax** parameter of nonlinear sigmoid function
28
+ - **C_i** [list or np.array] average number of synaptic contacts in th inhibitory and excitatory feedback loops
29
+ - **noise_amp**
30
+ - **noise_std**
31
+
32
+ - **dt** [second] integration time step
33
+ - **t_initial** [s] initial time
34
+ - **t_end** [s] final time
35
+ - **method** [str] method of integration
36
+ - **t_transition** [s] time to reach steady state
37
+ - **dim** [int] dimention of the system
38
+
39
+ '''
40
+ valid_params = [
41
+ "noise_seed", "seed", "G", "weights", "A", "B", "a", "b",
42
+ "noise_mu", "noise_std", "vmax", "v0", "r",
43
+ "C0", "C1", "C2", "C3", "dt", "method", "t_transition",
44
+ "t_end", "control", "output", "RECORD_AVG",
45
+ "initial_state"
46
+ ]
47
+
48
+ def __init__(self, par={}):
49
+
50
+ self.check_parameters(par)
51
+ self._par = self.get_default_parameters()
52
+ self._par.update(par)
53
+
54
+ for item in self._par.items():
55
+ name = item[0]
56
+ value = item[1]
57
+ setattr(self, name, value)
58
+
59
+ if self.seed is not None:
60
+ np.random.seed(self.seed)
61
+
62
+ self.N = self.num_nodes = np.asarray(self.weights).shape[0]
63
+
64
+ if self.initial_state is None:
65
+ self.INITIAL_STATE_SET = False
66
+
67
+ # self.C0 = self.C0 * np.ones(self.N)
68
+ # self.C1 = self.C1 * np.ones(self.N)
69
+ # self.C2 = self.C2 * np.ones(self.N)
70
+ # self.C3 = self.C3 * np.ones(self.N)
71
+ self.noise_seed = 1 if self.noise_seed else 0
72
+ os.makedirs(join(self.output), exist_ok=True)
73
+
74
+ def __str__(self) -> str:
75
+ print("Jansen-Rit sde model")
76
+ print("----------------")
77
+ for item in self._par.items():
78
+ name = item[0]
79
+ value = item[1]
80
+ print(f"{name} = {value}")
81
+ return ""
82
+
83
+ def __call__(self):
84
+ print("Jansen-Rit sde model")
85
+ return self._par
86
+
87
+ def check_parameters(self, par):
88
+ '''
89
+ Check if the parameters are valid.
90
+ '''
91
+ for key in par.keys():
92
+ if key not in self.valid_params:
93
+ raise ValueError("Invalid parameter: " + key)
94
+
95
+ def get_default_parameters(self):
96
+ '''
97
+ return default parameters for the Jansen-Rit sde model.
98
+ '''
99
+
100
+ par = {
101
+ 'G': 0.5, # global coupling strength
102
+ "A": 3.25, # mV
103
+ "B": 22.0, # mV
104
+ "a": 0.1, # 1/ms
105
+ "b": 0.05, # 1/ms
106
+ "noise_mu": 0.24,
107
+ "noise_std": 0.3,
108
+ "vmax": 0.005,
109
+ "v0": 6, # mV
110
+ "r": 0.56, # mV
111
+ "initial_state": None,
112
+
113
+ 'weights': None,
114
+ "C0": 135.0 * 1.0,
115
+ "C1": 135.0 * 0.8,
116
+ "C2": 135.0 * 0.25,
117
+ "C3": 135.0 * 0.25,
118
+
119
+ "noise_seed": 0,
120
+ "seed": None,
121
+
122
+ "dt": 0.05, # ms
123
+ "dim": 6,
124
+ "method": "heun",
125
+ "t_transition": 500.0, # ms
126
+ "t_end": 2501.0, # ms
127
+ "output": "output", # output directory
128
+ "RECORD_AVG": False # true to store large time series in file
129
+ }
130
+ return par
131
+
132
+ # ---------------------------------------------------------------
133
+ def set_initial_state(self):
134
+ '''
135
+ Set initial state for the system of JR equations with N nodes.
136
+ '''
137
+
138
+ self.initial_state = set_initial_state(self.num_nodes, self.seed)
139
+ self.INITIAL_STATE_SET = True
140
+
141
+ # -------------------------------------------------------------------------
142
+
143
+ # def set_C(self, label, val_dict):
144
+ # '''
145
+ # set the value of C0, C1, C2, C3.
146
+
147
+ # Parameters
148
+ # ----------
149
+ # label: str
150
+ # C0, C1, C2, C3
151
+ # val_dict: dict
152
+ # {'indices': [list or multiple list seperated with comma],
153
+ # 'value': [list or multiple list seperated with comma]}
154
+
155
+ # '''
156
+ # indices = val_dict['indices']
157
+
158
+ # if indices is None:
159
+ # indices = [list(range(self.N))]
160
+
161
+ # values = val_dict['value']
162
+ # if isinstance(values, np.ndarray):
163
+ # values = values.tolist()
164
+ # if not isinstance(values, list):
165
+ # values = [values]
166
+
167
+ # assert (len(indices) == len(values))
168
+ # C = getattr(self, label)
169
+
170
+ # for i in range(len(values)):
171
+ # C[indices[i]] = values[i]
172
+
173
+ def prepare_input(self):
174
+ '''
175
+ prepare input parameters for passing to C++ engine.
176
+ '''
177
+
178
+ self.N = int(self.N)
179
+ self.weights = np.asarray(self.weights)
180
+ self.dt = float(self.dt)
181
+ self.t_transition = float(self.t_transition)
182
+ self.t_end = float(self.t_end)
183
+ self.G = float(self.G)
184
+ self.A = float(self.A)
185
+ self.B = float(self.B)
186
+ self.a = float(self.a)
187
+ self.b = float(self.b)
188
+ self.r = float(self.r)
189
+ self.v0 = float(self.v0)
190
+ self.vmax = float(self.vmax)
191
+ # self.C0 = np.asarray(self.C0)
192
+ # self.C1 = np.asarray(self.C1)
193
+ # self.C2 = np.asarray(self.C2)
194
+ # self.C3 = np.asarray(self.C3)
195
+ self.C0 = check_sequence(self.C0, self.N)
196
+ self.C1 = check_sequence(self.C1, self.N)
197
+ self.C2 = check_sequence(self.C2, self.N)
198
+ self.C3 = check_sequence(self.C3, self.N)
199
+ self.noise_mu = float(self.noise_mu)
200
+ self.noise_std = float(self.noise_std)
201
+ self.noise_seed = int(self.noise_seed)
202
+ self.initial_state = np.asarray(self.initial_state)
203
+
204
+ # -------------------------------------------------------------------------
205
+ def run(self, par={}, x0=None, verbose=False):
206
+ '''
207
+ Integrate the system of equations for Jansen-Rit sde model.
208
+
209
+ Parameters
210
+ ----------
211
+
212
+ par: dict
213
+ parameters to control the Jansen-Rit sde model.
214
+ x0: np.array
215
+ initial state
216
+ verbose: bool
217
+ print the message if True
218
+
219
+ Returns
220
+ -------
221
+ dict
222
+ - **t** : time series
223
+ - **x** : state variables
224
+
225
+ '''
226
+
227
+ if x0 is None:
228
+ if not self.INITIAL_STATE_SET:
229
+ self.set_initial_state()
230
+ if verbose:
231
+ print("initial state set by default")
232
+ else:
233
+ self.INITIAL_STATE_SET = True
234
+ self.initial_state = x0
235
+
236
+ for key in par.keys():
237
+ if key not in self.valid_params:
238
+ raise ValueError("Invalid parameter: " + key)
239
+ # if key in ["C0", "C1", "C2", "C3"]:
240
+ # self.set_C(key, par[key])
241
+ # else:
242
+ setattr(self, key, par[key])
243
+
244
+ self.prepare_input()
245
+
246
+ obj = _JR_sde(self.N,
247
+ self.dt,
248
+ self.t_transition,
249
+ self.t_end,
250
+ self.G,
251
+ self.weights,
252
+ self.initial_state,
253
+ self.A,
254
+ self.B,
255
+ self.a,
256
+ self.b,
257
+ self.r,
258
+ self.v0,
259
+ self.vmax,
260
+ self.C0,
261
+ self.C1,
262
+ self.C2,
263
+ self.C3,
264
+ self.noise_mu,
265
+ self.noise_std,
266
+ self.noise_seed)
267
+
268
+ if self.method == 'euler':
269
+ obj.eulerIntegrate()
270
+ elif self.method == 'heun':
271
+ obj.heunIntegrate()
272
+ else:
273
+ print("unkown integratiom method")
274
+ exit(0)
275
+
276
+ sol = np.asarray(obj.get_coordinates()).T
277
+ times = np.asarray(obj.get_times())
278
+
279
+ del obj
280
+
281
+ return {"t": times, "x": sol}
282
+
283
+
284
+ ############################ Jansen-Rit sdde ##################################
285
+
286
+ class JR_sdde:
287
+ pass
288
+
289
+ valid_params = ["weights", "delays", "dt", "t_end", "G", "A", "a", "B", "b", "mu",
290
+ "nstart", "t_end", "t_transition", "sigma", "C", "record_step",
291
+ "C0", "C1", "C2", "C3", "vmax", "r", "v0", "output",
292
+ 'sti_ti', 'sti_duration', 'sti_amplitude', 'sti_gain',
293
+ "noise_seed", "seed", "method"]
294
+ # -------------------------------------------------------------------------
295
+
296
+ def __init__(self, par={}) -> None:
297
+
298
+ self.check_parameters(par)
299
+ _par = self.get_default_parameters()
300
+ _par.update(par)
301
+
302
+ for item in _par.items():
303
+ setattr(self, item[0], item[1])
304
+
305
+ if self.seed is not None:
306
+ np.random.seed(self.seed)
307
+
308
+ self.noise_seed = 1 if self.noise_seed else 0
309
+ assert (self.weights is not None), "weights must be provided"
310
+ assert (self.delays is not None), "delays must be provided"
311
+ self.N = self.num_nodes = len(self.weights)
312
+
313
+ self.C0 = check_sequence(self.C0, self.N)
314
+ self.C1 = check_sequence(self.C1, self.N)
315
+ self.C2 = check_sequence(self.C2, self.N)
316
+ self.C3 = check_sequence(self.C3, self.N)
317
+ self.sti_amplitude = check_sequence(self.sti_amplitude, self.N)
318
+
319
+ if self.initial_state is None:
320
+ self.INITIAL_STATE_SET = False
321
+ os.makedirs(join(self.output), exist_ok=True)
322
+
323
+ def check_parameters(self, par):
324
+ '''
325
+ check if the parameters are valid
326
+ '''
327
+ for key in par.keys():
328
+ if key not in self.valid_params:
329
+ raise ValueError("Invalid parameter: " + key)
330
+ # -------------------------------------------------------------------------
331
+
332
+ def get_default_parameters(self):
333
+ '''
334
+ get default parameters for the system of JR equations.
335
+ '''
336
+
337
+ param = {
338
+ "dt": 0.01,
339
+ "G": 0.01,
340
+ "mu": 0.22,
341
+ "sigma": 0.005,
342
+ "dim": 6,
343
+ "A": 3.25,
344
+ "a": 0.1,
345
+ "B": 22.0,
346
+ "b": 0.05,
347
+ "v0": 6.0,
348
+ "vmax": 0.005,
349
+ "r": 0.56,
350
+ "C0": 135.0 * 1.0,
351
+ "C1": 135.0 * 0.8,
352
+ "C2": 135.0 * 0.25,
353
+ "C3": 135.0 * 0.25,
354
+ 'sti_ti': 0.0,
355
+ 'sti_duration': 0.0,
356
+ 'sti_amplitude': 0.0, # scalar or sequence of length N
357
+ 'sti_gain': 0.0,
358
+ "noise_seed": False,
359
+ "seed": None,
360
+ "initial_state": None,
361
+ "method": "heun",
362
+ "output": "output",
363
+ "t_end": 2000.0,
364
+ "t_transition": 1000.0
365
+ }
366
+
367
+ return param
368
+ # -------------------------------------------------------------------------
369
+
370
+ def prepare_stimulus(self, sti_gain, sti_ti):
371
+ '''
372
+ prepare stimulation parameteres
373
+ '''
374
+ if np.abs(sti_gain) > 0.0:
375
+ assert (
376
+ sti_ti >= self.t_transition), "stimulation must start after transition"
377
+ # -------------------------------------------------------------------------
378
+
379
+ def set_initial_state(self):
380
+ '''
381
+ set initial state for the system of JR equations with N nodes.
382
+ '''
383
+ self.initial_state = set_initial_state(self.num_nodes, self.seed)
384
+ self.INITIAL_STATE_SET = True
385
+ # -------------------------------------------------------------------------
386
+
387
+ # def set_C(self, label, val_dict):
388
+ # indices = val_dict['indices']
389
+
390
+ # if indices is None:
391
+ # indices = [list(range(self.N))]
392
+
393
+ # values = val_dict['value']
394
+ # if isinstance(values, np.ndarray):
395
+ # values = values.tolist()
396
+ # if not isinstance(values, list):
397
+ # values = [values]
398
+
399
+ # assert (len(indices) == len(values))
400
+ # C = getattr(self, label)
401
+
402
+ # for i in range(len(values)):
403
+ # C[indices[i]] = values[i]
404
+ # -------------------------------------------------------------------------
405
+
406
+ def prepare_input(self):
407
+ '''
408
+ prepare input parameters for C++ code.
409
+ '''
410
+
411
+ self.dt = float(self.dt)
412
+ self.t_transition = float(self.t_transition)
413
+ self.t_end = float(self.t_end)
414
+ self.G = float(self.G)
415
+ self.A = float(self.A)
416
+ self.B = float(self.B)
417
+ self.a = float(self.a)
418
+ self.b = float(self.b)
419
+ self.r = float(self.r)
420
+ self.v0 = float(self.v0)
421
+ self.vmax = float(self.vmax)
422
+ self.C0 = np.asarray(self.C0)
423
+ self.C1 = np.asarray(self.C1)
424
+ self.C2 = np.asarray(self.C2)
425
+ self.C3 = np.asarray(self.C3)
426
+ self.sti_amplitude = np.asarray(self.sti_amplitude)
427
+ self.sti_gain = float(self.sti_gain)
428
+ self.sti_ti = float(self.sti_ti)
429
+ self.sti_duration = float(self.sti_duration)
430
+ self.mu = float(self.mu)
431
+ self.sigma = float(self.sigma)
432
+ self.noise_seed = int(self.noise_seed)
433
+ self.initial_state = np.asarray(self.initial_state)
434
+ self.weights = np.asarray(self.weights)
435
+ self.delays = np.asarray(self.delays)
436
+ # -------------------------------------------------------------------------
437
+
438
+ def run(self, par={}, x0=None, verbose=False):
439
+ '''
440
+ Integrate the system of equations for Jansen-Rit model.
441
+ '''
442
+
443
+ if x0 is None:
444
+ if not self.INITIAL_STATE_SET:
445
+ self.set_initial_state()
446
+ if verbose:
447
+ print("initial state set by default")
448
+ else:
449
+ assert (len(x0) == self.num_nodes * self.dim)
450
+ self.initial_state = x0
451
+ self.INITIAL_STATE_SET = True
452
+
453
+ for key in par.keys():
454
+ if key not in self.valid_params:
455
+ raise ValueError("Invalid parameter: " + key)
456
+ # if key in ["C0", "C1", "C2", "C3"]:
457
+ # self.set_C(key, par[key])
458
+ # else:
459
+ setattr(self, key, par[key])
460
+
461
+ self.prepare_input()
462
+ obj = _JR_sdde(self.dt,
463
+ self.initial_state,
464
+ self.weights,
465
+ self.delays,
466
+ self.G,
467
+ self.dim,
468
+ self.A,
469
+ self.B,
470
+ self.a,
471
+ self.b,
472
+ self.r,
473
+ self.v0,
474
+ self.vmax,
475
+ self.C0,
476
+ self.C1,
477
+ self.C2,
478
+ self.C3,
479
+ self.sti_amplitude,
480
+ self.sti_gain,
481
+ self.sti_ti,
482
+ self.sti_duration,
483
+ self.mu,
484
+ self.sigma,
485
+ self.t_transition,
486
+ self.t_end,
487
+ self.noise_seed)
488
+ obj.integrate(self.method)
489
+ nstart = int((np.max(self.delays)) / self.dt) + 1
490
+ t = np.asarray(obj.get_t())[:-nstart]
491
+ y = np.asarray(obj.get_y())[:, :-nstart]
492
+ sti_vector = np.asarray(obj.get_sti_vector())[:-nstart]
493
+
494
+ return {"t": t, "x": y, "sti": sti_vector}
495
+
496
+ ############################# helper functions ################################
497
+
498
+
499
+ def check_sequence(x, n):
500
+ '''
501
+ check if x is a scalar or a sequence of length n
502
+
503
+ parameters
504
+ ----------
505
+ x: scalar or sequence of length n
506
+ n: number of nodes
507
+
508
+ returns
509
+ -------
510
+ x: sequence of length n
511
+ '''
512
+ if isinstance(x, (np.ndarray, list, tuple)):
513
+ assert (len(x) == n), f" variable must be a sequence of length {n}"
514
+ return x
515
+ else:
516
+ return x * np.ones(n)
517
+
518
+
519
+ def set_initial_state(nn, seed=None):
520
+ '''
521
+ set initial state for the system of JR equations with N nodes.
522
+
523
+ parameters
524
+ ----------
525
+ nn: number of nodes
526
+ seed: random seed
527
+
528
+ returns
529
+ -------
530
+ y: initial state of length 6N
531
+
532
+ '''
533
+ if seed is not None:
534
+ np.random.seed(seed)
535
+
536
+ y0 = np.random.uniform(-1, 1, nn)
537
+ y1 = np.random.uniform(-500, 500, nn)
538
+ y2 = np.random.uniform(-50, 50, nn)
539
+ y3 = np.random.uniform(-6, 6, nn)
540
+ y4 = np.random.uniform(-20, 20, nn)
541
+ y5 = np.random.uniform(-500, 500, nn)
542
+
543
+ return np.hstack((y0, y1, y2, y3, y4, y5))