junifer 0.0.5__py3-none-any.whl → 0.0.5.dev11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. junifer/__init__.py +0 -17
  2. junifer/_version.py +2 -2
  3. junifer/api/__init__.py +1 -4
  4. junifer/api/cli.py +1 -91
  5. junifer/api/decorators.py +0 -9
  6. junifer/api/functions.py +10 -56
  7. junifer/api/parser.py +0 -3
  8. junifer/api/queue_context/__init__.py +1 -4
  9. junifer/api/res/afni/run_afni_docker.sh +1 -1
  10. junifer/api/res/ants/run_ants_docker.sh +1 -1
  11. junifer/api/res/fsl/run_fsl_docker.sh +1 -1
  12. junifer/api/tests/test_api_utils.py +2 -4
  13. junifer/api/tests/test_cli.py +0 -83
  14. junifer/api/tests/test_functions.py +2 -27
  15. junifer/configs/__init__.py +1 -1
  16. junifer/configs/juseless/__init__.py +1 -4
  17. junifer/configs/juseless/datagrabbers/__init__.py +1 -10
  18. junifer/configs/juseless/datagrabbers/aomic_id1000_vbm.py +0 -3
  19. junifer/configs/juseless/datagrabbers/camcan_vbm.py +0 -3
  20. junifer/configs/juseless/datagrabbers/ixi_vbm.py +0 -3
  21. junifer/configs/juseless/datagrabbers/tests/test_ucla.py +3 -1
  22. junifer/configs/juseless/datagrabbers/ucla.py +9 -12
  23. junifer/configs/juseless/datagrabbers/ukb_vbm.py +0 -3
  24. junifer/data/__init__.py +1 -21
  25. junifer/data/coordinates.py +19 -10
  26. junifer/data/masks.py +87 -58
  27. junifer/data/parcellations.py +3 -14
  28. junifer/data/template_spaces.py +1 -4
  29. junifer/data/tests/test_masks.py +37 -26
  30. junifer/data/utils.py +0 -3
  31. junifer/datagrabber/__init__.py +1 -18
  32. junifer/datagrabber/aomic/__init__.py +0 -3
  33. junifer/datagrabber/aomic/id1000.py +37 -70
  34. junifer/datagrabber/aomic/piop1.py +36 -69
  35. junifer/datagrabber/aomic/piop2.py +38 -71
  36. junifer/datagrabber/aomic/tests/test_id1000.py +99 -44
  37. junifer/datagrabber/aomic/tests/test_piop1.py +108 -65
  38. junifer/datagrabber/aomic/tests/test_piop2.py +102 -45
  39. junifer/datagrabber/base.py +6 -13
  40. junifer/datagrabber/datalad_base.py +1 -13
  41. junifer/datagrabber/dmcc13_benchmark.py +53 -36
  42. junifer/datagrabber/hcp1200/__init__.py +0 -3
  43. junifer/datagrabber/hcp1200/datalad_hcp1200.py +0 -3
  44. junifer/datagrabber/hcp1200/hcp1200.py +1 -4
  45. junifer/datagrabber/multiple.py +6 -45
  46. junifer/datagrabber/pattern.py +62 -170
  47. junifer/datagrabber/pattern_datalad.py +12 -25
  48. junifer/datagrabber/tests/test_datagrabber_utils.py +218 -0
  49. junifer/datagrabber/tests/test_datalad_base.py +4 -4
  50. junifer/datagrabber/tests/test_dmcc13_benchmark.py +19 -46
  51. junifer/datagrabber/tests/test_multiple.py +84 -161
  52. junifer/datagrabber/tests/test_pattern.py +0 -45
  53. junifer/datagrabber/tests/test_pattern_datalad.py +4 -4
  54. junifer/datagrabber/utils.py +230 -0
  55. junifer/datareader/__init__.py +1 -4
  56. junifer/datareader/default.py +43 -95
  57. junifer/external/__init__.py +1 -1
  58. junifer/external/nilearn/__init__.py +1 -5
  59. junifer/external/nilearn/junifer_nifti_spheres_masker.py +9 -23
  60. junifer/external/nilearn/tests/test_junifer_nifti_spheres_masker.py +1 -76
  61. junifer/markers/__init__.py +1 -23
  62. junifer/markers/base.py +28 -68
  63. junifer/markers/collection.py +2 -10
  64. junifer/markers/complexity/__init__.py +0 -10
  65. junifer/markers/complexity/complexity_base.py +43 -26
  66. junifer/markers/complexity/hurst_exponent.py +0 -3
  67. junifer/markers/complexity/multiscale_entropy_auc.py +0 -3
  68. junifer/markers/complexity/perm_entropy.py +0 -3
  69. junifer/markers/complexity/range_entropy.py +0 -3
  70. junifer/markers/complexity/range_entropy_auc.py +0 -3
  71. junifer/markers/complexity/sample_entropy.py +0 -3
  72. junifer/markers/complexity/tests/test_hurst_exponent.py +3 -11
  73. junifer/markers/complexity/tests/test_multiscale_entropy_auc.py +3 -11
  74. junifer/markers/complexity/tests/test_perm_entropy.py +3 -11
  75. junifer/markers/complexity/tests/test_range_entropy.py +3 -11
  76. junifer/markers/complexity/tests/test_range_entropy_auc.py +3 -11
  77. junifer/markers/complexity/tests/test_sample_entropy.py +3 -11
  78. junifer/markers/complexity/tests/test_weighted_perm_entropy.py +3 -11
  79. junifer/markers/complexity/weighted_perm_entropy.py +0 -3
  80. junifer/markers/ets_rss.py +42 -27
  81. junifer/markers/falff/__init__.py +0 -3
  82. junifer/markers/falff/_afni_falff.py +2 -5
  83. junifer/markers/falff/_junifer_falff.py +0 -3
  84. junifer/markers/falff/falff_base.py +46 -20
  85. junifer/markers/falff/falff_parcels.py +27 -56
  86. junifer/markers/falff/falff_spheres.py +29 -60
  87. junifer/markers/falff/tests/test_falff_parcels.py +23 -39
  88. junifer/markers/falff/tests/test_falff_spheres.py +23 -39
  89. junifer/markers/functional_connectivity/__init__.py +0 -9
  90. junifer/markers/functional_connectivity/crossparcellation_functional_connectivity.py +60 -63
  91. junifer/markers/functional_connectivity/edge_functional_connectivity_parcels.py +32 -45
  92. junifer/markers/functional_connectivity/edge_functional_connectivity_spheres.py +36 -49
  93. junifer/markers/functional_connectivity/functional_connectivity_base.py +70 -71
  94. junifer/markers/functional_connectivity/functional_connectivity_parcels.py +25 -34
  95. junifer/markers/functional_connectivity/functional_connectivity_spheres.py +30 -40
  96. junifer/markers/functional_connectivity/tests/test_crossparcellation_functional_connectivity.py +7 -11
  97. junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_parcels.py +7 -27
  98. junifer/markers/functional_connectivity/tests/test_edge_functional_connectivity_spheres.py +12 -28
  99. junifer/markers/functional_connectivity/tests/test_functional_connectivity_parcels.py +11 -35
  100. junifer/markers/functional_connectivity/tests/test_functional_connectivity_spheres.py +62 -36
  101. junifer/markers/parcel_aggregation.py +61 -47
  102. junifer/markers/reho/__init__.py +0 -3
  103. junifer/markers/reho/_afni_reho.py +2 -5
  104. junifer/markers/reho/_junifer_reho.py +1 -4
  105. junifer/markers/reho/reho_base.py +27 -8
  106. junifer/markers/reho/reho_parcels.py +17 -28
  107. junifer/markers/reho/reho_spheres.py +18 -27
  108. junifer/markers/reho/tests/test_reho_parcels.py +3 -8
  109. junifer/markers/reho/tests/test_reho_spheres.py +3 -8
  110. junifer/markers/sphere_aggregation.py +59 -43
  111. junifer/markers/temporal_snr/__init__.py +0 -3
  112. junifer/markers/temporal_snr/temporal_snr_base.py +32 -23
  113. junifer/markers/temporal_snr/temporal_snr_parcels.py +6 -9
  114. junifer/markers/temporal_snr/temporal_snr_spheres.py +6 -9
  115. junifer/markers/temporal_snr/tests/test_temporal_snr_parcels.py +3 -6
  116. junifer/markers/temporal_snr/tests/test_temporal_snr_spheres.py +3 -6
  117. junifer/markers/tests/test_collection.py +8 -9
  118. junifer/markers/tests/test_ets_rss.py +9 -15
  119. junifer/markers/tests/test_markers_base.py +18 -17
  120. junifer/markers/tests/test_parcel_aggregation.py +32 -93
  121. junifer/markers/tests/test_sphere_aggregation.py +19 -72
  122. junifer/onthefly/__init__.py +1 -4
  123. junifer/onthefly/read_transform.py +0 -3
  124. junifer/pipeline/__init__.py +1 -9
  125. junifer/pipeline/pipeline_step_mixin.py +4 -21
  126. junifer/pipeline/registry.py +0 -3
  127. junifer/pipeline/singleton.py +0 -3
  128. junifer/pipeline/tests/test_registry.py +1 -1
  129. junifer/pipeline/update_meta_mixin.py +0 -3
  130. junifer/pipeline/utils.py +1 -67
  131. junifer/pipeline/workdir_manager.py +0 -3
  132. junifer/preprocess/__init__.py +2 -10
  133. junifer/preprocess/ants/__init__.py +4 -0
  134. junifer/preprocess/ants/ants_apply_transforms_warper.py +185 -0
  135. junifer/preprocess/ants/tests/test_ants_apply_transforms_warper.py +56 -0
  136. junifer/preprocess/base.py +3 -6
  137. junifer/preprocess/bold_warper.py +265 -0
  138. junifer/preprocess/confounds/__init__.py +0 -3
  139. junifer/preprocess/confounds/fmriprep_confound_remover.py +60 -47
  140. junifer/preprocess/confounds/tests/test_fmriprep_confound_remover.py +113 -72
  141. junifer/preprocess/fsl/__init__.py +4 -0
  142. junifer/preprocess/fsl/apply_warper.py +179 -0
  143. junifer/preprocess/fsl/tests/test_apply_warper.py +45 -0
  144. junifer/preprocess/tests/test_bold_warper.py +159 -0
  145. junifer/preprocess/warping/__init__.py +0 -3
  146. junifer/preprocess/warping/_ants_warper.py +0 -3
  147. junifer/preprocess/warping/_fsl_warper.py +0 -3
  148. junifer/stats.py +1 -4
  149. junifer/storage/__init__.py +1 -9
  150. junifer/storage/base.py +1 -40
  151. junifer/storage/hdf5.py +9 -71
  152. junifer/storage/pandas_base.py +0 -3
  153. junifer/storage/sqlite.py +0 -3
  154. junifer/storage/tests/test_hdf5.py +10 -82
  155. junifer/storage/utils.py +0 -9
  156. junifer/testing/__init__.py +1 -4
  157. junifer/testing/datagrabbers.py +6 -13
  158. junifer/testing/tests/test_partlycloudytesting_datagrabber.py +7 -7
  159. junifer/testing/utils.py +0 -3
  160. junifer/utils/__init__.py +2 -13
  161. junifer/utils/fs.py +0 -3
  162. junifer/utils/helpers.py +1 -32
  163. junifer/utils/logging.py +4 -33
  164. junifer/utils/tests/test_logging.py +0 -8
  165. {junifer-0.0.5.dist-info → junifer-0.0.5.dev11.dist-info}/METADATA +16 -17
  166. junifer-0.0.5.dev11.dist-info/RECORD +259 -0
  167. {junifer-0.0.5.dist-info → junifer-0.0.5.dev11.dist-info}/WHEEL +1 -1
  168. junifer/api/res/freesurfer/mri_binarize +0 -3
  169. junifer/api/res/freesurfer/mri_mc +0 -3
  170. junifer/api/res/freesurfer/mri_pretess +0 -3
  171. junifer/api/res/freesurfer/mris_convert +0 -3
  172. junifer/api/res/freesurfer/run_freesurfer_docker.sh +0 -61
  173. junifer/data/masks/ukb/UKB_15K_GM_template.nii.gz +0 -0
  174. junifer/datagrabber/pattern_validation_mixin.py +0 -388
  175. junifer/datagrabber/tests/test_pattern_validation_mixin.py +0 -249
  176. junifer/external/BrainPrint/brainprint/__init__.py +0 -4
  177. junifer/external/BrainPrint/brainprint/_version.py +0 -3
  178. junifer/external/BrainPrint/brainprint/asymmetry.py +0 -91
  179. junifer/external/BrainPrint/brainprint/brainprint.py +0 -441
  180. junifer/external/BrainPrint/brainprint/surfaces.py +0 -258
  181. junifer/external/BrainPrint/brainprint/utils/__init__.py +0 -1
  182. junifer/external/BrainPrint/brainprint/utils/_config.py +0 -112
  183. junifer/external/BrainPrint/brainprint/utils/utils.py +0 -188
  184. junifer/external/nilearn/junifer_connectivity_measure.py +0 -483
  185. junifer/external/nilearn/tests/test_junifer_connectivity_measure.py +0 -1089
  186. junifer/markers/brainprint.py +0 -459
  187. junifer/markers/tests/test_brainprint.py +0 -58
  188. junifer/preprocess/smoothing/__init__.py +0 -9
  189. junifer/preprocess/smoothing/_afni_smoothing.py +0 -119
  190. junifer/preprocess/smoothing/_fsl_smoothing.py +0 -116
  191. junifer/preprocess/smoothing/_nilearn_smoothing.py +0 -69
  192. junifer/preprocess/smoothing/smoothing.py +0 -174
  193. junifer/preprocess/smoothing/tests/test_smoothing.py +0 -94
  194. junifer-0.0.5.dist-info/RECORD +0 -275
  195. {junifer-0.0.5.dist-info → junifer-0.0.5.dev11.dist-info}/AUTHORS.rst +0 -0
  196. {junifer-0.0.5.dist-info → junifer-0.0.5.dev11.dist-info}/LICENSE.md +0 -0
  197. {junifer-0.0.5.dist-info → junifer-0.0.5.dev11.dist-info}/entry_points.txt +0 -0
  198. {junifer-0.0.5.dist-info → junifer-0.0.5.dev11.dist-info}/top_level.txt +0 -0
