vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl → 0.2__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,532 @@
1
+ import warnings
2
+ import numpy as np
3
+ from numba import njit
4
+ from numba.experimental import jitclass
5
+ from numba import float64, boolean, int64
6
+ from numba.extending import register_jitable
7
+ from numba.core.errors import NumbaPerformanceWarning
8
+
9
+ warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
10
+
11
+ # ---------------------------------------------------------------
12
+ # Helper utilities (broadcasting, seeding, initial state)
13
+ # ---------------------------------------------------------------
14
+
15
+ @register_jitable
16
+ def set_seed_compat(x):
17
+ np.random.seed(x)
18
+
19
+
20
+ @register_jitable
21
+ def _as_1d_array_like(x, nn):
22
+ """Broadcast scalar to 1D array of length nn if needed."""
23
+ x_arr = np.array(x) if not isinstance(x, np.ndarray) else x
24
+ if x_arr.ndim == 0:
25
+ return np.ones(nn) * float(x_arr)
26
+ if x_arr.ndim == 1 and x_arr.shape[0] == nn:
27
+ return x_arr.astype(np.float64)
28
+ raise ValueError("Parameter must be scalar or 1D array of length nn")
29
+
30
+
31
+ @njit
32
+ def set_initial_state_jr(nn, seed=-1):
33
+ """Initial state for JR: stack 6*n vectors.
34
+ Mirrors ranges similar to the CuPy reference implementation.
35
+ """
36
+ if seed is not None and seed >= 0:
37
+ set_seed_compat(seed)
38
+
39
+ y0 = np.random.uniform(-1.0, 1.0, nn) # x
40
+ y1 = np.random.uniform(-500.0, 500.0, nn) # y
41
+ y2 = np.random.uniform(-50.0, 50.0, nn) # z
42
+ y3 = np.random.uniform(-6.0, 6.0, nn) # x'
43
+ y4 = np.random.uniform(-20.0, 20.0, nn) # y'
44
+ y5 = np.random.uniform(-500.0, 500.0, nn) # z'
45
+
46
+ y = np.zeros(6 * nn)
47
+ y[:nn] = y0
48
+ y[nn:2*nn] = y1
49
+ y[2*nn:3*nn] = y2
50
+ y[3*nn:4*nn] = y3
51
+ y[4*nn:5*nn] = y4
52
+ y[5*nn:6*nn] = y5
53
+ return y
54
+
55
+
56
+ # ---------------------------------------------------------------
57
+ # JR parameters as a jitclass (Numba-friendly container)
58
+ # ---------------------------------------------------------------
59
+
60
+ jr_spec = [
61
+ ("G", float64),
62
+ ("A", float64),
63
+ ("B", float64),
64
+ ("a", float64),
65
+ ("b", float64),
66
+ ("v0", float64),
67
+ ("vmax", float64),
68
+ ("r", float64),
69
+ ("mu", float64),
70
+ ("noise_amp", float64),
71
+ ("dt", float64),
72
+ ("t_cut", float64),
73
+ ("t_end", float64),
74
+ ("decimate", int64),
75
+ ("nn", int64),
76
+ ("seed", int64),
77
+ ("sigma", float64), # sqrt(dt) * noise_amp
78
+ # arrays
79
+ ("weights", float64[:, :]),
80
+ ("C0", float64[:]),
81
+ ("C1", float64[:]),
82
+ ("C2", float64[:]),
83
+ ("C3", float64[:]),
84
+ ("initial_state", float64[:]),
85
+ ]
86
+
87
+
88
+ @jitclass(jr_spec)
89
+ class ParJR:
90
+ def __init__(
91
+ self,
92
+ weights,
93
+ G=1.0,
94
+ A=3.25,
95
+ B=22.0,
96
+ a=0.1,
97
+ b=0.05,
98
+ v0=6.0,
99
+ vmax=0.005,
100
+ r=0.56,
101
+ mu=0.24,
102
+ noise_amp=0.01,
103
+ dt=0.01,
104
+ t_cut=0.0,
105
+ t_end=1000.0,
106
+ decimate=1,
107
+ C0=135.0,
108
+ C1=0.8*135.0,
109
+ C2=0.25*135.0,
110
+ C3=0.25*135.0,
111
+ seed=-1
112
+ ):
113
+ self.weights = weights
114
+ self.nn = len(weights)
115
+
116
+ self.G = G
117
+ self.A = A
118
+ self.B = B
119
+ self.a = a
120
+ self.b = b
121
+ self.v0 = v0
122
+ self.vmax = vmax
123
+ self.r = r
124
+ self.mu = mu
125
+ self.noise_amp = noise_amp
126
+ self.dt = dt
127
+ self.t_cut = t_cut
128
+ self.t_end = t_end
129
+ self.decimate = decimate
130
+ self.seed = seed
131
+
132
+ # C arrays are now passed pre-processed from outside
133
+ self.C0 = C0
134
+ self.C1 = C1
135
+ self.C2 = C2
136
+ self.C3 = C3
137
+
138
+ self.sigma = np.sqrt(dt) * noise_amp
139
+ self.initial_state = np.zeros(6 * self.nn) # set by caller later
140
+
141
+
142
+ # ---------------------------------------------------------------
143
+ # JR model equations + integrator (Numba-jitted)
144
+ # ---------------------------------------------------------------
145
+
146
+ @register_jitable
147
+ def S_sigmoid(x, vmax, r, v0):
148
+ """Numerically stable sigmoid function to avoid overflow."""
149
+ z = r * (v0 - x)
150
+ # Clip z to avoid overflow: exp(700) is near overflow limit
151
+ z_clipped = np.clip(z, -700, 700)
152
+ return vmax / (1.0 + np.exp(z_clipped))
153
+
154
+
155
+ @njit
156
+ def f_jr(x, t, P):
157
+ nn = P.nn
158
+
159
+ # Unpack state
160
+ x0 = x[0*nn:1*nn] # x
161
+ y0 = x[1*nn:2*nn] # y
162
+ z0 = x[2*nn:3*nn] # z
163
+ xp = x[3*nn:4*nn] # x'
164
+ yp = x[4*nn:5*nn] # y'
165
+ zp = x[5*nn:6*nn] # z'
166
+
167
+ # Precompute constants
168
+ Aa = P.A * P.a
169
+ Bb = P.B * P.b
170
+ aa = P.a * P.a
171
+ bb = P.b * P.b
172
+
173
+ # Coupling term: weights @ (y - z)
174
+ couplings = S_sigmoid(P.weights.dot(y0 - z0), P.vmax, P.r, P.v0)
175
+
176
+ # Allocate derivative
177
+ dxdt = np.zeros_like(x)
178
+
179
+ # Dynamics
180
+ dxdt[0*nn:1*nn] = xp
181
+ dxdt[1*nn:2*nn] = yp
182
+ dxdt[2*nn:3*nn] = zp
183
+
184
+ dxdt[3*nn:4*nn] = Aa * S_sigmoid(y0 - z0, P.vmax, P.r, P.v0) - 2.0 * P.a * xp - aa * x0
185
+ dxdt[4*nn:5*nn] = (
186
+ Aa * (P.mu + P.C1 * S_sigmoid(P.C0 * x0, P.vmax, P.r, P.v0) + P.G * couplings)
187
+ - 2.0 * P.a * yp - aa * y0
188
+ )
189
+ dxdt[5*nn:6*nn] = Bb * P.C3 * S_sigmoid(P.C2 * x0, P.vmax, P.r, P.v0) - 2.0 * P.b * zp - bb * z0
190
+
191
+ return dxdt
192
+
193
+
194
+ @njit
195
+ def heun_sde(x, t, P):
196
+ nn = P.nn
197
+ dt = P.dt
198
+
199
+ # Stochastic drive on the y' block, sigma already includes sqrt(dt)
200
+ dW = P.sigma * np.random.randn(nn)
201
+
202
+ k1 = f_jr(x, t, P)
203
+ x1 = x + dt * k1
204
+ x1[4*nn:5*nn] += dW
205
+
206
+ k2 = f_jr(x1, t + dt, P)
207
+ x = x + 0.5 * dt * (k1 + k2)
208
+ x[4*nn:5*nn] += dW
209
+
210
+ return x
211
+
212
+
213
+ # ---------------------------------------------------------------
214
+ # Top-level integrate and driver class (mirrors mpr.py style)
215
+ # ---------------------------------------------------------------
216
+
217
+ def integrate(P):
218
+ nn = P.nn
219
+ dt = P.dt
220
+ dec = P.decimate
221
+
222
+ # Ensure initial state is defined
223
+ x = P.initial_state.copy()
224
+
225
+ nt = int(P.t_end / dt)
226
+ tspan = np.linspace(0.0, (nt - 1) * dt, nt)
227
+
228
+ # Cut & decimate bookkeeping
229
+ i_cut = int(np.searchsorted(tspan, P.t_cut, side='left'))
230
+ n_keep = (nt - i_cut + (dec - 1)) // dec
231
+
232
+ # Output: y - z
233
+ ts = np.zeros(n_keep, dtype=np.float32)
234
+ ys = np.zeros((n_keep, nn), dtype=np.float32)
235
+
236
+ k = 0
237
+ for i in range(nt):
238
+ t = tspan[i]
239
+ x = heun_sde(x, t, P)
240
+
241
+ if i >= i_cut and ((i - i_cut) % dec == 0):
242
+ ts[k] = t
243
+ y0 = x[1*nn:2*nn]
244
+ z0 = x[2*nn:3*nn]
245
+ ys[k, :] = (y0 - z0).astype(np.float32)
246
+ k += 1
247
+ if k >= n_keep:
248
+ break
249
+
250
+ return {"t": ts, "x": ys}
251
+
252
+
253
+ class JR_sde:
254
+ """
255
+ Numba implementation of the Jansen-Rit neural mass model.
256
+
257
+ .. list-table:: Parameters
258
+ :widths: 25 50 25
259
+ :header-rows: 1
260
+
261
+ * - Name
262
+ - Explanation
263
+ - Default Value
264
+ * - `A`
265
+ - Excitatory post synaptic potential amplitude.
266
+ - 3.25
267
+ * - `B`
268
+ - Inhibitory post synaptic potential amplitude.
269
+ - 22.0
270
+ * - `a`
271
+ - Inverse time constant of the excitatory postsynaptic potential (1/a = time constant).
272
+ - 0.1 (time constant: 10.0)
273
+ * - `b`
274
+ - Inverse time constant of the inhibitory postsynaptic potential (1/b = time constant).
275
+ - 0.05 (time constant: 20.0)
276
+ * - `C0`
277
+ - Average number of synapses between pyramidal cells and excitatory interneurons. If array-like, it should be of length `nn` (number of nodes).
278
+ - 135.0
279
+ * - `C1`
280
+ - Average number of synapses between excitatory interneurons and pyramidal cells. If array-like, it should be of length `nn`.
281
+ - 0.8 * 135.0
282
+ * - `C2`
283
+ - Average number of synapses between pyramidal cells and inhibitory interneurons. If array-like, it should be of length `nn`.
284
+ - 0.25 * 135.0
285
+ * - `C3`
286
+ - Average number of synapses between inhibitory interneurons and pyramidal cells. If array-like, it should be of length `nn`.
287
+ - 0.25 * 135.0
288
+ * - `vmax`
289
+ - Maximum firing rate of the sigmoid function.
290
+ - 0.005
291
+ * - `v0`
292
+ - Potential at half of maximum firing rate (inflection point of sigmoid).
293
+ - 6.0
294
+ * - `r`
295
+ - Slope of sigmoid function at `v0`.
296
+ - 0.56
297
+ * - `G`
298
+ - Global coupling strength scaling the network connections.
299
+ - 1.0
300
+ * - `mu`
301
+ - Mean input to the excitatory population (external drive).
302
+ - 0.24
303
+ * - `noise_amp`
304
+ - Amplitude of the stochastic noise applied to the excitatory population.
305
+ - 0.01
306
+ * - `weights`
307
+ - Structural connectivity matrix of shape (`nn`, `nn`). Must be provided.
308
+ - None
309
+ * - `dt`
310
+ - Integration time step.
311
+ - 0.01
312
+ * - `t_end`
313
+ - End time of simulation.
314
+ - 1000.0
315
+ * - `t_cut`
316
+ - Time from which to start collecting output (burn-in period).
317
+ - 0.0
318
+ * - `decimate`
319
+ - Decimation factor for the output time series (every `decimate`-th point is saved).
320
+ - 1
321
+ * - `seed`
322
+ - Random seed for reproducible simulations. If -1 or None, no seeding is applied.
323
+ - -1
324
+ * - `initial_state`
325
+ - Initial state vector of shape (6*nn,). If None, random initial conditions are generated.
326
+ - None
327
+
328
+ Usage example (single simulation):
329
+ >>> import numpy as np
330
+ >>> from vbi.models.numba.jansen_rit import JR_sde
331
+ >>> W = np.eye(2) # 2-node demo connectivity
332
+ >>> jr = JR_sde({"weights": W, "dt": 0.01, "t_end": 200.0, "t_cut": 100.0, "decimate": 1})
333
+ >>> out = jr.run()
334
+ >>> t, x = out["t"], out["x"] # x has shape (n_step, nn)
335
+
336
+ Notes
337
+ -----
338
+ The Jansen-Rit model describes the dynamics of a cortical column with three neural populations:
339
+ - Pyramidal cells (main excitatory population)
340
+ - Excitatory interneurons
341
+ - Inhibitory interneurons
342
+
343
+ The model equations are integrated using the Heun stochastic integration scheme.
344
+ The output represents the difference between excitatory and inhibitory postsynaptic potentials (y - z),
345
+ which corresponds to the local field potential that can be measured experimentally.
346
+ """
347
+
348
+ def __init__(self, par_jr: dict):
349
+ """
350
+ Initialize the Jansen-Rit model.
351
+
352
+ Parameters
353
+ ----------
354
+ par_jr : dict
355
+ Dictionary containing model parameters. See class documentation for available parameters.
356
+ The 'weights' parameter is required and must be a square connectivity matrix.
357
+ """
358
+ # Validate weights early and create parameter jitclass
359
+ if "weights" not in par_jr or par_jr["weights"] is None:
360
+ raise ValueError("'weights' must be provided (square connectivity matrix)")
361
+
362
+ W = np.array(par_jr["weights"], dtype=np.float64)
363
+ if W.ndim != 2 or W.shape[0] != W.shape[1]:
364
+ raise ValueError("'weights' must be a square 2D array")
365
+
366
+ nn = len(W)
367
+
368
+ # Pre-process C parameters before passing to jitclass
369
+ params = dict(par_jr)
370
+ params["weights"] = W
371
+
372
+ # Handle C parameters - broadcast them here outside jitclass
373
+ for c_name in ["C0", "C1", "C2", "C3"]:
374
+ if c_name in params:
375
+ c_val = params[c_name]
376
+ params[c_name] = _as_1d_array_like(c_val, nn)
377
+ else:
378
+ # Set defaults
379
+ if c_name == "C0":
380
+ params[c_name] = _as_1d_array_like(135.0, nn)
381
+ elif c_name == "C1":
382
+ params[c_name] = _as_1d_array_like(0.8*135.0, nn)
383
+ elif c_name == "C2":
384
+ params[c_name] = _as_1d_array_like(0.25*135.0, nn)
385
+ elif c_name == "C3":
386
+ params[c_name] = _as_1d_array_like(0.25*135.0, nn)
387
+
388
+ # Create jitclass instance
389
+ self.P = ParJR(**params)
390
+
391
+ # Seed handling
392
+ self.seed = int(self.P.seed)
393
+ if self.seed >= 0:
394
+ np.random.seed(self.seed)
395
+
396
+ # Ensure initial state
397
+ if "initial_state" in par_jr and par_jr["initial_state"] is not None:
398
+ x0 = np.array(par_jr["initial_state"], dtype=np.float64)
399
+ if x0.shape[0] != 6 * self.P.nn:
400
+ raise ValueError("initial_state must have length 6*nn")
401
+ self.P.initial_state = x0
402
+ else:
403
+ self.P.initial_state = set_initial_state_jr(self.P.nn, self.seed)
404
+
405
+ self._checked = False
406
+
407
+ def __str__(self) -> str:
408
+ """
409
+ Return a string representation of the model parameters.
410
+
411
+ Returns
412
+ -------
413
+ str
414
+ Formatted string showing all model parameters and their values.
415
+ """
416
+ print("Jansen-Rit Model (Numba)")
417
+ print("------------------------")
418
+
419
+ # Model parameters
420
+ print(f"G = {self.P.G}")
421
+ print(f"A = {self.P.A}")
422
+ print(f"B = {self.P.B}")
423
+ print(f"a = {self.P.a}")
424
+ print(f"b = {self.P.b}")
425
+ print(f"v0 = {self.P.v0}")
426
+ print(f"vmax = {self.P.vmax}")
427
+ print(f"r = {self.P.r}")
428
+ print(f"mu = {self.P.mu}")
429
+ print(f"noise_amp = {self.P.noise_amp}")
430
+
431
+ # Connectivity parameters
432
+ print(f"C0 = {self.P.C0}")
433
+ print(f"C1 = {self.P.C1}")
434
+ print(f"C2 = {self.P.C2}")
435
+ print(f"C3 = {self.P.C3}")
436
+
437
+ # Simulation parameters
438
+ print(f"dt = {self.P.dt}")
439
+ print(f"t_end = {self.P.t_end}")
440
+ print(f"t_cut = {self.P.t_cut}")
441
+ print(f"decimate = {self.P.decimate}")
442
+ print(f"nn = {self.P.nn}")
443
+ print(f"seed = {self.P.seed}")
444
+ print(f"sigma = {self.P.sigma}")
445
+ print(f"weights shape = {self.P.weights.shape}")
446
+
447
+ return ""
448
+
449
+ def check_input(self):
450
+ """
451
+ Validate model parameters.
452
+
453
+ Raises
454
+ ------
455
+ ValueError
456
+ If any parameter values are invalid (e.g., t_cut >= t_end,
457
+ decimate < 1, or dimension mismatches).
458
+ """
459
+ if self.P.t_cut >= self.P.t_end:
460
+ raise ValueError("t_cut must be less than t_end")
461
+ if self.P.decimate < 1:
462
+ raise ValueError("decimate must be >= 1")
463
+ if self.P.nn != self.P.weights.shape[0]:
464
+ raise ValueError("nn != weights.shape[0]")
465
+ self._checked = True
466
+
467
+ def set_initial_state(self, seed: int = None):
468
+ """
469
+ Set random initial state for the simulation.
470
+
471
+ Parameters
472
+ ----------
473
+ seed : int, optional
474
+ Random seed for reproducible initial conditions.
475
+ If None, uses the seed specified during initialization.
476
+ """
477
+ seed_ = self.seed if seed is None else seed
478
+ self.P.initial_state = set_initial_state_jr(self.P.nn, seed_)
479
+
480
+ def run(self, par: dict = None, x0: np.ndarray = None):
481
+ """
482
+ Run the Jansen-Rit simulation.
483
+
484
+ Parameters
485
+ ----------
486
+ par : dict, optional
487
+ Dictionary of parameters to update for this simulation run.
488
+ Any parameter from the class documentation can be updated.
489
+ x0 : np.ndarray, optional
490
+ Initial state vector of shape (6*nn,). If None, uses the
491
+ initial state set during initialization or by set_initial_state().
492
+
493
+ Returns
494
+ -------
495
+ dict
496
+ Dictionary containing simulation results:
497
+ - 't': np.ndarray of shape (n_steps,) - time points
498
+ - 'x': np.ndarray of shape (n_steps, nn) - simulated time series (y - z)
499
+ representing local field potentials
500
+ """
501
+ # Optionally update parameters on the jitclass (Numba allows setattr)
502
+ if par:
503
+ for k, v in par.items():
504
+ if k == "weights":
505
+ W = np.array(v, dtype=np.float64)
506
+ if W.ndim != 2 or W.shape[0] != W.shape[1]:
507
+ raise ValueError("'weights' must be a square 2D array")
508
+ setattr(self.P, "weights", W)
509
+ setattr(self.P, "nn", len(W))
510
+ elif k in ("C0", "C1", "C2", "C3"):
511
+ arr = _as_1d_array_like(v, self.P.nn)
512
+ setattr(self.P, k, arr)
513
+ elif hasattr(self.P, k):
514
+ setattr(self.P, k, v)
515
+ else:
516
+ raise ValueError(f"Invalid parameter: {k}")
517
+
518
+ # Optionally replace initial state
519
+ if x0 is not None:
520
+ x0 = np.array(x0, dtype=np.float64)
521
+ if x0.shape[0] != 6 * self.P.nn:
522
+ raise ValueError("initial_state must have length 6*nn")
523
+ self.P.initial_state = x0
524
+
525
+ if not self._checked:
526
+ self.check_input()
527
+
528
+ return integrate(self.P)
529
+
530
+
531
+ # Alias for consistency with naming convention
532
+ JR_sde_numba = JR_sde
vbi/models/numba/mpr.py CHANGED
@@ -116,6 +116,7 @@ def integrate(P, B):
116
116
  rv_t = np.zeros((nt // rv_decimate), dtype=np.float32)
117
117
 
118
118
  def compute():
119
+ nonlocal rv_d, rv_t, bold_d, bold_t
119
120
 
120
121
  bold_d = np.array([])
121
122
  bold_t = np.array([])
@@ -152,10 +153,16 @@ def integrate(P, B):
152
153
  if (i % bold_decimate == 0) and ((i // bold_decimate) < vv.shape[0]):
153
154
  vv[i // bold_decimate] = v[1]
154
155
  qq[i // bold_decimate] = q[1]
156
+
157
+ if RECORD_RV:
158
+ rv_d = rv_d[rv_t >= P.t_cut, :]
159
+ rv_t = rv_t[rv_t >= P.t_cut]
155
160
 
156
161
  if RECORD_BOLD:
157
162
  bold_d = vo * (k1 * (1 - qq) + k2 * (1 - qq / vv) + k3 * (1 - vv))
158
163
  bold_t = np.linspace(0, P.t_end - dt * bold_decimate, len(bold_d))
164
+ bold_d = bold_d[bold_t >= P.t_cut, :]
165
+ bold_t = bold_t[bold_t >= P.t_cut]
159
166
 
160
167
  return rv_t, rv_d, bold_t, bold_d
161
168
 
@@ -215,6 +222,7 @@ class MPR_sde:
215
222
  assert self.P.weights.shape[0] == self.P.weights.shape[1]
216
223
  assert self.P.initial_state is not None
217
224
  assert len(self.P.initial_state) == 2 * self.P.weights.shape[0]
225
+ assert self.P.t_cut < self.P.t_end, "t_cut must be less than t_end"
218
226
  self.P.eta = check_vec_size(self.P.eta, self.P.nn)
219
227
  self.P.t_end /= 10
220
228
  self.P.t_cut /= 10