nimare 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. benchmarks/__init__.py +0 -0
  2. benchmarks/bench_cbma.py +57 -0
  3. nimare/__init__.py +45 -0
  4. nimare/_version.py +21 -0
  5. nimare/annotate/__init__.py +21 -0
  6. nimare/annotate/cogat.py +213 -0
  7. nimare/annotate/gclda.py +924 -0
  8. nimare/annotate/lda.py +147 -0
  9. nimare/annotate/text.py +75 -0
  10. nimare/annotate/utils.py +87 -0
  11. nimare/base.py +217 -0
  12. nimare/cli.py +124 -0
  13. nimare/correct.py +462 -0
  14. nimare/dataset.py +685 -0
  15. nimare/decode/__init__.py +33 -0
  16. nimare/decode/base.py +115 -0
  17. nimare/decode/continuous.py +462 -0
  18. nimare/decode/discrete.py +753 -0
  19. nimare/decode/encode.py +110 -0
  20. nimare/decode/utils.py +44 -0
  21. nimare/diagnostics.py +510 -0
  22. nimare/estimator.py +139 -0
  23. nimare/extract/__init__.py +19 -0
  24. nimare/extract/extract.py +466 -0
  25. nimare/extract/utils.py +295 -0
  26. nimare/generate.py +331 -0
  27. nimare/io.py +667 -0
  28. nimare/meta/__init__.py +39 -0
  29. nimare/meta/cbma/__init__.py +6 -0
  30. nimare/meta/cbma/ale.py +951 -0
  31. nimare/meta/cbma/base.py +947 -0
  32. nimare/meta/cbma/mkda.py +1361 -0
  33. nimare/meta/cbmr.py +970 -0
  34. nimare/meta/ibma.py +1683 -0
  35. nimare/meta/kernel.py +501 -0
  36. nimare/meta/models.py +1199 -0
  37. nimare/meta/utils.py +494 -0
  38. nimare/nimads.py +492 -0
  39. nimare/reports/__init__.py +24 -0
  40. nimare/reports/base.py +664 -0
  41. nimare/reports/default.yml +123 -0
  42. nimare/reports/figures.py +651 -0
  43. nimare/reports/report.tpl +160 -0
  44. nimare/resources/__init__.py +1 -0
  45. nimare/resources/atlases/Harvard-Oxford-LICENSE +93 -0
  46. nimare/resources/atlases/HarvardOxford-cort-maxprob-thr25-2mm.nii.gz +0 -0
  47. nimare/resources/database_file_manifest.json +142 -0
  48. nimare/resources/english_spellings.csv +1738 -0
  49. nimare/resources/filenames.json +32 -0
  50. nimare/resources/neurosynth_laird_studies.json +58773 -0
  51. nimare/resources/neurosynth_stoplist.txt +396 -0
  52. nimare/resources/nidm_pain_dset.json +1349 -0
  53. nimare/resources/references.bib +541 -0
  54. nimare/resources/semantic_knowledge_children.txt +325 -0
  55. nimare/resources/semantic_relatedness_children.txt +249 -0
  56. nimare/resources/templates/MNI152_2x2x2_brainmask.nii.gz +0 -0
  57. nimare/resources/templates/tpl-MNI152NLin6Asym_res-01_T1w.nii.gz +0 -0
  58. nimare/resources/templates/tpl-MNI152NLin6Asym_res-01_desc-brain_mask.nii.gz +0 -0
  59. nimare/resources/templates/tpl-MNI152NLin6Asym_res-02_T1w.nii.gz +0 -0
  60. nimare/resources/templates/tpl-MNI152NLin6Asym_res-02_desc-brain_mask.nii.gz +0 -0
  61. nimare/results.py +225 -0
  62. nimare/stats.py +276 -0
  63. nimare/tests/__init__.py +1 -0
  64. nimare/tests/conftest.py +229 -0
  65. nimare/tests/data/amygdala_roi.nii.gz +0 -0
  66. nimare/tests/data/data-neurosynth_version-7_coordinates.tsv.gz +0 -0
  67. nimare/tests/data/data-neurosynth_version-7_metadata.tsv.gz +0 -0
  68. nimare/tests/data/data-neurosynth_version-7_vocab-terms_source-abstract_type-tfidf_features.npz +0 -0
  69. nimare/tests/data/data-neurosynth_version-7_vocab-terms_vocabulary.txt +100 -0
  70. nimare/tests/data/neurosynth_dset.json +2868 -0
  71. nimare/tests/data/neurosynth_laird_studies.json +58773 -0
  72. nimare/tests/data/nidm_pain_dset.json +1349 -0
  73. nimare/tests/data/nimads_annotation.json +1 -0
  74. nimare/tests/data/nimads_studyset.json +1 -0
  75. nimare/tests/data/test_baseline.txt +2 -0
  76. nimare/tests/data/test_pain_dataset.json +1278 -0
  77. nimare/tests/data/test_pain_dataset_multiple_contrasts.json +1242 -0
  78. nimare/tests/data/test_sleuth_file.txt +18 -0
  79. nimare/tests/data/test_sleuth_file2.txt +10 -0
  80. nimare/tests/data/test_sleuth_file3.txt +5 -0
  81. nimare/tests/data/test_sleuth_file4.txt +5 -0
  82. nimare/tests/data/test_sleuth_file5.txt +5 -0
  83. nimare/tests/test_annotate_cogat.py +32 -0
  84. nimare/tests/test_annotate_gclda.py +86 -0
  85. nimare/tests/test_annotate_lda.py +27 -0
  86. nimare/tests/test_dataset.py +99 -0
  87. nimare/tests/test_decode_continuous.py +132 -0
  88. nimare/tests/test_decode_discrete.py +92 -0
  89. nimare/tests/test_diagnostics.py +168 -0
  90. nimare/tests/test_estimator_performance.py +385 -0
  91. nimare/tests/test_extract.py +46 -0
  92. nimare/tests/test_generate.py +247 -0
  93. nimare/tests/test_io.py +294 -0
  94. nimare/tests/test_meta_ale.py +298 -0
  95. nimare/tests/test_meta_cbmr.py +295 -0
  96. nimare/tests/test_meta_ibma.py +240 -0
  97. nimare/tests/test_meta_kernel.py +209 -0
  98. nimare/tests/test_meta_mkda.py +234 -0
  99. nimare/tests/test_nimads.py +21 -0
  100. nimare/tests/test_reports.py +110 -0
  101. nimare/tests/test_stats.py +101 -0
  102. nimare/tests/test_transforms.py +272 -0
  103. nimare/tests/test_utils.py +200 -0
  104. nimare/tests/test_workflows.py +221 -0
  105. nimare/tests/utils.py +126 -0
  106. nimare/transforms.py +907 -0
  107. nimare/utils.py +1367 -0
  108. nimare/workflows/__init__.py +14 -0
  109. nimare/workflows/base.py +189 -0
  110. nimare/workflows/cbma.py +165 -0
  111. nimare/workflows/ibma.py +108 -0
  112. nimare/workflows/macm.py +77 -0
  113. nimare/workflows/misc.py +65 -0
  114. nimare-0.4.2.dist-info/LICENSE +21 -0
  115. nimare-0.4.2.dist-info/METADATA +124 -0
  116. nimare-0.4.2.dist-info/RECORD +119 -0
  117. nimare-0.4.2.dist-info/WHEEL +5 -0
  118. nimare-0.4.2.dist-info/entry_points.txt +2 -0
  119. nimare-0.4.2.dist-info/top_level.txt +2 -0
