vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. vbi/__init__.py +37 -0
  2. vbi/_version.py +17 -0
  3. vbi/dataset/__init__.py +0 -0
  4. vbi/dataset/connectivity_84/centers.txt +84 -0
  5. vbi/dataset/connectivity_84/centres.txt +84 -0
  6. vbi/dataset/connectivity_84/cortical.txt +84 -0
  7. vbi/dataset/connectivity_84/tract_lengths.txt +84 -0
  8. vbi/dataset/connectivity_84/weights.txt +84 -0
  9. vbi/dataset/connectivity_88/Aud_88.txt +88 -0
  10. vbi/dataset/connectivity_88/Bold.npz +0 -0
  11. vbi/dataset/connectivity_88/Labels.txt +17 -0
  12. vbi/dataset/connectivity_88/Region_labels.txt +88 -0
  13. vbi/dataset/connectivity_88/tract_lengths.txt +88 -0
  14. vbi/dataset/connectivity_88/weights.txt +88 -0
  15. vbi/feature_extraction/__init__.py +1 -0
  16. vbi/feature_extraction/calc_features.py +293 -0
  17. vbi/feature_extraction/features.json +535 -0
  18. vbi/feature_extraction/features.py +2124 -0
  19. vbi/feature_extraction/features_settings.py +374 -0
  20. vbi/feature_extraction/features_utils.py +1357 -0
  21. vbi/feature_extraction/infodynamics.jar +0 -0
  22. vbi/feature_extraction/utility.py +507 -0
  23. vbi/inference.py +98 -0
  24. vbi/models/__init__.py +0 -0
  25. vbi/models/cpp/__init__.py +0 -0
  26. vbi/models/cpp/_src/__init__.py +0 -0
  27. vbi/models/cpp/_src/__pycache__/mpr_sde.cpython-310.pyc +0 -0
  28. vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
  29. vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
  30. vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  31. vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  32. vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  33. vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
  34. vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
  35. vbi/models/cpp/_src/bold.hpp +303 -0
  36. vbi/models/cpp/_src/do.hpp +167 -0
  37. vbi/models/cpp/_src/do.i +17 -0
  38. vbi/models/cpp/_src/do.py +467 -0
  39. vbi/models/cpp/_src/do_wrap.cxx +12811 -0
  40. vbi/models/cpp/_src/jr_sdde.hpp +352 -0
  41. vbi/models/cpp/_src/jr_sdde.i +19 -0
  42. vbi/models/cpp/_src/jr_sdde.py +688 -0
  43. vbi/models/cpp/_src/jr_sdde_wrap.cxx +18718 -0
  44. vbi/models/cpp/_src/jr_sde.hpp +264 -0
  45. vbi/models/cpp/_src/jr_sde.i +17 -0
  46. vbi/models/cpp/_src/jr_sde.py +470 -0
  47. vbi/models/cpp/_src/jr_sde_wrap.cxx +13406 -0
  48. vbi/models/cpp/_src/km_sde.hpp +158 -0
  49. vbi/models/cpp/_src/km_sde.i +19 -0
  50. vbi/models/cpp/_src/km_sde.py +671 -0
  51. vbi/models/cpp/_src/km_sde_wrap.cxx +17367 -0
  52. vbi/models/cpp/_src/makefile +52 -0
  53. vbi/models/cpp/_src/mpr_sde.hpp +327 -0
  54. vbi/models/cpp/_src/mpr_sde.i +19 -0
  55. vbi/models/cpp/_src/mpr_sde.py +711 -0
  56. vbi/models/cpp/_src/mpr_sde_wrap.cxx +18618 -0
  57. vbi/models/cpp/_src/utility.hpp +307 -0
  58. vbi/models/cpp/_src/vep.hpp +171 -0
  59. vbi/models/cpp/_src/vep.i +16 -0
  60. vbi/models/cpp/_src/vep.py +464 -0
  61. vbi/models/cpp/_src/vep_wrap.cxx +12968 -0
  62. vbi/models/cpp/_src/wc_ode.hpp +294 -0
  63. vbi/models/cpp/_src/wc_ode.i +19 -0
  64. vbi/models/cpp/_src/wc_ode.py +686 -0
  65. vbi/models/cpp/_src/wc_ode_wrap.cxx +24263 -0
  66. vbi/models/cpp/damp_oscillator.py +143 -0
  67. vbi/models/cpp/jansen_rit.py +543 -0
  68. vbi/models/cpp/km.py +187 -0
  69. vbi/models/cpp/mpr.py +289 -0
  70. vbi/models/cpp/vep.py +150 -0
  71. vbi/models/cpp/wc.py +216 -0
  72. vbi/models/cupy/__init__.py +0 -0
  73. vbi/models/cupy/bold.py +111 -0
  74. vbi/models/cupy/ghb.py +284 -0
  75. vbi/models/cupy/jansen_rit.py +473 -0
  76. vbi/models/cupy/km.py +224 -0
  77. vbi/models/cupy/mpr.py +475 -0
  78. vbi/models/cupy/mpr_modified_bold.py +12 -0
  79. vbi/models/cupy/utils.py +184 -0
  80. vbi/models/numba/__init__.py +0 -0
  81. vbi/models/numba/_ww_EI.py +444 -0
  82. vbi/models/numba/damp_oscillator.py +162 -0
  83. vbi/models/numba/ghb.py +208 -0
  84. vbi/models/numba/mpr.py +383 -0
  85. vbi/models/pytorch/__init__.py +0 -0
  86. vbi/models/pytorch/data/default_parameters.npz +0 -0
  87. vbi/models/pytorch/data/input/ROI_sim.mat +0 -0
  88. vbi/models/pytorch/data/input/fc_test.csv +68 -0
  89. vbi/models/pytorch/data/input/fc_train.csv +68 -0
  90. vbi/models/pytorch/data/input/fc_vali.csv +68 -0
  91. vbi/models/pytorch/data/input/fcd_test.mat +0 -0
  92. vbi/models/pytorch/data/input/fcd_test_high_window.mat +0 -0
  93. vbi/models/pytorch/data/input/fcd_test_low_window.mat +0 -0
  94. vbi/models/pytorch/data/input/fcd_train.mat +0 -0
  95. vbi/models/pytorch/data/input/fcd_vali.mat +0 -0
  96. vbi/models/pytorch/data/input/myelin.csv +68 -0
  97. vbi/models/pytorch/data/input/rsfc_gradient.csv +68 -0
  98. vbi/models/pytorch/data/input/run_label_testset.mat +0 -0
  99. vbi/models/pytorch/data/input/sc_test.csv +68 -0
  100. vbi/models/pytorch/data/input/sc_train.csv +68 -0
  101. vbi/models/pytorch/data/input/sc_vali.csv +68 -0
  102. vbi/models/pytorch/data/obs_kong0.npz +0 -0
  103. vbi/models/pytorch/ww_sde_kong.py +570 -0
  104. vbi/models/tvbk/__init__.py +9 -0
  105. vbi/models/tvbk/tvbk_wrapper.py +166 -0
  106. vbi/models/tvbk/utils.py +72 -0
  107. vbi/papers/__init__.py +0 -0
  108. vbi/papers/pavlides_pcb_2015/pavlides.py +211 -0
  109. vbi/tests/__init__.py +0 -0
  110. vbi/tests/_test_mpr_nb.py +36 -0
  111. vbi/tests/test_features.py +355 -0
  112. vbi/tests/test_ghb_cupy.py +90 -0
  113. vbi/tests/test_mpr_cupy.py +49 -0
  114. vbi/tests/test_mpr_numba.py +84 -0
  115. vbi/tests/test_suite.py +19 -0
  116. vbi/utils.py +402 -0
  117. vbi-0.1.3.dist-info/METADATA +166 -0
  118. vbi-0.1.3.dist-info/RECORD +121 -0
  119. vbi-0.1.3.dist-info/WHEEL +5 -0
  120. vbi-0.1.3.dist-info/licenses/LICENSE +201 -0
  121. vbi-0.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,208 @@
