vbi 0.1.3__cp310-cp310-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. vbi/__init__.py +37 -0
  2. vbi/_version.py +17 -0
  3. vbi/dataset/__init__.py +0 -0
  4. vbi/dataset/connectivity_84/centers.txt +84 -0
  5. vbi/dataset/connectivity_84/centres.txt +84 -0
  6. vbi/dataset/connectivity_84/cortical.txt +84 -0
  7. vbi/dataset/connectivity_84/tract_lengths.txt +84 -0
  8. vbi/dataset/connectivity_84/weights.txt +84 -0
  9. vbi/dataset/connectivity_88/Aud_88.txt +88 -0
  10. vbi/dataset/connectivity_88/Bold.npz +0 -0
  11. vbi/dataset/connectivity_88/Labels.txt +17 -0
  12. vbi/dataset/connectivity_88/Region_labels.txt +88 -0
  13. vbi/dataset/connectivity_88/tract_lengths.txt +88 -0
  14. vbi/dataset/connectivity_88/weights.txt +88 -0
  15. vbi/feature_extraction/__init__.py +1 -0
  16. vbi/feature_extraction/calc_features.py +293 -0
  17. vbi/feature_extraction/features.json +535 -0
  18. vbi/feature_extraction/features.py +2124 -0
  19. vbi/feature_extraction/features_settings.py +374 -0
  20. vbi/feature_extraction/features_utils.py +1357 -0
  21. vbi/feature_extraction/infodynamics.jar +0 -0
  22. vbi/feature_extraction/utility.py +507 -0
  23. vbi/inference.py +98 -0
  24. vbi/models/__init__.py +0 -0
  25. vbi/models/cpp/__init__.py +0 -0
  26. vbi/models/cpp/_src/__init__.py +0 -0
  27. vbi/models/cpp/_src/__pycache__/mpr_sde.cpython-310.pyc +0 -0
  28. vbi/models/cpp/_src/_do.cpython-310-x86_64-linux-gnu.so +0 -0
  29. vbi/models/cpp/_src/_jr_sdde.cpython-310-x86_64-linux-gnu.so +0 -0
  30. vbi/models/cpp/_src/_jr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  31. vbi/models/cpp/_src/_km_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  32. vbi/models/cpp/_src/_mpr_sde.cpython-310-x86_64-linux-gnu.so +0 -0
  33. vbi/models/cpp/_src/_vep.cpython-310-x86_64-linux-gnu.so +0 -0
  34. vbi/models/cpp/_src/_wc_ode.cpython-310-x86_64-linux-gnu.so +0 -0
  35. vbi/models/cpp/_src/bold.hpp +303 -0
  36. vbi/models/cpp/_src/do.hpp +167 -0
  37. vbi/models/cpp/_src/do.i +17 -0
  38. vbi/models/cpp/_src/do.py +467 -0
  39. vbi/models/cpp/_src/do_wrap.cxx +12811 -0
  40. vbi/models/cpp/_src/jr_sdde.hpp +352 -0
  41. vbi/models/cpp/_src/jr_sdde.i +19 -0
  42. vbi/models/cpp/_src/jr_sdde.py +688 -0
  43. vbi/models/cpp/_src/jr_sdde_wrap.cxx +18718 -0
  44. vbi/models/cpp/_src/jr_sde.hpp +264 -0
  45. vbi/models/cpp/_src/jr_sde.i +17 -0
  46. vbi/models/cpp/_src/jr_sde.py +470 -0
  47. vbi/models/cpp/_src/jr_sde_wrap.cxx +13406 -0
  48. vbi/models/cpp/_src/km_sde.hpp +158 -0
  49. vbi/models/cpp/_src/km_sde.i +19 -0
  50. vbi/models/cpp/_src/km_sde.py +671 -0
  51. vbi/models/cpp/_src/km_sde_wrap.cxx +17367 -0
  52. vbi/models/cpp/_src/makefile +52 -0
  53. vbi/models/cpp/_src/mpr_sde.hpp +327 -0
  54. vbi/models/cpp/_src/mpr_sde.i +19 -0
  55. vbi/models/cpp/_src/mpr_sde.py +711 -0
  56. vbi/models/cpp/_src/mpr_sde_wrap.cxx +18618 -0
  57. vbi/models/cpp/_src/utility.hpp +307 -0
  58. vbi/models/cpp/_src/vep.hpp +171 -0
  59. vbi/models/cpp/_src/vep.i +16 -0
  60. vbi/models/cpp/_src/vep.py +464 -0
  61. vbi/models/cpp/_src/vep_wrap.cxx +12968 -0
  62. vbi/models/cpp/_src/wc_ode.hpp +294 -0
  63. vbi/models/cpp/_src/wc_ode.i +19 -0
  64. vbi/models/cpp/_src/wc_ode.py +686 -0
  65. vbi/models/cpp/_src/wc_ode_wrap.cxx +24263 -0
  66. vbi/models/cpp/damp_oscillator.py +143 -0
  67. vbi/models/cpp/jansen_rit.py +543 -0
  68. vbi/models/cpp/km.py +187 -0
  69. vbi/models/cpp/mpr.py +289 -0
  70. vbi/models/cpp/vep.py +150 -0
  71. vbi/models/cpp/wc.py +216 -0
  72. vbi/models/cupy/__init__.py +0 -0
  73. vbi/models/cupy/bold.py +111 -0
  74. vbi/models/cupy/ghb.py +284 -0
  75. vbi/models/cupy/jansen_rit.py +473 -0
  76. vbi/models/cupy/km.py +224 -0
  77. vbi/models/cupy/mpr.py +475 -0
  78. vbi/models/cupy/mpr_modified_bold.py +12 -0
  79. vbi/models/cupy/utils.py +184 -0
  80. vbi/models/numba/__init__.py +0 -0
  81. vbi/models/numba/_ww_EI.py +444 -0
  82. vbi/models/numba/damp_oscillator.py +162 -0
  83. vbi/models/numba/ghb.py +208 -0
  84. vbi/models/numba/mpr.py +383 -0
  85. vbi/models/pytorch/__init__.py +0 -0
  86. vbi/models/pytorch/data/default_parameters.npz +0 -0
  87. vbi/models/pytorch/data/input/ROI_sim.mat +0 -0
  88. vbi/models/pytorch/data/input/fc_test.csv +68 -0
  89. vbi/models/pytorch/data/input/fc_train.csv +68 -0
  90. vbi/models/pytorch/data/input/fc_vali.csv +68 -0
  91. vbi/models/pytorch/data/input/fcd_test.mat +0 -0
  92. vbi/models/pytorch/data/input/fcd_test_high_window.mat +0 -0
  93. vbi/models/pytorch/data/input/fcd_test_low_window.mat +0 -0
  94. vbi/models/pytorch/data/input/fcd_train.mat +0 -0
  95. vbi/models/pytorch/data/input/fcd_vali.mat +0 -0
  96. vbi/models/pytorch/data/input/myelin.csv +68 -0
  97. vbi/models/pytorch/data/input/rsfc_gradient.csv +68 -0
  98. vbi/models/pytorch/data/input/run_label_testset.mat +0 -0
  99. vbi/models/pytorch/data/input/sc_test.csv +68 -0
  100. vbi/models/pytorch/data/input/sc_train.csv +68 -0
  101. vbi/models/pytorch/data/input/sc_vali.csv +68 -0
  102. vbi/models/pytorch/data/obs_kong0.npz +0 -0
  103. vbi/models/pytorch/ww_sde_kong.py +570 -0
  104. vbi/models/tvbk/__init__.py +9 -0
  105. vbi/models/tvbk/tvbk_wrapper.py +166 -0
  106. vbi/models/tvbk/utils.py +72 -0
  107. vbi/papers/__init__.py +0 -0
  108. vbi/papers/pavlides_pcb_2015/pavlides.py +211 -0
  109. vbi/tests/__init__.py +0 -0
  110. vbi/tests/_test_mpr_nb.py +36 -0
  111. vbi/tests/test_features.py +355 -0
  112. vbi/tests/test_ghb_cupy.py +90 -0
  113. vbi/tests/test_mpr_cupy.py +49 -0
  114. vbi/tests/test_mpr_numba.py +84 -0
  115. vbi/tests/test_suite.py +19 -0
  116. vbi/utils.py +402 -0
  117. vbi-0.1.3.dist-info/METADATA +166 -0
  118. vbi-0.1.3.dist-info/RECORD +121 -0
  119. vbi-0.1.3.dist-info/WHEEL +5 -0
  120. vbi-0.1.3.dist-info/licenses/LICENSE +201 -0
  121. vbi-0.1.3.dist-info/top_level.txt +1 -0
