vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vbi/__init__.py +37 -0
- vbi/_version.py +17 -0
- vbi/dataset/__init__.py +0 -0
- vbi/dataset/connectivity_84/centers.txt +84 -0
- vbi/dataset/connectivity_84/centres.txt +84 -0
- vbi/dataset/connectivity_84/cortical.txt +84 -0
- vbi/dataset/connectivity_84/tract_lengths.txt +84 -0
- vbi/dataset/connectivity_84/weights.txt +84 -0
- vbi/dataset/connectivity_88/Aud_88.txt +88 -0
- vbi/dataset/connectivity_88/Bold.npz +0 -0
- vbi/dataset/connectivity_88/Labels.txt +17 -0
- vbi/dataset/connectivity_88/Region_labels.txt +88 -0
- vbi/dataset/connectivity_88/tract_lengths.txt +88 -0
- vbi/dataset/connectivity_88/weights.txt +88 -0
- vbi/feature_extraction/__init__.py +1 -0
- vbi/feature_extraction/calc_features.py +293 -0
- vbi/feature_extraction/features.json +535 -0
- vbi/feature_extraction/features.py +2124 -0
- vbi/feature_extraction/features_settings.py +374 -0
- vbi/feature_extraction/features_utils.py +1357 -0
- vbi/feature_extraction/infodynamics.jar +0 -0
- vbi/feature_extraction/utility.py +507 -0
- vbi/inference.py +98 -0
- vbi/models/__init__.py +0 -0
- vbi/models/cpp/__init__.py +0 -0
- vbi/models/cpp/_src/__init__.py +0 -0
- vbi/models/cpp/_src/__pycache__/mpr_sde.cpython-310.pyc +0 -0
- vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
- vbi/models/cpp/_src/bold.hpp +303 -0
- vbi/models/cpp/_src/do.hpp +167 -0
- vbi/models/cpp/_src/do.i +17 -0
- vbi/models/cpp/_src/do.py +467 -0
- vbi/models/cpp/_src/do_wrap.cxx +12811 -0
- vbi/models/cpp/_src/jr_sdde.hpp +352 -0
- vbi/models/cpp/_src/jr_sdde.i +19 -0
- vbi/models/cpp/_src/jr_sdde.py +688 -0
- vbi/models/cpp/_src/jr_sdde_wrap.cxx +18718 -0
- vbi/models/cpp/_src/jr_sde.hpp +264 -0
- vbi/models/cpp/_src/jr_sde.i +17 -0
- vbi/models/cpp/_src/jr_sde.py +470 -0
- vbi/models/cpp/_src/jr_sde_wrap.cxx +13406 -0
- vbi/models/cpp/_src/km_sde.hpp +158 -0
- vbi/models/cpp/_src/km_sde.i +19 -0
- vbi/models/cpp/_src/km_sde.py +671 -0
- vbi/models/cpp/_src/km_sde_wrap.cxx +17367 -0
- vbi/models/cpp/_src/makefile +52 -0
- vbi/models/cpp/_src/mpr_sde.hpp +327 -0
- vbi/models/cpp/_src/mpr_sde.i +19 -0
- vbi/models/cpp/_src/mpr_sde.py +711 -0
- vbi/models/cpp/_src/mpr_sde_wrap.cxx +18618 -0
- vbi/models/cpp/_src/utility.hpp +307 -0
- vbi/models/cpp/_src/vep.hpp +171 -0
- vbi/models/cpp/_src/vep.i +16 -0
- vbi/models/cpp/_src/vep.py +464 -0
- vbi/models/cpp/_src/vep_wrap.cxx +12968 -0
- vbi/models/cpp/_src/wc_ode.hpp +294 -0
- vbi/models/cpp/_src/wc_ode.i +19 -0
- vbi/models/cpp/_src/wc_ode.py +686 -0
- vbi/models/cpp/_src/wc_ode_wrap.cxx +24263 -0
- vbi/models/cpp/damp_oscillator.py +143 -0
- vbi/models/cpp/jansen_rit.py +543 -0
- vbi/models/cpp/km.py +187 -0
- vbi/models/cpp/mpr.py +289 -0
- vbi/models/cpp/vep.py +150 -0
- vbi/models/cpp/wc.py +216 -0
- vbi/models/cupy/__init__.py +0 -0
- vbi/models/cupy/bold.py +111 -0
- vbi/models/cupy/ghb.py +284 -0
- vbi/models/cupy/jansen_rit.py +473 -0
- vbi/models/cupy/km.py +224 -0
- vbi/models/cupy/mpr.py +475 -0
- vbi/models/cupy/mpr_modified_bold.py +12 -0
- vbi/models/cupy/utils.py +184 -0
- vbi/models/numba/__init__.py +0 -0
- vbi/models/numba/_ww_EI.py +444 -0
- vbi/models/numba/damp_oscillator.py +162 -0
- vbi/models/numba/ghb.py +208 -0
- vbi/models/numba/mpr.py +383 -0
- vbi/models/pytorch/__init__.py +0 -0
- vbi/models/pytorch/data/default_parameters.npz +0 -0
- vbi/models/pytorch/data/input/ROI_sim.mat +0 -0
- vbi/models/pytorch/data/input/fc_test.csv +68 -0
- vbi/models/pytorch/data/input/fc_train.csv +68 -0
- vbi/models/pytorch/data/input/fc_vali.csv +68 -0
- vbi/models/pytorch/data/input/fcd_test.mat +0 -0
- vbi/models/pytorch/data/input/fcd_test_high_window.mat +0 -0
- vbi/models/pytorch/data/input/fcd_test_low_window.mat +0 -0
- vbi/models/pytorch/data/input/fcd_train.mat +0 -0
- vbi/models/pytorch/data/input/fcd_vali.mat +0 -0
- vbi/models/pytorch/data/input/myelin.csv +68 -0
- vbi/models/pytorch/data/input/rsfc_gradient.csv +68 -0
- vbi/models/pytorch/data/input/run_label_testset.mat +0 -0
- vbi/models/pytorch/data/input/sc_test.csv +68 -0
- vbi/models/pytorch/data/input/sc_train.csv +68 -0
- vbi/models/pytorch/data/input/sc_vali.csv +68 -0
- vbi/models/pytorch/data/obs_kong0.npz +0 -0
- vbi/models/pytorch/ww_sde_kong.py +570 -0
- vbi/models/tvbk/__init__.py +9 -0
- vbi/models/tvbk/tvbk_wrapper.py +166 -0
- vbi/models/tvbk/utils.py +72 -0
- vbi/papers/__init__.py +0 -0
- vbi/papers/pavlides_pcb_2015/pavlides.py +211 -0
- vbi/tests/__init__.py +0 -0
- vbi/tests/_test_mpr_nb.py +36 -0
- vbi/tests/test_features.py +355 -0
- vbi/tests/test_ghb_cupy.py +90 -0
- vbi/tests/test_mpr_cupy.py +49 -0
- vbi/tests/test_mpr_numba.py +84 -0
- vbi/tests/test_suite.py +19 -0
- vbi/utils.py +402 -0
- vbi-0.1.3.dist-info/METADATA +166 -0
- vbi-0.1.3.dist-info/RECORD +121 -0
- vbi-0.1.3.dist-info/WHEEL +5 -0
- vbi-0.1.3.dist-info/licenses/LICENSE +201 -0
- vbi-0.1.3.dist-info/top_level.txt +1 -0
Binary file
|
@@ -0,0 +1,507 @@
|
|
1
|
+
import torch
|
2
|
+
import logging
|
3
|
+
import numpy as np
|
4
|
+
import pandas as pd
|
5
|
+
from torch import Tensor
|
6
|
+
from typing import Union, List
|
7
|
+
|
8
|
+
|
9
|
+
def count_depth(ls):
|
10
|
+
"""
|
11
|
+
count the depth of a list
|
12
|
+
|
13
|
+
"""
|
14
|
+
if isinstance(ls, (list, tuple)):
|
15
|
+
return 1 + max(count_depth(item) for item in ls)
|
16
|
+
else:
|
17
|
+
return 0
|
18
|
+
|
19
|
+
|
20
|
+
def prepare_input(ts, dtype=np.float32):
|
21
|
+
"""
|
22
|
+
prepare input format
|
23
|
+
|
24
|
+
Parameters
|
25
|
+
----------
|
26
|
+
ts : array-like or list
|
27
|
+
Input from which the features are extracted
|
28
|
+
Returns
|
29
|
+
-------
|
30
|
+
ts: nd-array
|
31
|
+
formatted input
|
32
|
+
|
33
|
+
"""
|
34
|
+
n_trial = 0
|
35
|
+
|
36
|
+
if isinstance(ts, np.ndarray):
|
37
|
+
if ts.ndim == 3:
|
38
|
+
pass
|
39
|
+
elif ts.ndim == 2:
|
40
|
+
ts = ts[:, np.newaxis, :] # n_region = 1
|
41
|
+
else:
|
42
|
+
ts = ts[np.newaxis, np.newaxis, :] # n_region , n_trial = 1
|
43
|
+
|
44
|
+
elif isinstance(ts, (list, tuple)):
|
45
|
+
if isinstance(ts[0], np.ndarray):
|
46
|
+
if ts[0].ndim == 2:
|
47
|
+
ts = np.array(ts, dtype=dtype)
|
48
|
+
elif ts[0].ndim == 1:
|
49
|
+
ts = np.array(ts, dtype=dtype)
|
50
|
+
ts = ts[:, np.newaxis, :] # n_region = 1
|
51
|
+
else:
|
52
|
+
ts = np.array(ts, dtype=dtype)[np.newaxis, np.newaxis, :]
|
53
|
+
else:
|
54
|
+
if isinstance(ts[0], (list, tuple)):
|
55
|
+
depth = count_depth(ts)
|
56
|
+
if depth == 3:
|
57
|
+
ts = np.asarray(ts)
|
58
|
+
elif depth == 2:
|
59
|
+
ts = np.array(ts)
|
60
|
+
ts = ts[:, np.newaxis, :] # n_region = 1
|
61
|
+
else:
|
62
|
+
ts = np.array(ts)[
|
63
|
+
np.newaxis, np.newaxis, :
|
64
|
+
] # n_region , n_trial = 1
|
65
|
+
|
66
|
+
# if ts is dataframe
|
67
|
+
elif isinstance(ts, pd.DataFrame):
|
68
|
+
# assume that the dataframe is in the form of
|
69
|
+
# columns: time series
|
70
|
+
# rows: time
|
71
|
+
ts = ts.values.T
|
72
|
+
ts = ts[:, np.newaxis, :] # n_region = 1
|
73
|
+
|
74
|
+
return ts, n_trial
|
75
|
+
|
76
|
+
|
77
|
+
def prepare_input_ts(ts, indices: List[int] = None):
|
78
|
+
|
79
|
+
if not isinstance(ts, np.ndarray):
|
80
|
+
ts = np.array(ts)
|
81
|
+
if indices is None:
|
82
|
+
indices = np.arange(ts.shape[0], dtype=np.int32)
|
83
|
+
|
84
|
+
# check indices validity
|
85
|
+
if not isinstance(indices, (list, tuple, np.ndarray)):
|
86
|
+
raise ValueError("indices must be a list, tuple, or numpy array.")
|
87
|
+
if not all(isinstance(i, (int, np.int64, np.int32, np.int16)) for i in indices):
|
88
|
+
raise ValueError("indices must be a list of integers.")
|
89
|
+
if not all(i < ts.shape[0] for i in indices):
|
90
|
+
raise ValueError("indices must be smaller than the number of time series.")
|
91
|
+
|
92
|
+
ts = ts[indices]
|
93
|
+
|
94
|
+
if ts.ndim == 1:
|
95
|
+
ts = ts.reshape(1, -1)
|
96
|
+
|
97
|
+
if ts.size == 0:
|
98
|
+
return False, ts
|
99
|
+
|
100
|
+
if np.isnan(ts).any() or np.isinf(ts).any():
|
101
|
+
return False, ts
|
102
|
+
return True, ts
|
103
|
+
|
104
|
+
|
105
|
+
def make_mask(n, indices):
|
106
|
+
"""
|
107
|
+
make a mask matrix with given indices
|
108
|
+
|
109
|
+
Parameters
|
110
|
+
----------
|
111
|
+
n : int
|
112
|
+
size of the mask matrix
|
113
|
+
indices : list
|
114
|
+
indices of the mask matrix
|
115
|
+
|
116
|
+
Returns
|
117
|
+
-------
|
118
|
+
mask : numpy.ndarray
|
119
|
+
mask matrix
|
120
|
+
"""
|
121
|
+
# check validity of indices
|
122
|
+
if not isinstance(indices, (list, tuple, np.ndarray)):
|
123
|
+
raise ValueError("indices must be a list, tuple, or numpy array.")
|
124
|
+
if not all(isinstance(i, (int, np.int64, np.int32, np.int16)) for i in indices):
|
125
|
+
raise ValueError("indices must be a list of integers.")
|
126
|
+
if not all(i < n for i in indices):
|
127
|
+
raise ValueError("indices must be smaller than n.")
|
128
|
+
|
129
|
+
mask = np.zeros((n, n), dtype=np.int64)
|
130
|
+
mask[np.ix_(indices, indices)] = 1
|
131
|
+
mask = mask - np.diag(np.diag(mask))
|
132
|
+
|
133
|
+
return mask
|
134
|
+
|
135
|
+
|
136
|
+
def get_intrah_mask(n_nodes):
|
137
|
+
"""
|
138
|
+
Get a mask for intrahemispheric connections.
|
139
|
+
|
140
|
+
Parameters
|
141
|
+
----------
|
142
|
+
n_nodes: int
|
143
|
+
number of total nodes that constitute the data.
|
144
|
+
|
145
|
+
Returns
|
146
|
+
-------
|
147
|
+
mask_intrah: 2d array
|
148
|
+
mask for intrahemispheric connections.
|
149
|
+
"""
|
150
|
+
row_idx = np.arange(n_nodes)
|
151
|
+
idx1 = np.ix_(row_idx[: n_nodes // 2], row_idx[: n_nodes // 2])
|
152
|
+
idx2 = np.ix_(row_idx[n_nodes // 2 :], row_idx[n_nodes // 2 :])
|
153
|
+
# build on a zeros mask
|
154
|
+
mask_intrah = np.zeros((n_nodes, n_nodes))
|
155
|
+
mask_intrah[idx1] = 1
|
156
|
+
mask_intrah[idx2] = 1
|
157
|
+
return mask_intrah
|
158
|
+
|
159
|
+
|
160
|
+
def get_interh_mask(n_nodes):
|
161
|
+
"""
|
162
|
+
Get a mask for interhemispheric connections.
|
163
|
+
|
164
|
+
Parameters
|
165
|
+
----------
|
166
|
+
n_nodes: int
|
167
|
+
number of total nodes that constitute the data.
|
168
|
+
|
169
|
+
Returns
|
170
|
+
-------
|
171
|
+
mask_interh: 2d array
|
172
|
+
mask for interhemispheric connections.
|
173
|
+
"""
|
174
|
+
row_idx = np.arange(n_nodes // 2)
|
175
|
+
col_idx1 = np.where(np.eye(n_nodes, k=-n_nodes // 2))[0]
|
176
|
+
col_idx2 = np.where(np.eye(n_nodes, k=n_nodes // 2))[0]
|
177
|
+
idx1 = np.ix_(row_idx, col_idx1)
|
178
|
+
idx2 = np.ix_(row_idx + n_nodes // 2, col_idx2)
|
179
|
+
# build on a zeros mask
|
180
|
+
mask_interh = np.zeros((n_nodes, n_nodes))
|
181
|
+
mask_interh[idx1] = 1
|
182
|
+
mask_interh[idx2] = 1
|
183
|
+
return mask_interh
|
184
|
+
|
185
|
+
|
186
|
+
def get_masks(n_nodes, networks):
|
187
|
+
"""
|
188
|
+
Get a dictionary of masks based on the requested networks.
|
189
|
+
|
190
|
+
Parameters
|
191
|
+
----------
|
192
|
+
n_nodes: int
|
193
|
+
number of total nodes that constitute the data.
|
194
|
+
networks: list of str
|
195
|
+
list of networks to be included in the dictionary.
|
196
|
+
'full': full-network connections
|
197
|
+
'intrah': intrahemispheric connections
|
198
|
+
'interh': interhemispheric connections
|
199
|
+
to get a custom mask with specific indices
|
200
|
+
refere to `hbt.utility.make_mask(n, indices)`.
|
201
|
+
|
202
|
+
Returns
|
203
|
+
-------
|
204
|
+
masks: dict
|
205
|
+
dictionary of masks based on the requested networks.
|
206
|
+
"""
|
207
|
+
masks = {}
|
208
|
+
valid_networks = ["full", "intrah", "interh"]
|
209
|
+
# check if networks are valid
|
210
|
+
if not is_sequence(networks):
|
211
|
+
networks = [networks]
|
212
|
+
|
213
|
+
for i, ntw in enumerate(networks):
|
214
|
+
if ntw not in valid_networks:
|
215
|
+
raise ValueError(
|
216
|
+
f"Invalid network: {ntw}. Please choose from {valid_networks}."
|
217
|
+
)
|
218
|
+
if ntw == "full":
|
219
|
+
masks[ntw] = np.ones((n_nodes, n_nodes))
|
220
|
+
elif ntw == "intrah":
|
221
|
+
masks[ntw] = get_intrah_mask(n_nodes)
|
222
|
+
elif ntw == "interh":
|
223
|
+
masks[ntw] = get_interh_mask(n_nodes)
|
224
|
+
|
225
|
+
return masks
|
226
|
+
|
227
|
+
|
228
|
+
def is_sequence(arg):
|
229
|
+
"""
|
230
|
+
Check if the input is a sequence (list, tuple, np.ndarray, etc.)
|
231
|
+
|
232
|
+
Parameters
|
233
|
+
----------
|
234
|
+
arg : any
|
235
|
+
input to be checked.
|
236
|
+
|
237
|
+
Returns
|
238
|
+
-------
|
239
|
+
bool
|
240
|
+
True if the input is a sequence, False otherwise.
|
241
|
+
|
242
|
+
"""
|
243
|
+
return isinstance(arg, (list, tuple, np.ndarray))
|
244
|
+
|
245
|
+
|
246
|
+
def set_k_diagonals(A, k=0, value=0):
|
247
|
+
"""
|
248
|
+
set k diagonals of the given matrix to given value.
|
249
|
+
|
250
|
+
Parameters
|
251
|
+
----------
|
252
|
+
A : numpy.ndarray
|
253
|
+
input matrix.
|
254
|
+
k : int
|
255
|
+
number of diagonals to be set. The default is 0.
|
256
|
+
Notice that the main diagonal is 0.
|
257
|
+
value : int, optional
|
258
|
+
value to be set. The default is 0.
|
259
|
+
"""
|
260
|
+
|
261
|
+
if not isinstance(A, np.ndarray):
|
262
|
+
A = np.array(A)
|
263
|
+
if A.ndim != 2:
|
264
|
+
raise ValueError("A must be a 2d array.")
|
265
|
+
if not isinstance(k, int):
|
266
|
+
raise ValueError("k must be an integer.")
|
267
|
+
if not isinstance(value, (int, float)):
|
268
|
+
raise ValueError("value must be a number.")
|
269
|
+
if k >= A.shape[0]:
|
270
|
+
raise ValueError("k must be smaller than the size of A.")
|
271
|
+
|
272
|
+
n = A.shape[0]
|
273
|
+
|
274
|
+
for i in range(-k, k + 1):
|
275
|
+
a1 = np.diag(np.random.randint(1, 2, n - abs(i)), i)
|
276
|
+
idx = np.where(a1)
|
277
|
+
A[idx] = value
|
278
|
+
return A
|
279
|
+
|
280
|
+
|
281
|
+
def if_symmetric(A, tol=1e-8):
|
282
|
+
"""
|
283
|
+
Check if the input matrix is symmetric.
|
284
|
+
|
285
|
+
Parameters
|
286
|
+
----------
|
287
|
+
A : numpy.ndarray
|
288
|
+
input matrix.
|
289
|
+
tol : float, optional
|
290
|
+
tolerance for checking symmetry. The default is 1e-8.
|
291
|
+
|
292
|
+
Returns
|
293
|
+
-------
|
294
|
+
bool
|
295
|
+
True if the input matrix is symmetric, False otherwise.
|
296
|
+
|
297
|
+
"""
|
298
|
+
if not isinstance(A, np.ndarray):
|
299
|
+
A = np.array(A)
|
300
|
+
if A.ndim != 2:
|
301
|
+
raise ValueError("A must be a 2d array.")
|
302
|
+
|
303
|
+
return np.allclose(A, A.T, atol=tol)
|
304
|
+
|
305
|
+
|
306
|
+
def scipy_iir_filter_data(
|
307
|
+
x, sfreq, l_freq, h_freq, l_trans_bandwidth=None, h_trans_bandwidth=None, **kwargs
|
308
|
+
):
|
309
|
+
"""
|
310
|
+
Custom, scipy based filtering function with basic butterworth filter.
|
311
|
+
#comes from neurolib
|
312
|
+
|
313
|
+
Parameters
|
314
|
+
----------
|
315
|
+
x : np.ndarray
|
316
|
+
data to be filtered, time is the last axis
|
317
|
+
sfreq : float
|
318
|
+
sampling frequency of the data in Hz
|
319
|
+
l_freq : float|None
|
320
|
+
frequency below which to filter the data in Hz
|
321
|
+
h_freq : float|None
|
322
|
+
frequency above which to filter the data in Hz
|
323
|
+
l_trans_bandwidth : keeping for compatibility with mne
|
324
|
+
h_trans_bandwidth : keeping for compatibility with mne
|
325
|
+
**kwargs : possible keywords to `scipy.signal.butter`:
|
326
|
+
|
327
|
+
Returns
|
328
|
+
-------
|
329
|
+
np.ndarray
|
330
|
+
filtered data
|
331
|
+
|
332
|
+
"""
|
333
|
+
|
334
|
+
from scipy.signal import butter, sosfiltfilt
|
335
|
+
|
336
|
+
nyq = 0.5 * sfreq
|
337
|
+
if l_freq is not None:
|
338
|
+
low = l_freq / nyq
|
339
|
+
if h_freq is not None:
|
340
|
+
# so we have band filter
|
341
|
+
high = h_freq / nyq
|
342
|
+
if l_freq < h_freq:
|
343
|
+
btype = "bandpass"
|
344
|
+
elif l_freq > h_freq:
|
345
|
+
btype = "bandstop"
|
346
|
+
Wn = [low, high]
|
347
|
+
elif h_freq is None:
|
348
|
+
# so we have a high-pass filter
|
349
|
+
Wn = low
|
350
|
+
btype = "highpass"
|
351
|
+
elif l_freq is None:
|
352
|
+
# we have a low-pass
|
353
|
+
high = h_freq / nyq
|
354
|
+
Wn = high
|
355
|
+
btype = "lowpass"
|
356
|
+
# get butter coeffs
|
357
|
+
sos = butter(N=kwargs.pop("order", 8), Wn=Wn, btype=btype, output="sos")
|
358
|
+
return sosfiltfilt(sos, x, axis=-1)
|
359
|
+
|
360
|
+
|
361
|
+
def filter(
|
362
|
+
ts: np.ndarray,
|
363
|
+
fs: float,
|
364
|
+
low_freq: float,
|
365
|
+
high_freq: float,
|
366
|
+
l_trans_bandwidth: str = "auto",
|
367
|
+
h_trans_bandwidth: str = "auto",
|
368
|
+
**kwargs,
|
369
|
+
):
|
370
|
+
"""
|
371
|
+
Filter data. Can be:
|
372
|
+
- low-pass (low_freq is None, high_freq is not None),
|
373
|
+
- high-pass (high_freq is None, low_freq is not None),
|
374
|
+
- band-pass (l_freq < h_freq),
|
375
|
+
- band-stop (l_freq > h_freq) filter type
|
376
|
+
|
377
|
+
Parameters
|
378
|
+
----------
|
379
|
+
ts: np.ndarray
|
380
|
+
Time series data
|
381
|
+
low_freq : float|None
|
382
|
+
frequency below which to filter the data.
|
383
|
+
high_freq : float|None
|
384
|
+
frequency above which to filter the data.
|
385
|
+
l_trans_bandwidth : float|str
|
386
|
+
transition band width for low frequency
|
387
|
+
h_trans_bandwidth : float|str
|
388
|
+
transition band width for high frequency
|
389
|
+
inplace : bool
|
390
|
+
whether to do the operation in place or return
|
391
|
+
kwargs : possible keywords to mne.filter.create_filter:
|
392
|
+
filter_length="auto",
|
393
|
+
method="fir",
|
394
|
+
iir_params=None
|
395
|
+
phase="zero",
|
396
|
+
fir_window="hamming",
|
397
|
+
fir_design="firwin"
|
398
|
+
|
399
|
+
Returns
|
400
|
+
-------
|
401
|
+
np.ndarray
|
402
|
+
filtered data
|
403
|
+
"""
|
404
|
+
|
405
|
+
try:
|
406
|
+
from mne.filter import filter_data
|
407
|
+
|
408
|
+
except ImportError:
|
409
|
+
logging.warning(
|
410
|
+
"`mne` module not found, falling back to basic scipy's function"
|
411
|
+
)
|
412
|
+
filter_data = scipy_iir_filter_data
|
413
|
+
|
414
|
+
filtered = filter_data(
|
415
|
+
ts, # times has to be the last axis
|
416
|
+
sfreq=fs,
|
417
|
+
l_freq=low_freq,
|
418
|
+
h_freq=high_freq,
|
419
|
+
l_trans_bandwidth=l_trans_bandwidth,
|
420
|
+
h_trans_bandwidth=h_trans_bandwidth,
|
421
|
+
**kwargs,
|
422
|
+
)
|
423
|
+
return filtered
|
424
|
+
|
425
|
+
|
426
|
+
|
427
|
+
|
428
|
+
def posterior_shrinkage(
|
429
|
+
prior_samples: Union[Tensor, np.ndarray], post_samples: Union[Tensor, np.ndarray]
|
430
|
+
) -> Tensor:
|
431
|
+
"""
|
432
|
+
Calculate the posterior shrinkage, quantifying how much
|
433
|
+
the posterior distribution contracts from the initial
|
434
|
+
prior distribution.
|
435
|
+
References:
|
436
|
+
https://arxiv.org/abs/1803.08393
|
437
|
+
|
438
|
+
Parameters
|
439
|
+
----------
|
440
|
+
prior_samples : array_like or torch.Tensor [n_samples, n_params]
|
441
|
+
Samples from the prior distribution.
|
442
|
+
post_samples : array-like or torch.Tensor [n_samples, n_params]
|
443
|
+
Samples from the posterior distribution.
|
444
|
+
|
445
|
+
Returns
|
446
|
+
-------
|
447
|
+
shrinkage : torch.Tensor [n_params]
|
448
|
+
The posterior shrinkage.
|
449
|
+
"""
|
450
|
+
|
451
|
+
if len(prior_samples) == 0 or len(post_samples) == 0:
|
452
|
+
raise ValueError("Input samples are empty")
|
453
|
+
|
454
|
+
if not isinstance(prior_samples, torch.Tensor):
|
455
|
+
prior_samples = torch.tensor(prior_samples, dtype=torch.float32)
|
456
|
+
if not isinstance(post_samples, torch.Tensor):
|
457
|
+
post_samples = torch.tensor(post_samples, dtype=torch.float32)
|
458
|
+
|
459
|
+
if prior_samples.ndim == 1:
|
460
|
+
prior_samples = prior_samples[:, None]
|
461
|
+
if post_samples.ndim == 1:
|
462
|
+
post_samples = post_samples[:, None]
|
463
|
+
|
464
|
+
prior_std = torch.std(prior_samples, dim=0)
|
465
|
+
post_std = torch.std(post_samples, dim=0)
|
466
|
+
|
467
|
+
return 1 - (post_std / prior_std) ** 2
|
468
|
+
|
469
|
+
|
470
|
+
def posterior_zscore(
|
471
|
+
true_theta: Union[Tensor, np.array, float], post_samples: Union[Tensor, np.array]
|
472
|
+
):
|
473
|
+
"""
|
474
|
+
Calculate the posterior z-score, quantifying how much the posterior
|
475
|
+
distribution of a parameter encompasses its true value.
|
476
|
+
References:
|
477
|
+
https://arxiv.org/abs/1803.08393
|
478
|
+
|
479
|
+
Parameters
|
480
|
+
----------
|
481
|
+
true_theta : float, array-like or torch.Tensor [n_params]
|
482
|
+
The true value of the parameters.
|
483
|
+
post_samples : array-like or torch.Tensor [n_samples, n_params]
|
484
|
+
Samples from the posterior distributions.
|
485
|
+
|
486
|
+
Returns
|
487
|
+
-------
|
488
|
+
z : Tensor [n_params]
|
489
|
+
The z-score of the posterior distributions.
|
490
|
+
"""
|
491
|
+
|
492
|
+
if len(post_samples) == 0:
|
493
|
+
raise ValueError("Input samples are empty")
|
494
|
+
|
495
|
+
if not isinstance(true_theta, torch.Tensor):
|
496
|
+
true_theta = torch.tensor(true_theta, dtype=torch.float32)
|
497
|
+
if not isinstance(post_samples, torch.Tensor):
|
498
|
+
post_samples = torch.tensor(post_samples, dtype=torch.float32)
|
499
|
+
|
500
|
+
true_theta = np.atleast_1d(true_theta)
|
501
|
+
if post_samples.ndim == 1:
|
502
|
+
post_samples = post_samples[:, None]
|
503
|
+
|
504
|
+
post_mean = torch.mean(post_samples, dim=0)
|
505
|
+
post_std = torch.std(post_samples, dim=0)
|
506
|
+
|
507
|
+
return torch.abs((post_mean - true_theta) / post_std)
|
vbi/inference.py
ADDED
@@ -0,0 +1,98 @@
|
|
1
|
+
import torch
|
2
|
+
from vbi.utils import *
|
3
|
+
from sbi.inference import SNPE, SNLE, SNRE
|
4
|
+
from sbi.utils.user_input_checks import process_prior
|
5
|
+
|
6
|
+
class Inference(object):
|
7
|
+
def __init__(self) -> None:
|
8
|
+
pass
|
9
|
+
|
10
|
+
@timer
|
11
|
+
def train(self,
|
12
|
+
theta,
|
13
|
+
x,
|
14
|
+
prior,
|
15
|
+
num_threads=1,
|
16
|
+
method="SNPE",
|
17
|
+
device="cpu",
|
18
|
+
density_estimator="maf"
|
19
|
+
):
|
20
|
+
|
21
|
+
torch.set_num_threads(num_threads)
|
22
|
+
|
23
|
+
if (len(x.shape) == 1):
|
24
|
+
x = x[:, None]
|
25
|
+
if (len(theta.shape) == 1):
|
26
|
+
theta = theta[:, None]
|
27
|
+
|
28
|
+
if method == "SNPE":
|
29
|
+
inference = SNPE(
|
30
|
+
prior=prior, density_estimator=density_estimator, device=device)
|
31
|
+
elif method == "SNLE":
|
32
|
+
inference = SNLE(
|
33
|
+
prior=prior, density_estimator=density_estimator, device=device)
|
34
|
+
elif method == "SNRE":
|
35
|
+
inference = SNRE(
|
36
|
+
prior=prior, density_estimator=density_estimator, device=device)
|
37
|
+
else:
|
38
|
+
raise ValueError("Invalid method: " + method)
|
39
|
+
|
40
|
+
inference = inference.append_simulations(theta, x)
|
41
|
+
estimator_ = inference.train()
|
42
|
+
posterior = inference.build_posterior(estimator_)
|
43
|
+
|
44
|
+
return posterior
|
45
|
+
|
46
|
+
@staticmethod
|
47
|
+
def sample_prior(prior, n, seed=None):
|
48
|
+
'''
|
49
|
+
sample from prior distribution
|
50
|
+
|
51
|
+
Parameters
|
52
|
+
----------
|
53
|
+
prior: ?
|
54
|
+
prior distribution
|
55
|
+
n: int
|
56
|
+
number of samples
|
57
|
+
|
58
|
+
Returns
|
59
|
+
-------
|
60
|
+
|
61
|
+
'''
|
62
|
+
if seed is not None:
|
63
|
+
torch.manual_seed(seed)
|
64
|
+
|
65
|
+
prior, _, _ = process_prior(prior)
|
66
|
+
theta = prior.sample((n,))
|
67
|
+
return theta
|
68
|
+
|
69
|
+
@staticmethod
|
70
|
+
def sample_posterior(xo,
|
71
|
+
num_samples,
|
72
|
+
posterior):
|
73
|
+
'''
|
74
|
+
sample from the posterior using the given observation point.
|
75
|
+
|
76
|
+
Parameters
|
77
|
+
----------
|
78
|
+
x0: torch.tensor float32 (1, d)
|
79
|
+
observation point
|
80
|
+
num_samples: int
|
81
|
+
number of samples
|
82
|
+
posterior: ?
|
83
|
+
posterior object
|
84
|
+
|
85
|
+
Returns
|
86
|
+
-------
|
87
|
+
samples: torch.tensor float32 (num_samples, d)
|
88
|
+
samples from the posterior
|
89
|
+
|
90
|
+
'''
|
91
|
+
|
92
|
+
if not isinstance(xo, torch.Tensor):
|
93
|
+
xo = torch.tensor(xo, dtype=torch.float32)
|
94
|
+
if len(xo.shape) == 1:
|
95
|
+
xo = xo[None, :]
|
96
|
+
|
97
|
+
samples = posterior.sample((num_samples,), x=xo)
|
98
|
+
return samples
|
vbi/models/__init__.py
ADDED
File without changes
|
File without changes
|
File without changes
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|