1
+ import warnings
2
+ import numpy as np
3
+ from numba import njit
4
+ from numba.experimental import jitclass
5
+ from numba.core.errors import NumbaPerformanceWarning
6
+ from numba import float64, int64
7
+
8
+ warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
9
+
10
+
11
+ @njit(nogil=True)
12
+ def run(P, times):
13
+
14
+ G = P.G
15
+ dt = P.dt
16
+ eta = P.eta
17
+ SC = P.weights
18
+ sigma = P.sigma
19
+ omega = P.omega
20
+ tcut = P.tcut
21
+ decimate = P.decimate
22
+ init_state = P.init_state
23
+
24
+ epsilon = 0.5
25
+ itaus = 1.25
26
+ itauf = 2.5
27
+ itauo = 1.02040816327
28
+ ialpha = 5.0
29
+ Eo = 0.4
30
+ V0 = 4.0
31
+ k1 = 2.77264
32
+ k2 = 0.572
33
+ k3 = -0.43
34
+
35
+ nt = times.shape[0]
36
+ nn = SC.shape[0]
37
+
38
+ # state variables
39
+ n_buffer = np.int64(np.floor(nt / decimate) + 1)
40
+ x = np.zeros(nn)
41
+ bold = np.zeros((nn, n_buffer))
42
+ y = np.zeros(nn)
43
+ z = np.array([0.0] * nn + [1.0] * 3 * nn)
44
+ # act = np.zeros((nn, nt))
45
+
46
+ # initial conditions (similar value for all regions)
47
+ x_init, y_init = init_state[:nn], init_state[nn:]
48
+ x[:] = x_init
49
+ y[:] = y_init
50
+
51
+ for i in range(nn):
52
+ bold[i, 0] = V0 * (
53
+ k1
54
+ - k1 * z[3 * nn + i]
55
+ + k2
56
+ - k2 * (z[3 * nn + i] / z[2 * nn + i])
57
+ + k3
58
+ - k3 * z[2 * nn + i]
59
+ )
60
+
61
+ ii = 0 # counter for decimation
62
+ for it in range(nt - 1):
63
+ for i in range(nn):
64
+ gx, gy = 0.0, 0.0
65
+ for j in range(nn):
66
+ gx = gx + SC[i, j] * (x[j] - x[i])
67
+ gy = gy + SC[i, j] * (y[j] - y[i])
68
+ dx = (
69
+ (x[i] * (eta[i] - (x[i] * x[i]) - (y[i] * y[i])))
70
+ - (omega[i] * y[i])
71
+ + (G * gx)
72
+ )
73
+ dy = (
74
+ (y[i] * (eta[i] - (x[i] * x[i]) - (y[i] * y[i])))
75
+ + (omega[i] * x[i])
76
+ + (G * gy)
77
+ )
78
+ dz0 = epsilon * x[i] - itaus * z[i] - itauf * (z[nn + i] - 1)
79
+ dz1 = z[i]
80
+ dz2 = itauo * (z[nn + i] - z[2 * nn + i] ** ialpha)
81
+ dz3 = itauo * (
82
+ z[nn + i] * (1 - (1 - Eo) ** (1 / z[nn + i])) / Eo
83
+ - (z[2 * nn + i] ** ialpha) * z[3 * nn + i] / z[2 * nn + i]
84
+ )
85
+
86
+ x[i] = x[i] + dt * dx + np.sqrt(dt) * sigma * np.random.randn()
87
+ y[i] = y[i] + dt * dy + np.sqrt(dt) * sigma * np.random.randn()
88
+
89
+ z[i] = z[i] + dt * dz0
90
+ z[nn + i] = z[nn + i] + dt * dz1
91
+ z[2 * nn + i] = z[2 * nn + i] + dt * dz2
92
+ z[3 * nn + i] = z[3 * nn + i] + dt * dz3
93
+ if (it%decimate == 0):
94
+ bold[i, ii + 1] = V0 * (
95
+ k1
96
+ - k1 * z[3 * nn + i]
97
+ + k2
98
+ - k2 * (z[3 * nn + i] / z[2 * nn + i])
99
+ + k3
100
+ - k3 * z[2 * nn + i]
101
+ )
102
+ if (it%decimate == 0):
103
+ ii += 1
104
+ bold = bold[:, times[::decimate]>tcut]
105
+ t_bold = times[times[::decimate]>tcut]
106
+ return t_bold, bold
107
+
108
+
109
+ class GHB_sde(object):
110
+ def __init__(self, par: dict = {}) -> None:
111
+ self.valid_par = [par_spec[i][0] for i in range(len(par_spec))]
112
+ self.check_parameters(par)
113
+ self.P = self.get_par_obj(par)
114
+
115
+ def get_par_obj(self, par: dict):
116
+ if "init_state" in par.keys():
117
+ par["init_state"] = np.array(par["init_state"])
118
+ if "weights" in par.keys():
119
+ par["weights"] = np.array(par["weights"])
120
+ return ParGHB(**par)
121
+
122
+ def __str__(self) -> str:
123
+ print("GHB model")
124
+ for key in self.valid_par:
125
+ print(f"{key}: {getattr(self.P, key)}")
126
+ return ""
127
+
128
+ def check_parameters(self, par: dict) -> None:
129
+ for key in par.keys():
130
+ if key not in self.valid_par:
131
+ raise ValueError(f"Invalid parameter: {key}")
132
+
133
+ def set_initial_state(self, seed=None):
134
+ if seed is not None:
135
+ np.random.seed(seed)
136
+ assert self.P.weights is not None
137
+ return np.random.uniform(0, 1, 2 * self.P.weights.shape[0])
138
+
139
+ def check_input(self):
140
+ assert self.P.weights is not None
141
+ assert self.P.weights.shape[0] == self.P.weights.shape[1]
142
+ assert self.P.eta is not None
143
+ assert self.P.omega is not None
144
+ assert self.P.weights.shape[0] == self.P.eta.shape[0]
145
+
146
+ def run(self, par={}, tspan=None, x0=None, verbose=True):
147
+ if x0 is None:
148
+ self.seed = self.P.seed if self.P.seed > 0 else None
149
+ self.P.init_state = self.set_initial_state(seed=self.seed)
150
+ else:
151
+ self.P.init_state = x0
152
+
153
+ if tspan is None:
154
+ times = np.arange(0, self.P.tend, self.P.dt)
155
+ else:
156
+ times = np.arange(tspan[0], tspan[1], self.P.dt)
157
+
158
+ if par:
159
+ self.check_parameters(par)
160
+ for key in par.keys():
161
+ setattr(self.P, key, par[key])
162
+
163
+ self.check_input()
164
+ t, b = run(self.P, times)
165
+ return {'t': t, 'bold': b}
166
+
167
+
168
+ par_spec = [
169
+ ("G", float64),
170
+ ("dt", float64),
171
+ ("seed", int64),
172
+ ("tend", float64),
173
+ ("tcut", float64),
174
+ ("sigma", float64),
175
+ ("eta", float64[:]),
176
+ ("decimate", int64),
177
+ ("omega", float64[:]),
178
+ ("weights", float64[:, :]),
179
+ ("init_state", float64[:]),
180
+ ]
181
+
182
+
183
+ @jitclass(par_spec)
184
+ class ParGHB:
185
+ def __init__(
186
+ self,
187
+ G=1.0,
188
+ dt=0.001,
189
+ sigma=0.1,
190
+ tend=10.0,
191
+ tcut=0.0,
192
+ eta=np.array([]),
193
+ init_state=np.array([]),
194
+ omega=np.array([]),
195
+ weights=np.array([[], []]),
196
+ decimate=1,
197
+ ):
198
+ self.G = G
199
+ self.dt = dt
200
+ self.seed = -1
201
+ self.eta = eta
202
+ self.tend = tend
203
+ self.tcut = tcut
204
+ self.sigma = sigma
205
+ self.omega = omega
206
+ self.weights = weights
207
+ self.decimate = decimate
208
+ self.init_state = init_state
@@ -0,0 +1,383 @@
1
+ import warnings
2
+ import numpy as np
3
+ from copy import copy
4
+ from numba import njit, jit
5
+ from numba.experimental import jitclass
6
+ from numba.extending import register_jitable
7
+ from numba import float64, boolean, int64, types
8
+ from numba.core.errors import NumbaPerformanceWarning
9
+
10
+ warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
11
+ np.random.seed(42)
12
+
13
+
14
+ @njit
15
+ def f_mpr(x, t, P):
16
+ """
17
+ MPR model
18
+ """
19
+
20
+ dxdt = np.zeros_like(x)
21
+ nn = P.nn
22
+ x0 = x[:nn]
23
+ x1 = x[nn:]
24
+ delta_over_tau_pi = P.delta / (P.tau * np.pi)
25
+ J_tau = P.J * P.tau
26
+ pi2 = np.pi * np.pi
27
+ tau2 = P.tau * P.tau
28
+ rtau = 1.0 / P.tau
29
+
30
+ coupling = np.dot(P.weights, x0)
31
+ dxdt[:nn] = rtau * (delta_over_tau_pi + 2 * x0 * x1)
32
+ dxdt[nn:] = rtau * (
33
+ x1 * x1 + P.eta + P.iapp + J_tau * x0 - (pi2 * tau2 * x0 * x0) + P.G * coupling
34
+ )
35
+ return dxdt
36
+
37
+
38
+ @njit
39
+ def heun_sde(x, t, P):
40
+ nn = P.nn
41
+ dt = P.dt
42
+ dW_r = P.sigma_r * np.random.randn(nn)
43
+ dW_v = P.sigma_v * np.random.randn(nn)
44
+ k1 = f_mpr(x, t, P)
45
+ x1 = x + dt * k1
46
+ x1[:nn] += dW_r
47
+ x1[nn:] += dW_v
48
+
49
+ k2 = f_mpr(x1, t + dt, P)
50
+ x += 0.5 * dt * (k1 + k2)
51
+ x[:nn] += dW_r
52
+ x[:nn] = (x[:nn] > 0.0) * x[:nn]
53
+ x[nn:] += dW_v
54
+ return x
55
+
56
+
57
+ @njit
58
+ def do_bold_step(r_in, s, f, ftilde, vtilde, qtilde, v, q, dtt, P):
59
+ kappa = P.kappa
60
+ gamma = P.gamma
61
+ ialpha = 1 / P.alpha
62
+ tau = P.tau
63
+ Eo = P.Eo
64
+
65
+ s[1] = s[0] + dtt * (r_in - kappa * s[0] - gamma * (f[0] - 1))
66
+ f[0] = np.clip(f[0], 1, None)
67
+ ftilde[1] = ftilde[0] + dtt * (s[0] / f[0])
68
+ fv = v[0] ** ialpha # outflow
69
+ vtilde[1] = vtilde[0] + dtt * ((f[0] - fv) / (tau * v[0]))
70
+ q[0] = np.clip(q[0], 0.01, None)
71
+ ff = (1 - (1 - Eo) ** (1 / f[0])) / Eo # oxygen extraction
72
+ qtilde[1] = qtilde[0] + dtt * ((f[0] * ff - fv * q[0] / v[0]) / (tau * q[0]))
73
+
74
+ f[1] = np.exp(ftilde[1])
75
+ v[1] = np.exp(vtilde[1])
76
+ q[1] = np.exp(qtilde[1])
77
+
78
+ f[0] = f[1]
79
+ s[0] = s[1]
80
+ ftilde[0] = ftilde[1]
81
+ vtilde[0] = vtilde[1]
82
+ qtilde[0] = qtilde[1]
83
+ v[0] = v[1]
84
+ q[0] = q[1]
85
+
86
+
87
+ def integrate(P, B):
88
+
89
+ nn = P.nn
90
+ tr = P.tr
91
+ dt = P.dt
92
+ dt = P.dt
93
+ rv_decimate = P.rv_decimate
94
+ r_period = P.dt * 10 # extenting time
95
+ bold_decimate = int(np.round(tr / r_period))
96
+
97
+ dtt = r_period / 1000.0 # in seconds
98
+ k1 = 4.3 * B.theta0 * B.Eo * B.TE
99
+ k2 = B.epsilon * B.r0 * B.Eo * B.TE
100
+ k3 = 1 - B.epsilon
101
+ vo = B.vo
102
+
103
+ nt = int(P.t_end / P.dt)
104
+ rv_current = P.initial_state
105
+ RECORD_RV = P.RECORD_RV
106
+ RECORD_BOLD = P.RECORD_BOLD
107
+
108
+ rv_d = np.array([])
109
+ rv_t = np.zeros([])
110
+
111
+ bold_d = np.array([])
112
+ bold_t = np.array([])
113
+
114
+ if P.RECORD_RV:
115
+ rv_d = np.zeros((nt // rv_decimate, 2 * nn), dtype=np.float32)
116
+ rv_t = np.zeros((nt // rv_decimate), dtype=np.float32)
117
+
118
+ def compute():
119
+
120
+ bold_d = np.array([])
121
+ bold_t = np.array([])
122
+ s = np.zeros((2, nn))
123
+ f = np.zeros((2, nn))
124
+ ftilde = np.zeros((2, nn))
125
+ vtilde = np.zeros((2, nn))
126
+ qtilde = np.zeros((2, nn))
127
+ v = np.zeros((2, nn))
128
+ q = np.zeros((2, nn))
129
+ vv = np.zeros((nt // bold_decimate, nn))
130
+ qq = np.zeros((nt // bold_decimate, nn))
131
+ s[0] = 1
132
+ f[0] = 1
133
+ v[0] = 1
134
+ q[0] = 1
135
+ ftilde[0] = 0
136
+ vtilde[0] = 0
137
+ qtilde[0] = 0
138
+
139
+ for i in range(nt - 1):
140
+ t_current = i * dt
141
+ heun_sde(rv_current, t_current, P)
142
+
143
+ if RECORD_RV:
144
+ if ((i % rv_decimate) == 0) and ((i // rv_decimate) < rv_d.shape[0]):
145
+ rv_d[i // rv_decimate, :] = rv_current
146
+ rv_t[i // rv_decimate] = t_current
147
+
148
+ if RECORD_BOLD:
149
+ do_bold_step(
150
+ rv_current[:nn], s, f, ftilde, vtilde, qtilde, v, q, dtt, B
151
+ )
152
+ if (i % bold_decimate == 0) and ((i // bold_decimate) < vv.shape[0]):
153
+ vv[i // bold_decimate] = v[1]
154
+ qq[i // bold_decimate] = q[1]
155
+
156
+ if RECORD_BOLD:
157
+ bold_d = vo * (k1 * (1 - qq) + k2 * (1 - qq / vv) + k3 * (1 - vv))
158
+ bold_t = np.linspace(0, P.t_end - dt * bold_decimate, len(bold_d))
159
+
160
+ return rv_t, rv_d, bold_t, bold_d
161
+
162
+ rv_t, rv_d, bold_t, bold_d = compute()
163
+
164
+ return {
165
+ "rv_t": rv_t * 10,
166
+ "rv_d": rv_d,
167
+ "bold_t": bold_t.astype("f") * 10,
168
+ "bold_d": bold_d.astype("f"),
169
+ }
170
+
171
+
172
+ class MPR_sde:
173
+ def __init__(self, par_mpr: dict = {}) -> None:
174
+ self.valid_par = [mpr_spec[i][0] for i in range(len(mpr_spec))]
175
+ self.check_parameters(par_mpr)
176
+ self.P = self.get_par_mpr(par_mpr)
177
+ self.B = ParBold()
178
+
179
+ self.seed = self.P.seed
180
+ if self.seed > 0:
181
+ np.random.seed(self.seed)
182
+
183
+ def __str__(self) -> str:
184
+ print("MPR model")
185
+ print("Parameters: --------------------------------")
186
+ for key in self.valid_par:
187
+ print(f"{key} = {getattr(self.P, key)}")
188
+ print("--------------------------------------------")
189
+ return ""
190
+
191
+ def check_parameters(self, par: dict) -> None:
192
+ for key in par.keys():
193
+ if key not in self.valid_par:
194
+ raise ValueError(f"Invalid parameter: {key}")
195
+
196
+ def get_par_mpr(self, par: dict):
197
+ """
198
+ return default parameters of MPR model and update with user defined parameters.
199
+ """
200
+ if "initial_state" in par.keys():
201
+ par["initial_state"] = np.array(par["initial_state"])
202
+ if "weights" in par.keys():
203
+ assert par["weights"] is not None
204
+ par["weights"] = np.array(par["weights"])
205
+ assert par["weights"].shape[0] == par["weights"].shape[1]
206
+ parP = ParMPR(**par)
207
+ return parP
208
+
209
+ def set_initial_state(self):
210
+ self.initial_state = set_initial_state(self.P.nn, self.seed)
211
+ self.INITIAL_STATE_SET = True
212
+
213
+ def check_input(self):
214
+ assert self.P.weights is not None
215
+ assert self.P.weights.shape[0] == self.P.weights.shape[1]
216
+ assert self.P.initial_state is not None
217
+ assert len(self.P.initial_state) == 2 * self.P.weights.shape[0]
218
+ self.P.eta = check_vec_size(self.P.eta, self.P.nn)
219
+ self.P.t_end /= 10
220
+ self.P.t_cut /= 10
221
+
222
+ def run(self, par={}, x0=None, verbose=True):
223
+
224
+ if x0 is None:
225
+ self.seed = self.P.seed if self.P.seed > 0 else None
226
+ self.set_initial_state()
227
+ self.P.initial_state = self.initial_state
228
+ else:
229
+ self.P.initial_state = x0
230
+ # self.P.nn = len(x0) // 2 # is it necessary?
231
+ if par:
232
+ self.check_parameters(par)
233
+ for key in par.keys():
234
+ setattr(self.P, key, par[key])
235
+
236
+ self.check_input()
237
+
238
+ return integrate(self.P, self.B)
239
+
240
+
241
+ @njit
242
+ def set_initial_state(nn, seed=None):
243
+
244
+ if seed is not None:
245
+ set_seed_compat(seed)
246
+
247
+ y0 = np.random.rand(2 * nn)
248
+ y0[:nn] = y0[:nn] * 1.5
249
+ y0[nn:] = y0[nn:] * 4 - 2
250
+ return y0
251
+
252
+
253
+ mpr_spec = [
254
+ ("G", float64),
255
+ ("dt", float64),
256
+ ("J", float64),
257
+ ("eta", float64[:]),
258
+ ("tau", float64),
259
+ ("weights", float64[:, :]),
260
+ ("delta", float64),
261
+ ("t_init", float64),
262
+ ("t_cut", float64),
263
+ ("t_end", float64),
264
+ ("nn", int64),
265
+ ("method", types.string),
266
+ ("seed", int64),
267
+ ("initial_state", float64[:]),
268
+ ("noise_amp", float64),
269
+ ("sigma_r", float64),
270
+ ("sigma_v", float64),
271
+ ("iapp", float64),
272
+ ("output", types.string),
273
+ ("RECORD_RV", boolean),
274
+ ("RECORD_BOLD", boolean),
275
+ ("rv_decimate", int64),
276
+ ("tr", float64),
277
+ ]
278
+
279
+
280
+ @jitclass(mpr_spec)
281
+ class ParMPR:
282
+ def __init__(
283
+ self,
284
+ G=0.5,
285
+ dt=0.01,
286
+ J=14.5,
287
+ eta=np.array([-4.6]),
288
+ tau=1.0,
289
+ delta=0.7,
290
+ rv_decimate=1.0,
291
+ noise_amp=0.037,
292
+ weights=np.array([[], []]),
293
+ t_init=0.0,
294
+ t_cut=0.0,
295
+ t_end=1000.0,
296
+ iapp=0.0,
297
+ seed=-1,
298
+ output="output",
299
+ RECORD_RV=True,
300
+ RECORD_BOLD=True,
301
+ tr=500.0, # TR in milliseconds
302
+ ):
303
+
304
+ self.G = G
305
+ self.dt = dt
306
+ self.J = J
307
+ self.eta = eta
308
+ self.tau = tau
309
+ self.delta = delta
310
+ self.rv_decimate = rv_decimate
311
+ self.noise_amp = noise_amp
312
+ self.t_init = t_init
313
+ self.t_cut = t_cut
314
+ self.t_end = t_end
315
+ self.iapp = iapp
316
+ self.nn = len(weights)
317
+ self.seed = seed
318
+ self.output = output
319
+ self.weights = weights
320
+ self.RECORD_RV = RECORD_RV
321
+ self.RECORD_BOLD = RECORD_BOLD
322
+ self.sigma_r = np.sqrt(dt) * np.sqrt(2 * noise_amp)
323
+ self.sigma_v = np.sqrt(dt) * np.sqrt(4 * noise_amp)
324
+ self.tr = tr
325
+
326
+
327
+ bold_spec = [
328
+ ("kappa", float64),
329
+ ("gamma", float64),
330
+ ("tau", float64),
331
+ ("alpha", float64),
332
+ ("epsilon", float64),
333
+ ("Eo", float64),
334
+ ("TE", float64),
335
+ ("vo", float64),
336
+ ("r0", float64),
337
+ ("theta0", float64),
338
+ ("t_min", float64),
339
+ ("rtol", float64),
340
+ ("atol", float64),
341
+ ]
342
+
343
+
344
+ @jitclass(bold_spec)
345
+ class ParBold:
346
+ def __init__(
347
+ self,
348
+ kappa=0.65,
349
+ gamma=0.41,
350
+ tau=0.98,
351
+ alpha=0.32,
352
+ epsilon=0.34,
353
+ Eo=0.4,
354
+ TE=0.04,
355
+ vo=0.08,
356
+ r0=25.0,
357
+ theta0=40.3,
358
+ t_min=0.0,
359
+ rtol=1e-5,
360
+ atol=1e-8,
361
+ ):
362
+ self.kappa = kappa
363
+ self.gamma = gamma
364
+ self.tau = tau
365
+ self.alpha = alpha
366
+ self.epsilon = epsilon
367
+ self.Eo = Eo
368
+ self.TE = TE
369
+ self.vo = vo
370
+ self.r0 = r0
371
+ self.theta0 = theta0
372
+ self.t_min = t_min
373
+ self.rtol = rtol
374
+ self.atol = atol
375
+
376
+
377
+ def check_vec_size(x, nn):
378
+ return np.ones(nn) * x if len(x) != nn else np.array(x)
379
+
380
+
381
+ @register_jitable
382
+ def set_seed_compat(x):
383
+ np.random.seed(x)
File without changes