vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. vbi/__init__.py +37 -0
  2. vbi/_version.py +17 -0
  3. vbi/dataset/__init__.py +0 -0
  4. vbi/dataset/connectivity_84/centers.txt +84 -0
  5. vbi/dataset/connectivity_84/centres.txt +84 -0
  6. vbi/dataset/connectivity_84/cortical.txt +84 -0
  7. vbi/dataset/connectivity_84/tract_lengths.txt +84 -0
  8. vbi/dataset/connectivity_84/weights.txt +84 -0
  9. vbi/dataset/connectivity_88/Aud_88.txt +88 -0
  10. vbi/dataset/connectivity_88/Bold.npz +0 -0
  11. vbi/dataset/connectivity_88/Labels.txt +17 -0
  12. vbi/dataset/connectivity_88/Region_labels.txt +88 -0
  13. vbi/dataset/connectivity_88/tract_lengths.txt +88 -0
  14. vbi/dataset/connectivity_88/weights.txt +88 -0
  15. vbi/feature_extraction/__init__.py +1 -0
  16. vbi/feature_extraction/calc_features.py +293 -0
  17. vbi/feature_extraction/features.json +535 -0
  18. vbi/feature_extraction/features.py +2124 -0
  19. vbi/feature_extraction/features_settings.py +374 -0
  20. vbi/feature_extraction/features_utils.py +1357 -0
  21. vbi/feature_extraction/infodynamics.jar +0 -0
  22. vbi/feature_extraction/utility.py +507 -0
  23. vbi/inference.py +98 -0
  24. vbi/models/__init__.py +0 -0
  25. vbi/models/cpp/__init__.py +0 -0
  26. vbi/models/cpp/_src/__init__.py +0 -0
  27. vbi/models/cpp/_src/__pycache__/mpr_sde.cpython-310.pyc +0 -0
  28. vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
  29. vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
  30. vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  31. vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  32. vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  33. vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
  34. vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
  35. vbi/models/cpp/_src/bold.hpp +303 -0
  36. vbi/models/cpp/_src/do.hpp +167 -0
  37. vbi/models/cpp/_src/do.i +17 -0
  38. vbi/models/cpp/_src/do.py +467 -0
  39. vbi/models/cpp/_src/do_wrap.cxx +12811 -0
  40. vbi/models/cpp/_src/jr_sdde.hpp +352 -0
  41. vbi/models/cpp/_src/jr_sdde.i +19 -0
  42. vbi/models/cpp/_src/jr_sdde.py +688 -0
  43. vbi/models/cpp/_src/jr_sdde_wrap.cxx +18718 -0
  44. vbi/models/cpp/_src/jr_sde.hpp +264 -0
  45. vbi/models/cpp/_src/jr_sde.i +17 -0
  46. vbi/models/cpp/_src/jr_sde.py +470 -0
  47. vbi/models/cpp/_src/jr_sde_wrap.cxx +13406 -0
  48. vbi/models/cpp/_src/km_sde.hpp +158 -0
  49. vbi/models/cpp/_src/km_sde.i +19 -0
  50. vbi/models/cpp/_src/km_sde.py +671 -0
  51. vbi/models/cpp/_src/km_sde_wrap.cxx +17367 -0
  52. vbi/models/cpp/_src/makefile +52 -0
  53. vbi/models/cpp/_src/mpr_sde.hpp +327 -0
  54. vbi/models/cpp/_src/mpr_sde.i +19 -0
  55. vbi/models/cpp/_src/mpr_sde.py +711 -0
  56. vbi/models/cpp/_src/mpr_sde_wrap.cxx +18618 -0
  57. vbi/models/cpp/_src/utility.hpp +307 -0
  58. vbi/models/cpp/_src/vep.hpp +171 -0
  59. vbi/models/cpp/_src/vep.i +16 -0
  60. vbi/models/cpp/_src/vep.py +464 -0
  61. vbi/models/cpp/_src/vep_wrap.cxx +12968 -0
  62. vbi/models/cpp/_src/wc_ode.hpp +294 -0
  63. vbi/models/cpp/_src/wc_ode.i +19 -0
  64. vbi/models/cpp/_src/wc_ode.py +686 -0
  65. vbi/models/cpp/_src/wc_ode_wrap.cxx +24263 -0
  66. vbi/models/cpp/damp_oscillator.py +143 -0
  67. vbi/models/cpp/jansen_rit.py +543 -0
  68. vbi/models/cpp/km.py +187 -0
  69. vbi/models/cpp/mpr.py +289 -0
  70. vbi/models/cpp/vep.py +150 -0
  71. vbi/models/cpp/wc.py +216 -0
  72. vbi/models/cupy/__init__.py +0 -0
  73. vbi/models/cupy/bold.py +111 -0
  74. vbi/models/cupy/ghb.py +284 -0
  75. vbi/models/cupy/jansen_rit.py +473 -0
  76. vbi/models/cupy/km.py +224 -0
  77. vbi/models/cupy/mpr.py +475 -0
  78. vbi/models/cupy/mpr_modified_bold.py +12 -0
  79. vbi/models/cupy/utils.py +184 -0
  80. vbi/models/numba/__init__.py +0 -0
  81. vbi/models/numba/_ww_EI.py +444 -0
  82. vbi/models/numba/damp_oscillator.py +162 -0
  83. vbi/models/numba/ghb.py +208 -0
  84. vbi/models/numba/mpr.py +383 -0
  85. vbi/models/pytorch/__init__.py +0 -0
  86. vbi/models/pytorch/data/default_parameters.npz +0 -0
  87. vbi/models/pytorch/data/input/ROI_sim.mat +0 -0
  88. vbi/models/pytorch/data/input/fc_test.csv +68 -0
  89. vbi/models/pytorch/data/input/fc_train.csv +68 -0
  90. vbi/models/pytorch/data/input/fc_vali.csv +68 -0
  91. vbi/models/pytorch/data/input/fcd_test.mat +0 -0
  92. vbi/models/pytorch/data/input/fcd_test_high_window.mat +0 -0
  93. vbi/models/pytorch/data/input/fcd_test_low_window.mat +0 -0
  94. vbi/models/pytorch/data/input/fcd_train.mat +0 -0
  95. vbi/models/pytorch/data/input/fcd_vali.mat +0 -0
  96. vbi/models/pytorch/data/input/myelin.csv +68 -0
  97. vbi/models/pytorch/data/input/rsfc_gradient.csv +68 -0
  98. vbi/models/pytorch/data/input/run_label_testset.mat +0 -0
  99. vbi/models/pytorch/data/input/sc_test.csv +68 -0
  100. vbi/models/pytorch/data/input/sc_train.csv +68 -0
  101. vbi/models/pytorch/data/input/sc_vali.csv +68 -0
  102. vbi/models/pytorch/data/obs_kong0.npz +0 -0
  103. vbi/models/pytorch/ww_sde_kong.py +570 -0
  104. vbi/models/tvbk/__init__.py +9 -0
  105. vbi/models/tvbk/tvbk_wrapper.py +166 -0
  106. vbi/models/tvbk/utils.py +72 -0
  107. vbi/papers/__init__.py +0 -0
  108. vbi/papers/pavlides_pcb_2015/pavlides.py +211 -0
  109. vbi/tests/__init__.py +0 -0
  110. vbi/tests/_test_mpr_nb.py +36 -0
  111. vbi/tests/test_features.py +355 -0
  112. vbi/tests/test_ghb_cupy.py +90 -0
  113. vbi/tests/test_mpr_cupy.py +49 -0
  114. vbi/tests/test_mpr_numba.py +84 -0
  115. vbi/tests/test_suite.py +19 -0
  116. vbi/utils.py +402 -0
  117. vbi-0.1.3.dist-info/METADATA +166 -0
  118. vbi-0.1.3.dist-info/RECORD +121 -0
  119. vbi-0.1.3.dist-info/WHEEL +5 -0
  120. vbi-0.1.3.dist-info/licenses/LICENSE +201 -0
  121. vbi-0.1.3.dist-info/top_level.txt +1 -0
