nimare 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- benchmarks/__init__.py +0 -0
- benchmarks/bench_cbma.py +57 -0
- nimare/__init__.py +45 -0
- nimare/_version.py +21 -0
- nimare/annotate/__init__.py +21 -0
- nimare/annotate/cogat.py +213 -0
- nimare/annotate/gclda.py +924 -0
- nimare/annotate/lda.py +147 -0
- nimare/annotate/text.py +75 -0
- nimare/annotate/utils.py +87 -0
- nimare/base.py +217 -0
- nimare/cli.py +124 -0
- nimare/correct.py +462 -0
- nimare/dataset.py +685 -0
- nimare/decode/__init__.py +33 -0
- nimare/decode/base.py +115 -0
- nimare/decode/continuous.py +462 -0
- nimare/decode/discrete.py +753 -0
- nimare/decode/encode.py +110 -0
- nimare/decode/utils.py +44 -0
- nimare/diagnostics.py +510 -0
- nimare/estimator.py +139 -0
- nimare/extract/__init__.py +19 -0
- nimare/extract/extract.py +466 -0
- nimare/extract/utils.py +295 -0
- nimare/generate.py +331 -0
- nimare/io.py +667 -0
- nimare/meta/__init__.py +39 -0
- nimare/meta/cbma/__init__.py +6 -0
- nimare/meta/cbma/ale.py +951 -0
- nimare/meta/cbma/base.py +947 -0
- nimare/meta/cbma/mkda.py +1361 -0
- nimare/meta/cbmr.py +970 -0
- nimare/meta/ibma.py +1683 -0
- nimare/meta/kernel.py +501 -0
- nimare/meta/models.py +1199 -0
- nimare/meta/utils.py +494 -0
- nimare/nimads.py +492 -0
- nimare/reports/__init__.py +24 -0
- nimare/reports/base.py +664 -0
- nimare/reports/default.yml +123 -0
- nimare/reports/figures.py +651 -0
- nimare/reports/report.tpl +160 -0
- nimare/resources/__init__.py +1 -0
- nimare/resources/atlases/Harvard-Oxford-LICENSE +93 -0
- nimare/resources/atlases/HarvardOxford-cort-maxprob-thr25-2mm.nii.gz +0 -0
- nimare/resources/database_file_manifest.json +142 -0
- nimare/resources/english_spellings.csv +1738 -0
- nimare/resources/filenames.json +32 -0
- nimare/resources/neurosynth_laird_studies.json +58773 -0
- nimare/resources/neurosynth_stoplist.txt +396 -0
- nimare/resources/nidm_pain_dset.json +1349 -0
- nimare/resources/references.bib +541 -0
- nimare/resources/semantic_knowledge_children.txt +325 -0
- nimare/resources/semantic_relatedness_children.txt +249 -0
- nimare/resources/templates/MNI152_2x2x2_brainmask.nii.gz +0 -0
- nimare/resources/templates/tpl-MNI152NLin6Asym_res-01_T1w.nii.gz +0 -0
- nimare/resources/templates/tpl-MNI152NLin6Asym_res-01_desc-brain_mask.nii.gz +0 -0
- nimare/resources/templates/tpl-MNI152NLin6Asym_res-02_T1w.nii.gz +0 -0
- nimare/resources/templates/tpl-MNI152NLin6Asym_res-02_desc-brain_mask.nii.gz +0 -0
- nimare/results.py +225 -0
- nimare/stats.py +276 -0
- nimare/tests/__init__.py +1 -0
- nimare/tests/conftest.py +229 -0
- nimare/tests/data/amygdala_roi.nii.gz +0 -0
- nimare/tests/data/data-neurosynth_version-7_coordinates.tsv.gz +0 -0
- nimare/tests/data/data-neurosynth_version-7_metadata.tsv.gz +0 -0
- nimare/tests/data/data-neurosynth_version-7_vocab-terms_source-abstract_type-tfidf_features.npz +0 -0
- nimare/tests/data/data-neurosynth_version-7_vocab-terms_vocabulary.txt +100 -0
- nimare/tests/data/neurosynth_dset.json +2868 -0
- nimare/tests/data/neurosynth_laird_studies.json +58773 -0
- nimare/tests/data/nidm_pain_dset.json +1349 -0
- nimare/tests/data/nimads_annotation.json +1 -0
- nimare/tests/data/nimads_studyset.json +1 -0
- nimare/tests/data/test_baseline.txt +2 -0
- nimare/tests/data/test_pain_dataset.json +1278 -0
- nimare/tests/data/test_pain_dataset_multiple_contrasts.json +1242 -0
- nimare/tests/data/test_sleuth_file.txt +18 -0
- nimare/tests/data/test_sleuth_file2.txt +10 -0
- nimare/tests/data/test_sleuth_file3.txt +5 -0
- nimare/tests/data/test_sleuth_file4.txt +5 -0
- nimare/tests/data/test_sleuth_file5.txt +5 -0
- nimare/tests/test_annotate_cogat.py +32 -0
- nimare/tests/test_annotate_gclda.py +86 -0
- nimare/tests/test_annotate_lda.py +27 -0
- nimare/tests/test_dataset.py +99 -0
- nimare/tests/test_decode_continuous.py +132 -0
- nimare/tests/test_decode_discrete.py +92 -0
- nimare/tests/test_diagnostics.py +168 -0
- nimare/tests/test_estimator_performance.py +385 -0
- nimare/tests/test_extract.py +46 -0
- nimare/tests/test_generate.py +247 -0
- nimare/tests/test_io.py +294 -0
- nimare/tests/test_meta_ale.py +298 -0
- nimare/tests/test_meta_cbmr.py +295 -0
- nimare/tests/test_meta_ibma.py +240 -0
- nimare/tests/test_meta_kernel.py +209 -0
- nimare/tests/test_meta_mkda.py +234 -0
- nimare/tests/test_nimads.py +21 -0
- nimare/tests/test_reports.py +110 -0
- nimare/tests/test_stats.py +101 -0
- nimare/tests/test_transforms.py +272 -0
- nimare/tests/test_utils.py +200 -0
- nimare/tests/test_workflows.py +221 -0
- nimare/tests/utils.py +126 -0
- nimare/transforms.py +907 -0
- nimare/utils.py +1367 -0
- nimare/workflows/__init__.py +14 -0
- nimare/workflows/base.py +189 -0
- nimare/workflows/cbma.py +165 -0
- nimare/workflows/ibma.py +108 -0
- nimare/workflows/macm.py +77 -0
- nimare/workflows/misc.py +65 -0
- nimare-0.4.2.dist-info/LICENSE +21 -0
- nimare-0.4.2.dist-info/METADATA +124 -0
- nimare-0.4.2.dist-info/RECORD +119 -0
- nimare-0.4.2.dist-info/WHEEL +5 -0
- nimare-0.4.2.dist-info/entry_points.txt +2 -0
- nimare-0.4.2.dist-info/top_level.txt +2 -0
nimare/meta/models.py
ADDED
@@ -0,0 +1,1199 @@
|
|
1
|
+
"""CBMR Models."""
|
2
|
+
|
3
|
+
import abc
|
4
|
+
import logging
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import pandas as pd
|
8
|
+
|
9
|
+
try:
|
10
|
+
import torch
|
11
|
+
except ImportError as e:
|
12
|
+
raise ImportError(
|
13
|
+
"Torch is required to use `CBMR` models. Install with `pip install 'nimare[cbmr]'`."
|
14
|
+
) from e
|
15
|
+
|
16
|
+
LGR = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
|
19
|
+
class GeneralLinearModelEstimator(torch.nn.Module):
|
20
|
+
"""Base class for GLM estimators.
|
21
|
+
|
22
|
+
Parameters
|
23
|
+
----------
|
24
|
+
spatial_coef_dim : :obj:`int`
|
25
|
+
Number of spatial B-spline bases. Default is None.
|
26
|
+
moderators_coef_dim : :obj:`int`, optional
|
27
|
+
Number of study-level moderators. Default is None.
|
28
|
+
penalty : :obj:`bool`
|
29
|
+
Whether to Firth-type regularization term. Default is False.
|
30
|
+
lr : :obj:`float`
|
31
|
+
Learning rate. Default is 0.1.
|
32
|
+
lr_decay : :obj:`float`
|
33
|
+
Learning rate decay for each iteration. Default is 0.999.
|
34
|
+
n_iter : :obj:`int`
|
35
|
+
Maximum number of iterations. Default is 1000.
|
36
|
+
tol : :obj:`float`
|
37
|
+
Tolerance for convergence. Default is 1e-2.
|
38
|
+
device : :obj:`str`
|
39
|
+
Device to use for computations. Default is "cpu".
|
40
|
+
"""
|
41
|
+
|
42
|
+
_hessian_kwargs = {
|
43
|
+
"create_graph": False,
|
44
|
+
"vectorize": True,
|
45
|
+
"outer_jacobian_strategy": "forward-mode",
|
46
|
+
}
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
spatial_coef_dim=None,
|
51
|
+
moderators_coef_dim=None,
|
52
|
+
penalty=False,
|
53
|
+
lr=1,
|
54
|
+
lr_decay=0.999,
|
55
|
+
n_iter=2000,
|
56
|
+
tol=1e-9,
|
57
|
+
device="cpu",
|
58
|
+
):
|
59
|
+
super().__init__()
|
60
|
+
self.spatial_coef_dim = spatial_coef_dim
|
61
|
+
self.moderators_coef_dim = moderators_coef_dim
|
62
|
+
self.penalty = penalty
|
63
|
+
self.lr = lr
|
64
|
+
self.lr_decay = lr_decay
|
65
|
+
self.n_iter = n_iter
|
66
|
+
self.tol = tol
|
67
|
+
self.device = device
|
68
|
+
|
69
|
+
# initialization for iteration set up
|
70
|
+
self.iter = 0
|
71
|
+
|
72
|
+
# after fitting, the following attributes will be created
|
73
|
+
self.spatial_regression_coef = None
|
74
|
+
self.spatial_intensity_estimation = None
|
75
|
+
self.moderators_coef = None
|
76
|
+
self.moderators_effect = None
|
77
|
+
self.spatial_regression_coef_se = None
|
78
|
+
self.log_spatial_intensity_se = None
|
79
|
+
self.spatial_intensity_se = None
|
80
|
+
self.se_moderators = None
|
81
|
+
|
82
|
+
@abc.abstractmethod
|
83
|
+
def _log_likelihood_single_group(self, **kwargs):
|
84
|
+
"""Log-likelihood of a single group.
|
85
|
+
|
86
|
+
Returns
|
87
|
+
-------
|
88
|
+
torch.Tensor
|
89
|
+
Value of the log-likelihood of a single group.
|
90
|
+
"""
|
91
|
+
pass
|
92
|
+
|
93
|
+
@abc.abstractmethod
|
94
|
+
def _log_likelihood_mult_group(self, **kwargs):
|
95
|
+
"""Total log-likelihood of all groups in the dataset.
|
96
|
+
|
97
|
+
Returns
|
98
|
+
-------
|
99
|
+
torch.Tensor
|
100
|
+
Value of total log-likelihood of all groups in the dataset.
|
101
|
+
"""
|
102
|
+
pass
|
103
|
+
|
104
|
+
@abc.abstractmethod
|
105
|
+
def forward(self, **kwargs):
|
106
|
+
"""Define the loss function (nagetive log-likelihood function) for each model.
|
107
|
+
|
108
|
+
Returns
|
109
|
+
-------
|
110
|
+
torch.Tensor
|
111
|
+
Value of the log-likelihood of a single group.
|
112
|
+
"""
|
113
|
+
pass
|
114
|
+
|
115
|
+
def init_spatial_weights(self):
|
116
|
+
"""Initialize spatial regression coefficients.
|
117
|
+
|
118
|
+
Default is uniform distribution between -0.01 and 0.01.
|
119
|
+
"""
|
120
|
+
# initialization for spatial regression coefficients
|
121
|
+
spatial_coef_linears = dict()
|
122
|
+
for group in self.groups:
|
123
|
+
spatial_coef_linear_group = torch.nn.Linear(
|
124
|
+
self.spatial_coef_dim, 1, bias=False
|
125
|
+
).double()
|
126
|
+
torch.nn.init.uniform_(spatial_coef_linear_group.weight, a=-0.01, b=0.01)
|
127
|
+
spatial_coef_linears[group] = spatial_coef_linear_group
|
128
|
+
self.spatial_coef_linears = torch.nn.ModuleDict(spatial_coef_linears)
|
129
|
+
|
130
|
+
def init_moderator_weights(self):
|
131
|
+
"""Initialize the intercept and regression coefficients for moderators.
|
132
|
+
|
133
|
+
Default is uniform distribution between -0.01 and 0.01.
|
134
|
+
"""
|
135
|
+
self.moderators_linear = torch.nn.Linear(self.moderators_coef_dim, 1, bias=False).double()
|
136
|
+
torch.nn.init.uniform_(self.moderators_linear.weight, a=-0.01, b=0.01)
|
137
|
+
return
|
138
|
+
|
139
|
+
def init_weights(self, groups, moderators, spatial_coef_dim, moderators_coef_dim):
|
140
|
+
"""Initialize regression coefficients of spatial struture and study-level moderators."""
|
141
|
+
self.groups = groups
|
142
|
+
self.moderators = moderators
|
143
|
+
self.spatial_coef_dim = spatial_coef_dim
|
144
|
+
self.moderators_coef_dim = moderators_coef_dim
|
145
|
+
self.init_spatial_weights()
|
146
|
+
if moderators_coef_dim:
|
147
|
+
self.init_moderator_weights()
|
148
|
+
|
149
|
+
def _update(
|
150
|
+
self,
|
151
|
+
optimizer,
|
152
|
+
coef_spline_bases,
|
153
|
+
moderators,
|
154
|
+
foci_per_voxel,
|
155
|
+
foci_per_study,
|
156
|
+
prev_loss,
|
157
|
+
):
|
158
|
+
"""One iteration in optimization with L-BFGS.
|
159
|
+
|
160
|
+
Adjust learning rate based on the number of iteration (with learning rate decay parameter
|
161
|
+
`lr_decay`, default value is 0.999). Reset L-BFGS optimizer (as params in the previous
|
162
|
+
iteration) if NaN occurs.
|
163
|
+
|
164
|
+
Parameters
|
165
|
+
----------
|
166
|
+
optimizer : :obj:`torch.optim.lbfgs.LBFGS`
|
167
|
+
L-BFGS optimizer.
|
168
|
+
coef_spline_bases : :obj:`torch.Tensor`
|
169
|
+
Coefficient of B-spline bases evaluated at each voxel.
|
170
|
+
moderators : :obj:`dict`, optional
|
171
|
+
Dictionary of group-wise study-level moderators. Default is None.
|
172
|
+
foci_per_voxel : :obj:`dict`
|
173
|
+
Dictionary of group-wise number of foci per voxel.
|
174
|
+
foci_per_study : :obj:`dict`
|
175
|
+
Dictionary of group-wise number of foci per study.
|
176
|
+
prev_loss : :obj:`torch.Tensor`
|
177
|
+
Value of the loss function of the previous iteration.
|
178
|
+
|
179
|
+
Returns
|
180
|
+
-------
|
181
|
+
torch.Tensor
|
182
|
+
Updated value of the loss (negative log-likelihood) function.
|
183
|
+
"""
|
184
|
+
self.iter += 1
|
185
|
+
scheduler = torch.optim.lr_scheduler.ExponentialLR(
|
186
|
+
optimizer, gamma=self.lr_decay
|
187
|
+
) # learning rate decay
|
188
|
+
|
189
|
+
def closure():
|
190
|
+
optimizer.zero_grad()
|
191
|
+
loss = self(coef_spline_bases, moderators, foci_per_voxel, foci_per_study)
|
192
|
+
loss.backward()
|
193
|
+
return loss
|
194
|
+
|
195
|
+
optimizer.step(closure)
|
196
|
+
scheduler.step()
|
197
|
+
# recalculate the loss function
|
198
|
+
loss = self(coef_spline_bases, moderators, foci_per_voxel, foci_per_study)
|
199
|
+
|
200
|
+
if torch.isnan(loss):
|
201
|
+
raise ValueError(
|
202
|
+
f"""The current learing rate {str(self.lr)} or choice of model gives rise to
|
203
|
+
NaN log-likelihood, please try Poisson model or adjust learning rate to a smaller
|
204
|
+
value."""
|
205
|
+
)
|
206
|
+
|
207
|
+
return loss
|
208
|
+
|
209
|
+
def _optimizer(self, coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study):
|
210
|
+
"""
|
211
|
+
Optimize the loss (negative log-likelihood) function with L-BFGS.
|
212
|
+
|
213
|
+
Parameters
|
214
|
+
----------
|
215
|
+
coef_spline_bases : :obj:`numpy.ndarray`
|
216
|
+
Coefficient of B-spline bases evaluated at each voxel.
|
217
|
+
moderators_by_group : :obj:`dict`, optional
|
218
|
+
Dictionary of group-wise study-level moderators.
|
219
|
+
foci_per_voxel : :obj:`dict`
|
220
|
+
Dictionary of group-wise number of foci per voxel.
|
221
|
+
foci_per_study : :obj:`dict`
|
222
|
+
Dictionary of group-wise number of foci per study.
|
223
|
+
"""
|
224
|
+
torch.manual_seed(100)
|
225
|
+
optimizer = torch.optim.LBFGS(
|
226
|
+
params=self.parameters(),
|
227
|
+
lr=self.lr,
|
228
|
+
max_iter=self.n_iter,
|
229
|
+
tolerance_change=self.tol,
|
230
|
+
line_search_fn="strong_wolfe",
|
231
|
+
)
|
232
|
+
# load dataset info to torch.tensor
|
233
|
+
coef_spline_bases = torch.tensor(
|
234
|
+
coef_spline_bases, dtype=torch.float64, device=self.device
|
235
|
+
)
|
236
|
+
if moderators_by_group:
|
237
|
+
moderators_by_group_tensor = dict()
|
238
|
+
for group in self.groups:
|
239
|
+
moderators_tensor = torch.tensor(
|
240
|
+
moderators_by_group[group], dtype=torch.float64, device=self.device
|
241
|
+
)
|
242
|
+
moderators_by_group_tensor[group] = moderators_tensor
|
243
|
+
else:
|
244
|
+
moderators_by_group_tensor = None
|
245
|
+
foci_per_voxel_tensor, foci_per_study_tensor = dict(), dict()
|
246
|
+
for group in self.groups:
|
247
|
+
group_foci_per_voxel_tensor = torch.tensor(
|
248
|
+
foci_per_voxel[group], dtype=torch.float64, device=self.device
|
249
|
+
)
|
250
|
+
group_foci_per_study_tensor = torch.tensor(
|
251
|
+
foci_per_study[group], dtype=torch.float64, device=self.device
|
252
|
+
)
|
253
|
+
foci_per_voxel_tensor[group] = group_foci_per_voxel_tensor
|
254
|
+
foci_per_study_tensor[group] = group_foci_per_study_tensor
|
255
|
+
|
256
|
+
if self.iter == 0:
|
257
|
+
prev_loss = torch.tensor(float("inf")) # initialization loss difference
|
258
|
+
|
259
|
+
self._update(
|
260
|
+
optimizer,
|
261
|
+
coef_spline_bases,
|
262
|
+
moderators_by_group_tensor,
|
263
|
+
foci_per_voxel_tensor,
|
264
|
+
foci_per_study_tensor,
|
265
|
+
prev_loss,
|
266
|
+
)
|
267
|
+
|
268
|
+
return
|
269
|
+
|
270
|
+
def fit(self, coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study):
|
271
|
+
"""Fit the model and estimate standard error of estimates."""
|
272
|
+
self._optimizer(coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study)
|
273
|
+
self.extract_optimized_params(coef_spline_bases, moderators_by_group)
|
274
|
+
self.standard_error_estimation(
|
275
|
+
coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study
|
276
|
+
)
|
277
|
+
|
278
|
+
return
|
279
|
+
|
280
|
+
def extract_optimized_params(self, coef_spline_bases, moderators_by_group):
|
281
|
+
"""Extract optimized regression coefficient of study-level moderators from the model."""
|
282
|
+
spatial_regression_coef, spatial_intensity_estimation = dict(), dict()
|
283
|
+
for group in self.groups:
|
284
|
+
# Extract optimized spatial regression coefficients from the model
|
285
|
+
group_spatial_coef_linear_weight = self.spatial_coef_linears[group].weight
|
286
|
+
group_spatial_coef_linear_weight = (
|
287
|
+
group_spatial_coef_linear_weight.cpu().detach().numpy().flatten()
|
288
|
+
)
|
289
|
+
spatial_regression_coef[group] = group_spatial_coef_linear_weight
|
290
|
+
# Estimate group-specific spatial intensity
|
291
|
+
group_spatial_intensity_estimation = np.exp(
|
292
|
+
np.matmul(coef_spline_bases, group_spatial_coef_linear_weight)
|
293
|
+
)
|
294
|
+
spatial_intensity_estimation["spatialIntensity_group-" + group] = (
|
295
|
+
group_spatial_intensity_estimation
|
296
|
+
)
|
297
|
+
|
298
|
+
# Extract optimized regression coefficient of study-level moderators from the model
|
299
|
+
if self.moderators_coef_dim:
|
300
|
+
moderators_effect = dict()
|
301
|
+
moderators_coef = self.moderators_linear.weight
|
302
|
+
moderators_coef = moderators_coef.cpu().detach().numpy()
|
303
|
+
for group in self.groups:
|
304
|
+
group_moderators = moderators_by_group[group]
|
305
|
+
group_moderators_effect = np.exp(np.matmul(group_moderators, moderators_coef.T))
|
306
|
+
moderators_effect[group] = group_moderators_effect.flatten()
|
307
|
+
else:
|
308
|
+
moderators_coef, moderators_effect = None, None
|
309
|
+
|
310
|
+
self.spatial_regression_coef = spatial_regression_coef
|
311
|
+
self.spatial_intensity_estimation = spatial_intensity_estimation
|
312
|
+
self.moderators_coef = moderators_coef
|
313
|
+
self.moderators_effect = moderators_effect
|
314
|
+
|
315
|
+
def standard_error_estimation(
|
316
|
+
self, coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study
|
317
|
+
):
|
318
|
+
"""Estimate standard error of estimates.
|
319
|
+
|
320
|
+
For spatial regression coefficients, we estimate its covariance matrix using Fisher
|
321
|
+
Information Matrix and then take the square root of the diagonal elements.
|
322
|
+
For log spatial intensity, we use the delta method to estimate its standard error.
|
323
|
+
For models with over-dispersion parameter, we also estimate its standard error.
|
324
|
+
"""
|
325
|
+
spatial_regression_coef_se, log_spatial_intensity_se, spatial_intensity_se = (
|
326
|
+
dict(),
|
327
|
+
dict(),
|
328
|
+
dict(),
|
329
|
+
)
|
330
|
+
for group in self.groups:
|
331
|
+
group_foci_per_voxel = torch.tensor(
|
332
|
+
foci_per_voxel[group], dtype=torch.float64, device=self.device
|
333
|
+
)
|
334
|
+
group_foci_per_study = torch.tensor(
|
335
|
+
foci_per_study[group], dtype=torch.float64, device=self.device
|
336
|
+
)
|
337
|
+
group_spatial_coef = self.spatial_coef_linears[group].weight
|
338
|
+
if self.moderators_coef_dim:
|
339
|
+
group_moderators = torch.tensor(
|
340
|
+
moderators_by_group[group], dtype=torch.float64, device=self.device
|
341
|
+
)
|
342
|
+
moderators_coef = self.moderators_linear.weight
|
343
|
+
else:
|
344
|
+
group_moderators, moderators_coef = None, None
|
345
|
+
|
346
|
+
ll_single_group_kwargs = {
|
347
|
+
"moderators_coef": moderators_coef if self.moderators_coef_dim else None,
|
348
|
+
"coef_spline_bases": torch.tensor(
|
349
|
+
coef_spline_bases, dtype=torch.float64, device=self.device
|
350
|
+
),
|
351
|
+
"group_moderators": group_moderators if self.moderators_coef_dim else None,
|
352
|
+
"group_foci_per_voxel": group_foci_per_voxel,
|
353
|
+
"group_foci_per_study": group_foci_per_study,
|
354
|
+
"device": self.device,
|
355
|
+
}
|
356
|
+
|
357
|
+
if hasattr(self, "overdispersion"):
|
358
|
+
ll_single_group_kwargs["group_overdispersion"] = self.overdispersion[group]
|
359
|
+
|
360
|
+
# create a negative log-likelihood function
|
361
|
+
def nll_spatial_coef(group_spatial_coef):
|
362
|
+
return -self._log_likelihood_single_group(
|
363
|
+
group_spatial_coef=group_spatial_coef,
|
364
|
+
**ll_single_group_kwargs,
|
365
|
+
)
|
366
|
+
|
367
|
+
f_spatial_coef = torch.func.hessian(nll_spatial_coef)(group_spatial_coef)
|
368
|
+
f_spatial_coef = f_spatial_coef.reshape((self.spatial_coef_dim, self.spatial_coef_dim))
|
369
|
+
cov_spatial_coef = np.linalg.inv(f_spatial_coef.detach().numpy())
|
370
|
+
var_spatial_coef = np.diag(cov_spatial_coef)
|
371
|
+
se_spatial_coef = np.sqrt(var_spatial_coef)
|
372
|
+
spatial_regression_coef_se[group] = se_spatial_coef
|
373
|
+
|
374
|
+
var_log_spatial_intensity = np.einsum(
|
375
|
+
"ij,ji->i",
|
376
|
+
coef_spline_bases,
|
377
|
+
cov_spatial_coef @ coef_spline_bases.T,
|
378
|
+
)
|
379
|
+
se_log_spatial_intensity = np.sqrt(var_log_spatial_intensity)
|
380
|
+
log_spatial_intensity_se[group] = se_log_spatial_intensity
|
381
|
+
|
382
|
+
group_studywise_spatial_intensity = np.exp(
|
383
|
+
np.matmul(coef_spline_bases, group_spatial_coef.detach().cpu().numpy().T)
|
384
|
+
).flatten()
|
385
|
+
se_spatial_intensity = group_studywise_spatial_intensity * se_log_spatial_intensity
|
386
|
+
spatial_intensity_se[group] = se_spatial_intensity
|
387
|
+
|
388
|
+
# Inference on regression coefficient of moderators
|
389
|
+
if self.moderators_coef_dim:
|
390
|
+
# modify ll_single_group_kwargs so that spatial_coef is fixed
|
391
|
+
# and moderators_coef can vary
|
392
|
+
del ll_single_group_kwargs["moderators_coef"]
|
393
|
+
ll_single_group_kwargs["group_spatial_coef"] = group_spatial_coef
|
394
|
+
|
395
|
+
def nll_moderators_coef(moderators_coef):
|
396
|
+
return -self._log_likelihood_single_group(
|
397
|
+
moderators_coef=moderators_coef,
|
398
|
+
**ll_single_group_kwargs,
|
399
|
+
)
|
400
|
+
|
401
|
+
f_moderators_coef = torch.func.hessian(nll_moderators_coef)(moderators_coef)
|
402
|
+
f_moderators_coef = f_moderators_coef.reshape(
|
403
|
+
(self.moderators_coef_dim, self.moderators_coef_dim)
|
404
|
+
)
|
405
|
+
cov_moderators_coef = np.linalg.inv(f_moderators_coef.detach().numpy())
|
406
|
+
var_moderators = np.diag(cov_moderators_coef).reshape((1, self.moderators_coef_dim))
|
407
|
+
se_moderators = np.sqrt(var_moderators)
|
408
|
+
else:
|
409
|
+
se_moderators = None
|
410
|
+
|
411
|
+
self.spatial_regression_coef_se = spatial_regression_coef_se
|
412
|
+
self.log_spatial_intensity_se = log_spatial_intensity_se
|
413
|
+
self.spatial_intensity_se = spatial_intensity_se
|
414
|
+
self.se_moderators = se_moderators
|
415
|
+
|
416
|
+
def summary(self):
|
417
|
+
"""Summarize the main results of the fitted model.
|
418
|
+
|
419
|
+
Summarize optimized regression coefficients from model and store in `tables`,
|
420
|
+
summarize standard error of regression coefficient and (Log-)spatial intensity
|
421
|
+
and store in `results`.
|
422
|
+
"""
|
423
|
+
params = (
|
424
|
+
self.spatial_regression_coef,
|
425
|
+
self.spatial_intensity_estimation,
|
426
|
+
self.spatial_regression_coef_se,
|
427
|
+
self.log_spatial_intensity_se,
|
428
|
+
self.spatial_intensity_se,
|
429
|
+
)
|
430
|
+
if any([param is None for param in params]):
|
431
|
+
raise ValueError("Run fit first")
|
432
|
+
tables = dict()
|
433
|
+
# Extract optimized regression coefficients from model and store them in 'tables'
|
434
|
+
tables["spatial_regression_coef"] = pd.DataFrame.from_dict(
|
435
|
+
self.spatial_regression_coef, orient="index"
|
436
|
+
)
|
437
|
+
maps = self.spatial_intensity_estimation
|
438
|
+
if self.moderators_coef_dim:
|
439
|
+
tables["moderators_regression_coef"] = pd.DataFrame(
|
440
|
+
data=self.moderators_coef, columns=self.moderators
|
441
|
+
)
|
442
|
+
tables["moderators_effect"] = pd.DataFrame.from_dict(
|
443
|
+
data=self.moderators_effect, orient="index"
|
444
|
+
)
|
445
|
+
|
446
|
+
# Estimate standard error of regression coefficient and (Log-)spatial intensity and store
|
447
|
+
# them in 'tables'
|
448
|
+
tables["spatial_regression_coef_se"] = pd.DataFrame.from_dict(
|
449
|
+
self.spatial_regression_coef_se, orient="index"
|
450
|
+
)
|
451
|
+
tables["log_spatial_intensity_se"] = pd.DataFrame.from_dict(
|
452
|
+
self.log_spatial_intensity_se, orient="index"
|
453
|
+
)
|
454
|
+
tables["spatial_intensity_se"] = pd.DataFrame.from_dict(
|
455
|
+
self.spatial_intensity_se, orient="index"
|
456
|
+
)
|
457
|
+
if self.moderators_coef_dim:
|
458
|
+
tables["moderators_regression_se"] = pd.DataFrame(
|
459
|
+
data=self.se_moderators, columns=self.moderators
|
460
|
+
)
|
461
|
+
return maps, tables
|
462
|
+
|
463
|
+
def fisher_info_multiple_group_spatial(
|
464
|
+
self,
|
465
|
+
involved_groups,
|
466
|
+
coef_spline_bases,
|
467
|
+
moderators_by_group,
|
468
|
+
foci_per_voxel,
|
469
|
+
foci_per_study,
|
470
|
+
):
|
471
|
+
"""Estimate the Fisher information matrix of spatial regression coeffcients.
|
472
|
+
|
473
|
+
Fisher information matrix is estimated by negative Hessian of the log-likelihood.
|
474
|
+
|
475
|
+
Parameters
|
476
|
+
----------
|
477
|
+
involved_groups : :obj:`list`
|
478
|
+
Group names involved in generalized linear hypothesis (GLH) testing in `CBMRInference`.
|
479
|
+
coef_spline_bases : :obj:`numpy.ndarray`
|
480
|
+
Coefficient of B-spline bases evaluated at each voxel.
|
481
|
+
moderators_by_group : :obj:`dict`, optional
|
482
|
+
Dictionary of group-wise study-level moderators. Default is None.
|
483
|
+
foci_per_voxel : :obj:`dict`
|
484
|
+
Dictionary of group-wise number of foci per voxel.
|
485
|
+
foci_per_study : :obj:`dict`
|
486
|
+
Dictionary of group-wise number of foci per study.
|
487
|
+
|
488
|
+
Returns
|
489
|
+
-------
|
490
|
+
numpy.ndarray
|
491
|
+
Fisher information matrix of spatial regression coefficients (for involved groups).
|
492
|
+
"""
|
493
|
+
n_involved_groups = len(involved_groups)
|
494
|
+
involved_foci_per_voxel = [
|
495
|
+
torch.tensor(foci_per_voxel[group], dtype=torch.float64, device=self.device)
|
496
|
+
for group in involved_groups
|
497
|
+
]
|
498
|
+
involved_foci_per_study = [
|
499
|
+
torch.tensor(foci_per_study[group], dtype=torch.float64, device=self.device)
|
500
|
+
for group in involved_groups
|
501
|
+
]
|
502
|
+
spatial_coef = [self.spatial_coef_linears[group].weight.T for group in involved_groups]
|
503
|
+
spatial_coef = torch.stack(spatial_coef, dim=0)
|
504
|
+
if self.moderators_coef_dim:
|
505
|
+
involved_moderators_by_group = [
|
506
|
+
torch.tensor(moderators_by_group[group], dtype=torch.float64, device=self.device)
|
507
|
+
for group in involved_groups
|
508
|
+
]
|
509
|
+
moderators_coef = torch.tensor(
|
510
|
+
self.moderators_coef.T, dtype=torch.float64, device=self.device
|
511
|
+
)
|
512
|
+
else:
|
513
|
+
involved_moderators_by_group, moderators_coef = None, None
|
514
|
+
|
515
|
+
ll_mult_group_kwargs = {
|
516
|
+
"moderator_coef": moderators_coef,
|
517
|
+
"coef_spline_bases": torch.tensor(
|
518
|
+
coef_spline_bases, dtype=torch.float64, device=self.device
|
519
|
+
),
|
520
|
+
"foci_per_voxel": involved_foci_per_voxel,
|
521
|
+
"foci_per_study": involved_foci_per_study,
|
522
|
+
"moderators": involved_moderators_by_group,
|
523
|
+
"device": self.device,
|
524
|
+
}
|
525
|
+
|
526
|
+
if hasattr(self, "overdispersion"):
|
527
|
+
ll_mult_group_kwargs["overdispersion_coef"] = [
|
528
|
+
self.overdispersion[group] for group in involved_groups
|
529
|
+
]
|
530
|
+
|
531
|
+
# create a negative log-likelihood function
|
532
|
+
def nll_spatial_coef(spatial_coef):
|
533
|
+
return -self._log_likelihood_mult_group(
|
534
|
+
spatial_coef=spatial_coef,
|
535
|
+
**ll_mult_group_kwargs,
|
536
|
+
)
|
537
|
+
|
538
|
+
h = torch.func.hessian(nll_spatial_coef)(spatial_coef)
|
539
|
+
h = h.view(n_involved_groups * self.spatial_coef_dim, -1)
|
540
|
+
|
541
|
+
return h.detach().cpu().numpy()
|
542
|
+
|
543
|
+
def fisher_info_multiple_group_moderator(
|
544
|
+
self, coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study
|
545
|
+
):
|
546
|
+
"""Estimate the Fisher information matrix of regression coefficients of moderators.
|
547
|
+
|
548
|
+
Fisher information matrix is estimated by negative Hessian of the log-likelihood.
|
549
|
+
|
550
|
+
Parameters
|
551
|
+
----------
|
552
|
+
coef_spline_bases : :obj:`numpy.ndarray`
|
553
|
+
Coefficient of B-spline bases evaluated at each voxel.
|
554
|
+
moderators_by_group : :obj:`dict`, optional
|
555
|
+
Dictionary of group-wise study-level moderators. Default is None.
|
556
|
+
foci_per_voxel : :obj:`dict`
|
557
|
+
Dictionary of group-wise number of foci per voxel.
|
558
|
+
foci_per_study : :obj:`dict`
|
559
|
+
Dictionary of group-wise number of foci per study.
|
560
|
+
|
561
|
+
Returns
|
562
|
+
-------
|
563
|
+
numpy.ndarray
|
564
|
+
Fisher information matrix of study-level moderator regressors.
|
565
|
+
"""
|
566
|
+
foci_per_voxel = [
|
567
|
+
torch.tensor(foci_per_voxel[group], dtype=torch.float64, device=self.device)
|
568
|
+
for group in self.groups
|
569
|
+
]
|
570
|
+
foci_per_study = [
|
571
|
+
torch.tensor(foci_per_study[group], dtype=torch.float64, device=self.device)
|
572
|
+
for group in self.groups
|
573
|
+
]
|
574
|
+
spatial_coef = [self.spatial_coef_linears[group].weight.T for group in self.groups]
|
575
|
+
spatial_coef = torch.stack(spatial_coef, dim=0)
|
576
|
+
|
577
|
+
if self.moderators_coef_dim:
|
578
|
+
moderators_by_group = [
|
579
|
+
torch.tensor(moderators_by_group[group], dtype=torch.float64, device=self.device)
|
580
|
+
for group in self.groups
|
581
|
+
]
|
582
|
+
moderator_coef = torch.tensor(
|
583
|
+
self.moderators_coef.T, dtype=torch.float64, device=self.device
|
584
|
+
)
|
585
|
+
else:
|
586
|
+
moderators_by_group, moderator_coef = None, None
|
587
|
+
|
588
|
+
ll_mult_group_kwargs = {
|
589
|
+
"spatial_coef": spatial_coef,
|
590
|
+
"coef_spline_bases": torch.tensor(
|
591
|
+
coef_spline_bases, dtype=torch.float64, device=self.device
|
592
|
+
),
|
593
|
+
"foci_per_voxel": foci_per_voxel,
|
594
|
+
"foci_per_study": foci_per_study,
|
595
|
+
"moderators": moderators_by_group,
|
596
|
+
"device": self.device,
|
597
|
+
}
|
598
|
+
if hasattr(self, "overdispersion"):
|
599
|
+
ll_mult_group_kwargs["overdispersion_coef"] = [
|
600
|
+
self.overdispersion[group] for group in self.groups
|
601
|
+
]
|
602
|
+
|
603
|
+
# create a negative log-likelihood function w.r.t moderator coefficients
|
604
|
+
def nll_moderator_coef(moderator_coef):
|
605
|
+
return -self._log_likelihood_mult_group(
|
606
|
+
moderator_coef=moderator_coef,
|
607
|
+
**ll_mult_group_kwargs,
|
608
|
+
)
|
609
|
+
|
610
|
+
h = torch.func.hessian(nll_moderator_coef)(moderator_coef)
|
611
|
+
h = h.view(self.moderators_coef_dim, self.moderators_coef_dim)
|
612
|
+
|
613
|
+
return h.detach().cpu().numpy()
|
614
|
+
|
615
|
+
def firth_penalty(
|
616
|
+
self,
|
617
|
+
foci_per_voxel,
|
618
|
+
foci_per_study,
|
619
|
+
moderators,
|
620
|
+
coef_spline_bases,
|
621
|
+
overdispersion=False,
|
622
|
+
):
|
623
|
+
"""Compute Firth's penalized log-likelihood.
|
624
|
+
|
625
|
+
Parameters
|
626
|
+
----------
|
627
|
+
foci_per_voxel : :obj:`dict`
|
628
|
+
Dictionary of group-wise number of foci per voxel.
|
629
|
+
foci_per_study : :obj:`dict`
|
630
|
+
Dictionary of group-wise number of foci per study.
|
631
|
+
moderators : :obj:`dict`, optional
|
632
|
+
Dictionary of group-wise study-level moderators. Default is None.
|
633
|
+
coef_spline_bases : :obj:`torch.Tensor`
|
634
|
+
Coefficient of B-spline bases evaluated at each voxel.
|
635
|
+
overdispersion : :obj:`bool`
|
636
|
+
Whether the model contains overdispersion parameter. Default is False.
|
637
|
+
|
638
|
+
Returns
|
639
|
+
-------
|
640
|
+
torch.Tensor
|
641
|
+
Firth-type regularization term.
|
642
|
+
"""
|
643
|
+
group_firth_penalty = 0
|
644
|
+
for group in self.groups:
|
645
|
+
partial_kwargs = {"coef_spline_bases": coef_spline_bases}
|
646
|
+
if overdispersion:
|
647
|
+
partial_kwargs["group_overdispersion"] = self.overdispersion[group]
|
648
|
+
if getattr(self, "square_root", False):
|
649
|
+
partial_kwargs["group_overdispersion"] = (
|
650
|
+
partial_kwargs["group_overdispersion"] ** 2
|
651
|
+
)
|
652
|
+
partial_kwargs["group_foci_per_voxel"] = foci_per_voxel[group]
|
653
|
+
partial_kwargs["group_foci_per_study"] = foci_per_study[group]
|
654
|
+
if self.moderators_coef_dim:
|
655
|
+
moderators_coef = self.moderators_linear.weight
|
656
|
+
group_moderators = moderators[group]
|
657
|
+
else:
|
658
|
+
moderators_coef, group_moderators = None, None
|
659
|
+
partial_kwargs["moderators_coef"] = moderators_coef
|
660
|
+
partial_kwargs["group_moderators"] = group_moderators
|
661
|
+
|
662
|
+
# create a negative log-likelihood function w.r.t spatial coefficients
|
663
|
+
def nll_spatial_coef(group_spatial_coef):
|
664
|
+
return -self._log_likelihood_single_group(
|
665
|
+
group_spatial_coef=group_spatial_coef,
|
666
|
+
**partial_kwargs,
|
667
|
+
)
|
668
|
+
|
669
|
+
group_spatial_coef = self.spatial_coef_linears[group].weight
|
670
|
+
group_f = torch.autograd.functional.hessian(
|
671
|
+
nll_spatial_coef,
|
672
|
+
group_spatial_coef,
|
673
|
+
**self._hessian_kwargs,
|
674
|
+
)
|
675
|
+
|
676
|
+
group_f = group_f.reshape((self.spatial_coef_dim, self.spatial_coef_dim))
|
677
|
+
group_eig_vals = torch.real(torch.linalg.eigvals(group_f))
|
678
|
+
del group_f
|
679
|
+
group_firth_penalty = 0.5 * torch.sum(torch.log(group_eig_vals))
|
680
|
+
del group_eig_vals
|
681
|
+
group_firth_penalty += group_firth_penalty
|
682
|
+
|
683
|
+
return group_firth_penalty
|
684
|
+
|
685
|
+
|
686
|
+
class OverdispersionModelEstimator(GeneralLinearModelEstimator):
|
687
|
+
"""Base class for CBMR models with over-dispersion parameter."""
|
688
|
+
|
689
|
+
def __init__(self, **kwargs):
|
690
|
+
self.square_root = kwargs.pop("square_root", False)
|
691
|
+
super().__init__(**kwargs)
|
692
|
+
|
693
|
+
def init_overdispersion_weights(self):
|
694
|
+
"""Initialize weights for overdispersion parameters.
|
695
|
+
|
696
|
+
Default is 1e-2.
|
697
|
+
"""
|
698
|
+
overdispersion = dict()
|
699
|
+
for group in self.groups:
|
700
|
+
# initialization for alpha
|
701
|
+
overdispersion_init_group = torch.tensor(1e-2).double()
|
702
|
+
if self.square_root:
|
703
|
+
overdispersion_init_group = torch.sqrt(overdispersion_init_group)
|
704
|
+
overdispersion[group] = torch.nn.Parameter(
|
705
|
+
overdispersion_init_group, requires_grad=True
|
706
|
+
)
|
707
|
+
self.overdispersion = torch.nn.ParameterDict(overdispersion)
|
708
|
+
|
709
|
+
def init_weights(self, groups, moderators, spatial_coef_dim, moderators_coef_dim):
|
710
|
+
"""Initialize weights for spatial and study-level moderator coefficients."""
|
711
|
+
super().init_weights(groups, moderators, spatial_coef_dim, moderators_coef_dim)
|
712
|
+
self.init_overdispersion_weights()
|
713
|
+
|
714
|
+
def inference_outcome(
|
715
|
+
self, coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study
|
716
|
+
):
|
717
|
+
"""Summarize inference outcome into `maps` and `tables`.
|
718
|
+
|
719
|
+
Add optimized overdispersion parameter to the tables.
|
720
|
+
"""
|
721
|
+
maps, tables = super().inference_outcome(
|
722
|
+
coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study
|
723
|
+
)
|
724
|
+
overdispersion_param = dict()
|
725
|
+
for group in self.groups:
|
726
|
+
group_overdispersion = self.overdispersion[group]
|
727
|
+
group_overdispersion = group_overdispersion.cpu().detach().numpy()
|
728
|
+
overdispersion_param[group] = group_overdispersion
|
729
|
+
tables["overdispersion_coef"] = pd.DataFrame.from_dict(
|
730
|
+
overdispersion_param, orient="index", columns=["overdispersion"]
|
731
|
+
)
|
732
|
+
|
733
|
+
return maps, tables
|
734
|
+
|
735
|
+
|
736
|
+
class PoissonEstimator(GeneralLinearModelEstimator):
|
737
|
+
"""CBMR framework with Poisson model.
|
738
|
+
|
739
|
+
Poisson model is the most basic model for Coordinate-based Meta-regression (CBMR).
|
740
|
+
It's based on the assumption that foci arise from a realisation of a (continues)
|
741
|
+
inhomogeneous Poisson process, so that the (discrete) voxel-wise foci counts will
|
742
|
+
be independently distributed as Poisson random variables, with rate equal to the
|
743
|
+
integral of the (true, unobserved, continous) intensity function over each voxels.
|
744
|
+
"""
|
745
|
+
|
746
|
+
def __init__(self, **kwargs):
|
747
|
+
super().__init__(**kwargs)
|
748
|
+
|
749
|
+
def _log_likelihood_single_group(
|
750
|
+
self,
|
751
|
+
group_spatial_coef,
|
752
|
+
moderators_coef,
|
753
|
+
coef_spline_bases,
|
754
|
+
group_moderators,
|
755
|
+
group_foci_per_voxel,
|
756
|
+
group_foci_per_study,
|
757
|
+
device="cpu",
|
758
|
+
):
|
759
|
+
log_mu_spatial = torch.matmul(coef_spline_bases, group_spatial_coef.T)
|
760
|
+
mu_spatial = torch.exp(log_mu_spatial)
|
761
|
+
if moderators_coef is None:
|
762
|
+
n_study, _ = group_foci_per_study.shape
|
763
|
+
log_mu_moderators = torch.tensor(
|
764
|
+
[0] * n_study, dtype=torch.float64, device=device
|
765
|
+
).reshape((-1, 1))
|
766
|
+
mu_moderators = torch.exp(log_mu_moderators)
|
767
|
+
else:
|
768
|
+
log_mu_moderators = torch.matmul(group_moderators, moderators_coef.T)
|
769
|
+
mu_moderators = torch.exp(log_mu_moderators)
|
770
|
+
log_l = (
|
771
|
+
torch.sum(torch.mul(group_foci_per_voxel, log_mu_spatial))
|
772
|
+
+ torch.sum(torch.mul(group_foci_per_study, log_mu_moderators))
|
773
|
+
- torch.sum(mu_spatial) * torch.sum(mu_moderators)
|
774
|
+
)
|
775
|
+
return log_l
|
776
|
+
|
777
|
+
def _log_likelihood_mult_group(
|
778
|
+
self,
|
779
|
+
spatial_coef,
|
780
|
+
moderator_coef,
|
781
|
+
coef_spline_bases,
|
782
|
+
foci_per_voxel,
|
783
|
+
foci_per_study,
|
784
|
+
moderators,
|
785
|
+
device="cpu",
|
786
|
+
):
|
787
|
+
n_groups = len(spatial_coef)
|
788
|
+
log_spatial_intensity = [
|
789
|
+
torch.matmul(coef_spline_bases, spatial_coef[i, :, :]) for i in range(n_groups)
|
790
|
+
]
|
791
|
+
spatial_intensity = [
|
792
|
+
torch.exp(group_log_spatial_intensity)
|
793
|
+
for group_log_spatial_intensity in log_spatial_intensity
|
794
|
+
]
|
795
|
+
if moderator_coef is not None:
|
796
|
+
log_moderator_effect = [
|
797
|
+
torch.matmul(group_moderator, moderator_coef) for group_moderator in moderators
|
798
|
+
]
|
799
|
+
moderator_effect = [
|
800
|
+
torch.exp(group_log_moderator_effect)
|
801
|
+
for group_log_moderator_effect in log_moderator_effect
|
802
|
+
]
|
803
|
+
else:
|
804
|
+
log_moderator_effect = [
|
805
|
+
torch.tensor(
|
806
|
+
[0] * foci_per_study_i.shape[0], dtype=torch.float64, device=device
|
807
|
+
).reshape((-1, 1))
|
808
|
+
for foci_per_study_i in foci_per_study
|
809
|
+
]
|
810
|
+
moderator_effect = [
|
811
|
+
torch.exp(group_log_moderator_effect)
|
812
|
+
for group_log_moderator_effect in log_moderator_effect
|
813
|
+
]
|
814
|
+
log_l = 0
|
815
|
+
for i in range(n_groups):
|
816
|
+
log_l += (
|
817
|
+
torch.sum(foci_per_voxel[i] * log_spatial_intensity[i])
|
818
|
+
+ torch.sum(foci_per_study[i] * log_moderator_effect[i])
|
819
|
+
- torch.sum(spatial_intensity[i]) * torch.sum(moderator_effect[i])
|
820
|
+
)
|
821
|
+
return log_l
|
822
|
+
|
823
|
+
def forward(self, coef_spline_bases, moderators, foci_per_voxel, foci_per_study):
|
824
|
+
"""Define the loss function (nagetive log-likelihood function) for Poisson model.
|
825
|
+
|
826
|
+
Model refactorization is applied to reduce the dimensionality of variables.
|
827
|
+
|
828
|
+
Returns
|
829
|
+
-------
|
830
|
+
torch.Tensor
|
831
|
+
Loss (nagative log-likelihood) of Poisson model at current iteration.
|
832
|
+
"""
|
833
|
+
log_l = 0
|
834
|
+
for group in self.groups:
|
835
|
+
group_spatial_coef = self.spatial_coef_linears[group].weight
|
836
|
+
group_foci_per_voxel = foci_per_voxel[group]
|
837
|
+
group_foci_per_study = foci_per_study[group]
|
838
|
+
if isinstance(moderators, dict):
|
839
|
+
moderators_coef = self.moderators_linear.weight
|
840
|
+
group_moderators = moderators[group]
|
841
|
+
else:
|
842
|
+
moderators_coef, group_moderators = None, None
|
843
|
+
group_log_l = self._log_likelihood_single_group(
|
844
|
+
group_spatial_coef,
|
845
|
+
moderators_coef,
|
846
|
+
coef_spline_bases,
|
847
|
+
group_moderators,
|
848
|
+
group_foci_per_voxel,
|
849
|
+
group_foci_per_study,
|
850
|
+
)
|
851
|
+
log_l += group_log_l
|
852
|
+
|
853
|
+
if self.penalty:
|
854
|
+
# Firth-type penalty
|
855
|
+
log_l += self.firth_penalty(
|
856
|
+
foci_per_voxel,
|
857
|
+
foci_per_study,
|
858
|
+
moderators,
|
859
|
+
coef_spline_bases,
|
860
|
+
overdispersion=False,
|
861
|
+
)
|
862
|
+
return -log_l
|
863
|
+
|
864
|
+
|
865
|
+
class NegativeBinomialEstimator(OverdispersionModelEstimator):
|
866
|
+
"""CBMR framework with Negative Binomial (NB) model.
|
867
|
+
|
868
|
+
Negative Binomial (NB) model is a generalized Poisson model with overdispersion.
|
869
|
+
It's a more flexible model, but more difficult to estimate. In practice, foci
|
870
|
+
counts often display over-dispersion (the variance of response variable
|
871
|
+
substantially exceeeds the mean), which is not captured by Poisson model.
|
872
|
+
"""
|
873
|
+
|
874
|
+
def __init__(self, **kwargs):
|
875
|
+
kwargs["square_root"] = True
|
876
|
+
super().__init__(**kwargs)
|
877
|
+
|
878
|
+
def _three_term(self, y, r):
|
879
|
+
max_foci = torch.max(y).to(dtype=torch.int64, device=self.device)
|
880
|
+
sum_three_term = 0
|
881
|
+
for k in range(max_foci):
|
882
|
+
foci_index = (y == k + 1).nonzero()[:, 0]
|
883
|
+
r_j = r[foci_index]
|
884
|
+
n_voxel = list(foci_index.shape)[0]
|
885
|
+
y_j = torch.tensor([k + 1] * n_voxel, device=self.device).double()
|
886
|
+
y_j = y_j.reshape((n_voxel, 1))
|
887
|
+
# y=0 => sum_three_term = 0
|
888
|
+
sum_three_term += torch.sum(
|
889
|
+
torch.lgamma(y_j + r_j) - torch.lgamma(y_j + 1) - torch.lgamma(r_j)
|
890
|
+
)
|
891
|
+
|
892
|
+
return sum_three_term
|
893
|
+
|
894
|
+
def _log_likelihood_single_group(
|
895
|
+
self,
|
896
|
+
group_overdispersion,
|
897
|
+
group_spatial_coef,
|
898
|
+
moderators_coef,
|
899
|
+
coef_spline_bases,
|
900
|
+
group_moderators,
|
901
|
+
group_foci_per_voxel,
|
902
|
+
group_foci_per_study,
|
903
|
+
device="cpu",
|
904
|
+
):
|
905
|
+
log_mu_spatial = torch.matmul(coef_spline_bases, group_spatial_coef.T)
|
906
|
+
mu_spatial = torch.exp(log_mu_spatial)
|
907
|
+
if moderators_coef is not None:
|
908
|
+
log_mu_moderators = torch.matmul(group_moderators, moderators_coef.T)
|
909
|
+
mu_moderators = torch.exp(log_mu_moderators)
|
910
|
+
else:
|
911
|
+
n_study, _ = group_foci_per_study.shape
|
912
|
+
log_mu_moderators = torch.tensor(
|
913
|
+
[0] * n_study, dtype=torch.float64, device=device
|
914
|
+
).reshape((-1, 1))
|
915
|
+
mu_moderators = torch.exp(log_mu_moderators)
|
916
|
+
# parameter of a NB variable to approximate a sum of NB variables
|
917
|
+
r = 1 / group_overdispersion * torch.sum(mu_moderators) ** 2 / torch.sum(mu_moderators**2)
|
918
|
+
p = 1 / (
|
919
|
+
1
|
920
|
+
+ torch.sum(mu_moderators)
|
921
|
+
/ (group_overdispersion * mu_spatial * torch.sum(mu_moderators**2))
|
922
|
+
)
|
923
|
+
# log-likelihood (moment matching approach)
|
924
|
+
log_l = torch.sum(
|
925
|
+
torch.lgamma(group_foci_per_voxel + r)
|
926
|
+
- torch.lgamma(group_foci_per_voxel + 1)
|
927
|
+
- torch.lgamma(r)
|
928
|
+
+ r * torch.log(1 - p)
|
929
|
+
+ group_foci_per_voxel * torch.log(p)
|
930
|
+
)
|
931
|
+
|
932
|
+
return log_l
|
933
|
+
|
934
|
+
def _log_likelihood_mult_group(
|
935
|
+
self,
|
936
|
+
overdispersion_coef,
|
937
|
+
spatial_coef,
|
938
|
+
coef_spline_bases,
|
939
|
+
foci_per_voxel,
|
940
|
+
foci_per_study,
|
941
|
+
moderator_coef=None,
|
942
|
+
moderators=None,
|
943
|
+
device="cpu",
|
944
|
+
):
|
945
|
+
n_groups = len(foci_per_voxel)
|
946
|
+
log_spatial_intensity = [
|
947
|
+
torch.matmul(coef_spline_bases, spatial_coef[i, :, :]) for i in range(n_groups)
|
948
|
+
]
|
949
|
+
spatial_intensity = [
|
950
|
+
torch.exp(group_log_spatial_intensity)
|
951
|
+
for group_log_spatial_intensity in log_spatial_intensity
|
952
|
+
]
|
953
|
+
if moderator_coef is not None:
|
954
|
+
log_moderator_effect = [
|
955
|
+
torch.matmul(group_moderator, moderator_coef) for group_moderator in moderators
|
956
|
+
]
|
957
|
+
moderator_effect = [
|
958
|
+
torch.exp(group_log_moderator_effect)
|
959
|
+
for group_log_moderator_effect in log_moderator_effect
|
960
|
+
]
|
961
|
+
else:
|
962
|
+
log_moderator_effect = [
|
963
|
+
torch.tensor(
|
964
|
+
[0] * foci_per_study.shape[0], dtype=torch.float64, device=device
|
965
|
+
).reshape((-1, 1))
|
966
|
+
for foci_per_study in foci_per_study
|
967
|
+
]
|
968
|
+
moderator_effect = [
|
969
|
+
torch.exp(group_log_moderator_effect)
|
970
|
+
for group_log_moderator_effect in log_moderator_effect
|
971
|
+
]
|
972
|
+
# After similification, we have:
|
973
|
+
# r' = 1/alpha * sum(mu^Z_i)^2 / sum((mu^Z_i)^2)
|
974
|
+
# p'_j = 1 / (1 + sum(mu^Z_i) / (alpha * mu^X_j * sum((mu^Z_i)^2)
|
975
|
+
r = [
|
976
|
+
1
|
977
|
+
/ overdispersion_coef[i]
|
978
|
+
* torch.sum(moderator_effect[i]) ** 2
|
979
|
+
/ torch.sum(moderator_effect[i] ** 2)
|
980
|
+
for i in range(n_groups)
|
981
|
+
]
|
982
|
+
p_frac = [
|
983
|
+
torch.sum(moderator_effect[i])
|
984
|
+
/ (overdispersion_coef[i] * spatial_intensity[i] * torch.sum(moderator_effect[i] ** 2))
|
985
|
+
for i in range(n_groups)
|
986
|
+
]
|
987
|
+
p = [1 / (1 + p_frac[i]) for i in range(n_groups)]
|
988
|
+
|
989
|
+
log_l = 0
|
990
|
+
for i in range(n_groups):
|
991
|
+
group_log_l = torch.sum(
|
992
|
+
torch.lgamma(foci_per_voxel[i] + r[i])
|
993
|
+
- torch.lgamma(foci_per_voxel[i] + 1)
|
994
|
+
- torch.lgamma(r[i])
|
995
|
+
+ r[i] * torch.log(1 - p[i])
|
996
|
+
+ foci_per_voxel[i] * torch.log(p[i])
|
997
|
+
)
|
998
|
+
log_l += group_log_l
|
999
|
+
|
1000
|
+
return log_l
|
1001
|
+
|
1002
|
+
def forward(self, coef_spline_bases, moderators, foci_per_voxel, foci_per_study):
|
1003
|
+
"""Define the loss function (nagetive log-likelihood function) for NB model.
|
1004
|
+
|
1005
|
+
Model refactorization is applied to reduce the dimensionality of variables.
|
1006
|
+
|
1007
|
+
Returns
|
1008
|
+
-------
|
1009
|
+
torch.Tensor
|
1010
|
+
Loss (nagative log-likelihood) of NB model at current iteration.
|
1011
|
+
"""
|
1012
|
+
log_l = 0
|
1013
|
+
for group in self.groups:
|
1014
|
+
group_overdispersion = self.overdispersion[group] ** 2
|
1015
|
+
group_spatial_coef = self.spatial_coef_linears[group].weight
|
1016
|
+
group_foci_per_voxel = foci_per_voxel[group]
|
1017
|
+
group_foci_per_study = foci_per_study[group]
|
1018
|
+
if isinstance(moderators, dict):
|
1019
|
+
moderators_coef = self.moderators_linear.weight
|
1020
|
+
group_moderators = moderators[group]
|
1021
|
+
else:
|
1022
|
+
moderators_coef, group_moderators = None, None
|
1023
|
+
group_log_l = self._log_likelihood_single_group(
|
1024
|
+
group_overdispersion,
|
1025
|
+
group_spatial_coef,
|
1026
|
+
moderators_coef,
|
1027
|
+
coef_spline_bases,
|
1028
|
+
group_moderators,
|
1029
|
+
group_foci_per_voxel,
|
1030
|
+
group_foci_per_study,
|
1031
|
+
)
|
1032
|
+
|
1033
|
+
log_l += group_log_l
|
1034
|
+
|
1035
|
+
if self.penalty:
|
1036
|
+
# Firth-type penalty
|
1037
|
+
log_l += self.firth_penalty(
|
1038
|
+
foci_per_voxel,
|
1039
|
+
foci_per_study,
|
1040
|
+
moderators,
|
1041
|
+
coef_spline_bases,
|
1042
|
+
overdispersion=True,
|
1043
|
+
)
|
1044
|
+
|
1045
|
+
return -log_l
|
1046
|
+
|
1047
|
+
|
1048
|
+
class ClusteredNegativeBinomialEstimator(OverdispersionModelEstimator):
|
1049
|
+
"""CBMR framework with Clustered Negative Binomial (Clustered NB) model.
|
1050
|
+
|
1051
|
+
Clustered NB model can also accommodate over-dispersion in foci counts.
|
1052
|
+
In NB model, the latent Gamma random variable introduces indepdentent variation
|
1053
|
+
at each voxel. While in Clustered NB model, we assert the random effects are not
|
1054
|
+
independent voxelwise effects, but rather latent characteristics of each study,
|
1055
|
+
and represent a shared effect over the entire brain for a given study.
|
1056
|
+
"""
|
1057
|
+
|
1058
|
+
def __init__(self, **kwargs):
|
1059
|
+
kwargs["square_root"] = False
|
1060
|
+
super().__init__(**kwargs)
|
1061
|
+
|
1062
|
+
def _log_likelihood_single_group(
|
1063
|
+
self,
|
1064
|
+
group_overdispersion,
|
1065
|
+
group_spatial_coef,
|
1066
|
+
moderators_coef,
|
1067
|
+
coef_spline_bases,
|
1068
|
+
group_moderators,
|
1069
|
+
group_foci_per_voxel,
|
1070
|
+
group_foci_per_study,
|
1071
|
+
device="cpu",
|
1072
|
+
):
|
1073
|
+
v = 1 / group_overdispersion
|
1074
|
+
log_mu_spatial = torch.matmul(coef_spline_bases, group_spatial_coef.T)
|
1075
|
+
mu_spatial = torch.exp(log_mu_spatial)
|
1076
|
+
if moderators_coef is not None:
|
1077
|
+
log_mu_moderators = torch.matmul(group_moderators, moderators_coef.T)
|
1078
|
+
mu_moderators = torch.exp(log_mu_moderators)
|
1079
|
+
else:
|
1080
|
+
n_study, _ = group_foci_per_study.shape
|
1081
|
+
log_mu_moderators = torch.tensor(
|
1082
|
+
[0] * n_study, dtype=torch.float64, device=device
|
1083
|
+
).reshape((-1, 1))
|
1084
|
+
mu_moderators = torch.exp(log_mu_moderators)
|
1085
|
+
mu_sum_per_study = torch.sum(mu_spatial) * mu_moderators
|
1086
|
+
group_n_study, _ = group_foci_per_study.shape
|
1087
|
+
|
1088
|
+
log_l = (
|
1089
|
+
group_n_study * v * torch.log(v)
|
1090
|
+
- group_n_study * torch.lgamma(v)
|
1091
|
+
+ torch.sum(torch.lgamma(group_foci_per_study + v))
|
1092
|
+
- torch.sum((group_foci_per_study + v) * torch.log(mu_sum_per_study + v))
|
1093
|
+
+ torch.sum(group_foci_per_voxel * log_mu_spatial)
|
1094
|
+
+ torch.sum(group_foci_per_study * log_mu_moderators)
|
1095
|
+
)
|
1096
|
+
|
1097
|
+
return log_l
|
1098
|
+
|
1099
|
+
def _log_likelihood_mult_group(
|
1100
|
+
self,
|
1101
|
+
overdispersion_coef,
|
1102
|
+
spatial_coef,
|
1103
|
+
coef_spline_bases,
|
1104
|
+
foci_per_voxel,
|
1105
|
+
foci_per_study,
|
1106
|
+
moderator_coef=None,
|
1107
|
+
moderators=None,
|
1108
|
+
device="cpu",
|
1109
|
+
):
|
1110
|
+
n_groups = len(foci_per_voxel)
|
1111
|
+
v = [1 / group_overdispersion_coef for group_overdispersion_coef in overdispersion_coef]
|
1112
|
+
# estimated intensity and log estimated intensity
|
1113
|
+
log_spatial_intensity = [
|
1114
|
+
torch.matmul(coef_spline_bases, spatial_coef[i, :, :]) for i in range(n_groups)
|
1115
|
+
]
|
1116
|
+
spatial_intensity = [
|
1117
|
+
torch.exp(group_log_spatial_intensity)
|
1118
|
+
for group_log_spatial_intensity in log_spatial_intensity
|
1119
|
+
]
|
1120
|
+
if moderator_coef is not None:
|
1121
|
+
log_moderator_effect = [
|
1122
|
+
torch.matmul(group_moderator, moderator_coef) for group_moderator in moderators
|
1123
|
+
]
|
1124
|
+
moderator_effect = [
|
1125
|
+
torch.exp(group_log_moderator_effect)
|
1126
|
+
for group_log_moderator_effect in log_moderator_effect
|
1127
|
+
]
|
1128
|
+
else:
|
1129
|
+
log_moderator_effect = [
|
1130
|
+
torch.tensor(
|
1131
|
+
[0] * foci_per_study.shape[0], dtype=torch.float64, device=device
|
1132
|
+
).reshape((-1, 1))
|
1133
|
+
for foci_per_study in foci_per_study
|
1134
|
+
]
|
1135
|
+
moderator_effect = [
|
1136
|
+
torch.exp(group_log_moderator_effect)
|
1137
|
+
for group_log_moderator_effect in log_moderator_effect
|
1138
|
+
]
|
1139
|
+
mu_sum_per_study = [
|
1140
|
+
torch.sum(spatial_intensity[i]) * moderator_effect[i] for i in range(n_groups)
|
1141
|
+
]
|
1142
|
+
n_study_list = [group_foci_per_study.shape[0] for group_foci_per_study in foci_per_study]
|
1143
|
+
|
1144
|
+
log_l = 0
|
1145
|
+
for i in range(n_groups):
|
1146
|
+
log_l += (
|
1147
|
+
n_study_list[i] * v[i] * torch.log(v[i])
|
1148
|
+
- n_study_list[i] * torch.lgamma(v[i])
|
1149
|
+
+ torch.sum(torch.lgamma(foci_per_study[i] + v[i]))
|
1150
|
+
- torch.sum((foci_per_study[i] + v[i]) * torch.log(mu_sum_per_study[i] + v[i]))
|
1151
|
+
+ torch.sum(foci_per_voxel[i] * log_spatial_intensity[i])
|
1152
|
+
+ torch.sum(foci_per_study[i] * log_moderator_effect[i])
|
1153
|
+
)
|
1154
|
+
|
1155
|
+
return log_l
|
1156
|
+
|
1157
|
+
def forward(self, coef_spline_bases, moderators, foci_per_voxel, foci_per_study):
|
1158
|
+
"""Define the loss function (nagetive log-likelihood function) for Clustered NB model.
|
1159
|
+
|
1160
|
+
Model refactorization is applied to reduce the dimensionality of variables.
|
1161
|
+
|
1162
|
+
Returns
|
1163
|
+
-------
|
1164
|
+
torch.Tensor
|
1165
|
+
Loss (nagative log-likelihood) of Poisson model at current iteration.
|
1166
|
+
"""
|
1167
|
+
log_l = 0
|
1168
|
+
for group in self.groups:
|
1169
|
+
group_overdispersion = self.overdispersion[group]
|
1170
|
+
group_spatial_coef = self.spatial_coef_linears[group].weight
|
1171
|
+
group_foci_per_voxel = foci_per_voxel[group]
|
1172
|
+
group_foci_per_study = foci_per_study[group]
|
1173
|
+
if isinstance(moderators, dict):
|
1174
|
+
moderators_coef = self.moderators_linear.weight
|
1175
|
+
group_moderators = moderators[group]
|
1176
|
+
else:
|
1177
|
+
moderators_coef, group_moderators = None, None
|
1178
|
+
group_log_l = self._log_likelihood_single_group(
|
1179
|
+
group_overdispersion,
|
1180
|
+
group_spatial_coef,
|
1181
|
+
moderators_coef,
|
1182
|
+
coef_spline_bases,
|
1183
|
+
group_moderators,
|
1184
|
+
group_foci_per_voxel,
|
1185
|
+
group_foci_per_study,
|
1186
|
+
)
|
1187
|
+
log_l += group_log_l
|
1188
|
+
|
1189
|
+
if self.penalty:
|
1190
|
+
# Firth-type penalty
|
1191
|
+
log_l += self.firth_penalty(
|
1192
|
+
foci_per_voxel,
|
1193
|
+
foci_per_study,
|
1194
|
+
moderators,
|
1195
|
+
coef_spline_bases,
|
1196
|
+
overdispersion=True,
|
1197
|
+
)
|
1198
|
+
|
1199
|
+
return -log_l
|