@@ -1,1089 +0,0 @@
1
- """Provide tests for JuniferConnectivityMeasure class."""
2
-
3
- # Authors: Synchon Mandal <s.mandal@fz-juelich.de>
4
- # License: AGPL
5
-
6
- import copy
7
- import warnings
8
- from math import cosh, exp, log, sinh, sqrt
9
- from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
10
-
11
- import numpy as np
12
- import pytest
13
- from nilearn.connectome.connectivity_matrices import sym_matrix_to_vec
14
- from nilearn.tests.test_signal import generate_signals
15
- from numpy.testing import assert_array_almost_equal, assert_array_equal
16
- from pandas import DataFrame
17
- from scipy import linalg
18
- from sklearn.covariance import EmpiricalCovariance, LedoitWolf
19
-
20
- from junifer.external.nilearn import JuniferConnectivityMeasure
21
- from junifer.external.nilearn.junifer_connectivity_measure import (
22
- _check_spd,
23
- _check_square,
24
- _form_symmetric,
25
- _geometric_mean,
26
- _map_eigenvalues,
27
- is_spd,
28
- )
29
-
30
-
31
- if TYPE_CHECKING:
32
- from numpy.typing import ArrayLike
33
- from sklearn.base import BaseEstimator
34
-
35
- # New BSD License
36
-
37
- # Copyright (c) The nilearn developers.
38
- # All rights reserved.
39
-
40
-
41
- # Redistribution and use in source and binary forms, with or without
42
- # modification, are permitted provided that the following conditions are met:
43
-
44
- # a. Redistributions of source code must retain the above copyright notice,
45
- # this list of conditions and the following disclaimer.
46
- # b. Redistributions in binary form must reproduce the above copyright
47
- # notice, this list of conditions and the following disclaimer in the
48
- # documentation and/or other materials provided with the distribution.
49
- # c. Neither the name of the nilearn developers nor the names of
50
- # its contributors may be used to endorse or promote products
51
- # derived from this software without specific prior written
52
- # permission.
53
-
54
-
55
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
56
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
57
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
58
- # ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
59
- # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
60
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
61
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
62
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
63
- # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
64
- # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
65
- # DAMAGE.
66
-
67
-
68
- CONNECTIVITY_KINDS = (
69
- "covariance",
70
- "correlation",
71
- "tangent",
72
- "precision",
73
- "partial correlation",
74
- )
75
-
76
- N_FEATURES = 49
77
-
78
- N_SUBJECTS = 5
79
-
80
-
81
- def random_diagonal(
82
- p: int,
83
- v_min: float = 1.0,
84
- v_max: float = 2.0,
85
- random_state: Union[int, np.random.RandomState] = 0,
86
- ) -> np.ndarray:
87
- """Generate a random diagonal matrix.
88
-
89
- Parameters
90
- ----------
91
- p : int
92
- The first dimension of the array.
93
- v_min : float, optional
94
- Minimal element (default 1.).
95
- v_max : float, optional
96
- Maximal element (default 2.).
97
- random_state : int or numpy.random.RandomState instance, optional
98
- random number generator, or seed (default 0).
99
-
100
- Returns
101
- -------
102
- numpy.ndarray of shape (p, p)
103
- A diagonal matrix with the given minimal and maximal elements.
104
-
105
- """
106
- random_state = np.random.default_rng(random_state)
107
- diag = random_state.random(p) * (v_max - v_min) + v_min
108
- diag[diag == np.amax(diag)] = v_max
109
- diag[diag == np.amin(diag)] = v_min
110
- return np.diag(diag)
111
-
112
-
113
- def random_spd(
114
- p: int,
115
- eig_min: float,
116
- cond: float,
117
- random_state: Union[int, np.random.RandomState] = 0,
118
- ) -> np.ndarray:
119
- """Generate a random symmetric positive definite matrix.
120
-
121
- Parameters
122
- ----------
123
- p : int
124
- The first dimension of the array.
125
- eig_min : float
126
- Minimal eigenvalue.
127
- cond : float
128
- Condition number, defined as the ratio of the maximum eigenvalue to the
129
- minimum one.
130
- random_state : int or numpy.random.RandomState instance, optional
131
- random number generator, or seed (default 0).
132
-
133
- Returns
134
- -------
135
- numpy.ndarray of shape (p, p)
136
- A symmetric positive definite matrix with the given minimal eigenvalue
137
- and condition number.
138
-
139
- """
140
- rand_gen = np.random.default_rng(random_state)
141
- mat = rand_gen.standard_normal((p, p))
142
- unitary, _ = linalg.qr(mat)
143
- diag = random_diagonal(
144
- p, v_min=eig_min, v_max=cond * eig_min, random_state=random_state
145
- )
146
- return unitary.dot(diag).dot(unitary.T)
147
-
148
-
149
- def _signals(
150
- n_subjects: int = N_SUBJECTS,
151
- ) -> Tuple[List[np.ndarray], np.ndarray]:
152
- """Generate signals and compute covariances while applying confounds.
153
-
154
- Parameters
155
- ----------
156
- n_subjects : int
157
- Number of subjects.
158
-
159
- Returns
160
- -------
161
- tuple of list of np.ndarray and np.ndarray.
162
-
163
- """
164
- n_features = N_FEATURES
165
- signals = []
166
- for k in range(n_subjects):
167
- n_samples = 200 + k
168
- signal, _, confounds = generate_signals(
169
- n_features=n_features,
170
- n_confounds=5,
171
- length=n_samples,
172
- same_variance=False,
173
- )
174
- signals.append(signal)
175
- signal -= signal.mean(axis=0)
176
- return signals, confounds
177
-
178
-
179
- @pytest.fixture
180
- def signals() -> List[np.ndarray]:
181
- """Return signals as list of np.ndarray."""
182
- return _signals(N_SUBJECTS)[0]
183
-
184
-
185
- @pytest.fixture
186
- def signals_and_covariances(
187
- cov_estimator: Union[LedoitWolf, EmpiricalCovariance]
188
- ) -> Tuple[List[np.ndarray], List[float]]:
189
- """Return signals and covariances for a covariance estimator.
190
-
191
- Parameters
192
- ----------
193
- cov_estimator : LedoitWolf instance or EmpiricalCovariance instance
194
- The covariance estimator.
195
-
196
- Returns
197
- -------
198
- tuple of list of np.ndarray and list of float
199
-
200
- """
201
- signals, _ = _signals()
202
- emp_covs = []
203
- ledoit_covs = []
204
- ledoit_estimator = LedoitWolf()
205
- for k, signal_ in enumerate(signals):
206
- n_samples = 200 + k
207
- signal_ -= signal_.mean(axis=0)
208
- emp_covs.append((signal_.T).dot(signal_) / n_samples)
209
- ledoit_covs.append(ledoit_estimator.fit(signal_).covariance_)
210
-
211
- if isinstance(cov_estimator, LedoitWolf):
212
- return signals, ledoit_covs
213
- elif isinstance(cov_estimator, EmpiricalCovariance):
214
- return signals, emp_covs
215
-
216
-
217
- def test_check_square() -> None:
218
- """Test square matrix assertion."""
219
- non_square = np.ones((2, 3))
220
- with pytest.raises(ValueError, match="Expected a square matrix"):
221
- _check_square(non_square)
222
-
223
-
224
- @pytest.mark.parametrize(
225
- "invalid_input",
226
- [
227
- np.array([[0, 1], [0, 0]]), # non symmetric
228
- np.ones((3, 3)), # non SPD
229
- ],
230
- )
231
- def test_check_spd(invalid_input: np.ndarray) -> None:
232
- """Test matrix is symmetric positive definite.
233
-
234
- Parameters
235
- ----------
236
- invalid_input : numpy.ndarray
237
- The parametrized invalid input array.
238
-
239
- """
240
- with pytest.raises(
241
- ValueError, match="Expected a symmetric positive definite matrix."
242
- ):
243
- _check_spd(invalid_input)
244
-
245
-
246
- def test_map_eigenvalues_on_exp_map() -> None:
247
- """Test exponential eigenvalues mapping."""
248
- sym = np.ones((2, 2))
249
- sym_exp = exp(1.0) * np.array(
250
- [[cosh(1.0), sinh(1.0)], [sinh(1.0), cosh(1.0)]]
251
- )
252
- assert_array_almost_equal(_map_eigenvalues(np.exp, sym), sym_exp)
253
-
254
-
255
- def test_map_eigenvalues_on_sqrt_map() -> None:
256
- """Test square-root eigenvalues mapping."""
257
- spd_sqrt = np.array(
258
- [[2.0, -1.0, 0.0], [-1.0, 2.0, -1.0], [0.0, -1.0, 2.0]]
259
- )
260
- spd = spd_sqrt.dot(spd_sqrt)
261
- assert_array_almost_equal(_map_eigenvalues(np.sqrt, spd), spd_sqrt)
262
-
263
-
264
- def test_map_eigenvalues_on_log_map() -> None:
265
- """Test logarithmic eigenvalues mapping."""
266
- spd = np.array([[1.25, 0.75], [0.75, 1.25]])
267
- spd_log = np.array([[0.0, log(2.0)], [log(2.0), 0.0]])
268
- assert_array_almost_equal(_map_eigenvalues(np.log, spd), spd_log)
269
-
270
-
271
- def test_geometric_mean_couple() -> None:
272
- """Test geometric mean."""
273
- n_features = 7
274
- spd1 = np.ones((n_features, n_features))
275
- spd1 = spd1.dot(spd1) + n_features * np.eye(n_features)
276
- spd2 = np.tril(np.ones((n_features, n_features)))
277
- spd2 = spd2.dot(spd2.T)
278
- vals_spd2, vecs_spd2 = np.linalg.eigh(spd2)
279
- spd2_sqrt = _form_symmetric(np.sqrt, vals_spd2, vecs_spd2)
280
- spd2_inv_sqrt = _form_symmetric(np.sqrt, 1.0 / vals_spd2, vecs_spd2)
281
- geo = spd2_sqrt.dot(
282
- _map_eigenvalues(np.sqrt, spd2_inv_sqrt.dot(spd1).dot(spd2_inv_sqrt))
283
- ).dot(spd2_sqrt)
284
-
285
- assert_array_almost_equal(_geometric_mean([spd1, spd2]), geo)
286
-
287
-
288
- def test_geometric_mean_diagonal() -> None:
289
- """Test geometric mean along diagonal."""
290
- n_matrices = 20
291
- n_features = 5
292
- diags = []
293
- for k in range(n_matrices):
294
- diag = np.eye(n_features)
295
- diag[k % n_features, k % n_features] = 1e4 + k
296
- diag[(n_features - 1) // (k + 1), (n_features - 1) // (k + 1)] = (
297
- k + 1
298
- ) * 1e-4
299
- diags.append(diag)
300
- geo = np.prod(np.array(diags), axis=0) ** (1 / float(len(diags)))
301
-
302
- assert_array_almost_equal(_geometric_mean(diags), geo)
303
-
304
-
305
- def test_geometric_mean_geodesic() -> None:
306
- """Test geometric mean along geodesic."""
307
- n_matrices = 10
308
- n_features = 6
309
- sym = np.arange(n_features) / np.linalg.norm(np.arange(n_features))
310
- sym = sym * sym[:, np.newaxis]
311
- times = np.arange(n_matrices)
312
- non_singular = np.eye(n_features)
313
- non_singular[1:3, 1:3] = np.array([[-1, -0.5], [-0.5, -1]])
314
- spds = [
315
- non_singular.dot(_map_eigenvalues(np.exp, time * sym)).dot(
316
- non_singular.T
317
- )
318
- for time in times
319
- ]
320
- gmean = non_singular.dot(_map_eigenvalues(np.exp, times.mean() * sym)).dot(
321
- non_singular.T
322
- )
323
- assert_array_almost_equal(_geometric_mean(spds), gmean)
324
-
325
-
326
- def test_geometric_mean_properties() -> None:
327
- """Test geometric mean properties."""
328
- n_matrices = 40
329
- n_features = 15
330
- spds = [
331
- random_spd(n_features, eig_min=1.0, cond=10.0, random_state=0)
332
- for _ in range(n_matrices)
333
- ]
334
- input_spds = copy.copy(spds)
335
-
336
- gmean = _geometric_mean(spds)
337
-
338
- # Generic
339
- assert isinstance(spds, list)
340
- for spd, input_spd in zip(spds, input_spds):
341
- assert_array_equal(spd, input_spd)
342
- assert is_spd(gmean, decimal=7)
343
-
344
-
345
- def random_non_singular(
346
- p: int,
347
- sing_min: float = 1.0,
348
- sing_max: float = 2.0,
349
- random_state: Union[int, np.random.RandomState] = 0,
350
- ) -> np.ndarray:
351
- """Generate a random nonsingular matrix.
352
-
353
- Parameters
354
- ----------
355
- p : int
356
- The first dimension of the array.
357
- sing_min : float, optional
358
- Minimal singular value (default 1.).
359
- sing_max : float, optional
360
- Maximal singular value (default 2.).
361
- random_state : int or numpy.random.RandomState instance, optional
362
- random number generator, or seed (default 0).
363
-
364
- Returns
365
- -------
366
- numpy.ndarray of shape (p, p)
367
- A nonsingular matrix with the given minimal and maximal singular
368
- values.
369
-
370
- """
371
- rand_gen = np.random.default_rng(random_state)
372
- diag = random_diagonal(
373
- p, v_min=sing_min, v_max=sing_max, random_state=random_state
374
- )
375
- mat1 = rand_gen.standard_normal((p, p))
376
- mat2 = rand_gen.standard_normal((p, p))
377
- unitary1, _ = linalg.qr(mat1)
378
- unitary2, _ = linalg.qr(mat2)
379
- return unitary1.dot(diag).dot(unitary2.T)
380
-
381
-
382
- def test_geometric_mean_properties_check_invariance() -> None:
383
- """Test geometric mean properties' invariance."""
384
- n_matrices = 40
385
- n_features = 15
386
- spds = [
387
- random_spd(n_features, eig_min=1.0, cond=10.0, random_state=0)
388
- for _ in range(n_matrices)
389
- ]
390
-
391
- gmean = _geometric_mean(spds)
392
-
393
- # Invariance under reordering
394
- spds.reverse()
395
- spds.insert(0, spds[1])
396
- spds.pop(2)
397
- assert_array_almost_equal(_geometric_mean(spds), gmean)
398
-
399
- # Invariance under congruent transformation
400
- non_singular = random_non_singular(n_features, random_state=0)
401
- spds_cong = [non_singular.dot(spd).dot(non_singular.T) for spd in spds]
402
- assert_array_almost_equal(
403
- _geometric_mean(spds_cong), non_singular.dot(gmean).dot(non_singular.T)
404
- )
405
-
406
- # Invariance under inversion
407
- spds_inv = [linalg.inv(spd) for spd in spds]
408
- init = linalg.inv(np.mean(spds, axis=0))
409
- assert_array_almost_equal(
410
- _geometric_mean(spds_inv, init=init), linalg.inv(gmean)
411
- )
412
-
413
-
414
- def grad_geometric_mean(
415
- mats: "ArrayLike",
416
- init: Optional["ArrayLike"] = None,
417
- max_iter: int = 10,
418
- tol: float = 1e-7,
419
- ) -> List[float]:
420
- """Compute gradient of geometric mean.
421
-
422
- Return the norm of the covariant derivative at each iteration step
423
- of geometric_mean. See its docstring for details.
424
-
425
- Norm is intrinsic norm on the tangent space of the manifold of symmetric
426
- positive definite matrices.
427
-
428
- Parameters
429
- ----------
430
- mats : array-like object
431
- Object that can be converted to np.ndarray.
432
- init : array-like object or None, optional
433
- Initialization matrix (default None).
434
- max_iter : int, optional
435
- Maximum iteration for gradient descent (default 10).
436
- tol : float, optional
437
- Tolerance for norm (default 1e-7).
438
-
439
- Returns
440
- -------
441
- grad_norm : list of float
442
- Norm of the covariant derivative in the tangent space at each step.
443
-
444
- """
445
- mats = np.array(mats)
446
-
447
- # Initialization
448
- gmean = init or np.mean(mats, axis=0)
449
-
450
- norm_old = np.inf
451
- step = 1.0
452
- grad_norm = []
453
- for _ in range(max_iter):
454
- # Computation of the gradient
455
- vals_gmean, vecs_gmean = linalg.eigh(gmean)
456
- gmean_inv_sqrt = _form_symmetric(np.sqrt, 1.0 / vals_gmean, vecs_gmean)
457
- whitened_mats = [
458
- gmean_inv_sqrt.dot(mat).dot(gmean_inv_sqrt) for mat in mats
459
- ]
460
- logs = [_map_eigenvalues(np.log, w_mat) for w_mat in whitened_mats]
461
-
462
- # Covariant derivative is - gmean.dot(logs_mean)
463
- logs_mean = np.mean(logs, axis=0)
464
-
465
- # Norm of the covariant derivative on
466
- # the tangent space at point gmean
467
- norm = np.linalg.norm(logs_mean)
468
-
469
- # Update of the minimizer
470
- vals_log, vecs_log = linalg.eigh(logs_mean)
471
- gmean_sqrt = _form_symmetric(np.sqrt, vals_gmean, vecs_gmean)
472
- gmean = gmean_sqrt.dot(
473
- _form_symmetric(np.exp, vals_log * step, vecs_log)
474
- ).dot(gmean_sqrt)
475
-
476
- # Update the norm and the step size
477
- if norm < norm_old:
478
- norm_old = norm
479
- if norm > norm_old:
480
- step = step / 2.0
481
- norm = norm_old
482
-
483
- grad_norm.append(norm / gmean.size)
484
- if tol is not None and norm / gmean.size < tol:
485
- break
486
-
487
- return grad_norm
488
-
489
-
490
- def test_geometric_mean_properties_check_gradient() -> None:
491
- """Test geometric mean properties' gradient."""
492
- n_matrices = 40
493
- n_features = 15
494
- spds = [
495
- random_spd(n_features, eig_min=1.0, cond=10.0, random_state=0)
496
- for _ in range(n_matrices)
497
- ]
498
-
499
- grad_norm = grad_geometric_mean(spds, tol=1e-20)
500
-
501
- # Gradient norm is decreasing
502
- difference = np.diff(grad_norm)
503
- assert np.amax(difference) <= 0.0
504
-
505
- # Check warning if gradient norm in the last step is less than
506
- # tolerance
507
- max_iter = 1
508
- tol = 1e-20
509
- with warnings.catch_warnings(record=True) as w:
510
- warnings.simplefilter("always")
511
- _geometric_mean(spds, max_iter=max_iter, tol=tol)
512
- assert len(w) == 1
513
-
514
- grad_norm = grad_geometric_mean(spds, max_iter=max_iter, tol=tol)
515
-
516
- assert len(grad_norm) == max_iter
517
- assert grad_norm[-1] > tol
518
-
519
-
520
- # proportion of badly conditioned matrices
521
- @pytest.mark.parametrize("p", [0.5, 1.0])
522
- def test_geometric_mean_properties_evaluate_convergence(p: float) -> None:
523
- """Test geometric mean properties' convergence.
524
-
525
- Parameters
526
- ----------
527
- p : float
528
- Convergence criteria.
529
-
530
- """
531
- n_matrices = 40
532
- n_features = 15
533
- # A warning is printed if tolerance is not reached
534
- spds = [
535
- random_spd(n_features, eig_min=1e-2, cond=1e6, random_state=0)
536
- for _ in range(int(p * n_matrices))
537
- ]
538
- spds.extend(
539
- random_spd(n_features, eig_min=1.0, cond=10.0, random_state=0)
540
- for _ in range(int(p * n_matrices), n_matrices)
541
- )
542
- max_iter = 30 if p < 1 else 60
543
-
544
- _geometric_mean(spds, max_iter=max_iter, tol=1e-5)
545
-
546
-
547
- def test_geometric_mean_error_non_square_matrix() -> None:
548
- """Test geometric mean error for non-square matrix."""
549
- n_features = 5
550
- mat1 = np.ones((n_features, n_features + 1))
551
-
552
- with pytest.raises(ValueError, match="Expected a square matrix"):
553
- _geometric_mean([mat1])
554
-
555
-
556
- def test_geometric_mean_error_input_matrices_have_different_shapes() -> None:
557
- """Test geometric mean error for different input matrices shape."""
558
- n_features = 5
559
- mat1 = np.eye(n_features)
560
- mat2 = np.ones((n_features + 1, n_features + 1))
561
-
562
- with pytest.raises(
563
- ValueError, match="Matrices are not of the same shape."
564
- ):
565
- _geometric_mean([mat1, mat2])
566
-
567
-
568
- def test_geometric_mean_error_non_spd_input_matrix() -> None:
569
- """Test geometric mean error for non SPD input matrix."""
570
- n_features = 5
571
- mat2 = np.ones((n_features + 1, n_features + 1))
572
-
573
- with pytest.raises(
574
- ValueError, match="Expected a symmetric positive definite matrix."
575
- ):
576
- _geometric_mean([mat2])
577
-
578
-
579
- def test_connectivity_measure_errors():
580
- """Test errors."""
581
- # Raising error for input subjects not iterable
582
- conn_measure = JuniferConnectivityMeasure()
583
-
584
- with pytest.raises(
585
- ValueError, match="'subjects' input argument must be an iterable"
586
- ):
587
- conn_measure.fit(1.0)
588
-
589
- # input subjects not 2D numpy.ndarrays
590
- with pytest.raises(
591
- ValueError, match="Each subject must be 2D numpy.ndarray."
592
- ):
593
- conn_measure.fit([np.ones((100, 40)), np.ones((10,))])
594
-
595
- # input subjects with different number of features
596
- with pytest.raises(
597
- ValueError, match="All subjects must have the same number of features."
598
- ):
599
- conn_measure.fit([np.ones((100, 40)), np.ones((100, 41))])
600
-
601
- # fit_transform with a single subject and kind=tangent
602
- conn_measure = JuniferConnectivityMeasure(kind="tangent")
603
-
604
- with pytest.raises(
605
- ValueError,
606
- match="Tangent space parametrization .* only be .* group of subjects",
607
- ):
608
- conn_measure.fit_transform([np.ones((100, 40))])
609
-
610
-
611
- @pytest.mark.parametrize(
612
- "cov_estimator", [EmpiricalCovariance(), LedoitWolf()]
613
- )
614
- @pytest.mark.parametrize("kind", CONNECTIVITY_KINDS)
615
- def test_connectivity_measure_generic(
616
- kind: str,
617
- cov_estimator: Type["BaseEstimator"],
618
- signals_and_covariances: Tuple[List[np.ndarray], List[float]],
619
- ) -> None:
620
- """Test generic JuniferConnectivityMeasure.
621
-
622
- Parameters
623
- ----------
624
- kind : str
625
- The parametrized connectivity matrix kind.
626
- cov_estimator : estimator object
627
- The parametrized covariance estimator.
628
- signals_and_covariances : tuple
629
- The signals and covariances for a covariance estimator.
630
-
631
- """
632
- signals, covs = signals_and_covariances
633
-
634
- # Check outputs properties
635
- input_covs = copy.copy(covs)
636
- conn_measure = JuniferConnectivityMeasure(
637
- kind=kind, cov_estimator=cov_estimator
638
- )
639
- connectivities = conn_measure.fit_transform(signals)
640
-
641
- # Generic
642
- assert isinstance(connectivities, np.ndarray)
643
- assert len(connectivities) == len(covs)
644
-
645
- for k, _ in enumerate(connectivities):
646
- assert_array_equal(input_covs[k], covs[k])
647
-
648
- assert is_spd(covs[k], decimal=7)
649
-
650
-
651
- def _assert_connectivity_tangent(connectivities, conn_measure, covs) -> None:
652
- """Assert tangent connectivity matrix.
653
-
654
- Check output value properties for tangent connectivity measure
655
- that they have the expected relationship
656
- to the input covariance matrices.
657
-
658
- - the geometric mean of the eigenvalues
659
- of the mean covariance matrix is positive-definite
660
- - the whitening matrix (used to transform the data
661
- also produces a positive-definite matrix
662
-
663
- """
664
- for true_covariance_matrix, estimated_covariance_matrix in zip(
665
- covs, connectivities
666
- ):
667
- assert_array_almost_equal(
668
- estimated_covariance_matrix, estimated_covariance_matrix.T
669
- )
670
-
671
- assert is_spd(conn_measure.whitening_, decimal=7)
672
-
673
- gmean_sqrt = _map_eigenvalues(np.sqrt, conn_measure.mean_)
674
- assert is_spd(gmean_sqrt, decimal=7)
675
- assert_array_almost_equal(
676
- conn_measure.whitening_.dot(gmean_sqrt),
677
- np.eye(N_FEATURES),
678
- )
679
- assert_array_almost_equal(
680
- gmean_sqrt.dot(
681
- _map_eigenvalues(np.exp, estimated_covariance_matrix)
682
- ).dot(gmean_sqrt),
683
- true_covariance_matrix,
684
- )
685
-
686
-
687
- def _assert_connectivity_precision(connectivities, covs) -> None:
688
- """Assert precision connectivity matrix.
689
-
690
- Estimated precision matrix:
691
- - is positive definite
692
- - its product with the true covariance matrix
693
- is close to the identity matrix
694
-
695
- """
696
- for true_covariance_matrix, estimated_covariance_matrix in zip(
697
- covs, connectivities
698
- ):
699
- assert is_spd(estimated_covariance_matrix, decimal=7)
700
- assert_array_almost_equal(
701
- estimated_covariance_matrix.dot(true_covariance_matrix),
702
- np.eye(N_FEATURES),
703
- )
704
-
705
-
706
- def _assert_connectivity_correlation(
707
- connectivities, cov_estimator, covs
708
- ) -> None:
709
- """Assert correlation connectivity matrix.
710
-
711
- Verify that the estimated covariance matrix:
712
- - is symmetric and positive definite
713
- - has values close to 1 on its diagonal
714
-
715
- If the covariance estimator is EmpiricalCovariance,
716
- the product of:
717
- - the square root of the diagonal of the true covariance matrix
718
- - the estimated covariance matrix
719
- - the square root of the diagonal of the true covariance matrix
720
-
721
- should be close to the true covariance matrix.
722
-
723
- """
724
- for true_covariance_matrix, estimated_covariance_matrix in zip(
725
- covs, connectivities
726
- ):
727
- assert is_spd(estimated_covariance_matrix, decimal=7)
728
-
729
- assert_array_almost_equal(
730
- np.diag(estimated_covariance_matrix), np.ones(N_FEATURES)
731
- )
732
-
733
- if cov_estimator == EmpiricalCovariance():
734
- # square root of the diagonal of the true covariance matrix
735
- d = np.sqrt(np.diag(np.diag(true_covariance_matrix)))
736
-
737
- assert_array_almost_equal(
738
- d.dot(estimated_covariance_matrix).dot(d),
739
- true_covariance_matrix,
740
- )
741
-
742
-
743
- def _assert_connectivity_partial_correlation(connectivities, covs) -> None:
744
- """Assert partial correlation connectivity matrix."""
745
- for true_covariance_matrix, estimated_covariance_matrix in zip(
746
- covs, connectivities
747
- ):
748
- precision_matrix = linalg.inv(true_covariance_matrix)
749
-
750
- # square root of the diagonal elements of the precision matrix
751
- d = np.sqrt(np.diag(np.diag(precision_matrix)))
752
-
753
- # normalize the computed partial correlation matrix
754
- # necessary to ensure that the diagonal elements
755
- # of the partial correlation matrix are equal to 1
756
- normalized_partial_correlation_matrix = d.dot(
757
- estimated_covariance_matrix
758
- ).dot(d)
759
-
760
- # expected value
761
- partial_corrlelation_matrix = -precision_matrix + 2 * np.diag(
762
- np.diag(precision_matrix)
763
- )
764
-
765
- assert_array_almost_equal(
766
- normalized_partial_correlation_matrix,
767
- partial_corrlelation_matrix,
768
- )
769
-
770
-
771
- @pytest.mark.parametrize(
772
- "kind",
773
- ["tangent", "precision", "correlation", "partial correlation"],
774
- )
775
- @pytest.mark.parametrize(
776
- "cov_estimator", [EmpiricalCovariance(), LedoitWolf()]
777
- )
778
- def test_connectivity_measure_specific_for_each_kind(
779
- kind: str,
780
- cov_estimator: Type["BaseEstimator"],
781
- signals_and_covariances: Tuple[List[np.ndarray], List[float]],
782
- ) -> None:
783
- """Test connectivity matrix for each kind.
784
-
785
- Parameters
786
- ----------
787
- kind : str
788
- The parametrized connectivity matrix kind.
789
- cov_estimator : estimator object
790
- The parametrized covariance estimator.
791
- signals_and_covariances : tuple
792
- The signals and covariances for a covariance estimator.
793
-
794
- """
795
- signals, covs = signals_and_covariances
796
-
797
- conn_measure = JuniferConnectivityMeasure(
798
- kind=kind, cov_estimator=cov_estimator
799
- )
800
- connectivities = conn_measure.fit_transform(signals)
801
-
802
- if kind == "tangent":
803
- _assert_connectivity_tangent(connectivities, conn_measure, covs)
804
- elif kind == "precision":
805
- _assert_connectivity_precision(connectivities, covs)
806
- elif kind == "correlation":
807
- _assert_connectivity_correlation(connectivities, cov_estimator, covs)
808
- elif kind == "partial correlation":
809
- _assert_connectivity_partial_correlation(connectivities, covs)
810
-
811
-
812
- @pytest.mark.parametrize("kind", CONNECTIVITY_KINDS)
813
- def test_connectivity_measure_check_mean(
814
- kind: str, signals: List[np.ndarray]
815
- ) -> None:
816
- """Test mean of connectivity matrix for each kind.
817
-
818
- Parameters
819
- ----------
820
- kind : str
821
- The parametrized connectivity matrix kind.
822
- signals : list of np.ndarray
823
- The input signals.
824
-
825
- """
826
- conn_measure = JuniferConnectivityMeasure(kind=kind)
827
- conn_measure.fit_transform(signals)
828
-
829
- assert (conn_measure.mean_).shape == (N_FEATURES, N_FEATURES)
830
-
831
- if kind != "tangent":
832
- assert_array_almost_equal(
833
- conn_measure.mean_,
834
- np.mean(conn_measure.transform(signals), axis=0),
835
- )
836
-
837
- # Check that the mean isn't modified in transform
838
- conn_measure = JuniferConnectivityMeasure(kind="covariance")
839
- conn_measure.fit(signals[:1])
840
- mean = conn_measure.mean_
841
- conn_measure.transform(signals[1:])
842
-
843
- assert_array_equal(mean, conn_measure.mean_)
844
-
845
-
846
- @pytest.mark.parametrize("kind", CONNECTIVITY_KINDS)
847
- def test_connectivity_measure_check_vectorization_option(
848
- kind: str, signals: List[np.ndarray]
849
- ) -> None:
850
- """Test vectorization of connectivity matrix for each kind.
851
-
852
- Parameters
853
- ----------
854
- kind : str
855
- The parametrized connectivity matrix kind.
856
- signals : list of np.ndarray
857
- The input signals.
858
-
859
- """
860
- conn_measure = JuniferConnectivityMeasure(kind=kind)
861
- connectivities = conn_measure.fit_transform(signals)
862
- conn_measure = JuniferConnectivityMeasure(vectorize=True, kind=kind)
863
- vectorized_connectivities = conn_measure.fit_transform(signals)
864
-
865
- assert_array_almost_equal(
866
- vectorized_connectivities, sym_matrix_to_vec(connectivities)
867
- )
868
-
869
- # Check not fitted error
870
- with pytest.raises(ValueError, match="has not been fitted. "):
871
- JuniferConnectivityMeasure().inverse_transform(
872
- vectorized_connectivities
873
- )
874
-
875
-
876
- @pytest.mark.parametrize(
877
- "kind",
878
- ["covariance", "correlation", "precision", "partial correlation"],
879
- )
880
- def test_connectivity_measure_check_inverse_transformation(
881
- kind: str, signals: List[np.ndarray]
882
- ) -> None:
883
- """Test inverse transform.
884
-
885
- Parameters
886
- ----------
887
- kind : str
888
- The parametrized connectivity matrix kind.
889
- signals : list of np.ndarray
890
- The input signals.
891
-
892
- """
893
- # without vectorization: input matrices are returned with no change
894
- conn_measure = JuniferConnectivityMeasure(kind=kind)
895
- connectivities = conn_measure.fit_transform(signals)
896
-
897
- assert_array_almost_equal(
898
- conn_measure.inverse_transform(connectivities), connectivities
899
- )
900
-
901
- # with vectorization: input vectors are reshaped into matrices
902
- # if diagonal has not been discarded
903
- conn_measure = JuniferConnectivityMeasure(kind=kind, vectorize=True)
904
- vectorized_connectivities = conn_measure.fit_transform(signals)
905
-
906
- assert_array_almost_equal(
907
- conn_measure.inverse_transform(vectorized_connectivities),
908
- connectivities,
909
- )
910
-
911
-
912
- @pytest.mark.parametrize(
913
- "kind",
914
- ["covariance", "correlation", "precision", "partial correlation"],
915
- )
916
- def test_connectivity_measure_check_inverse_transformation_discard_diag(
917
- kind: str, signals: List[np.ndarray]
918
- ) -> None:
919
- """Test diagonal for inverse transform.
920
-
921
- Parameters
922
- ----------
923
- kind : str
924
- The parametrized connectivity matrix kind.
925
- signals : list of np.ndarray
926
- The input signals.
927
-
928
- """
929
- # with vectorization
930
- connectivities = JuniferConnectivityMeasure(kind=kind).fit_transform(
931
- signals
932
- )
933
- conn_measure = JuniferConnectivityMeasure(
934
- kind=kind, vectorize=True, discard_diagonal=True
935
- )
936
- vectorized_connectivities = conn_measure.fit_transform(signals)
937
-
938
- if kind in ["correlation", "partial correlation"]:
939
- assert_array_almost_equal(
940
- conn_measure.inverse_transform(vectorized_connectivities),
941
- connectivities,
942
- )
943
- elif kind in ["covariance", "precision"]:
944
- diagonal = np.array(
945
- [np.diagonal(conn) / sqrt(2) for conn in connectivities]
946
- )
947
- inverse_transformed = conn_measure.inverse_transform(
948
- vectorized_connectivities, diagonal=diagonal
949
- )
950
-
951
- assert_array_almost_equal(inverse_transformed, connectivities)
952
- with pytest.raises(
953
- ValueError, match="cannot reconstruct connectivity matrices"
954
- ):
955
- conn_measure.inverse_transform(vectorized_connectivities)
956
-
957
-
958
- def test_connectivity_measure_inverse_transform_tangent(
959
- signals: List[np.ndarray],
960
- ) -> None:
961
- """Test that for 'tangent' kind, covariance matrices are reconstructed.
962
-
963
- Parameters
964
- ----------
965
- signals : list of np.ndarray
966
- The input signals.
967
-
968
- """
969
- # Without vectorization
970
- tangent_measure = JuniferConnectivityMeasure(kind="tangent")
971
- displacements = tangent_measure.fit_transform(signals)
972
- covariances = JuniferConnectivityMeasure(kind="covariance").fit_transform(
973
- signals
974
- )
975
-
976
- assert_array_almost_equal(
977
- tangent_measure.inverse_transform(displacements), covariances
978
- )
979
-
980
- # with vectorization
981
- # when diagonal has not been discarded
982
- tangent_measure = JuniferConnectivityMeasure(
983
- kind="tangent", vectorize=True
984
- )
985
- vectorized_displacements = tangent_measure.fit_transform(signals)
986
-
987
- assert_array_almost_equal(
988
- tangent_measure.inverse_transform(vectorized_displacements),
989
- covariances,
990
- )
991
-
992
- # When diagonal has been discarded
993
- tangent_measure = JuniferConnectivityMeasure(
994
- kind="tangent", vectorize=True, discard_diagonal=True
995
- )
996
- vectorized_displacements = tangent_measure.fit_transform(signals)
997
-
998
- diagonal = np.array(
999
- [np.diagonal(matrix) / sqrt(2) for matrix in displacements]
1000
- )
1001
- inverse_transformed = tangent_measure.inverse_transform(
1002
- vectorized_displacements, diagonal=diagonal
1003
- )
1004
-
1005
- assert_array_almost_equal(inverse_transformed, covariances)
1006
- with pytest.raises(
1007
- ValueError, match="cannot reconstruct connectivity matrices"
1008
- ):
1009
- tangent_measure.inverse_transform(vectorized_displacements)
1010
-
1011
-
1012
- def test_confounds_connectivity_measure() -> None:
1013
- """Test confounds."""
1014
- n_subjects = 10
1015
-
1016
- signals, confounds = _signals(n_subjects)
1017
-
1018
- correlation_measure = JuniferConnectivityMeasure(
1019
- kind="correlation", vectorize=True
1020
- )
1021
-
1022
- # Clean confounds on 10 subjects with confounds filtered to 10 subjects in
1023
- # length
1024
- cleaned_vectors = correlation_measure.fit_transform(
1025
- signals, confounds=confounds[:10]
1026
- )
1027
-
1028
- zero_matrix = np.zeros((confounds.shape[1], cleaned_vectors.shape[1]))
1029
- assert_array_almost_equal(
1030
- np.dot(confounds[:10].T, cleaned_vectors), zero_matrix
1031
- )
1032
- assert isinstance(cleaned_vectors, np.ndarray)
1033
-
1034
- # Confounds as pandas DataFrame
1035
- confounds_df = DataFrame(confounds[:10])
1036
- correlation_measure.fit_transform(signals, confounds=confounds_df)
1037
-
1038
-
1039
- def test_confounds_connectivity_measure_errors() -> None:
1040
- """Test errors for dealing with confounds."""
1041
- # Generate signals and compute covariances and apply confounds while
1042
- # computing covariances
1043
- signals, confounds = _signals()
1044
-
1045
- # Raising error for input confounds are not iterable
1046
- conn_measure = JuniferConnectivityMeasure(vectorize=True)
1047
- msg = "'confounds' input argument must be an iterable"
1048
-
1049
- with pytest.raises(ValueError, match=msg):
1050
- conn_measure._check_input(X=signals, confounds=1.0)
1051
-
1052
- with pytest.raises(ValueError, match=msg):
1053
- conn_measure._fit_transform(
1054
- X=signals, do_fit=True, do_transform=True, confounds=1.0
1055
- )
1056
-
1057
- with pytest.raises(ValueError, match=msg):
1058
- conn_measure.fit_transform(X=signals, y=None, confounds=1.0)
1059
-
1060
- # Raising error for input confounds are given but not vectorize=True
1061
- conn_measure = JuniferConnectivityMeasure(vectorize=False)
1062
- with pytest.raises(
1063
- ValueError, match="'confounds' are provided but vectorize=False"
1064
- ):
1065
- conn_measure.fit_transform(signals, None, confounds[:10])
1066
-
1067
-
1068
- def test_connectivity_measure_standardize(
1069
- signals: List[np.ndarray],
1070
- ) -> None:
1071
- """Check warning is raised and then suppressed with setting standardize.
1072
-
1073
- Parameters
1074
- ----------
1075
- signals : list of np.ndarray
1076
- The input signals.
1077
-
1078
- """
1079
- match = "default strategy for standardize"
1080
-
1081
- with pytest.warns(DeprecationWarning, match=match):
1082
- JuniferConnectivityMeasure(kind="correlation").fit_transform(signals)
1083
-
1084
- with warnings.catch_warnings(record=True) as record:
1085
- JuniferConnectivityMeasure(
1086
- kind="correlation", standardize="zscore_sample"
1087
- ).fit_transform(signals)
1088
- for m in record:
1089
- assert match not in m.message