vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl → 0.2__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vbi/feature_extraction/features.json +4 -1
- vbi/feature_extraction/features.py +10 -4
- vbi/inference.py +50 -22
- vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/jr_sde.hpp +5 -6
- vbi/models/cpp/_src/jr_sde_wrap.cxx +28 -28
- vbi/models/cpp/jansen_rit.py +2 -9
- vbi/models/cupy/bold.py +117 -0
- vbi/models/cupy/jansen_rit.py +1 -1
- vbi/models/cupy/km.py +62 -34
- vbi/models/cupy/mpr.py +24 -4
- vbi/models/cupy/utils.py +163 -2
- vbi/models/cupy/wilson_cowan.py +317 -0
- vbi/models/cupy/ww.py +342 -0
- vbi/models/numba/__init__.py +4 -0
- vbi/models/numba/jansen_rit.py +532 -0
- vbi/models/numba/mpr.py +8 -0
- vbi/models/numba/wilson_cowan.py +443 -0
- vbi/models/numba/ww.py +564 -0
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/METADATA +30 -11
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/RECORD +30 -26
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/WHEEL +1 -1
- vbi/models/numba/_ww_EI.py +0 -444
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/licenses/LICENSE +0 -0
- {vbi-0.1.3.dist-info → vbi-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,532 @@
|
|
1
|
+
import warnings
|
2
|
+
import numpy as np
|
3
|
+
from numba import njit
|
4
|
+
from numba.experimental import jitclass
|
5
|
+
from numba import float64, boolean, int64
|
6
|
+
from numba.extending import register_jitable
|
7
|
+
from numba.core.errors import NumbaPerformanceWarning
|
8
|
+
|
9
|
+
warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
|
10
|
+
|
11
|
+
# ---------------------------------------------------------------
|
12
|
+
# Helper utilities (broadcasting, seeding, initial state)
|
13
|
+
# ---------------------------------------------------------------
|
14
|
+
|
15
|
+
@register_jitable
|
16
|
+
def set_seed_compat(x):
|
17
|
+
np.random.seed(x)
|
18
|
+
|
19
|
+
|
20
|
+
@register_jitable
|
21
|
+
def _as_1d_array_like(x, nn):
|
22
|
+
"""Broadcast scalar to 1D array of length nn if needed."""
|
23
|
+
x_arr = np.array(x) if not isinstance(x, np.ndarray) else x
|
24
|
+
if x_arr.ndim == 0:
|
25
|
+
return np.ones(nn) * float(x_arr)
|
26
|
+
if x_arr.ndim == 1 and x_arr.shape[0] == nn:
|
27
|
+
return x_arr.astype(np.float64)
|
28
|
+
raise ValueError("Parameter must be scalar or 1D array of length nn")
|
29
|
+
|
30
|
+
|
31
|
+
@njit
|
32
|
+
def set_initial_state_jr(nn, seed=-1):
|
33
|
+
"""Initial state for JR: stack 6*n vectors.
|
34
|
+
Mirrors ranges similar to the CuPy reference implementation.
|
35
|
+
"""
|
36
|
+
if seed is not None and seed >= 0:
|
37
|
+
set_seed_compat(seed)
|
38
|
+
|
39
|
+
y0 = np.random.uniform(-1.0, 1.0, nn) # x
|
40
|
+
y1 = np.random.uniform(-500.0, 500.0, nn) # y
|
41
|
+
y2 = np.random.uniform(-50.0, 50.0, nn) # z
|
42
|
+
y3 = np.random.uniform(-6.0, 6.0, nn) # x'
|
43
|
+
y4 = np.random.uniform(-20.0, 20.0, nn) # y'
|
44
|
+
y5 = np.random.uniform(-500.0, 500.0, nn) # z'
|
45
|
+
|
46
|
+
y = np.zeros(6 * nn)
|
47
|
+
y[:nn] = y0
|
48
|
+
y[nn:2*nn] = y1
|
49
|
+
y[2*nn:3*nn] = y2
|
50
|
+
y[3*nn:4*nn] = y3
|
51
|
+
y[4*nn:5*nn] = y4
|
52
|
+
y[5*nn:6*nn] = y5
|
53
|
+
return y
|
54
|
+
|
55
|
+
|
56
|
+
# ---------------------------------------------------------------
|
57
|
+
# JR parameters as a jitclass (Numba-friendly container)
|
58
|
+
# ---------------------------------------------------------------
|
59
|
+
|
60
|
+
jr_spec = [
|
61
|
+
("G", float64),
|
62
|
+
("A", float64),
|
63
|
+
("B", float64),
|
64
|
+
("a", float64),
|
65
|
+
("b", float64),
|
66
|
+
("v0", float64),
|
67
|
+
("vmax", float64),
|
68
|
+
("r", float64),
|
69
|
+
("mu", float64),
|
70
|
+
("noise_amp", float64),
|
71
|
+
("dt", float64),
|
72
|
+
("t_cut", float64),
|
73
|
+
("t_end", float64),
|
74
|
+
("decimate", int64),
|
75
|
+
("nn", int64),
|
76
|
+
("seed", int64),
|
77
|
+
("sigma", float64), # sqrt(dt) * noise_amp
|
78
|
+
# arrays
|
79
|
+
("weights", float64[:, :]),
|
80
|
+
("C0", float64[:]),
|
81
|
+
("C1", float64[:]),
|
82
|
+
("C2", float64[:]),
|
83
|
+
("C3", float64[:]),
|
84
|
+
("initial_state", float64[:]),
|
85
|
+
]
|
86
|
+
|
87
|
+
|
88
|
+
@jitclass(jr_spec)
|
89
|
+
class ParJR:
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
weights,
|
93
|
+
G=1.0,
|
94
|
+
A=3.25,
|
95
|
+
B=22.0,
|
96
|
+
a=0.1,
|
97
|
+
b=0.05,
|
98
|
+
v0=6.0,
|
99
|
+
vmax=0.005,
|
100
|
+
r=0.56,
|
101
|
+
mu=0.24,
|
102
|
+
noise_amp=0.01,
|
103
|
+
dt=0.01,
|
104
|
+
t_cut=0.0,
|
105
|
+
t_end=1000.0,
|
106
|
+
decimate=1,
|
107
|
+
C0=135.0,
|
108
|
+
C1=0.8*135.0,
|
109
|
+
C2=0.25*135.0,
|
110
|
+
C3=0.25*135.0,
|
111
|
+
seed=-1
|
112
|
+
):
|
113
|
+
self.weights = weights
|
114
|
+
self.nn = len(weights)
|
115
|
+
|
116
|
+
self.G = G
|
117
|
+
self.A = A
|
118
|
+
self.B = B
|
119
|
+
self.a = a
|
120
|
+
self.b = b
|
121
|
+
self.v0 = v0
|
122
|
+
self.vmax = vmax
|
123
|
+
self.r = r
|
124
|
+
self.mu = mu
|
125
|
+
self.noise_amp = noise_amp
|
126
|
+
self.dt = dt
|
127
|
+
self.t_cut = t_cut
|
128
|
+
self.t_end = t_end
|
129
|
+
self.decimate = decimate
|
130
|
+
self.seed = seed
|
131
|
+
|
132
|
+
# C arrays are now passed pre-processed from outside
|
133
|
+
self.C0 = C0
|
134
|
+
self.C1 = C1
|
135
|
+
self.C2 = C2
|
136
|
+
self.C3 = C3
|
137
|
+
|
138
|
+
self.sigma = np.sqrt(dt) * noise_amp
|
139
|
+
self.initial_state = np.zeros(6 * self.nn) # set by caller later
|
140
|
+
|
141
|
+
|
142
|
+
# ---------------------------------------------------------------
|
143
|
+
# JR model equations + integrator (Numba-jitted)
|
144
|
+
# ---------------------------------------------------------------
|
145
|
+
|
146
|
+
@register_jitable
|
147
|
+
def S_sigmoid(x, vmax, r, v0):
|
148
|
+
"""Numerically stable sigmoid function to avoid overflow."""
|
149
|
+
z = r * (v0 - x)
|
150
|
+
# Clip z to avoid overflow: exp(700) is near overflow limit
|
151
|
+
z_clipped = np.clip(z, -700, 700)
|
152
|
+
return vmax / (1.0 + np.exp(z_clipped))
|
153
|
+
|
154
|
+
|
155
|
+
@njit
|
156
|
+
def f_jr(x, t, P):
|
157
|
+
nn = P.nn
|
158
|
+
|
159
|
+
# Unpack state
|
160
|
+
x0 = x[0*nn:1*nn] # x
|
161
|
+
y0 = x[1*nn:2*nn] # y
|
162
|
+
z0 = x[2*nn:3*nn] # z
|
163
|
+
xp = x[3*nn:4*nn] # x'
|
164
|
+
yp = x[4*nn:5*nn] # y'
|
165
|
+
zp = x[5*nn:6*nn] # z'
|
166
|
+
|
167
|
+
# Precompute constants
|
168
|
+
Aa = P.A * P.a
|
169
|
+
Bb = P.B * P.b
|
170
|
+
aa = P.a * P.a
|
171
|
+
bb = P.b * P.b
|
172
|
+
|
173
|
+
# Coupling term: weights @ (y - z)
|
174
|
+
couplings = S_sigmoid(P.weights.dot(y0 - z0), P.vmax, P.r, P.v0)
|
175
|
+
|
176
|
+
# Allocate derivative
|
177
|
+
dxdt = np.zeros_like(x)
|
178
|
+
|
179
|
+
# Dynamics
|
180
|
+
dxdt[0*nn:1*nn] = xp
|
181
|
+
dxdt[1*nn:2*nn] = yp
|
182
|
+
dxdt[2*nn:3*nn] = zp
|
183
|
+
|
184
|
+
dxdt[3*nn:4*nn] = Aa * S_sigmoid(y0 - z0, P.vmax, P.r, P.v0) - 2.0 * P.a * xp - aa * x0
|
185
|
+
dxdt[4*nn:5*nn] = (
|
186
|
+
Aa * (P.mu + P.C1 * S_sigmoid(P.C0 * x0, P.vmax, P.r, P.v0) + P.G * couplings)
|
187
|
+
- 2.0 * P.a * yp - aa * y0
|
188
|
+
)
|
189
|
+
dxdt[5*nn:6*nn] = Bb * P.C3 * S_sigmoid(P.C2 * x0, P.vmax, P.r, P.v0) - 2.0 * P.b * zp - bb * z0
|
190
|
+
|
191
|
+
return dxdt
|
192
|
+
|
193
|
+
|
194
|
+
@njit
|
195
|
+
def heun_sde(x, t, P):
|
196
|
+
nn = P.nn
|
197
|
+
dt = P.dt
|
198
|
+
|
199
|
+
# Stochastic drive on the y' block, sigma already includes sqrt(dt)
|
200
|
+
dW = P.sigma * np.random.randn(nn)
|
201
|
+
|
202
|
+
k1 = f_jr(x, t, P)
|
203
|
+
x1 = x + dt * k1
|
204
|
+
x1[4*nn:5*nn] += dW
|
205
|
+
|
206
|
+
k2 = f_jr(x1, t + dt, P)
|
207
|
+
x = x + 0.5 * dt * (k1 + k2)
|
208
|
+
x[4*nn:5*nn] += dW
|
209
|
+
|
210
|
+
return x
|
211
|
+
|
212
|
+
|
213
|
+
# ---------------------------------------------------------------
|
214
|
+
# Top-level integrate and driver class (mirrors mpr.py style)
|
215
|
+
# ---------------------------------------------------------------
|
216
|
+
|
217
|
+
def integrate(P):
|
218
|
+
nn = P.nn
|
219
|
+
dt = P.dt
|
220
|
+
dec = P.decimate
|
221
|
+
|
222
|
+
# Ensure initial state is defined
|
223
|
+
x = P.initial_state.copy()
|
224
|
+
|
225
|
+
nt = int(P.t_end / dt)
|
226
|
+
tspan = np.linspace(0.0, (nt - 1) * dt, nt)
|
227
|
+
|
228
|
+
# Cut & decimate bookkeeping
|
229
|
+
i_cut = int(np.searchsorted(tspan, P.t_cut, side='left'))
|
230
|
+
n_keep = (nt - i_cut + (dec - 1)) // dec
|
231
|
+
|
232
|
+
# Output: y - z
|
233
|
+
ts = np.zeros(n_keep, dtype=np.float32)
|
234
|
+
ys = np.zeros((n_keep, nn), dtype=np.float32)
|
235
|
+
|
236
|
+
k = 0
|
237
|
+
for i in range(nt):
|
238
|
+
t = tspan[i]
|
239
|
+
x = heun_sde(x, t, P)
|
240
|
+
|
241
|
+
if i >= i_cut and ((i - i_cut) % dec == 0):
|
242
|
+
ts[k] = t
|
243
|
+
y0 = x[1*nn:2*nn]
|
244
|
+
z0 = x[2*nn:3*nn]
|
245
|
+
ys[k, :] = (y0 - z0).astype(np.float32)
|
246
|
+
k += 1
|
247
|
+
if k >= n_keep:
|
248
|
+
break
|
249
|
+
|
250
|
+
return {"t": ts, "x": ys}
|
251
|
+
|
252
|
+
|
253
|
+
class JR_sde:
|
254
|
+
"""
|
255
|
+
Numba implementation of the Jansen-Rit neural mass model.
|
256
|
+
|
257
|
+
.. list-table:: Parameters
|
258
|
+
:widths: 25 50 25
|
259
|
+
:header-rows: 1
|
260
|
+
|
261
|
+
* - Name
|
262
|
+
- Explanation
|
263
|
+
- Default Value
|
264
|
+
* - `A`
|
265
|
+
- Excitatory post synaptic potential amplitude.
|
266
|
+
- 3.25
|
267
|
+
* - `B`
|
268
|
+
- Inhibitory post synaptic potential amplitude.
|
269
|
+
- 22.0
|
270
|
+
* - `a`
|
271
|
+
- Inverse time constant of the excitatory postsynaptic potential (1/a = time constant).
|
272
|
+
- 0.1 (time constant: 10.0)
|
273
|
+
* - `b`
|
274
|
+
- Inverse time constant of the inhibitory postsynaptic potential (1/b = time constant).
|
275
|
+
- 0.05 (time constant: 20.0)
|
276
|
+
* - `C0`
|
277
|
+
- Average number of synapses between pyramidal cells and excitatory interneurons. If array-like, it should be of length `nn` (number of nodes).
|
278
|
+
- 135.0
|
279
|
+
* - `C1`
|
280
|
+
- Average number of synapses between excitatory interneurons and pyramidal cells. If array-like, it should be of length `nn`.
|
281
|
+
- 0.8 * 135.0
|
282
|
+
* - `C2`
|
283
|
+
- Average number of synapses between pyramidal cells and inhibitory interneurons. If array-like, it should be of length `nn`.
|
284
|
+
- 0.25 * 135.0
|
285
|
+
* - `C3`
|
286
|
+
- Average number of synapses between inhibitory interneurons and pyramidal cells. If array-like, it should be of length `nn`.
|
287
|
+
- 0.25 * 135.0
|
288
|
+
* - `vmax`
|
289
|
+
- Maximum firing rate of the sigmoid function.
|
290
|
+
- 0.005
|
291
|
+
* - `v0`
|
292
|
+
- Potential at half of maximum firing rate (inflection point of sigmoid).
|
293
|
+
- 6.0
|
294
|
+
* - `r`
|
295
|
+
- Slope of sigmoid function at `v0`.
|
296
|
+
- 0.56
|
297
|
+
* - `G`
|
298
|
+
- Global coupling strength scaling the network connections.
|
299
|
+
- 1.0
|
300
|
+
* - `mu`
|
301
|
+
- Mean input to the excitatory population (external drive).
|
302
|
+
- 0.24
|
303
|
+
* - `noise_amp`
|
304
|
+
- Amplitude of the stochastic noise applied to the excitatory population.
|
305
|
+
- 0.01
|
306
|
+
* - `weights`
|
307
|
+
- Structural connectivity matrix of shape (`nn`, `nn`). Must be provided.
|
308
|
+
- None
|
309
|
+
* - `dt`
|
310
|
+
- Integration time step.
|
311
|
+
- 0.01
|
312
|
+
* - `t_end`
|
313
|
+
- End time of simulation.
|
314
|
+
- 1000.0
|
315
|
+
* - `t_cut`
|
316
|
+
- Time from which to start collecting output (burn-in period).
|
317
|
+
- 0.0
|
318
|
+
* - `decimate`
|
319
|
+
- Decimation factor for the output time series (every `decimate`-th point is saved).
|
320
|
+
- 1
|
321
|
+
* - `seed`
|
322
|
+
- Random seed for reproducible simulations. If -1 or None, no seeding is applied.
|
323
|
+
- -1
|
324
|
+
* - `initial_state`
|
325
|
+
- Initial state vector of shape (6*nn,). If None, random initial conditions are generated.
|
326
|
+
- None
|
327
|
+
|
328
|
+
Usage example (single simulation):
|
329
|
+
>>> import numpy as np
|
330
|
+
>>> from vbi.models.numba.jansen_rit import JR_sde
|
331
|
+
>>> W = np.eye(2) # 2-node demo connectivity
|
332
|
+
>>> jr = JR_sde({"weights": W, "dt": 0.01, "t_end": 200.0, "t_cut": 100.0, "decimate": 1})
|
333
|
+
>>> out = jr.run()
|
334
|
+
>>> t, x = out["t"], out["x"] # x has shape (n_step, nn)
|
335
|
+
|
336
|
+
Notes
|
337
|
+
-----
|
338
|
+
The Jansen-Rit model describes the dynamics of a cortical column with three neural populations:
|
339
|
+
- Pyramidal cells (main excitatory population)
|
340
|
+
- Excitatory interneurons
|
341
|
+
- Inhibitory interneurons
|
342
|
+
|
343
|
+
The model equations are integrated using the Heun stochastic integration scheme.
|
344
|
+
The output represents the difference between excitatory and inhibitory postsynaptic potentials (y - z),
|
345
|
+
which corresponds to the local field potential that can be measured experimentally.
|
346
|
+
"""
|
347
|
+
|
348
|
+
def __init__(self, par_jr: dict):
|
349
|
+
"""
|
350
|
+
Initialize the Jansen-Rit model.
|
351
|
+
|
352
|
+
Parameters
|
353
|
+
----------
|
354
|
+
par_jr : dict
|
355
|
+
Dictionary containing model parameters. See class documentation for available parameters.
|
356
|
+
The 'weights' parameter is required and must be a square connectivity matrix.
|
357
|
+
"""
|
358
|
+
# Validate weights early and create parameter jitclass
|
359
|
+
if "weights" not in par_jr or par_jr["weights"] is None:
|
360
|
+
raise ValueError("'weights' must be provided (square connectivity matrix)")
|
361
|
+
|
362
|
+
W = np.array(par_jr["weights"], dtype=np.float64)
|
363
|
+
if W.ndim != 2 or W.shape[0] != W.shape[1]:
|
364
|
+
raise ValueError("'weights' must be a square 2D array")
|
365
|
+
|
366
|
+
nn = len(W)
|
367
|
+
|
368
|
+
# Pre-process C parameters before passing to jitclass
|
369
|
+
params = dict(par_jr)
|
370
|
+
params["weights"] = W
|
371
|
+
|
372
|
+
# Handle C parameters - broadcast them here outside jitclass
|
373
|
+
for c_name in ["C0", "C1", "C2", "C3"]:
|
374
|
+
if c_name in params:
|
375
|
+
c_val = params[c_name]
|
376
|
+
params[c_name] = _as_1d_array_like(c_val, nn)
|
377
|
+
else:
|
378
|
+
# Set defaults
|
379
|
+
if c_name == "C0":
|
380
|
+
params[c_name] = _as_1d_array_like(135.0, nn)
|
381
|
+
elif c_name == "C1":
|
382
|
+
params[c_name] = _as_1d_array_like(0.8*135.0, nn)
|
383
|
+
elif c_name == "C2":
|
384
|
+
params[c_name] = _as_1d_array_like(0.25*135.0, nn)
|
385
|
+
elif c_name == "C3":
|
386
|
+
params[c_name] = _as_1d_array_like(0.25*135.0, nn)
|
387
|
+
|
388
|
+
# Create jitclass instance
|
389
|
+
self.P = ParJR(**params)
|
390
|
+
|
391
|
+
# Seed handling
|
392
|
+
self.seed = int(self.P.seed)
|
393
|
+
if self.seed >= 0:
|
394
|
+
np.random.seed(self.seed)
|
395
|
+
|
396
|
+
# Ensure initial state
|
397
|
+
if "initial_state" in par_jr and par_jr["initial_state"] is not None:
|
398
|
+
x0 = np.array(par_jr["initial_state"], dtype=np.float64)
|
399
|
+
if x0.shape[0] != 6 * self.P.nn:
|
400
|
+
raise ValueError("initial_state must have length 6*nn")
|
401
|
+
self.P.initial_state = x0
|
402
|
+
else:
|
403
|
+
self.P.initial_state = set_initial_state_jr(self.P.nn, self.seed)
|
404
|
+
|
405
|
+
self._checked = False
|
406
|
+
|
407
|
+
def __str__(self) -> str:
|
408
|
+
"""
|
409
|
+
Return a string representation of the model parameters.
|
410
|
+
|
411
|
+
Returns
|
412
|
+
-------
|
413
|
+
str
|
414
|
+
Formatted string showing all model parameters and their values.
|
415
|
+
"""
|
416
|
+
print("Jansen-Rit Model (Numba)")
|
417
|
+
print("------------------------")
|
418
|
+
|
419
|
+
# Model parameters
|
420
|
+
print(f"G = {self.P.G}")
|
421
|
+
print(f"A = {self.P.A}")
|
422
|
+
print(f"B = {self.P.B}")
|
423
|
+
print(f"a = {self.P.a}")
|
424
|
+
print(f"b = {self.P.b}")
|
425
|
+
print(f"v0 = {self.P.v0}")
|
426
|
+
print(f"vmax = {self.P.vmax}")
|
427
|
+
print(f"r = {self.P.r}")
|
428
|
+
print(f"mu = {self.P.mu}")
|
429
|
+
print(f"noise_amp = {self.P.noise_amp}")
|
430
|
+
|
431
|
+
# Connectivity parameters
|
432
|
+
print(f"C0 = {self.P.C0}")
|
433
|
+
print(f"C1 = {self.P.C1}")
|
434
|
+
print(f"C2 = {self.P.C2}")
|
435
|
+
print(f"C3 = {self.P.C3}")
|
436
|
+
|
437
|
+
# Simulation parameters
|
438
|
+
print(f"dt = {self.P.dt}")
|
439
|
+
print(f"t_end = {self.P.t_end}")
|
440
|
+
print(f"t_cut = {self.P.t_cut}")
|
441
|
+
print(f"decimate = {self.P.decimate}")
|
442
|
+
print(f"nn = {self.P.nn}")
|
443
|
+
print(f"seed = {self.P.seed}")
|
444
|
+
print(f"sigma = {self.P.sigma}")
|
445
|
+
print(f"weights shape = {self.P.weights.shape}")
|
446
|
+
|
447
|
+
return ""
|
448
|
+
|
449
|
+
def check_input(self):
|
450
|
+
"""
|
451
|
+
Validate model parameters.
|
452
|
+
|
453
|
+
Raises
|
454
|
+
------
|
455
|
+
ValueError
|
456
|
+
If any parameter values are invalid (e.g., t_cut >= t_end,
|
457
|
+
decimate < 1, or dimension mismatches).
|
458
|
+
"""
|
459
|
+
if self.P.t_cut >= self.P.t_end:
|
460
|
+
raise ValueError("t_cut must be less than t_end")
|
461
|
+
if self.P.decimate < 1:
|
462
|
+
raise ValueError("decimate must be >= 1")
|
463
|
+
if self.P.nn != self.P.weights.shape[0]:
|
464
|
+
raise ValueError("nn != weights.shape[0]")
|
465
|
+
self._checked = True
|
466
|
+
|
467
|
+
def set_initial_state(self, seed: int = None):
|
468
|
+
"""
|
469
|
+
Set random initial state for the simulation.
|
470
|
+
|
471
|
+
Parameters
|
472
|
+
----------
|
473
|
+
seed : int, optional
|
474
|
+
Random seed for reproducible initial conditions.
|
475
|
+
If None, uses the seed specified during initialization.
|
476
|
+
"""
|
477
|
+
seed_ = self.seed if seed is None else seed
|
478
|
+
self.P.initial_state = set_initial_state_jr(self.P.nn, seed_)
|
479
|
+
|
480
|
+
def run(self, par: dict = None, x0: np.ndarray = None):
|
481
|
+
"""
|
482
|
+
Run the Jansen-Rit simulation.
|
483
|
+
|
484
|
+
Parameters
|
485
|
+
----------
|
486
|
+
par : dict, optional
|
487
|
+
Dictionary of parameters to update for this simulation run.
|
488
|
+
Any parameter from the class documentation can be updated.
|
489
|
+
x0 : np.ndarray, optional
|
490
|
+
Initial state vector of shape (6*nn,). If None, uses the
|
491
|
+
initial state set during initialization or by set_initial_state().
|
492
|
+
|
493
|
+
Returns
|
494
|
+
-------
|
495
|
+
dict
|
496
|
+
Dictionary containing simulation results:
|
497
|
+
- 't': np.ndarray of shape (n_steps,) - time points
|
498
|
+
- 'x': np.ndarray of shape (n_steps, nn) - simulated time series (y - z)
|
499
|
+
representing local field potentials
|
500
|
+
"""
|
501
|
+
# Optionally update parameters on the jitclass (Numba allows setattr)
|
502
|
+
if par:
|
503
|
+
for k, v in par.items():
|
504
|
+
if k == "weights":
|
505
|
+
W = np.array(v, dtype=np.float64)
|
506
|
+
if W.ndim != 2 or W.shape[0] != W.shape[1]:
|
507
|
+
raise ValueError("'weights' must be a square 2D array")
|
508
|
+
setattr(self.P, "weights", W)
|
509
|
+
setattr(self.P, "nn", len(W))
|
510
|
+
elif k in ("C0", "C1", "C2", "C3"):
|
511
|
+
arr = _as_1d_array_like(v, self.P.nn)
|
512
|
+
setattr(self.P, k, arr)
|
513
|
+
elif hasattr(self.P, k):
|
514
|
+
setattr(self.P, k, v)
|
515
|
+
else:
|
516
|
+
raise ValueError(f"Invalid parameter: {k}")
|
517
|
+
|
518
|
+
# Optionally replace initial state
|
519
|
+
if x0 is not None:
|
520
|
+
x0 = np.array(x0, dtype=np.float64)
|
521
|
+
if x0.shape[0] != 6 * self.P.nn:
|
522
|
+
raise ValueError("initial_state must have length 6*nn")
|
523
|
+
self.P.initial_state = x0
|
524
|
+
|
525
|
+
if not self._checked:
|
526
|
+
self.check_input()
|
527
|
+
|
528
|
+
return integrate(self.P)
|
529
|
+
|
530
|
+
|
531
|
+
# Alias for consistency with naming convention
|
532
|
+
JR_sde_numba = JR_sde
|
vbi/models/numba/mpr.py
CHANGED
@@ -116,6 +116,7 @@ def integrate(P, B):
|
|
116
116
|
rv_t = np.zeros((nt // rv_decimate), dtype=np.float32)
|
117
117
|
|
118
118
|
def compute():
|
119
|
+
nonlocal rv_d, rv_t, bold_d, bold_t
|
119
120
|
|
120
121
|
bold_d = np.array([])
|
121
122
|
bold_t = np.array([])
|
@@ -152,10 +153,16 @@ def integrate(P, B):
|
|
152
153
|
if (i % bold_decimate == 0) and ((i // bold_decimate) < vv.shape[0]):
|
153
154
|
vv[i // bold_decimate] = v[1]
|
154
155
|
qq[i // bold_decimate] = q[1]
|
156
|
+
|
157
|
+
if RECORD_RV:
|
158
|
+
rv_d = rv_d[rv_t >= P.t_cut, :]
|
159
|
+
rv_t = rv_t[rv_t >= P.t_cut]
|
155
160
|
|
156
161
|
if RECORD_BOLD:
|
157
162
|
bold_d = vo * (k1 * (1 - qq) + k2 * (1 - qq / vv) + k3 * (1 - vv))
|
158
163
|
bold_t = np.linspace(0, P.t_end - dt * bold_decimate, len(bold_d))
|
164
|
+
bold_d = bold_d[bold_t >= P.t_cut, :]
|
165
|
+
bold_t = bold_t[bold_t >= P.t_cut]
|
159
166
|
|
160
167
|
return rv_t, rv_d, bold_t, bold_d
|
161
168
|
|
@@ -215,6 +222,7 @@ class MPR_sde:
|
|
215
222
|
assert self.P.weights.shape[0] == self.P.weights.shape[1]
|
216
223
|
assert self.P.initial_state is not None
|
217
224
|
assert len(self.P.initial_state) == 2 * self.P.weights.shape[0]
|
225
|
+
assert self.P.t_cut < self.P.t_end, "t_cut must be less than t_end"
|
218
226
|
self.P.eta = check_vec_size(self.P.eta, self.P.nn)
|
219
227
|
self.P.t_end /= 10
|
220
228
|
self.P.t_cut /= 10
|