vbi/models/cpp/wc.py ADDED
@@ -0,0 +1,216 @@
1
+ import numpy as np
2
+
3
+ try:
4
+ from vbi.models.cpp._src.wc_ode import WC_ode as _WC_ode
5
+ except ImportError as e:
6
+ print(f"Could not import modules: {e}, probably C++ code is not compiled or properly linked.")
7
+
8
+
9
+ ################################## Wilson-Cowan ode ###########################
10
+ ###############################################################################
11
+
12
+ class WC_ode(object):
13
+ r"""
14
+ **References**:
15
+
16
+ .. [WC_1972] Wilson, H.R. and Cowan, J.D. *Excitatory and inhibitory
17
+ interactions in localized populations of model neurons*, Biophysical
18
+ journal, 12: 1-24, 1972.
19
+ .. [WC_1973] Wilson, H.R. and Cowan, J.D *A Mathematical Theory of the
20
+ Functional Dynamics of Cortical and Thalamic Nervous Tissue*
21
+ .. [D_2011] Daffertshofer, A. and van Wijk, B. *On the influence of
22
+ amplitude on the connectivity between phases*
23
+ Frontiers in Neuroinformatics, July, 2011
24
+
25
+ Used Eqns 11 and 12 from [WC_1972]_ in ``rhs``. P and Q represent external
26
+ inputs, which when exploring the phase portrait of the local model are set
27
+ to constant values. However in the case of a full network, P and Q are the
28
+ entry point to our long range and local couplings, that is, the activity
29
+ from all other nodes is the external input to the local population [WC_1973]_, [D_2011]_ .
30
+
31
+ The default parameters are taken from figure 4 of [WC_1972]_, pag. 10.
32
+
33
+ """
34
+
35
+
36
+
37
+ def __init__(self, par={}) -> None:
38
+
39
+ self.valid_params = self.get_default_parameters().keys()
40
+ self.check_parameters(par)
41
+ self._par = self.get_default_parameters()
42
+ self._par.update(par)
43
+
44
+ for item in self._par.items():
45
+ name = item[0]
46
+ value = item[1]
47
+ setattr(self, name, value)
48
+
49
+ if self.seed is not None:
50
+ np.random.seed(self.seed)
51
+
52
+ self.N = self.num_nodes = np.asarray(self.weights).shape[0]
53
+
54
+ def __str__(self) -> str:
55
+ print("Wilson-Cowan model.")
56
+ print("--------------------")
57
+ for item in self._par.items():
58
+ print(f"{item[0]}, : , {item[1]}")
59
+ return ""
60
+
61
+ def __call__(self):
62
+ print("Wilson-Cowan model.")
63
+ return self._par
64
+
65
+ def check_parameters(self, par):
66
+ for key in par.keys():
67
+ if key not in self.valid_params:
68
+ raise ValueError(f"Invalid parameter: {key}")
69
+
70
+ def get_default_parameters(self):
71
+ par = {
72
+ 'c_ee': 16.0,
73
+ 'c_ei': 12.0,
74
+ 'c_ie': 15.0,
75
+ 'c_ii': 3.0,
76
+ 'tau_e': 8.0,
77
+ 'tau_i': 8.0,
78
+ 'a_e': 1.3,
79
+ 'a_i': 2.0,
80
+ 'b_e': 4.0,
81
+ 'b_i': 3.7,
82
+ 'c_e': 1.0,
83
+ 'c_i': 1.0,
84
+ 'theta_e': 0.0,
85
+ 'theta_i': 0.0,
86
+ 'r_e': 1.0,
87
+ 'r_i': 1.0,
88
+ 'k_e': 0.994,
89
+ 'k_i': 0.999,
90
+ 'alpha_e': 1.0,
91
+ 'alpha_i': 1.0,
92
+ 'P': 1.25,
93
+ 'Q': 0.0,
94
+ 'g_e': 0.0,
95
+ 'g_i': 0.0,
96
+ "method": "heun",
97
+ "weights": None,
98
+ 'seed': None,
99
+ "t_end": 300.0,
100
+ "t_cut": 0.0,
101
+ "dt": 0.01,
102
+ "noise_seed": False,
103
+ "output": "output",
104
+ }
105
+ return par
106
+
107
+ def set_initial_state(self, seed=None):
108
+
109
+ if seed is not None:
110
+ np.random.seed(seed)
111
+ self.initial_state = np.random.rand(2*self.num_nodes)
112
+
113
+ def prepare_input(self):
114
+ self.noise_seed = int(self.noise_seed)
115
+ self.t_end = float(self.t_end)
116
+ self.t_cut = float(self.t_cut)
117
+ self.dt = float(self.dt)
118
+ self.P = check_sequence(self.P, self.num_nodes)
119
+ self.Q = check_sequence(self.Q, self.num_nodes)
120
+ self.c_ee = float(self.c_ee)
121
+ self.c_ei = float(self.c_ei)
122
+ self.c_ie = float(self.c_ie)
123
+ self.c_ii = float(self.c_ii)
124
+ self.tau_e = float(self.tau_e)
125
+ self.tau_i = float(self.tau_i)
126
+ self.a_e = float(self.a_e)
127
+ self.a_i = float(self.a_i)
128
+ self.b_e = float(self.b_e)
129
+ self.b_i = float(self.b_i)
130
+ self.c_e = float(self.c_e)
131
+ self.c_i = float(self.c_i)
132
+ self.theta_e = float(self.theta_e)
133
+ self.theta_i = float(self.theta_i)
134
+ self.r_e = float(self.r_e)
135
+ self.r_i = float(self.r_i)
136
+ self.k_e = float(self.k_e)
137
+ self.k_i = float(self.k_i)
138
+ self.alpha_e = float(self.alpha_e)
139
+ self.alpha_i = float(self.alpha_i)
140
+ self.g_e = float(self.g_e)
141
+ self.g_i = float(self.g_i)
142
+ self.method = str(self.method)
143
+ self.weights = np.asarray(self.weights)
144
+
145
+
146
+ def run(self, par={}, x0=None, verbose=False):
147
+
148
+ '''
149
+ Integrate the system of equations for the Wilson-Cowan model.
150
+
151
+ Parameters
152
+ ----------
153
+ par : dict
154
+ Dictionary with parameters of the model.
155
+ x0 : array-like
156
+ Initial state of the system.
157
+ verbose : bool
158
+ If True, print the integration progress.
159
+
160
+ '''
161
+
162
+ if x0 is None:
163
+ self.set_initial_state()
164
+ if verbose:
165
+ print("Initial state set by default.")
166
+ else:
167
+ self.initial_state = x0
168
+
169
+ for key in par.keys():
170
+ if key not in self.valid_params:
171
+ raise ValueError(f"Invalid parameter: {key}")
172
+ setattr(self, key, par[key]['value'])
173
+
174
+ self.prepare_input()
175
+
176
+ obj = _WC_ode(
177
+ self.N, self.dt, self.P, self.Q, self.initial_state, self.weights,
178
+ self.t_end, self.t_cut, self.c_ee, self.c_ei, self.c_ie, self.c_ii,
179
+ self.tau_e, self.tau_i, self.a_e, self.a_i, self.b_e, self.b_i,
180
+ self.c_e, self.c_i, self.theta_e, self.theta_i, self.r_e, self.r_i,
181
+ self.k_e, self.k_i, self.alpha_e, self.alpha_i, self.g_e, self.g_i,
182
+ self.noise_seed
183
+ )
184
+
185
+ if self.method == "euler":
186
+ obj.eulerIntegrate()
187
+ elif self.method == "heun":
188
+ obj.heunIntegrate()
189
+ elif self.method == "rk4":
190
+ obj.rk4Integrate()
191
+
192
+ t = np.asarray(obj.get_times())
193
+ x = np.asarray(obj.get_states()).T
194
+
195
+ del obj
196
+ return {"t": t, "x": x}
197
+
198
+
199
+ def check_sequence(x, n):
200
+ '''
201
+ check if x is a scalar or a sequence of length n
202
+
203
+ parameters
204
+ ----------
205
+ x: scalar or sequence of length n
206
+ n: number of nodes
207
+
208
+ returns
209
+ -------
210
+ x: sequence of length n
211
+ '''
212
+ if isinstance(x, (np.ndarray, list, tuple)):
213
+ assert (len(x) == n), f" variable must be a sequence of length {n}"
214
+ return x
215
+ else:
216
+ return x * np.ones(n)
File without changes
@@ -0,0 +1,111 @@
1
+ import numpy as np
2
+
3
+
4
+ class BoldStephan2008:
5
+
6
+ def __init__(self, par: dict = {}) -> None:
7
+
8
+ self._par = self.get_default_parameters()
9
+ self.valid_parameters = list(self._par.keys())
10
+ self.check_parameters(par)
11
+ self._par.update(par)
12
+
13
+ for key, value in self._par.items():
14
+ setattr(self, key, value)
15
+
16
+ def _prepare(self, nn, ns, xp, n_steps, bold_decimate):
17
+ s = xp.zeros((2, nn, ns), dtype=self.dtype)
18
+ f = xp.zeros((2, nn, ns), dtype=self.dtype)
19
+ ftilde = xp.zeros((2, nn, ns), dtype=self.dtype)
20
+ vtilde = xp.zeros((2, nn, ns), dtype=self.dtype)
21
+ qtilde = xp.zeros((2, nn, ns), dtype=self.dtype)
22
+ v = xp.zeros((2, nn, ns), dtype=self.dtype)
23
+ q = xp.zeros((2, nn, ns), dtype=self.dtype)
24
+ vv = np.zeros((n_steps // bold_decimate, nn, ns), dtype="f")
25
+ qq = np.zeros((n_steps // bold_decimate, nn, ns), dtype="f")
26
+ s[0] = 1
27
+ f[0] = 1
28
+ v[0] = 1
29
+ q[0] = 1
30
+ ftilde[0] = 0
31
+ vtilde[0] = 0
32
+ qtilde[0] = 0
33
+
34
+ return {
35
+ "s": s,
36
+ "f": f,
37
+ "ftilde": ftilde,
38
+ "vtilde": vtilde,
39
+ "qtilde": qtilde,
40
+ "v": v,
41
+ "q": q,
42
+ "vv": vv,
43
+ "qq": qq,
44
+ }
45
+
46
+ def check_parameters(self, par):
47
+ for key in par.keys():
48
+ if key not in self.valid_parameters:
49
+ raise ValueError(f"Invalid parameter {key:s} provided.")
50
+
51
+ def get_default_parameters(self):
52
+
53
+ theta0 = 41.0
54
+ Eo = 0.42
55
+ TE = 0.05
56
+ epsilon = 0.36
57
+ r0 = 26.0
58
+ k1 = 4.3 * theta0 * Eo * TE
59
+ k2 = epsilon * r0 * Eo * TE
60
+ k3 = 1 - epsilon
61
+
62
+ par = {
63
+ "kappa": 0.7,
64
+ "gamma": 0.5,
65
+ "tau": 1.0,
66
+ "alpha": 0.35,
67
+ "epsilon": epsilon,
68
+ "Eo": Eo,
69
+ "TE": TE,
70
+ "vo": 0.09,
71
+ "r0": r0,
72
+ "theta0": theta0,
73
+ "rtol": 1e-6,
74
+ "atol": 1e-9,
75
+ "k1": k1,
76
+ "k2": k2,
77
+ "k3": k3,
78
+ }
79
+ return par
80
+
81
+ def bold_step(self, r_in, s, f, ftilde, vtilde, qtilde, v, q, dt, P):
82
+
83
+ kappa, gamma, alpha, tau, Eo = P
84
+ ialpha = 1 / alpha
85
+
86
+ s[1] = s[0] + dt * (r_in - kappa * s[0] - gamma * (f[0] - 1))
87
+ f[0] = np.clip(f[0], 1, None)
88
+ ftilde[1] = ftilde[0] + dt * (s[0] / f[0])
89
+ fv = v[0] ** ialpha # outflow
90
+ vtilde[1] = vtilde[0] + dt * ((f[0] - fv) / (tau * v[0]))
91
+ q[0] = np.clip(q[0], 0.01, None)
92
+ ff = (1 - (1 - Eo) ** (1 / f[0])) / Eo # oxygen extraction
93
+ qtilde[1] = qtilde[0] + dt * ((f[0] * ff - fv * q[0] / v[0]) / (tau * q[0]))
94
+
95
+ f[1] = np.exp(ftilde[1])
96
+ v[1] = np.exp(vtilde[1])
97
+ q[1] = np.exp(qtilde[1])
98
+
99
+ f[0] = f[1]
100
+ s[0] = s[1]
101
+ ftilde[0] = ftilde[1]
102
+ vtilde[0] = vtilde[1]
103
+ qtilde[0] = qtilde[1]
104
+ v[0] = v[1]
105
+ q[0] = q[1]
106
+
107
+
108
+ class BoldTVB:
109
+
110
+ def __init__(self):
111
+ pass
vbi/models/cupy/ghb.py ADDED
@@ -0,0 +1,284 @@
1
+ import tqdm
2
+ import cupy as cp
3
+ from copy import copy
4
+ from vbi.models.cupy.utils import *
5
+
6
+
7
+ class GHB_sde:
8
+ """
9
+ Generic Hopf model cupy implementation
10
+
11
+ Parameters
12
+ ----------
13
+ par: dict
14
+ Dictionary of parameters
15
+ - 'G': Global coupling
16
+ - 'dt': Time step
17
+
18
+ """
19
+
20
+ epsilon = 0.5
21
+ itaus = 1.25
22
+ itauf = 2.5
23
+ itauo = 1.02040816327
24
+ ialpha = 5.
25
+ E0 = 0.4
26
+ V0 = 4.
27
+ K1 = 2.77264
28
+ K2 = 0.572
29
+ K3 = -0.43
30
+
31
+ def __init__(self, par: dict = {}) -> None:
32
+
33
+ self.valid_params = list(self.get_default_parameters().keys())
34
+ self.check_parameters(par)
35
+ self.par_ = self.get_default_parameters()
36
+ self.par_.update(par)
37
+
38
+ for item in self.par_.items():
39
+ setattr(self, *item)
40
+
41
+ self.xp = get_module(self.engine)
42
+ if self.seed is not None:
43
+ self.xp.random.seed(self.seed)
44
+
45
+ def __call__(self):
46
+ print("GHB model")
47
+ return self.par_
48
+
49
+ def __str__(self):
50
+ print("GHB model")
51
+ print("-" * 50)
52
+ for item in self.par_items():
53
+ name = item[0]
54
+ value = item[1]
55
+ print(f"{name} : {value}")
56
+ return ""
57
+
58
+ def set_initial_state(self):
59
+ self.initial_state = set_initial_state(
60
+ self.nn,
61
+ self.num_sim,
62
+ self.engine,
63
+ self.seed,
64
+ self.same_initial_state,
65
+ self.dtype,
66
+ )
67
+
68
+ def check_parameters(self, par):
69
+ for key in par.keys():
70
+ assert key in self.valid_params, "Invalid parameter: " + key
71
+
72
+ def get_default_parameters(self):
73
+ par = {
74
+ "G": 25.0,
75
+ "t_cut": 0,
76
+ "dt": 0.01,
77
+ "eta": None,
78
+ "num_sim": 1,
79
+ "sigma": 0.1,
80
+ "seed": None,
81
+ "decimate": 1,
82
+ "omega": None,
83
+ "t_end": 10.0,
84
+ "engine": "cpu",
85
+ "weights": None,
86
+ "dtype": "float",
87
+ "method": "euler",
88
+ "output": "output",
89
+ "initial_state": None,
90
+ "same_initial_state": False,
91
+ }
92
+ return par
93
+
94
+ def prepare_input(self):
95
+ self.G = self.xp.array(self.G, dtype=self.dtype)
96
+ assert self.weights is not None, "weights not provided"
97
+ self.weights = self.xp.array(self.weights, dtype=self.dtype)
98
+ self.weights = self.weights.reshape(self.weights.shape+(1,))
99
+ self.weights = move_data(self.weights, self.engine)
100
+ self.nn = self.num_nodes = self.weights.shape[0]
101
+
102
+
103
+ if self.initial_state is None:
104
+ self.set_initial_state()
105
+ else:
106
+ self.initial_state = move_data(
107
+ self.initial_state, self.engine)
108
+
109
+
110
+ self.eta = prepare_vec(self.eta, self.num_sim, self.engine, self.dtype)
111
+ self.omega = prepare_vec(self.omega, self.num_sim, self.engine, self.dtype)
112
+
113
+ def f_sys(self, x0, t):
114
+
115
+ G = self.G
116
+ xp = self.xp
117
+ nn = self.nn
118
+ eta = self.eta
119
+ x = x0[:nn, :]
120
+ y = x0[nn:, :]
121
+ ns = self.num_sim
122
+ sc = self.weights
123
+ omega = self.omega
124
+
125
+ gx = xp.sum(sc * (x - x[:, None]), axis=1)
126
+ gy = xp.sum(sc * (y - y[:, None]), axis=1)
127
+ dxdt = xp.zeros((2 * nn, ns)).astype(self.dtype)
128
+
129
+ dxdt[:nn, :] = x * (eta - x * x - y * y) - omega * y + G * gx
130
+ dxdt[nn:, :] = y * (eta - x * x - y * y) + omega * x + G * gy
131
+
132
+ return dxdt
133
+
134
+ def f_fmri(self, xin, x, t):
135
+
136
+ E0 = self.E0
137
+ xp = self.xp
138
+ nn = self.num_nodes
139
+ ns = self.num_sim
140
+ itauf = self.itauf
141
+ itauo = self.itauo
142
+ itaus = self.itaus
143
+ ialpha = self.ialpha
144
+
145
+ dxdt = xp.zeros((4 * nn, ns)).astype(self.dtype)
146
+ s = x[:nn, :]
147
+ f = x[nn : 2 * nn, :]
148
+ v = x[2 * nn : 3 * nn, :]
149
+ q = x[3 * nn :, :]
150
+
151
+ dxdt[:nn, :] = xin[:nn, :] - itaus * s - itauf * (f - 1.0)
152
+ dxdt[nn : (2 * nn), :] = s
153
+ dxdt[(2 * nn) : (3 * nn), :] = itauo * (f - v ** (ialpha))
154
+ dxdt[3 * nn :, :] = (itauo) * (
155
+ (f * (1.0 - (1.0 - E0) ** (1.0 / f)) / E0) - (v ** (ialpha)) * (q / v)
156
+ )
157
+
158
+ return dxdt
159
+
160
+ def heun_sde_step(self, x0, t):
161
+
162
+ xp = self.xp
163
+ dt = self.dt
164
+ dx = self.f_sys(x0, t) * dt
165
+ dW = self.sigma * xp.random.normal(0, 1, size=x0.shape) * xp.sqrt(dt)
166
+ x1 = x0 + dx + dW
167
+ dx1 = self.f_sys(x1, t + dt) * dt
168
+ return x0 + 0.5 * (dx + dx1) + dW
169
+
170
+ def heun_ode_step(self, yin, y, t):
171
+
172
+ dt = self.dt
173
+ dy = self.f_fmri(yin, y, t) * dt
174
+ y1 = y + dy
175
+ dy1 = self.f_fmri(yin, y1, t + dt) * dt
176
+ return y + 0.5 * (dy + dy1)
177
+
178
+ def intg_fmri(self, yin, y, t):
179
+ """
180
+ Integrate one step of Balloon model
181
+
182
+ Parameters
183
+ ----------
184
+ yin: array
185
+ input
186
+ y: array [4*nn, ns]
187
+ state
188
+ t : float
189
+ current time
190
+
191
+ Returns
192
+ -------
193
+ bold: array [nn, ns]
194
+ BOLD signal
195
+ y: array [4*nn, ns]
196
+ updated state
197
+
198
+ """
199
+
200
+ V0 = self.V0
201
+ K1 = self.K1
202
+ K2 = self.K2
203
+ K3 = self.K3
204
+
205
+ nn = self.num_nodes
206
+ y = self.heun_ode_step(yin, y, t)
207
+ bold = V0 * (
208
+ K1 * (1.0 - y[(3 * nn) :, :])
209
+ + K2 * (1.0 - y[(3 * nn) :, :] / y[(2 * nn) : (3 * nn), :])
210
+ + K3 * (1.0 - y[(2 * nn) : (3 * nn), :])
211
+ )
212
+
213
+ return bold, y
214
+
215
+ def sync(self, engine="gpu"):
216
+ if engine == "gpu":
217
+ cp.cuda.Stream.null.synchronize()
218
+ else:
219
+ pass
220
+
221
+ def run(self, x0=None, verbose=True):
222
+ """
223
+ run ghb model
224
+ """
225
+ self.prepare_input()
226
+ dt = self.dt
227
+ xp = self.xp
228
+ ns = self.num_sim
229
+ nn = self.num_nodes
230
+ dec = self.decimate
231
+ engine = self.engine
232
+ t_cut = self.t_cut
233
+ n_steps = np.ceil(self.t_end / dt).astype(int)
234
+
235
+ y0_state = xp.zeros((4 * nn, ns)).astype(self.dtype)
236
+ y0_state[nn:, :] = 1.0
237
+ y0 = copy(self.initial_state)
238
+ bold = np.zeros((nn, n_steps // dec, ns)).astype(np.float32)
239
+
240
+ for it in tqdm.trange(n_steps, disable=not verbose, desc="Integrating"):
241
+ y0 = self.heun_sde_step(y0, it * dt)
242
+ bold_, y0_state = self.intg_fmri(y0, y0_state, it * dt)
243
+ self.sync(engine)
244
+ if it % dec == 0:
245
+ bold[:, it // dec, :] = bold_.get() if engine == "gpu" else bold_
246
+
247
+ t = np.arange(0, self.t_end, dec * dt).astype(np.float32)
248
+ bold = bold[:, t > t_cut, :]
249
+ t_bold = t[t > t_cut]
250
+ return {"t": t_bold, "bold": bold}
251
+
252
+
253
+ def set_initial_state(nn, ns, engine, seed=None, same_initial_state=False, dtype=float):
254
+ """
255
+ Set initial state
256
+
257
+ Parameters
258
+ ----------
259
+ nn : int
260
+ number of nodes
261
+ ns : int
262
+ number of simulations
263
+ engine : str
264
+ cpu or gpu
265
+ same_initial_condition : bool
266
+ same initial condition for all simulations
267
+ seed : int
268
+ random seed
269
+ dtype : str
270
+ float: float64
271
+ f : float32
272
+ """
273
+
274
+ if seed is not None:
275
+ np.random.seed(seed)
276
+
277
+ if same_initial_state:
278
+ y0 = np.random.rand(2 * nn)
279
+ y0 = repmat_vec(y0, ns, engine)
280
+ else:
281
+ y0 = np.random.rand(2 * nn, ns)
282
+ y0 = move_data(y0, engine)
283
+
284
+ return y0.astype(dtype)