nimare/meta/models.py ADDED
@@ -0,0 +1,1199 @@
1
+ """CBMR Models."""
2
+
3
+ import abc
4
+ import logging
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ try:
10
+ import torch
11
+ except ImportError as e:
12
+ raise ImportError(
13
+ "Torch is required to use `CBMR` models. Install with `pip install 'nimare[cbmr]'`."
14
+ ) from e
15
+
16
+ LGR = logging.getLogger(__name__)
17
+
18
+
19
+ class GeneralLinearModelEstimator(torch.nn.Module):
20
+ """Base class for GLM estimators.
21
+
22
+ Parameters
23
+ ----------
24
+ spatial_coef_dim : :obj:`int`
25
+ Number of spatial B-spline bases. Default is None.
26
+ moderators_coef_dim : :obj:`int`, optional
27
+ Number of study-level moderators. Default is None.
28
+ penalty : :obj:`bool`
29
+ Whether to Firth-type regularization term. Default is False.
30
+ lr : :obj:`float`
31
+ Learning rate. Default is 0.1.
32
+ lr_decay : :obj:`float`
33
+ Learning rate decay for each iteration. Default is 0.999.
34
+ n_iter : :obj:`int`
35
+ Maximum number of iterations. Default is 1000.
36
+ tol : :obj:`float`
37
+ Tolerance for convergence. Default is 1e-2.
38
+ device : :obj:`str`
39
+ Device to use for computations. Default is "cpu".
40
+ """
41
+
42
+ _hessian_kwargs = {
43
+ "create_graph": False,
44
+ "vectorize": True,
45
+ "outer_jacobian_strategy": "forward-mode",
46
+ }
47
+
48
+ def __init__(
49
+ self,
50
+ spatial_coef_dim=None,
51
+ moderators_coef_dim=None,
52
+ penalty=False,
53
+ lr=1,
54
+ lr_decay=0.999,
55
+ n_iter=2000,
56
+ tol=1e-9,
57
+ device="cpu",
58
+ ):
59
+ super().__init__()
60
+ self.spatial_coef_dim = spatial_coef_dim
61
+ self.moderators_coef_dim = moderators_coef_dim
62
+ self.penalty = penalty
63
+ self.lr = lr
64
+ self.lr_decay = lr_decay
65
+ self.n_iter = n_iter
66
+ self.tol = tol
67
+ self.device = device
68
+
69
+ # initialization for iteration set up
70
+ self.iter = 0
71
+
72
+ # after fitting, the following attributes will be created
73
+ self.spatial_regression_coef = None
74
+ self.spatial_intensity_estimation = None
75
+ self.moderators_coef = None
76
+ self.moderators_effect = None
77
+ self.spatial_regression_coef_se = None
78
+ self.log_spatial_intensity_se = None
79
+ self.spatial_intensity_se = None
80
+ self.se_moderators = None
81
+
82
+ @abc.abstractmethod
83
+ def _log_likelihood_single_group(self, **kwargs):
84
+ """Log-likelihood of a single group.
85
+
86
+ Returns
87
+ -------
88
+ torch.Tensor
89
+ Value of the log-likelihood of a single group.
90
+ """
91
+ pass
92
+
93
+ @abc.abstractmethod
94
+ def _log_likelihood_mult_group(self, **kwargs):
95
+ """Total log-likelihood of all groups in the dataset.
96
+
97
+ Returns
98
+ -------
99
+ torch.Tensor
100
+ Value of total log-likelihood of all groups in the dataset.
101
+ """
102
+ pass
103
+
104
+ @abc.abstractmethod
105
+ def forward(self, **kwargs):
106
+ """Define the loss function (nagetive log-likelihood function) for each model.
107
+
108
+ Returns
109
+ -------
110
+ torch.Tensor
111
+ Value of the log-likelihood of a single group.
112
+ """
113
+ pass
114
+
115
+ def init_spatial_weights(self):
116
+ """Initialize spatial regression coefficients.
117
+
118
+ Default is uniform distribution between -0.01 and 0.01.
119
+ """
120
+ # initialization for spatial regression coefficients
121
+ spatial_coef_linears = dict()
122
+ for group in self.groups:
123
+ spatial_coef_linear_group = torch.nn.Linear(
124
+ self.spatial_coef_dim, 1, bias=False
125
+ ).double()
126
+ torch.nn.init.uniform_(spatial_coef_linear_group.weight, a=-0.01, b=0.01)
127
+ spatial_coef_linears[group] = spatial_coef_linear_group
128
+ self.spatial_coef_linears = torch.nn.ModuleDict(spatial_coef_linears)
129
+
130
+ def init_moderator_weights(self):
131
+ """Initialize the intercept and regression coefficients for moderators.
132
+
133
+ Default is uniform distribution between -0.01 and 0.01.
134
+ """
135
+ self.moderators_linear = torch.nn.Linear(self.moderators_coef_dim, 1, bias=False).double()
136
+ torch.nn.init.uniform_(self.moderators_linear.weight, a=-0.01, b=0.01)
137
+ return
138
+
139
+ def init_weights(self, groups, moderators, spatial_coef_dim, moderators_coef_dim):
140
+ """Initialize regression coefficients of spatial struture and study-level moderators."""
141
+ self.groups = groups
142
+ self.moderators = moderators
143
+ self.spatial_coef_dim = spatial_coef_dim
144
+ self.moderators_coef_dim = moderators_coef_dim
145
+ self.init_spatial_weights()
146
+ if moderators_coef_dim:
147
+ self.init_moderator_weights()
148
+
149
+ def _update(
150
+ self,
151
+ optimizer,
152
+ coef_spline_bases,
153
+ moderators,
154
+ foci_per_voxel,
155
+ foci_per_study,
156
+ prev_loss,
157
+ ):
158
+ """One iteration in optimization with L-BFGS.
159
+
160
+ Adjust learning rate based on the number of iteration (with learning rate decay parameter
161
+ `lr_decay`, default value is 0.999). Reset L-BFGS optimizer (as params in the previous
162
+ iteration) if NaN occurs.
163
+
164
+ Parameters
165
+ ----------
166
+ optimizer : :obj:`torch.optim.lbfgs.LBFGS`
167
+ L-BFGS optimizer.
168
+ coef_spline_bases : :obj:`torch.Tensor`
169
+ Coefficient of B-spline bases evaluated at each voxel.
170
+ moderators : :obj:`dict`, optional
171
+ Dictionary of group-wise study-level moderators. Default is None.
172
+ foci_per_voxel : :obj:`dict`
173
+ Dictionary of group-wise number of foci per voxel.
174
+ foci_per_study : :obj:`dict`
175
+ Dictionary of group-wise number of foci per study.
176
+ prev_loss : :obj:`torch.Tensor`
177
+ Value of the loss function of the previous iteration.
178
+
179
+ Returns
180
+ -------
181
+ torch.Tensor
182
+ Updated value of the loss (negative log-likelihood) function.
183
+ """
184
+ self.iter += 1
185
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(
186
+ optimizer, gamma=self.lr_decay
187
+ ) # learning rate decay
188
+
189
+ def closure():
190
+ optimizer.zero_grad()
191
+ loss = self(coef_spline_bases, moderators, foci_per_voxel, foci_per_study)
192
+ loss.backward()
193
+ return loss
194
+
195
+ optimizer.step(closure)
196
+ scheduler.step()
197
+ # recalculate the loss function
198
+ loss = self(coef_spline_bases, moderators, foci_per_voxel, foci_per_study)
199
+
200
+ if torch.isnan(loss):
201
+ raise ValueError(
202
+ f"""The current learing rate {str(self.lr)} or choice of model gives rise to
203
+ NaN log-likelihood, please try Poisson model or adjust learning rate to a smaller
204
+ value."""
205
+ )
206
+
207
+ return loss
208
+
209
+ def _optimizer(self, coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study):
210
+ """
211
+ Optimize the loss (negative log-likelihood) function with L-BFGS.
212
+
213
+ Parameters
214
+ ----------
215
+ coef_spline_bases : :obj:`numpy.ndarray`
216
+ Coefficient of B-spline bases evaluated at each voxel.
217
+ moderators_by_group : :obj:`dict`, optional
218
+ Dictionary of group-wise study-level moderators.
219
+ foci_per_voxel : :obj:`dict`
220
+ Dictionary of group-wise number of foci per voxel.
221
+ foci_per_study : :obj:`dict`
222
+ Dictionary of group-wise number of foci per study.
223
+ """
224
+ torch.manual_seed(100)
225
+ optimizer = torch.optim.LBFGS(
226
+ params=self.parameters(),
227
+ lr=self.lr,
228
+ max_iter=self.n_iter,
229
+ tolerance_change=self.tol,
230
+ line_search_fn="strong_wolfe",
231
+ )
232
+ # load dataset info to torch.tensor
233
+ coef_spline_bases = torch.tensor(
234
+ coef_spline_bases, dtype=torch.float64, device=self.device
235
+ )
236
+ if moderators_by_group:
237
+ moderators_by_group_tensor = dict()
238
+ for group in self.groups:
239
+ moderators_tensor = torch.tensor(
240
+ moderators_by_group[group], dtype=torch.float64, device=self.device
241
+ )
242
+ moderators_by_group_tensor[group] = moderators_tensor
243
+ else:
244
+ moderators_by_group_tensor = None
245
+ foci_per_voxel_tensor, foci_per_study_tensor = dict(), dict()
246
+ for group in self.groups:
247
+ group_foci_per_voxel_tensor = torch.tensor(
248
+ foci_per_voxel[group], dtype=torch.float64, device=self.device
249
+ )
250
+ group_foci_per_study_tensor = torch.tensor(
251
+ foci_per_study[group], dtype=torch.float64, device=self.device
252
+ )
253
+ foci_per_voxel_tensor[group] = group_foci_per_voxel_tensor
254
+ foci_per_study_tensor[group] = group_foci_per_study_tensor
255
+
256
+ if self.iter == 0:
257
+ prev_loss = torch.tensor(float("inf")) # initialization loss difference
258
+
259
+ self._update(
260
+ optimizer,
261
+ coef_spline_bases,
262
+ moderators_by_group_tensor,
263
+ foci_per_voxel_tensor,
264
+ foci_per_study_tensor,
265
+ prev_loss,
266
+ )
267
+
268
+ return
269
+
270
+ def fit(self, coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study):
271
+ """Fit the model and estimate standard error of estimates."""
272
+ self._optimizer(coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study)
273
+ self.extract_optimized_params(coef_spline_bases, moderators_by_group)
274
+ self.standard_error_estimation(
275
+ coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study
276
+ )
277
+
278
+ return
279
+
280
+ def extract_optimized_params(self, coef_spline_bases, moderators_by_group):
281
+ """Extract optimized regression coefficient of study-level moderators from the model."""
282
+ spatial_regression_coef, spatial_intensity_estimation = dict(), dict()
283
+ for group in self.groups:
284
+ # Extract optimized spatial regression coefficients from the model
285
+ group_spatial_coef_linear_weight = self.spatial_coef_linears[group].weight
286
+ group_spatial_coef_linear_weight = (
287
+ group_spatial_coef_linear_weight.cpu().detach().numpy().flatten()
288
+ )
289
+ spatial_regression_coef[group] = group_spatial_coef_linear_weight
290
+ # Estimate group-specific spatial intensity
291
+ group_spatial_intensity_estimation = np.exp(
292
+ np.matmul(coef_spline_bases, group_spatial_coef_linear_weight)
293
+ )
294
+ spatial_intensity_estimation["spatialIntensity_group-" + group] = (
295
+ group_spatial_intensity_estimation
296
+ )
297
+
298
+ # Extract optimized regression coefficient of study-level moderators from the model
299
+ if self.moderators_coef_dim:
300
+ moderators_effect = dict()
301
+ moderators_coef = self.moderators_linear.weight
302
+ moderators_coef = moderators_coef.cpu().detach().numpy()
303
+ for group in self.groups:
304
+ group_moderators = moderators_by_group[group]
305
+ group_moderators_effect = np.exp(np.matmul(group_moderators, moderators_coef.T))
306
+ moderators_effect[group] = group_moderators_effect.flatten()
307
+ else:
308
+ moderators_coef, moderators_effect = None, None
309
+
310
+ self.spatial_regression_coef = spatial_regression_coef
311
+ self.spatial_intensity_estimation = spatial_intensity_estimation
312
+ self.moderators_coef = moderators_coef
313
+ self.moderators_effect = moderators_effect
314
+
315
+ def standard_error_estimation(
316
+ self, coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study
317
+ ):
318
+ """Estimate standard error of estimates.
319
+
320
+ For spatial regression coefficients, we estimate its covariance matrix using Fisher
321
+ Information Matrix and then take the square root of the diagonal elements.
322
+ For log spatial intensity, we use the delta method to estimate its standard error.
323
+ For models with over-dispersion parameter, we also estimate its standard error.
324
+ """
325
+ spatial_regression_coef_se, log_spatial_intensity_se, spatial_intensity_se = (
326
+ dict(),
327
+ dict(),
328
+ dict(),
329
+ )
330
+ for group in self.groups:
331
+ group_foci_per_voxel = torch.tensor(
332
+ foci_per_voxel[group], dtype=torch.float64, device=self.device
333
+ )
334
+ group_foci_per_study = torch.tensor(
335
+ foci_per_study[group], dtype=torch.float64, device=self.device
336
+ )
337
+ group_spatial_coef = self.spatial_coef_linears[group].weight
338
+ if self.moderators_coef_dim:
339
+ group_moderators = torch.tensor(
340
+ moderators_by_group[group], dtype=torch.float64, device=self.device
341
+ )
342
+ moderators_coef = self.moderators_linear.weight
343
+ else:
344
+ group_moderators, moderators_coef = None, None
345
+
346
+ ll_single_group_kwargs = {
347
+ "moderators_coef": moderators_coef if self.moderators_coef_dim else None,
348
+ "coef_spline_bases": torch.tensor(
349
+ coef_spline_bases, dtype=torch.float64, device=self.device
350
+ ),
351
+ "group_moderators": group_moderators if self.moderators_coef_dim else None,
352
+ "group_foci_per_voxel": group_foci_per_voxel,
353
+ "group_foci_per_study": group_foci_per_study,
354
+ "device": self.device,
355
+ }
356
+
357
+ if hasattr(self, "overdispersion"):
358
+ ll_single_group_kwargs["group_overdispersion"] = self.overdispersion[group]
359
+
360
+ # create a negative log-likelihood function
361
+ def nll_spatial_coef(group_spatial_coef):
362
+ return -self._log_likelihood_single_group(
363
+ group_spatial_coef=group_spatial_coef,
364
+ **ll_single_group_kwargs,
365
+ )
366
+
367
+ f_spatial_coef = torch.func.hessian(nll_spatial_coef)(group_spatial_coef)
368
+ f_spatial_coef = f_spatial_coef.reshape((self.spatial_coef_dim, self.spatial_coef_dim))
369
+ cov_spatial_coef = np.linalg.inv(f_spatial_coef.detach().numpy())
370
+ var_spatial_coef = np.diag(cov_spatial_coef)
371
+ se_spatial_coef = np.sqrt(var_spatial_coef)
372
+ spatial_regression_coef_se[group] = se_spatial_coef
373
+
374
+ var_log_spatial_intensity = np.einsum(
375
+ "ij,ji->i",
376
+ coef_spline_bases,
377
+ cov_spatial_coef @ coef_spline_bases.T,
378
+ )
379
+ se_log_spatial_intensity = np.sqrt(var_log_spatial_intensity)
380
+ log_spatial_intensity_se[group] = se_log_spatial_intensity
381
+
382
+ group_studywise_spatial_intensity = np.exp(
383
+ np.matmul(coef_spline_bases, group_spatial_coef.detach().cpu().numpy().T)
384
+ ).flatten()
385
+ se_spatial_intensity = group_studywise_spatial_intensity * se_log_spatial_intensity
386
+ spatial_intensity_se[group] = se_spatial_intensity
387
+
388
+ # Inference on regression coefficient of moderators
389
+ if self.moderators_coef_dim:
390
+ # modify ll_single_group_kwargs so that spatial_coef is fixed
391
+ # and moderators_coef can vary
392
+ del ll_single_group_kwargs["moderators_coef"]
393
+ ll_single_group_kwargs["group_spatial_coef"] = group_spatial_coef
394
+
395
+ def nll_moderators_coef(moderators_coef):
396
+ return -self._log_likelihood_single_group(
397
+ moderators_coef=moderators_coef,
398
+ **ll_single_group_kwargs,
399
+ )
400
+
401
+ f_moderators_coef = torch.func.hessian(nll_moderators_coef)(moderators_coef)
402
+ f_moderators_coef = f_moderators_coef.reshape(
403
+ (self.moderators_coef_dim, self.moderators_coef_dim)
404
+ )
405
+ cov_moderators_coef = np.linalg.inv(f_moderators_coef.detach().numpy())
406
+ var_moderators = np.diag(cov_moderators_coef).reshape((1, self.moderators_coef_dim))
407
+ se_moderators = np.sqrt(var_moderators)
408
+ else:
409
+ se_moderators = None
410
+
411
+ self.spatial_regression_coef_se = spatial_regression_coef_se
412
+ self.log_spatial_intensity_se = log_spatial_intensity_se
413
+ self.spatial_intensity_se = spatial_intensity_se
414
+ self.se_moderators = se_moderators
415
+
416
+ def summary(self):
417
+ """Summarize the main results of the fitted model.
418
+
419
+ Summarize optimized regression coefficients from model and store in `tables`,
420
+ summarize standard error of regression coefficient and (Log-)spatial intensity
421
+ and store in `results`.
422
+ """
423
+ params = (
424
+ self.spatial_regression_coef,
425
+ self.spatial_intensity_estimation,
426
+ self.spatial_regression_coef_se,
427
+ self.log_spatial_intensity_se,
428
+ self.spatial_intensity_se,
429
+ )
430
+ if any([param is None for param in params]):
431
+ raise ValueError("Run fit first")
432
+ tables = dict()
433
+ # Extract optimized regression coefficients from model and store them in 'tables'
434
+ tables["spatial_regression_coef"] = pd.DataFrame.from_dict(
435
+ self.spatial_regression_coef, orient="index"
436
+ )
437
+ maps = self.spatial_intensity_estimation
438
+ if self.moderators_coef_dim:
439
+ tables["moderators_regression_coef"] = pd.DataFrame(
440
+ data=self.moderators_coef, columns=self.moderators
441
+ )
442
+ tables["moderators_effect"] = pd.DataFrame.from_dict(
443
+ data=self.moderators_effect, orient="index"
444
+ )
445
+
446
+ # Estimate standard error of regression coefficient and (Log-)spatial intensity and store
447
+ # them in 'tables'
448
+ tables["spatial_regression_coef_se"] = pd.DataFrame.from_dict(
449
+ self.spatial_regression_coef_se, orient="index"
450
+ )
451
+ tables["log_spatial_intensity_se"] = pd.DataFrame.from_dict(
452
+ self.log_spatial_intensity_se, orient="index"
453
+ )
454
+ tables["spatial_intensity_se"] = pd.DataFrame.from_dict(
455
+ self.spatial_intensity_se, orient="index"
456
+ )
457
+ if self.moderators_coef_dim:
458
+ tables["moderators_regression_se"] = pd.DataFrame(
459
+ data=self.se_moderators, columns=self.moderators
460
+ )
461
+ return maps, tables
462
+
463
+ def fisher_info_multiple_group_spatial(
464
+ self,
465
+ involved_groups,
466
+ coef_spline_bases,
467
+ moderators_by_group,
468
+ foci_per_voxel,
469
+ foci_per_study,
470
+ ):
471
+ """Estimate the Fisher information matrix of spatial regression coeffcients.
472
+
473
+ Fisher information matrix is estimated by negative Hessian of the log-likelihood.
474
+
475
+ Parameters
476
+ ----------
477
+ involved_groups : :obj:`list`
478
+ Group names involved in generalized linear hypothesis (GLH) testing in `CBMRInference`.
479
+ coef_spline_bases : :obj:`numpy.ndarray`
480
+ Coefficient of B-spline bases evaluated at each voxel.
481
+ moderators_by_group : :obj:`dict`, optional
482
+ Dictionary of group-wise study-level moderators. Default is None.
483
+ foci_per_voxel : :obj:`dict`
484
+ Dictionary of group-wise number of foci per voxel.
485
+ foci_per_study : :obj:`dict`
486
+ Dictionary of group-wise number of foci per study.
487
+
488
+ Returns
489
+ -------
490
+ numpy.ndarray
491
+ Fisher information matrix of spatial regression coefficients (for involved groups).
492
+ """
493
+ n_involved_groups = len(involved_groups)
494
+ involved_foci_per_voxel = [
495
+ torch.tensor(foci_per_voxel[group], dtype=torch.float64, device=self.device)
496
+ for group in involved_groups
497
+ ]
498
+ involved_foci_per_study = [
499
+ torch.tensor(foci_per_study[group], dtype=torch.float64, device=self.device)
500
+ for group in involved_groups
501
+ ]
502
+ spatial_coef = [self.spatial_coef_linears[group].weight.T for group in involved_groups]
503
+ spatial_coef = torch.stack(spatial_coef, dim=0)
504
+ if self.moderators_coef_dim:
505
+ involved_moderators_by_group = [
506
+ torch.tensor(moderators_by_group[group], dtype=torch.float64, device=self.device)
507
+ for group in involved_groups
508
+ ]
509
+ moderators_coef = torch.tensor(
510
+ self.moderators_coef.T, dtype=torch.float64, device=self.device
511
+ )
512
+ else:
513
+ involved_moderators_by_group, moderators_coef = None, None
514
+
515
+ ll_mult_group_kwargs = {
516
+ "moderator_coef": moderators_coef,
517
+ "coef_spline_bases": torch.tensor(
518
+ coef_spline_bases, dtype=torch.float64, device=self.device
519
+ ),
520
+ "foci_per_voxel": involved_foci_per_voxel,
521
+ "foci_per_study": involved_foci_per_study,
522
+ "moderators": involved_moderators_by_group,
523
+ "device": self.device,
524
+ }
525
+
526
+ if hasattr(self, "overdispersion"):
527
+ ll_mult_group_kwargs["overdispersion_coef"] = [
528
+ self.overdispersion[group] for group in involved_groups
529
+ ]
530
+
531
+ # create a negative log-likelihood function
532
+ def nll_spatial_coef(spatial_coef):
533
+ return -self._log_likelihood_mult_group(
534
+ spatial_coef=spatial_coef,
535
+ **ll_mult_group_kwargs,
536
+ )
537
+
538
+ h = torch.func.hessian(nll_spatial_coef)(spatial_coef)
539
+ h = h.view(n_involved_groups * self.spatial_coef_dim, -1)
540
+
541
+ return h.detach().cpu().numpy()
542
+
543
+ def fisher_info_multiple_group_moderator(
544
+ self, coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study
545
+ ):
546
+ """Estimate the Fisher information matrix of regression coefficients of moderators.
547
+
548
+ Fisher information matrix is estimated by negative Hessian of the log-likelihood.
549
+
550
+ Parameters
551
+ ----------
552
+ coef_spline_bases : :obj:`numpy.ndarray`
553
+ Coefficient of B-spline bases evaluated at each voxel.
554
+ moderators_by_group : :obj:`dict`, optional
555
+ Dictionary of group-wise study-level moderators. Default is None.
556
+ foci_per_voxel : :obj:`dict`
557
+ Dictionary of group-wise number of foci per voxel.
558
+ foci_per_study : :obj:`dict`
559
+ Dictionary of group-wise number of foci per study.
560
+
561
+ Returns
562
+ -------
563
+ numpy.ndarray
564
+ Fisher information matrix of study-level moderator regressors.
565
+ """
566
+ foci_per_voxel = [
567
+ torch.tensor(foci_per_voxel[group], dtype=torch.float64, device=self.device)
568
+ for group in self.groups
569
+ ]
570
+ foci_per_study = [
571
+ torch.tensor(foci_per_study[group], dtype=torch.float64, device=self.device)
572
+ for group in self.groups
573
+ ]
574
+ spatial_coef = [self.spatial_coef_linears[group].weight.T for group in self.groups]
575
+ spatial_coef = torch.stack(spatial_coef, dim=0)
576
+
577
+ if self.moderators_coef_dim:
578
+ moderators_by_group = [
579
+ torch.tensor(moderators_by_group[group], dtype=torch.float64, device=self.device)
580
+ for group in self.groups
581
+ ]
582
+ moderator_coef = torch.tensor(
583
+ self.moderators_coef.T, dtype=torch.float64, device=self.device
584
+ )
585
+ else:
586
+ moderators_by_group, moderator_coef = None, None
587
+
588
+ ll_mult_group_kwargs = {
589
+ "spatial_coef": spatial_coef,
590
+ "coef_spline_bases": torch.tensor(
591
+ coef_spline_bases, dtype=torch.float64, device=self.device
592
+ ),
593
+ "foci_per_voxel": foci_per_voxel,
594
+ "foci_per_study": foci_per_study,
595
+ "moderators": moderators_by_group,
596
+ "device": self.device,
597
+ }
598
+ if hasattr(self, "overdispersion"):
599
+ ll_mult_group_kwargs["overdispersion_coef"] = [
600
+ self.overdispersion[group] for group in self.groups
601
+ ]
602
+
603
+ # create a negative log-likelihood function w.r.t moderator coefficients
604
+ def nll_moderator_coef(moderator_coef):
605
+ return -self._log_likelihood_mult_group(
606
+ moderator_coef=moderator_coef,
607
+ **ll_mult_group_kwargs,
608
+ )
609
+
610
+ h = torch.func.hessian(nll_moderator_coef)(moderator_coef)
611
+ h = h.view(self.moderators_coef_dim, self.moderators_coef_dim)
612
+
613
+ return h.detach().cpu().numpy()
614
+
615
+ def firth_penalty(
616
+ self,
617
+ foci_per_voxel,
618
+ foci_per_study,
619
+ moderators,
620
+ coef_spline_bases,
621
+ overdispersion=False,
622
+ ):
623
+ """Compute Firth's penalized log-likelihood.
624
+
625
+ Parameters
626
+ ----------
627
+ foci_per_voxel : :obj:`dict`
628
+ Dictionary of group-wise number of foci per voxel.
629
+ foci_per_study : :obj:`dict`
630
+ Dictionary of group-wise number of foci per study.
631
+ moderators : :obj:`dict`, optional
632
+ Dictionary of group-wise study-level moderators. Default is None.
633
+ coef_spline_bases : :obj:`torch.Tensor`
634
+ Coefficient of B-spline bases evaluated at each voxel.
635
+ overdispersion : :obj:`bool`
636
+ Whether the model contains overdispersion parameter. Default is False.
637
+
638
+ Returns
639
+ -------
640
+ torch.Tensor
641
+ Firth-type regularization term.
642
+ """
643
+ group_firth_penalty = 0
644
+ for group in self.groups:
645
+ partial_kwargs = {"coef_spline_bases": coef_spline_bases}
646
+ if overdispersion:
647
+ partial_kwargs["group_overdispersion"] = self.overdispersion[group]
648
+ if getattr(self, "square_root", False):
649
+ partial_kwargs["group_overdispersion"] = (
650
+ partial_kwargs["group_overdispersion"] ** 2
651
+ )
652
+ partial_kwargs["group_foci_per_voxel"] = foci_per_voxel[group]
653
+ partial_kwargs["group_foci_per_study"] = foci_per_study[group]
654
+ if self.moderators_coef_dim:
655
+ moderators_coef = self.moderators_linear.weight
656
+ group_moderators = moderators[group]
657
+ else:
658
+ moderators_coef, group_moderators = None, None
659
+ partial_kwargs["moderators_coef"] = moderators_coef
660
+ partial_kwargs["group_moderators"] = group_moderators
661
+
662
+ # create a negative log-likelihood function w.r.t spatial coefficients
663
+ def nll_spatial_coef(group_spatial_coef):
664
+ return -self._log_likelihood_single_group(
665
+ group_spatial_coef=group_spatial_coef,
666
+ **partial_kwargs,
667
+ )
668
+
669
+ group_spatial_coef = self.spatial_coef_linears[group].weight
670
+ group_f = torch.autograd.functional.hessian(
671
+ nll_spatial_coef,
672
+ group_spatial_coef,
673
+ **self._hessian_kwargs,
674
+ )
675
+
676
+ group_f = group_f.reshape((self.spatial_coef_dim, self.spatial_coef_dim))
677
+ group_eig_vals = torch.real(torch.linalg.eigvals(group_f))
678
+ del group_f
679
+ group_firth_penalty = 0.5 * torch.sum(torch.log(group_eig_vals))
680
+ del group_eig_vals
681
+ group_firth_penalty += group_firth_penalty
682
+
683
+ return group_firth_penalty
684
+
685
+
686
+ class OverdispersionModelEstimator(GeneralLinearModelEstimator):
687
+ """Base class for CBMR models with over-dispersion parameter."""
688
+
689
+ def __init__(self, **kwargs):
690
+ self.square_root = kwargs.pop("square_root", False)
691
+ super().__init__(**kwargs)
692
+
693
+ def init_overdispersion_weights(self):
694
+ """Initialize weights for overdispersion parameters.
695
+
696
+ Default is 1e-2.
697
+ """
698
+ overdispersion = dict()
699
+ for group in self.groups:
700
+ # initialization for alpha
701
+ overdispersion_init_group = torch.tensor(1e-2).double()
702
+ if self.square_root:
703
+ overdispersion_init_group = torch.sqrt(overdispersion_init_group)
704
+ overdispersion[group] = torch.nn.Parameter(
705
+ overdispersion_init_group, requires_grad=True
706
+ )
707
+ self.overdispersion = torch.nn.ParameterDict(overdispersion)
708
+
709
+ def init_weights(self, groups, moderators, spatial_coef_dim, moderators_coef_dim):
710
+ """Initialize weights for spatial and study-level moderator coefficients."""
711
+ super().init_weights(groups, moderators, spatial_coef_dim, moderators_coef_dim)
712
+ self.init_overdispersion_weights()
713
+
714
+ def inference_outcome(
715
+ self, coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study
716
+ ):
717
+ """Summarize inference outcome into `maps` and `tables`.
718
+
719
+ Add optimized overdispersion parameter to the tables.
720
+ """
721
+ maps, tables = super().inference_outcome(
722
+ coef_spline_bases, moderators_by_group, foci_per_voxel, foci_per_study
723
+ )
724
+ overdispersion_param = dict()
725
+ for group in self.groups:
726
+ group_overdispersion = self.overdispersion[group]
727
+ group_overdispersion = group_overdispersion.cpu().detach().numpy()
728
+ overdispersion_param[group] = group_overdispersion
729
+ tables["overdispersion_coef"] = pd.DataFrame.from_dict(
730
+ overdispersion_param, orient="index", columns=["overdispersion"]
731
+ )
732
+
733
+ return maps, tables
734
+
735
+
736
+ class PoissonEstimator(GeneralLinearModelEstimator):
737
+ """CBMR framework with Poisson model.
738
+
739
+ Poisson model is the most basic model for Coordinate-based Meta-regression (CBMR).
740
+ It's based on the assumption that foci arise from a realisation of a (continues)
741
+ inhomogeneous Poisson process, so that the (discrete) voxel-wise foci counts will
742
+ be independently distributed as Poisson random variables, with rate equal to the
743
+ integral of the (true, unobserved, continous) intensity function over each voxels.
744
+ """
745
+
746
+ def __init__(self, **kwargs):
747
+ super().__init__(**kwargs)
748
+
749
+ def _log_likelihood_single_group(
750
+ self,
751
+ group_spatial_coef,
752
+ moderators_coef,
753
+ coef_spline_bases,
754
+ group_moderators,
755
+ group_foci_per_voxel,
756
+ group_foci_per_study,
757
+ device="cpu",
758
+ ):
759
+ log_mu_spatial = torch.matmul(coef_spline_bases, group_spatial_coef.T)
760
+ mu_spatial = torch.exp(log_mu_spatial)
761
+ if moderators_coef is None:
762
+ n_study, _ = group_foci_per_study.shape
763
+ log_mu_moderators = torch.tensor(
764
+ [0] * n_study, dtype=torch.float64, device=device
765
+ ).reshape((-1, 1))
766
+ mu_moderators = torch.exp(log_mu_moderators)
767
+ else:
768
+ log_mu_moderators = torch.matmul(group_moderators, moderators_coef.T)
769
+ mu_moderators = torch.exp(log_mu_moderators)
770
+ log_l = (
771
+ torch.sum(torch.mul(group_foci_per_voxel, log_mu_spatial))
772
+ + torch.sum(torch.mul(group_foci_per_study, log_mu_moderators))
773
+ - torch.sum(mu_spatial) * torch.sum(mu_moderators)
774
+ )
775
+ return log_l
776
+
777
+ def _log_likelihood_mult_group(
778
+ self,
779
+ spatial_coef,
780
+ moderator_coef,
781
+ coef_spline_bases,
782
+ foci_per_voxel,
783
+ foci_per_study,
784
+ moderators,
785
+ device="cpu",
786
+ ):
787
+ n_groups = len(spatial_coef)
788
+ log_spatial_intensity = [
789
+ torch.matmul(coef_spline_bases, spatial_coef[i, :, :]) for i in range(n_groups)
790
+ ]
791
+ spatial_intensity = [
792
+ torch.exp(group_log_spatial_intensity)
793
+ for group_log_spatial_intensity in log_spatial_intensity
794
+ ]
795
+ if moderator_coef is not None:
796
+ log_moderator_effect = [
797
+ torch.matmul(group_moderator, moderator_coef) for group_moderator in moderators
798
+ ]
799
+ moderator_effect = [
800
+ torch.exp(group_log_moderator_effect)
801
+ for group_log_moderator_effect in log_moderator_effect
802
+ ]
803
+ else:
804
+ log_moderator_effect = [
805
+ torch.tensor(
806
+ [0] * foci_per_study_i.shape[0], dtype=torch.float64, device=device
807
+ ).reshape((-1, 1))
808
+ for foci_per_study_i in foci_per_study
809
+ ]
810
+ moderator_effect = [
811
+ torch.exp(group_log_moderator_effect)
812
+ for group_log_moderator_effect in log_moderator_effect
813
+ ]
814
+ log_l = 0
815
+ for i in range(n_groups):
816
+ log_l += (
817
+ torch.sum(foci_per_voxel[i] * log_spatial_intensity[i])
818
+ + torch.sum(foci_per_study[i] * log_moderator_effect[i])
819
+ - torch.sum(spatial_intensity[i]) * torch.sum(moderator_effect[i])
820
+ )
821
+ return log_l
822
+
823
+ def forward(self, coef_spline_bases, moderators, foci_per_voxel, foci_per_study):
824
+ """Define the loss function (nagetive log-likelihood function) for Poisson model.
825
+
826
+ Model refactorization is applied to reduce the dimensionality of variables.
827
+
828
+ Returns
829
+ -------
830
+ torch.Tensor
831
+ Loss (nagative log-likelihood) of Poisson model at current iteration.
832
+ """
833
+ log_l = 0
834
+ for group in self.groups:
835
+ group_spatial_coef = self.spatial_coef_linears[group].weight
836
+ group_foci_per_voxel = foci_per_voxel[group]
837
+ group_foci_per_study = foci_per_study[group]
838
+ if isinstance(moderators, dict):
839
+ moderators_coef = self.moderators_linear.weight
840
+ group_moderators = moderators[group]
841
+ else:
842
+ moderators_coef, group_moderators = None, None
843
+ group_log_l = self._log_likelihood_single_group(
844
+ group_spatial_coef,
845
+ moderators_coef,
846
+ coef_spline_bases,
847
+ group_moderators,
848
+ group_foci_per_voxel,
849
+ group_foci_per_study,
850
+ )
851
+ log_l += group_log_l
852
+
853
+ if self.penalty:
854
+ # Firth-type penalty
855
+ log_l += self.firth_penalty(
856
+ foci_per_voxel,
857
+ foci_per_study,
858
+ moderators,
859
+ coef_spline_bases,
860
+ overdispersion=False,
861
+ )
862
+ return -log_l
863
+
864
+
865
+ class NegativeBinomialEstimator(OverdispersionModelEstimator):
866
+ """CBMR framework with Negative Binomial (NB) model.
867
+
868
+ Negative Binomial (NB) model is a generalized Poisson model with overdispersion.
869
+ It's a more flexible model, but more difficult to estimate. In practice, foci
870
+ counts often display over-dispersion (the variance of response variable
871
+ substantially exceeeds the mean), which is not captured by Poisson model.
872
+ """
873
+
874
+ def __init__(self, **kwargs):
875
+ kwargs["square_root"] = True
876
+ super().__init__(**kwargs)
877
+
878
+ def _three_term(self, y, r):
879
+ max_foci = torch.max(y).to(dtype=torch.int64, device=self.device)
880
+ sum_three_term = 0
881
+ for k in range(max_foci):
882
+ foci_index = (y == k + 1).nonzero()[:, 0]
883
+ r_j = r[foci_index]
884
+ n_voxel = list(foci_index.shape)[0]
885
+ y_j = torch.tensor([k + 1] * n_voxel, device=self.device).double()
886
+ y_j = y_j.reshape((n_voxel, 1))
887
+ # y=0 => sum_three_term = 0
888
+ sum_three_term += torch.sum(
889
+ torch.lgamma(y_j + r_j) - torch.lgamma(y_j + 1) - torch.lgamma(r_j)
890
+ )
891
+
892
+ return sum_three_term
893
+
894
+ def _log_likelihood_single_group(
895
+ self,
896
+ group_overdispersion,
897
+ group_spatial_coef,
898
+ moderators_coef,
899
+ coef_spline_bases,
900
+ group_moderators,
901
+ group_foci_per_voxel,
902
+ group_foci_per_study,
903
+ device="cpu",
904
+ ):
905
+ log_mu_spatial = torch.matmul(coef_spline_bases, group_spatial_coef.T)
906
+ mu_spatial = torch.exp(log_mu_spatial)
907
+ if moderators_coef is not None:
908
+ log_mu_moderators = torch.matmul(group_moderators, moderators_coef.T)
909
+ mu_moderators = torch.exp(log_mu_moderators)
910
+ else:
911
+ n_study, _ = group_foci_per_study.shape
912
+ log_mu_moderators = torch.tensor(
913
+ [0] * n_study, dtype=torch.float64, device=device
914
+ ).reshape((-1, 1))
915
+ mu_moderators = torch.exp(log_mu_moderators)
916
+ # parameter of a NB variable to approximate a sum of NB variables
917
+ r = 1 / group_overdispersion * torch.sum(mu_moderators) ** 2 / torch.sum(mu_moderators**2)
918
+ p = 1 / (
919
+ 1
920
+ + torch.sum(mu_moderators)
921
+ / (group_overdispersion * mu_spatial * torch.sum(mu_moderators**2))
922
+ )
923
+ # log-likelihood (moment matching approach)
924
+ log_l = torch.sum(
925
+ torch.lgamma(group_foci_per_voxel + r)
926
+ - torch.lgamma(group_foci_per_voxel + 1)
927
+ - torch.lgamma(r)
928
+ + r * torch.log(1 - p)
929
+ + group_foci_per_voxel * torch.log(p)
930
+ )
931
+
932
+ return log_l
933
+
934
+ def _log_likelihood_mult_group(
935
+ self,
936
+ overdispersion_coef,
937
+ spatial_coef,
938
+ coef_spline_bases,
939
+ foci_per_voxel,
940
+ foci_per_study,
941
+ moderator_coef=None,
942
+ moderators=None,
943
+ device="cpu",
944
+ ):
945
+ n_groups = len(foci_per_voxel)
946
+ log_spatial_intensity = [
947
+ torch.matmul(coef_spline_bases, spatial_coef[i, :, :]) for i in range(n_groups)
948
+ ]
949
+ spatial_intensity = [
950
+ torch.exp(group_log_spatial_intensity)
951
+ for group_log_spatial_intensity in log_spatial_intensity
952
+ ]
953
+ if moderator_coef is not None:
954
+ log_moderator_effect = [
955
+ torch.matmul(group_moderator, moderator_coef) for group_moderator in moderators
956
+ ]
957
+ moderator_effect = [
958
+ torch.exp(group_log_moderator_effect)
959
+ for group_log_moderator_effect in log_moderator_effect
960
+ ]
961
+ else:
962
+ log_moderator_effect = [
963
+ torch.tensor(
964
+ [0] * foci_per_study.shape[0], dtype=torch.float64, device=device
965
+ ).reshape((-1, 1))
966
+ for foci_per_study in foci_per_study
967
+ ]
968
+ moderator_effect = [
969
+ torch.exp(group_log_moderator_effect)
970
+ for group_log_moderator_effect in log_moderator_effect
971
+ ]
972
+ # After similification, we have:
973
+ # r' = 1/alpha * sum(mu^Z_i)^2 / sum((mu^Z_i)^2)
974
+ # p'_j = 1 / (1 + sum(mu^Z_i) / (alpha * mu^X_j * sum((mu^Z_i)^2)
975
+ r = [
976
+ 1
977
+ / overdispersion_coef[i]
978
+ * torch.sum(moderator_effect[i]) ** 2
979
+ / torch.sum(moderator_effect[i] ** 2)
980
+ for i in range(n_groups)
981
+ ]
982
+ p_frac = [
983
+ torch.sum(moderator_effect[i])
984
+ / (overdispersion_coef[i] * spatial_intensity[i] * torch.sum(moderator_effect[i] ** 2))
985
+ for i in range(n_groups)
986
+ ]
987
+ p = [1 / (1 + p_frac[i]) for i in range(n_groups)]
988
+
989
+ log_l = 0
990
+ for i in range(n_groups):
991
+ group_log_l = torch.sum(
992
+ torch.lgamma(foci_per_voxel[i] + r[i])
993
+ - torch.lgamma(foci_per_voxel[i] + 1)
994
+ - torch.lgamma(r[i])
995
+ + r[i] * torch.log(1 - p[i])
996
+ + foci_per_voxel[i] * torch.log(p[i])
997
+ )
998
+ log_l += group_log_l
999
+
1000
+ return log_l
1001
+
1002
+ def forward(self, coef_spline_bases, moderators, foci_per_voxel, foci_per_study):
1003
+ """Define the loss function (nagetive log-likelihood function) for NB model.
1004
+
1005
+ Model refactorization is applied to reduce the dimensionality of variables.
1006
+
1007
+ Returns
1008
+ -------
1009
+ torch.Tensor
1010
+ Loss (nagative log-likelihood) of NB model at current iteration.
1011
+ """
1012
+ log_l = 0
1013
+ for group in self.groups:
1014
+ group_overdispersion = self.overdispersion[group] ** 2
1015
+ group_spatial_coef = self.spatial_coef_linears[group].weight
1016
+ group_foci_per_voxel = foci_per_voxel[group]
1017
+ group_foci_per_study = foci_per_study[group]
1018
+ if isinstance(moderators, dict):
1019
+ moderators_coef = self.moderators_linear.weight
1020
+ group_moderators = moderators[group]
1021
+ else:
1022
+ moderators_coef, group_moderators = None, None
1023
+ group_log_l = self._log_likelihood_single_group(
1024
+ group_overdispersion,
1025
+ group_spatial_coef,
1026
+ moderators_coef,
1027
+ coef_spline_bases,
1028
+ group_moderators,
1029
+ group_foci_per_voxel,
1030
+ group_foci_per_study,
1031
+ )
1032
+
1033
+ log_l += group_log_l
1034
+
1035
+ if self.penalty:
1036
+ # Firth-type penalty
1037
+ log_l += self.firth_penalty(
1038
+ foci_per_voxel,
1039
+ foci_per_study,
1040
+ moderators,
1041
+ coef_spline_bases,
1042
+ overdispersion=True,
1043
+ )
1044
+
1045
+ return -log_l
1046
+
1047
+
1048
+ class ClusteredNegativeBinomialEstimator(OverdispersionModelEstimator):
1049
+ """CBMR framework with Clustered Negative Binomial (Clustered NB) model.
1050
+
1051
+ Clustered NB model can also accommodate over-dispersion in foci counts.
1052
+ In NB model, the latent Gamma random variable introduces indepdentent variation
1053
+ at each voxel. While in Clustered NB model, we assert the random effects are not
1054
+ independent voxelwise effects, but rather latent characteristics of each study,
1055
+ and represent a shared effect over the entire brain for a given study.
1056
+ """
1057
+
1058
+ def __init__(self, **kwargs):
1059
+ kwargs["square_root"] = False
1060
+ super().__init__(**kwargs)
1061
+
1062
+ def _log_likelihood_single_group(
1063
+ self,
1064
+ group_overdispersion,
1065
+ group_spatial_coef,
1066
+ moderators_coef,
1067
+ coef_spline_bases,
1068
+ group_moderators,
1069
+ group_foci_per_voxel,
1070
+ group_foci_per_study,
1071
+ device="cpu",
1072
+ ):
1073
+ v = 1 / group_overdispersion
1074
+ log_mu_spatial = torch.matmul(coef_spline_bases, group_spatial_coef.T)
1075
+ mu_spatial = torch.exp(log_mu_spatial)
1076
+ if moderators_coef is not None:
1077
+ log_mu_moderators = torch.matmul(group_moderators, moderators_coef.T)
1078
+ mu_moderators = torch.exp(log_mu_moderators)
1079
+ else:
1080
+ n_study, _ = group_foci_per_study.shape
1081
+ log_mu_moderators = torch.tensor(
1082
+ [0] * n_study, dtype=torch.float64, device=device
1083
+ ).reshape((-1, 1))
1084
+ mu_moderators = torch.exp(log_mu_moderators)
1085
+ mu_sum_per_study = torch.sum(mu_spatial) * mu_moderators
1086
+ group_n_study, _ = group_foci_per_study.shape
1087
+
1088
+ log_l = (
1089
+ group_n_study * v * torch.log(v)
1090
+ - group_n_study * torch.lgamma(v)
1091
+ + torch.sum(torch.lgamma(group_foci_per_study + v))
1092
+ - torch.sum((group_foci_per_study + v) * torch.log(mu_sum_per_study + v))
1093
+ + torch.sum(group_foci_per_voxel * log_mu_spatial)
1094
+ + torch.sum(group_foci_per_study * log_mu_moderators)
1095
+ )
1096
+
1097
+ return log_l
1098
+
1099
+ def _log_likelihood_mult_group(
1100
+ self,
1101
+ overdispersion_coef,
1102
+ spatial_coef,
1103
+ coef_spline_bases,
1104
+ foci_per_voxel,
1105
+ foci_per_study,
1106
+ moderator_coef=None,
1107
+ moderators=None,
1108
+ device="cpu",
1109
+ ):
1110
+ n_groups = len(foci_per_voxel)
1111
+ v = [1 / group_overdispersion_coef for group_overdispersion_coef in overdispersion_coef]
1112
+ # estimated intensity and log estimated intensity
1113
+ log_spatial_intensity = [
1114
+ torch.matmul(coef_spline_bases, spatial_coef[i, :, :]) for i in range(n_groups)
1115
+ ]
1116
+ spatial_intensity = [
1117
+ torch.exp(group_log_spatial_intensity)
1118
+ for group_log_spatial_intensity in log_spatial_intensity
1119
+ ]
1120
+ if moderator_coef is not None:
1121
+ log_moderator_effect = [
1122
+ torch.matmul(group_moderator, moderator_coef) for group_moderator in moderators
1123
+ ]
1124
+ moderator_effect = [
1125
+ torch.exp(group_log_moderator_effect)
1126
+ for group_log_moderator_effect in log_moderator_effect
1127
+ ]
1128
+ else:
1129
+ log_moderator_effect = [
1130
+ torch.tensor(
1131
+ [0] * foci_per_study.shape[0], dtype=torch.float64, device=device
1132
+ ).reshape((-1, 1))
1133
+ for foci_per_study in foci_per_study
1134
+ ]
1135
+ moderator_effect = [
1136
+ torch.exp(group_log_moderator_effect)
1137
+ for group_log_moderator_effect in log_moderator_effect
1138
+ ]
1139
+ mu_sum_per_study = [
1140
+ torch.sum(spatial_intensity[i]) * moderator_effect[i] for i in range(n_groups)
1141
+ ]
1142
+ n_study_list = [group_foci_per_study.shape[0] for group_foci_per_study in foci_per_study]
1143
+
1144
+ log_l = 0
1145
+ for i in range(n_groups):
1146
+ log_l += (
1147
+ n_study_list[i] * v[i] * torch.log(v[i])
1148
+ - n_study_list[i] * torch.lgamma(v[i])
1149
+ + torch.sum(torch.lgamma(foci_per_study[i] + v[i]))
1150
+ - torch.sum((foci_per_study[i] + v[i]) * torch.log(mu_sum_per_study[i] + v[i]))
1151
+ + torch.sum(foci_per_voxel[i] * log_spatial_intensity[i])
1152
+ + torch.sum(foci_per_study[i] * log_moderator_effect[i])
1153
+ )
1154
+
1155
+ return log_l
1156
+
1157
+ def forward(self, coef_spline_bases, moderators, foci_per_voxel, foci_per_study):
1158
+ """Define the loss function (nagetive log-likelihood function) for Clustered NB model.
1159
+
1160
+ Model refactorization is applied to reduce the dimensionality of variables.
1161
+
1162
+ Returns
1163
+ -------
1164
+ torch.Tensor
1165
+ Loss (nagative log-likelihood) of Poisson model at current iteration.
1166
+ """
1167
+ log_l = 0
1168
+ for group in self.groups:
1169
+ group_overdispersion = self.overdispersion[group]
1170
+ group_spatial_coef = self.spatial_coef_linears[group].weight
1171
+ group_foci_per_voxel = foci_per_voxel[group]
1172
+ group_foci_per_study = foci_per_study[group]
1173
+ if isinstance(moderators, dict):
1174
+ moderators_coef = self.moderators_linear.weight
1175
+ group_moderators = moderators[group]
1176
+ else:
1177
+ moderators_coef, group_moderators = None, None
1178
+ group_log_l = self._log_likelihood_single_group(
1179
+ group_overdispersion,
1180
+ group_spatial_coef,
1181
+ moderators_coef,
1182
+ coef_spline_bases,
1183
+ group_moderators,
1184
+ group_foci_per_voxel,
1185
+ group_foci_per_study,
1186
+ )
1187
+ log_l += group_log_l
1188
+
1189
+ if self.penalty:
1190
+ # Firth-type penalty
1191
+ log_l += self.firth_penalty(
1192
+ foci_per_voxel,
1193
+ foci_per_study,
1194
+ moderators,
1195
+ coef_spline_bases,
1196
+ overdispersion=True,
1197
+ )
1198
+
1199
+ return -log_l