vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl → 0.2__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -359,6 +359,7 @@
359
359
  "verbose": false,
360
360
  "pca_num_components": 3,
361
361
  "quantiles": [0.05, 0.25, 0.5, 0.75, 0.95],
362
+ "k": null,
362
363
  "features": ["sum", "max", "min", "mean", "std", "skew", "kurtosis"]
363
364
  },
364
365
  "use": "yes",
@@ -459,7 +460,9 @@
459
460
  "moments": [2, 3, 4, 5, 6],
460
461
  "normalize": false,
461
462
  "verbose": false,
462
- "indices": null
463
+ "indices": null,
464
+ "average": false
465
+
463
466
  },
464
467
  "use": "yes",
465
468
  "tag": "all"
@@ -1128,6 +1128,7 @@ def fcd_stat(
1128
1128
  pca_num_components=3,
1129
1129
  quantiles=[0.05, 0.25, 0.5, 0.75, 0.95],
1130
1130
  features=["sum", "max", "min", "mean", "std", "skew", "kurtosis"],
1131
+ k=None,
1131
1132
  verbose=False,
1132
1133
  ):
1133
1134
 
@@ -1142,8 +1143,7 @@ def fcd_stat(
1142
1143
 
1143
1144
  Values = []
1144
1145
  Labels = []
1145
-
1146
- k = int(win_len / TR)
1146
+ k = k if k is not None else int(win_len / TR)
1147
1147
  fcd = get_fcd(ts=ts, TR=TR, win_len=win_len, positive=positive, masks=masks)
1148
1148
  for key in fcd.keys():
1149
1149
  values, labels = matrix_stat(
@@ -1598,6 +1598,7 @@ def spectrum_moments(
1598
1598
  moments=[2, 3, 4, 5, 6],
1599
1599
  normalize=False,
1600
1600
  indices=None,
1601
+ average=False,
1601
1602
  verbose=False,
1602
1603
  ):
1603
1604
  """
@@ -1660,8 +1661,13 @@ def spectrum_moments(
1660
1661
 
1661
1662
  for i in moments:
1662
1663
  _m = moment(psd, i, axis=1)
1663
- Values = np.append(Values, _m)
1664
- Labels = Labels + [f"spectrum_moment_{i}_{j}" for j in range(len(_m))]
1664
+ if not average:
1665
+ Values = np.append(Values, _m)
1666
+ Labels = Labels + [f"spectrum_moment_{i}_{j}" for j in range(len(_m))]
1667
+ else:
1668
+ Values = np.append(Values, np.mean(_m))
1669
+ Labels = Labels + [f"spectrum_moment_{i}"]
1670
+
1665
1671
  return Values, Labels
1666
1672
 
1667
1673
 
vbi/inference.py CHANGED
@@ -3,37 +3,67 @@ from vbi.utils import *
3
3
  from sbi.inference import SNPE, SNLE, SNRE
4
4
  from sbi.utils.user_input_checks import process_prior
5
5
 
6
+
6
7
  class Inference(object):
7
8
  def __init__(self) -> None:
8
9
  pass
9
10
 
10
11
  @timer
11
- def train(self,
12
- theta,
13
- x,
14
- prior,
15
- num_threads=1,
16
- method="SNPE",
17
- device="cpu",
18
- density_estimator="maf"
19
- ):
12
+ def train(
13
+ self,
14
+ theta,
15
+ x,
16
+ prior,
17
+ num_threads=1,
18
+ method="SNPE",
19
+ device="cpu",
20
+ density_estimator="maf",
21
+ ):
22
+ '''
23
+ train the inference model
24
+
25
+ Parameters
26
+ ----------
27
+ theta: torch.tensor float32 (n, d)
28
+ parameter samples, where n is the number of samples and d is the dimension of the parameter space
29
+ x: torch.tensor float32 (n, d)
30
+ feature samples, where n is the number of samples and d is the dimension of the feature space
31
+ prior: sbi.utils object
32
+ prior distribution object
33
+ num_threads: int
34
+ number of threads to use for training, for multi-threading support, default is 1
35
+ method: str
36
+ inference method to use, one of "SNPE", "SNLE", "SNRE", default is "SNPE"
37
+ device: str
38
+ device to use for training, one of "cpu", "cuda", default is "cpu"
39
+ density_estimator: str
40
+ density estimator to use, one of "maf", "nsf", default is "maf"
41
+ Returns
42
+ -------
43
+ posterior: sbi.utils object
44
+ posterior distribution object trained on the given data
45
+
46
+ '''
20
47
 
21
48
  torch.set_num_threads(num_threads)
22
49
 
23
- if (len(x.shape) == 1):
50
+ if len(x.shape) == 1:
24
51
  x = x[:, None]
25
- if (len(theta.shape) == 1):
52
+ if len(theta.shape) == 1:
26
53
  theta = theta[:, None]
27
54
 
28
55
  if method == "SNPE":
29
56
  inference = SNPE(
30
- prior=prior, density_estimator=density_estimator, device=device)
57
+ prior=prior, density_estimator=density_estimator, device=device
58
+ )
31
59
  elif method == "SNLE":
32
60
  inference = SNLE(
33
- prior=prior, density_estimator=density_estimator, device=device)
61
+ prior=prior, density_estimator=density_estimator, device=device
62
+ )
34
63
  elif method == "SNRE":
35
64
  inference = SNRE(
36
- prior=prior, density_estimator=density_estimator, device=device)
65
+ prior=prior, density_estimator=density_estimator, device=device
66
+ )
37
67
  else:
38
68
  raise ValueError("Invalid method: " + method)
39
69
 
@@ -45,7 +75,7 @@ class Inference(object):
45
75
 
46
76
  @staticmethod
47
77
  def sample_prior(prior, n, seed=None):
48
- '''
78
+ """
49
79
  sample from prior distribution
50
80
 
51
81
  Parameters
@@ -58,19 +88,17 @@ class Inference(object):
58
88
  Returns
59
89
  -------
60
90
 
61
- '''
91
+ """
62
92
  if seed is not None:
63
93
  torch.manual_seed(seed)
64
-
94
+
65
95
  prior, _, _ = process_prior(prior)
66
96
  theta = prior.sample((n,))
67
97
  return theta
68
98
 
69
99
  @staticmethod
70
- def sample_posterior(xo,
71
- num_samples,
72
- posterior):
73
- '''
100
+ def sample_posterior(xo, num_samples, posterior):
101
+ """
74
102
  sample from the posterior using the given observation point.
75
103
 
76
104
  Parameters
@@ -87,7 +115,7 @@ class Inference(object):
87
115
  samples: torch.tensor float32 (num_samples, d)
88
116
  samples from the posterior
89
117
 
90
- '''
118
+ """
91
119
 
92
120
  if not isinstance(xo, torch.Tensor):
93
121
  xo = torch.tensor(xo, dtype=torch.float32)
@@ -30,7 +30,7 @@ private:
30
30
  size_t index_transition;
31
31
  vector<vector<unsigned>> adjlist;
32
32
 
33
- double par_A;
33
+ dim1 A;
34
34
  double par_a;
35
35
  double par_B;
36
36
  double par_b;
@@ -64,7 +64,7 @@ public:
64
64
  double coupling,
65
65
  dim2 adj,
66
66
  dim1 y,
67
- double A,
67
+ dim1 A,
68
68
  double B,
69
69
  double a,
70
70
  double b,
@@ -81,7 +81,7 @@ public:
81
81
  {
82
82
  assert(t_final > t_transition);
83
83
 
84
- par_A = A;
84
+ this->A = A;
85
85
  par_B = B;
86
86
  par_a = a;
87
87
  par_b = b;
@@ -131,7 +131,6 @@ public:
131
131
 
132
132
  double a2 = par_a * par_a;
133
133
  double b2 = par_b * par_b;
134
- double Aa = par_A * par_a;
135
134
  double Bb = par_B * par_b;
136
135
 
137
136
  size_t N2 = 2 * N;
@@ -152,8 +151,8 @@ public:
152
151
  dxdt[i] = y[i + N3];
153
152
  dxdt[i + N] = y[i + N4];
154
153
  dxdt[i + N2] = y[i + N5];
155
- dxdt[i + N3] = par_A * par_a * sigma(y[i + N] - y[i + N2]) - 2 * par_a * y[i + N3] - a2 * y[i];
156
- dxdt[i + N4] = Aa * (noise_mu + C1[i] * sigma(C0[i] * y[i]) + coupling * coupling_term) - 2 * par_a * y[i + N4] - a2 * y[i + N];
154
+ dxdt[i + N3] = A[i] * par_a * sigma(y[i + N] - y[i + N2]) - 2 * par_a * y[i + N3] - a2 * y[i];
155
+ dxdt[i + N4] = A[i] * par_a * (noise_mu + C1[i] * sigma(C0[i] * y[i]) + coupling * coupling_term) - 2 * par_a * y[i + N4] - a2 * y[i + N];
157
156
  dxdt[i + N5] = Bb * C3[i] * sigma(C2[i] * y[i]) - 2 * par_b * y[i + N5] -
158
157
  b2 * y[i + N2];
159
158
  }
@@ -11803,7 +11803,7 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_0(PyObject *self, Py_ssize_t nobjs,
11803
11803
  double arg5 ;
11804
11804
  dim2 arg6 ;
11805
11805
  dim1 arg7 ;
11806
- double arg8 ;
11806
+ dim1 arg8 ;
11807
11807
  double arg9 ;
11808
11808
  double arg10 ;
11809
11809
  double arg11 ;
@@ -11827,8 +11827,6 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_0(PyObject *self, Py_ssize_t nobjs,
11827
11827
  int ecode4 = 0 ;
11828
11828
  double val5 ;
11829
11829
  int ecode5 = 0 ;
11830
- double val8 ;
11831
- int ecode8 = 0 ;
11832
11830
  double val9 ;
11833
11831
  int ecode9 = 0 ;
11834
11832
  double val10 ;
@@ -11894,11 +11892,15 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_0(PyObject *self, Py_ssize_t nobjs,
11894
11892
  arg7 = *ptr;
11895
11893
  if (SWIG_IsNewObj(res)) delete ptr;
11896
11894
  }
11897
- ecode8 = SWIG_AsVal_double(swig_obj[7], &val8);
11898
- if (!SWIG_IsOK(ecode8)) {
11899
- SWIG_exception_fail(SWIG_ArgError(ecode8), "in method '" "new_JR_sde" "', argument " "8"" of type '" "double""'");
11900
- }
11901
- arg8 = static_cast< double >(val8);
11895
+ {
11896
+ std::vector< double,std::allocator< double > > *ptr = (std::vector< double,std::allocator< double > > *)0;
11897
+ int res = swig::asptr(swig_obj[7], &ptr);
11898
+ if (!SWIG_IsOK(res) || !ptr) {
11899
+ SWIG_exception_fail(SWIG_ArgError((ptr ? res : SWIG_TypeError)), "in method '" "new_JR_sde" "', argument " "8"" of type '" "dim1""'");
11900
+ }
11901
+ arg8 = *ptr;
11902
+ if (SWIG_IsNewObj(res)) delete ptr;
11903
+ }
11902
11904
  ecode9 = SWIG_AsVal_double(swig_obj[8], &val9);
11903
11905
  if (!SWIG_IsOK(ecode9)) {
11904
11906
  SWIG_exception_fail(SWIG_ArgError(ecode9), "in method '" "new_JR_sde" "', argument " "9"" of type '" "double""'");
@@ -11980,7 +11982,7 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_0(PyObject *self, Py_ssize_t nobjs,
11980
11982
  SWIG_exception_fail(SWIG_ArgError(ecode21), "in method '" "new_JR_sde" "', argument " "21"" of type '" "int""'");
11981
11983
  }
11982
11984
  arg21 = static_cast< int >(val21);
11983
- result = (JR_sde *)new JR_sde(SWIG_STD_MOVE(arg1),arg2,arg3,arg4,arg5,SWIG_STD_MOVE(arg6),SWIG_STD_MOVE(arg7),arg8,arg9,arg10,arg11,arg12,arg13,arg14,SWIG_STD_MOVE(arg15),SWIG_STD_MOVE(arg16),SWIG_STD_MOVE(arg17),SWIG_STD_MOVE(arg18),arg19,arg20,arg21);
11985
+ result = (JR_sde *)new JR_sde(SWIG_STD_MOVE(arg1),arg2,arg3,arg4,arg5,SWIG_STD_MOVE(arg6),SWIG_STD_MOVE(arg7),SWIG_STD_MOVE(arg8),arg9,arg10,arg11,arg12,arg13,arg14,SWIG_STD_MOVE(arg15),SWIG_STD_MOVE(arg16),SWIG_STD_MOVE(arg17),SWIG_STD_MOVE(arg18),arg19,arg20,arg21);
11984
11986
  resultobj = SWIG_NewPointerObj(SWIG_as_voidptr(result), SWIGTYPE_p_JR_sde, SWIG_POINTER_NEW | 0 );
11985
11987
  return resultobj;
11986
11988
  fail:
@@ -11997,7 +11999,7 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_1(PyObject *self, Py_ssize_t nobjs,
11997
11999
  double arg5 ;
11998
12000
  dim2 arg6 ;
11999
12001
  dim1 arg7 ;
12000
- double arg8 ;
12002
+ dim1 arg8 ;
12001
12003
  double arg9 ;
12002
12004
  double arg10 ;
12003
12005
  double arg11 ;
@@ -12020,8 +12022,6 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_1(PyObject *self, Py_ssize_t nobjs,
12020
12022
  int ecode4 = 0 ;
12021
12023
  double val5 ;
12022
12024
  int ecode5 = 0 ;
12023
- double val8 ;
12024
- int ecode8 = 0 ;
12025
12025
  double val9 ;
12026
12026
  int ecode9 = 0 ;
12027
12027
  double val10 ;
@@ -12085,11 +12085,15 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_1(PyObject *self, Py_ssize_t nobjs,
12085
12085
  arg7 = *ptr;
12086
12086
  if (SWIG_IsNewObj(res)) delete ptr;
12087
12087
  }
12088
- ecode8 = SWIG_AsVal_double(swig_obj[7], &val8);
12089
- if (!SWIG_IsOK(ecode8)) {
12090
- SWIG_exception_fail(SWIG_ArgError(ecode8), "in method '" "new_JR_sde" "', argument " "8"" of type '" "double""'");
12091
- }
12092
- arg8 = static_cast< double >(val8);
12088
+ {
12089
+ std::vector< double,std::allocator< double > > *ptr = (std::vector< double,std::allocator< double > > *)0;
12090
+ int res = swig::asptr(swig_obj[7], &ptr);
12091
+ if (!SWIG_IsOK(res) || !ptr) {
12092
+ SWIG_exception_fail(SWIG_ArgError((ptr ? res : SWIG_TypeError)), "in method '" "new_JR_sde" "', argument " "8"" of type '" "dim1""'");
12093
+ }
12094
+ arg8 = *ptr;
12095
+ if (SWIG_IsNewObj(res)) delete ptr;
12096
+ }
12093
12097
  ecode9 = SWIG_AsVal_double(swig_obj[8], &val9);
12094
12098
  if (!SWIG_IsOK(ecode9)) {
12095
12099
  SWIG_exception_fail(SWIG_ArgError(ecode9), "in method '" "new_JR_sde" "', argument " "9"" of type '" "double""'");
@@ -12166,7 +12170,7 @@ SWIGINTERN PyObject *_wrap_new_JR_sde__SWIG_1(PyObject *self, Py_ssize_t nobjs,
12166
12170
  SWIG_exception_fail(SWIG_ArgError(ecode20), "in method '" "new_JR_sde" "', argument " "20"" of type '" "double""'");
12167
12171
  }
12168
12172
  arg20 = static_cast< double >(val20);
12169
- result = (JR_sde *)new JR_sde(SWIG_STD_MOVE(arg1),arg2,arg3,arg4,arg5,SWIG_STD_MOVE(arg6),SWIG_STD_MOVE(arg7),arg8,arg9,arg10,arg11,arg12,arg13,arg14,SWIG_STD_MOVE(arg15),SWIG_STD_MOVE(arg16),SWIG_STD_MOVE(arg17),SWIG_STD_MOVE(arg18),arg19,arg20);
12173
+ result = (JR_sde *)new JR_sde(SWIG_STD_MOVE(arg1),arg2,arg3,arg4,arg5,SWIG_STD_MOVE(arg6),SWIG_STD_MOVE(arg7),SWIG_STD_MOVE(arg8),arg9,arg10,arg11,arg12,arg13,arg14,SWIG_STD_MOVE(arg15),SWIG_STD_MOVE(arg16),SWIG_STD_MOVE(arg17),SWIG_STD_MOVE(arg18),arg19,arg20);
12170
12174
  resultobj = SWIG_NewPointerObj(SWIG_as_voidptr(result), SWIGTYPE_p_JR_sde, SWIG_POINTER_NEW | 0 );
12171
12175
  return resultobj;
12172
12176
  fail:
@@ -12215,10 +12219,8 @@ SWIGINTERN PyObject *_wrap_new_JR_sde(PyObject *self, PyObject *args) {
12215
12219
  int res = swig::asptr(argv[6], (std::vector< double,std::allocator< double > >**)(0));
12216
12220
  _v = SWIG_CheckState(res);
12217
12221
  if (_v) {
12218
- {
12219
- int res = SWIG_AsVal_double(argv[7], NULL);
12220
- _v = SWIG_CheckState(res);
12221
- }
12222
+ int res = swig::asptr(argv[7], (std::vector< double,std::allocator< double > >**)(0));
12223
+ _v = SWIG_CheckState(res);
12222
12224
  if (_v) {
12223
12225
  {
12224
12226
  int res = SWIG_AsVal_double(argv[8], NULL);
@@ -12327,10 +12329,8 @@ SWIGINTERN PyObject *_wrap_new_JR_sde(PyObject *self, PyObject *args) {
12327
12329
  int res = swig::asptr(argv[6], (std::vector< double,std::allocator< double > >**)(0));
12328
12330
  _v = SWIG_CheckState(res);
12329
12331
  if (_v) {
12330
- {
12331
- int res = SWIG_AsVal_double(argv[7], NULL);
12332
- _v = SWIG_CheckState(res);
12333
- }
12332
+ int res = swig::asptr(argv[7], (std::vector< double,std::allocator< double > >**)(0));
12333
+ _v = SWIG_CheckState(res);
12334
12334
  if (_v) {
12335
12335
  {
12336
12336
  int res = SWIG_AsVal_double(argv[8], NULL);
@@ -12416,8 +12416,8 @@ SWIGINTERN PyObject *_wrap_new_JR_sde(PyObject *self, PyObject *args) {
12416
12416
  fail:
12417
12417
  SWIG_Python_RaiseOrModifyTypeError("Wrong number or type of arguments for overloaded function 'new_JR_sde'.\n"
12418
12418
  " Possible C/C++ prototypes are:\n"
12419
- " JR_sde::JR_sde(size_t,double,double,double,double,dim2,dim1,double,double,double,double,double,double,double,dim1,dim1,dim1,dim1,double,double,int)\n"
12420
- " JR_sde::JR_sde(size_t,double,double,double,double,dim2,dim1,double,double,double,double,double,double,double,dim1,dim1,dim1,dim1,double,double)\n");
12419
+ " JR_sde::JR_sde(size_t,double,double,double,double,dim2,dim1,dim1,double,double,double,double,double,double,dim1,dim1,dim1,dim1,double,double,int)\n"
12420
+ " JR_sde::JR_sde(size_t,double,double,double,double,dim2,dim1,dim1,double,double,double,double,double,double,dim1,dim1,dim1,dim1,double,double)\n");
12421
12421
  return 0;
12422
12422
  }
12423
12423
 
@@ -124,7 +124,7 @@ class JR_sde:
124
124
  "method": "heun",
125
125
  "t_transition": 500.0, # ms
126
126
  "t_end": 2501.0, # ms
127
- "output": "output", # output directory
127
+ "output": "output", # output directory
128
128
  "RECORD_AVG": False # true to store large time series in file
129
129
  }
130
130
  return par
@@ -181,17 +181,13 @@ class JR_sde:
181
181
  self.t_transition = float(self.t_transition)
182
182
  self.t_end = float(self.t_end)
183
183
  self.G = float(self.G)
184
- self.A = float(self.A)
185
184
  self.B = float(self.B)
186
185
  self.a = float(self.a)
187
186
  self.b = float(self.b)
188
187
  self.r = float(self.r)
189
188
  self.v0 = float(self.v0)
190
189
  self.vmax = float(self.vmax)
191
- # self.C0 = np.asarray(self.C0)
192
- # self.C1 = np.asarray(self.C1)
193
- # self.C2 = np.asarray(self.C2)
194
- # self.C3 = np.asarray(self.C3)
190
+ self.A = check_sequence(self.A, self.N)
195
191
  self.C0 = check_sequence(self.C0, self.N)
196
192
  self.C1 = check_sequence(self.C1, self.N)
197
193
  self.C2 = check_sequence(self.C2, self.N)
@@ -236,9 +232,6 @@ class JR_sde:
236
232
  for key in par.keys():
237
233
  if key not in self.valid_params:
238
234
  raise ValueError("Invalid parameter: " + key)
239
- # if key in ["C0", "C1", "C2", "C3"]:
240
- # self.set_C(key, par[key])
241
- # else:
242
235
  setattr(self, key, par[key])
243
236
 
244
237
  self.prepare_input()
vbi/models/cupy/bold.py CHANGED
@@ -105,6 +105,123 @@ class BoldStephan2008:
105
105
  q[0] = q[1]
106
106
 
107
107
 
108
+ class Bold:
109
+
110
+ def __init__(self, par: dict = {}) -> None:
111
+
112
+ self._par = self.get_default_parameters()
113
+ self.valid_parameters = list(self._par.keys())
114
+ self.check_parameters(par)
115
+ self._par.update(par)
116
+
117
+ for item in self._par.items():
118
+ setattr(self, item[0], item[1])
119
+ self.update_dependent_parameters()
120
+
121
+
122
+ def get_default_parameters(self):
123
+ """get balloon model parameters."""
124
+
125
+ vo = 0.08
126
+ theta = 40.3
127
+ TE = 0.04
128
+ Eo = 0.4
129
+ r0 = 25.0
130
+ epsilon = 0.34
131
+ k1 = 4.3 * theta * Eo * TE
132
+ k2 = epsilon * r0 * Eo * TE
133
+ k3 = 1 - epsilon
134
+
135
+ par = {
136
+ "kappa": 0.65,
137
+ "gamma": 0.41,
138
+ "tau": 0.98,
139
+ "alpha": 0.32,
140
+ "epsilon": epsilon,
141
+ "Eo": Eo,
142
+ "TE": TE,
143
+ "vo": vo,
144
+ "r0": r0,
145
+ "theta": theta,
146
+ "t_min": 0.0,
147
+ "rtol": 1e-5,
148
+ "atol": 1e-8,
149
+ "k1": k1,
150
+ "k2": k2,
151
+ "k3": k3
152
+ }
153
+ return par
154
+
155
+ def update_dependent_parameters(self):
156
+ self.k1 = 4.3 * self.theta * self.Eo * self.TE
157
+ self.k2 = self.epsilon * self.r0 * self.Eo * self.TE
158
+ self.k3 = 1 - self.epsilon
159
+
160
+ def check_parameters(self, par):
161
+ for key in par.keys():
162
+ if key not in self.valid_parameters:
163
+ raise ValueError(f"Invalid parameter {key:s} provided.")
164
+
165
+ def allocate_memory(self, xp, nn, ns, n_steps, bold_decimate, dtype):
166
+
167
+ self.s = xp.zeros((2, nn, ns), dtype=dtype)
168
+ self.f = xp.zeros((2, nn, ns), dtype=dtype)
169
+ self.ftilde = xp.zeros((2, nn, ns), dtype=dtype)
170
+ self.vtilde = xp.zeros((2, nn, ns), dtype=dtype)
171
+ self.qtilde = xp.zeros((2, nn, ns), dtype=dtype)
172
+ self.v = xp.zeros((2, nn, ns), dtype=dtype)
173
+ self.q = xp.zeros((2, nn, ns), dtype=dtype)
174
+ self.vv = np.zeros((n_steps // bold_decimate, nn, ns), dtype="f")
175
+ self.qq = np.zeros((n_steps // bold_decimate, nn, ns), dtype="f")
176
+ self.s[0] = 1
177
+ self.f[0] = 1
178
+ self.v[0] = 1
179
+ self.q[0] = 1
180
+ self.ftilde[0] = 0
181
+ self.vtilde[0] = 0
182
+ self.qtilde[0] = 0
183
+
184
+ def do_bold_step(self, r_in, dtt):
185
+
186
+ Eo = self.Eo
187
+ tau = self.tau
188
+ kappa = self.kappa
189
+ gamma = self.gamma
190
+ alpha = self.alpha
191
+ ialpha = 1 / alpha
192
+
193
+ v = self.v
194
+ q = self.q
195
+ s = self.s
196
+ f = self.f
197
+ ftilde = self.ftilde
198
+ vtilde = self.vtilde
199
+ qtilde = self.qtilde
200
+
201
+ s[1] = s[0] + dtt * (r_in - kappa * s[0] - gamma * (f[0] - 1))
202
+ f[0] = np.clip(f[0], 1, None)
203
+ ftilde[1] = ftilde[0] + dtt * (s[0] / f[0])
204
+ fv = v[0] ** ialpha # outflow
205
+ vtilde[1] = vtilde[0] + dtt * ((f[0] - fv) / (tau * v[0]))
206
+ q[0] = np.clip(q[0], 0.01, None)
207
+ ff = (1 - (1 - Eo) ** (1 / f[0])) / Eo # oxygen extraction
208
+ qtilde[1] = qtilde[0] + dtt * ((f[0] * ff - fv * q[0] / v[0]) / (tau * q[0]))
209
+
210
+ f[1] = np.exp(ftilde[1])
211
+ v[1] = np.exp(vtilde[1])
212
+ q[1] = np.exp(qtilde[1])
213
+
214
+ f[0] = f[1]
215
+ s[0] = s[1]
216
+ ftilde[0] = ftilde[1]
217
+ vtilde[0] = vtilde[1]
218
+ qtilde[0] = qtilde[1]
219
+ v[0] = v[1]
220
+ q[0] = q[1]
221
+
222
+
223
+
224
+
108
225
  class BoldTVB:
109
226
 
110
227
  def __init__(self):
@@ -199,7 +199,7 @@ class JR_sde:
199
199
 
200
200
  self.G = self.xp.array(self.G)
201
201
  assert self.weights is not None, "weights must be provided"
202
- self.weights = self.xp.array(self.weights).T # ! check this
202
+ self.weights = self.xp.array(self.weights)
203
203
  self.weights = move_data(self.weights, self.engine)
204
204
  self.nn = self.num_nodes = self.weights.shape[0]
205
205