vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. vbi/__init__.py +37 -0
  2. vbi/_version.py +17 -0
  3. vbi/dataset/__init__.py +0 -0
  4. vbi/dataset/connectivity_84/centers.txt +84 -0
  5. vbi/dataset/connectivity_84/centres.txt +84 -0
  6. vbi/dataset/connectivity_84/cortical.txt +84 -0
  7. vbi/dataset/connectivity_84/tract_lengths.txt +84 -0
  8. vbi/dataset/connectivity_84/weights.txt +84 -0
  9. vbi/dataset/connectivity_88/Aud_88.txt +88 -0
  10. vbi/dataset/connectivity_88/Bold.npz +0 -0
  11. vbi/dataset/connectivity_88/Labels.txt +17 -0
  12. vbi/dataset/connectivity_88/Region_labels.txt +88 -0
  13. vbi/dataset/connectivity_88/tract_lengths.txt +88 -0
  14. vbi/dataset/connectivity_88/weights.txt +88 -0
  15. vbi/feature_extraction/__init__.py +1 -0
  16. vbi/feature_extraction/calc_features.py +293 -0
  17. vbi/feature_extraction/features.json +535 -0
  18. vbi/feature_extraction/features.py +2124 -0
  19. vbi/feature_extraction/features_settings.py +374 -0
  20. vbi/feature_extraction/features_utils.py +1357 -0
  21. vbi/feature_extraction/infodynamics.jar +0 -0
  22. vbi/feature_extraction/utility.py +507 -0
  23. vbi/inference.py +98 -0
  24. vbi/models/__init__.py +0 -0
  25. vbi/models/cpp/__init__.py +0 -0
  26. vbi/models/cpp/_src/__init__.py +0 -0
  27. vbi/models/cpp/_src/__pycache__/mpr_sde.cpython-310.pyc +0 -0
  28. vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
  29. vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
  30. vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  31. vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  32. vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  33. vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
  34. vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
  35. vbi/models/cpp/_src/bold.hpp +303 -0
  36. vbi/models/cpp/_src/do.hpp +167 -0
  37. vbi/models/cpp/_src/do.i +17 -0
  38. vbi/models/cpp/_src/do.py +467 -0
  39. vbi/models/cpp/_src/do_wrap.cxx +12811 -0
  40. vbi/models/cpp/_src/jr_sdde.hpp +352 -0
  41. vbi/models/cpp/_src/jr_sdde.i +19 -0
  42. vbi/models/cpp/_src/jr_sdde.py +688 -0
  43. vbi/models/cpp/_src/jr_sdde_wrap.cxx +18718 -0
  44. vbi/models/cpp/_src/jr_sde.hpp +264 -0
  45. vbi/models/cpp/_src/jr_sde.i +17 -0
  46. vbi/models/cpp/_src/jr_sde.py +470 -0
  47. vbi/models/cpp/_src/jr_sde_wrap.cxx +13406 -0
  48. vbi/models/cpp/_src/km_sde.hpp +158 -0
  49. vbi/models/cpp/_src/km_sde.i +19 -0
  50. vbi/models/cpp/_src/km_sde.py +671 -0
  51. vbi/models/cpp/_src/km_sde_wrap.cxx +17367 -0
  52. vbi/models/cpp/_src/makefile +52 -0
  53. vbi/models/cpp/_src/mpr_sde.hpp +327 -0
  54. vbi/models/cpp/_src/mpr_sde.i +19 -0
  55. vbi/models/cpp/_src/mpr_sde.py +711 -0
  56. vbi/models/cpp/_src/mpr_sde_wrap.cxx +18618 -0
  57. vbi/models/cpp/_src/utility.hpp +307 -0
  58. vbi/models/cpp/_src/vep.hpp +171 -0
  59. vbi/models/cpp/_src/vep.i +16 -0
  60. vbi/models/cpp/_src/vep.py +464 -0
  61. vbi/models/cpp/_src/vep_wrap.cxx +12968 -0
  62. vbi/models/cpp/_src/wc_ode.hpp +294 -0
  63. vbi/models/cpp/_src/wc_ode.i +19 -0
  64. vbi/models/cpp/_src/wc_ode.py +686 -0
  65. vbi/models/cpp/_src/wc_ode_wrap.cxx +24263 -0
  66. vbi/models/cpp/damp_oscillator.py +143 -0
  67. vbi/models/cpp/jansen_rit.py +543 -0
  68. vbi/models/cpp/km.py +187 -0
  69. vbi/models/cpp/mpr.py +289 -0
  70. vbi/models/cpp/vep.py +150 -0
  71. vbi/models/cpp/wc.py +216 -0
  72. vbi/models/cupy/__init__.py +0 -0
  73. vbi/models/cupy/bold.py +111 -0
  74. vbi/models/cupy/ghb.py +284 -0
  75. vbi/models/cupy/jansen_rit.py +473 -0
  76. vbi/models/cupy/km.py +224 -0
  77. vbi/models/cupy/mpr.py +475 -0
  78. vbi/models/cupy/mpr_modified_bold.py +12 -0
  79. vbi/models/cupy/utils.py +184 -0
  80. vbi/models/numba/__init__.py +0 -0
  81. vbi/models/numba/_ww_EI.py +444 -0
  82. vbi/models/numba/damp_oscillator.py +162 -0
  83. vbi/models/numba/ghb.py +208 -0
  84. vbi/models/numba/mpr.py +383 -0
  85. vbi/models/pytorch/__init__.py +0 -0
  86. vbi/models/pytorch/data/default_parameters.npz +0 -0
  87. vbi/models/pytorch/data/input/ROI_sim.mat +0 -0
  88. vbi/models/pytorch/data/input/fc_test.csv +68 -0
  89. vbi/models/pytorch/data/input/fc_train.csv +68 -0
  90. vbi/models/pytorch/data/input/fc_vali.csv +68 -0
  91. vbi/models/pytorch/data/input/fcd_test.mat +0 -0
  92. vbi/models/pytorch/data/input/fcd_test_high_window.mat +0 -0
  93. vbi/models/pytorch/data/input/fcd_test_low_window.mat +0 -0
  94. vbi/models/pytorch/data/input/fcd_train.mat +0 -0
  95. vbi/models/pytorch/data/input/fcd_vali.mat +0 -0
  96. vbi/models/pytorch/data/input/myelin.csv +68 -0
  97. vbi/models/pytorch/data/input/rsfc_gradient.csv +68 -0
  98. vbi/models/pytorch/data/input/run_label_testset.mat +0 -0
  99. vbi/models/pytorch/data/input/sc_test.csv +68 -0
  100. vbi/models/pytorch/data/input/sc_train.csv +68 -0
  101. vbi/models/pytorch/data/input/sc_vali.csv +68 -0
  102. vbi/models/pytorch/data/obs_kong0.npz +0 -0
  103. vbi/models/pytorch/ww_sde_kong.py +570 -0
  104. vbi/models/tvbk/__init__.py +9 -0
  105. vbi/models/tvbk/tvbk_wrapper.py +166 -0
  106. vbi/models/tvbk/utils.py +72 -0
  107. vbi/papers/__init__.py +0 -0
  108. vbi/papers/pavlides_pcb_2015/pavlides.py +211 -0
  109. vbi/tests/__init__.py +0 -0
  110. vbi/tests/_test_mpr_nb.py +36 -0
  111. vbi/tests/test_features.py +355 -0
  112. vbi/tests/test_ghb_cupy.py +90 -0
  113. vbi/tests/test_mpr_cupy.py +49 -0
  114. vbi/tests/test_mpr_numba.py +84 -0
  115. vbi/tests/test_suite.py +19 -0
  116. vbi/utils.py +402 -0
  117. vbi-0.1.3.dist-info/METADATA +166 -0
  118. vbi-0.1.3.dist-info/RECORD +121 -0
  119. vbi-0.1.3.dist-info/WHEEL +5 -0
  120. vbi-0.1.3.dist-info/licenses/LICENSE +201 -0
  121. vbi-0.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,444 @@
