junifer 0.0.5.dev202__py3-none-any.whl → 0.0.5.dev219__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- junifer/_version.py +2 -2
- junifer/external/nilearn/__init__.py +2 -1
- junifer/external/nilearn/junifer_connectivity_measure.py +483 -0
- junifer/external/nilearn/tests/test_junifer_connectivity_measure.py +1089 -0
- junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py +31 -15
- junifer/markers/functional_connectivity/edge_functional_connectivity_parcels.py +26 -22
- junifer/markers/functional_connectivity/edge_functional_connectivity_spheres.py +33 -27
- junifer/markers/functional_connectivity/functional_connectivity_base.py +42 -30
- junifer/markers/functional_connectivity/functional_connectivity_parcels.py +25 -19
- junifer/markers/functional_connectivity/functional_connectivity_spheres.py +31 -24
- junifer/markers/functional_connectivity/tests/test_crossparcellation_functional_connectivity.py +3 -3
- junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_parcels.py +21 -4
- junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_spheres.py +22 -9
- junifer/markers/functional_connectivity/tests/test_functional_connectivity_parcels.py +29 -8
- junifer/markers/functional_connectivity/tests/test_functional_connectivity_spheres.py +30 -61
- {junifer-0.0.5.dev202.dist-info → junifer-0.0.5.dev219.dist-info}/METADATA +1 -1
- {junifer-0.0.5.dev202.dist-info → junifer-0.0.5.dev219.dist-info}/RECORD +22 -20
- {junifer-0.0.5.dev202.dist-info → junifer-0.0.5.dev219.dist-info}/WHEEL +1 -1
- {junifer-0.0.5.dev202.dist-info → junifer-0.0.5.dev219.dist-info}/AUTHORS.rst +0 -0
- {junifer-0.0.5.dev202.dist-info → junifer-0.0.5.dev219.dist-info}/LICENSE.md +0 -0
- {junifer-0.0.5.dev202.dist-info → junifer-0.0.5.dev219.dist-info}/entry_points.txt +0 -0
- {junifer-0.0.5.dev202.dist-info → junifer-0.0.5.dev219.dist-info}/top_level.txt +0 -0
junifer/_version.py
CHANGED
@@ -12,5 +12,5 @@ __version__: str
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
13
13
|
version_tuple: VERSION_TUPLE
|
14
14
|
|
15
|
-
__version__ = version = '0.0.5.
|
16
|
-
__version_tuple__ = version_tuple = (0, 0, 5, '
|
15
|
+
__version__ = version = '0.0.5.dev219'
|
16
|
+
__version_tuple__ = version_tuple = (0, 0, 5, 'dev219')
|
@@ -4,6 +4,7 @@
|
|
4
4
|
# License: AGPL
|
5
5
|
|
6
6
|
from .junifer_nifti_spheres_masker import JuniferNiftiSpheresMasker
|
7
|
+
from .junifer_connectivity_measure import JuniferConnectivityMeasure
|
7
8
|
|
8
9
|
|
9
|
-
__all__ = ["JuniferNiftiSpheresMasker"]
|
10
|
+
__all__ = ["JuniferNiftiSpheresMasker", "JuniferConnectivityMeasure"]
|
@@ -0,0 +1,483 @@
|
|
1
|
+
"""Provide JuniferConnectivityMeasure class."""
|
2
|
+
|
3
|
+
# Authors: Synchon Mandal <s.mandal@fz-juelich.de>
|
4
|
+
# License: AGPL
|
5
|
+
|
6
|
+
from typing import Callable, List, Optional
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from nilearn import signal
|
10
|
+
from nilearn.connectome import (
|
11
|
+
ConnectivityMeasure,
|
12
|
+
cov_to_corr,
|
13
|
+
prec_to_partial,
|
14
|
+
sym_matrix_to_vec,
|
15
|
+
)
|
16
|
+
from scipy import linalg
|
17
|
+
from sklearn.base import clone
|
18
|
+
from sklearn.covariance import EmpiricalCovariance
|
19
|
+
|
20
|
+
from ...utils import logger, raise_error, warn_with_log
|
21
|
+
|
22
|
+
|
23
|
+
__all__ = ["JuniferConnectivityMeasure"]
|
24
|
+
|
25
|
+
|
26
|
+
DEFAULT_COV_ESTIMATOR = EmpiricalCovariance(store_precision=False)
|
27
|
+
|
28
|
+
|
29
|
+
# New BSD License
|
30
|
+
|
31
|
+
# Copyright (c) The nilearn developers.
|
32
|
+
# All rights reserved.
|
33
|
+
|
34
|
+
|
35
|
+
# Redistribution and use in source and binary forms, with or without
|
36
|
+
# modification, are permitted provided that the following conditions are met:
|
37
|
+
|
38
|
+
# a. Redistributions of source code must retain the above copyright notice,
|
39
|
+
# this list of conditions and the following disclaimer.
|
40
|
+
# b. Redistributions in binary form must reproduce the above copyright
|
41
|
+
# notice, this list of conditions and the following disclaimer in the
|
42
|
+
# documentation and/or other materials provided with the distribution.
|
43
|
+
# c. Neither the name of the nilearn developers nor the names of
|
44
|
+
# its contributors may be used to endorse or promote products
|
45
|
+
# derived from this software without specific prior written
|
46
|
+
# permission.
|
47
|
+
|
48
|
+
|
49
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
50
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
51
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
52
|
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
|
53
|
+
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
54
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
55
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
56
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
|
57
|
+
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
|
58
|
+
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
|
59
|
+
# DAMAGE.
|
60
|
+
|
61
|
+
|
62
|
+
def _check_square(matrix: np.ndarray) -> None:
|
63
|
+
"""Raise a ValueError if the input matrix is square.
|
64
|
+
|
65
|
+
Parameters
|
66
|
+
----------
|
67
|
+
matrix : numpy.ndarray
|
68
|
+
Input array.
|
69
|
+
|
70
|
+
Raises
|
71
|
+
------
|
72
|
+
ValueError
|
73
|
+
If ``matrix`` is not a square matrix.
|
74
|
+
|
75
|
+
"""
|
76
|
+
if matrix.ndim != 2 or (matrix.shape[0] != matrix.shape[-1]):
|
77
|
+
raise_error(
|
78
|
+
f"Expected a square matrix, got array of shape {matrix.shape}."
|
79
|
+
)
|
80
|
+
|
81
|
+
|
82
|
+
def is_spd(M: np.ndarray, decimal: int = 15) -> bool: # noqa: N803
|
83
|
+
"""Check that input matrix is symmetric positive definite.
|
84
|
+
|
85
|
+
``M`` must be symmetric down to specified ``decimal`` places.
|
86
|
+
The check is performed by checking that all eigenvalues are positive.
|
87
|
+
|
88
|
+
Parameters
|
89
|
+
----------
|
90
|
+
M : numpy.ndarray
|
91
|
+
Input matrix to check for symmetric positive definite.
|
92
|
+
decimal : int, optional
|
93
|
+
Decimal places to check (default 15).
|
94
|
+
|
95
|
+
Returns
|
96
|
+
-------
|
97
|
+
bool
|
98
|
+
True if matrix is symmetric positive definite, False otherwise.
|
99
|
+
|
100
|
+
"""
|
101
|
+
if not np.allclose(M, M.T, atol=0, rtol=10**-decimal):
|
102
|
+
logger.debug(f"matrix not symmetric to {decimal:d} decimals")
|
103
|
+
return False
|
104
|
+
eigvalsh = np.linalg.eigvalsh(M)
|
105
|
+
ispd = eigvalsh.min() > 0
|
106
|
+
if not ispd:
|
107
|
+
logger.debug(f"matrix has a negative eigenvalue: {eigvalsh.min():.3f}")
|
108
|
+
return ispd
|
109
|
+
|
110
|
+
|
111
|
+
def _check_spd(matrix: np.ndarray) -> None:
|
112
|
+
"""Check ``matrix`` is symmetric positive definite.
|
113
|
+
|
114
|
+
Parameters
|
115
|
+
----------
|
116
|
+
matrix : numpy.ndarray
|
117
|
+
Input array.
|
118
|
+
|
119
|
+
Raises
|
120
|
+
------
|
121
|
+
ValueError
|
122
|
+
If the input matrix is not symmetric positive definite.
|
123
|
+
|
124
|
+
"""
|
125
|
+
if not is_spd(matrix, decimal=7):
|
126
|
+
raise_error("Expected a symmetric positive definite matrix.")
|
127
|
+
|
128
|
+
|
129
|
+
def _form_symmetric(
|
130
|
+
function: Callable[[np.ndarray], np.ndarray],
|
131
|
+
eigenvalues: np.ndarray,
|
132
|
+
eigenvectors: np.ndarray,
|
133
|
+
) -> np.ndarray:
|
134
|
+
"""Return the symmetric matrix.
|
135
|
+
|
136
|
+
Apply ``function`` to ``eigenvalues``, construct symmetric matrix with it
|
137
|
+
and ``eigenvectors`` and return the constructed symmetric matrix.
|
138
|
+
|
139
|
+
Parameters
|
140
|
+
----------
|
141
|
+
function : callable (function numpy.ndarray -> numpy.ndarray)
|
142
|
+
The transform to apply to the eigenvalues.
|
143
|
+
eigenvalues : numpy.ndarray of shape (n_features, )
|
144
|
+
Input argument of the function.
|
145
|
+
eigenvectors : numpy.ndarray of shape (n_features, n_features)
|
146
|
+
Unitary matrix.
|
147
|
+
|
148
|
+
Returns
|
149
|
+
-------
|
150
|
+
numpy.ndarray of shape (n_features, n_features)
|
151
|
+
The symmetric matrix obtained after transforming the eigenvalues, while
|
152
|
+
keeping the same eigenvectors.
|
153
|
+
|
154
|
+
"""
|
155
|
+
return np.dot(eigenvectors * function(eigenvalues), eigenvectors.T)
|
156
|
+
|
157
|
+
|
158
|
+
def _map_eigenvalues(
|
159
|
+
function: Callable[[np.ndarray], np.ndarray], symmetric: np.ndarray
|
160
|
+
) -> np.ndarray:
|
161
|
+
"""Matrix function, for real symmetric matrices.
|
162
|
+
|
163
|
+
The function is applied to the eigenvalues of ``symmetric``.
|
164
|
+
|
165
|
+
Parameters
|
166
|
+
----------
|
167
|
+
function : callable (function numpy.ndarray -> numpy.ndarray)
|
168
|
+
The transform to apply to the eigenvalues.
|
169
|
+
symmetric : numpy.ndarray of shape (n_features, n_features)
|
170
|
+
The input symmetric matrix.
|
171
|
+
|
172
|
+
Returns
|
173
|
+
-------
|
174
|
+
numpy.ndarray of shape (n_features, n_features)
|
175
|
+
The new symmetric matrix obtained after transforming the eigenvalues,
|
176
|
+
while keeping the same eigenvectors.
|
177
|
+
|
178
|
+
Notes
|
179
|
+
-----
|
180
|
+
If input matrix is not real symmetric, no error is reported but result will
|
181
|
+
be wrong.
|
182
|
+
|
183
|
+
"""
|
184
|
+
eigenvalues, eigenvectors = linalg.eigh(symmetric)
|
185
|
+
return _form_symmetric(function, eigenvalues, eigenvectors)
|
186
|
+
|
187
|
+
|
188
|
+
def _geometric_mean(
|
189
|
+
matrices: List[np.ndarray],
|
190
|
+
init: Optional[np.ndarray] = None,
|
191
|
+
max_iter: int = 10,
|
192
|
+
tol: Optional[float] = 1e-7,
|
193
|
+
) -> np.ndarray:
|
194
|
+
"""Compute the geometric mean of symmetric positive definite matrices.
|
195
|
+
|
196
|
+
The geometric mean of ``n`` positive definite matrices
|
197
|
+
``M_1, ..., M_n`` is the minimizer of the sum of squared distances from an
|
198
|
+
arbitrary matrix to each input matrix ``M_k``
|
199
|
+
|
200
|
+
.. math:: gmean(M_1, ..., M_n) = argmin_X sum_{k=1}^N dist(X, M_k)^2
|
201
|
+
|
202
|
+
where the used distance is related to matrices logarithm
|
203
|
+
|
204
|
+
.. math:: dist(X, M_k) = ||log(X^{-1/2} M_k X^{-1/2)}||
|
205
|
+
|
206
|
+
In case of positive numbers, this mean is the usual geometric mean.
|
207
|
+
|
208
|
+
See Algorithm 3 of [1]_ .
|
209
|
+
|
210
|
+
Parameters
|
211
|
+
----------
|
212
|
+
matrices : list of numpy.ndarray, all of shape (n_features, n_features)
|
213
|
+
List of matrices whose geometric mean to compute. Raise an error if the
|
214
|
+
matrices are not all symmetric positive definite of the same shape.
|
215
|
+
init : numpy.ndarray of shape (n_features, n_features), optional
|
216
|
+
Initialization matrix, default to the arithmetic mean of matrices.
|
217
|
+
Raise an error if the matrix is not symmetric positive definite of the
|
218
|
+
same shape as the elements of matrices (default None).
|
219
|
+
max_iter : int, optional
|
220
|
+
Maximal number of iterations (default 10).
|
221
|
+
tol : positive float or None, optional
|
222
|
+
The tolerance to declare convergence: if the gradient norm goes below
|
223
|
+
this value, the gradient descent is stopped. If None, no check is
|
224
|
+
performed (default 1e-7).
|
225
|
+
|
226
|
+
Returns
|
227
|
+
-------
|
228
|
+
gmean : numpy.ndarray of shape (n_features, n_features)
|
229
|
+
Geometric mean of the matrices.
|
230
|
+
|
231
|
+
References
|
232
|
+
----------
|
233
|
+
.. [1] Fletcher, T., P., Joshi, S.
|
234
|
+
Riemannian geometry for the statistical analysis of diffusion tensor
|
235
|
+
data.
|
236
|
+
Signal Processing, Volume 87, Issue 2, 2007, Pages 250-262
|
237
|
+
https://doi.org/10.1016/j.sigpro.2005.12.018.
|
238
|
+
|
239
|
+
"""
|
240
|
+
# Shape and symmetry positive definiteness checks
|
241
|
+
n_features = matrices[0].shape[0]
|
242
|
+
for matrix in matrices:
|
243
|
+
_check_square(matrix)
|
244
|
+
if matrix.shape[0] != n_features:
|
245
|
+
raise_error("Matrices are not of the same shape.")
|
246
|
+
_check_spd(matrix)
|
247
|
+
|
248
|
+
# Initialization
|
249
|
+
matrices = np.array(matrices)
|
250
|
+
if init is None:
|
251
|
+
gmean = np.mean(matrices, axis=0)
|
252
|
+
else:
|
253
|
+
_check_square(init)
|
254
|
+
if init.shape[0] != n_features:
|
255
|
+
raise_error("Initialization has incorrect shape.")
|
256
|
+
_check_spd(init)
|
257
|
+
gmean = init
|
258
|
+
|
259
|
+
norm_old = np.inf
|
260
|
+
step = 1.0
|
261
|
+
|
262
|
+
# Gradient descent
|
263
|
+
for _ in range(max_iter):
|
264
|
+
# Computation of the gradient
|
265
|
+
vals_gmean, vecs_gmean = linalg.eigh(gmean)
|
266
|
+
gmean_inv_sqrt = _form_symmetric(np.sqrt, 1.0 / vals_gmean, vecs_gmean)
|
267
|
+
whitened_matrices = [
|
268
|
+
gmean_inv_sqrt.dot(matrix).dot(gmean_inv_sqrt)
|
269
|
+
for matrix in matrices
|
270
|
+
]
|
271
|
+
logs = [_map_eigenvalues(np.log, w_mat) for w_mat in whitened_matrices]
|
272
|
+
# Covariant derivative is - gmean.dot(logms_mean)
|
273
|
+
logs_mean = np.mean(logs, axis=0)
|
274
|
+
if np.any(np.isnan(logs_mean)):
|
275
|
+
raise_error(
|
276
|
+
klass=FloatingPointError,
|
277
|
+
msg="Nan value after logarithm operation.",
|
278
|
+
)
|
279
|
+
|
280
|
+
# Norm of the covariant derivative on the tangent space at point gmean
|
281
|
+
norm = np.linalg.norm(logs_mean)
|
282
|
+
|
283
|
+
# Update of the minimizer
|
284
|
+
vals_log, vecs_log = linalg.eigh(logs_mean)
|
285
|
+
gmean_sqrt = _form_symmetric(np.sqrt, vals_gmean, vecs_gmean)
|
286
|
+
# Move along the geodesic
|
287
|
+
gmean = gmean_sqrt.dot(
|
288
|
+
_form_symmetric(np.exp, vals_log * step, vecs_log)
|
289
|
+
).dot(gmean_sqrt)
|
290
|
+
|
291
|
+
# Update the norm and the step size
|
292
|
+
if norm < norm_old:
|
293
|
+
norm_old = norm
|
294
|
+
elif norm > norm_old:
|
295
|
+
step = step / 2.0
|
296
|
+
norm = norm_old
|
297
|
+
if tol is not None and norm / gmean.size < tol:
|
298
|
+
break
|
299
|
+
if tol is not None and norm / gmean.size >= tol:
|
300
|
+
warn_with_log(
|
301
|
+
f"Maximum number of iterations {max_iter} reached without "
|
302
|
+
f"getting to the requested tolerance level {tol}."
|
303
|
+
)
|
304
|
+
|
305
|
+
return gmean
|
306
|
+
|
307
|
+
|
308
|
+
class JuniferConnectivityMeasure(ConnectivityMeasure):
|
309
|
+
"""Class for custom ConnectivityMeasure.
|
310
|
+
|
311
|
+
Differs from :class:`nilearn.connectome.ConnectivityMeasure` in the
|
312
|
+
following ways:
|
313
|
+
|
314
|
+
* default ``cov_estimator`` is
|
315
|
+
:class:`sklearn.covariance.EmpiricalCovariance`
|
316
|
+
* default ``kind`` is ``"correlation"``
|
317
|
+
|
318
|
+
Parameters
|
319
|
+
----------
|
320
|
+
cov_estimator : estimator object, optional
|
321
|
+
The covariance estimator
|
322
|
+
(default ``EmpiricalCovariance(store_precision=False)``).
|
323
|
+
kind : {"covariance", "correlation", "partial correlation", \
|
324
|
+
"tangent", "precision"}, optional
|
325
|
+
The matrix kind. For the use of ``"tangent"`` see [1]_
|
326
|
+
(default "correlation").
|
327
|
+
vectorize : bool, optional
|
328
|
+
If True, connectivity matrices are reshaped into 1D arrays and only
|
329
|
+
their flattened lower triangular parts are returned (default False).
|
330
|
+
discard_diagonal : bool, optional
|
331
|
+
If True, vectorized connectivity coefficients do not include the
|
332
|
+
matrices diagonal elements. Used only when vectorize is set to True
|
333
|
+
(default False).
|
334
|
+
standardize : bool, optional
|
335
|
+
If standardize is True, the data are centered and normed: their mean
|
336
|
+
is put to 0 and their variance is put to 1 in the time dimension
|
337
|
+
(default True).
|
338
|
+
|
339
|
+
.. note::
|
340
|
+
|
341
|
+
Added to control passing value to ``standardize`` of
|
342
|
+
``signal.clean`` to call new behavior since passing ``"zscore"`` or
|
343
|
+
True (default) is deprecated. This parameter will be deprecated in
|
344
|
+
version 0.13 and removed in version 0.15.
|
345
|
+
|
346
|
+
Attributes
|
347
|
+
----------
|
348
|
+
cov_estimator_ : estimator object
|
349
|
+
A new covariance estimator with the same parameters as
|
350
|
+
``cov_estimator``.
|
351
|
+
mean_ : numpy.ndarray
|
352
|
+
The mean connectivity matrix across subjects. For ``"tangent"`` kind,
|
353
|
+
it is the geometric mean of covariances (a group covariance
|
354
|
+
matrix that captures information from both correlation and partial
|
355
|
+
correlation matrices). For other values for ``kind``, it is the
|
356
|
+
mean of the corresponding matrices.
|
357
|
+
whitening_ : numpy.ndarray
|
358
|
+
The inverted square-rooted geometric mean of the covariance matrices.
|
359
|
+
|
360
|
+
References
|
361
|
+
----------
|
362
|
+
.. [1] Varoquaux, G., Baronnet, F., Kleinschmidt, A. et al.
|
363
|
+
Detection of brain functional-connectivity difference in
|
364
|
+
post-stroke patients using group-level covariance modeling.
|
365
|
+
In Tianzi Jiang, Nassir Navab, Josien P. W. Pluim, and
|
366
|
+
Max A. Viergever, editors, Medical image computing and
|
367
|
+
computer-assisted intervention - MICCAI 2010, Lecture notes
|
368
|
+
in computer science, Pages 200-208. Berlin, Heidelberg, 2010.
|
369
|
+
Springer.
|
370
|
+
doi:10/cn2h9c.
|
371
|
+
|
372
|
+
"""
|
373
|
+
|
374
|
+
def __init__(
|
375
|
+
self,
|
376
|
+
cov_estimator=DEFAULT_COV_ESTIMATOR,
|
377
|
+
kind="correlation",
|
378
|
+
vectorize=False,
|
379
|
+
discard_diagonal=False,
|
380
|
+
standardize=True,
|
381
|
+
):
|
382
|
+
super().__init__(
|
383
|
+
cov_estimator=cov_estimator,
|
384
|
+
kind=kind,
|
385
|
+
vectorize=vectorize,
|
386
|
+
discard_diagonal=discard_diagonal,
|
387
|
+
standardize=standardize,
|
388
|
+
)
|
389
|
+
|
390
|
+
def _fit_transform(
|
391
|
+
self,
|
392
|
+
X, # noqa: N803
|
393
|
+
do_transform=False,
|
394
|
+
do_fit=False,
|
395
|
+
confounds=None,
|
396
|
+
):
|
397
|
+
"""Avoid duplication of computation."""
|
398
|
+
self._check_input(X, confounds=confounds)
|
399
|
+
if do_fit:
|
400
|
+
self.cov_estimator_ = clone(self.cov_estimator)
|
401
|
+
|
402
|
+
# Compute all the matrices, stored in "connectivities"
|
403
|
+
if self.kind == "correlation":
|
404
|
+
covariances_std = [
|
405
|
+
self.cov_estimator_.fit(
|
406
|
+
signal.standardize_signal(
|
407
|
+
x,
|
408
|
+
detrend=False,
|
409
|
+
standardize=self.standardize,
|
410
|
+
)
|
411
|
+
).covariance_
|
412
|
+
for x in X
|
413
|
+
]
|
414
|
+
connectivities = [cov_to_corr(cov) for cov in covariances_std]
|
415
|
+
else:
|
416
|
+
covariances = [self.cov_estimator_.fit(x).covariance_ for x in X]
|
417
|
+
if self.kind in ("covariance", "tangent"):
|
418
|
+
connectivities = covariances
|
419
|
+
elif self.kind == "precision":
|
420
|
+
connectivities = [linalg.inv(cov) for cov in covariances]
|
421
|
+
elif self.kind == "partial correlation":
|
422
|
+
connectivities = [
|
423
|
+
prec_to_partial(linalg.inv(cov)) for cov in covariances
|
424
|
+
]
|
425
|
+
else:
|
426
|
+
allowed_kinds = (
|
427
|
+
"correlation",
|
428
|
+
"partial correlation",
|
429
|
+
"tangent",
|
430
|
+
"covariance",
|
431
|
+
"precision",
|
432
|
+
)
|
433
|
+
raise_error(
|
434
|
+
f"Allowed connectivity kinds are {allowed_kinds}. "
|
435
|
+
f"Got kind {self.kind}."
|
436
|
+
)
|
437
|
+
|
438
|
+
# Store the mean
|
439
|
+
if do_fit:
|
440
|
+
if self.kind == "tangent":
|
441
|
+
self.mean_ = _geometric_mean(
|
442
|
+
covariances, max_iter=30, tol=1e-7
|
443
|
+
)
|
444
|
+
self.whitening_ = _map_eigenvalues(
|
445
|
+
lambda x: 1.0 / np.sqrt(x), self.mean_
|
446
|
+
)
|
447
|
+
else:
|
448
|
+
self.mean_ = np.mean(connectivities, axis=0)
|
449
|
+
# Fight numerical instabilities: make symmetric
|
450
|
+
self.mean_ = self.mean_ + self.mean_.T
|
451
|
+
self.mean_ *= 0.5
|
452
|
+
|
453
|
+
# Compute the vector we return on transform
|
454
|
+
if do_transform:
|
455
|
+
if self.kind == "tangent":
|
456
|
+
connectivities = [
|
457
|
+
_map_eigenvalues(
|
458
|
+
np.log, self.whitening_.dot(cov).dot(self.whitening_)
|
459
|
+
)
|
460
|
+
for cov in connectivities
|
461
|
+
]
|
462
|
+
|
463
|
+
connectivities = np.array(connectivities)
|
464
|
+
|
465
|
+
if confounds is not None and not self.vectorize:
|
466
|
+
error_message = (
|
467
|
+
"'confounds' are provided but vectorize=False. "
|
468
|
+
"Confounds are only cleaned on vectorized matrices "
|
469
|
+
"as second level connectome regression "
|
470
|
+
"but not on symmetric matrices."
|
471
|
+
)
|
472
|
+
raise_error(error_message)
|
473
|
+
|
474
|
+
if self.vectorize:
|
475
|
+
connectivities = sym_matrix_to_vec(
|
476
|
+
connectivities, discard_diagonal=self.discard_diagonal
|
477
|
+
)
|
478
|
+
if confounds is not None:
|
479
|
+
connectivities = signal.clean(
|
480
|
+
connectivities, confounds=confounds
|
481
|
+
)
|
482
|
+
|
483
|
+
return connectivities
|