Binary file
@@ -0,0 +1,507 @@
1
+ import torch
2
+ import logging
3
+ import numpy as np
4
+ import pandas as pd
5
+ from torch import Tensor
6
+ from typing import Union, List
7
+
8
+
9
+ def count_depth(ls):
10
+ """
11
+ count the depth of a list
12
+
13
+ """
14
+ if isinstance(ls, (list, tuple)):
15
+ return 1 + max(count_depth(item) for item in ls)
16
+ else:
17
+ return 0
18
+
19
+
20
+ def prepare_input(ts, dtype=np.float32):
21
+ """
22
+ prepare input format
23
+
24
+ Parameters
25
+ ----------
26
+ ts : array-like or list
27
+ Input from which the features are extracted
28
+ Returns
29
+ -------
30
+ ts: nd-array
31
+ formatted input
32
+
33
+ """
34
+ n_trial = 0
35
+
36
+ if isinstance(ts, np.ndarray):
37
+ if ts.ndim == 3:
38
+ pass
39
+ elif ts.ndim == 2:
40
+ ts = ts[:, np.newaxis, :] # n_region = 1
41
+ else:
42
+ ts = ts[np.newaxis, np.newaxis, :] # n_region , n_trial = 1
43
+
44
+ elif isinstance(ts, (list, tuple)):
45
+ if isinstance(ts[0], np.ndarray):
46
+ if ts[0].ndim == 2:
47
+ ts = np.array(ts, dtype=dtype)
48
+ elif ts[0].ndim == 1:
49
+ ts = np.array(ts, dtype=dtype)
50
+ ts = ts[:, np.newaxis, :] # n_region = 1
51
+ else:
52
+ ts = np.array(ts, dtype=dtype)[np.newaxis, np.newaxis, :]
53
+ else:
54
+ if isinstance(ts[0], (list, tuple)):
55
+ depth = count_depth(ts)
56
+ if depth == 3:
57
+ ts = np.asarray(ts)
58
+ elif depth == 2:
59
+ ts = np.array(ts)
60
+ ts = ts[:, np.newaxis, :] # n_region = 1
61
+ else:
62
+ ts = np.array(ts)[
63
+ np.newaxis, np.newaxis, :
64
+ ] # n_region , n_trial = 1
65
+
66
+ # if ts is dataframe
67
+ elif isinstance(ts, pd.DataFrame):
68
+ # assume that the dataframe is in the form of
69
+ # columns: time series
70
+ # rows: time
71
+ ts = ts.values.T
72
+ ts = ts[:, np.newaxis, :] # n_region = 1
73
+
74
+ return ts, n_trial
75
+
76
+
77
+ def prepare_input_ts(ts, indices: List[int] = None):
78
+
79
+ if not isinstance(ts, np.ndarray):
80
+ ts = np.array(ts)
81
+ if indices is None:
82
+ indices = np.arange(ts.shape[0], dtype=np.int32)
83
+
84
+ # check indices validity
85
+ if not isinstance(indices, (list, tuple, np.ndarray)):
86
+ raise ValueError("indices must be a list, tuple, or numpy array.")
87
+ if not all(isinstance(i, (int, np.int64, np.int32, np.int16)) for i in indices):
88
+ raise ValueError("indices must be a list of integers.")
89
+ if not all(i < ts.shape[0] for i in indices):
90
+ raise ValueError("indices must be smaller than the number of time series.")
91
+
92
+ ts = ts[indices]
93
+
94
+ if ts.ndim == 1:
95
+ ts = ts.reshape(1, -1)
96
+
97
+ if ts.size == 0:
98
+ return False, ts
99
+
100
+ if np.isnan(ts).any() or np.isinf(ts).any():
101
+ return False, ts
102
+ return True, ts
103
+
104
+
105
+ def make_mask(n, indices):
106
+ """
107
+ make a mask matrix with given indices
108
+
109
+ Parameters
110
+ ----------
111
+ n : int
112
+ size of the mask matrix
113
+ indices : list
114
+ indices of the mask matrix
115
+
116
+ Returns
117
+ -------
118
+ mask : numpy.ndarray
119
+ mask matrix
120
+ """
121
+ # check validity of indices
122
+ if not isinstance(indices, (list, tuple, np.ndarray)):
123
+ raise ValueError("indices must be a list, tuple, or numpy array.")
124
+ if not all(isinstance(i, (int, np.int64, np.int32, np.int16)) for i in indices):
125
+ raise ValueError("indices must be a list of integers.")
126
+ if not all(i < n for i in indices):
127
+ raise ValueError("indices must be smaller than n.")
128
+
129
+ mask = np.zeros((n, n), dtype=np.int64)
130
+ mask[np.ix_(indices, indices)] = 1
131
+ mask = mask - np.diag(np.diag(mask))
132
+
133
+ return mask
134
+
135
+
136
+ def get_intrah_mask(n_nodes):
137
+ """
138
+ Get a mask for intrahemispheric connections.
139
+
140
+ Parameters
141
+ ----------
142
+ n_nodes: int
143
+ number of total nodes that constitute the data.
144
+
145
+ Returns
146
+ -------
147
+ mask_intrah: 2d array
148
+ mask for intrahemispheric connections.
149
+ """
150
+ row_idx = np.arange(n_nodes)
151
+ idx1 = np.ix_(row_idx[: n_nodes // 2], row_idx[: n_nodes // 2])
152
+ idx2 = np.ix_(row_idx[n_nodes // 2 :], row_idx[n_nodes // 2 :])
153
+ # build on a zeros mask
154
+ mask_intrah = np.zeros((n_nodes, n_nodes))
155
+ mask_intrah[idx1] = 1
156
+ mask_intrah[idx2] = 1
157
+ return mask_intrah
158
+
159
+
160
+ def get_interh_mask(n_nodes):
161
+ """
162
+ Get a mask for interhemispheric connections.
163
+
164
+ Parameters
165
+ ----------
166
+ n_nodes: int
167
+ number of total nodes that constitute the data.
168
+
169
+ Returns
170
+ -------
171
+ mask_interh: 2d array
172
+ mask for interhemispheric connections.
173
+ """
174
+ row_idx = np.arange(n_nodes // 2)
175
+ col_idx1 = np.where(np.eye(n_nodes, k=-n_nodes // 2))[0]
176
+ col_idx2 = np.where(np.eye(n_nodes, k=n_nodes // 2))[0]
177
+ idx1 = np.ix_(row_idx, col_idx1)
178
+ idx2 = np.ix_(row_idx + n_nodes // 2, col_idx2)
179
+ # build on a zeros mask
180
+ mask_interh = np.zeros((n_nodes, n_nodes))
181
+ mask_interh[idx1] = 1
182
+ mask_interh[idx2] = 1
183
+ return mask_interh
184
+
185
+
186
+ def get_masks(n_nodes, networks):
187
+ """
188
+ Get a dictionary of masks based on the requested networks.
189
+
190
+ Parameters
191
+ ----------
192
+ n_nodes: int
193
+ number of total nodes that constitute the data.
194
+ networks: list of str
195
+ list of networks to be included in the dictionary.
196
+ 'full': full-network connections
197
+ 'intrah': intrahemispheric connections
198
+ 'interh': interhemispheric connections
199
+ to get a custom mask with specific indices
200
+ refere to `hbt.utility.make_mask(n, indices)`.
201
+
202
+ Returns
203
+ -------
204
+ masks: dict
205
+ dictionary of masks based on the requested networks.
206
+ """
207
+ masks = {}
208
+ valid_networks = ["full", "intrah", "interh"]
209
+ # check if networks are valid
210
+ if not is_sequence(networks):
211
+ networks = [networks]
212
+
213
+ for i, ntw in enumerate(networks):
214
+ if ntw not in valid_networks:
215
+ raise ValueError(
216
+ f"Invalid network: {ntw}. Please choose from {valid_networks}."
217
+ )
218
+ if ntw == "full":
219
+ masks[ntw] = np.ones((n_nodes, n_nodes))
220
+ elif ntw == "intrah":
221
+ masks[ntw] = get_intrah_mask(n_nodes)
222
+ elif ntw == "interh":
223
+ masks[ntw] = get_interh_mask(n_nodes)
224
+
225
+ return masks
226
+
227
+
228
+ def is_sequence(arg):
229
+ """
230
+ Check if the input is a sequence (list, tuple, np.ndarray, etc.)
231
+
232
+ Parameters
233
+ ----------
234
+ arg : any
235
+ input to be checked.
236
+
237
+ Returns
238
+ -------
239
+ bool
240
+ True if the input is a sequence, False otherwise.
241
+
242
+ """
243
+ return isinstance(arg, (list, tuple, np.ndarray))
244
+
245
+
246
+ def set_k_diagonals(A, k=0, value=0):
247
+ """
248
+ set k diagonals of the given matrix to given value.
249
+
250
+ Parameters
251
+ ----------
252
+ A : numpy.ndarray
253
+ input matrix.
254
+ k : int
255
+ number of diagonals to be set. The default is 0.
256
+ Notice that the main diagonal is 0.
257
+ value : int, optional
258
+ value to be set. The default is 0.
259
+ """
260
+
261
+ if not isinstance(A, np.ndarray):
262
+ A = np.array(A)
263
+ if A.ndim != 2:
264
+ raise ValueError("A must be a 2d array.")
265
+ if not isinstance(k, int):
266
+ raise ValueError("k must be an integer.")
267
+ if not isinstance(value, (int, float)):
268
+ raise ValueError("value must be a number.")
269
+ if k >= A.shape[0]:
270
+ raise ValueError("k must be smaller than the size of A.")
271
+
272
+ n = A.shape[0]
273
+
274
+ for i in range(-k, k + 1):
275
+ a1 = np.diag(np.random.randint(1, 2, n - abs(i)), i)
276
+ idx = np.where(a1)
277
+ A[idx] = value
278
+ return A
279
+
280
+
281
+ def if_symmetric(A, tol=1e-8):
282
+ """
283
+ Check if the input matrix is symmetric.
284
+
285
+ Parameters
286
+ ----------
287
+ A : numpy.ndarray
288
+ input matrix.
289
+ tol : float, optional
290
+ tolerance for checking symmetry. The default is 1e-8.
291
+
292
+ Returns
293
+ -------
294
+ bool
295
+ True if the input matrix is symmetric, False otherwise.
296
+
297
+ """
298
+ if not isinstance(A, np.ndarray):
299
+ A = np.array(A)
300
+ if A.ndim != 2:
301
+ raise ValueError("A must be a 2d array.")
302
+
303
+ return np.allclose(A, A.T, atol=tol)
304
+
305
+
306
+ def scipy_iir_filter_data(
307
+ x, sfreq, l_freq, h_freq, l_trans_bandwidth=None, h_trans_bandwidth=None, **kwargs
308
+ ):
309
+ """
310
+ Custom, scipy based filtering function with basic butterworth filter.
311
+ #comes from neurolib
312
+
313
+ Parameters
314
+ ----------
315
+ x : np.ndarray
316
+ data to be filtered, time is the last axis
317
+ sfreq : float
318
+ sampling frequency of the data in Hz
319
+ l_freq : float|None
320
+ frequency below which to filter the data in Hz
321
+ h_freq : float|None
322
+ frequency above which to filter the data in Hz
323
+ l_trans_bandwidth : keeping for compatibility with mne
324
+ h_trans_bandwidth : keeping for compatibility with mne
325
+ **kwargs : possible keywords to `scipy.signal.butter`:
326
+
327
+ Returns
328
+ -------
329
+ np.ndarray
330
+ filtered data
331
+
332
+ """
333
+
334
+ from scipy.signal import butter, sosfiltfilt
335
+
336
+ nyq = 0.5 * sfreq
337
+ if l_freq is not None:
338
+ low = l_freq / nyq
339
+ if h_freq is not None:
340
+ # so we have band filter
341
+ high = h_freq / nyq
342
+ if l_freq < h_freq:
343
+ btype = "bandpass"
344
+ elif l_freq > h_freq:
345
+ btype = "bandstop"
346
+ Wn = [low, high]
347
+ elif h_freq is None:
348
+ # so we have a high-pass filter
349
+ Wn = low
350
+ btype = "highpass"
351
+ elif l_freq is None:
352
+ # we have a low-pass
353
+ high = h_freq / nyq
354
+ Wn = high
355
+ btype = "lowpass"
356
+ # get butter coeffs
357
+ sos = butter(N=kwargs.pop("order", 8), Wn=Wn, btype=btype, output="sos")
358
+ return sosfiltfilt(sos, x, axis=-1)
359
+
360
+
361
+ def filter(
362
+ ts: np.ndarray,
363
+ fs: float,
364
+ low_freq: float,
365
+ high_freq: float,
366
+ l_trans_bandwidth: str = "auto",
367
+ h_trans_bandwidth: str = "auto",
368
+ **kwargs,
369
+ ):
370
+ """
371
+ Filter data. Can be:
372
+ - low-pass (low_freq is None, high_freq is not None),
373
+ - high-pass (high_freq is None, low_freq is not None),
374
+ - band-pass (l_freq < h_freq),
375
+ - band-stop (l_freq > h_freq) filter type
376
+
377
+ Parameters
378
+ ----------
379
+ ts: np.ndarray
380
+ Time series data
381
+ low_freq : float|None
382
+ frequency below which to filter the data.
383
+ high_freq : float|None
384
+ frequency above which to filter the data.
385
+ l_trans_bandwidth : float|str
386
+ transition band width for low frequency
387
+ h_trans_bandwidth : float|str
388
+ transition band width for high frequency
389
+ inplace : bool
390
+ whether to do the operation in place or return
391
+ kwargs : possible keywords to mne.filter.create_filter:
392
+ filter_length="auto",
393
+ method="fir",
394
+ iir_params=None
395
+ phase="zero",
396
+ fir_window="hamming",
397
+ fir_design="firwin"
398
+
399
+ Returns
400
+ -------
401
+ np.ndarray
402
+ filtered data
403
+ """
404
+
405
+ try:
406
+ from mne.filter import filter_data
407
+
408
+ except ImportError:
409
+ logging.warning(
410
+ "`mne` module not found, falling back to basic scipy's function"
411
+ )
412
+ filter_data = scipy_iir_filter_data
413
+
414
+ filtered = filter_data(
415
+ ts, # times has to be the last axis
416
+ sfreq=fs,
417
+ l_freq=low_freq,
418
+ h_freq=high_freq,
419
+ l_trans_bandwidth=l_trans_bandwidth,
420
+ h_trans_bandwidth=h_trans_bandwidth,
421
+ **kwargs,
422
+ )
423
+ return filtered
424
+
425
+
426
+
427
+
428
+ def posterior_shrinkage(
429
+ prior_samples: Union[Tensor, np.ndarray], post_samples: Union[Tensor, np.ndarray]
430
+ ) -> Tensor:
431
+ """
432
+ Calculate the posterior shrinkage, quantifying how much
433
+ the posterior distribution contracts from the initial
434
+ prior distribution.
435
+ References:
436
+ https://arxiv.org/abs/1803.08393
437
+
438
+ Parameters
439
+ ----------
440
+ prior_samples : array_like or torch.Tensor [n_samples, n_params]
441
+ Samples from the prior distribution.
442
+ post_samples : array-like or torch.Tensor [n_samples, n_params]
443
+ Samples from the posterior distribution.
444
+
445
+ Returns
446
+ -------
447
+ shrinkage : torch.Tensor [n_params]
448
+ The posterior shrinkage.
449
+ """
450
+
451
+ if len(prior_samples) == 0 or len(post_samples) == 0:
452
+ raise ValueError("Input samples are empty")
453
+
454
+ if not isinstance(prior_samples, torch.Tensor):
455
+ prior_samples = torch.tensor(prior_samples, dtype=torch.float32)
456
+ if not isinstance(post_samples, torch.Tensor):
457
+ post_samples = torch.tensor(post_samples, dtype=torch.float32)
458
+
459
+ if prior_samples.ndim == 1:
460
+ prior_samples = prior_samples[:, None]
461
+ if post_samples.ndim == 1:
462
+ post_samples = post_samples[:, None]
463
+
464
+ prior_std = torch.std(prior_samples, dim=0)
465
+ post_std = torch.std(post_samples, dim=0)
466
+
467
+ return 1 - (post_std / prior_std) ** 2
468
+
469
+
470
+ def posterior_zscore(
471
+ true_theta: Union[Tensor, np.array, float], post_samples: Union[Tensor, np.array]
472
+ ):
473
+ """
474
+ Calculate the posterior z-score, quantifying how much the posterior
475
+ distribution of a parameter encompasses its true value.
476
+ References:
477
+ https://arxiv.org/abs/1803.08393
478
+
479
+ Parameters
480
+ ----------
481
+ true_theta : float, array-like or torch.Tensor [n_params]
482
+ The true value of the parameters.
483
+ post_samples : array-like or torch.Tensor [n_samples, n_params]
484
+ Samples from the posterior distributions.
485
+
486
+ Returns
487
+ -------
488
+ z : Tensor [n_params]
489
+ The z-score of the posterior distributions.
490
+ """
491
+
492
+ if len(post_samples) == 0:
493
+ raise ValueError("Input samples are empty")
494
+
495
+ if not isinstance(true_theta, torch.Tensor):
496
+ true_theta = torch.tensor(true_theta, dtype=torch.float32)
497
+ if not isinstance(post_samples, torch.Tensor):
498
+ post_samples = torch.tensor(post_samples, dtype=torch.float32)
499
+
500
+ true_theta = np.atleast_1d(true_theta)
501
+ if post_samples.ndim == 1:
502
+ post_samples = post_samples[:, None]
503
+
504
+ post_mean = torch.mean(post_samples, dim=0)
505
+ post_std = torch.std(post_samples, dim=0)
506
+
507
+ return torch.abs((post_mean - true_theta) / post_std)
vbi/inference.py ADDED
@@ -0,0 +1,98 @@
1
+ import torch
2
+ from vbi.utils import *
3
+ from sbi.inference import SNPE, SNLE, SNRE
4
+ from sbi.utils.user_input_checks import process_prior
5
+
6
+ class Inference(object):
7
+ def __init__(self) -> None:
8
+ pass
9
+
10
+ @timer
11
+ def train(self,
12
+ theta,
13
+ x,
14
+ prior,
15
+ num_threads=1,
16
+ method="SNPE",
17
+ device="cpu",
18
+ density_estimator="maf"
19
+ ):
20
+
21
+ torch.set_num_threads(num_threads)
22
+
23
+ if (len(x.shape) == 1):
24
+ x = x[:, None]
25
+ if (len(theta.shape) == 1):
26
+ theta = theta[:, None]
27
+
28
+ if method == "SNPE":
29
+ inference = SNPE(
30
+ prior=prior, density_estimator=density_estimator, device=device)
31
+ elif method == "SNLE":
32
+ inference = SNLE(
33
+ prior=prior, density_estimator=density_estimator, device=device)
34
+ elif method == "SNRE":
35
+ inference = SNRE(
36
+ prior=prior, density_estimator=density_estimator, device=device)
37
+ else:
38
+ raise ValueError("Invalid method: " + method)
39
+
40
+ inference = inference.append_simulations(theta, x)
41
+ estimator_ = inference.train()
42
+ posterior = inference.build_posterior(estimator_)
43
+
44
+ return posterior
45
+
46
+ @staticmethod
47
+ def sample_prior(prior, n, seed=None):
48
+ '''
49
+ sample from prior distribution
50
+
51
+ Parameters
52
+ ----------
53
+ prior: ?
54
+ prior distribution
55
+ n: int
56
+ number of samples
57
+
58
+ Returns
59
+ -------
60
+
61
+ '''
62
+ if seed is not None:
63
+ torch.manual_seed(seed)
64
+
65
+ prior, _, _ = process_prior(prior)
66
+ theta = prior.sample((n,))
67
+ return theta
68
+
69
+ @staticmethod
70
+ def sample_posterior(xo,
71
+ num_samples,
72
+ posterior):
73
+ '''
74
+ sample from the posterior using the given observation point.
75
+
76
+ Parameters
77
+ ----------
78
+ x0: torch.tensor float32 (1, d)
79
+ observation point
80
+ num_samples: int
81
+ number of samples
82
+ posterior: ?
83
+ posterior object
84
+
85
+ Returns
86
+ -------
87
+ samples: torch.tensor float32 (num_samples, d)
88
+ samples from the posterior
89
+
90
+ '''
91
+
92
+ if not isinstance(xo, torch.Tensor):
93
+ xo = torch.tensor(xo, dtype=torch.float32)
94
+ if len(xo.shape) == 1:
95
+ xo = xo[None, :]
96
+
97
+ samples = posterior.sample((num_samples,), x=xo)
98
+ return samples
vbi/models/__init__.py ADDED
File without changes
File without changes
File without changes