1
+ import warnings
2
+ import numpy as np
3
+ from numba import njit
4
+ from numba.experimental import jitclass
5
+ from numba.core.errors import NumbaPerformanceWarning
6
+ from numba import float64, boolean, int64, types
7
+
8
+ warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
9
+
10
+
11
+ w_spec = [
12
+ ("G", float64),
13
+
14
+ ("a_I", float64),
15
+ ("b_I", float64),
16
+ ("d_I", float64),
17
+ ("tau_I", float64),
18
+
19
+ ("a_E", float64),
20
+ ("b_E", float64),
21
+ ("d_E", float64),
22
+ ("tau_E", float64),
23
+
24
+ ("w_II", float64),
25
+ ("w_EE", float64),
26
+ ("w_IE", float64),
27
+ ("w_EI", float64),
28
+
29
+ ("W_E", float64),
30
+ ("W_I", float64),
31
+
32
+ ("gamma", float64),
33
+ ("dt", float64),
34
+ ("J_NMDA", float64),
35
+ ("J_I", float64),
36
+
37
+ ("I_I", float64[:]),
38
+ ("I_E", float64[:]),
39
+
40
+ # ("sigma_I", float64),
41
+ # ("sigma_E", float64),
42
+
43
+ ("initial_state", float64[:]),
44
+ ("weights", float64[:, :]),
45
+ ("seed", int64),
46
+ ("method", types.string),
47
+ ("t_end", float64),
48
+ ("t_cut", float64),
49
+ ("nn", int64),
50
+ ("ts_decimate", int64),
51
+ ("fmri_decimate", int64),
52
+ ("RECORD_TS", boolean),
53
+ ("RECORD_FMRI", boolean),
54
+ ]
55
+
56
+ b_spec = [
57
+ ("eps", float64),
58
+ ("E0", float64),
59
+ ("V0", float64),
60
+ ("alpha", float64),
61
+ ("inv_alpha", float64),
62
+ ("K1", float64),
63
+ ("K2", float64),
64
+ ("K3", float64),
65
+ ("taus", float64),
66
+ ("tauo", float64),
67
+ ("tauf", float64),
68
+ ("inv_tauo", float64),
69
+ ("inv_taus", float64),
70
+ ("inv_tauf", float64),
71
+ ("nn", int64),
72
+ ("dt_bold", float64),
73
+ ]
74
+
75
+
76
+ @jitclass(w_spec)
77
+ class ParWW:
78
+ def __init__(
79
+ self,
80
+ G=0.0,
81
+ a_I=615.0,
82
+ b_I=177.0,
83
+ d_I=0.087,
84
+ tau_I=0.01,
85
+
86
+ a_E=310.0,
87
+ b_E=125.0,
88
+ d_E=0.16,
89
+ tau_E=0.1,
90
+
91
+ gamma=0.641,
92
+
93
+ w_II=1.0,
94
+ w_IE=1.4,
95
+ w_EI=1.0,
96
+ w_EE=1.0,
97
+ dt=0.01,
98
+
99
+ W_E=1.0,
100
+ W_I =0.7,
101
+
102
+ I0 = 0.382,
103
+ J_NMDA=0.15,
104
+
105
+
106
+ I_I=np.array([0.296]), # 0.296
107
+ I_E=np.array([0.377]), # 0.377
108
+ sigma_I=0.001,
109
+ sigma_E=0.001,
110
+ initial_state=np.array([]),
111
+ weights=np.array([[], []]),
112
+ seed=-1,
113
+ method="heun",
114
+ t_end=300.0,
115
+ t_cut=0.0,
116
+ ts_decimate=10,
117
+ fmri_decimate=10,
118
+ RECORD_TS=True,
119
+ RECORD_FMRI=True,
120
+ ):
121
+ self.G = G
122
+ self.a_I = a_I
123
+ self.b_I = b_I
124
+ self.d_I = d_I
125
+ self.tau_I = tau_I
126
+
127
+ self.a_E = a_E
128
+ self.b_E = b_E
129
+ self.d_E = d_E
130
+ self.tau_E = tau_E
131
+
132
+ self.w_II = w_II
133
+ self.w_IE = w_IE
134
+ self.w_EI = w_EI
135
+ self.w_EE = w_EE
136
+ self.gamma = gamma
137
+
138
+ self.dt = dt
139
+
140
+ self.W_E = W_E
141
+ self.W_I = W_I
142
+
143
+ self.I0 = I0
144
+ self.I_E = I_E
145
+ self.I_I = I_I
146
+ self.J_NMDA = J_NMDA
147
+
148
+ self.sigma_I = sigma_I
149
+ self.sigma_E = sigma_E
150
+
151
+ self.initial_state = initial_state
152
+ self.weights = weights
153
+ self.seed = seed
154
+ self.method = method
155
+ self.t_end = t_end
156
+ self.t_cut = t_cut
157
+ self.ts_decimate = ts_decimate
158
+ self.fmri_decimate = fmri_decimate
159
+ self.RECORD_TS = RECORD_TS
160
+ self.RECORD_FMRI = RECORD_FMRI
161
+ if len(initial_state) > 0:
162
+ self.nn = len(initial_state)
163
+ else:
164
+ self.nn = -1
165
+
166
+
167
+ @jitclass(b_spec)
168
+ class ParBaloon:
169
+ def __init__(
170
+ self, eps=0.5, E0=0.4, V0=4.0,
171
+ alpha=0.32, taus=1.54, tauo=0.98, tauf=1.44
172
+ ):
173
+ self.eps = eps
174
+ self.E0 = E0
175
+ self.V0 = V0
176
+ self.alpha = alpha
177
+ self.inv_alpha = 1.0 / alpha
178
+ self.K1 = 7.0 * E0
179
+ self.K2 = 2 * E0
180
+ self.K3 = 1 - eps
181
+ self.taus = taus
182
+ self.tauo = tauo
183
+ self.tauf = tauf
184
+ self.inv_tauo = 1.0 / tauo
185
+ self.inv_taus = 1.0 / taus
186
+ self.inv_tauf = 1.0 / tauf
187
+ self.dt_bold = 0.01
188
+
189
+
190
+ @njit
191
+ def f_ww(S, P):
192
+ """
193
+ system function for Wong-Wang model.
194
+ """
195
+ coupling = np.dot(P.weights, S)
196
+ x = P.w * P.J_N * S + P.I_o + P.G * P.J_N * coupling
197
+ H = (P.a * x - P.b) / (1 - np.exp(-P.d * (P.a * x - P.b)))
198
+ dS = -(S / P.tau_s) + (1 - S) * H * P.gamma
199
+ return dS
200
+
201
+
202
+ @njit
203
+ def f_fmri(xin, x, t, B):
204
+ """
205
+ system function for Balloon model.
206
+ """
207
+ E0 = B.E0
208
+ nn = B.nn
209
+ inv_tauf = B.inv_tauf
210
+ inv_tauo = B.inv_tauo
211
+ inv_taus = B.inv_taus
212
+ inv_alpha = B.inv_alpha
213
+
214
+ dxdt = np.zeros(4 * nn)
215
+ s = x[:nn]
216
+ f = x[nn : 2 * nn]
217
+ v = x[2 * nn : 3 * nn]
218
+ q = x[3 * nn :]
219
+
220
+ dxdt[:nn] = xin - inv_taus * s - inv_tauf * (f - 1.0)
221
+ dxdt[nn : (2 * nn)] = s
222
+ dxdt[(2 * nn) : (3 * nn)] = inv_tauo * (f - v ** (inv_alpha))
223
+ dxdt[3 * nn :] = (inv_tauo) * (
224
+ (f * (1.0 - (1.0 - E0) ** (1.0 / f)) / E0) - (v ** (inv_alpha)) * (q / v)
225
+ )
226
+ return dxdt
227
+
228
+
229
+ @njit
230
+ def euler_sde_step(S, P):
231
+ dW = np.sqrt(P.dt) * P.sigma_noise * np.random.randn(P.nn)
232
+ return S + P.dt * f_ww(S, P) + dW
233
+
234
+
235
+ @njit
236
+ def heun_sde_step(S, P):
237
+ dW = np.sqrt(P.dt) * P.sigma_noise * np.random.randn(P.nn)
238
+ k0 = f_ww(S, P)
239
+ S1 = S + P.dt * k0 + dW
240
+ k1 = f_ww(S1, P)
241
+ return S + 0.5 * P.dt * (k0 + k1) + dW
242
+
243
+
244
+ @njit
245
+ def heun_ode_step(yin, y, t, B):
246
+ """Heun scheme."""
247
+
248
+ dt = B.dt_bold
249
+ k1 = f_fmri(yin, y, t, B)
250
+ tmp = y + dt * k1
251
+ k2 = f_fmri(yin, tmp, t + dt, B)
252
+ y += 0.5 * dt * (k1 + k2)
253
+ return y
254
+
255
+
256
+ @njit
257
+ def integrate_fmri(yin, y, t, B):
258
+ """
259
+ Integrate Balloon model
260
+
261
+ Parameters
262
+ ----------
263
+ yin : array [nn]
264
+ r and v time series, r is used as input
265
+ y : array [4*nn]
266
+ state, update in place
267
+ t : float
268
+ time
269
+
270
+ Returns
271
+ -------
272
+ yb : array [nn]
273
+ BOLD signal
274
+
275
+ """
276
+
277
+ V0 = B.V0
278
+ K1 = B.K1
279
+ K2 = B.K2
280
+ K3 = B.K3
281
+
282
+ nn = yin.shape[0]
283
+ y = heun_ode_step(yin, y, t, B)
284
+ yb = V0 * (
285
+ K1 * (1.0 - y[(3 * nn) :])
286
+ + K2 * (1.0 - y[(3 * nn) :] / y[(2 * nn) : (3 * nn)])
287
+ + K3 * (1.0 - y[(2 * nn) : (3 * nn)])
288
+ )
289
+ return y, yb
290
+
291
+
292
+ @njit
293
+ def integrate(P, B, intg=heun_sde_step):
294
+ """
295
+ integrate Wong-Wang model and Balloon model.
296
+ """
297
+ t = np.arange(0, P.t_end, P.dt)
298
+ nt = len(t)
299
+ nn = P.nn
300
+
301
+ if P.RECORD_TS:
302
+ T = np.empty(int(np.ceil(nt / P.ts_decimate)))
303
+ S = np.empty((int(np.ceil(nt / P.ts_decimate)), nn))
304
+ else:
305
+ T = np.empty(0)
306
+ S = np.empty((0, 1))
307
+
308
+ if P.RECORD_FMRI:
309
+ t_fmri = np.empty(int(np.ceil(nt / P.fmri_decimate)))
310
+ d_fmri = np.empty((int(np.ceil(nt / P.fmri_decimate)), nn))
311
+ else:
312
+ t_fmri = np.empty(0)
313
+ d_fmri = np.empty((0, 1))
314
+
315
+ x0 = P.initial_state
316
+ y0 = np.zeros((4 * nn))
317
+ y0[nn:] = 1.0
318
+
319
+ jj = 0
320
+ ii = 0
321
+ for i in range(1, nt):
322
+
323
+ t = i * P.dt
324
+ t_bold = i * B.dt_bold
325
+ x0 = intg(x0, P)
326
+
327
+ if P.RECORD_TS:
328
+ if i % P.ts_decimate == 0:
329
+ S[ii, :] = x0
330
+ T[ii] = t
331
+ ii += 1
332
+ if P.RECORD_FMRI:
333
+ y0, fmri_i = integrate_fmri(x0, y0, t, B)
334
+ if i % P.fmri_decimate == 0:
335
+ d_fmri[jj, :] = fmri_i
336
+ # t_fmri[jj] = t[i]
337
+ t_fmri[jj] = t_bold
338
+ jj += 1
339
+ S = S[T >= P.t_cut, :]
340
+ T = T[T >= P.t_cut]
341
+ d_fmri = d_fmri[t_fmri >= P.t_cut, :]
342
+ t_fmri = t_fmri[t_fmri >= P.t_cut]
343
+
344
+ return T, S, t_fmri, d_fmri
345
+
346
+
347
+ class WW_sde(object):
348
+ r"""
349
+ Wong-Wang model.
350
+
351
+ .. math::
352
+ x_k &= w\,J_N \, S_k + I_o + G\,J_N \sum_j \, C_{kj} \,Sj \\
353
+ H(x_k) &= \dfrac{ax_k - b}{1 - \exp(-d(ax_k -b))}\\
354
+ \dot{S}_k &= -\dfrac{S_k}{\tau_s} + (1 - S_k) \, H(x_k) \, \gamma + \sigma \, \Xi_k
355
+
356
+ - Kong-Fatt Wong and Xiao-Jing Wang, A Recurrent Network Mechanism of Time Integration in Perceptual Decisions. Journal of Neuroscience 26(4), 1314-1328, 2006.
357
+ - Deco Gustavo, Ponce Alvarez Adrian, Dante Mantini, Gian Luca Romani, Patric Hagmann and Maurizio Corbetta. Resting-State Functional Connectivity Emerges from Structurally and Dynamically Shaped Slow Linear Fluctuations. The Journal of Neuroscience 32(27), 11239-11252, 2013. Equations taken from DPA 2013 , page 11242.
358
+ """
359
+
360
+ def __init__(self, par: dict = {}, parB: dict = {}) -> None:
361
+
362
+ self.valid_parW = [w_spec[i][0] for i in range(len(w_spec))]
363
+ self.valid_parB = [b_spec[i][0] for i in range(len(b_spec))]
364
+ self.valid_par = self.valid_parW + self.valid_parB
365
+
366
+ self.check_parameters(par)
367
+ self.P = self.get_par_WW_obj(par)
368
+ self.B = self.get_par_Baloon_obj(parB)
369
+
370
+ def __str__(self):
371
+ print("Wong-Wang model of neural population dynamics")
372
+ print("Parameters:----------------------------")
373
+ for key in self.valid_parW:
374
+ print(key, ": ", getattr(self.P, key))
375
+ print("---------------------------------------")
376
+ for key in self.valid_parB:
377
+ print(key, ": ", getattr(self.B, key))
378
+ return ""
379
+
380
+ def get_par_WW_obj(self, par: dict = {}) -> ParWW:
381
+ """
382
+ return default parameters for Wong-Wang model.
383
+ """
384
+ if "initial_state" in par.keys():
385
+ par["initial_state"] = np.array(par["initial_state"])
386
+ if "weights" in par.keys():
387
+ assert par["weights"] is not None
388
+ par["weights"] = np.array(par["weights"])
389
+ assert par["weights"].shape[0] == par["weights"].shape[1]
390
+ parobj = ParWW(**par)
391
+
392
+ return parobj
393
+
394
+ def get_par_Baloon_obj(self, par: dict = {}) -> ParBaloon:
395
+ """
396
+ return default parameters for Balloon model.
397
+ """
398
+ parobj = ParBaloon(**par)
399
+ return parobj
400
+
401
+ def check_parameters(self, par: dict) -> None:
402
+ for key in par.keys():
403
+ if key not in self.valid_par:
404
+ raise ValueError(f"Invalid parameter {key}")
405
+
406
+ def set_initial_state(self, seed=None):
407
+ if seed is not None:
408
+ np.random.seed(seed)
409
+ self.P.nn = self.P.weights.shape[0]
410
+ self.initial_state = np.random.rand(self.P.nn)
411
+ self.B.nn = self.P.nn
412
+
413
+ def check_input(self):
414
+
415
+ assert self.P.weights is not None
416
+ assert self.P.weights.shape[0] == self.P.weights.shape[1]
417
+ assert self.P.initial_state is not None
418
+ assert len(self.P.initial_state) == self.P.weights.shape[0]
419
+ self.B.nn = self.P.nn
420
+ # self.B.dt = self.P.dt
421
+
422
+ def run(self, par={}, parB={}, x0=None, verbose=True):
423
+
424
+ if x0 is None:
425
+ self.seed = self.P.seed if self.P.seed > 0 else None
426
+ self.set_initial_state(self.seed)
427
+ self.P.initial_state = self.initial_state
428
+ else:
429
+ self.P.initial_state = x0
430
+ self.P.nn = len(x0)
431
+
432
+ if par:
433
+ self.check_parameters(par)
434
+ for key in par.keys():
435
+ setattr(self.P, key, par[key])
436
+ if parB:
437
+ self.check_parameters(parB)
438
+ for key in parB.keys():
439
+ setattr(self.B, key, parB[key])
440
+ self.check_input()
441
+
442
+ T, S, t_fmri, d_fmri = integrate(self.P, self.B)
443
+
444
+ return {"t": T, "s": S, "t_fmri": t_fmri, "d_fmri": d_fmri}
@@ -0,0 +1,162 @@
1
+ import os
2
+ import numpy as np
3
+ from numba.experimental import jitclass
4
+ from numba import float64, types
5
+ from numba import njit
6
+
7
+ jit_spec = [('a', float64),
8
+ ('b', float64),
9
+ ('dt', float64),
10
+ ('t_start', float64),
11
+ ('t_end', float64),
12
+ ('t_cut', float64),
13
+ ('output', types.string),
14
+ ('initial_state', float64[:]),
15
+ ('method', types.string)
16
+ ]
17
+
18
+ @jitclass(jit_spec)
19
+ class Param:
20
+ def __init__(self,
21
+ a=0.1,
22
+ b=0.05,
23
+ dt=0.01,
24
+ t_start=0,
25
+ t_end=100.0,
26
+ t_cut=20,
27
+ output="output",
28
+ method="euler",
29
+ initial_state=np.array([0.5, 1.0])
30
+ ):
31
+ self.a = a
32
+ self.b = b
33
+ self.dt = dt
34
+ self.t_start = t_start
35
+ self.t_end = t_end
36
+ self.t_cut = t_cut
37
+ self.output = output
38
+ self.method = method
39
+ self.initial_state = initial_state
40
+
41
+
42
+ @njit
43
+ def _f_sys(x, P):
44
+ '''
45
+ system function for damp oscillator model.
46
+ '''
47
+ a = P.a
48
+ b = P.b
49
+ return np.array([x[0] - x[0]*x[1] - a * x[0] * x[0],
50
+ x[0]*x[1] - x[1] - b * x[1] * x[1]])
51
+
52
+
53
+ @njit
54
+ def euler(x, P):
55
+ '''
56
+ euler integration for damp oscillator model.
57
+ '''
58
+ return x + P.dt * _f_sys(x, P)
59
+
60
+ @njit
61
+ def heun(x, P):
62
+ '''
63
+ heun integration for damp oscillator model.
64
+ '''
65
+ k0 = _f_sys(x, P)
66
+ x1 = x + P.dt * k0
67
+ k1 = _f_sys(x1, P)
68
+ return x + 0.5 * P.dt * (k0 + k1)
69
+
70
+ @njit
71
+ def rk4(x, P):
72
+ '''
73
+ runge-kutta integration for damp oscillator model.
74
+ '''
75
+ k1 = _f_sys(x, P)
76
+ k2 = _f_sys(x + 0.5 * P.dt * k1, P)
77
+ k3 = _f_sys(x + 0.5 * P.dt * k2, P)
78
+ k4 = _f_sys(x + P.dt * k3, P)
79
+ return x + P.dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6.0
80
+
81
+
82
+ @njit
83
+ def _integrate(x, P, intg=euler):
84
+
85
+ t0 = np.arange(P.t_start, P.t_cut, P.dt)
86
+
87
+ for i in range(len(t0)):
88
+ x = intg(x, P)
89
+
90
+ t = np.arange(P.t_cut, P.t_end, P.dt)
91
+ x_out = np.zeros((len(t), len(x)))
92
+
93
+ for i in range(len(t)):
94
+ x = intg(x, P)
95
+ x_out[i, :] = x
96
+ return t, x_out
97
+
98
+
99
+ class DO_nb:
100
+ '''
101
+ Damper Oscillator model class.
102
+ '''
103
+
104
+ def __init__(self, par={}):
105
+
106
+ self.valid_params = [jit_spec[i][0] for i in range(len(jit_spec))]
107
+ self.check_parameters(par)
108
+ self.P = self.get_parobj(par)
109
+
110
+ self.P.output = "output" if self.P.output is None else self.P.output
111
+ os.makedirs(self.P.output, exist_ok=True)
112
+
113
+ def __str__(self) -> str:
114
+ print("Damp Oscillator model")
115
+ print("----------------")
116
+ for key in self.valid_params:
117
+ print(key, ": ", getattr(self.P, key))
118
+ return ""
119
+
120
+ def check_parameters(self, par):
121
+ '''
122
+ check if the parameters are valid.
123
+ '''
124
+ for key in par.keys():
125
+ if key not in self.valid_params:
126
+ raise ValueError("Invalid parameter: " + key)
127
+
128
+ def get_parobj(self, par={}):
129
+ '''
130
+ return default parameters for damp oscillator model.
131
+ '''
132
+ if "initial_state" in par.keys():
133
+ par["initial_state"] = np.array(par["initial_state"])
134
+
135
+ parobj = Param(**par)
136
+
137
+ return parobj
138
+
139
+ def update_par(self, par={}):
140
+
141
+ if par:
142
+ self.check_parameters(par)
143
+ for key in par.keys():
144
+ setattr(self.P, key, par[key])
145
+
146
+ def run(self, par={}, x0=None):
147
+
148
+ self.update_par(par)
149
+ if x0 is not None:
150
+ assert len(x0) == 2, "Invalid initial state"
151
+ self.P.initial_state = x0
152
+
153
+ method = self.P.method
154
+ if method == "euler":
155
+ intg = euler
156
+ elif method == "heun":
157
+ intg = heun
158
+ elif method == "rk4":
159
+ intg = rk4
160
+
161
+ return _integrate(self.P.initial_state, self.P, intg=intg)
162
+