vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl → 0.2__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vbi/feature_extraction/features.json +4 -1
- vbi/feature_extraction/features.py +10 -4
- vbi/inference.py +50 -22
- vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/jr_sde.hpp +5 -6
- vbi/models/cpp/_src/jr_sde_wrap.cxx +28 -28
- vbi/models/cpp/jansen_rit.py +2 -9
- vbi/models/cupy/bold.py +117 -0
- vbi/models/cupy/jansen_rit.py +1 -1
- vbi/models/cupy/km.py +62 -34
- vbi/models/cupy/mpr.py +24 -4
- vbi/models/cupy/utils.py +163 -2
- vbi/models/cupy/wilson_cowan.py +317 -0
- vbi/models/cupy/ww.py +342 -0
- vbi/models/numba/__init__.py +4 -0
- vbi/models/numba/jansen_rit.py +532 -0
- vbi/models/numba/mpr.py +8 -0
- vbi/models/numba/wilson_cowan.py +443 -0
- vbi/models/numba/ww.py +564 -0
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/METADATA +30 -11
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/RECORD +30 -26
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/WHEEL +1 -1
- vbi/models/numba/_ww_EI.py +0 -444
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/licenses/LICENSE +0 -0
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/top_level.txt +0 -0
@@ -359,6 +359,7 @@
|
|
359
359
|
"verbose": false,
|
360
360
|
"pca_num_components": 3,
|
361
361
|
"quantiles": [0.05, 0.25, 0.5, 0.75, 0.95],
|
362
|
+
"k": null,
|
362
363
|
"features": ["sum", "max", "min", "mean", "std", "skew", "kurtosis"]
|
363
364
|
},
|
364
365
|
"use": "yes",
|
@@ -459,7 +460,9 @@
|
|
459
460
|
"moments": [2, 3, 4, 5, 6],
|
460
461
|
"normalize": false,
|
461
462
|
"verbose": false,
|
462
|
-
"indices": null
|
463
|
+
"indices": null,
|
464
|
+
"average": false
|
465
|
+
|
463
466
|
},
|
464
467
|
"use": "yes",
|
465
468
|
"tag": "all"
|
@@ -1128,6 +1128,7 @@ def fcd_stat(
|
|
1128
1128
|
pca_num_components=3,
|
1129
1129
|
quantiles=[0.05, 0.25, 0.5, 0.75, 0.95],
|
1130
1130
|
features=["sum", "max", "min", "mean", "std", "skew", "kurtosis"],
|
1131
|
+
k=None,
|
1131
1132
|
verbose=False,
|
1132
1133
|
):
|
1133
1134
|
|
@@ -1142,8 +1143,7 @@ def fcd_stat(
|
|
1142
1143
|
|
1143
1144
|
Values = []
|
1144
1145
|
Labels = []
|
1145
|
-
|
1146
|
-
k = int(win_len / TR)
|
1146
|
+
k = k if k is not None else int(win_len / TR)
|
1147
1147
|
fcd = get_fcd(ts=ts, TR=TR, win_len=win_len, positive=positive, masks=masks)
|
1148
1148
|
for key in fcd.keys():
|
1149
1149
|
values, labels = matrix_stat(
|
@@ -1598,6 +1598,7 @@ def spectrum_moments(
|
|
1598
1598
|
moments=[2, 3, 4, 5, 6],
|
1599
1599
|
normalize=False,
|
1600
1600
|
indices=None,
|
1601
|
+
average=False,
|
1601
1602
|
verbose=False,
|
1602
1603
|
):
|
1603
1604
|
"""
|
@@ -1660,8 +1661,13 @@ def spectrum_moments(
|
|
1660
1661
|
|
1661
1662
|
for i in moments:
|
1662
1663
|
_m = moment(psd, i, axis=1)
|
1663
|
-
|
1664
|
-
|
1664
|
+
if not average:
|
1665
|
+
Values = np.append(Values, _m)
|
1666
|
+
Labels = Labels + [f"spectrum_moment_{i}_{j}" for j in range(len(_m))]
|
1667
|
+
else:
|
1668
|
+
Values = np.append(Values, np.mean(_m))
|
1669
|
+
Labels = Labels + [f"spectrum_moment_{i}"]
|
1670
|
+
|
1665
1671
|
return Values, Labels
|
1666
1672
|
|
1667
1673
|
|
vbi/inference.py
CHANGED
@@ -3,37 +3,67 @@ from vbi.utils import *
|
|
3
3
|
from sbi.inference import SNPE, SNLE, SNRE
|
4
4
|
from sbi.utils.user_input_checks import process_prior
|
5
5
|
|
6
|
+
|
6
7
|
class Inference(object):
|
7
8
|
def __init__(self) -> None:
|
8
9
|
pass
|
9
10
|
|
10
11
|
@timer
|
11
|
-
def train(
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
12
|
+
def train(
|
13
|
+
self,
|
14
|
+
theta,
|
15
|
+
x,
|
16
|
+
prior,
|
17
|
+
num_threads=1,
|
18
|
+
method="SNPE",
|
19
|
+
device="cpu",
|
20
|
+
density_estimator="maf",
|
21
|
+
):
|
22
|
+
'''
|
23
|
+
train the inference model
|
24
|
+
|
25
|
+
Parameters
|
26
|
+
----------
|
27
|
+
theta: torch.tensor float32 (n, d)
|
28
|
+
parameter samples, where n is the number of samples and d is the dimension of the parameter space
|
29
|
+
x: torch.tensor float32 (n, d)
|
30
|
+
feature samples, where n is the number of samples and d is the dimension of the feature space
|
31
|
+
prior: sbi.utils object
|
32
|
+
prior distribution object
|
33
|
+
num_threads: int
|
34
|
+
number of threads to use for training, for multi-threading support, default is 1
|
35
|
+
method: str
|
36
|
+
inference method to use, one of "SNPE", "SNLE", "SNRE", default is "SNPE"
|
37
|
+
device: str
|
38
|
+
device to use for training, one of "cpu", "cuda", default is "cpu"
|
39
|
+
density_estimator: str
|
40
|
+
density estimator to use, one of "maf", "nsf", default is "maf"
|
41
|
+
Returns
|
42
|
+
-------
|
43
|
+
posterior: sbi.utils object
|
44
|
+
posterior distribution object trained on the given data
|
45
|
+
|
46
|
+
'''
|
20
47
|
|
21
48
|
torch.set_num_threads(num_threads)
|
22
49
|
|
23
|
-
if
|
50
|
+
if len(x.shape) == 1:
|
24
51
|
x = x[:, None]
|
25
|
-
if
|
52
|
+
if len(theta.shape) == 1:
|
26
53
|
theta = theta[:, None]
|
27
54
|
|
28
55
|
if method == "SNPE":
|
29
56
|
inference = SNPE(
|
30
|
-
prior=prior, density_estimator=density_estimator, device=device
|
57
|
+
prior=prior, density_estimator=density_estimator, device=device
|
58
|
+
)
|
31
59
|
elif method == "SNLE":
|
32
60
|
inference = SNLE(
|
33
|
-
prior=prior, density_estimator=density_estimator, device=device
|
61
|
+
prior=prior, density_estimator=density_estimator, device=device
|
62
|
+
)
|
34
63
|
elif method == "SNRE":
|
35
64
|
inference = SNRE(
|
36
|
-
prior=prior, density_estimator=density_estimator, device=device
|
65
|
+
prior=prior, density_estimator=density_estimator, device=device
|
66
|
+
)
|
37
67
|
else:
|
38
68
|
raise ValueError("Invalid method: " + method)
|
39
69
|
|
@@ -45,7 +75,7 @@ class Inference(object):
|
|
45
75
|
|
46
76
|
@staticmethod
|
47
77
|
def sample_prior(prior, n, seed=None):
|
48
|
-
|
78
|
+
"""
|
49
79
|
sample from prior distribution
|
50
80
|
|
51
81
|
Parameters
|
@@ -58,19 +88,17 @@ class Inference(object):
|
|
58
88
|
Returns
|
59
89
|
-------
|
60
90
|
|
61
|
-
|
91
|
+
"""
|
62
92
|
if seed is not None:
|
63
93
|
torch.manual_seed(seed)
|
64
|
-
|
94
|
+
|
65
95
|
prior, _, _ = process_prior(prior)
|
66
96
|
theta = prior.sample((n,))
|
67
97
|
return theta
|
68
98
|
|
69
99
|
@staticmethod
|
70
|
-
def sample_posterior(xo,
|
71
|
-
|
72
|
-
posterior):
|
73
|
-
'''
|
100
|
+
def sample_posterior(xo, num_samples, posterior):
|
101
|
+
"""
|
74
102
|
sample from the posterior using the given observation point.
|
75
103
|
|
76
104
|
Parameters
|
@@ -87,7 +115,7 @@ class Inference(object):
|
|
87
115
|
samples: torch.tensor float32 (num_samples, d)
|
88
116
|
samples from the posterior
|
89
117
|
|
90
|
-
|
118
|
+
"""
|
91
119
|
|
92
120
|
if not isinstance(xo, torch.Tensor):
|
93
121
|
xo = torch.tensor(xo, dtype=torch.float32)
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
vbi/models/cpp/_src/jr_sde.hpp
CHANGED
@@ -30,7 +30,7 @@ private:
|
|
30
30
|
size_t index_transition;
|
31
31
|
vector<vector<unsigned>> adjlist;
|
32
32
|
|
33
|
-
|
33
|
+
dim1 A;
|
34
34
|
double par_a;
|
35
35
|
double par_B;
|
36
36
|
double par_b;
|
@@ -64,7 +64,7 @@ public:
|
|
64
64
|
double coupling,
|
65
65
|
dim2 adj,
|
66
66
|
dim1 y,
|
67
|
-
|
67
|
+
dim1 A,
|
68
68
|
double B,
|
69
69
|
double a,
|
70
70
|
double b,
|
@@ -81,7 +81,7 @@ public:
|
|
81
81
|
{
|
82
82
|
assert(t_final > t_transition);
|
83
83
|
|
84
|
-
|
84
|
+
this->A = A;
|
85
85
|
par_B = B;
|
86
86
|
par_a = a;
|
87
87
|
par_b = b;
|
@@ -131,7 +131,6 @@ public:
|
|
131
131
|
|
132
132
|
double a2 = par_a * par_a;
|
133
133
|
double b2 = par_b * par_b;
|
134
|
-
double Aa = par_A * par_a;
|
135
134
|
double Bb = par_B * par_b;
|
136
135
|
|
137
136
|
size_t N2 = 2 * N;
|
@@ -152,8 +151,8 @@ public:
|
|
152
151
|
dxdt[i] = y[i + N3];
|
153
152
|
dxdt[i + N] = y[i + N4];
|
154
153
|
dxdt[i + N2] = y[i + N5];
|
155
|
-
dxdt[i + N3] =
|
156
|
-
dxdt[i + N4] =
|
154
|
+
dxdt[i + N3] = A[i] * par_a * sigma(y[i + N] - y[i + N2]) - 2 * par_a * y[i + N3] - a2 * y[i];
|
155
|
+
dxdt[i + N4] = A[i] * par_a * (noise_mu + C1[i] * sigma(C0[i] * y[i]) + coupling * coupling_term) - 2 * par_a * y[i + N4] - a2 * y[i + N];
|
157
156
|
dxdt[i + N5] = Bb * C3[i] * sigma(C2[i] * y[i]) - 2 * par_b * y[i + N5] -
|
158
157
|
b2 * y[i + N2];
|
159
158
|
}
|
@@ -11803,7 +11803,7 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_0(PyObject *self, Py_ssize_t nobjs,
|
|
11803
11803
|
double arg5 ;
|
11804
11804
|
dim2 arg6 ;
|
11805
11805
|
dim1 arg7 ;
|
11806
|
-
|
11806
|
+
dim1 arg8 ;
|
11807
11807
|
double arg9 ;
|
11808
11808
|
double arg10 ;
|
11809
11809
|
double arg11 ;
|
@@ -11827,8 +11827,6 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_0(PyObject *self, Py_ssize_t nobjs,
|
|
11827
11827
|
int ecode4 = 0 ;
|
11828
11828
|
double val5 ;
|
11829
11829
|
int ecode5 = 0 ;
|
11830
|
-
double val8 ;
|
11831
|
-
int ecode8 = 0 ;
|
11832
11830
|
double val9 ;
|
11833
11831
|
int ecode9 = 0 ;
|
11834
11832
|
double val10 ;
|
@@ -11894,11 +11892,15 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_0(PyObject *self, Py_ssize_t nobjs,
|
|
11894
11892
|
arg7 = *ptr;
|
11895
11893
|
if (SWIG_IsNewObj(res)) delete ptr;
|
11896
11894
|
}
|
11897
|
-
|
11898
|
-
|
11899
|
-
|
11900
|
-
|
11901
|
-
|
11895
|
+
{
|
11896
|
+
std::vector< double,std::allocator< double > > *ptr = (std::vector< double,std::allocator< double > > *)0;
|
11897
|
+
int res = swig::asptr(swig_obj[7], &ptr);
|
11898
|
+
if (!SWIG_IsOK(res) || !ptr) {
|
11899
|
+
SWIG_exception_fail(SWIG_ArgError((ptr ? res : SWIG_TypeError)), "in method '" "new_JR_sde" "', argument " "8"" of type '" "dim1""'");
|
11900
|
+
}
|
11901
|
+
arg8 = *ptr;
|
11902
|
+
if (SWIG_IsNewObj(res)) delete ptr;
|
11903
|
+
}
|
11902
11904
|
ecode9 = SWIG_AsVal_double(swig_obj[8], &val9);
|
11903
11905
|
if (!SWIG_IsOK(ecode9)) {
|
11904
11906
|
SWIG_exception_fail(SWIG_ArgError(ecode9), "in method '" "new_JR_sde" "', argument " "9"" of type '" "double""'");
|
@@ -11980,7 +11982,7 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_0(PyObject *self, Py_ssize_t nobjs,
|
|
11980
11982
|
SWIG_exception_fail(SWIG_ArgError(ecode21), "in method '" "new_JR_sde" "', argument " "21"" of type '" "int""'");
|
11981
11983
|
}
|
11982
11984
|
arg21 = static_cast< int >(val21);
|
11983
|
-
result = (JR_sde *)new JR_sde(SWIG_STD_MOVE(arg1),arg2,arg3,arg4,arg5,SWIG_STD_MOVE(arg6),SWIG_STD_MOVE(arg7),arg8,arg9,arg10,arg11,arg12,arg13,arg14,SWIG_STD_MOVE(arg15),SWIG_STD_MOVE(arg16),SWIG_STD_MOVE(arg17),SWIG_STD_MOVE(arg18),arg19,arg20,arg21);
|
11985
|
+
result = (JR_sde *)new JR_sde(SWIG_STD_MOVE(arg1),arg2,arg3,arg4,arg5,SWIG_STD_MOVE(arg6),SWIG_STD_MOVE(arg7),SWIG_STD_MOVE(arg8),arg9,arg10,arg11,arg12,arg13,arg14,SWIG_STD_MOVE(arg15),SWIG_STD_MOVE(arg16),SWIG_STD_MOVE(arg17),SWIG_STD_MOVE(arg18),arg19,arg20,arg21);
|
11984
11986
|
resultobj = SWIG_NewPointerObj(SWIG_as_voidptr(result), SWIGTYPE_p_JR_sde, SWIG_POINTER_NEW | 0 );
|
11985
11987
|
return resultobj;
|
11986
11988
|
fail:
|
@@ -11997,7 +11999,7 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_1(PyObject *self, Py_ssize_t nobjs,
|
|
11997
11999
|
double arg5 ;
|
11998
12000
|
dim2 arg6 ;
|
11999
12001
|
dim1 arg7 ;
|
12000
|
-
|
12002
|
+
dim1 arg8 ;
|
12001
12003
|
double arg9 ;
|
12002
12004
|
double arg10 ;
|
12003
12005
|
double arg11 ;
|
@@ -12020,8 +12022,6 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_1(PyObject *self, Py_ssize_t nobjs,
|
|
12020
12022
|
int ecode4 = 0 ;
|
12021
12023
|
double val5 ;
|
12022
12024
|
int ecode5 = 0 ;
|
12023
|
-
double val8 ;
|
12024
|
-
int ecode8 = 0 ;
|
12025
12025
|
double val9 ;
|
12026
12026
|
int ecode9 = 0 ;
|
12027
12027
|
double val10 ;
|
@@ -12085,11 +12085,15 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_1(PyObject *self, Py_ssize_t nobjs,
|
|
12085
12085
|
arg7 = *ptr;
|
12086
12086
|
if (SWIG_IsNewObj(res)) delete ptr;
|
12087
12087
|
}
|
12088
|
-
|
12089
|
-
|
12090
|
-
|
12091
|
-
|
12092
|
-
|
12088
|
+
{
|
12089
|
+
std::vector< double,std::allocator< double > > *ptr = (std::vector< double,std::allocator< double > > *)0;
|
12090
|
+
int res = swig::asptr(swig_obj[7], &ptr);
|
12091
|
+
if (!SWIG_IsOK(res) || !ptr) {
|
12092
|
+
SWIG_exception_fail(SWIG_ArgError((ptr ? res : SWIG_TypeError)), "in method '" "new_JR_sde" "', argument " "8"" of type '" "dim1""'");
|
12093
|
+
}
|
12094
|
+
arg8 = *ptr;
|
12095
|
+
if (SWIG_IsNewObj(res)) delete ptr;
|
12096
|
+
}
|
12093
12097
|
ecode9 = SWIG_AsVal_double(swig_obj[8], &val9);
|
12094
12098
|
if (!SWIG_IsOK(ecode9)) {
|
12095
12099
|
SWIG_exception_fail(SWIG_ArgError(ecode9), "in method '" "new_JR_sde" "', argument " "9"" of type '" "double""'");
|
@@ -12166,7 +12170,7 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_1(PyObject *self, Py_ssize_t nobjs,
|
|
12166
12170
|
SWIG_exception_fail(SWIG_ArgError(ecode20), "in method '" "new_JR_sde" "', argument " "20"" of type '" "double""'");
|
12167
12171
|
}
|
12168
12172
|
arg20 = static_cast< double >(val20);
|
12169
|
-
result = (JR_sde *)new JR_sde(SWIG_STD_MOVE(arg1),arg2,arg3,arg4,arg5,SWIG_STD_MOVE(arg6),SWIG_STD_MOVE(arg7),arg8,arg9,arg10,arg11,arg12,arg13,arg14,SWIG_STD_MOVE(arg15),SWIG_STD_MOVE(arg16),SWIG_STD_MOVE(arg17),SWIG_STD_MOVE(arg18),arg19,arg20);
|
12173
|
+
result = (JR_sde *)new JR_sde(SWIG_STD_MOVE(arg1),arg2,arg3,arg4,arg5,SWIG_STD_MOVE(arg6),SWIG_STD_MOVE(arg7),SWIG_STD_MOVE(arg8),arg9,arg10,arg11,arg12,arg13,arg14,SWIG_STD_MOVE(arg15),SWIG_STD_MOVE(arg16),SWIG_STD_MOVE(arg17),SWIG_STD_MOVE(arg18),arg19,arg20);
|
12170
12174
|
resultobj = SWIG_NewPointerObj(SWIG_as_voidptr(result), SWIGTYPE_p_JR_sde, SWIG_POINTER_NEW | 0 );
|
12171
12175
|
return resultobj;
|
12172
12176
|
fail:
|
@@ -12215,10 +12219,8 @@ SWIGINTERN PyObject *_wrap_new_JR_sde(PyObject *self, PyObject *args) {
|
|
12215
12219
|
int res = swig::asptr(argv[6], (std::vector< double,std::allocator< double > >**)(0));
|
12216
12220
|
_v = SWIG_CheckState(res);
|
12217
12221
|
if (_v) {
|
12218
|
-
|
12219
|
-
|
12220
|
-
_v = SWIG_CheckState(res);
|
12221
|
-
}
|
12222
|
+
int res = swig::asptr(argv[7], (std::vector< double,std::allocator< double > >**)(0));
|
12223
|
+
_v = SWIG_CheckState(res);
|
12222
12224
|
if (_v) {
|
12223
12225
|
{
|
12224
12226
|
int res = SWIG_AsVal_double(argv[8], NULL);
|
@@ -12327,10 +12329,8 @@ SWIGINTERN PyObject *_wrap_new_JR_sde(PyObject *self, PyObject *args) {
|
|
12327
12329
|
int res = swig::asptr(argv[6], (std::vector< double,std::allocator< double > >**)(0));
|
12328
12330
|
_v = SWIG_CheckState(res);
|
12329
12331
|
if (_v) {
|
12330
|
-
|
12331
|
-
|
12332
|
-
_v = SWIG_CheckState(res);
|
12333
|
-
}
|
12332
|
+
int res = swig::asptr(argv[7], (std::vector< double,std::allocator< double > >**)(0));
|
12333
|
+
_v = SWIG_CheckState(res);
|
12334
12334
|
if (_v) {
|
12335
12335
|
{
|
12336
12336
|
int res = SWIG_AsVal_double(argv[8], NULL);
|
@@ -12416,8 +12416,8 @@ SWIGINTERN PyObject *_wrap_new_JR_sde(PyObject *self, PyObject *args) {
|
|
12416
12416
|
fail:
|
12417
12417
|
SWIG_Python_RaiseOrModifyTypeError("Wrong number or type of arguments for overloaded function 'new_JR_sde'.\n"
|
12418
12418
|
" Possible C/C++ prototypes are:\n"
|
12419
|
-
" JR_sde::JR_sde(size_t,double,double,double,double,dim2,dim1,
|
12420
|
-
" JR_sde::JR_sde(size_t,double,double,double,double,dim2,dim1,
|
12419
|
+
" JR_sde::JR_sde(size_t,double,double,double,double,dim2,dim1,dim1,double,double,double,double,double,double,dim1,dim1,dim1,dim1,double,double,int)\n"
|
12420
|
+
" JR_sde::JR_sde(size_t,double,double,double,double,dim2,dim1,dim1,double,double,double,double,double,double,dim1,dim1,dim1,dim1,double,double)\n");
|
12421
12421
|
return 0;
|
12422
12422
|
}
|
12423
12423
|
|
vbi/models/cpp/jansen_rit.py
CHANGED
@@ -124,7 +124,7 @@ class JR_sde:
|
|
124
124
|
"method": "heun",
|
125
125
|
"t_transition": 500.0, # ms
|
126
126
|
"t_end": 2501.0, # ms
|
127
|
-
"output": "output",
|
127
|
+
"output": "output", # output directory
|
128
128
|
"RECORD_AVG": False # true to store large time series in file
|
129
129
|
}
|
130
130
|
return par
|
@@ -181,17 +181,13 @@ class JR_sde:
|
|
181
181
|
self.t_transition = float(self.t_transition)
|
182
182
|
self.t_end = float(self.t_end)
|
183
183
|
self.G = float(self.G)
|
184
|
-
self.A = float(self.A)
|
185
184
|
self.B = float(self.B)
|
186
185
|
self.a = float(self.a)
|
187
186
|
self.b = float(self.b)
|
188
187
|
self.r = float(self.r)
|
189
188
|
self.v0 = float(self.v0)
|
190
189
|
self.vmax = float(self.vmax)
|
191
|
-
|
192
|
-
# self.C1 = np.asarray(self.C1)
|
193
|
-
# self.C2 = np.asarray(self.C2)
|
194
|
-
# self.C3 = np.asarray(self.C3)
|
190
|
+
self.A = check_sequence(self.A, self.N)
|
195
191
|
self.C0 = check_sequence(self.C0, self.N)
|
196
192
|
self.C1 = check_sequence(self.C1, self.N)
|
197
193
|
self.C2 = check_sequence(self.C2, self.N)
|
@@ -236,9 +232,6 @@ class JR_sde:
|
|
236
232
|
for key in par.keys():
|
237
233
|
if key not in self.valid_params:
|
238
234
|
raise ValueError("Invalid parameter: " + key)
|
239
|
-
# if key in ["C0", "C1", "C2", "C3"]:
|
240
|
-
# self.set_C(key, par[key])
|
241
|
-
# else:
|
242
235
|
setattr(self, key, par[key])
|
243
236
|
|
244
237
|
self.prepare_input()
|
vbi/models/cupy/bold.py
CHANGED
@@ -105,6 +105,123 @@ class BoldStephan2008:
|
|
105
105
|
q[0] = q[1]
|
106
106
|
|
107
107
|
|
108
|
+
class Bold:
|
109
|
+
|
110
|
+
def __init__(self, par: dict = {}) -> None:
|
111
|
+
|
112
|
+
self._par = self.get_default_parameters()
|
113
|
+
self.valid_parameters = list(self._par.keys())
|
114
|
+
self.check_parameters(par)
|
115
|
+
self._par.update(par)
|
116
|
+
|
117
|
+
for item in self._par.items():
|
118
|
+
setattr(self, item[0], item[1])
|
119
|
+
self.update_dependent_parameters()
|
120
|
+
|
121
|
+
|
122
|
+
def get_default_parameters(self):
|
123
|
+
"""get balloon model parameters."""
|
124
|
+
|
125
|
+
vo = 0.08
|
126
|
+
theta = 40.3
|
127
|
+
TE = 0.04
|
128
|
+
Eo = 0.4
|
129
|
+
r0 = 25.0
|
130
|
+
epsilon = 0.34
|
131
|
+
k1 = 4.3 * theta * Eo * TE
|
132
|
+
k2 = epsilon * r0 * Eo * TE
|
133
|
+
k3 = 1 - epsilon
|
134
|
+
|
135
|
+
par = {
|
136
|
+
"kappa": 0.65,
|
137
|
+
"gamma": 0.41,
|
138
|
+
"tau": 0.98,
|
139
|
+
"alpha": 0.32,
|
140
|
+
"epsilon": epsilon,
|
141
|
+
"Eo": Eo,
|
142
|
+
"TE": TE,
|
143
|
+
"vo": vo,
|
144
|
+
"r0": r0,
|
145
|
+
"theta": theta,
|
146
|
+
"t_min": 0.0,
|
147
|
+
"rtol": 1e-5,
|
148
|
+
"atol": 1e-8,
|
149
|
+
"k1": k1,
|
150
|
+
"k2": k2,
|
151
|
+
"k3": k3
|
152
|
+
}
|
153
|
+
return par
|
154
|
+
|
155
|
+
def update_dependent_parameters(self):
|
156
|
+
self.k1 = 4.3 * self.theta * self.Eo * self.TE
|
157
|
+
self.k2 = self.epsilon * self.r0 * self.Eo * self.TE
|
158
|
+
self.k3 = 1 - self.epsilon
|
159
|
+
|
160
|
+
def check_parameters(self, par):
|
161
|
+
for key in par.keys():
|
162
|
+
if key not in self.valid_parameters:
|
163
|
+
raise ValueError(f"Invalid parameter {key:s} provided.")
|
164
|
+
|
165
|
+
def allocate_memory(self, xp, nn, ns, n_steps, bold_decimate, dtype):
|
166
|
+
|
167
|
+
self.s = xp.zeros((2, nn, ns), dtype=dtype)
|
168
|
+
self.f = xp.zeros((2, nn, ns), dtype=dtype)
|
169
|
+
self.ftilde = xp.zeros((2, nn, ns), dtype=dtype)
|
170
|
+
self.vtilde = xp.zeros((2, nn, ns), dtype=dtype)
|
171
|
+
self.qtilde = xp.zeros((2, nn, ns), dtype=dtype)
|
172
|
+
self.v = xp.zeros((2, nn, ns), dtype=dtype)
|
173
|
+
self.q = xp.zeros((2, nn, ns), dtype=dtype)
|
174
|
+
self.vv = np.zeros((n_steps // bold_decimate, nn, ns), dtype="f")
|
175
|
+
self.qq = np.zeros((n_steps // bold_decimate, nn, ns), dtype="f")
|
176
|
+
self.s[0] = 1
|
177
|
+
self.f[0] = 1
|
178
|
+
self.v[0] = 1
|
179
|
+
self.q[0] = 1
|
180
|
+
self.ftilde[0] = 0
|
181
|
+
self.vtilde[0] = 0
|
182
|
+
self.qtilde[0] = 0
|
183
|
+
|
184
|
+
def do_bold_step(self, r_in, dtt):
|
185
|
+
|
186
|
+
Eo = self.Eo
|
187
|
+
tau = self.tau
|
188
|
+
kappa = self.kappa
|
189
|
+
gamma = self.gamma
|
190
|
+
alpha = self.alpha
|
191
|
+
ialpha = 1 / alpha
|
192
|
+
|
193
|
+
v = self.v
|
194
|
+
q = self.q
|
195
|
+
s = self.s
|
196
|
+
f = self.f
|
197
|
+
ftilde = self.ftilde
|
198
|
+
vtilde = self.vtilde
|
199
|
+
qtilde = self.qtilde
|
200
|
+
|
201
|
+
s[1] = s[0] + dtt * (r_in - kappa * s[0] - gamma * (f[0] - 1))
|
202
|
+
f[0] = np.clip(f[0], 1, None)
|
203
|
+
ftilde[1] = ftilde[0] + dtt * (s[0] / f[0])
|
204
|
+
fv = v[0] ** ialpha # outflow
|
205
|
+
vtilde[1] = vtilde[0] + dtt * ((f[0] - fv) / (tau * v[0]))
|
206
|
+
q[0] = np.clip(q[0], 0.01, None)
|
207
|
+
ff = (1 - (1 - Eo) ** (1 / f[0])) / Eo # oxygen extraction
|
208
|
+
qtilde[1] = qtilde[0] + dtt * ((f[0] * ff - fv * q[0] / v[0]) / (tau * q[0]))
|
209
|
+
|
210
|
+
f[1] = np.exp(ftilde[1])
|
211
|
+
v[1] = np.exp(vtilde[1])
|
212
|
+
q[1] = np.exp(qtilde[1])
|
213
|
+
|
214
|
+
f[0] = f[1]
|
215
|
+
s[0] = s[1]
|
216
|
+
ftilde[0] = ftilde[1]
|
217
|
+
vtilde[0] = vtilde[1]
|
218
|
+
qtilde[0] = qtilde[1]
|
219
|
+
v[0] = v[1]
|
220
|
+
q[0] = q[1]
|
221
|
+
|
222
|
+
|
223
|
+
|
224
|
+
|
108
225
|
class BoldTVB:
|
109
226
|
|
110
227
|
def __init__(self):
|
vbi/models/cupy/jansen_rit.py
CHANGED
@@ -199,7 +199,7 @@ class JR_sde:
|
|
199
199
|
|
200
200
|
self.G = self.xp.array(self.G)
|
201
201
|
assert self.weights is not None, "weights must be provided"
|
202
|
-
self.weights = self.xp.array(self.weights)
|
202
|
+
self.weights = self.xp.array(self.weights)
|
203
203
|
self.weights = move_data(self.weights, self.engine)
|
204
204
|
self.nn = self.num_nodes = self.weights.shape[0]
|
205
205
|
|