amica-python 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amica/__init__.py +5 -0
- amica/_batching.py +194 -0
- amica/_newton.py +77 -0
- amica/_sklearn_interface.py +387 -0
- amica/_types.py +44 -0
- amica/conftest.py +30 -0
- amica/constants.py +47 -0
- amica/core.py +1165 -0
- amica/datasets.py +15 -0
- amica/kernels.py +1308 -0
- amica/linalg.py +349 -0
- amica/state.py +385 -0
- amica/tests/test_amica.py +497 -0
- amica/utils/__init__.py +36 -0
- amica/utils/_logging.py +64 -0
- amica/utils/_progress.py +34 -0
- amica/utils/_verbose.py +14 -0
- amica/utils/fetch.py +274 -0
- amica/utils/fortran.py +387 -0
- amica/utils/imports.py +46 -0
- amica/utils/mne.py +74 -0
- amica/utils/parallel.py +72 -0
- amica/utils/simulation.py +36 -0
- amica/utils/tests/test_fetch.py +9 -0
- amica/utils/tests/test_fortran.py +47 -0
- amica/utils/tests/test_imports.py +0 -0
- amica/utils/tests/test_logger.py +29 -0
- amica/utils/tests/test_mne.py +27 -0
- amica_python-0.1.0.dist-info/METADATA +196 -0
- amica_python-0.1.0.dist-info/RECORD +33 -0
- amica_python-0.1.0.dist-info/WHEEL +5 -0
- amica_python-0.1.0.dist-info/licenses/LICENSE +25 -0
- amica_python-0.1.0.dist-info/top_level.txt +1 -0
amica/linalg.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
"""Whitening, unmixing helpers, determinant computation, etc."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from amica._types import ComponentsVector, DataArray2D, WeightsArray
|
|
9
|
+
|
|
10
|
+
from .utils._logging import log
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_unmixing_matrices(
|
|
14
|
+
*,
|
|
15
|
+
c,
|
|
16
|
+
A,
|
|
17
|
+
W,
|
|
18
|
+
):
|
|
19
|
+
"""Get unmixing matrices for AMICA."""
|
|
20
|
+
wc = torch.zeros_like(c)
|
|
21
|
+
#--------------------------FORTRAN CODE-------------------------
|
|
22
|
+
# call DCOPY(nw*nw,A(:,comp_list(:,h)),1,W(:,:,h),1)
|
|
23
|
+
#---------------------------------------------------------------
|
|
24
|
+
W[:, :] = A[:, :].clone()
|
|
25
|
+
|
|
26
|
+
#--------------------------FORTRAN CODE-------------------------
|
|
27
|
+
# call DGETRF(nw,nw,W(:,:,h),nw,ipivnw,info)
|
|
28
|
+
# call DGETRI(nw,W(:,:,h),nw,ipivnw,work,lwork,info)
|
|
29
|
+
#---------------------------------------------------------------
|
|
30
|
+
try:
|
|
31
|
+
W[:, :] = torch.linalg.inv(W[:, :])
|
|
32
|
+
except RuntimeError as e:
|
|
33
|
+
log("Matrix W is singular!")
|
|
34
|
+
raise e
|
|
35
|
+
|
|
36
|
+
#--------------------------FORTRAN CODE-------------------------
|
|
37
|
+
# call DGEMV('N',nw,nw,dble(1.0),W(:,:,h),nw,c(:,h),1,dble(0.0),wc(:,h),1)
|
|
38
|
+
#---------------------------------------------------------------
|
|
39
|
+
wc[:] = W[:, :] @ c[:]
|
|
40
|
+
|
|
41
|
+
return W, wc
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def compute_sign_log_determinant(
|
|
45
|
+
*,
|
|
46
|
+
unmixing_matrix: WeightsArray,
|
|
47
|
+
minlog: float = -1500,
|
|
48
|
+
mode: Literal["strict", "fallback"] = "strict",
|
|
49
|
+
) -> tuple[Literal[-1, 1], float]:
|
|
50
|
+
"""Compute the sign and log-determinant of the unmixing matrix for a single model.
|
|
51
|
+
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
unmixing_matrix: array, shape (n_components, n_features)
|
|
55
|
+
The unmixing matrix W for a single model h (i.e. a 2D slice of state.W).
|
|
56
|
+
minlog: float
|
|
57
|
+
Minimum log value: log absolute determinant to return if the computed
|
|
58
|
+
log-determinant is zero. Default is -1500, but currently if the computed
|
|
59
|
+
log-determinant is zero, an error is raised instead.
|
|
60
|
+
mode: str
|
|
61
|
+
Mode for handling cases where the computed log-determinant is zero.
|
|
62
|
+
default is "strict", Options are:
|
|
63
|
+
- "strict": Raise a ValueError if the log-determinant is zero.
|
|
64
|
+
- "fallback": Issue a warning and then set the log-determinant to minlog
|
|
65
|
+
(default: -1500) and set sign to -1.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
sign: {-1, 1}
|
|
70
|
+
The sign of the determinant (+1 or -1). In "fallback" mode, sign is set to -1 if
|
|
71
|
+
the determinant is zero, to maintain the invariant that sign (never 0).
|
|
72
|
+
logabsdet: float
|
|
73
|
+
The (natural) log-determinant of the unmixing matrix. In "fallback" mode, this
|
|
74
|
+
is set to minlog (default: -1500).
|
|
75
|
+
"""
|
|
76
|
+
#--------------------------------FORTRAN CODE------------------------------
|
|
77
|
+
# call DCOPY(nw*nw,W(:,:,h),1,Wtmp,1)
|
|
78
|
+
# ....
|
|
79
|
+
# call DGEQRF(nw,nw,Wtmp,nw,wr,work,lwork,info)
|
|
80
|
+
# ...
|
|
81
|
+
# Dtemp(h) = dble(0.0)
|
|
82
|
+
# ...
|
|
83
|
+
# Dtemp(h) = Dtemp(h) + log(abs(Wtmp(i,i)))
|
|
84
|
+
# ------------------------------------------------------------------------
|
|
85
|
+
# Alias for clarity with Fortran code
|
|
86
|
+
W = unmixing_matrix
|
|
87
|
+
sign, logabsdet = torch.linalg.slogdet(W)
|
|
88
|
+
# TODO: slogdet requires a square unmixing matrix. Does AMICA gaurantee this?
|
|
89
|
+
if logabsdet == -torch.inf or sign == 0: # Model fit has collapsed.
|
|
90
|
+
msg = (
|
|
91
|
+
"Unmixing matrix (W) is singular (determinant = 0)\n\n"
|
|
92
|
+
"Details:\n"
|
|
93
|
+
f"- shape of W: {W.shape}\n"
|
|
94
|
+
f"- sign={sign}, log|det|={logabsdet}\n\n"
|
|
95
|
+
"Things to try:\n"
|
|
96
|
+
"- Check that your input data is rank-sufficient\n"
|
|
97
|
+
"- Reduce the number of components\n"
|
|
98
|
+
)
|
|
99
|
+
if mode == "strict":
|
|
100
|
+
# By default Let's raise an error until we can test this case properly
|
|
101
|
+
raise ValueError(msg)
|
|
102
|
+
else:
|
|
103
|
+
log(msg, level="warning")
|
|
104
|
+
log(f"Setting log-determinant to {minlog} and sign to -1", level="warning")
|
|
105
|
+
# fallback values (numerical hack to let training continue)
|
|
106
|
+
logabsdet = minlog
|
|
107
|
+
sign = -1 # matches dsign = 1 if det > 0 else -1 in Fortran
|
|
108
|
+
return sign, logabsdet
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def get_initial_model_log_likelihood(
|
|
112
|
+
*,
|
|
113
|
+
unmixing_logdet: float,
|
|
114
|
+
whitening_logdet: float,
|
|
115
|
+
model_weight: float,
|
|
116
|
+
) -> float:
|
|
117
|
+
"""
|
|
118
|
+
Initialize the per-sample model log-likelihood with baseline terms.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
unmixing_logdet : float
|
|
123
|
+
The log-determinant of the unmixing matrix (W) for this model.
|
|
124
|
+
whitening_logdet : float
|
|
125
|
+
The log-determinant of the sphering/whitening transform (S),
|
|
126
|
+
computed from the input data's whitening/sphering matrix. e.g. It's computed
|
|
127
|
+
as -0.5 ∑ log(λ_i) where λ_i are covariance eigenvalues.
|
|
128
|
+
model_weight : float
|
|
129
|
+
The mixture proportion (prior probability) for this model.
|
|
130
|
+
|
|
131
|
+
Returns
|
|
132
|
+
-------
|
|
133
|
+
initial_modloglik : float
|
|
134
|
+
A scalar baseline log-likelihood value. This should be broadcast across all
|
|
135
|
+
samples of the model log-likelihood array at the call site.
|
|
136
|
+
|
|
137
|
+
Notes
|
|
138
|
+
-----
|
|
139
|
+
- The Jacobian from x → u is |det(W S)|, so log|det(W S)| = log|det(W)| +
|
|
140
|
+
log|det(S)| = Dsum[h] + sldet.
|
|
141
|
+
- S is positive-definite with full-rank whitening, so sldet has no sign
|
|
142
|
+
issue
|
|
143
|
+
- In the Fortran code, the variable Ptmp(bstrt:bstp,h) holds the initial
|
|
144
|
+
model log-likelihood for model h across the data block (bstrt:bstp). This gets
|
|
145
|
+
copied into modloglik.
|
|
146
|
+
"""
|
|
147
|
+
whitening_logdet = torch.as_tensor(whitening_logdet, dtype=torch.float64)
|
|
148
|
+
unmixing_logdet = torch.as_tensor(unmixing_logdet, dtype=torch.float64)
|
|
149
|
+
if model_weight <= 0:
|
|
150
|
+
raise ValueError(
|
|
151
|
+
f"model_weight must be > 0, got {model_weight}"
|
|
152
|
+
) # pragma no cover noqa: E501
|
|
153
|
+
#--------------------------FORTRAN CODE-------------------------
|
|
154
|
+
# Ptmp(bstrt:bstp,h) = Dsum(h) + log(gm(h)) + sldet
|
|
155
|
+
#---------------------------------------------------------------
|
|
156
|
+
initial_modloglik = unmixing_logdet + torch.log(model_weight) + whitening_logdet
|
|
157
|
+
return initial_modloglik
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def pre_whiten(
|
|
161
|
+
*,
|
|
162
|
+
X: DataArray2D,
|
|
163
|
+
n_components: int | None = None,
|
|
164
|
+
mineig: float = 1e-6,
|
|
165
|
+
do_mean: bool = True,
|
|
166
|
+
do_sphere: bool = True,
|
|
167
|
+
do_approx_sphere: bool = True,
|
|
168
|
+
inplace: bool = True,
|
|
169
|
+
) -> tuple[DataArray2D, WeightsArray, float, WeightsArray, ComponentsVector | None]:
|
|
170
|
+
"""
|
|
171
|
+
Pre-whiten the input data matrix X prior to ICA.
|
|
172
|
+
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
X : array, shape (``n_samples``, ``n_features``)
|
|
176
|
+
Input data matrix to be whitened. If ``inplace`` is ``True``, X will be
|
|
177
|
+
mutated and returned as the whitened data. Otherwise a copy will be made and
|
|
178
|
+
returned.
|
|
179
|
+
n_components : int or None
|
|
180
|
+
Number of components to keep. If ``None``, all components are kept.
|
|
181
|
+
mineig : float
|
|
182
|
+
Minimum eigenvalue threshold for keeping components. Eigenvalues below this will
|
|
183
|
+
be discarded.
|
|
184
|
+
do_mean : bool
|
|
185
|
+
If ``True``, mean-center the data before whitening.
|
|
186
|
+
do_sphere : bool
|
|
187
|
+
If ``True``, perform sphering (whitening). If ``False``, only variance
|
|
188
|
+
normalization is performed.
|
|
189
|
+
do_approx_sphere : bool
|
|
190
|
+
If ``True``, Data is whitened with the inverse of the symmetric square root of
|
|
191
|
+
the covariance matrix (ZCA whitening). If ``False``, PCA whitening is
|
|
192
|
+
performed. Only used if ``do_sphere`` is ``True``.
|
|
193
|
+
inplace : bool
|
|
194
|
+
If ``True``, modify X in place. If ``False``, make a copy of X and modify that.
|
|
195
|
+
|
|
196
|
+
Returns
|
|
197
|
+
-------
|
|
198
|
+
X : array, shape (``n_samples``, ``n_features``)
|
|
199
|
+
The whitened data matrix. This is a copy of the input data if inplace is False,
|
|
200
|
+
otherwise it is the mutated input data itself.
|
|
201
|
+
whitening_matrix : array, shape (``n_features``, ``n_features``)
|
|
202
|
+
The whitening/sphering matrix applied to the data. If do_sphere is False, then
|
|
203
|
+
this is the variance normalization matrix.
|
|
204
|
+
sldet : float
|
|
205
|
+
The log-determinant of the whitening matrix.
|
|
206
|
+
whitening_inverse : array, shape (``n_features``, ``n_features``)
|
|
207
|
+
The pseudoinverse of the whitening matrix. Only returned if do_sphere is True.
|
|
208
|
+
otherwise None.
|
|
209
|
+
mean : array, shape (``n_features``,)
|
|
210
|
+
The mean of each feature that was subtracted if ``do_mean`` is ``True``. Only
|
|
211
|
+
returned if ``do_mean`` is ``True``, otherwise ``None``.
|
|
212
|
+
"""
|
|
213
|
+
dataseg = X if inplace else X.copy()
|
|
214
|
+
assert dataseg.ndim == 2, f"X must be 2D, got {dataseg.ndim}D"
|
|
215
|
+
# !---------------------------- get the mean --------------------------------
|
|
216
|
+
n_samples, nx = dataseg.shape
|
|
217
|
+
if n_components is None:
|
|
218
|
+
n_components = nx
|
|
219
|
+
|
|
220
|
+
# ---- Mean-centering ----
|
|
221
|
+
if do_mean:
|
|
222
|
+
log("getting the mean ...")
|
|
223
|
+
mean = dataseg.mean(axis=0)
|
|
224
|
+
# !--- subtract the mean
|
|
225
|
+
dataseg -= mean[None, :] # Subtract mean from each channel
|
|
226
|
+
|
|
227
|
+
# ---- Covariance ----
|
|
228
|
+
log(" Getting the covariance matrix ...")
|
|
229
|
+
# Compute the covariance matrix
|
|
230
|
+
# The Fortran code only computes the upper triangular part of the covariance matrix
|
|
231
|
+
|
|
232
|
+
# -------------------- FORTRAN CODE ------------------------------------------------
|
|
233
|
+
# call DSCAL(nx*nx,dble(0.0),Stmp,1)
|
|
234
|
+
# call DSYRK('L','N',nx,blk_size(seg),dble(1.0),dataseg(seg)%data(:,bstrt:bstp)...
|
|
235
|
+
# call DSCAL(nx*nx,dble(1.0)/dble(cnt),S,1)
|
|
236
|
+
#-----------------------------------------------------------------------------------
|
|
237
|
+
Cov = dataseg.T @ dataseg / n_samples
|
|
238
|
+
|
|
239
|
+
# ---- Eigen-decomposition
|
|
240
|
+
log(f"doing eigenvalue decomposition for {nx} features ...")
|
|
241
|
+
eigvals, eigvecs = np.linalg.eigh(Cov) # ascending order
|
|
242
|
+
|
|
243
|
+
min_eigs = eigvals[:min(nx//2, 3)]
|
|
244
|
+
max_eigs = eigvals[::-1][:3]
|
|
245
|
+
log(f"minimum eigenvalues: {min_eigs}")
|
|
246
|
+
log(f"maximum eigenvalues: {max_eigs}")
|
|
247
|
+
|
|
248
|
+
# keep only valid eigs (pcakeep)
|
|
249
|
+
# Do we need to pass numeigs to optimize if sum(eigvals > mineig) < n_components?
|
|
250
|
+
numeigs = min(n_components, sum(eigvals > mineig)) # np.linalg.matrix_rank?
|
|
251
|
+
log(f"num eigvals kept: {numeigs}")
|
|
252
|
+
|
|
253
|
+
# Log determinant of the whitening matrix
|
|
254
|
+
if numeigs == nx:
|
|
255
|
+
sldet = -0.5 * np.sum(np.log(eigvals))
|
|
256
|
+
else:
|
|
257
|
+
sldet = -0.5 * np.sum(np.log(eigvals[::-1][:numeigs]))
|
|
258
|
+
|
|
259
|
+
# 1) reorder eigenvectors (descending eigenvalues)
|
|
260
|
+
order = np.argsort(eigvals)[::-1]
|
|
261
|
+
eigvals_desc = eigvals[order]
|
|
262
|
+
# UNCSCALED eigenvectors
|
|
263
|
+
Stmp = eigvecs[:, order].T.copy()
|
|
264
|
+
Stmp2 = Stmp.copy()
|
|
265
|
+
# 2) SCALED eigenvectors
|
|
266
|
+
Stmp2[:numeigs, :] /= np.sqrt(eigvals_desc[:numeigs, None])
|
|
267
|
+
|
|
268
|
+
# ---- Sphere or variance normalize ----
|
|
269
|
+
if do_sphere:
|
|
270
|
+
log("Sphering the data...", level="info", color="blue", weight="bold")
|
|
271
|
+
if numeigs == nx:
|
|
272
|
+
# call DSCAL(nx*nx,dble(0.0),S,1)
|
|
273
|
+
if do_approx_sphere:
|
|
274
|
+
# Zero-copy assignment
|
|
275
|
+
S = (eigvecs * (1.0 / np.sqrt(eigvals))) @ eigvecs.T
|
|
276
|
+
else:
|
|
277
|
+
# call DCOPY(nx*nx,Stmp2,1,S,1)
|
|
278
|
+
S = Stmp2.copy()
|
|
279
|
+
else:
|
|
280
|
+
if do_approx_sphere:
|
|
281
|
+
# -------------------- FORTRAN CODE ------------------------------------
|
|
282
|
+
# call DSCAL(nx*blk_size(seg),dble(0.0),xtmp(:,1:blk_size(seg)),1)
|
|
283
|
+
# call DGEMM('N','N',nx,blk_size(seg),nx,dble(1.0),S,nx,dataseg(seg)...
|
|
284
|
+
# call DCOPY(nx*blk_size(seg),xtmp(:,1:blk_size(seg)),1,dataseg(seg)
|
|
285
|
+
# ----------------------------------------------------------------------
|
|
286
|
+
|
|
287
|
+
# This is a direct translation of the Fortran code
|
|
288
|
+
# Which was hard to follow...
|
|
289
|
+
# I think this is ZCA whitening
|
|
290
|
+
|
|
291
|
+
# 3) Unscaled eigenvectors goes into SVD
|
|
292
|
+
S = Stmp.copy()
|
|
293
|
+
|
|
294
|
+
# 4) SVD on leading block of Unscaled eigenvectors
|
|
295
|
+
U, s, VT = np.linalg.svd(S[:numeigs, :numeigs], full_matrices=True)
|
|
296
|
+
Stmp[:numeigs, :numeigs] = VT.T
|
|
297
|
+
S[:numeigs, :numeigs] = U.T
|
|
298
|
+
|
|
299
|
+
# 5) Stmp3 = Stmp^T @ S^T (numeigs×numeigs block)
|
|
300
|
+
Stmp3 = Stmp[:numeigs, :numeigs] @ S[:numeigs, :numeigs]
|
|
301
|
+
|
|
302
|
+
# 6) zero S and form final S = Stmp3 @ Stmp2
|
|
303
|
+
S.fill(0.0)
|
|
304
|
+
S[:numeigs, :] = Stmp3 @ Stmp2[:numeigs, :]
|
|
305
|
+
else:
|
|
306
|
+
# I think this is PCA whitening
|
|
307
|
+
S = Stmp2.copy()
|
|
308
|
+
else:
|
|
309
|
+
# !--- just normalize by the channel variances (don't sphere)
|
|
310
|
+
# -------------------- FORTRAN CODE ---------------------------------------
|
|
311
|
+
# call DCOPY(nx*nx,S,1,Stmp,1)
|
|
312
|
+
# call DSCAL(nx*nx,dble(0.0),S,1)
|
|
313
|
+
#------------------------------------------------------------------------
|
|
314
|
+
S = np.zeros_like(Cov) # This is S in Fortran code
|
|
315
|
+
# Zero out the lower triangle to have parity with Fortran
|
|
316
|
+
sldet = 0.0
|
|
317
|
+
for i in range(nx):
|
|
318
|
+
if np.triu(Cov)[i, i] > 0:
|
|
319
|
+
S[i, i] = 1.0 / np.sqrt(Cov[i, i])
|
|
320
|
+
sldet += 0.5 * np.log(S[i, i])
|
|
321
|
+
numeigs = nx
|
|
322
|
+
# -------------------- FORTRAN CODE ---------------------------------------
|
|
323
|
+
# call DSCAL(nx*blk_size(seg),dble(0.0),xtmp(:,1:blk_size(seg)),1)
|
|
324
|
+
# call DGEMM('N','N',nx,blk_size(seg),nx,dble(1.0),S,nx,dataseg(seg)%data(:,bstrt...
|
|
325
|
+
# call DCOPY(nx*blk_size(seg),xtmp(:,1:blk_size(seg)),1,dataseg(seg)%data(:,bstrt...
|
|
326
|
+
# -------------------------------------------------------------------------
|
|
327
|
+
dataseg = np.matmul(dataseg, S.T, out=dataseg) # In-place if possible
|
|
328
|
+
|
|
329
|
+
nw = numeigs # Number of weights, as per Fortran code
|
|
330
|
+
log(f"numeigs = {numeigs}, nw = {nw}")
|
|
331
|
+
# ! get the pseudoinverse of the sphering matrix
|
|
332
|
+
# call DGESVD( 'A', 'S', numeigs, nx, Stmp2, nx, eigvals, sUtmp, numeigs, sVtmp...
|
|
333
|
+
Winv = (eigvecs * np.sqrt(eigvals)) @ eigvecs.T # Inverse of the whitening matrix
|
|
334
|
+
|
|
335
|
+
if n_components is None:
|
|
336
|
+
n_components = nw
|
|
337
|
+
elif n_components > nw:
|
|
338
|
+
raise ValueError(
|
|
339
|
+
f"n_components must be less than or equal to the rank of the data. "
|
|
340
|
+
f"Got a rank of {nw} but {n_components} requested components."
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
if not do_mean:
|
|
344
|
+
mean = None
|
|
345
|
+
assert dataseg.shape == (n_samples, nx), (
|
|
346
|
+
f"dataseg shape {dataseg.shape} "
|
|
347
|
+
"!= (n_samples, n_features) = ({n_samples}, {nx}) "
|
|
348
|
+
)
|
|
349
|
+
return dataseg, S, sldet, Winv, mean
|