vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. vbi/__init__.py +37 -0
  2. vbi/_version.py +17 -0
  3. vbi/dataset/__init__.py +0 -0
  4. vbi/dataset/connectivity_84/centers.txt +84 -0
  5. vbi/dataset/connectivity_84/centres.txt +84 -0
  6. vbi/dataset/connectivity_84/cortical.txt +84 -0
  7. vbi/dataset/connectivity_84/tract_lengths.txt +84 -0
  8. vbi/dataset/connectivity_84/weights.txt +84 -0
  9. vbi/dataset/connectivity_88/Aud_88.txt +88 -0
  10. vbi/dataset/connectivity_88/Bold.npz +0 -0
  11. vbi/dataset/connectivity_88/Labels.txt +17 -0
  12. vbi/dataset/connectivity_88/Region_labels.txt +88 -0
  13. vbi/dataset/connectivity_88/tract_lengths.txt +88 -0
  14. vbi/dataset/connectivity_88/weights.txt +88 -0
  15. vbi/feature_extraction/__init__.py +1 -0
  16. vbi/feature_extraction/calc_features.py +293 -0
  17. vbi/feature_extraction/features.json +535 -0
  18. vbi/feature_extraction/features.py +2124 -0
  19. vbi/feature_extraction/features_settings.py +374 -0
  20. vbi/feature_extraction/features_utils.py +1357 -0
  21. vbi/feature_extraction/infodynamics.jar +0 -0
  22. vbi/feature_extraction/utility.py +507 -0
  23. vbi/inference.py +98 -0
  24. vbi/models/__init__.py +0 -0
  25. vbi/models/cpp/__init__.py +0 -0
  26. vbi/models/cpp/_src/__init__.py +0 -0
  27. vbi/models/cpp/_src/__pycache__/mpr_sde.cpython-310.pyc +0 -0
  28. vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
  29. vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
  30. vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  31. vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  32. vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  33. vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
  34. vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
  35. vbi/models/cpp/_src/bold.hpp +303 -0
  36. vbi/models/cpp/_src/do.hpp +167 -0
  37. vbi/models/cpp/_src/do.i +17 -0
  38. vbi/models/cpp/_src/do.py +467 -0
  39. vbi/models/cpp/_src/do_wrap.cxx +12811 -0
  40. vbi/models/cpp/_src/jr_sdde.hpp +352 -0
  41. vbi/models/cpp/_src/jr_sdde.i +19 -0
  42. vbi/models/cpp/_src/jr_sdde.py +688 -0
  43. vbi/models/cpp/_src/jr_sdde_wrap.cxx +18718 -0
  44. vbi/models/cpp/_src/jr_sde.hpp +264 -0
  45. vbi/models/cpp/_src/jr_sde.i +17 -0
  46. vbi/models/cpp/_src/jr_sde.py +470 -0
  47. vbi/models/cpp/_src/jr_sde_wrap.cxx +13406 -0
  48. vbi/models/cpp/_src/km_sde.hpp +158 -0
  49. vbi/models/cpp/_src/km_sde.i +19 -0
  50. vbi/models/cpp/_src/km_sde.py +671 -0
  51. vbi/models/cpp/_src/km_sde_wrap.cxx +17367 -0
  52. vbi/models/cpp/_src/makefile +52 -0
  53. vbi/models/cpp/_src/mpr_sde.hpp +327 -0
  54. vbi/models/cpp/_src/mpr_sde.i +19 -0
  55. vbi/models/cpp/_src/mpr_sde.py +711 -0
  56. vbi/models/cpp/_src/mpr_sde_wrap.cxx +18618 -0
  57. vbi/models/cpp/_src/utility.hpp +307 -0
  58. vbi/models/cpp/_src/vep.hpp +171 -0
  59. vbi/models/cpp/_src/vep.i +16 -0
  60. vbi/models/cpp/_src/vep.py +464 -0
  61. vbi/models/cpp/_src/vep_wrap.cxx +12968 -0
  62. vbi/models/cpp/_src/wc_ode.hpp +294 -0
  63. vbi/models/cpp/_src/wc_ode.i +19 -0
  64. vbi/models/cpp/_src/wc_ode.py +686 -0
  65. vbi/models/cpp/_src/wc_ode_wrap.cxx +24263 -0
  66. vbi/models/cpp/damp_oscillator.py +143 -0
  67. vbi/models/cpp/jansen_rit.py +543 -0
  68. vbi/models/cpp/km.py +187 -0
  69. vbi/models/cpp/mpr.py +289 -0
  70. vbi/models/cpp/vep.py +150 -0
  71. vbi/models/cpp/wc.py +216 -0
  72. vbi/models/cupy/__init__.py +0 -0
  73. vbi/models/cupy/bold.py +111 -0
  74. vbi/models/cupy/ghb.py +284 -0
  75. vbi/models/cupy/jansen_rit.py +473 -0
  76. vbi/models/cupy/km.py +224 -0
  77. vbi/models/cupy/mpr.py +475 -0
  78. vbi/models/cupy/mpr_modified_bold.py +12 -0
  79. vbi/models/cupy/utils.py +184 -0
  80. vbi/models/numba/__init__.py +0 -0
  81. vbi/models/numba/_ww_EI.py +444 -0
  82. vbi/models/numba/damp_oscillator.py +162 -0
  83. vbi/models/numba/ghb.py +208 -0
  84. vbi/models/numba/mpr.py +383 -0
  85. vbi/models/pytorch/__init__.py +0 -0
  86. vbi/models/pytorch/data/default_parameters.npz +0 -0
  87. vbi/models/pytorch/data/input/ROI_sim.mat +0 -0
  88. vbi/models/pytorch/data/input/fc_test.csv +68 -0
  89. vbi/models/pytorch/data/input/fc_train.csv +68 -0
  90. vbi/models/pytorch/data/input/fc_vali.csv +68 -0
  91. vbi/models/pytorch/data/input/fcd_test.mat +0 -0
  92. vbi/models/pytorch/data/input/fcd_test_high_window.mat +0 -0
  93. vbi/models/pytorch/data/input/fcd_test_low_window.mat +0 -0
  94. vbi/models/pytorch/data/input/fcd_train.mat +0 -0
  95. vbi/models/pytorch/data/input/fcd_vali.mat +0 -0
  96. vbi/models/pytorch/data/input/myelin.csv +68 -0
  97. vbi/models/pytorch/data/input/rsfc_gradient.csv +68 -0
  98. vbi/models/pytorch/data/input/run_label_testset.mat +0 -0
  99. vbi/models/pytorch/data/input/sc_test.csv +68 -0
  100. vbi/models/pytorch/data/input/sc_train.csv +68 -0
  101. vbi/models/pytorch/data/input/sc_vali.csv +68 -0
  102. vbi/models/pytorch/data/obs_kong0.npz +0 -0
  103. vbi/models/pytorch/ww_sde_kong.py +570 -0
  104. vbi/models/tvbk/__init__.py +9 -0
  105. vbi/models/tvbk/tvbk_wrapper.py +166 -0
  106. vbi/models/tvbk/utils.py +72 -0
  107. vbi/papers/__init__.py +0 -0
  108. vbi/papers/pavlides_pcb_2015/pavlides.py +211 -0
  109. vbi/tests/__init__.py +0 -0
  110. vbi/tests/_test_mpr_nb.py +36 -0
  111. vbi/tests/test_features.py +355 -0
  112. vbi/tests/test_ghb_cupy.py +90 -0
  113. vbi/tests/test_mpr_cupy.py +49 -0
  114. vbi/tests/test_mpr_numba.py +84 -0
  115. vbi/tests/test_suite.py +19 -0
  116. vbi/utils.py +402 -0
  117. vbi-0.1.3.dist-info/METADATA +166 -0
  118. vbi-0.1.3.dist-info/RECORD +121 -0
  119. vbi-0.1.3.dist-info/WHEEL +5 -0
  120. vbi-0.1.3.dist-info/licenses/LICENSE +201 -0
  121. vbi-0.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,166 @@
