vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl → 0.2__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,443 @@
1
+
2
+ import warnings
3
+ import numpy as np
4
+ from numba import njit, jit
5
+ from numba.experimental import jitclass
6
+ from numba.extending import register_jitable
7
+ from numba import float64, boolean, int64, types
8
+ from numba.core.errors import NumbaPerformanceWarning
9
+
10
+ warnings.simplefilter("ignore", category=NumbaPerformanceWarning)
11
+
12
+
13
+ # ---------- utilities ----------
14
+
15
+ def _to_1d_array(x):
16
+ x = np.array(x, dtype=np.float64)
17
+ if x.ndim == 0:
18
+ x = x.reshape(1)
19
+ return x
20
+
21
+ def _to_2d_array(x):
22
+ x = np.array(x, dtype=np.float64)
23
+ if x.ndim == 1:
24
+ # try to guess a square matrix if possible
25
+ n = int(np.sqrt(x.size))
26
+ if n * n == x.size:
27
+ x = x.reshape(n, n)
28
+ else:
29
+ raise ValueError("weights must be square (nxn).")
30
+ return x
31
+
32
+ def check_vec_size(x, nn):
33
+ """Return a length-nn vector from scalar/1-vector or already-length-nn input."""
34
+ arr = np.array(x, dtype=np.float64)
35
+ if arr.ndim == 0:
36
+ return np.ones(nn, dtype=np.float64) * float(arr)
37
+ if arr.size == 1:
38
+ return np.ones(nn, dtype=np.float64) * float(arr[0])
39
+ if arr.size != nn:
40
+ raise ValueError(f"Vector parameter has size {arr.size} but nn={nn}.")
41
+ return arr.astype(np.float64)
42
+
43
+
44
+ @register_jitable
45
+ def set_seed_compat(x):
46
+ np.random.seed(x)
47
+
48
+
49
+ # ---------- core model (Numba) ----------
50
+
51
+ wc_spec = [
52
+ ("c_ee", float64[:]),
53
+ ("c_ei", float64[:]),
54
+ ("c_ie", float64[:]),
55
+ ("c_ii", float64[:]),
56
+ ("tau_e", float64[:]),
57
+ ("tau_i", float64[:]),
58
+ ("a_e", float64),
59
+ ("a_i", float64),
60
+ ("b_e", float64),
61
+ ("b_i", float64),
62
+ ("c_e", float64),
63
+ ("c_i", float64),
64
+ ("theta_e", float64),
65
+ ("theta_i", float64),
66
+ ("r_e", float64),
67
+ ("r_i", float64),
68
+ ("k_e", float64),
69
+ ("k_i", float64),
70
+ ("alpha_e", float64),
71
+ ("alpha_i", float64),
72
+ ("P", float64[:]),
73
+ ("Q", float64[:]),
74
+ ("g_e", float64),
75
+ ("g_i", float64),
76
+ ("dt", float64),
77
+ ("t_end", float64),
78
+ ("t_cut", float64),
79
+ ("nn", int64),
80
+ ("weights", float64[:, :]),
81
+ ("seed", int64),
82
+ ("noise_amp", float64),
83
+ ("decimate", int64),
84
+ ("RECORD_EI", types.string),
85
+ ("initial_state", float64[:]),
86
+ ("shift_sigmoid", boolean),
87
+ ]
88
+
89
+
90
+ @jitclass(wc_spec)
91
+ class ParWC:
92
+ def __init__(
93
+ self,
94
+ c_ee=np.array([16.0]),
95
+ c_ei=np.array([12.0]),
96
+ c_ie=np.array([15.0]),
97
+ c_ii=np.array([3.0]),
98
+ tau_e=np.array([8.0]),
99
+ tau_i=np.array([8.0]),
100
+ a_e=1.3,
101
+ a_i=2.0,
102
+ b_e=4.0,
103
+ b_i=3.7,
104
+ c_e=1.0,
105
+ c_i=1.0,
106
+ theta_e=0.0,
107
+ theta_i=0.0,
108
+ r_e=1.0,
109
+ r_i=1.0,
110
+ k_e=0.994,
111
+ k_i=0.999,
112
+ alpha_e=1.0,
113
+ alpha_i=1.0,
114
+ P=np.array([0.0]),
115
+ Q=np.array([0.0]),
116
+ g_e=0.0,
117
+ g_i=0.0,
118
+ dt=0.01,
119
+ t_end=300.0,
120
+ t_cut=0.0,
121
+ weights=np.empty((0, 0), dtype=np.float64),
122
+ seed=-1,
123
+ noise_amp=0.0,
124
+ decimate=1,
125
+ RECORD_EI="E",
126
+ initial_state=np.empty(0, dtype=np.float64),
127
+ shift_sigmoid=False,
128
+ ):
129
+ self.c_ee = c_ee
130
+ self.c_ei = c_ei
131
+ self.c_ie = c_ie
132
+ self.c_ii = c_ii
133
+ self.tau_e = tau_e
134
+ self.tau_i = tau_i
135
+ self.a_e = a_e
136
+ self.a_i = a_i
137
+ self.b_e = b_e
138
+ self.b_i = b_i
139
+ self.c_e = c_e
140
+ self.c_i = c_i
141
+ self.theta_e = theta_e
142
+ self.theta_i = theta_i
143
+ self.r_e = r_e
144
+ self.r_i = r_i
145
+ self.k_e = k_e
146
+ self.k_i = k_i
147
+ self.alpha_e = alpha_e
148
+ self.alpha_i = alpha_i
149
+ self.P = P
150
+ self.Q = Q
151
+ self.g_e = g_e
152
+ self.g_i = g_i
153
+ self.dt = dt
154
+ self.t_end = t_end
155
+ self.t_cut = t_cut
156
+ self.nn = len(weights)
157
+ self.weights = weights
158
+ self.seed = seed
159
+ self.noise_amp = noise_amp
160
+ self.decimate = decimate
161
+ self.RECORD_EI = RECORD_EI
162
+ self.initial_state = initial_state
163
+ self.shift_sigmoid = shift_sigmoid
164
+
165
+
166
+ @njit
167
+ def sigmoid_vec(x, a, b, c, shift_sigmoid):
168
+ y = np.empty_like(x)
169
+ if shift_sigmoid:
170
+ # c * (sigmoid(a(x-b)) - sigmoid(-ab))
171
+ base = 1.0 / (1.0 + np.exp(-a * (-b)))
172
+ for i in range(x.size):
173
+ y[i] = c * (1.0 / (1.0 + np.exp(-a * (x[i] - b))) - base)
174
+ else:
175
+ for i in range(x.size):
176
+ y[i] = c / (1.0 + np.exp(-a * (x[i] - b)))
177
+ return y
178
+
179
+
180
+ @njit
181
+ def f_wc(x, t, P):
182
+ """
183
+ Wilson-Cowan ODE right-hand side (per-node, single simulation).
184
+ x: shape (2*nn,)
185
+ """
186
+ nn = P.nn
187
+ dxdt = np.zeros_like(x)
188
+
189
+ E = x[:nn]
190
+ I = x[nn:]
191
+
192
+ # Linear coupling (weights @ state)
193
+ lc_e = P.g_e * np.dot(P.weights, E) if P.g_e != 0.0 else np.zeros(nn)
194
+ lc_i = P.g_i * np.dot(P.weights, I) if P.g_i != 0.0 else np.zeros(nn)
195
+
196
+ # Inputs to sigmoids
197
+ x_e = P.alpha_e * (P.c_ee * E - P.c_ei * I + P.P - P.theta_e + lc_e)
198
+ x_i = P.alpha_i * (P.c_ie * E - P.c_ii * I + P.Q - P.theta_i + lc_i)
199
+
200
+ s_e = sigmoid_vec(x_e, P.a_e, P.b_e, P.c_e, P.shift_sigmoid)
201
+ s_i = sigmoid_vec(x_i, P.a_i, P.b_i, P.c_i, P.shift_sigmoid)
202
+
203
+ # Time constants (vectorized)
204
+ inv_tau_e = 1.0 / P.tau_e
205
+ inv_tau_i = 1.0 / P.tau_i
206
+
207
+ # dE/dt
208
+ for i in range(nn):
209
+ dxdt[i] = inv_tau_e[i] * (-E[i] + (P.k_e - P.r_e * E[i]) * s_e[i])
210
+ # dI/dt
211
+ for i in range(nn):
212
+ dxdt[nn + i] = inv_tau_i[i] * (-I[i] + (P.k_i - P.r_i * I[i]) * s_i[i])
213
+
214
+ return dxdt
215
+
216
+
217
+ @njit
218
+ def heun_sde(x, t, P):
219
+ dt = P.dt
220
+ coeff = P.noise_amp * np.sqrt(dt)
221
+ dW = coeff * np.random.randn(x.size)
222
+
223
+ k1 = f_wc(x, t, P)
224
+ x1 = x + dt * k1 + dW
225
+ k2 = f_wc(x1, t + dt, P)
226
+ x_out = x + 0.5 * dt * (k1 + k2) + dW
227
+ return x_out
228
+
229
+
230
+ @njit
231
+ def set_initial_state(nn, seed=-1):
232
+ if seed >= 0:
233
+ set_seed_compat(seed)
234
+ y0 = np.random.rand(2 * nn)
235
+ return y0
236
+
237
+
238
+ # ---------- high-level API (Python) ----------
239
+
240
+ class WC_sde_numba:
241
+ """
242
+ Numba implementation of the Wilson-Cowan SDE, modeled after mpr.py and
243
+ translated from the CuPy/Numpy reference.
244
+ """
245
+
246
+ def __init__(self, par: dict = {}):
247
+ # Prepare raw dict and build jitclass
248
+ self.P = self._get_par_wc(par)
249
+
250
+ # Seed
251
+ if self.P.seed >= 0:
252
+ np.random.seed(self.P.seed)
253
+
254
+ def __call__(self):
255
+ return self.P
256
+
257
+ def __str__(self) -> str:
258
+ params = [
259
+ "nn", "dt", "t_end", "t_cut", "decimate", "noise_amp",
260
+ "g_e", "g_i", "a_e", "a_i", "b_e", "b_i", "k_e", "k_i",
261
+ ]
262
+ s = ["Wilson-Cowan (Numba) parameters:"]
263
+ for k in params:
264
+ s.append(f"{k} = {getattr(self.P, k)}")
265
+ return "\n".join(s)
266
+
267
+ # ----- builders & checks -----
268
+ def _get_par_wc(self, par: dict):
269
+ par = dict(par) # shallow copy
270
+
271
+ # weights first (to infer nn)
272
+ if "weights" not in par:
273
+ raise ValueError("weights (nxn) must be provided.")
274
+ W = _to_2d_array(par["weights"])
275
+ nn = W.shape[0]
276
+
277
+ # convert possibly-scalar/vector params to length-nn arrays
278
+ vec_keys = ["c_ee","c_ei","c_ie","c_ii","tau_e","tau_i","P","Q"]
279
+ for k in vec_keys:
280
+ if k in par:
281
+ par[k] = check_vec_size(par[k], nn)
282
+
283
+ # defaults for any missing vector keys
284
+ defaults = {
285
+ "c_ee": 16.0, "c_ei": 12.0, "c_ie": 15.0, "c_ii": 3.0,
286
+ "tau_e": 8.0, "tau_i": 8.0, "P": 0.0, "Q": 0.0
287
+ }
288
+ for k, v in defaults.items():
289
+ if k not in par:
290
+ par[k] = np.ones(nn) * v
291
+
292
+ # set weights and nn
293
+ par["weights"] = W
294
+
295
+ # initial_state (optional)
296
+ if "initial_state" in par:
297
+ arr = np.array(par["initial_state"], dtype=np.float64)
298
+ if arr.size != 0 and arr.size != 2 * nn:
299
+ raise ValueError(f"initial_state must have length {2*nn}.")
300
+ par["initial_state"] = arr
301
+ else:
302
+ par["initial_state"] = np.empty(0, dtype=np.float64)
303
+
304
+ # strings/flags
305
+ if "RECORD_EI" not in par:
306
+ par["RECORD_EI"] = "E"
307
+ if "decimate" not in par:
308
+ par["decimate"] = 1
309
+ if "noise_amp" not in par:
310
+ par["noise_amp"] = 0.0
311
+
312
+ # build jitclass
313
+ P = ParWC(**par)
314
+ return P
315
+
316
+ def set_initial_state(self):
317
+ self.P.initial_state = set_initial_state(self.P.nn, self.P.seed)
318
+
319
+ def check_input(self):
320
+ P = self.P
321
+ assert P.weights.shape[0] == P.weights.shape[1], "weights must be square"
322
+ assert P.nn == P.weights.shape[0], "nn must match weights shape"
323
+ if P.initial_state.size == 0:
324
+ self.set_initial_state()
325
+ assert P.initial_state.size == 2 * P.nn, "initial_state length mismatch"
326
+ assert P.t_cut < P.t_end, "t_cut must be less than t_end"
327
+
328
+ # ensure vector parameters are length-nn (already enforced in builder)
329
+ # but re-check shapes at runtime for safety
330
+ for k in ["c_ee","c_ei","c_ie","c_ii","tau_e","tau_i","P","Q"]:
331
+ v = getattr(P, k)
332
+ assert v.size == P.nn, f"{k} must be length nn"
333
+
334
+ def run(self, par: dict = None, x0=None, verbose: bool = True):
335
+ # update parameters if provided
336
+ if par:
337
+ # (rebuild jitclass when structure-changing params come in)
338
+ merged = {**self._par_to_dict(), **par}
339
+ self.P = self._get_par_wc(merged)
340
+
341
+ # set external initial state if provided
342
+ if x0 is not None:
343
+ x0 = np.array(x0, dtype=np.float64)
344
+ if x0.size != 2 * self.P.nn:
345
+ raise ValueError(f"x0 must be length {2*self.P.nn}")
346
+ self.P.initial_state = x0
347
+
348
+ # checks
349
+ self.check_input()
350
+
351
+ return integrate(self.P, verbose=verbose)
352
+
353
+ def _par_to_dict(self):
354
+ P = self.P
355
+ d = {
356
+ "c_ee": np.array(P.c_ee),
357
+ "c_ei": np.array(P.c_ei),
358
+ "c_ie": np.array(P.c_ie),
359
+ "c_ii": np.array(P.c_ii),
360
+ "tau_e": np.array(P.tau_e),
361
+ "tau_i": np.array(P.tau_i),
362
+ "a_e": P.a_e,
363
+ "a_i": P.a_i,
364
+ "b_e": P.b_e,
365
+ "b_i": P.b_i,
366
+ "c_e": P.c_e,
367
+ "c_i": P.c_i,
368
+ "theta_e": P.theta_e,
369
+ "theta_i": P.theta_i,
370
+ "r_e": P.r_e,
371
+ "r_i": P.r_i,
372
+ "k_e": P.k_e,
373
+ "k_i": P.k_i,
374
+ "alpha_e": P.alpha_e,
375
+ "alpha_i": P.alpha_i,
376
+ "P": np.array(P.P),
377
+ "Q": np.array(P.Q),
378
+ "g_e": P.g_e,
379
+ "g_i": P.g_i,
380
+ "dt": P.dt,
381
+ "t_end": P.t_end,
382
+ "t_cut": P.t_cut,
383
+ "weights": np.array(P.weights),
384
+ "seed": P.seed,
385
+ "noise_amp": P.noise_amp,
386
+ "decimate": P.decimate,
387
+ "RECORD_EI": P.RECORD_EI,
388
+ "initial_state": np.array(P.initial_state),
389
+ "shift_sigmoid": P.shift_sigmoid,
390
+ }
391
+ return d
392
+
393
+
394
+ def integrate(P: ParWC, verbose=True):
395
+ """
396
+ Pure-Python driver (Numba-accelerated inner steps).
397
+ Returns dict with t, E, I (float32).
398
+ """
399
+ nn = P.nn
400
+ dt = P.dt
401
+ nt = int(P.t_end / dt)
402
+ dec = max(1, int(P.decimate))
403
+
404
+ # buffers sized after decimation & cut
405
+ # we'll first allocate full decimated length, then trim by t_cut
406
+ nbuf = nt // dec
407
+ record_e = "e" in P.RECORD_EI.lower()
408
+ record_i = "i" in P.RECORD_EI.lower()
409
+
410
+ t_buf = np.zeros(nbuf, dtype=np.float32)
411
+ E_buf = np.zeros((nbuf, nn), dtype=np.float32) if record_e else None
412
+ I_buf = np.zeros((nbuf, nn), dtype=np.float32) if record_i else None
413
+
414
+ x = P.initial_state.copy()
415
+ buf_idx = 0
416
+
417
+ for i in range(nt):
418
+ t_curr = i * dt
419
+ x = heun_sde(x, t_curr, P)
420
+
421
+ if (i % dec) == 0 and buf_idx < nbuf:
422
+ t_buf[buf_idx] = t_curr
423
+ if record_e:
424
+ E_buf[buf_idx] = x[:nn].astype(np.float32)
425
+ if record_i:
426
+ I_buf[buf_idx] = x[nn:].astype(np.float32)
427
+ buf_idx += 1
428
+
429
+ # trim to actual filled length
430
+ t_buf = t_buf[:buf_idx]
431
+ if record_e: E_buf = E_buf[:buf_idx]
432
+ if record_i: I_buf = I_buf[:buf_idx]
433
+
434
+ # apply t_cut
435
+ keep = t_buf >= P.t_cut
436
+ t_out = t_buf[keep]
437
+ E_out = E_buf[keep] if record_e else None
438
+ I_out = I_buf[keep] if record_i else None
439
+
440
+ return {"t": t_out, "E": E_out, "I": I_out}
441
+
442
+
443
+ WC_sde = WC_sde_numba # alias