vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vbi/__init__.py +37 -0
- vbi/_version.py +17 -0
- vbi/dataset/__init__.py +0 -0
- vbi/dataset/connectivity_84/centers.txt +84 -0
- vbi/dataset/connectivity_84/centres.txt +84 -0
- vbi/dataset/connectivity_84/cortical.txt +84 -0
- vbi/dataset/connectivity_84/tract_lengths.txt +84 -0
- vbi/dataset/connectivity_84/weights.txt +84 -0
- vbi/dataset/connectivity_88/Aud_88.txt +88 -0
- vbi/dataset/connectivity_88/Bold.npz +0 -0
- vbi/dataset/connectivity_88/Labels.txt +17 -0
- vbi/dataset/connectivity_88/Region_labels.txt +88 -0
- vbi/dataset/connectivity_88/tract_lengths.txt +88 -0
- vbi/dataset/connectivity_88/weights.txt +88 -0
- vbi/feature_extraction/__init__.py +1 -0
- vbi/feature_extraction/calc_features.py +293 -0
- vbi/feature_extraction/features.json +535 -0
- vbi/feature_extraction/features.py +2124 -0
- vbi/feature_extraction/features_settings.py +374 -0
- vbi/feature_extraction/features_utils.py +1357 -0
- vbi/feature_extraction/infodynamics.jar +0 -0
- vbi/feature_extraction/utility.py +507 -0
- vbi/inference.py +98 -0
- vbi/models/__init__.py +0 -0
- vbi/models/cpp/__init__.py +0 -0
- vbi/models/cpp/_src/__init__.py +0 -0
- vbi/models/cpp/_src/__pycache__/mpr_sde.cpython-310.pyc +0 -0
- vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/bold.hpp +303 -0
- vbi/models/cpp/_src/do.hpp +167 -0
- vbi/models/cpp/_src/do.i +17 -0
- vbi/models/cpp/_src/do.py +467 -0
- vbi/models/cpp/_src/do_wrap.cxx +12811 -0
- vbi/models/cpp/_src/jr_sdde.hpp +352 -0
- vbi/models/cpp/_src/jr_sdde.i +19 -0
- vbi/models/cpp/_src/jr_sdde.py +688 -0
- vbi/models/cpp/_src/jr_sdde_wrap.cxx +18718 -0
- vbi/models/cpp/_src/jr_sde.hpp +264 -0
- vbi/models/cpp/_src/jr_sde.i +17 -0
- vbi/models/cpp/_src/jr_sde.py +470 -0
- vbi/models/cpp/_src/jr_sde_wrap.cxx +13406 -0
- vbi/models/cpp/_src/km_sde.hpp +158 -0
- vbi/models/cpp/_src/km_sde.i +19 -0
- vbi/models/cpp/_src/km_sde.py +671 -0
- vbi/models/cpp/_src/km_sde_wrap.cxx +17367 -0
- vbi/models/cpp/_src/makefile +52 -0
- vbi/models/cpp/_src/mpr_sde.hpp +327 -0
- vbi/models/cpp/_src/mpr_sde.i +19 -0
- vbi/models/cpp/_src/mpr_sde.py +711 -0
- vbi/models/cpp/_src/mpr_sde_wrap.cxx +18618 -0
- vbi/models/cpp/_src/utility.hpp +307 -0
- vbi/models/cpp/_src/vep.hpp +171 -0
- vbi/models/cpp/_src/vep.i +16 -0
- vbi/models/cpp/_src/vep.py +464 -0
- vbi/models/cpp/_src/vep_wrap.cxx +12968 -0
- vbi/models/cpp/_src/wc_ode.hpp +294 -0
- vbi/models/cpp/_src/wc_ode.i +19 -0
- vbi/models/cpp/_src/wc_ode.py +686 -0
- vbi/models/cpp/_src/wc_ode_wrap.cxx +24263 -0
- vbi/models/cpp/damp_oscillator.py +143 -0
- vbi/models/cpp/jansen_rit.py +543 -0
- vbi/models/cpp/km.py +187 -0
- vbi/models/cpp/mpr.py +289 -0
- vbi/models/cpp/vep.py +150 -0
- vbi/models/cpp/wc.py +216 -0
- vbi/models/cupy/__init__.py +0 -0
- vbi/models/cupy/bold.py +111 -0
- vbi/models/cupy/ghb.py +284 -0
- vbi/models/cupy/jansen_rit.py +473 -0
- vbi/models/cupy/km.py +224 -0
- vbi/models/cupy/mpr.py +475 -0
- vbi/models/cupy/mpr_modified_bold.py +12 -0
- vbi/models/cupy/utils.py +184 -0
- vbi/models/numba/__init__.py +0 -0
- vbi/models/numba/_ww_EI.py +444 -0
- vbi/models/numba/damp_oscillator.py +162 -0
- vbi/models/numba/ghb.py +208 -0
- vbi/models/numba/mpr.py +383 -0
- vbi/models/pytorch/__init__.py +0 -0
- vbi/models/pytorch/data/default_parameters.npz +0 -0
- vbi/models/pytorch/data/input/ROI_sim.mat +0 -0
- vbi/models/pytorch/data/input/fc_test.csv +68 -0
- vbi/models/pytorch/data/input/fc_train.csv +68 -0
- vbi/models/pytorch/data/input/fc_vali.csv +68 -0
- vbi/models/pytorch/data/input/fcd_test.mat +0 -0
- vbi/models/pytorch/data/input/fcd_test_high_window.mat +0 -0
- vbi/models/pytorch/data/input/fcd_test_low_window.mat +0 -0
- vbi/models/pytorch/data/input/fcd_train.mat +0 -0
- vbi/models/pytorch/data/input/fcd_vali.mat +0 -0
- vbi/models/pytorch/data/input/myelin.csv +68 -0
- vbi/models/pytorch/data/input/rsfc_gradient.csv +68 -0
- vbi/models/pytorch/data/input/run_label_testset.mat +0 -0
- vbi/models/pytorch/data/input/sc_test.csv +68 -0
- vbi/models/pytorch/data/input/sc_train.csv +68 -0
- vbi/models/pytorch/data/input/sc_vali.csv +68 -0
- vbi/models/pytorch/data/obs_kong0.npz +0 -0
- vbi/models/pytorch/ww_sde_kong.py +570 -0
- vbi/models/tvbk/__init__.py +9 -0
- vbi/models/tvbk/tvbk_wrapper.py +166 -0
- vbi/models/tvbk/utils.py +72 -0
- vbi/papers/__init__.py +0 -0
- vbi/papers/pavlides_pcb_2015/pavlides.py +211 -0
- vbi/tests/__init__.py +0 -0
- vbi/tests/_test_mpr_nb.py +36 -0
- vbi/tests/test_features.py +355 -0
- vbi/tests/test_ghb_cupy.py +90 -0
- vbi/tests/test_mpr_cupy.py +49 -0
- vbi/tests/test_mpr_numba.py +84 -0
- vbi/tests/test_suite.py +19 -0
- vbi/utils.py +402 -0
- vbi-0.1.3.dist-info/METADATA +166 -0
- vbi-0.1.3.dist-info/RECORD +121 -0
- vbi-0.1.3.dist-info/WHEEL +5 -0
- vbi-0.1.3.dist-info/licenses/LICENSE +201 -0
- vbi-0.1.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,166 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import collections
|
3
|
+
from vbi.models.tvbk.utils import prepare_vec, setup_connectivity
|
4
|
+
|
5
|
+
try:
|
6
|
+
import tvbk as m
|
7
|
+
|
8
|
+
TVBK_AVAILABLE = True
|
9
|
+
except ImportError:
|
10
|
+
TVBK_AVAILABLE = False
|
11
|
+
|
12
|
+
|
13
|
+
class MPR:
|
14
|
+
|
15
|
+
MPRTheta = collections.namedtuple(
|
16
|
+
typename="MPRTheta", field_names="tau I Delta J eta G".split(" ")
|
17
|
+
)
|
18
|
+
num_svar = 2 # number of state variables
|
19
|
+
num_parm = 6 # number of parameters
|
20
|
+
|
21
|
+
def __init__(self, par: dict = {}) -> None:
|
22
|
+
|
23
|
+
self._par = self.get_default_parameters()
|
24
|
+
self.valid_parameters = list(self._par.keys())
|
25
|
+
self.check_parameters(par)
|
26
|
+
self._par.update(par)
|
27
|
+
|
28
|
+
for item in self._par.items():
|
29
|
+
setattr(self, item[0], item[1])
|
30
|
+
|
31
|
+
self.mpr_default_theta = self.MPRTheta(
|
32
|
+
tau=self._par["tau"],
|
33
|
+
I=self._par["I"],
|
34
|
+
Delta=self._par["Delta"],
|
35
|
+
J=self._par["J"],
|
36
|
+
eta=self._par["eta"],
|
37
|
+
G=self._par["G"],
|
38
|
+
)
|
39
|
+
|
40
|
+
def __str__(self) -> str:
|
41
|
+
return f"MPR model with parameters: {self._par}"
|
42
|
+
|
43
|
+
def get_default_parameters(self) -> dict:
|
44
|
+
return {
|
45
|
+
"tau": 1.0,
|
46
|
+
"Delta": 1.0,
|
47
|
+
"I": 0.0,
|
48
|
+
"J": 15.0,
|
49
|
+
"eta": -5.0,
|
50
|
+
"G": 1.0,
|
51
|
+
"num_batch": 1,
|
52
|
+
"horizon": 256,
|
53
|
+
"width": 8,
|
54
|
+
"dt": 0.01,
|
55
|
+
"dtype": np.float32,
|
56
|
+
"weights": None,
|
57
|
+
"delays": None,
|
58
|
+
"num_node": None,
|
59
|
+
"noise_amp": None,
|
60
|
+
"num_time": 1000,
|
61
|
+
"decimate_rv": 10,
|
62
|
+
"RECORD_RV": True,
|
63
|
+
"RECORD_BOLD": False, # TODO: Add BOLD recording
|
64
|
+
}
|
65
|
+
|
66
|
+
def check_parameters(self, par):
|
67
|
+
for key in par.keys():
|
68
|
+
if key not in self.valid_parameters:
|
69
|
+
raise ValueError(f"Invalid parameter {key:s} provided.")
|
70
|
+
|
71
|
+
|
72
|
+
def initialize_buffers(self):
|
73
|
+
|
74
|
+
total_volume = self.num_batch * self.num_node * self.horizon * self.width
|
75
|
+
self.cx = m.Cx8s(self.num_node, self.horizon, self.num_batch)
|
76
|
+
buf_val = (
|
77
|
+
np.r_[: 1.0 : 1j * total_volume]
|
78
|
+
.reshape(self.num_batch, self.num_node, self.horizon, self.width)
|
79
|
+
.astype(self.dtype)
|
80
|
+
* 4.0
|
81
|
+
)
|
82
|
+
self.cx.buf[:] = buf_val
|
83
|
+
self.cx.cx1[:] = self.cx.cx2[:] = 0.0
|
84
|
+
|
85
|
+
|
86
|
+
def prepare_input(self):
|
87
|
+
|
88
|
+
width = self.width
|
89
|
+
num_batch = self.num_batch
|
90
|
+
num_parm = self.num_parm
|
91
|
+
|
92
|
+
assert self.weights is not None, "weights must be provided"
|
93
|
+
self.weights = np.array(self.weights)
|
94
|
+
num_node = self.num_node = self.weights.shape[0]
|
95
|
+
|
96
|
+
self.G = prepare_vec(self.G, num_batch, self.dtype)
|
97
|
+
self.J = prepare_vec(self.J, num_batch, self.dtype)
|
98
|
+
self.I = prepare_vec(self.I, num_batch, self.dtype)
|
99
|
+
self.eta = prepare_vec(self.eta, num_batch, self.dtype)
|
100
|
+
self.tau = prepare_vec(self.tau, num_batch, self.dtype)
|
101
|
+
self.Delta = prepare_vec(self.Delta, num_batch, self.dtype)
|
102
|
+
self.noise_amp = np.array(self.noise_amp, self.dtype)
|
103
|
+
self.p = np.zeros((num_batch, num_node, num_parm, width), self.dtype)
|
104
|
+
|
105
|
+
self.p[:, :, 0, :] = self.tau
|
106
|
+
self.p[:, :, 1, :] = self.I
|
107
|
+
self.p[:, :, 2, :] = self.Delta
|
108
|
+
self.p[:, :, 3, :] = self.J
|
109
|
+
self.p[:, :, 4, :] = self.eta
|
110
|
+
self.p[:, :, 5, :] = self.G
|
111
|
+
|
112
|
+
self.initialize_buffers()
|
113
|
+
self.conn = setup_connectivity(self.weights, self.delays)
|
114
|
+
|
115
|
+
|
116
|
+
def run(self):
|
117
|
+
|
118
|
+
self.prepare_input()
|
119
|
+
|
120
|
+
x = np.zeros((self.num_batch, self.num_svar, self.num_node, self.width), self.dtype)
|
121
|
+
y = np.zeros_like(x)
|
122
|
+
z = np.zeros((self.num_batch, self.num_svar, 8), self.dtype)+ self.noise_amp
|
123
|
+
seed = np.zeros((self.num_batch, 8, 4), np.uint64)
|
124
|
+
num_samples = self.num_time // self.decimate_rv + 1
|
125
|
+
|
126
|
+
if self.RECORD_RV:
|
127
|
+
trace_c = np.zeros(
|
128
|
+
(
|
129
|
+
num_samples,
|
130
|
+
self.num_batch,
|
131
|
+
self.num_svar,
|
132
|
+
self.num_node,
|
133
|
+
self.width,
|
134
|
+
)
|
135
|
+
)
|
136
|
+
|
137
|
+
for i in range(num_samples):
|
138
|
+
if self.RECORD_RV:
|
139
|
+
m.step_mpr(
|
140
|
+
self.cx,
|
141
|
+
self.conn,
|
142
|
+
x,
|
143
|
+
y,
|
144
|
+
z,
|
145
|
+
self.p,
|
146
|
+
i * self.decimate_rv,
|
147
|
+
self.decimate_rv,
|
148
|
+
self.dt,
|
149
|
+
seed
|
150
|
+
)
|
151
|
+
if self.RECORD_RV:
|
152
|
+
trace_c[i] = x
|
153
|
+
|
154
|
+
# TODO: Calculate BOLD signal
|
155
|
+
if self.RECORD_BOLD:
|
156
|
+
pass
|
157
|
+
# add BOLD signal calculation pytorch code here
|
158
|
+
|
159
|
+
|
160
|
+
return {
|
161
|
+
"rv_t": ...,
|
162
|
+
"rv_d": trace_c,
|
163
|
+
"fmri_t": ...,
|
164
|
+
"fmri_d": None,
|
165
|
+
}
|
166
|
+
|
vbi/models/tvbk/utils.py
ADDED
@@ -0,0 +1,72 @@
|
|
1
|
+
import tvbk as m
|
2
|
+
import numpy as np
|
3
|
+
import scipy.sparse
|
4
|
+
import collections
|
5
|
+
from vbi.models.cupy.utils import repmat_vec, is_seq # TODO move it to vbi.utils
|
6
|
+
|
7
|
+
|
8
|
+
def setup_connectivity(weights, delays=None):
|
9
|
+
"""
|
10
|
+
Sets up connectivity using provided weights and delays.
|
11
|
+
|
12
|
+
Args:
|
13
|
+
weights (np.ndarray): The weight matrix.
|
14
|
+
delays (np.ndarray): The delay matrix.
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
conn: The connection object.
|
18
|
+
"""
|
19
|
+
|
20
|
+
# Convert weights to sparse matrix
|
21
|
+
s_w = scipy.sparse.csr_matrix(weights)
|
22
|
+
num_node = weights.shape[0]
|
23
|
+
|
24
|
+
# TODO! check how to handle delays if not provided
|
25
|
+
if delays is None:
|
26
|
+
delays = np.zeros_like(weights)
|
27
|
+
|
28
|
+
# Ensure delays are valid
|
29
|
+
idelays = (delays[weights != 0]).astype(np.uint32) + 2
|
30
|
+
assert idelays.max() < delays.shape[1]
|
31
|
+
assert idelays.min() >= 2
|
32
|
+
|
33
|
+
# Create the connection object
|
34
|
+
conn = m.Conn(num_node, s_w.data.size)
|
35
|
+
conn.weights[:] = s_w.data.astype(np.float32)
|
36
|
+
conn.indptr[:] = s_w.indptr.astype(np.uint32)
|
37
|
+
conn.indices[:] = s_w.indices.astype(np.uint32)
|
38
|
+
conn.idelays[:] = idelays
|
39
|
+
|
40
|
+
return conn
|
41
|
+
|
42
|
+
|
43
|
+
def prepare_vec(x, num_batch, dtype=np.float32):
|
44
|
+
"""
|
45
|
+
Check and prepare vector dimension and type.
|
46
|
+
|
47
|
+
Parameters
|
48
|
+
----------
|
49
|
+
x: array 1d
|
50
|
+
vector to be prepared, if x is a scalar, only the type is modified.
|
51
|
+
num_batch: int
|
52
|
+
number of batched simulations.
|
53
|
+
|
54
|
+
Returns
|
55
|
+
-------
|
56
|
+
x: array [len(x), num_batch]
|
57
|
+
prepared vector.
|
58
|
+
|
59
|
+
"""
|
60
|
+
|
61
|
+
if not is_seq(x):
|
62
|
+
return dtype(x)
|
63
|
+
else:
|
64
|
+
x = np.array(x)
|
65
|
+
if x.ndim == 1:
|
66
|
+
x = repmat_vec(x, num_batch, "cpu")
|
67
|
+
elif x.ndim == 2:
|
68
|
+
assert x.shape[1] == num_batch, "second dimension of x must be equal to ns"
|
69
|
+
x = x.astype(dtype)
|
70
|
+
else:
|
71
|
+
raise ValueError("x.ndim must be 1 or 2")
|
72
|
+
return x
|
vbi/papers/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,211 @@
|
|
1
|
+
import os
|
2
|
+
import os.path
|
3
|
+
import numpy as np
|
4
|
+
from numpy import pi
|
5
|
+
from os.path import join
|
6
|
+
import matplotlib.pyplot as plt
|
7
|
+
from jitcdde import jitcdde, y, t
|
8
|
+
from symengine import sin, cos, Symbol, symarray, exp
|
9
|
+
import warnings
|
10
|
+
|
11
|
+
warnings.filterwarnings("ignore")
|
12
|
+
|
13
|
+
|
14
|
+
class Pav:
|
15
|
+
"""
|
16
|
+
This class represents Wilson-Cowan model for Parkinson's disease.
|
17
|
+
|
18
|
+
Reference:
|
19
|
+
- Pavlides, A., Hogan, S.J. and Bogacz, R., 2015. Computational models describing possible mechanisms for generation of excessive beta oscillations in Parkinson's disease. PLoS computational biology, 11(12), p.e1004609.
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self, par: dict = {}):
|
23
|
+
|
24
|
+
_par = self.get_default_params()
|
25
|
+
_par.update(par)
|
26
|
+
for item in _par.items():
|
27
|
+
if item[0] not in _par["control"]:
|
28
|
+
setattr(self, item[0], item[1])
|
29
|
+
|
30
|
+
self.control_pars = []
|
31
|
+
if len(_par["control"]) > 0:
|
32
|
+
for i in _par["control"]:
|
33
|
+
value = Symbol(i)
|
34
|
+
setattr(self, i, value)
|
35
|
+
self.control_pars.append(value)
|
36
|
+
|
37
|
+
if not "modulename" in par.keys():
|
38
|
+
self.modulename = "pav"
|
39
|
+
os.makedirs(self.output, exist_ok=True)
|
40
|
+
|
41
|
+
def get_default_params(self):
|
42
|
+
"""
|
43
|
+
Return a dictionary of default parameters for the model.
|
44
|
+
"""
|
45
|
+
par = {
|
46
|
+
"control": "", # list of control parameters
|
47
|
+
"verbose": False, # print compilation information
|
48
|
+
"openmp": False, # use openmp
|
49
|
+
"output": "output", # output directory
|
50
|
+
"initial_state": None, # initial state of the system
|
51
|
+
"t_end": 1000.0, # end time of the simulation
|
52
|
+
"t_cut": 0.0, # cut time of the simulation
|
53
|
+
"seed": None, # seed for random number generator
|
54
|
+
"n_components": 4, # number of components in the system
|
55
|
+
"interval": 0.1, # interval for saving the state of the system
|
56
|
+
|
57
|
+
"Tsg": 6.0, # ms delay between subthalamic and globus pallidus
|
58
|
+
"Tgs": 6.0, # ms delay between globus pallidus and subthalamic
|
59
|
+
"Tgg": 4.0, # ms delay between globus pallidus and globus pallidus
|
60
|
+
"Tcs": 5.5, # ms delay between cortex and subthalamic
|
61
|
+
"Tsc": 21.5, # ms delay between subthalamic and cortex
|
62
|
+
"Tcc": 4.65, # ms delay between cortex and cortex
|
63
|
+
|
64
|
+
"taus": 12.8, # ms time constant for subthalamic
|
65
|
+
"taug": 20.0, # ms time constant for globus pallidus
|
66
|
+
"taue": 11.59, # ms time constant for excitatory neurons
|
67
|
+
"taui": 13.02, # ms time constant for inhibitory neurons
|
68
|
+
|
69
|
+
"Ms": 300.0/1000, # spk/ms maximum firing rate of subthalamic
|
70
|
+
"Mg": 400.0/1000, # spk/ms maximum firing rate of globus pallidus
|
71
|
+
"Me": 75.77/1000, # spk/ms maximum firing rate of excitatory neurons
|
72
|
+
"Mi": 205.72/1000, # spk/ms maximum firing rate of inhibitory neurons
|
73
|
+
|
74
|
+
"Bs": 10.0/1000, # spk/ms baseline firing rate of subthalamic
|
75
|
+
"Bg": 20.0/1000, # spk/ms baseline firing rate of globus pallidus
|
76
|
+
"Be": 17.85/1000, # spk/ms population firing rate of excitatory neurons
|
77
|
+
"Bi": 9.87/1000, # spk/ms population firing rate of inhibitory neurons
|
78
|
+
|
79
|
+
"C": 172.18/1000, # spk/s external input to cortex
|
80
|
+
"Str": 8.46/1000, # spk/s external input to striatum
|
81
|
+
|
82
|
+
"wgs": 1.33, # synaptic weight from globus pallidus to subthalamic
|
83
|
+
"wsg": 4.87, # synaptic weight from subthalamic to globus pallidus
|
84
|
+
"wgg": 0.53, # synaptic weight from globus pallidus to globus pallidus
|
85
|
+
"wcs": 9.97, # synaptic weight from cortex to subthalamic
|
86
|
+
"wsc": 8.93, # synaptic weight from subthalamic to cortex
|
87
|
+
"wcc": 6.17, # synaptic weight from cortex to cortex
|
88
|
+
}
|
89
|
+
return par
|
90
|
+
|
91
|
+
|
92
|
+
def sys_eqs(self):
|
93
|
+
|
94
|
+
inS = self.wcs * y(2, t - self.Tcs) - self.wgs * y(1, t - self.Tgs)
|
95
|
+
inG = self.wsg * y(0, t - self.Tsg) - self.wgg * y(1, t - self.Tgg) - self.Str
|
96
|
+
inE = -self.wsc * y(0, t - self.Tsc) - self.wcc * y(3, t - self.Tcc) + self.C
|
97
|
+
inI = self.wcc * y(2, t - self.Tcc)
|
98
|
+
|
99
|
+
yield ((self.Ms/((1+exp(-4*inS/self.Ms)*((self.Ms-self.Bs)/self.Bs)))) - y(0))*(1/self.taus)
|
100
|
+
yield ((self.Mg/((1+exp(-4*inG/self.Mg)*((self.Mg-self.Bg)/self.Bg)))) - y(1))*(1/self.taug)
|
101
|
+
yield ((self.Me/((1+exp(-4*inE/self.Me)*((self.Me-self.Be)/self.Be)))) - y(2))*(1/self.taue)
|
102
|
+
yield ((self.Mi/((1+exp(-4*inI/self.Mi)*((self.Mi-self.Bi)/self.Bi)))) - y(3))*(1/self.taui)
|
103
|
+
|
104
|
+
|
105
|
+
def compile(self, **kwargs):
|
106
|
+
control_pars = self.control_pars if len(self.control_pars) > 0 else ()
|
107
|
+
I = jitcdde(
|
108
|
+
self.sys_eqs,
|
109
|
+
n=self.n_components,
|
110
|
+
verbose=self.verbose,
|
111
|
+
control_pars=control_pars,
|
112
|
+
)
|
113
|
+
I.compile_C(omp=self.openmp, **kwargs)
|
114
|
+
I.save_compiled(overwrite=True, destination=join(self.output, self.modulename))
|
115
|
+
|
116
|
+
def set_initial_state(self, seed=None):
|
117
|
+
if seed is not None:
|
118
|
+
np.random.seed(seed)
|
119
|
+
initial_state = np.zeros(4)
|
120
|
+
return initial_state
|
121
|
+
|
122
|
+
def run(
|
123
|
+
self,
|
124
|
+
par=[],
|
125
|
+
disc="step_on",
|
126
|
+
step=0.001,
|
127
|
+
propagations=1,
|
128
|
+
min_distance=1e-5,
|
129
|
+
max_step=None,
|
130
|
+
shift_ratio=1e-4,
|
131
|
+
**integrator_params
|
132
|
+
):
|
133
|
+
"""
|
134
|
+
integrate the system of equations and return the
|
135
|
+
computed state of the system after integration and times
|
136
|
+
|
137
|
+
Parameters
|
138
|
+
------------
|
139
|
+
|
140
|
+
par : list
|
141
|
+
values of control parameters in order of appearance in `control`
|
142
|
+
disc : str
|
143
|
+
type of discontinuities handling. The default value is blind
|
144
|
+
- step_on [step_on_discontinuities]
|
145
|
+
- blind [integrate_blindly]
|
146
|
+
- adjust [adjust_diff]
|
147
|
+
step : float
|
148
|
+
argument for integrate_blindly aspired step size. The actual step size may be slightly adapted to make it divide the integration time. If `None`, `0`, or otherwise falsy, the maximum step size as set with `max_step` of `set_integration_parameters` is used.
|
149
|
+
|
150
|
+
propagations : int
|
151
|
+
argument for step_on_discontinuities: how often the discontinuity has to propagate to before it's considered smoothed.
|
152
|
+
min_distance : float
|
153
|
+
argument for step_on_discontinuities: If two required steps are closer than this, they will be treated as one.
|
154
|
+
max_step : float
|
155
|
+
argument for step_on_discontinuities: Retired parameter. Steps are now automatically adapted.
|
156
|
+
shift_ratio : float
|
157
|
+
argument for adjust_diff. Performs a zero-amplitude (backwards) `jump` whose `width` is `shift_ratio` times the distance to the previous anchor into the past. See the documentation of `jump` for the caveats of this and see `discontinuities` for more information on why you almost certainly need to use this or an alternative way to address initial discontinuities.
|
158
|
+
|
159
|
+
Return : dict(t, x)
|
160
|
+
- **t** times
|
161
|
+
- **x** coordinates.
|
162
|
+
"""
|
163
|
+
|
164
|
+
if self.initial_state is None:
|
165
|
+
self.initial_state = self.set_initial_state(self.seed)
|
166
|
+
|
167
|
+
I = jitcdde(
|
168
|
+
self.sys_eqs,
|
169
|
+
n=4,
|
170
|
+
control_pars=self.control_pars,
|
171
|
+
module_location=join(self.output, self.modulename + ".so"),
|
172
|
+
)
|
173
|
+
I.set_integration_parameters(**integrator_params)
|
174
|
+
I.constant_past(self.initial_state, time=0.0)
|
175
|
+
|
176
|
+
if disc == "blind":
|
177
|
+
I.integrate_blindly(self.initial_state, step=step)
|
178
|
+
elif disc == "step_on":
|
179
|
+
I.step_on_discontinuities(
|
180
|
+
propagations=propagations, min_distance=min_distance, max_step=max_step
|
181
|
+
)
|
182
|
+
else:
|
183
|
+
I.adjust_diff(shift_ratio=shift_ratio)
|
184
|
+
|
185
|
+
if len(self.control_pars) > 0:
|
186
|
+
I.set_parameters(par)
|
187
|
+
tcut = max(self.t_cut, I.t)
|
188
|
+
times = tcut + np.arange(0, self.t_end - tcut, self.interval)
|
189
|
+
|
190
|
+
x = np.zeros((len(times), self.n_components))
|
191
|
+
for i in range(len(times)):
|
192
|
+
x[i, :] = I.integrate(times[i])
|
193
|
+
|
194
|
+
return {"t": times, "x": x}
|
195
|
+
|
196
|
+
if __name__ == "__main__":
|
197
|
+
|
198
|
+
par = {"control": "",
|
199
|
+
"output": "output"
|
200
|
+
}
|
201
|
+
ode = Pav(par=par)
|
202
|
+
ode.compile()
|
203
|
+
data = ode.run(disc="step_on")
|
204
|
+
times = data["t"]
|
205
|
+
x = data["x"]
|
206
|
+
|
207
|
+
print("Times: ", times.shape)
|
208
|
+
print("Coordinates: ", x.shape)
|
209
|
+
plt.plot(times, x)
|
210
|
+
plt.legend(["Subthalamic", "Globus Pallidus", "Excitatory", "Inhibitory"])
|
211
|
+
plt.show()
|
vbi/tests/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,36 @@
|
|
1
|
+
import torch
|
2
|
+
import unittest
|
3
|
+
import numpy as np
|
4
|
+
import networkx as nx
|
5
|
+
from vbi.models.numba.mpr import MPR_sde
|
6
|
+
|
7
|
+
seed = 2
|
8
|
+
np.random.seed(seed)
|
9
|
+
torch.manual_seed(seed)
|
10
|
+
|
11
|
+
nn = 3
|
12
|
+
g = nx.complete_graph(nn)
|
13
|
+
sc = nx.to_numpy_array(g)/ 10.0
|
14
|
+
|
15
|
+
class testMPRSDE(unittest.TestCase):
|
16
|
+
|
17
|
+
mpr = MPR_sde()
|
18
|
+
p = mpr.get_default_parameters()
|
19
|
+
p['weights'] = sc
|
20
|
+
p['seed'] = seed
|
21
|
+
p['t_cut'] = 0.01 * 60 * 1000
|
22
|
+
p['t_end'] = 0.02 * 60 * 1000
|
23
|
+
|
24
|
+
def test_invalid_parameter_raises_value_error(self):
|
25
|
+
invalid_params = {"invalid_param": 42}
|
26
|
+
with self.assertRaises(ValueError):
|
27
|
+
MPR_sde(par=invalid_params)
|
28
|
+
|
29
|
+
def test_run(self):
|
30
|
+
|
31
|
+
control = {"G": 0.1, "eta": -4.7}
|
32
|
+
mpr = MPR_sde(self.p)
|
33
|
+
sol = mpr.run(par=control)
|
34
|
+
x = sol["x"]
|
35
|
+
t = sol["t"]
|
36
|
+
self.assertEqual(x.shape[0], nn)
|