1
+ import numpy as np
2
+ import collections
3
+ from vbi.models.tvbk.utils import prepare_vec, setup_connectivity
4
+
5
+ try:
6
+ import tvbk as m
7
+
8
+ TVBK_AVAILABLE = True
9
+ except ImportError:
10
+ TVBK_AVAILABLE = False
11
+
12
+
13
+ class MPR:
14
+
15
+ MPRTheta = collections.namedtuple(
16
+ typename="MPRTheta", field_names="tau I Delta J eta G".split(" ")
17
+ )
18
+ num_svar = 2 # number of state variables
19
+ num_parm = 6 # number of parameters
20
+
21
+ def __init__(self, par: dict = {}) -> None:
22
+
23
+ self._par = self.get_default_parameters()
24
+ self.valid_parameters = list(self._par.keys())
25
+ self.check_parameters(par)
26
+ self._par.update(par)
27
+
28
+ for item in self._par.items():
29
+ setattr(self, item[0], item[1])
30
+
31
+ self.mpr_default_theta = self.MPRTheta(
32
+ tau=self._par["tau"],
33
+ I=self._par["I"],
34
+ Delta=self._par["Delta"],
35
+ J=self._par["J"],
36
+ eta=self._par["eta"],
37
+ G=self._par["G"],
38
+ )
39
+
40
+ def __str__(self) -> str:
41
+ return f"MPR model with parameters: {self._par}"
42
+
43
+ def get_default_parameters(self) -> dict:
44
+ return {
45
+ "tau": 1.0,
46
+ "Delta": 1.0,
47
+ "I": 0.0,
48
+ "J": 15.0,
49
+ "eta": -5.0,
50
+ "G": 1.0,
51
+ "num_batch": 1,
52
+ "horizon": 256,
53
+ "width": 8,
54
+ "dt": 0.01,
55
+ "dtype": np.float32,
56
+ "weights": None,
57
+ "delays": None,
58
+ "num_node": None,
59
+ "noise_amp": None,
60
+ "num_time": 1000,
61
+ "decimate_rv": 10,
62
+ "RECORD_RV": True,
63
+ "RECORD_BOLD": False, # TODO: Add BOLD recording
64
+ }
65
+
66
+ def check_parameters(self, par):
67
+ for key in par.keys():
68
+ if key not in self.valid_parameters:
69
+ raise ValueError(f"Invalid parameter {key:s} provided.")
70
+
71
+
72
+ def initialize_buffers(self):
73
+
74
+ total_volume = self.num_batch * self.num_node * self.horizon * self.width
75
+ self.cx = m.Cx8s(self.num_node, self.horizon, self.num_batch)
76
+ buf_val = (
77
+ np.r_[: 1.0 : 1j * total_volume]
78
+ .reshape(self.num_batch, self.num_node, self.horizon, self.width)
79
+ .astype(self.dtype)
80
+ * 4.0
81
+ )
82
+ self.cx.buf[:] = buf_val
83
+ self.cx.cx1[:] = self.cx.cx2[:] = 0.0
84
+
85
+
86
+ def prepare_input(self):
87
+
88
+ width = self.width
89
+ num_batch = self.num_batch
90
+ num_parm = self.num_parm
91
+
92
+ assert self.weights is not None, "weights must be provided"
93
+ self.weights = np.array(self.weights)
94
+ num_node = self.num_node = self.weights.shape[0]
95
+
96
+ self.G = prepare_vec(self.G, num_batch, self.dtype)
97
+ self.J = prepare_vec(self.J, num_batch, self.dtype)
98
+ self.I = prepare_vec(self.I, num_batch, self.dtype)
99
+ self.eta = prepare_vec(self.eta, num_batch, self.dtype)
100
+ self.tau = prepare_vec(self.tau, num_batch, self.dtype)
101
+ self.Delta = prepare_vec(self.Delta, num_batch, self.dtype)
102
+ self.noise_amp = np.array(self.noise_amp, self.dtype)
103
+ self.p = np.zeros((num_batch, num_node, num_parm, width), self.dtype)
104
+
105
+ self.p[:, :, 0, :] = self.tau
106
+ self.p[:, :, 1, :] = self.I
107
+ self.p[:, :, 2, :] = self.Delta
108
+ self.p[:, :, 3, :] = self.J
109
+ self.p[:, :, 4, :] = self.eta
110
+ self.p[:, :, 5, :] = self.G
111
+
112
+ self.initialize_buffers()
113
+ self.conn = setup_connectivity(self.weights, self.delays)
114
+
115
+
116
+ def run(self):
117
+
118
+ self.prepare_input()
119
+
120
+ x = np.zeros((self.num_batch, self.num_svar, self.num_node, self.width), self.dtype)
121
+ y = np.zeros_like(x)
122
+ z = np.zeros((self.num_batch, self.num_svar, 8), self.dtype)+ self.noise_amp
123
+ seed = np.zeros((self.num_batch, 8, 4), np.uint64)
124
+ num_samples = self.num_time // self.decimate_rv + 1
125
+
126
+ if self.RECORD_RV:
127
+ trace_c = np.zeros(
128
+ (
129
+ num_samples,
130
+ self.num_batch,
131
+ self.num_svar,
132
+ self.num_node,
133
+ self.width,
134
+ )
135
+ )
136
+
137
+ for i in range(num_samples):
138
+ if self.RECORD_RV:
139
+ m.step_mpr(
140
+ self.cx,
141
+ self.conn,
142
+ x,
143
+ y,
144
+ z,
145
+ self.p,
146
+ i * self.decimate_rv,
147
+ self.decimate_rv,
148
+ self.dt,
149
+ seed
150
+ )
151
+ if self.RECORD_RV:
152
+ trace_c[i] = x
153
+
154
+ # TODO: Calculate BOLD signal
155
+ if self.RECORD_BOLD:
156
+ pass
157
+ # add BOLD signal calculation pytorch code here
158
+
159
+
160
+ return {
161
+ "rv_t": ...,
162
+ "rv_d": trace_c,
163
+ "fmri_t": ...,
164
+ "fmri_d": None,
165
+ }
166
+
@@ -0,0 +1,72 @@
1
+ import tvbk as m
2
+ import numpy as np
3
+ import scipy.sparse
4
+ import collections
5
+ from vbi.models.cupy.utils import repmat_vec, is_seq # TODO move it to vbi.utils
6
+
7
+
8
+ def setup_connectivity(weights, delays=None):
9
+ """
10
+ Sets up connectivity using provided weights and delays.
11
+
12
+ Args:
13
+ weights (np.ndarray): The weight matrix.
14
+ delays (np.ndarray): The delay matrix.
15
+
16
+ Returns:
17
+ conn: The connection object.
18
+ """
19
+
20
+ # Convert weights to sparse matrix
21
+ s_w = scipy.sparse.csr_matrix(weights)
22
+ num_node = weights.shape[0]
23
+
24
+ # TODO! check how to handle delays if not provided
25
+ if delays is None:
26
+ delays = np.zeros_like(weights)
27
+
28
+ # Ensure delays are valid
29
+ idelays = (delays[weights != 0]).astype(np.uint32) + 2
30
+ assert idelays.max() < delays.shape[1]
31
+ assert idelays.min() >= 2
32
+
33
+ # Create the connection object
34
+ conn = m.Conn(num_node, s_w.data.size)
35
+ conn.weights[:] = s_w.data.astype(np.float32)
36
+ conn.indptr[:] = s_w.indptr.astype(np.uint32)
37
+ conn.indices[:] = s_w.indices.astype(np.uint32)
38
+ conn.idelays[:] = idelays
39
+
40
+ return conn
41
+
42
+
43
+ def prepare_vec(x, num_batch, dtype=np.float32):
44
+ """
45
+ Check and prepare vector dimension and type.
46
+
47
+ Parameters
48
+ ----------
49
+ x: array 1d
50
+ vector to be prepared, if x is a scalar, only the type is modified.
51
+ num_batch: int
52
+ number of batched simulations.
53
+
54
+ Returns
55
+ -------
56
+ x: array [len(x), num_batch]
57
+ prepared vector.
58
+
59
+ """
60
+
61
+ if not is_seq(x):
62
+ return dtype(x)
63
+ else:
64
+ x = np.array(x)
65
+ if x.ndim == 1:
66
+ x = repmat_vec(x, num_batch, "cpu")
67
+ elif x.ndim == 2:
68
+ assert x.shape[1] == num_batch, "second dimension of x must be equal to ns"
69
+ x = x.astype(dtype)
70
+ else:
71
+ raise ValueError("x.ndim must be 1 or 2")
72
+ return x
vbi/papers/__init__.py ADDED
File without changes
@@ -0,0 +1,211 @@
1
+ import os
2
+ import os.path
3
+ import numpy as np
4
+ from numpy import pi
5
+ from os.path import join
6
+ import matplotlib.pyplot as plt
7
+ from jitcdde import jitcdde, y, t
8
+ from symengine import sin, cos, Symbol, symarray, exp
9
+ import warnings
10
+
11
+ warnings.filterwarnings("ignore")
12
+
13
+
14
+ class Pav:
15
+ """
16
+ This class represents Wilson-Cowan model for Parkinson's disease.
17
+
18
+ Reference:
19
+ - Pavlides, A., Hogan, S.J. and Bogacz, R., 2015. Computational models describing possible mechanisms for generation of excessive beta oscillations in Parkinson's disease. PLoS computational biology, 11(12), p.e1004609.
20
+ """
21
+
22
+ def __init__(self, par: dict = {}):
23
+
24
+ _par = self.get_default_params()
25
+ _par.update(par)
26
+ for item in _par.items():
27
+ if item[0] not in _par["control"]:
28
+ setattr(self, item[0], item[1])
29
+
30
+ self.control_pars = []
31
+ if len(_par["control"]) > 0:
32
+ for i in _par["control"]:
33
+ value = Symbol(i)
34
+ setattr(self, i, value)
35
+ self.control_pars.append(value)
36
+
37
+ if not "modulename" in par.keys():
38
+ self.modulename = "pav"
39
+ os.makedirs(self.output, exist_ok=True)
40
+
41
+ def get_default_params(self):
42
+ """
43
+ Return a dictionary of default parameters for the model.
44
+ """
45
+ par = {
46
+ "control": "", # list of control parameters
47
+ "verbose": False, # print compilation information
48
+ "openmp": False, # use openmp
49
+ "output": "output", # output directory
50
+ "initial_state": None, # initial state of the system
51
+ "t_end": 1000.0, # end time of the simulation
52
+ "t_cut": 0.0, # cut time of the simulation
53
+ "seed": None, # seed for random number generator
54
+ "n_components": 4, # number of components in the system
55
+ "interval": 0.1, # interval for saving the state of the system
56
+
57
+ "Tsg": 6.0, # ms delay between subthalamic and globus pallidus
58
+ "Tgs": 6.0, # ms delay between globus pallidus and subthalamic
59
+ "Tgg": 4.0, # ms delay between globus pallidus and globus pallidus
60
+ "Tcs": 5.5, # ms delay between cortex and subthalamic
61
+ "Tsc": 21.5, # ms delay between subthalamic and cortex
62
+ "Tcc": 4.65, # ms delay between cortex and cortex
63
+
64
+ "taus": 12.8, # ms time constant for subthalamic
65
+ "taug": 20.0, # ms time constant for globus pallidus
66
+ "taue": 11.59, # ms time constant for excitatory neurons
67
+ "taui": 13.02, # ms time constant for inhibitory neurons
68
+
69
+ "Ms": 300.0/1000, # spk/ms maximum firing rate of subthalamic
70
+ "Mg": 400.0/1000, # spk/ms maximum firing rate of globus pallidus
71
+ "Me": 75.77/1000, # spk/ms maximum firing rate of excitatory neurons
72
+ "Mi": 205.72/1000, # spk/ms maximum firing rate of inhibitory neurons
73
+
74
+ "Bs": 10.0/1000, # spk/ms baseline firing rate of subthalamic
75
+ "Bg": 20.0/1000, # spk/ms baseline firing rate of globus pallidus
76
+ "Be": 17.85/1000, # spk/ms population firing rate of excitatory neurons
77
+ "Bi": 9.87/1000, # spk/ms population firing rate of inhibitory neurons
78
+
79
+ "C": 172.18/1000, # spk/s external input to cortex
80
+ "Str": 8.46/1000, # spk/s external input to striatum
81
+
82
+ "wgs": 1.33, # synaptic weight from globus pallidus to subthalamic
83
+ "wsg": 4.87, # synaptic weight from subthalamic to globus pallidus
84
+ "wgg": 0.53, # synaptic weight from globus pallidus to globus pallidus
85
+ "wcs": 9.97, # synaptic weight from cortex to subthalamic
86
+ "wsc": 8.93, # synaptic weight from subthalamic to cortex
87
+ "wcc": 6.17, # synaptic weight from cortex to cortex
88
+ }
89
+ return par
90
+
91
+
92
+ def sys_eqs(self):
93
+
94
+ inS = self.wcs * y(2, t - self.Tcs) - self.wgs * y(1, t - self.Tgs)
95
+ inG = self.wsg * y(0, t - self.Tsg) - self.wgg * y(1, t - self.Tgg) - self.Str
96
+ inE = -self.wsc * y(0, t - self.Tsc) - self.wcc * y(3, t - self.Tcc) + self.C
97
+ inI = self.wcc * y(2, t - self.Tcc)
98
+
99
+ yield ((self.Ms/((1+exp(-4*inS/self.Ms)*((self.Ms-self.Bs)/self.Bs)))) - y(0))*(1/self.taus)
100
+ yield ((self.Mg/((1+exp(-4*inG/self.Mg)*((self.Mg-self.Bg)/self.Bg)))) - y(1))*(1/self.taug)
101
+ yield ((self.Me/((1+exp(-4*inE/self.Me)*((self.Me-self.Be)/self.Be)))) - y(2))*(1/self.taue)
102
+ yield ((self.Mi/((1+exp(-4*inI/self.Mi)*((self.Mi-self.Bi)/self.Bi)))) - y(3))*(1/self.taui)
103
+
104
+
105
+ def compile(self, **kwargs):
106
+ control_pars = self.control_pars if len(self.control_pars) > 0 else ()
107
+ I = jitcdde(
108
+ self.sys_eqs,
109
+ n=self.n_components,
110
+ verbose=self.verbose,
111
+ control_pars=control_pars,
112
+ )
113
+ I.compile_C(omp=self.openmp, **kwargs)
114
+ I.save_compiled(overwrite=True, destination=join(self.output, self.modulename))
115
+
116
+ def set_initial_state(self, seed=None):
117
+ if seed is not None:
118
+ np.random.seed(seed)
119
+ initial_state = np.zeros(4)
120
+ return initial_state
121
+
122
+ def run(
123
+ self,
124
+ par=[],
125
+ disc="step_on",
126
+ step=0.001,
127
+ propagations=1,
128
+ min_distance=1e-5,
129
+ max_step=None,
130
+ shift_ratio=1e-4,
131
+ **integrator_params
132
+ ):
133
+ """
134
+ integrate the system of equations and return the
135
+ computed state of the system after integration and times
136
+
137
+ Parameters
138
+ ------------
139
+
140
+ par : list
141
+ values of control parameters in order of appearance in `control`
142
+ disc : str
143
+ type of discontinuities handling. The default value is blind
144
+ - step_on [step_on_discontinuities]
145
+ - blind [integrate_blindly]
146
+ - adjust [adjust_diff]
147
+ step : float
148
+ argument for integrate_blindly aspired step size. The actual step size may be slightly adapted to make it divide the integration time. If `None`, `0`, or otherwise falsy, the maximum step size as set with `max_step` of `set_integration_parameters` is used.
149
+
150
+ propagations : int
151
+ argument for step_on_discontinuities: how often the discontinuity has to propagate to before it's considered smoothed.
152
+ min_distance : float
153
+ argument for step_on_discontinuities: If two required steps are closer than this, they will be treated as one.
154
+ max_step : float
155
+ argument for step_on_discontinuities: Retired parameter. Steps are now automatically adapted.
156
+ shift_ratio : float
157
+ argument for adjust_diff. Performs a zero-amplitude (backwards) `jump` whose `width` is `shift_ratio` times the distance to the previous anchor into the past. See the documentation of `jump` for the caveats of this and see `discontinuities` for more information on why you almost certainly need to use this or an alternative way to address initial discontinuities.
158
+
159
+ Return : dict(t, x)
160
+ - **t** times
161
+ - **x** coordinates.
162
+ """
163
+
164
+ if self.initial_state is None:
165
+ self.initial_state = self.set_initial_state(self.seed)
166
+
167
+ I = jitcdde(
168
+ self.sys_eqs,
169
+ n=4,
170
+ control_pars=self.control_pars,
171
+ module_location=join(self.output, self.modulename + ".so"),
172
+ )
173
+ I.set_integration_parameters(**integrator_params)
174
+ I.constant_past(self.initial_state, time=0.0)
175
+
176
+ if disc == "blind":
177
+ I.integrate_blindly(self.initial_state, step=step)
178
+ elif disc == "step_on":
179
+ I.step_on_discontinuities(
180
+ propagations=propagations, min_distance=min_distance, max_step=max_step
181
+ )
182
+ else:
183
+ I.adjust_diff(shift_ratio=shift_ratio)
184
+
185
+ if len(self.control_pars) > 0:
186
+ I.set_parameters(par)
187
+ tcut = max(self.t_cut, I.t)
188
+ times = tcut + np.arange(0, self.t_end - tcut, self.interval)
189
+
190
+ x = np.zeros((len(times), self.n_components))
191
+ for i in range(len(times)):
192
+ x[i, :] = I.integrate(times[i])
193
+
194
+ return {"t": times, "x": x}
195
+
196
+ if __name__ == "__main__":
197
+
198
+ par = {"control": "",
199
+ "output": "output"
200
+ }
201
+ ode = Pav(par=par)
202
+ ode.compile()
203
+ data = ode.run(disc="step_on")
204
+ times = data["t"]
205
+ x = data["x"]
206
+
207
+ print("Times: ", times.shape)
208
+ print("Coordinates: ", x.shape)
209
+ plt.plot(times, x)
210
+ plt.legend(["Subthalamic", "Globus Pallidus", "Excitatory", "Inhibitory"])
211
+ plt.show()
vbi/tests/__init__.py ADDED
File without changes
@@ -0,0 +1,36 @@
1
+ import torch
2
+ import unittest
3
+ import numpy as np
4
+ import networkx as nx
5
+ from vbi.models.numba.mpr import MPR_sde
6
+
7
+ seed = 2
8
+ np.random.seed(seed)
9
+ torch.manual_seed(seed)
10
+
11
+ nn = 3
12
+ g = nx.complete_graph(nn)
13
+ sc = nx.to_numpy_array(g)/ 10.0
14
+
15
+ class testMPRSDE(unittest.TestCase):
16
+
17
+ mpr = MPR_sde()
18
+ p = mpr.get_default_parameters()
19
+ p['weights'] = sc
20
+ p['seed'] = seed
21
+ p['t_cut'] = 0.01 * 60 * 1000
22
+ p['t_end'] = 0.02 * 60 * 1000
23
+
24
+ def test_invalid_parameter_raises_value_error(self):
25
+ invalid_params = {"invalid_param": 42}
26
+ with self.assertRaises(ValueError):
27
+ MPR_sde(par=invalid_params)
28
+
29
+ def test_run(self):
30
+
31
+ control = {"G": 0.1, "eta": -4.7}
32
+ mpr = MPR_sde(self.p)
33
+ sol = mpr.run(par=control)
34
+ x = sol["x"]
35
+ t = sol["t"]
36
+ self.assertEqual(x.shape[0], nn)